diff --git a/api/handlers/documents.go b/api/handlers/documents.go index faca57a..10bc0bc 100644 --- a/api/handlers/documents.go +++ b/api/handlers/documents.go @@ -10,6 +10,7 @@ import ( "github.com/gin-gonic/gin" apimodels "github.com/pikami/cosmium/api/api_models" "github.com/pikami/cosmium/internal/constants" + "github.com/pikami/cosmium/internal/converters" "github.com/pikami/cosmium/internal/datastore" "github.com/pikami/cosmium/internal/logger" "github.com/pikami/cosmium/parsers" @@ -378,20 +379,16 @@ func (h *Handlers) executeQueryDocuments(databaseId string, collectionId string, return nil, datastore.BadRequest } - collectionDocuments, status := h.dataStore.GetAllDocuments(databaseId, collectionId) + allDocumentsIterator, status := h.dataStore.GetDocumentIterator(databaseId, collectionId) if status != datastore.StatusOk { return nil, status } - // TODO: Investigate, this could cause unnecessary memory usage - covDocs := make([]memoryexecutor.RowType, 0) - for _, doc := range collectionDocuments { - covDocs = append(covDocs, map[string]interface{}(doc)) - } + rowsIterator := converters.NewDocumentToRowTypeIterator(allDocumentsIterator) if typedQuery, ok := parsedQuery.(parsers.SelectStmt); ok { typedQuery.Parameters = queryParameters - return memoryexecutor.ExecuteQuery(typedQuery, covDocs), datastore.StatusOk + return memoryexecutor.ExecuteQuery(typedQuery, rowsIterator), datastore.StatusOk } return nil, datastore.BadRequest diff --git a/internal/converters/document_to_rowtype.go b/internal/converters/document_to_rowtype.go new file mode 100644 index 0000000..5a343b0 --- /dev/null +++ b/internal/converters/document_to_rowtype.go @@ -0,0 +1,20 @@ +package converters + +import ( + "github.com/pikami/cosmium/internal/datastore" + memoryexecutor "github.com/pikami/cosmium/query_executors/memory_executor" +) + +type DocumentToRowTypeIterator struct { + documents datastore.DocumentIterator +} + +func NewDocumentToRowTypeIterator(documents datastore.DocumentIterator) *DocumentToRowTypeIterator { + return &DocumentToRowTypeIterator{ + documents: documents, + } +} + +func (di *DocumentToRowTypeIterator) Next() (memoryexecutor.RowType, datastore.DataStoreStatus) { + return di.documents.Next() +} diff --git a/internal/datastore/datastore.go b/internal/datastore/datastore.go index c742f5b..8645b7a 100644 --- a/internal/datastore/datastore.go +++ b/internal/datastore/datastore.go @@ -40,5 +40,4 @@ type DataStore interface { type DocumentIterator interface { Next() (Document, DataStoreStatus) - HasMore() bool } diff --git a/internal/datastore/map_datastore/array_document_iterator.go b/internal/datastore/map_datastore/array_document_iterator.go index c52ade2..060998d 100644 --- a/internal/datastore/map_datastore/array_document_iterator.go +++ b/internal/datastore/map_datastore/array_document_iterator.go @@ -15,7 +15,3 @@ func (i *ArrayDocumentIterator) Next() (datastore.Document, datastore.DataStoreS return i.documents[i.index], datastore.StatusOk } - -func (i *ArrayDocumentIterator) HasMore() bool { - return i.index < len(i.documents)-1 -} diff --git a/internal/datastore/models.go b/internal/datastore/models.go index 3dc3a58..4d8a5cd 100644 --- a/internal/datastore/models.go +++ b/internal/datastore/models.go @@ -15,6 +15,7 @@ const ( StatusNotFound = 2 Conflict = 3 BadRequest = 4 + IterEOF = 5 ) type TriggerOperation string diff --git a/query_executors/memory_executor/array_functions.go b/query_executors/memory_executor/array_functions.go index 43b2a27..6d60517 100644 --- a/query_executors/memory_executor/array_functions.go +++ b/query_executors/memory_executor/array_functions.go @@ -196,6 +196,10 @@ func (r rowContext) parseArray(argument interface{}) []interface{} { ex := r.resolveSelectItem(exItem) arrValue := reflect.ValueOf(ex) + if arrValue.Kind() == reflect.Invalid { + return nil + } + if arrValue.Kind() != reflect.Slice { logger.ErrorLn("parseArray got parameters of wrong type") return nil diff --git a/query_executors/memory_executor/array_iterator.go b/query_executors/memory_executor/array_iterator.go new file mode 100644 index 0000000..e476306 --- /dev/null +++ b/query_executors/memory_executor/array_iterator.go @@ -0,0 +1,27 @@ +package memoryexecutor + +import "github.com/pikami/cosmium/internal/datastore" + +type rowArrayIterator struct { + documents []rowContext + index int +} + +func NewRowArrayIterator(documents []rowContext) *rowArrayIterator { + return &rowArrayIterator{ + documents: documents, + index: -1, + } +} + +func (i *rowArrayIterator) Next() (rowContext, datastore.DataStoreStatus) { + i.index++ + if i.index >= len(i.documents) { + return rowContext{}, datastore.IterEOF + } + + row := i.documents[i.index] + i.documents[i.index] = rowContext{} // Help GC reclaim memory + + return row, datastore.StatusOk +} diff --git a/query_executors/memory_executor/common.go b/query_executors/memory_executor/common.go new file mode 100644 index 0000000..866ed38 --- /dev/null +++ b/query_executors/memory_executor/common.go @@ -0,0 +1,397 @@ +package memoryexecutor + +import ( + "fmt" + "reflect" + "strconv" + "strings" + + "github.com/pikami/cosmium/internal/datastore" + "github.com/pikami/cosmium/internal/logger" + "github.com/pikami/cosmium/parsers" +) + +type RowType interface{} +type rowContext struct { + tables map[string]RowType + parameters map[string]interface{} + grouppedRows []rowContext +} + +type rowIterator interface { + Next() (rowContext, datastore.DataStoreStatus) +} + +type rowTypeIterator interface { + Next() (RowType, datastore.DataStoreStatus) +} + +func resolveDestinationColumnName(selectItem parsers.SelectItem, itemIndex int, queryParameters map[string]interface{}) string { + if selectItem.Alias != "" { + return selectItem.Alias + } + + destinationName := fmt.Sprintf("$%d", itemIndex+1) + if len(selectItem.Path) > 0 { + destinationName = selectItem.Path[len(selectItem.Path)-1] + } + + if destinationName[0] == '@' { + destinationName = queryParameters[destinationName].(string) + } + + return destinationName +} + +func (r rowContext) resolveSelectItem(selectItem parsers.SelectItem) interface{} { + if selectItem.Type == parsers.SelectItemTypeArray { + return r.selectItem_SelectItemTypeArray(selectItem) + } + + if selectItem.Type == parsers.SelectItemTypeObject { + return r.selectItem_SelectItemTypeObject(selectItem) + } + + if selectItem.Type == parsers.SelectItemTypeConstant { + return r.selectItem_SelectItemTypeConstant(selectItem) + } + + if selectItem.Type == parsers.SelectItemTypeSubQuery { + return r.selectItem_SelectItemTypeSubQuery(selectItem) + } + + if selectItem.Type == parsers.SelectItemTypeFunctionCall { + if typedFunctionCall, ok := selectItem.Value.(parsers.FunctionCall); ok { + return r.selectItem_SelectItemTypeFunctionCall(typedFunctionCall) + } + + logger.ErrorLn("parsers.SelectItem has incorrect Value type (expected parsers.FunctionCall)") + return nil + } + + return r.selectItem_SelectItemTypeField(selectItem) +} + +func (r rowContext) selectItem_SelectItemTypeArray(selectItem parsers.SelectItem) interface{} { + arrayValue := make([]interface{}, 0) + for _, subSelectItem := range selectItem.SelectItems { + arrayValue = append(arrayValue, r.resolveSelectItem(subSelectItem)) + } + return arrayValue +} + +func (r rowContext) selectItem_SelectItemTypeObject(selectItem parsers.SelectItem) interface{} { + objectValue := make(map[string]interface{}) + for _, subSelectItem := range selectItem.SelectItems { + objectValue[subSelectItem.Alias] = r.resolveSelectItem(subSelectItem) + } + return objectValue +} + +func (r rowContext) selectItem_SelectItemTypeConstant(selectItem parsers.SelectItem) interface{} { + var typedValue parsers.Constant + var ok bool + if typedValue, ok = selectItem.Value.(parsers.Constant); !ok { + // TODO: Handle error + logger.ErrorLn("parsers.Constant has incorrect Value type") + } + + if typedValue.Type == parsers.ConstantTypeParameterConstant && + r.parameters != nil { + if key, ok := typedValue.Value.(string); ok { + return r.parameters[key] + } + } + + return typedValue.Value +} + +func (r rowContext) selectItem_SelectItemTypeSubQuery(selectItem parsers.SelectItem) interface{} { + subQuery := selectItem.Value.(parsers.SelectStmt) + subQueryResult := executeQuery( + subQuery, + NewRowArrayIterator([]rowContext{r}), + ) + + if subQuery.Exists { + _, status := subQueryResult.Next() + return status == datastore.StatusOk + } + + allDocuments := make([]RowType, 0) + for { + row, status := subQueryResult.Next() + if status != datastore.StatusOk { + break + } + allDocuments = append(allDocuments, row) + } + + return allDocuments +} + +func (r rowContext) selectItem_SelectItemTypeFunctionCall(functionCall parsers.FunctionCall) interface{} { + switch functionCall.Type { + case parsers.FunctionCallStringEquals: + return r.strings_StringEquals(functionCall.Arguments) + case parsers.FunctionCallContains: + return r.strings_Contains(functionCall.Arguments) + case parsers.FunctionCallEndsWith: + return r.strings_EndsWith(functionCall.Arguments) + case parsers.FunctionCallStartsWith: + return r.strings_StartsWith(functionCall.Arguments) + case parsers.FunctionCallConcat: + return r.strings_Concat(functionCall.Arguments) + case parsers.FunctionCallIndexOf: + return r.strings_IndexOf(functionCall.Arguments) + case parsers.FunctionCallToString: + return r.strings_ToString(functionCall.Arguments) + case parsers.FunctionCallUpper: + return r.strings_Upper(functionCall.Arguments) + case parsers.FunctionCallLower: + return r.strings_Lower(functionCall.Arguments) + case parsers.FunctionCallLeft: + return r.strings_Left(functionCall.Arguments) + case parsers.FunctionCallLength: + return r.strings_Length(functionCall.Arguments) + case parsers.FunctionCallLTrim: + return r.strings_LTrim(functionCall.Arguments) + case parsers.FunctionCallReplace: + return r.strings_Replace(functionCall.Arguments) + case parsers.FunctionCallReplicate: + return r.strings_Replicate(functionCall.Arguments) + case parsers.FunctionCallReverse: + return r.strings_Reverse(functionCall.Arguments) + case parsers.FunctionCallRight: + return r.strings_Right(functionCall.Arguments) + case parsers.FunctionCallRTrim: + return r.strings_RTrim(functionCall.Arguments) + case parsers.FunctionCallSubstring: + return r.strings_Substring(functionCall.Arguments) + case parsers.FunctionCallTrim: + return r.strings_Trim(functionCall.Arguments) + + case parsers.FunctionCallIsDefined: + return r.typeChecking_IsDefined(functionCall.Arguments) + case parsers.FunctionCallIsArray: + return r.typeChecking_IsArray(functionCall.Arguments) + case parsers.FunctionCallIsBool: + return r.typeChecking_IsBool(functionCall.Arguments) + case parsers.FunctionCallIsFiniteNumber: + return r.typeChecking_IsFiniteNumber(functionCall.Arguments) + case parsers.FunctionCallIsInteger: + return r.typeChecking_IsInteger(functionCall.Arguments) + case parsers.FunctionCallIsNull: + return r.typeChecking_IsNull(functionCall.Arguments) + case parsers.FunctionCallIsNumber: + return r.typeChecking_IsNumber(functionCall.Arguments) + case parsers.FunctionCallIsObject: + return r.typeChecking_IsObject(functionCall.Arguments) + case parsers.FunctionCallIsPrimitive: + return r.typeChecking_IsPrimitive(functionCall.Arguments) + case parsers.FunctionCallIsString: + return r.typeChecking_IsString(functionCall.Arguments) + + case parsers.FunctionCallArrayConcat: + return r.array_Concat(functionCall.Arguments) + case parsers.FunctionCallArrayContains: + return r.array_Contains(functionCall.Arguments) + case parsers.FunctionCallArrayContainsAny: + return r.array_Contains_Any(functionCall.Arguments) + case parsers.FunctionCallArrayContainsAll: + return r.array_Contains_All(functionCall.Arguments) + case parsers.FunctionCallArrayLength: + return r.array_Length(functionCall.Arguments) + case parsers.FunctionCallArraySlice: + return r.array_Slice(functionCall.Arguments) + case parsers.FunctionCallSetIntersect: + return r.set_Intersect(functionCall.Arguments) + case parsers.FunctionCallSetUnion: + return r.set_Union(functionCall.Arguments) + + case parsers.FunctionCallMathAbs: + return r.math_Abs(functionCall.Arguments) + case parsers.FunctionCallMathAcos: + return r.math_Acos(functionCall.Arguments) + case parsers.FunctionCallMathAsin: + return r.math_Asin(functionCall.Arguments) + case parsers.FunctionCallMathAtan: + return r.math_Atan(functionCall.Arguments) + case parsers.FunctionCallMathCeiling: + return r.math_Ceiling(functionCall.Arguments) + case parsers.FunctionCallMathCos: + return r.math_Cos(functionCall.Arguments) + case parsers.FunctionCallMathCot: + return r.math_Cot(functionCall.Arguments) + case parsers.FunctionCallMathDegrees: + return r.math_Degrees(functionCall.Arguments) + case parsers.FunctionCallMathExp: + return r.math_Exp(functionCall.Arguments) + case parsers.FunctionCallMathFloor: + return r.math_Floor(functionCall.Arguments) + case parsers.FunctionCallMathIntBitNot: + return r.math_IntBitNot(functionCall.Arguments) + case parsers.FunctionCallMathLog10: + return r.math_Log10(functionCall.Arguments) + case parsers.FunctionCallMathRadians: + return r.math_Radians(functionCall.Arguments) + case parsers.FunctionCallMathRound: + return r.math_Round(functionCall.Arguments) + case parsers.FunctionCallMathSign: + return r.math_Sign(functionCall.Arguments) + case parsers.FunctionCallMathSin: + return r.math_Sin(functionCall.Arguments) + case parsers.FunctionCallMathSqrt: + return r.math_Sqrt(functionCall.Arguments) + case parsers.FunctionCallMathSquare: + return r.math_Square(functionCall.Arguments) + case parsers.FunctionCallMathTan: + return r.math_Tan(functionCall.Arguments) + case parsers.FunctionCallMathTrunc: + return r.math_Trunc(functionCall.Arguments) + case parsers.FunctionCallMathAtn2: + return r.math_Atn2(functionCall.Arguments) + case parsers.FunctionCallMathIntAdd: + return r.math_IntAdd(functionCall.Arguments) + case parsers.FunctionCallMathIntBitAnd: + return r.math_IntBitAnd(functionCall.Arguments) + case parsers.FunctionCallMathIntBitLeftShift: + return r.math_IntBitLeftShift(functionCall.Arguments) + case parsers.FunctionCallMathIntBitOr: + return r.math_IntBitOr(functionCall.Arguments) + case parsers.FunctionCallMathIntBitRightShift: + return r.math_IntBitRightShift(functionCall.Arguments) + case parsers.FunctionCallMathIntBitXor: + return r.math_IntBitXor(functionCall.Arguments) + case parsers.FunctionCallMathIntDiv: + return r.math_IntDiv(functionCall.Arguments) + case parsers.FunctionCallMathIntMod: + return r.math_IntMod(functionCall.Arguments) + case parsers.FunctionCallMathIntMul: + return r.math_IntMul(functionCall.Arguments) + case parsers.FunctionCallMathIntSub: + return r.math_IntSub(functionCall.Arguments) + case parsers.FunctionCallMathPower: + return r.math_Power(functionCall.Arguments) + case parsers.FunctionCallMathLog: + return r.math_Log(functionCall.Arguments) + case parsers.FunctionCallMathNumberBin: + return r.math_NumberBin(functionCall.Arguments) + case parsers.FunctionCallMathPi: + return r.math_Pi() + case parsers.FunctionCallMathRand: + return r.math_Rand() + + case parsers.FunctionCallAggregateAvg: + return r.aggregate_Avg(functionCall.Arguments) + case parsers.FunctionCallAggregateCount: + return r.aggregate_Count(functionCall.Arguments) + case parsers.FunctionCallAggregateMax: + return r.aggregate_Max(functionCall.Arguments) + case parsers.FunctionCallAggregateMin: + return r.aggregate_Min(functionCall.Arguments) + case parsers.FunctionCallAggregateSum: + return r.aggregate_Sum(functionCall.Arguments) + + case parsers.FunctionCallIn: + return r.misc_In(functionCall.Arguments) + } + + logger.Errorf("Unknown function call type: %v", functionCall.Type) + return nil +} + +func (r rowContext) selectItem_SelectItemTypeField(selectItem parsers.SelectItem) interface{} { + value := r.tables[selectItem.Path[0]] + + if len(selectItem.Path) > 1 { + for _, pathSegment := range selectItem.Path[1:] { + if pathSegment[0] == '@' { + pathSegment = r.parameters[pathSegment].(string) + } + + switch nestedValue := value.(type) { + case map[string]interface{}: + value = nestedValue[pathSegment] + case map[string]RowType: + value = nestedValue[pathSegment] + case datastore.Document: + value = nestedValue[pathSegment] + case map[string]datastore.Document: + value = nestedValue[pathSegment] + case []int, []string, []interface{}: + slice := reflect.ValueOf(nestedValue) + if arrayIndex, err := strconv.Atoi(pathSegment); err == nil && slice.Len() > arrayIndex { + value = slice.Index(arrayIndex).Interface() + } else { + return nil + } + default: + return nil + } + } + } + + return value +} + +func compareValues(val1, val2 interface{}) int { + if val1 == nil && val2 == nil { + return 0 + } else if val1 == nil { + return -1 + } else if val2 == nil { + return 1 + } + + if reflect.TypeOf(val1) != reflect.TypeOf(val2) { + return 1 + } + + switch val1 := val1.(type) { + case int: + val2 := val2.(int) + if val1 < val2 { + return -1 + } else if val1 > val2 { + return 1 + } + return 0 + case float64: + val2 := val2.(float64) + if val1 < val2 { + return -1 + } else if val1 > val2 { + return 1 + } + return 0 + case string: + val2 := val2.(string) + return strings.Compare(val1, val2) + case bool: + val2 := val2.(bool) + if val1 == val2 { + return 0 + } else if val1 { + return 1 + } else { + return -1 + } + // TODO: Add more types + default: + if reflect.DeepEqual(val1, val2) { + return 0 + } + return 1 + } +} + +func copyMap[T RowType | []RowType](originalMap map[string]T) map[string]T { + targetMap := make(map[string]T) + + for k, v := range originalMap { + targetMap[k] = v + } + + return targetMap +} diff --git a/query_executors/memory_executor/distinct_iterator.go b/query_executors/memory_executor/distinct_iterator.go new file mode 100644 index 0000000..a5c95fe --- /dev/null +++ b/query_executors/memory_executor/distinct_iterator.go @@ -0,0 +1,36 @@ +package memoryexecutor + +import "github.com/pikami/cosmium/internal/datastore" + +type distinctIterator struct { + documents rowTypeIterator + seenDocs []RowType +} + +func (di *distinctIterator) Next() (RowType, datastore.DataStoreStatus) { + if di.documents == nil { + return rowContext{}, datastore.IterEOF + } + + for { + row, status := di.documents.Next() + if status != datastore.StatusOk { + di.documents = nil + return rowContext{}, status + } + + if !di.seen(row) { + di.seenDocs = append(di.seenDocs, row) + return row, status + } + } +} + +func (di *distinctIterator) seen(row RowType) bool { + for _, seenRow := range di.seenDocs { + if compareValues(seenRow, row) == 0 { + return true + } + } + return false +} diff --git a/query_executors/memory_executor/filter_iterator.go b/query_executors/memory_executor/filter_iterator.go new file mode 100644 index 0000000..ed4a68b --- /dev/null +++ b/query_executors/memory_executor/filter_iterator.go @@ -0,0 +1,143 @@ +package memoryexecutor + +import ( + "github.com/pikami/cosmium/internal/datastore" + "github.com/pikami/cosmium/internal/logger" + "github.com/pikami/cosmium/parsers" +) + +type filterIterator struct { + documents rowIterator + filters interface{} +} + +func (fi *filterIterator) Next() (rowContext, datastore.DataStoreStatus) { + if fi.documents == nil { + return rowContext{}, datastore.IterEOF + } + + for { + row, status := fi.documents.Next() + if status != datastore.StatusOk { + fi.documents = nil + return rowContext{}, status + } + + if fi.evaluateFilters(row) { + return row, status + } + } +} + +func (fi *filterIterator) evaluateFilters(row rowContext) bool { + if fi.filters == nil { + return true + } + + switch typedFilters := fi.filters.(type) { + case parsers.ComparisonExpression: + return row.filters_ComparisonExpression(typedFilters) + case parsers.LogicalExpression: + return row.filters_LogicalExpression(typedFilters) + case parsers.Constant: + if value, ok := typedFilters.Value.(bool); ok { + return value + } + return false + case parsers.SelectItem: + resolvedValue := row.resolveSelectItem(typedFilters) + if value, ok := resolvedValue.(bool); ok { + if typedFilters.Invert { + return !value + } + + return value + } + } + + return false +} + +func (r rowContext) applyFilters(filters interface{}) bool { + if filters == nil { + return true + } + + switch typedFilters := filters.(type) { + case parsers.ComparisonExpression: + return r.filters_ComparisonExpression(typedFilters) + case parsers.LogicalExpression: + return r.filters_LogicalExpression(typedFilters) + case parsers.Constant: + if value, ok := typedFilters.Value.(bool); ok { + return value + } + return false + case parsers.SelectItem: + resolvedValue := r.resolveSelectItem(typedFilters) + if value, ok := resolvedValue.(bool); ok { + if typedFilters.Invert { + return !value + } + + return value + } + } + + return false +} + +func (r rowContext) filters_ComparisonExpression(expression parsers.ComparisonExpression) bool { + leftExpression, leftExpressionOk := expression.Left.(parsers.SelectItem) + rightExpression, rightExpressionOk := expression.Right.(parsers.SelectItem) + + if !leftExpressionOk || !rightExpressionOk { + logger.ErrorLn("ComparisonExpression has incorrect Left or Right type") + return false + } + + leftValue := r.resolveSelectItem(leftExpression) + rightValue := r.resolveSelectItem(rightExpression) + + cmp := compareValues(leftValue, rightValue) + switch expression.Operation { + case "=": + return cmp == 0 + case "!=": + return cmp != 0 + case "<": + return cmp < 0 + case ">": + return cmp > 0 + case "<=": + return cmp <= 0 + case ">=": + return cmp >= 0 + } + + return false +} + +func (r rowContext) filters_LogicalExpression(expression parsers.LogicalExpression) bool { + var result bool + for i, subExpression := range expression.Expressions { + expressionResult := r.applyFilters(subExpression) + if i == 0 { + result = expressionResult + } + + switch expression.Operation { + case parsers.LogicalExpressionTypeAnd: + result = result && expressionResult + if !result { + return false + } + case parsers.LogicalExpressionTypeOr: + result = result || expressionResult + if result { + return true + } + } + } + return result +} diff --git a/query_executors/memory_executor/from_iterator.go b/query_executors/memory_executor/from_iterator.go new file mode 100644 index 0000000..5b5d58a --- /dev/null +++ b/query_executors/memory_executor/from_iterator.go @@ -0,0 +1,73 @@ +package memoryexecutor + +import ( + "github.com/pikami/cosmium/internal/datastore" + "github.com/pikami/cosmium/parsers" +) + +type fromIterator struct { + documents rowIterator + table parsers.Table + buffer []rowContext + bufferIndex int +} + +func (fi *fromIterator) Next() (rowContext, datastore.DataStoreStatus) { + if fi.documents == nil { + return rowContext{}, datastore.IterEOF + } + + // Return from buffer if available + if fi.bufferIndex < len(fi.buffer) { + result := fi.buffer[fi.bufferIndex] + fi.buffer[fi.bufferIndex] = rowContext{} + fi.bufferIndex++ + return result, datastore.StatusOk + } + + // Resolve next row from documents + row, status := fi.documents.Next() + if status != datastore.StatusOk { + fi.documents = nil + return row, status + } + + if fi.table.SelectItem.Path != nil || fi.table.SelectItem.Type == parsers.SelectItemTypeSubQuery { + destinationTableName := fi.table.SelectItem.Alias + if destinationTableName == "" { + destinationTableName = fi.table.Value + } + if destinationTableName == "" { + destinationTableName = resolveDestinationColumnName(fi.table.SelectItem, 0, row.parameters) + } + + if fi.table.IsInSelect || fi.table.SelectItem.Type == parsers.SelectItemTypeSubQuery { + selectValue := row.parseArray(fi.table.SelectItem) + rowContexts := make([]rowContext, len(selectValue)) + for i, newRowData := range selectValue { + rowContexts[i].parameters = row.parameters + rowContexts[i].tables = copyMap(row.tables) + rowContexts[i].tables[destinationTableName] = newRowData + } + + fi.buffer = rowContexts + fi.bufferIndex = 0 + return fi.Next() + } + + if len(fi.table.SelectItem.Path) > 0 { + sourceTableName := fi.table.SelectItem.Path[0] + sourceTableData := row.tables[sourceTableName] + if sourceTableData == nil { + // When source table is not found, assume it's root document + row.tables[sourceTableName] = row.tables["$root"] + } + } + + newRowData := row.resolveSelectItem(fi.table.SelectItem) + row.tables[destinationTableName] = newRowData + return row, status + } + + return row, status +} diff --git a/query_executors/memory_executor/groupBy_iterator.go b/query_executors/memory_executor/groupBy_iterator.go new file mode 100644 index 0000000..5e132b3 --- /dev/null +++ b/query_executors/memory_executor/groupBy_iterator.go @@ -0,0 +1,69 @@ +package memoryexecutor + +import ( + "fmt" + "strings" + + "github.com/pikami/cosmium/internal/datastore" + "github.com/pikami/cosmium/parsers" +) + +type groupByIterator struct { + documents rowIterator + groupBy []parsers.SelectItem + groupedRows []rowContext +} + +func (gi *groupByIterator) Next() (rowContext, datastore.DataStoreStatus) { + if gi.groupedRows != nil { + if len(gi.groupedRows) == 0 { + return rowContext{}, datastore.IterEOF + } + row := gi.groupedRows[0] + gi.groupedRows = gi.groupedRows[1:] + return row, datastore.StatusOk + } + + documents := make([]rowContext, 0) + for { + row, status := gi.documents.Next() + if status != datastore.StatusOk { + break + } + + documents = append(documents, row) + } + gi.documents = nil + + groupedRows := make(map[string][]rowContext) + groupedKeys := make([]string, 0) + + for _, row := range documents { + key := row.generateGroupByKey(gi.groupBy) + if _, ok := groupedRows[key]; !ok { + groupedKeys = append(groupedKeys, key) + } + groupedRows[key] = append(groupedRows[key], row) + } + + gi.groupedRows = make([]rowContext, 0) + for _, key := range groupedKeys { + gi.groupedRows = append(gi.groupedRows, rowContext{ + tables: groupedRows[key][0].tables, + parameters: groupedRows[key][0].parameters, + grouppedRows: groupedRows[key], + }) + } + + return gi.Next() +} + +func (r rowContext) generateGroupByKey(groupBy []parsers.SelectItem) string { + var keyBuilder strings.Builder + for _, selectItem := range groupBy { + value := r.resolveSelectItem(selectItem) + keyBuilder.WriteString(fmt.Sprintf("%v", value)) + keyBuilder.WriteString(":") + } + return keyBuilder.String() +} diff --git a/query_executors/memory_executor/join_iterator.go b/query_executors/memory_executor/join_iterator.go new file mode 100644 index 0000000..a747420 --- /dev/null +++ b/query_executors/memory_executor/join_iterator.go @@ -0,0 +1,62 @@ +package memoryexecutor + +import ( + "github.com/pikami/cosmium/internal/datastore" + "github.com/pikami/cosmium/parsers" +) + +type joinIterator struct { + documents rowIterator + query parsers.SelectStmt + buffer []rowContext +} + +func (ji *joinIterator) Next() (rowContext, datastore.DataStoreStatus) { + if ji.documents == nil { + return rowContext{}, datastore.IterEOF + } + + if len(ji.buffer) > 0 { + row := ji.buffer[0] + ji.buffer = ji.buffer[1:] + return row, datastore.StatusOk + } + + doc, status := ji.documents.Next() + if status != datastore.StatusOk { + ji.documents = nil + return rowContext{}, status + } + + ji.buffer = []rowContext{doc} + for _, joinItem := range ji.query.JoinItems { + nextDocuments := make([]rowContext, 0) + for _, row := range ji.buffer { + joinedItems := row.resolveJoinItemSelect(joinItem.SelectItem) + for _, joinedItem := range joinedItems { + tablesCopy := copyMap(row.tables) + tablesCopy[joinItem.Table.Value] = joinedItem + nextDocuments = append(nextDocuments, rowContext{ + parameters: row.parameters, + tables: tablesCopy, + }) + } + } + ji.buffer = nextDocuments + } + + return ji.Next() +} + +func (r rowContext) resolveJoinItemSelect(selectItem parsers.SelectItem) []RowType { + if selectItem.Path != nil || selectItem.Type == parsers.SelectItemTypeSubQuery { + selectValue := r.parseArray(selectItem) + documents := make([]RowType, len(selectValue)) + for i, newRowData := range selectValue { + documents[i] = newRowData + } + return documents + } + + return []RowType{} +} diff --git a/query_executors/memory_executor/limit_iterator.go b/query_executors/memory_executor/limit_iterator.go new file mode 100644 index 0000000..7667e14 --- /dev/null +++ b/query_executors/memory_executor/limit_iterator.go @@ -0,0 +1,19 @@ +package memoryexecutor + +import "github.com/pikami/cosmium/internal/datastore" + +type limitIterator struct { + documents rowTypeIterator + limit int + count int +} + +func (li *limitIterator) Next() (RowType, datastore.DataStoreStatus) { + if li.count >= li.limit { + li.documents = nil + return rowContext{}, datastore.IterEOF + } + + li.count++ + return li.documents.Next() +} diff --git a/query_executors/memory_executor/memory_executor.go b/query_executors/memory_executor/memory_executor.go index 6c5c460..89fd362 100644 --- a/query_executors/memory_executor/memory_executor.go +++ b/query_executors/memory_executor/memory_executor.go @@ -1,752 +1,92 @@ package memoryexecutor import ( - "fmt" - "reflect" - "sort" - "strconv" - "strings" - - "github.com/pikami/cosmium/internal/logger" + "github.com/pikami/cosmium/internal/datastore" "github.com/pikami/cosmium/parsers" - "golang.org/x/exp/slices" ) -type RowType interface{} -type rowContext struct { - tables map[string]RowType - parameters map[string]interface{} - grouppedRows []rowContext +func ExecuteQuery(query parsers.SelectStmt, documents rowTypeIterator) []RowType { + resultIter := executeQuery(query, &rowTypeToRowContextIterator{documents: documents, query: query}) + result := make([]RowType, 0) + for { + row, status := resultIter.Next() + if status != datastore.StatusOk { + break + } + + result = append(result, row) + } + return result } -func ExecuteQuery(query parsers.SelectStmt, documents []RowType) []RowType { - currentDocuments := make([]rowContext, 0) - for _, doc := range documents { - currentDocuments = append(currentDocuments, resolveFrom(query, doc)...) +func executeQuery(query parsers.SelectStmt, documents rowIterator) rowTypeIterator { + // Resolve FROM + var iter rowIterator = &fromIterator{ + documents: documents, + table: query.Table, } - // Handle JOINS - nextDocuments := make([]rowContext, 0) - for _, currentDocument := range currentDocuments { - rowContexts := currentDocument.handleJoin(query) - nextDocuments = append(nextDocuments, rowContexts...) - } - currentDocuments = nextDocuments - - // Apply filters - nextDocuments = make([]rowContext, 0) - for _, currentDocument := range currentDocuments { - if currentDocument.applyFilters(query.Filters) { - nextDocuments = append(nextDocuments, currentDocument) + // Apply JOIN + if len(query.JoinItems) > 0 { + iter = &joinIterator{ + documents: iter, + query: query, } } - currentDocuments = nextDocuments - // Apply order + // Apply WHERE + if query.Filters != nil { + iter = &filterIterator{ + documents: iter, + filters: query.Filters, + } + } + + // Apply ORDER BY if len(query.OrderExpressions) > 0 { - applyOrder(currentDocuments, query.OrderExpressions) + iter = &orderIterator{ + documents: iter, + orderExpressions: query.OrderExpressions, + } } - // Apply group by + // Apply GROUP BY if len(query.GroupBy) > 0 { - currentDocuments = applyGroupBy(currentDocuments, query.GroupBy) + iter = &groupByIterator{ + documents: iter, + groupBy: query.GroupBy, + } } - // Apply select - projectedDocuments := applyProjection(currentDocuments, query.SelectItems, query.GroupBy) + // Apply SELECT + var projectedIterator rowTypeIterator = &projectIterator{ + documents: iter, + selectItems: query.SelectItems, + groupBy: query.GroupBy, + } - // Apply distinct + // Apply DISTINCT if query.Distinct { - projectedDocuments = deduplicate(projectedDocuments) + projectedIterator = &distinctIterator{ + documents: projectedIterator, + } } - // Apply offset + // Apply OFFSET if query.Offset > 0 { - if query.Offset < len(projectedDocuments) { - projectedDocuments = projectedDocuments[query.Offset:] - } else { - projectedDocuments = []RowType{} + projectedIterator = &offsetIterator{ + documents: projectedIterator, + offset: query.Offset, } } - // Apply result limit - if query.Count > 0 && len(projectedDocuments) > query.Count { - projectedDocuments = projectedDocuments[:query.Count] + // Apply LIMIT + if query.Count > 0 { + projectedIterator = &limitIterator{ + documents: projectedIterator, + limit: query.Count, + } } - return projectedDocuments -} - -func resolveFrom(query parsers.SelectStmt, doc RowType) []rowContext { - initialRow, gotParentContext := doc.(rowContext) - if !gotParentContext { - var initialTableName string - if query.Table.SelectItem.Type == parsers.SelectItemTypeSubQuery { - initialTableName = query.Table.SelectItem.Value.(parsers.SelectStmt).Table.Value - } - - if initialTableName == "" { - initialTableName = query.Table.Value - } - - if initialTableName == "" { - initialTableName = resolveDestinationColumnName(query.Table.SelectItem, 0, query.Parameters) - } - - initialRow = rowContext{ - parameters: query.Parameters, - tables: map[string]RowType{ - initialTableName: doc, - "$root": doc, - }, - } - } - - if query.Table.SelectItem.Path != nil || query.Table.SelectItem.Type == parsers.SelectItemTypeSubQuery { - destinationTableName := query.Table.SelectItem.Alias - if destinationTableName == "" { - destinationTableName = query.Table.Value - } - if destinationTableName == "" { - destinationTableName = resolveDestinationColumnName(query.Table.SelectItem, 0, initialRow.parameters) - } - - if query.Table.IsInSelect || query.Table.SelectItem.Type == parsers.SelectItemTypeSubQuery { - selectValue := initialRow.parseArray(query.Table.SelectItem) - rowContexts := make([]rowContext, len(selectValue)) - for i, newRowData := range selectValue { - rowContexts[i].parameters = initialRow.parameters - rowContexts[i].tables = copyMap(initialRow.tables) - rowContexts[i].tables[destinationTableName] = newRowData - } - return rowContexts - } - - if len(query.Table.SelectItem.Path) > 0 { - sourceTableName := query.Table.SelectItem.Path[0] - sourceTableData := initialRow.tables[sourceTableName] - if sourceTableData == nil { - // When source table is not found, assume it's root document - initialRow.tables[sourceTableName] = initialRow.tables["$root"] - } - } - - newRowData := initialRow.resolveSelectItem(query.Table.SelectItem) - initialRow.tables[destinationTableName] = newRowData - return []rowContext{initialRow} - } - - return []rowContext{initialRow} -} - -func (r rowContext) handleJoin(query parsers.SelectStmt) []rowContext { - currentDocuments := []rowContext{r} - - for _, joinItem := range query.JoinItems { - nextDocuments := make([]rowContext, 0) - for _, currentDocument := range currentDocuments { - joinedItems := currentDocument.resolveJoinItemSelect(joinItem.SelectItem) - for _, joinedItem := range joinedItems { - tablesCopy := copyMap(currentDocument.tables) - tablesCopy[joinItem.Table.Value] = joinedItem - nextDocuments = append(nextDocuments, rowContext{ - parameters: currentDocument.parameters, - tables: tablesCopy, - }) - } - } - currentDocuments = nextDocuments - } - - return currentDocuments -} - -func (r rowContext) resolveJoinItemSelect(selectItem parsers.SelectItem) []RowType { - if selectItem.Path != nil || selectItem.Type == parsers.SelectItemTypeSubQuery { - selectValue := r.parseArray(selectItem) - documents := make([]RowType, len(selectValue)) - for i, newRowData := range selectValue { - documents[i] = newRowData - } - return documents - } - - return []RowType{} -} - -func (r rowContext) applyFilters(filters interface{}) bool { - if filters == nil { - return true - } - - switch typedFilters := filters.(type) { - case parsers.ComparisonExpression: - return r.filters_ComparisonExpression(typedFilters) - case parsers.LogicalExpression: - return r.filters_LogicalExpression(typedFilters) - case parsers.Constant: - if value, ok := typedFilters.Value.(bool); ok { - return value - } - return false - case parsers.SelectItem: - resolvedValue := r.resolveSelectItem(typedFilters) - if value, ok := resolvedValue.(bool); ok { - if typedFilters.Invert { - return !value - } - - return value - } - } - - return false -} - -func (r rowContext) filters_ComparisonExpression(expression parsers.ComparisonExpression) bool { - leftExpression, leftExpressionOk := expression.Left.(parsers.SelectItem) - rightExpression, rightExpressionOk := expression.Right.(parsers.SelectItem) - - if !leftExpressionOk || !rightExpressionOk { - logger.ErrorLn("ComparisonExpression has incorrect Left or Right type") - return false - } - - leftValue := r.resolveSelectItem(leftExpression) - rightValue := r.resolveSelectItem(rightExpression) - - cmp := compareValues(leftValue, rightValue) - switch expression.Operation { - case "=": - return cmp == 0 - case "!=": - return cmp != 0 - case "<": - return cmp < 0 - case ">": - return cmp > 0 - case "<=": - return cmp <= 0 - case ">=": - return cmp >= 0 - } - - return false -} - -func (r rowContext) filters_LogicalExpression(expression parsers.LogicalExpression) bool { - var result bool - for i, subExpression := range expression.Expressions { - expressionResult := r.applyFilters(subExpression) - if i == 0 { - result = expressionResult - } - - switch expression.Operation { - case parsers.LogicalExpressionTypeAnd: - result = result && expressionResult - if !result { - return false - } - case parsers.LogicalExpressionTypeOr: - result = result || expressionResult - if result { - return true - } - } - } - return result -} - -func applyOrder(documents []rowContext, orderExpressions []parsers.OrderExpression) { - less := func(i, j int) bool { - for _, order := range orderExpressions { - val1 := documents[i].resolveSelectItem(order.SelectItem) - val2 := documents[j].resolveSelectItem(order.SelectItem) - - cmp := compareValues(val1, val2) - if cmp != 0 { - if order.Direction == parsers.OrderDirectionDesc { - return cmp > 0 - } - return cmp < 0 - } - } - return i < j - } - - sort.SliceStable(documents, less) -} - -func applyGroupBy(documents []rowContext, groupBy []parsers.SelectItem) []rowContext { - groupedRows := make(map[string][]rowContext) - groupedKeys := make([]string, 0) - - for _, row := range documents { - key := row.generateGroupByKey(groupBy) - if _, ok := groupedRows[key]; !ok { - groupedKeys = append(groupedKeys, key) - } - groupedRows[key] = append(groupedRows[key], row) - } - - grouppedRows := make([]rowContext, 0) - for _, key := range groupedKeys { - grouppedRowContext := rowContext{ - tables: groupedRows[key][0].tables, - parameters: groupedRows[key][0].parameters, - grouppedRows: groupedRows[key], - } - grouppedRows = append(grouppedRows, grouppedRowContext) - } - - return grouppedRows -} - -func (r rowContext) generateGroupByKey(groupBy []parsers.SelectItem) string { - var keyBuilder strings.Builder - for _, selectItem := range groupBy { - value := r.resolveSelectItem(selectItem) - keyBuilder.WriteString(fmt.Sprintf("%v", value)) - keyBuilder.WriteString(":") - } - return keyBuilder.String() -} - -func applyProjection(documents []rowContext, selectItems []parsers.SelectItem, groupBy []parsers.SelectItem) []RowType { - if len(documents) == 0 { - return []RowType{} - } - - if hasAggregateFunctions(selectItems) && len(groupBy) == 0 { - // When can have aggregate functions without GROUP BY clause, - // we should aggregate all rows in that case - rowContext := rowContext{ - tables: documents[0].tables, - parameters: documents[0].parameters, - grouppedRows: documents, - } - return []RowType{rowContext.applyProjection(selectItems)} - } - - projectedDocuments := make([]RowType, len(documents)) - for index, row := range documents { - projectedDocuments[index] = row.applyProjection(selectItems) - } - - return projectedDocuments -} - -func (r rowContext) applyProjection(selectItems []parsers.SelectItem) RowType { - // When the first value is top level, select it instead - if len(selectItems) > 0 && selectItems[0].IsTopLevel { - return r.resolveSelectItem(selectItems[0]) - } - - // Construct a new row based on the selected columns - row := make(map[string]interface{}) - for index, selectItem := range selectItems { - destinationName := resolveDestinationColumnName(selectItem, index, r.parameters) - - row[destinationName] = r.resolveSelectItem(selectItem) - } - - return row -} - -func resolveDestinationColumnName(selectItem parsers.SelectItem, itemIndex int, queryParameters map[string]interface{}) string { - if selectItem.Alias != "" { - return selectItem.Alias - } - - destinationName := fmt.Sprintf("$%d", itemIndex+1) - if len(selectItem.Path) > 0 { - destinationName = selectItem.Path[len(selectItem.Path)-1] - } - - if destinationName[0] == '@' { - destinationName = queryParameters[destinationName].(string) - } - - return destinationName -} - -func (r rowContext) resolveSelectItem(selectItem parsers.SelectItem) interface{} { - if selectItem.Type == parsers.SelectItemTypeArray { - return r.selectItem_SelectItemTypeArray(selectItem) - } - - if selectItem.Type == parsers.SelectItemTypeObject { - return r.selectItem_SelectItemTypeObject(selectItem) - } - - if selectItem.Type == parsers.SelectItemTypeConstant { - return r.selectItem_SelectItemTypeConstant(selectItem) - } - - if selectItem.Type == parsers.SelectItemTypeSubQuery { - return r.selectItem_SelectItemTypeSubQuery(selectItem) - } - - if selectItem.Type == parsers.SelectItemTypeFunctionCall { - if typedFunctionCall, ok := selectItem.Value.(parsers.FunctionCall); ok { - return r.selectItem_SelectItemTypeFunctionCall(typedFunctionCall) - } - - logger.ErrorLn("parsers.SelectItem has incorrect Value type (expected parsers.FunctionCall)") - return nil - } - - return r.selectItem_SelectItemTypeField(selectItem) -} - -func (r rowContext) selectItem_SelectItemTypeArray(selectItem parsers.SelectItem) interface{} { - arrayValue := make([]interface{}, 0) - for _, subSelectItem := range selectItem.SelectItems { - arrayValue = append(arrayValue, r.resolveSelectItem(subSelectItem)) - } - return arrayValue -} - -func (r rowContext) selectItem_SelectItemTypeObject(selectItem parsers.SelectItem) interface{} { - objectValue := make(map[string]interface{}) - for _, subSelectItem := range selectItem.SelectItems { - objectValue[subSelectItem.Alias] = r.resolveSelectItem(subSelectItem) - } - return objectValue -} - -func (r rowContext) selectItem_SelectItemTypeConstant(selectItem parsers.SelectItem) interface{} { - var typedValue parsers.Constant - var ok bool - if typedValue, ok = selectItem.Value.(parsers.Constant); !ok { - // TODO: Handle error - logger.ErrorLn("parsers.Constant has incorrect Value type") - } - - if typedValue.Type == parsers.ConstantTypeParameterConstant && - r.parameters != nil { - if key, ok := typedValue.Value.(string); ok { - return r.parameters[key] - } - } - - return typedValue.Value -} - -func (r rowContext) selectItem_SelectItemTypeSubQuery(selectItem parsers.SelectItem) interface{} { - subQuery := selectItem.Value.(parsers.SelectStmt) - subQueryResult := ExecuteQuery( - subQuery, - []RowType{r}, - ) - - if subQuery.Exists { - return len(subQueryResult) > 0 - } - - return subQueryResult -} - -func (r rowContext) selectItem_SelectItemTypeFunctionCall(functionCall parsers.FunctionCall) interface{} { - switch functionCall.Type { - case parsers.FunctionCallStringEquals: - return r.strings_StringEquals(functionCall.Arguments) - case parsers.FunctionCallContains: - return r.strings_Contains(functionCall.Arguments) - case parsers.FunctionCallEndsWith: - return r.strings_EndsWith(functionCall.Arguments) - case parsers.FunctionCallStartsWith: - return r.strings_StartsWith(functionCall.Arguments) - case parsers.FunctionCallConcat: - return r.strings_Concat(functionCall.Arguments) - case parsers.FunctionCallIndexOf: - return r.strings_IndexOf(functionCall.Arguments) - case parsers.FunctionCallToString: - return r.strings_ToString(functionCall.Arguments) - case parsers.FunctionCallUpper: - return r.strings_Upper(functionCall.Arguments) - case parsers.FunctionCallLower: - return r.strings_Lower(functionCall.Arguments) - case parsers.FunctionCallLeft: - return r.strings_Left(functionCall.Arguments) - case parsers.FunctionCallLength: - return r.strings_Length(functionCall.Arguments) - case parsers.FunctionCallLTrim: - return r.strings_LTrim(functionCall.Arguments) - case parsers.FunctionCallReplace: - return r.strings_Replace(functionCall.Arguments) - case parsers.FunctionCallReplicate: - return r.strings_Replicate(functionCall.Arguments) - case parsers.FunctionCallReverse: - return r.strings_Reverse(functionCall.Arguments) - case parsers.FunctionCallRight: - return r.strings_Right(functionCall.Arguments) - case parsers.FunctionCallRTrim: - return r.strings_RTrim(functionCall.Arguments) - case parsers.FunctionCallSubstring: - return r.strings_Substring(functionCall.Arguments) - case parsers.FunctionCallTrim: - return r.strings_Trim(functionCall.Arguments) - - case parsers.FunctionCallIsDefined: - return r.typeChecking_IsDefined(functionCall.Arguments) - case parsers.FunctionCallIsArray: - return r.typeChecking_IsArray(functionCall.Arguments) - case parsers.FunctionCallIsBool: - return r.typeChecking_IsBool(functionCall.Arguments) - case parsers.FunctionCallIsFiniteNumber: - return r.typeChecking_IsFiniteNumber(functionCall.Arguments) - case parsers.FunctionCallIsInteger: - return r.typeChecking_IsInteger(functionCall.Arguments) - case parsers.FunctionCallIsNull: - return r.typeChecking_IsNull(functionCall.Arguments) - case parsers.FunctionCallIsNumber: - return r.typeChecking_IsNumber(functionCall.Arguments) - case parsers.FunctionCallIsObject: - return r.typeChecking_IsObject(functionCall.Arguments) - case parsers.FunctionCallIsPrimitive: - return r.typeChecking_IsPrimitive(functionCall.Arguments) - case parsers.FunctionCallIsString: - return r.typeChecking_IsString(functionCall.Arguments) - - case parsers.FunctionCallArrayConcat: - return r.array_Concat(functionCall.Arguments) - case parsers.FunctionCallArrayContains: - return r.array_Contains(functionCall.Arguments) - case parsers.FunctionCallArrayContainsAny: - return r.array_Contains_Any(functionCall.Arguments) - case parsers.FunctionCallArrayContainsAll: - return r.array_Contains_All(functionCall.Arguments) - case parsers.FunctionCallArrayLength: - return r.array_Length(functionCall.Arguments) - case parsers.FunctionCallArraySlice: - return r.array_Slice(functionCall.Arguments) - case parsers.FunctionCallSetIntersect: - return r.set_Intersect(functionCall.Arguments) - case parsers.FunctionCallSetUnion: - return r.set_Union(functionCall.Arguments) - - case parsers.FunctionCallMathAbs: - return r.math_Abs(functionCall.Arguments) - case parsers.FunctionCallMathAcos: - return r.math_Acos(functionCall.Arguments) - case parsers.FunctionCallMathAsin: - return r.math_Asin(functionCall.Arguments) - case parsers.FunctionCallMathAtan: - return r.math_Atan(functionCall.Arguments) - case parsers.FunctionCallMathCeiling: - return r.math_Ceiling(functionCall.Arguments) - case parsers.FunctionCallMathCos: - return r.math_Cos(functionCall.Arguments) - case parsers.FunctionCallMathCot: - return r.math_Cot(functionCall.Arguments) - case parsers.FunctionCallMathDegrees: - return r.math_Degrees(functionCall.Arguments) - case parsers.FunctionCallMathExp: - return r.math_Exp(functionCall.Arguments) - case parsers.FunctionCallMathFloor: - return r.math_Floor(functionCall.Arguments) - case parsers.FunctionCallMathIntBitNot: - return r.math_IntBitNot(functionCall.Arguments) - case parsers.FunctionCallMathLog10: - return r.math_Log10(functionCall.Arguments) - case parsers.FunctionCallMathRadians: - return r.math_Radians(functionCall.Arguments) - case parsers.FunctionCallMathRound: - return r.math_Round(functionCall.Arguments) - case parsers.FunctionCallMathSign: - return r.math_Sign(functionCall.Arguments) - case parsers.FunctionCallMathSin: - return r.math_Sin(functionCall.Arguments) - case parsers.FunctionCallMathSqrt: - return r.math_Sqrt(functionCall.Arguments) - case parsers.FunctionCallMathSquare: - return r.math_Square(functionCall.Arguments) - case parsers.FunctionCallMathTan: - return r.math_Tan(functionCall.Arguments) - case parsers.FunctionCallMathTrunc: - return r.math_Trunc(functionCall.Arguments) - case parsers.FunctionCallMathAtn2: - return r.math_Atn2(functionCall.Arguments) - case parsers.FunctionCallMathIntAdd: - return r.math_IntAdd(functionCall.Arguments) - case parsers.FunctionCallMathIntBitAnd: - return r.math_IntBitAnd(functionCall.Arguments) - case parsers.FunctionCallMathIntBitLeftShift: - return r.math_IntBitLeftShift(functionCall.Arguments) - case parsers.FunctionCallMathIntBitOr: - return r.math_IntBitOr(functionCall.Arguments) - case parsers.FunctionCallMathIntBitRightShift: - return r.math_IntBitRightShift(functionCall.Arguments) - case parsers.FunctionCallMathIntBitXor: - return r.math_IntBitXor(functionCall.Arguments) - case parsers.FunctionCallMathIntDiv: - return r.math_IntDiv(functionCall.Arguments) - case parsers.FunctionCallMathIntMod: - return r.math_IntMod(functionCall.Arguments) - case parsers.FunctionCallMathIntMul: - return r.math_IntMul(functionCall.Arguments) - case parsers.FunctionCallMathIntSub: - return r.math_IntSub(functionCall.Arguments) - case parsers.FunctionCallMathPower: - return r.math_Power(functionCall.Arguments) - case parsers.FunctionCallMathLog: - return r.math_Log(functionCall.Arguments) - case parsers.FunctionCallMathNumberBin: - return r.math_NumberBin(functionCall.Arguments) - case parsers.FunctionCallMathPi: - return r.math_Pi() - case parsers.FunctionCallMathRand: - return r.math_Rand() - - case parsers.FunctionCallAggregateAvg: - return r.aggregate_Avg(functionCall.Arguments) - case parsers.FunctionCallAggregateCount: - return r.aggregate_Count(functionCall.Arguments) - case parsers.FunctionCallAggregateMax: - return r.aggregate_Max(functionCall.Arguments) - case parsers.FunctionCallAggregateMin: - return r.aggregate_Min(functionCall.Arguments) - case parsers.FunctionCallAggregateSum: - return r.aggregate_Sum(functionCall.Arguments) - - case parsers.FunctionCallIn: - return r.misc_In(functionCall.Arguments) - } - - logger.Errorf("Unknown function call type: %v", functionCall.Type) - return nil -} - -func (r rowContext) selectItem_SelectItemTypeField(selectItem parsers.SelectItem) interface{} { - value := r.tables[selectItem.Path[0]] - - if len(selectItem.Path) > 1 { - for _, pathSegment := range selectItem.Path[1:] { - if pathSegment[0] == '@' { - pathSegment = r.parameters[pathSegment].(string) - } - - switch nestedValue := value.(type) { - case map[string]interface{}: - value = nestedValue[pathSegment] - case map[string]RowType: - value = nestedValue[pathSegment] - case []int, []string, []interface{}: - slice := reflect.ValueOf(nestedValue) - if arrayIndex, err := strconv.Atoi(pathSegment); err == nil && slice.Len() > arrayIndex { - value = slice.Index(arrayIndex).Interface() - } else { - return nil - } - default: - return nil - } - } - } - - return value -} - -func hasAggregateFunctions(selectItems []parsers.SelectItem) bool { - if selectItems == nil { - return false - } - - for _, selectItem := range selectItems { - if selectItem.Type == parsers.SelectItemTypeFunctionCall { - if typedValue, ok := selectItem.Value.(parsers.FunctionCall); ok && slices.Contains[[]parsers.FunctionCallType](parsers.AggregateFunctions, typedValue.Type) { - return true - } - } - - if hasAggregateFunctions(selectItem.SelectItems) { - return true - } - } - - return false -} - -func compareValues(val1, val2 interface{}) int { - if val1 == nil && val2 == nil { - return 0 - } else if val1 == nil { - return -1 - } else if val2 == nil { - return 1 - } - - if reflect.TypeOf(val1) != reflect.TypeOf(val2) { - return 1 - } - - switch val1 := val1.(type) { - case int: - val2 := val2.(int) - if val1 < val2 { - return -1 - } else if val1 > val2 { - return 1 - } - return 0 - case float64: - val2 := val2.(float64) - if val1 < val2 { - return -1 - } else if val1 > val2 { - return 1 - } - return 0 - case string: - val2 := val2.(string) - return strings.Compare(val1, val2) - case bool: - val2 := val2.(bool) - if val1 == val2 { - return 0 - } else if val1 { - return 1 - } else { - return -1 - } - // TODO: Add more types - default: - if reflect.DeepEqual(val1, val2) { - return 0 - } - return 1 - } -} - -func deduplicate[T RowType | interface{}](slice []T) []T { - var result []T - result = make([]T, 0) - - for i := 0; i < len(slice); i++ { - unique := true - for j := 0; j < len(result); j++ { - if compareValues(slice[i], result[j]) == 0 { - unique = false - break - } - } - - if unique { - result = append(result, slice[i]) - } - } - - return result -} - -func copyMap[T RowType | []RowType](originalMap map[string]T) map[string]T { - targetMap := make(map[string]T) - - for k, v := range originalMap { - targetMap[k] = v - } - - return targetMap + return projectedIterator } diff --git a/query_executors/memory_executor/misc_test.go b/query_executors/memory_executor/misc_test.go index 3862f6b..83f05da 100644 --- a/query_executors/memory_executor/misc_test.go +++ b/query_executors/memory_executor/misc_test.go @@ -4,18 +4,41 @@ import ( "reflect" "testing" + "github.com/pikami/cosmium/internal/datastore" "github.com/pikami/cosmium/parsers" memoryexecutor "github.com/pikami/cosmium/query_executors/memory_executor" testutils "github.com/pikami/cosmium/test_utils" ) +type TestDocumentIterator struct { + documents []memoryexecutor.RowType + index int +} + +func NewTestDocumentIterator(documents []memoryexecutor.RowType) *TestDocumentIterator { + return &TestDocumentIterator{ + documents: documents, + index: -1, + } +} + +func (i *TestDocumentIterator) Next() (memoryexecutor.RowType, datastore.DataStoreStatus) { + i.index++ + if i.index >= len(i.documents) { + return nil, datastore.IterEOF + } + + return i.documents[i.index], datastore.StatusOk +} + func testQueryExecute( t *testing.T, query parsers.SelectStmt, data []memoryexecutor.RowType, expectedData []memoryexecutor.RowType, ) { - result := memoryexecutor.ExecuteQuery(query, data) + iter := NewTestDocumentIterator(data) + result := memoryexecutor.ExecuteQuery(query, iter) if !reflect.DeepEqual(result, expectedData) { t.Errorf("execution result does not match expected data.\nExpected: %+v\nGot: %+v", expectedData, result) diff --git a/query_executors/memory_executor/offset_iterator.go b/query_executors/memory_executor/offset_iterator.go new file mode 100644 index 0000000..adb1197 --- /dev/null +++ b/query_executors/memory_executor/offset_iterator.go @@ -0,0 +1,22 @@ +package memoryexecutor + +import "github.com/pikami/cosmium/internal/datastore" + +type offsetIterator struct { + documents rowTypeIterator + offset int + skipped bool +} + +func (oi *offsetIterator) Next() (RowType, datastore.DataStoreStatus) { + if oi.skipped { + return oi.documents.Next() + } + + for i := 0; i < oi.offset; i++ { + oi.documents.Next() + } + + oi.skipped = true + return oi.Next() +} diff --git a/query_executors/memory_executor/order_iterator.go b/query_executors/memory_executor/order_iterator.go new file mode 100644 index 0000000..c541f43 --- /dev/null +++ b/query_executors/memory_executor/order_iterator.go @@ -0,0 +1,63 @@ +package memoryexecutor + +import ( + "sort" + + "github.com/pikami/cosmium/internal/datastore" + "github.com/pikami/cosmium/parsers" +) + +type orderIterator struct { + documents rowIterator + orderExpressions []parsers.OrderExpression + orderedDocs []rowContext + docsIndex int +} + +func (oi *orderIterator) Next() (rowContext, datastore.DataStoreStatus) { + if oi.orderedDocs != nil { + if oi.docsIndex >= len(oi.orderedDocs) { + return rowContext{}, datastore.IterEOF + } + row := oi.orderedDocs[oi.docsIndex] + oi.orderedDocs[oi.docsIndex] = rowContext{} + oi.docsIndex++ + return row, datastore.StatusOk + } + + oi.orderedDocs = make([]rowContext, 0) + for { + row, status := oi.documents.Next() + if status != datastore.StatusOk { + break + } + + oi.orderedDocs = append(oi.orderedDocs, row) + } + oi.documents = nil + + less := func(i, j int) bool { + for _, order := range oi.orderExpressions { + val1 := oi.orderedDocs[i].resolveSelectItem(order.SelectItem) + val2 := oi.orderedDocs[j].resolveSelectItem(order.SelectItem) + + cmp := compareValues(val1, val2) + if cmp != 0 { + if order.Direction == parsers.OrderDirectionDesc { + return cmp > 0 + } + return cmp < 0 + } + } + return i < j + } + + sort.SliceStable(oi.orderedDocs, less) + + if len(oi.orderedDocs) == 0 { + return rowContext{}, datastore.IterEOF + } + + oi.docsIndex = 1 + return oi.orderedDocs[0], datastore.StatusOk +} diff --git a/query_executors/memory_executor/project_iterator.go b/query_executors/memory_executor/project_iterator.go new file mode 100644 index 0000000..d2c440a --- /dev/null +++ b/query_executors/memory_executor/project_iterator.go @@ -0,0 +1,90 @@ +package memoryexecutor + +import ( + "github.com/pikami/cosmium/internal/datastore" + "github.com/pikami/cosmium/parsers" + "golang.org/x/exp/slices" +) + +type projectIterator struct { + documents rowIterator + selectItems []parsers.SelectItem + groupBy []parsers.SelectItem +} + +func (pi *projectIterator) Next() (RowType, datastore.DataStoreStatus) { + if pi.documents == nil { + return rowContext{}, datastore.IterEOF + } + + row, status := pi.documents.Next() + if status != datastore.StatusOk { + pi.documents = nil + return rowContext{}, status + } + + if hasAggregateFunctions(pi.selectItems) && len(pi.groupBy) == 0 { + // When can have aggregate functions without GROUP BY clause, + // we should aggregate all rows in that case. + allDocuments := []rowContext{row} + for { + row, status := pi.documents.Next() + if status != datastore.StatusOk { + break + } + + allDocuments = append(allDocuments, row) + } + + if len(allDocuments) == 0 { + return rowContext{}, datastore.IterEOF + } + + aggRow := rowContext{ + tables: row.tables, + parameters: row.parameters, + grouppedRows: allDocuments, + } + + return aggRow.applyProjection(pi.selectItems), datastore.StatusOk + } + + return row.applyProjection(pi.selectItems), datastore.StatusOk +} + +func (r rowContext) applyProjection(selectItems []parsers.SelectItem) RowType { + // When the first value is top level, select it instead + if len(selectItems) > 0 && selectItems[0].IsTopLevel { + return r.resolveSelectItem(selectItems[0]) + } + + // Construct a new row based on the selected columns + row := make(map[string]interface{}) + for index, selectItem := range selectItems { + destinationName := resolveDestinationColumnName(selectItem, index, r.parameters) + + row[destinationName] = r.resolveSelectItem(selectItem) + } + + return row +} + +func hasAggregateFunctions(selectItems []parsers.SelectItem) bool { + if selectItems == nil { + return false + } + + for _, selectItem := range selectItems { + if selectItem.Type == parsers.SelectItemTypeFunctionCall { + if typedValue, ok := selectItem.Value.(parsers.FunctionCall); ok && slices.Contains[[]parsers.FunctionCallType](parsers.AggregateFunctions, typedValue.Type) { + return true + } + } + + if hasAggregateFunctions(selectItem.SelectItems) { + return true + } + } + + return false +} diff --git a/query_executors/memory_executor/rowTypeToRowContext_iterator.go b/query_executors/memory_executor/rowTypeToRowContext_iterator.go new file mode 100644 index 0000000..304dbb0 --- /dev/null +++ b/query_executors/memory_executor/rowTypeToRowContext_iterator.go @@ -0,0 +1,44 @@ +package memoryexecutor + +import ( + "github.com/pikami/cosmium/internal/datastore" + "github.com/pikami/cosmium/parsers" +) + +type rowTypeToRowContextIterator struct { + documents rowTypeIterator + query parsers.SelectStmt +} + +func (di *rowTypeToRowContextIterator) Next() (rowContext, datastore.DataStoreStatus) { + if di.documents == nil { + return rowContext{}, datastore.IterEOF + } + + doc, status := di.documents.Next() + if status != datastore.StatusOk { + di.documents = nil + return rowContext{}, status + } + + var initialTableName string + if di.query.Table.SelectItem.Type == parsers.SelectItemTypeSubQuery { + initialTableName = di.query.Table.SelectItem.Value.(parsers.SelectStmt).Table.Value + } + + if initialTableName == "" { + initialTableName = di.query.Table.Value + } + + if initialTableName == "" { + initialTableName = resolveDestinationColumnName(di.query.Table.SelectItem, 0, di.query.Parameters) + } + + return rowContext{ + parameters: di.query.Parameters, + tables: map[string]RowType{ + initialTableName: doc, + "$root": doc, + }, + }, status +}