Implement 'GROUP BY' statement

This commit is contained in:
Pijus Kamandulis 2024-03-11 17:50:20 +02:00
parent 18edb925bf
commit b72bba86c8
6 changed files with 1120 additions and 942 deletions

View File

@ -8,6 +8,7 @@ type SelectStmt struct {
Count int
Parameters map[string]interface{}
OrderExpressions []OrderExpression
GroupBy []SelectItem
}
type Table struct {

View File

@ -63,6 +63,24 @@ func Test_Parse(t *testing.T) {
)
})
t.Run("Should parse SELECT with GROUP BY", func(t *testing.T) {
testQueryParse(
t,
`SELECT c.id, c["pk"] FROM c GROUP BY c.id, c.pk`,
parsers.SelectStmt{
SelectItems: []parsers.SelectItem{
{Path: []string{"c", "id"}},
{Path: []string{"c", "pk"}},
},
Table: parsers.Table{Value: "c"},
GroupBy: []parsers.SelectItem{
{Path: []string{"c", "id"}},
{Path: []string{"c", "pk"}},
},
},
)
})
t.Run("Should parse IN function", func(t *testing.T) {
testQueryParse(
t,

File diff suppressed because it is too large Load Diff

View File

@ -6,7 +6,7 @@ import "github.com/pikami/cosmium/parsers"
func makeSelectStmt(
columns, table,
whereClause interface{}, distinctClause interface{},
count interface{}, orderList interface{},
count interface{}, groupByClause interface{}, orderList interface{},
) (parsers.SelectStmt, error) {
selectStmt := parsers.SelectStmt{
SelectItems: columns.([]parsers.SelectItem),
@ -30,6 +30,10 @@ func makeSelectStmt(
selectStmt.OrderExpressions = orderExpressions
}
if groupByClause != nil {
selectStmt.GroupBy = groupByClause.([]parsers.SelectItem)
}
return selectStmt, nil
}
@ -149,8 +153,10 @@ SelectStmt <- Select ws
topClause:TopClause? ws columns:Selection ws
From ws table:TableName ws
whereClause:(ws Where ws condition:Condition { return condition, nil })?
groupByClause:(ws GroupBy ws columns:ColumnList { return columns, nil })?
orderByClause:OrderByClause? {
return makeSelectStmt(columns, table, whereClause, distinctClause, topClause, orderByClause)
return makeSelectStmt(columns, table, whereClause,
distinctClause, topClause, groupByClause, orderByClause)
}
DistinctClause <- "DISTINCT"i
@ -285,6 +291,8 @@ And <- "AND"i
Or <- "OR"i
GroupBy <- "GROUP"i ws "BY"i
OrderBy <- "ORDER"i ws "BY"i
ComparisonOperator <- ("=" / "!=" / "<" / "<=" / ">" / ">=") {

View File

@ -36,12 +36,20 @@ func Execute(query parsers.SelectStmt, data []RowType) []RowType {
ctx.orderBy(query.OrderExpressions, result)
}
// Apply group
isGroupSelect := query.GroupBy != nil && len(query.GroupBy) > 0
if isGroupSelect {
result = ctx.groupBy(query, result)
}
// Apply select
if !isGroupSelect {
selectedData := make([]RowType, 0)
for _, row := range result {
selectedData = append(selectedData, ctx.selectRow(query.SelectItems, row))
}
result = selectedData
}
// Apply distinct
if query.Distinct {
@ -182,6 +190,11 @@ func (c memoryExecutorContext) getFieldValue(field parsers.SelectItem, row RowTy
return typedValue.Value
}
rowValue := row
if array, isArray := row.([]RowType); isArray {
rowValue = array[0]
}
if field.Type == parsers.SelectItemTypeFunctionCall {
var typedValue parsers.FunctionCall
var ok bool
@ -192,82 +205,83 @@ func (c memoryExecutorContext) getFieldValue(field parsers.SelectItem, row RowTy
switch typedValue.Type {
case parsers.FunctionCallStringEquals:
return c.strings_StringEquals(typedValue.Arguments, row)
return c.strings_StringEquals(typedValue.Arguments, rowValue)
case parsers.FunctionCallContains:
return c.strings_Contains(typedValue.Arguments, row)
return c.strings_Contains(typedValue.Arguments, rowValue)
case parsers.FunctionCallEndsWith:
return c.strings_EndsWith(typedValue.Arguments, row)
return c.strings_EndsWith(typedValue.Arguments, rowValue)
case parsers.FunctionCallStartsWith:
return c.strings_StartsWith(typedValue.Arguments, row)
return c.strings_StartsWith(typedValue.Arguments, rowValue)
case parsers.FunctionCallConcat:
return c.strings_Concat(typedValue.Arguments, row)
return c.strings_Concat(typedValue.Arguments, rowValue)
case parsers.FunctionCallIndexOf:
return c.strings_IndexOf(typedValue.Arguments, row)
return c.strings_IndexOf(typedValue.Arguments, rowValue)
case parsers.FunctionCallToString:
return c.strings_ToString(typedValue.Arguments, row)
return c.strings_ToString(typedValue.Arguments, rowValue)
case parsers.FunctionCallUpper:
return c.strings_Upper(typedValue.Arguments, row)
return c.strings_Upper(typedValue.Arguments, rowValue)
case parsers.FunctionCallLower:
return c.strings_Lower(typedValue.Arguments, row)
return c.strings_Lower(typedValue.Arguments, rowValue)
case parsers.FunctionCallLeft:
return c.strings_Left(typedValue.Arguments, row)
return c.strings_Left(typedValue.Arguments, rowValue)
case parsers.FunctionCallLength:
return c.strings_Length(typedValue.Arguments, row)
return c.strings_Length(typedValue.Arguments, rowValue)
case parsers.FunctionCallLTrim:
return c.strings_LTrim(typedValue.Arguments, row)
return c.strings_LTrim(typedValue.Arguments, rowValue)
case parsers.FunctionCallReplace:
return c.strings_Replace(typedValue.Arguments, row)
return c.strings_Replace(typedValue.Arguments, rowValue)
case parsers.FunctionCallReplicate:
return c.strings_Replicate(typedValue.Arguments, row)
return c.strings_Replicate(typedValue.Arguments, rowValue)
case parsers.FunctionCallReverse:
return c.strings_Reverse(typedValue.Arguments, row)
return c.strings_Reverse(typedValue.Arguments, rowValue)
case parsers.FunctionCallRight:
return c.strings_Right(typedValue.Arguments, row)
return c.strings_Right(typedValue.Arguments, rowValue)
case parsers.FunctionCallRTrim:
return c.strings_RTrim(typedValue.Arguments, row)
return c.strings_RTrim(typedValue.Arguments, rowValue)
case parsers.FunctionCallSubstring:
return c.strings_Substring(typedValue.Arguments, row)
return c.strings_Substring(typedValue.Arguments, rowValue)
case parsers.FunctionCallTrim:
return c.strings_Trim(typedValue.Arguments, row)
return c.strings_Trim(typedValue.Arguments, rowValue)
case parsers.FunctionCallIsDefined:
return c.typeChecking_IsDefined(typedValue.Arguments, row)
return c.typeChecking_IsDefined(typedValue.Arguments, rowValue)
case parsers.FunctionCallIsArray:
return c.typeChecking_IsArray(typedValue.Arguments, row)
return c.typeChecking_IsArray(typedValue.Arguments, rowValue)
case parsers.FunctionCallIsBool:
return c.typeChecking_IsBool(typedValue.Arguments, row)
return c.typeChecking_IsBool(typedValue.Arguments, rowValue)
case parsers.FunctionCallIsFiniteNumber:
return c.typeChecking_IsFiniteNumber(typedValue.Arguments, row)
return c.typeChecking_IsFiniteNumber(typedValue.Arguments, rowValue)
case parsers.FunctionCallIsInteger:
return c.typeChecking_IsInteger(typedValue.Arguments, row)
return c.typeChecking_IsInteger(typedValue.Arguments, rowValue)
case parsers.FunctionCallIsNull:
return c.typeChecking_IsNull(typedValue.Arguments, row)
return c.typeChecking_IsNull(typedValue.Arguments, rowValue)
case parsers.FunctionCallIsNumber:
return c.typeChecking_IsNumber(typedValue.Arguments, row)
return c.typeChecking_IsNumber(typedValue.Arguments, rowValue)
case parsers.FunctionCallIsObject:
return c.typeChecking_IsObject(typedValue.Arguments, row)
return c.typeChecking_IsObject(typedValue.Arguments, rowValue)
case parsers.FunctionCallIsPrimitive:
return c.typeChecking_IsPrimitive(typedValue.Arguments, row)
return c.typeChecking_IsPrimitive(typedValue.Arguments, rowValue)
case parsers.FunctionCallIsString:
return c.typeChecking_IsString(typedValue.Arguments, row)
return c.typeChecking_IsString(typedValue.Arguments, rowValue)
case parsers.FunctionCallArrayConcat:
return c.array_Concat(typedValue.Arguments, row)
return c.array_Concat(typedValue.Arguments, rowValue)
case parsers.FunctionCallArrayLength:
return c.array_Length(typedValue.Arguments, row)
return c.array_Length(typedValue.Arguments, rowValue)
case parsers.FunctionCallArraySlice:
return c.array_Slice(typedValue.Arguments, row)
return c.array_Slice(typedValue.Arguments, rowValue)
case parsers.FunctionCallSetIntersect:
return c.set_Intersect(typedValue.Arguments, row)
return c.set_Intersect(typedValue.Arguments, rowValue)
case parsers.FunctionCallSetUnion:
return c.set_Union(typedValue.Arguments, row)
return c.set_Union(typedValue.Arguments, rowValue)
case parsers.FunctionCallIn:
return c.misc_In(typedValue.Arguments, row)
return c.misc_In(typedValue.Arguments, rowValue)
}
}
value := row
value := rowValue
if len(field.Path) > 1 {
for _, pathSegment := range field.Path[1:] {
if nestedValue, ok := value.(map[string]interface{}); ok {
@ -314,6 +328,47 @@ func (c memoryExecutorContext) orderBy(orderBy []parsers.OrderExpression, data [
sort.SliceStable(data, less)
}
func (c memoryExecutorContext) groupBy(selectStmt parsers.SelectStmt, data []RowType) []RowType {
groupedRows := make(map[string][]RowType)
groupedKeys := make([]string, 0)
// Group rows by group by columns
for _, row := range data {
key := c.generateGroupKey(selectStmt.GroupBy, row)
if _, ok := groupedRows[key]; !ok {
groupedKeys = append(groupedKeys, key)
}
groupedRows[key] = append(groupedRows[key], row)
}
// Aggregate each group
aggregatedRows := make([]RowType, 0)
for _, key := range groupedKeys {
groupRows := groupedRows[key]
aggregatedRow := c.aggregateGroup(selectStmt, groupRows)
aggregatedRows = append(aggregatedRows, aggregatedRow)
}
return aggregatedRows
}
func (c memoryExecutorContext) generateGroupKey(groupByFields []parsers.SelectItem, row RowType) string {
var keyBuilder strings.Builder
for _, column := range groupByFields {
fieldValue := c.getFieldValue(column, row)
keyBuilder.WriteString(fmt.Sprintf("%v", fieldValue))
keyBuilder.WriteString(":")
}
return keyBuilder.String()
}
func (c memoryExecutorContext) aggregateGroup(selectStmt parsers.SelectStmt, groupRows []RowType) RowType {
aggregatedRow := c.selectRow(selectStmt.SelectItems, groupRows)
return aggregatedRow
}
func compareValues(val1, val2 interface{}) int {
if reflect.TypeOf(val1) != reflect.TypeOf(val2) {
return 1

View File

@ -59,6 +59,26 @@ func Test_Execute(t *testing.T) {
)
})
t.Run("Should execute SELECT with GROUP BY", func(t *testing.T) {
testQueryExecute(
t,
parsers.SelectStmt{
SelectItems: []parsers.SelectItem{
{Path: []string{"c", "pk"}},
},
Table: parsers.Table{Value: "c"},
GroupBy: []parsers.SelectItem{
{Path: []string{"c", "pk"}},
},
},
mockData,
[]memoryexecutor.RowType{
map[string]interface{}{"pk": 123},
map[string]interface{}{"pk": 456},
},
)
})
t.Run("Should execute IN function", func(t *testing.T) {
testQueryExecute(
t,