Refactored query engine utilizing iterators

This commit is contained in:
Pijus Kamandulis 2025-03-11 17:36:28 +02:00
parent 221f029a1d
commit e526b2269e
20 changed files with 1160 additions and 735 deletions

View File

@ -10,6 +10,7 @@ import (
"github.com/gin-gonic/gin"
apimodels "github.com/pikami/cosmium/api/api_models"
"github.com/pikami/cosmium/internal/constants"
"github.com/pikami/cosmium/internal/converters"
"github.com/pikami/cosmium/internal/datastore"
"github.com/pikami/cosmium/internal/logger"
"github.com/pikami/cosmium/parsers"
@ -378,20 +379,16 @@ func (h *Handlers) executeQueryDocuments(databaseId string, collectionId string,
return nil, datastore.BadRequest
}
collectionDocuments, status := h.dataStore.GetAllDocuments(databaseId, collectionId)
allDocumentsIterator, status := h.dataStore.GetDocumentIterator(databaseId, collectionId)
if status != datastore.StatusOk {
return nil, status
}
// TODO: Investigate, this could cause unnecessary memory usage
covDocs := make([]memoryexecutor.RowType, 0)
for _, doc := range collectionDocuments {
covDocs = append(covDocs, map[string]interface{}(doc))
}
rowsIterator := converters.NewDocumentToRowTypeIterator(allDocumentsIterator)
if typedQuery, ok := parsedQuery.(parsers.SelectStmt); ok {
typedQuery.Parameters = queryParameters
return memoryexecutor.ExecuteQuery(typedQuery, covDocs), datastore.StatusOk
return memoryexecutor.ExecuteQuery(typedQuery, rowsIterator), datastore.StatusOk
}
return nil, datastore.BadRequest

View File

@ -0,0 +1,20 @@
package converters
import (
"github.com/pikami/cosmium/internal/datastore"
memoryexecutor "github.com/pikami/cosmium/query_executors/memory_executor"
)
type DocumentToRowTypeIterator struct {
documents datastore.DocumentIterator
}
func NewDocumentToRowTypeIterator(documents datastore.DocumentIterator) *DocumentToRowTypeIterator {
return &DocumentToRowTypeIterator{
documents: documents,
}
}
func (di *DocumentToRowTypeIterator) Next() (memoryexecutor.RowType, datastore.DataStoreStatus) {
return di.documents.Next()
}

View File

@ -40,5 +40,4 @@ type DataStore interface {
type DocumentIterator interface {
Next() (Document, DataStoreStatus)
HasMore() bool
}

View File

@ -15,7 +15,3 @@ func (i *ArrayDocumentIterator) Next() (datastore.Document, datastore.DataStoreS
return i.documents[i.index], datastore.StatusOk
}
func (i *ArrayDocumentIterator) HasMore() bool {
return i.index < len(i.documents)-1
}

View File

@ -15,6 +15,7 @@ const (
StatusNotFound = 2
Conflict = 3
BadRequest = 4
IterEOF = 5
)
type TriggerOperation string

View File

@ -196,6 +196,10 @@ func (r rowContext) parseArray(argument interface{}) []interface{} {
ex := r.resolveSelectItem(exItem)
arrValue := reflect.ValueOf(ex)
if arrValue.Kind() == reflect.Invalid {
return nil
}
if arrValue.Kind() != reflect.Slice {
logger.ErrorLn("parseArray got parameters of wrong type")
return nil

View File

@ -0,0 +1,27 @@
package memoryexecutor
import "github.com/pikami/cosmium/internal/datastore"
type rowArrayIterator struct {
documents []rowContext
index int
}
func NewRowArrayIterator(documents []rowContext) *rowArrayIterator {
return &rowArrayIterator{
documents: documents,
index: -1,
}
}
func (i *rowArrayIterator) Next() (rowContext, datastore.DataStoreStatus) {
i.index++
if i.index >= len(i.documents) {
return rowContext{}, datastore.IterEOF
}
row := i.documents[i.index]
i.documents[i.index] = rowContext{} // Help GC reclaim memory
return row, datastore.StatusOk
}

View File

@ -0,0 +1,397 @@
package memoryexecutor
import (
"fmt"
"reflect"
"strconv"
"strings"
"github.com/pikami/cosmium/internal/datastore"
"github.com/pikami/cosmium/internal/logger"
"github.com/pikami/cosmium/parsers"
)
type RowType interface{}
type rowContext struct {
tables map[string]RowType
parameters map[string]interface{}
grouppedRows []rowContext
}
type rowIterator interface {
Next() (rowContext, datastore.DataStoreStatus)
}
type rowTypeIterator interface {
Next() (RowType, datastore.DataStoreStatus)
}
func resolveDestinationColumnName(selectItem parsers.SelectItem, itemIndex int, queryParameters map[string]interface{}) string {
if selectItem.Alias != "" {
return selectItem.Alias
}
destinationName := fmt.Sprintf("$%d", itemIndex+1)
if len(selectItem.Path) > 0 {
destinationName = selectItem.Path[len(selectItem.Path)-1]
}
if destinationName[0] == '@' {
destinationName = queryParameters[destinationName].(string)
}
return destinationName
}
func (r rowContext) resolveSelectItem(selectItem parsers.SelectItem) interface{} {
if selectItem.Type == parsers.SelectItemTypeArray {
return r.selectItem_SelectItemTypeArray(selectItem)
}
if selectItem.Type == parsers.SelectItemTypeObject {
return r.selectItem_SelectItemTypeObject(selectItem)
}
if selectItem.Type == parsers.SelectItemTypeConstant {
return r.selectItem_SelectItemTypeConstant(selectItem)
}
if selectItem.Type == parsers.SelectItemTypeSubQuery {
return r.selectItem_SelectItemTypeSubQuery(selectItem)
}
if selectItem.Type == parsers.SelectItemTypeFunctionCall {
if typedFunctionCall, ok := selectItem.Value.(parsers.FunctionCall); ok {
return r.selectItem_SelectItemTypeFunctionCall(typedFunctionCall)
}
logger.ErrorLn("parsers.SelectItem has incorrect Value type (expected parsers.FunctionCall)")
return nil
}
return r.selectItem_SelectItemTypeField(selectItem)
}
func (r rowContext) selectItem_SelectItemTypeArray(selectItem parsers.SelectItem) interface{} {
arrayValue := make([]interface{}, 0)
for _, subSelectItem := range selectItem.SelectItems {
arrayValue = append(arrayValue, r.resolveSelectItem(subSelectItem))
}
return arrayValue
}
func (r rowContext) selectItem_SelectItemTypeObject(selectItem parsers.SelectItem) interface{} {
objectValue := make(map[string]interface{})
for _, subSelectItem := range selectItem.SelectItems {
objectValue[subSelectItem.Alias] = r.resolveSelectItem(subSelectItem)
}
return objectValue
}
func (r rowContext) selectItem_SelectItemTypeConstant(selectItem parsers.SelectItem) interface{} {
var typedValue parsers.Constant
var ok bool
if typedValue, ok = selectItem.Value.(parsers.Constant); !ok {
// TODO: Handle error
logger.ErrorLn("parsers.Constant has incorrect Value type")
}
if typedValue.Type == parsers.ConstantTypeParameterConstant &&
r.parameters != nil {
if key, ok := typedValue.Value.(string); ok {
return r.parameters[key]
}
}
return typedValue.Value
}
func (r rowContext) selectItem_SelectItemTypeSubQuery(selectItem parsers.SelectItem) interface{} {
subQuery := selectItem.Value.(parsers.SelectStmt)
subQueryResult := executeQuery(
subQuery,
NewRowArrayIterator([]rowContext{r}),
)
if subQuery.Exists {
_, status := subQueryResult.Next()
return status == datastore.StatusOk
}
allDocuments := make([]RowType, 0)
for {
row, status := subQueryResult.Next()
if status != datastore.StatusOk {
break
}
allDocuments = append(allDocuments, row)
}
return allDocuments
}
func (r rowContext) selectItem_SelectItemTypeFunctionCall(functionCall parsers.FunctionCall) interface{} {
switch functionCall.Type {
case parsers.FunctionCallStringEquals:
return r.strings_StringEquals(functionCall.Arguments)
case parsers.FunctionCallContains:
return r.strings_Contains(functionCall.Arguments)
case parsers.FunctionCallEndsWith:
return r.strings_EndsWith(functionCall.Arguments)
case parsers.FunctionCallStartsWith:
return r.strings_StartsWith(functionCall.Arguments)
case parsers.FunctionCallConcat:
return r.strings_Concat(functionCall.Arguments)
case parsers.FunctionCallIndexOf:
return r.strings_IndexOf(functionCall.Arguments)
case parsers.FunctionCallToString:
return r.strings_ToString(functionCall.Arguments)
case parsers.FunctionCallUpper:
return r.strings_Upper(functionCall.Arguments)
case parsers.FunctionCallLower:
return r.strings_Lower(functionCall.Arguments)
case parsers.FunctionCallLeft:
return r.strings_Left(functionCall.Arguments)
case parsers.FunctionCallLength:
return r.strings_Length(functionCall.Arguments)
case parsers.FunctionCallLTrim:
return r.strings_LTrim(functionCall.Arguments)
case parsers.FunctionCallReplace:
return r.strings_Replace(functionCall.Arguments)
case parsers.FunctionCallReplicate:
return r.strings_Replicate(functionCall.Arguments)
case parsers.FunctionCallReverse:
return r.strings_Reverse(functionCall.Arguments)
case parsers.FunctionCallRight:
return r.strings_Right(functionCall.Arguments)
case parsers.FunctionCallRTrim:
return r.strings_RTrim(functionCall.Arguments)
case parsers.FunctionCallSubstring:
return r.strings_Substring(functionCall.Arguments)
case parsers.FunctionCallTrim:
return r.strings_Trim(functionCall.Arguments)
case parsers.FunctionCallIsDefined:
return r.typeChecking_IsDefined(functionCall.Arguments)
case parsers.FunctionCallIsArray:
return r.typeChecking_IsArray(functionCall.Arguments)
case parsers.FunctionCallIsBool:
return r.typeChecking_IsBool(functionCall.Arguments)
case parsers.FunctionCallIsFiniteNumber:
return r.typeChecking_IsFiniteNumber(functionCall.Arguments)
case parsers.FunctionCallIsInteger:
return r.typeChecking_IsInteger(functionCall.Arguments)
case parsers.FunctionCallIsNull:
return r.typeChecking_IsNull(functionCall.Arguments)
case parsers.FunctionCallIsNumber:
return r.typeChecking_IsNumber(functionCall.Arguments)
case parsers.FunctionCallIsObject:
return r.typeChecking_IsObject(functionCall.Arguments)
case parsers.FunctionCallIsPrimitive:
return r.typeChecking_IsPrimitive(functionCall.Arguments)
case parsers.FunctionCallIsString:
return r.typeChecking_IsString(functionCall.Arguments)
case parsers.FunctionCallArrayConcat:
return r.array_Concat(functionCall.Arguments)
case parsers.FunctionCallArrayContains:
return r.array_Contains(functionCall.Arguments)
case parsers.FunctionCallArrayContainsAny:
return r.array_Contains_Any(functionCall.Arguments)
case parsers.FunctionCallArrayContainsAll:
return r.array_Contains_All(functionCall.Arguments)
case parsers.FunctionCallArrayLength:
return r.array_Length(functionCall.Arguments)
case parsers.FunctionCallArraySlice:
return r.array_Slice(functionCall.Arguments)
case parsers.FunctionCallSetIntersect:
return r.set_Intersect(functionCall.Arguments)
case parsers.FunctionCallSetUnion:
return r.set_Union(functionCall.Arguments)
case parsers.FunctionCallMathAbs:
return r.math_Abs(functionCall.Arguments)
case parsers.FunctionCallMathAcos:
return r.math_Acos(functionCall.Arguments)
case parsers.FunctionCallMathAsin:
return r.math_Asin(functionCall.Arguments)
case parsers.FunctionCallMathAtan:
return r.math_Atan(functionCall.Arguments)
case parsers.FunctionCallMathCeiling:
return r.math_Ceiling(functionCall.Arguments)
case parsers.FunctionCallMathCos:
return r.math_Cos(functionCall.Arguments)
case parsers.FunctionCallMathCot:
return r.math_Cot(functionCall.Arguments)
case parsers.FunctionCallMathDegrees:
return r.math_Degrees(functionCall.Arguments)
case parsers.FunctionCallMathExp:
return r.math_Exp(functionCall.Arguments)
case parsers.FunctionCallMathFloor:
return r.math_Floor(functionCall.Arguments)
case parsers.FunctionCallMathIntBitNot:
return r.math_IntBitNot(functionCall.Arguments)
case parsers.FunctionCallMathLog10:
return r.math_Log10(functionCall.Arguments)
case parsers.FunctionCallMathRadians:
return r.math_Radians(functionCall.Arguments)
case parsers.FunctionCallMathRound:
return r.math_Round(functionCall.Arguments)
case parsers.FunctionCallMathSign:
return r.math_Sign(functionCall.Arguments)
case parsers.FunctionCallMathSin:
return r.math_Sin(functionCall.Arguments)
case parsers.FunctionCallMathSqrt:
return r.math_Sqrt(functionCall.Arguments)
case parsers.FunctionCallMathSquare:
return r.math_Square(functionCall.Arguments)
case parsers.FunctionCallMathTan:
return r.math_Tan(functionCall.Arguments)
case parsers.FunctionCallMathTrunc:
return r.math_Trunc(functionCall.Arguments)
case parsers.FunctionCallMathAtn2:
return r.math_Atn2(functionCall.Arguments)
case parsers.FunctionCallMathIntAdd:
return r.math_IntAdd(functionCall.Arguments)
case parsers.FunctionCallMathIntBitAnd:
return r.math_IntBitAnd(functionCall.Arguments)
case parsers.FunctionCallMathIntBitLeftShift:
return r.math_IntBitLeftShift(functionCall.Arguments)
case parsers.FunctionCallMathIntBitOr:
return r.math_IntBitOr(functionCall.Arguments)
case parsers.FunctionCallMathIntBitRightShift:
return r.math_IntBitRightShift(functionCall.Arguments)
case parsers.FunctionCallMathIntBitXor:
return r.math_IntBitXor(functionCall.Arguments)
case parsers.FunctionCallMathIntDiv:
return r.math_IntDiv(functionCall.Arguments)
case parsers.FunctionCallMathIntMod:
return r.math_IntMod(functionCall.Arguments)
case parsers.FunctionCallMathIntMul:
return r.math_IntMul(functionCall.Arguments)
case parsers.FunctionCallMathIntSub:
return r.math_IntSub(functionCall.Arguments)
case parsers.FunctionCallMathPower:
return r.math_Power(functionCall.Arguments)
case parsers.FunctionCallMathLog:
return r.math_Log(functionCall.Arguments)
case parsers.FunctionCallMathNumberBin:
return r.math_NumberBin(functionCall.Arguments)
case parsers.FunctionCallMathPi:
return r.math_Pi()
case parsers.FunctionCallMathRand:
return r.math_Rand()
case parsers.FunctionCallAggregateAvg:
return r.aggregate_Avg(functionCall.Arguments)
case parsers.FunctionCallAggregateCount:
return r.aggregate_Count(functionCall.Arguments)
case parsers.FunctionCallAggregateMax:
return r.aggregate_Max(functionCall.Arguments)
case parsers.FunctionCallAggregateMin:
return r.aggregate_Min(functionCall.Arguments)
case parsers.FunctionCallAggregateSum:
return r.aggregate_Sum(functionCall.Arguments)
case parsers.FunctionCallIn:
return r.misc_In(functionCall.Arguments)
}
logger.Errorf("Unknown function call type: %v", functionCall.Type)
return nil
}
func (r rowContext) selectItem_SelectItemTypeField(selectItem parsers.SelectItem) interface{} {
value := r.tables[selectItem.Path[0]]
if len(selectItem.Path) > 1 {
for _, pathSegment := range selectItem.Path[1:] {
if pathSegment[0] == '@' {
pathSegment = r.parameters[pathSegment].(string)
}
switch nestedValue := value.(type) {
case map[string]interface{}:
value = nestedValue[pathSegment]
case map[string]RowType:
value = nestedValue[pathSegment]
case datastore.Document:
value = nestedValue[pathSegment]
case map[string]datastore.Document:
value = nestedValue[pathSegment]
case []int, []string, []interface{}:
slice := reflect.ValueOf(nestedValue)
if arrayIndex, err := strconv.Atoi(pathSegment); err == nil && slice.Len() > arrayIndex {
value = slice.Index(arrayIndex).Interface()
} else {
return nil
}
default:
return nil
}
}
}
return value
}
func compareValues(val1, val2 interface{}) int {
if val1 == nil && val2 == nil {
return 0
} else if val1 == nil {
return -1
} else if val2 == nil {
return 1
}
if reflect.TypeOf(val1) != reflect.TypeOf(val2) {
return 1
}
switch val1 := val1.(type) {
case int:
val2 := val2.(int)
if val1 < val2 {
return -1
} else if val1 > val2 {
return 1
}
return 0
case float64:
val2 := val2.(float64)
if val1 < val2 {
return -1
} else if val1 > val2 {
return 1
}
return 0
case string:
val2 := val2.(string)
return strings.Compare(val1, val2)
case bool:
val2 := val2.(bool)
if val1 == val2 {
return 0
} else if val1 {
return 1
} else {
return -1
}
// TODO: Add more types
default:
if reflect.DeepEqual(val1, val2) {
return 0
}
return 1
}
}
func copyMap[T RowType | []RowType](originalMap map[string]T) map[string]T {
targetMap := make(map[string]T)
for k, v := range originalMap {
targetMap[k] = v
}
return targetMap
}

View File

@ -0,0 +1,36 @@
package memoryexecutor
import "github.com/pikami/cosmium/internal/datastore"
type distinctIterator struct {
documents rowTypeIterator
seenDocs []RowType
}
func (di *distinctIterator) Next() (RowType, datastore.DataStoreStatus) {
if di.documents == nil {
return rowContext{}, datastore.IterEOF
}
for {
row, status := di.documents.Next()
if status != datastore.StatusOk {
di.documents = nil
return rowContext{}, status
}
if !di.seen(row) {
di.seenDocs = append(di.seenDocs, row)
return row, status
}
}
}
func (di *distinctIterator) seen(row RowType) bool {
for _, seenRow := range di.seenDocs {
if compareValues(seenRow, row) == 0 {
return true
}
}
return false
}

View File

@ -0,0 +1,143 @@
package memoryexecutor
import (
"github.com/pikami/cosmium/internal/datastore"
"github.com/pikami/cosmium/internal/logger"
"github.com/pikami/cosmium/parsers"
)
type filterIterator struct {
documents rowIterator
filters interface{}
}
func (fi *filterIterator) Next() (rowContext, datastore.DataStoreStatus) {
if fi.documents == nil {
return rowContext{}, datastore.IterEOF
}
for {
row, status := fi.documents.Next()
if status != datastore.StatusOk {
fi.documents = nil
return rowContext{}, status
}
if fi.evaluateFilters(row) {
return row, status
}
}
}
func (fi *filterIterator) evaluateFilters(row rowContext) bool {
if fi.filters == nil {
return true
}
switch typedFilters := fi.filters.(type) {
case parsers.ComparisonExpression:
return row.filters_ComparisonExpression(typedFilters)
case parsers.LogicalExpression:
return row.filters_LogicalExpression(typedFilters)
case parsers.Constant:
if value, ok := typedFilters.Value.(bool); ok {
return value
}
return false
case parsers.SelectItem:
resolvedValue := row.resolveSelectItem(typedFilters)
if value, ok := resolvedValue.(bool); ok {
if typedFilters.Invert {
return !value
}
return value
}
}
return false
}
func (r rowContext) applyFilters(filters interface{}) bool {
if filters == nil {
return true
}
switch typedFilters := filters.(type) {
case parsers.ComparisonExpression:
return r.filters_ComparisonExpression(typedFilters)
case parsers.LogicalExpression:
return r.filters_LogicalExpression(typedFilters)
case parsers.Constant:
if value, ok := typedFilters.Value.(bool); ok {
return value
}
return false
case parsers.SelectItem:
resolvedValue := r.resolveSelectItem(typedFilters)
if value, ok := resolvedValue.(bool); ok {
if typedFilters.Invert {
return !value
}
return value
}
}
return false
}
func (r rowContext) filters_ComparisonExpression(expression parsers.ComparisonExpression) bool {
leftExpression, leftExpressionOk := expression.Left.(parsers.SelectItem)
rightExpression, rightExpressionOk := expression.Right.(parsers.SelectItem)
if !leftExpressionOk || !rightExpressionOk {
logger.ErrorLn("ComparisonExpression has incorrect Left or Right type")
return false
}
leftValue := r.resolveSelectItem(leftExpression)
rightValue := r.resolveSelectItem(rightExpression)
cmp := compareValues(leftValue, rightValue)
switch expression.Operation {
case "=":
return cmp == 0
case "!=":
return cmp != 0
case "<":
return cmp < 0
case ">":
return cmp > 0
case "<=":
return cmp <= 0
case ">=":
return cmp >= 0
}
return false
}
func (r rowContext) filters_LogicalExpression(expression parsers.LogicalExpression) bool {
var result bool
for i, subExpression := range expression.Expressions {
expressionResult := r.applyFilters(subExpression)
if i == 0 {
result = expressionResult
}
switch expression.Operation {
case parsers.LogicalExpressionTypeAnd:
result = result && expressionResult
if !result {
return false
}
case parsers.LogicalExpressionTypeOr:
result = result || expressionResult
if result {
return true
}
}
}
return result
}

View File

@ -0,0 +1,73 @@
package memoryexecutor
import (
"github.com/pikami/cosmium/internal/datastore"
"github.com/pikami/cosmium/parsers"
)
type fromIterator struct {
documents rowIterator
table parsers.Table
buffer []rowContext
bufferIndex int
}
func (fi *fromIterator) Next() (rowContext, datastore.DataStoreStatus) {
if fi.documents == nil {
return rowContext{}, datastore.IterEOF
}
// Return from buffer if available
if fi.bufferIndex < len(fi.buffer) {
result := fi.buffer[fi.bufferIndex]
fi.buffer[fi.bufferIndex] = rowContext{}
fi.bufferIndex++
return result, datastore.StatusOk
}
// Resolve next row from documents
row, status := fi.documents.Next()
if status != datastore.StatusOk {
fi.documents = nil
return row, status
}
if fi.table.SelectItem.Path != nil || fi.table.SelectItem.Type == parsers.SelectItemTypeSubQuery {
destinationTableName := fi.table.SelectItem.Alias
if destinationTableName == "" {
destinationTableName = fi.table.Value
}
if destinationTableName == "" {
destinationTableName = resolveDestinationColumnName(fi.table.SelectItem, 0, row.parameters)
}
if fi.table.IsInSelect || fi.table.SelectItem.Type == parsers.SelectItemTypeSubQuery {
selectValue := row.parseArray(fi.table.SelectItem)
rowContexts := make([]rowContext, len(selectValue))
for i, newRowData := range selectValue {
rowContexts[i].parameters = row.parameters
rowContexts[i].tables = copyMap(row.tables)
rowContexts[i].tables[destinationTableName] = newRowData
}
fi.buffer = rowContexts
fi.bufferIndex = 0
return fi.Next()
}
if len(fi.table.SelectItem.Path) > 0 {
sourceTableName := fi.table.SelectItem.Path[0]
sourceTableData := row.tables[sourceTableName]
if sourceTableData == nil {
// When source table is not found, assume it's root document
row.tables[sourceTableName] = row.tables["$root"]
}
}
newRowData := row.resolveSelectItem(fi.table.SelectItem)
row.tables[destinationTableName] = newRowData
return row, status
}
return row, status
}

View File

@ -0,0 +1,69 @@
package memoryexecutor
import (
"fmt"
"strings"
"github.com/pikami/cosmium/internal/datastore"
"github.com/pikami/cosmium/parsers"
)
type groupByIterator struct {
documents rowIterator
groupBy []parsers.SelectItem
groupedRows []rowContext
}
func (gi *groupByIterator) Next() (rowContext, datastore.DataStoreStatus) {
if gi.groupedRows != nil {
if len(gi.groupedRows) == 0 {
return rowContext{}, datastore.IterEOF
}
row := gi.groupedRows[0]
gi.groupedRows = gi.groupedRows[1:]
return row, datastore.StatusOk
}
documents := make([]rowContext, 0)
for {
row, status := gi.documents.Next()
if status != datastore.StatusOk {
break
}
documents = append(documents, row)
}
gi.documents = nil
groupedRows := make(map[string][]rowContext)
groupedKeys := make([]string, 0)
for _, row := range documents {
key := row.generateGroupByKey(gi.groupBy)
if _, ok := groupedRows[key]; !ok {
groupedKeys = append(groupedKeys, key)
}
groupedRows[key] = append(groupedRows[key], row)
}
gi.groupedRows = make([]rowContext, 0)
for _, key := range groupedKeys {
gi.groupedRows = append(gi.groupedRows, rowContext{
tables: groupedRows[key][0].tables,
parameters: groupedRows[key][0].parameters,
grouppedRows: groupedRows[key],
})
}
return gi.Next()
}
func (r rowContext) generateGroupByKey(groupBy []parsers.SelectItem) string {
var keyBuilder strings.Builder
for _, selectItem := range groupBy {
value := r.resolveSelectItem(selectItem)
keyBuilder.WriteString(fmt.Sprintf("%v", value))
keyBuilder.WriteString(":")
}
return keyBuilder.String()
}

View File

@ -0,0 +1,62 @@
package memoryexecutor
import (
"github.com/pikami/cosmium/internal/datastore"
"github.com/pikami/cosmium/parsers"
)
type joinIterator struct {
documents rowIterator
query parsers.SelectStmt
buffer []rowContext
}
func (ji *joinIterator) Next() (rowContext, datastore.DataStoreStatus) {
if ji.documents == nil {
return rowContext{}, datastore.IterEOF
}
if len(ji.buffer) > 0 {
row := ji.buffer[0]
ji.buffer = ji.buffer[1:]
return row, datastore.StatusOk
}
doc, status := ji.documents.Next()
if status != datastore.StatusOk {
ji.documents = nil
return rowContext{}, status
}
ji.buffer = []rowContext{doc}
for _, joinItem := range ji.query.JoinItems {
nextDocuments := make([]rowContext, 0)
for _, row := range ji.buffer {
joinedItems := row.resolveJoinItemSelect(joinItem.SelectItem)
for _, joinedItem := range joinedItems {
tablesCopy := copyMap(row.tables)
tablesCopy[joinItem.Table.Value] = joinedItem
nextDocuments = append(nextDocuments, rowContext{
parameters: row.parameters,
tables: tablesCopy,
})
}
}
ji.buffer = nextDocuments
}
return ji.Next()
}
func (r rowContext) resolveJoinItemSelect(selectItem parsers.SelectItem) []RowType {
if selectItem.Path != nil || selectItem.Type == parsers.SelectItemTypeSubQuery {
selectValue := r.parseArray(selectItem)
documents := make([]RowType, len(selectValue))
for i, newRowData := range selectValue {
documents[i] = newRowData
}
return documents
}
return []RowType{}
}

View File

@ -0,0 +1,19 @@
package memoryexecutor
import "github.com/pikami/cosmium/internal/datastore"
type limitIterator struct {
documents rowTypeIterator
limit int
count int
}
func (li *limitIterator) Next() (RowType, datastore.DataStoreStatus) {
if li.count >= li.limit {
li.documents = nil
return rowContext{}, datastore.IterEOF
}
li.count++
return li.documents.Next()
}

View File

@ -1,752 +1,92 @@
package memoryexecutor
import (
"fmt"
"reflect"
"sort"
"strconv"
"strings"
"github.com/pikami/cosmium/internal/logger"
"github.com/pikami/cosmium/internal/datastore"
"github.com/pikami/cosmium/parsers"
"golang.org/x/exp/slices"
)
type RowType interface{}
type rowContext struct {
tables map[string]RowType
parameters map[string]interface{}
grouppedRows []rowContext
func ExecuteQuery(query parsers.SelectStmt, documents rowTypeIterator) []RowType {
resultIter := executeQuery(query, &rowTypeToRowContextIterator{documents: documents, query: query})
result := make([]RowType, 0)
for {
row, status := resultIter.Next()
if status != datastore.StatusOk {
break
}
result = append(result, row)
}
return result
}
func ExecuteQuery(query parsers.SelectStmt, documents []RowType) []RowType {
currentDocuments := make([]rowContext, 0)
for _, doc := range documents {
currentDocuments = append(currentDocuments, resolveFrom(query, doc)...)
func executeQuery(query parsers.SelectStmt, documents rowIterator) rowTypeIterator {
// Resolve FROM
var iter rowIterator = &fromIterator{
documents: documents,
table: query.Table,
}
// Handle JOINS
nextDocuments := make([]rowContext, 0)
for _, currentDocument := range currentDocuments {
rowContexts := currentDocument.handleJoin(query)
nextDocuments = append(nextDocuments, rowContexts...)
}
currentDocuments = nextDocuments
// Apply filters
nextDocuments = make([]rowContext, 0)
for _, currentDocument := range currentDocuments {
if currentDocument.applyFilters(query.Filters) {
nextDocuments = append(nextDocuments, currentDocument)
// Apply JOIN
if len(query.JoinItems) > 0 {
iter = &joinIterator{
documents: iter,
query: query,
}
}
currentDocuments = nextDocuments
// Apply order
// Apply WHERE
if query.Filters != nil {
iter = &filterIterator{
documents: iter,
filters: query.Filters,
}
}
// Apply ORDER BY
if len(query.OrderExpressions) > 0 {
applyOrder(currentDocuments, query.OrderExpressions)
iter = &orderIterator{
documents: iter,
orderExpressions: query.OrderExpressions,
}
}
// Apply group by
// Apply GROUP BY
if len(query.GroupBy) > 0 {
currentDocuments = applyGroupBy(currentDocuments, query.GroupBy)
iter = &groupByIterator{
documents: iter,
groupBy: query.GroupBy,
}
}
// Apply select
projectedDocuments := applyProjection(currentDocuments, query.SelectItems, query.GroupBy)
// Apply SELECT
var projectedIterator rowTypeIterator = &projectIterator{
documents: iter,
selectItems: query.SelectItems,
groupBy: query.GroupBy,
}
// Apply distinct
// Apply DISTINCT
if query.Distinct {
projectedDocuments = deduplicate(projectedDocuments)
projectedIterator = &distinctIterator{
documents: projectedIterator,
}
}
// Apply offset
// Apply OFFSET
if query.Offset > 0 {
if query.Offset < len(projectedDocuments) {
projectedDocuments = projectedDocuments[query.Offset:]
} else {
projectedDocuments = []RowType{}
projectedIterator = &offsetIterator{
documents: projectedIterator,
offset: query.Offset,
}
}
// Apply result limit
if query.Count > 0 && len(projectedDocuments) > query.Count {
projectedDocuments = projectedDocuments[:query.Count]
// Apply LIMIT
if query.Count > 0 {
projectedIterator = &limitIterator{
documents: projectedIterator,
limit: query.Count,
}
}
return projectedDocuments
}
func resolveFrom(query parsers.SelectStmt, doc RowType) []rowContext {
initialRow, gotParentContext := doc.(rowContext)
if !gotParentContext {
var initialTableName string
if query.Table.SelectItem.Type == parsers.SelectItemTypeSubQuery {
initialTableName = query.Table.SelectItem.Value.(parsers.SelectStmt).Table.Value
}
if initialTableName == "" {
initialTableName = query.Table.Value
}
if initialTableName == "" {
initialTableName = resolveDestinationColumnName(query.Table.SelectItem, 0, query.Parameters)
}
initialRow = rowContext{
parameters: query.Parameters,
tables: map[string]RowType{
initialTableName: doc,
"$root": doc,
},
}
}
if query.Table.SelectItem.Path != nil || query.Table.SelectItem.Type == parsers.SelectItemTypeSubQuery {
destinationTableName := query.Table.SelectItem.Alias
if destinationTableName == "" {
destinationTableName = query.Table.Value
}
if destinationTableName == "" {
destinationTableName = resolveDestinationColumnName(query.Table.SelectItem, 0, initialRow.parameters)
}
if query.Table.IsInSelect || query.Table.SelectItem.Type == parsers.SelectItemTypeSubQuery {
selectValue := initialRow.parseArray(query.Table.SelectItem)
rowContexts := make([]rowContext, len(selectValue))
for i, newRowData := range selectValue {
rowContexts[i].parameters = initialRow.parameters
rowContexts[i].tables = copyMap(initialRow.tables)
rowContexts[i].tables[destinationTableName] = newRowData
}
return rowContexts
}
if len(query.Table.SelectItem.Path) > 0 {
sourceTableName := query.Table.SelectItem.Path[0]
sourceTableData := initialRow.tables[sourceTableName]
if sourceTableData == nil {
// When source table is not found, assume it's root document
initialRow.tables[sourceTableName] = initialRow.tables["$root"]
}
}
newRowData := initialRow.resolveSelectItem(query.Table.SelectItem)
initialRow.tables[destinationTableName] = newRowData
return []rowContext{initialRow}
}
return []rowContext{initialRow}
}
func (r rowContext) handleJoin(query parsers.SelectStmt) []rowContext {
currentDocuments := []rowContext{r}
for _, joinItem := range query.JoinItems {
nextDocuments := make([]rowContext, 0)
for _, currentDocument := range currentDocuments {
joinedItems := currentDocument.resolveJoinItemSelect(joinItem.SelectItem)
for _, joinedItem := range joinedItems {
tablesCopy := copyMap(currentDocument.tables)
tablesCopy[joinItem.Table.Value] = joinedItem
nextDocuments = append(nextDocuments, rowContext{
parameters: currentDocument.parameters,
tables: tablesCopy,
})
}
}
currentDocuments = nextDocuments
}
return currentDocuments
}
func (r rowContext) resolveJoinItemSelect(selectItem parsers.SelectItem) []RowType {
if selectItem.Path != nil || selectItem.Type == parsers.SelectItemTypeSubQuery {
selectValue := r.parseArray(selectItem)
documents := make([]RowType, len(selectValue))
for i, newRowData := range selectValue {
documents[i] = newRowData
}
return documents
}
return []RowType{}
}
func (r rowContext) applyFilters(filters interface{}) bool {
if filters == nil {
return true
}
switch typedFilters := filters.(type) {
case parsers.ComparisonExpression:
return r.filters_ComparisonExpression(typedFilters)
case parsers.LogicalExpression:
return r.filters_LogicalExpression(typedFilters)
case parsers.Constant:
if value, ok := typedFilters.Value.(bool); ok {
return value
}
return false
case parsers.SelectItem:
resolvedValue := r.resolveSelectItem(typedFilters)
if value, ok := resolvedValue.(bool); ok {
if typedFilters.Invert {
return !value
}
return value
}
}
return false
}
func (r rowContext) filters_ComparisonExpression(expression parsers.ComparisonExpression) bool {
leftExpression, leftExpressionOk := expression.Left.(parsers.SelectItem)
rightExpression, rightExpressionOk := expression.Right.(parsers.SelectItem)
if !leftExpressionOk || !rightExpressionOk {
logger.ErrorLn("ComparisonExpression has incorrect Left or Right type")
return false
}
leftValue := r.resolveSelectItem(leftExpression)
rightValue := r.resolveSelectItem(rightExpression)
cmp := compareValues(leftValue, rightValue)
switch expression.Operation {
case "=":
return cmp == 0
case "!=":
return cmp != 0
case "<":
return cmp < 0
case ">":
return cmp > 0
case "<=":
return cmp <= 0
case ">=":
return cmp >= 0
}
return false
}
func (r rowContext) filters_LogicalExpression(expression parsers.LogicalExpression) bool {
var result bool
for i, subExpression := range expression.Expressions {
expressionResult := r.applyFilters(subExpression)
if i == 0 {
result = expressionResult
}
switch expression.Operation {
case parsers.LogicalExpressionTypeAnd:
result = result && expressionResult
if !result {
return false
}
case parsers.LogicalExpressionTypeOr:
result = result || expressionResult
if result {
return true
}
}
}
return result
}
func applyOrder(documents []rowContext, orderExpressions []parsers.OrderExpression) {
less := func(i, j int) bool {
for _, order := range orderExpressions {
val1 := documents[i].resolveSelectItem(order.SelectItem)
val2 := documents[j].resolveSelectItem(order.SelectItem)
cmp := compareValues(val1, val2)
if cmp != 0 {
if order.Direction == parsers.OrderDirectionDesc {
return cmp > 0
}
return cmp < 0
}
}
return i < j
}
sort.SliceStable(documents, less)
}
func applyGroupBy(documents []rowContext, groupBy []parsers.SelectItem) []rowContext {
groupedRows := make(map[string][]rowContext)
groupedKeys := make([]string, 0)
for _, row := range documents {
key := row.generateGroupByKey(groupBy)
if _, ok := groupedRows[key]; !ok {
groupedKeys = append(groupedKeys, key)
}
groupedRows[key] = append(groupedRows[key], row)
}
grouppedRows := make([]rowContext, 0)
for _, key := range groupedKeys {
grouppedRowContext := rowContext{
tables: groupedRows[key][0].tables,
parameters: groupedRows[key][0].parameters,
grouppedRows: groupedRows[key],
}
grouppedRows = append(grouppedRows, grouppedRowContext)
}
return grouppedRows
}
func (r rowContext) generateGroupByKey(groupBy []parsers.SelectItem) string {
var keyBuilder strings.Builder
for _, selectItem := range groupBy {
value := r.resolveSelectItem(selectItem)
keyBuilder.WriteString(fmt.Sprintf("%v", value))
keyBuilder.WriteString(":")
}
return keyBuilder.String()
}
func applyProjection(documents []rowContext, selectItems []parsers.SelectItem, groupBy []parsers.SelectItem) []RowType {
if len(documents) == 0 {
return []RowType{}
}
if hasAggregateFunctions(selectItems) && len(groupBy) == 0 {
// When can have aggregate functions without GROUP BY clause,
// we should aggregate all rows in that case
rowContext := rowContext{
tables: documents[0].tables,
parameters: documents[0].parameters,
grouppedRows: documents,
}
return []RowType{rowContext.applyProjection(selectItems)}
}
projectedDocuments := make([]RowType, len(documents))
for index, row := range documents {
projectedDocuments[index] = row.applyProjection(selectItems)
}
return projectedDocuments
}
func (r rowContext) applyProjection(selectItems []parsers.SelectItem) RowType {
// When the first value is top level, select it instead
if len(selectItems) > 0 && selectItems[0].IsTopLevel {
return r.resolveSelectItem(selectItems[0])
}
// Construct a new row based on the selected columns
row := make(map[string]interface{})
for index, selectItem := range selectItems {
destinationName := resolveDestinationColumnName(selectItem, index, r.parameters)
row[destinationName] = r.resolveSelectItem(selectItem)
}
return row
}
func resolveDestinationColumnName(selectItem parsers.SelectItem, itemIndex int, queryParameters map[string]interface{}) string {
if selectItem.Alias != "" {
return selectItem.Alias
}
destinationName := fmt.Sprintf("$%d", itemIndex+1)
if len(selectItem.Path) > 0 {
destinationName = selectItem.Path[len(selectItem.Path)-1]
}
if destinationName[0] == '@' {
destinationName = queryParameters[destinationName].(string)
}
return destinationName
}
func (r rowContext) resolveSelectItem(selectItem parsers.SelectItem) interface{} {
if selectItem.Type == parsers.SelectItemTypeArray {
return r.selectItem_SelectItemTypeArray(selectItem)
}
if selectItem.Type == parsers.SelectItemTypeObject {
return r.selectItem_SelectItemTypeObject(selectItem)
}
if selectItem.Type == parsers.SelectItemTypeConstant {
return r.selectItem_SelectItemTypeConstant(selectItem)
}
if selectItem.Type == parsers.SelectItemTypeSubQuery {
return r.selectItem_SelectItemTypeSubQuery(selectItem)
}
if selectItem.Type == parsers.SelectItemTypeFunctionCall {
if typedFunctionCall, ok := selectItem.Value.(parsers.FunctionCall); ok {
return r.selectItem_SelectItemTypeFunctionCall(typedFunctionCall)
}
logger.ErrorLn("parsers.SelectItem has incorrect Value type (expected parsers.FunctionCall)")
return nil
}
return r.selectItem_SelectItemTypeField(selectItem)
}
func (r rowContext) selectItem_SelectItemTypeArray(selectItem parsers.SelectItem) interface{} {
arrayValue := make([]interface{}, 0)
for _, subSelectItem := range selectItem.SelectItems {
arrayValue = append(arrayValue, r.resolveSelectItem(subSelectItem))
}
return arrayValue
}
func (r rowContext) selectItem_SelectItemTypeObject(selectItem parsers.SelectItem) interface{} {
objectValue := make(map[string]interface{})
for _, subSelectItem := range selectItem.SelectItems {
objectValue[subSelectItem.Alias] = r.resolveSelectItem(subSelectItem)
}
return objectValue
}
func (r rowContext) selectItem_SelectItemTypeConstant(selectItem parsers.SelectItem) interface{} {
var typedValue parsers.Constant
var ok bool
if typedValue, ok = selectItem.Value.(parsers.Constant); !ok {
// TODO: Handle error
logger.ErrorLn("parsers.Constant has incorrect Value type")
}
if typedValue.Type == parsers.ConstantTypeParameterConstant &&
r.parameters != nil {
if key, ok := typedValue.Value.(string); ok {
return r.parameters[key]
}
}
return typedValue.Value
}
func (r rowContext) selectItem_SelectItemTypeSubQuery(selectItem parsers.SelectItem) interface{} {
subQuery := selectItem.Value.(parsers.SelectStmt)
subQueryResult := ExecuteQuery(
subQuery,
[]RowType{r},
)
if subQuery.Exists {
return len(subQueryResult) > 0
}
return subQueryResult
}
func (r rowContext) selectItem_SelectItemTypeFunctionCall(functionCall parsers.FunctionCall) interface{} {
switch functionCall.Type {
case parsers.FunctionCallStringEquals:
return r.strings_StringEquals(functionCall.Arguments)
case parsers.FunctionCallContains:
return r.strings_Contains(functionCall.Arguments)
case parsers.FunctionCallEndsWith:
return r.strings_EndsWith(functionCall.Arguments)
case parsers.FunctionCallStartsWith:
return r.strings_StartsWith(functionCall.Arguments)
case parsers.FunctionCallConcat:
return r.strings_Concat(functionCall.Arguments)
case parsers.FunctionCallIndexOf:
return r.strings_IndexOf(functionCall.Arguments)
case parsers.FunctionCallToString:
return r.strings_ToString(functionCall.Arguments)
case parsers.FunctionCallUpper:
return r.strings_Upper(functionCall.Arguments)
case parsers.FunctionCallLower:
return r.strings_Lower(functionCall.Arguments)
case parsers.FunctionCallLeft:
return r.strings_Left(functionCall.Arguments)
case parsers.FunctionCallLength:
return r.strings_Length(functionCall.Arguments)
case parsers.FunctionCallLTrim:
return r.strings_LTrim(functionCall.Arguments)
case parsers.FunctionCallReplace:
return r.strings_Replace(functionCall.Arguments)
case parsers.FunctionCallReplicate:
return r.strings_Replicate(functionCall.Arguments)
case parsers.FunctionCallReverse:
return r.strings_Reverse(functionCall.Arguments)
case parsers.FunctionCallRight:
return r.strings_Right(functionCall.Arguments)
case parsers.FunctionCallRTrim:
return r.strings_RTrim(functionCall.Arguments)
case parsers.FunctionCallSubstring:
return r.strings_Substring(functionCall.Arguments)
case parsers.FunctionCallTrim:
return r.strings_Trim(functionCall.Arguments)
case parsers.FunctionCallIsDefined:
return r.typeChecking_IsDefined(functionCall.Arguments)
case parsers.FunctionCallIsArray:
return r.typeChecking_IsArray(functionCall.Arguments)
case parsers.FunctionCallIsBool:
return r.typeChecking_IsBool(functionCall.Arguments)
case parsers.FunctionCallIsFiniteNumber:
return r.typeChecking_IsFiniteNumber(functionCall.Arguments)
case parsers.FunctionCallIsInteger:
return r.typeChecking_IsInteger(functionCall.Arguments)
case parsers.FunctionCallIsNull:
return r.typeChecking_IsNull(functionCall.Arguments)
case parsers.FunctionCallIsNumber:
return r.typeChecking_IsNumber(functionCall.Arguments)
case parsers.FunctionCallIsObject:
return r.typeChecking_IsObject(functionCall.Arguments)
case parsers.FunctionCallIsPrimitive:
return r.typeChecking_IsPrimitive(functionCall.Arguments)
case parsers.FunctionCallIsString:
return r.typeChecking_IsString(functionCall.Arguments)
case parsers.FunctionCallArrayConcat:
return r.array_Concat(functionCall.Arguments)
case parsers.FunctionCallArrayContains:
return r.array_Contains(functionCall.Arguments)
case parsers.FunctionCallArrayContainsAny:
return r.array_Contains_Any(functionCall.Arguments)
case parsers.FunctionCallArrayContainsAll:
return r.array_Contains_All(functionCall.Arguments)
case parsers.FunctionCallArrayLength:
return r.array_Length(functionCall.Arguments)
case parsers.FunctionCallArraySlice:
return r.array_Slice(functionCall.Arguments)
case parsers.FunctionCallSetIntersect:
return r.set_Intersect(functionCall.Arguments)
case parsers.FunctionCallSetUnion:
return r.set_Union(functionCall.Arguments)
case parsers.FunctionCallMathAbs:
return r.math_Abs(functionCall.Arguments)
case parsers.FunctionCallMathAcos:
return r.math_Acos(functionCall.Arguments)
case parsers.FunctionCallMathAsin:
return r.math_Asin(functionCall.Arguments)
case parsers.FunctionCallMathAtan:
return r.math_Atan(functionCall.Arguments)
case parsers.FunctionCallMathCeiling:
return r.math_Ceiling(functionCall.Arguments)
case parsers.FunctionCallMathCos:
return r.math_Cos(functionCall.Arguments)
case parsers.FunctionCallMathCot:
return r.math_Cot(functionCall.Arguments)
case parsers.FunctionCallMathDegrees:
return r.math_Degrees(functionCall.Arguments)
case parsers.FunctionCallMathExp:
return r.math_Exp(functionCall.Arguments)
case parsers.FunctionCallMathFloor:
return r.math_Floor(functionCall.Arguments)
case parsers.FunctionCallMathIntBitNot:
return r.math_IntBitNot(functionCall.Arguments)
case parsers.FunctionCallMathLog10:
return r.math_Log10(functionCall.Arguments)
case parsers.FunctionCallMathRadians:
return r.math_Radians(functionCall.Arguments)
case parsers.FunctionCallMathRound:
return r.math_Round(functionCall.Arguments)
case parsers.FunctionCallMathSign:
return r.math_Sign(functionCall.Arguments)
case parsers.FunctionCallMathSin:
return r.math_Sin(functionCall.Arguments)
case parsers.FunctionCallMathSqrt:
return r.math_Sqrt(functionCall.Arguments)
case parsers.FunctionCallMathSquare:
return r.math_Square(functionCall.Arguments)
case parsers.FunctionCallMathTan:
return r.math_Tan(functionCall.Arguments)
case parsers.FunctionCallMathTrunc:
return r.math_Trunc(functionCall.Arguments)
case parsers.FunctionCallMathAtn2:
return r.math_Atn2(functionCall.Arguments)
case parsers.FunctionCallMathIntAdd:
return r.math_IntAdd(functionCall.Arguments)
case parsers.FunctionCallMathIntBitAnd:
return r.math_IntBitAnd(functionCall.Arguments)
case parsers.FunctionCallMathIntBitLeftShift:
return r.math_IntBitLeftShift(functionCall.Arguments)
case parsers.FunctionCallMathIntBitOr:
return r.math_IntBitOr(functionCall.Arguments)
case parsers.FunctionCallMathIntBitRightShift:
return r.math_IntBitRightShift(functionCall.Arguments)
case parsers.FunctionCallMathIntBitXor:
return r.math_IntBitXor(functionCall.Arguments)
case parsers.FunctionCallMathIntDiv:
return r.math_IntDiv(functionCall.Arguments)
case parsers.FunctionCallMathIntMod:
return r.math_IntMod(functionCall.Arguments)
case parsers.FunctionCallMathIntMul:
return r.math_IntMul(functionCall.Arguments)
case parsers.FunctionCallMathIntSub:
return r.math_IntSub(functionCall.Arguments)
case parsers.FunctionCallMathPower:
return r.math_Power(functionCall.Arguments)
case parsers.FunctionCallMathLog:
return r.math_Log(functionCall.Arguments)
case parsers.FunctionCallMathNumberBin:
return r.math_NumberBin(functionCall.Arguments)
case parsers.FunctionCallMathPi:
return r.math_Pi()
case parsers.FunctionCallMathRand:
return r.math_Rand()
case parsers.FunctionCallAggregateAvg:
return r.aggregate_Avg(functionCall.Arguments)
case parsers.FunctionCallAggregateCount:
return r.aggregate_Count(functionCall.Arguments)
case parsers.FunctionCallAggregateMax:
return r.aggregate_Max(functionCall.Arguments)
case parsers.FunctionCallAggregateMin:
return r.aggregate_Min(functionCall.Arguments)
case parsers.FunctionCallAggregateSum:
return r.aggregate_Sum(functionCall.Arguments)
case parsers.FunctionCallIn:
return r.misc_In(functionCall.Arguments)
}
logger.Errorf("Unknown function call type: %v", functionCall.Type)
return nil
}
func (r rowContext) selectItem_SelectItemTypeField(selectItem parsers.SelectItem) interface{} {
value := r.tables[selectItem.Path[0]]
if len(selectItem.Path) > 1 {
for _, pathSegment := range selectItem.Path[1:] {
if pathSegment[0] == '@' {
pathSegment = r.parameters[pathSegment].(string)
}
switch nestedValue := value.(type) {
case map[string]interface{}:
value = nestedValue[pathSegment]
case map[string]RowType:
value = nestedValue[pathSegment]
case []int, []string, []interface{}:
slice := reflect.ValueOf(nestedValue)
if arrayIndex, err := strconv.Atoi(pathSegment); err == nil && slice.Len() > arrayIndex {
value = slice.Index(arrayIndex).Interface()
} else {
return nil
}
default:
return nil
}
}
}
return value
}
func hasAggregateFunctions(selectItems []parsers.SelectItem) bool {
if selectItems == nil {
return false
}
for _, selectItem := range selectItems {
if selectItem.Type == parsers.SelectItemTypeFunctionCall {
if typedValue, ok := selectItem.Value.(parsers.FunctionCall); ok && slices.Contains[[]parsers.FunctionCallType](parsers.AggregateFunctions, typedValue.Type) {
return true
}
}
if hasAggregateFunctions(selectItem.SelectItems) {
return true
}
}
return false
}
func compareValues(val1, val2 interface{}) int {
if val1 == nil && val2 == nil {
return 0
} else if val1 == nil {
return -1
} else if val2 == nil {
return 1
}
if reflect.TypeOf(val1) != reflect.TypeOf(val2) {
return 1
}
switch val1 := val1.(type) {
case int:
val2 := val2.(int)
if val1 < val2 {
return -1
} else if val1 > val2 {
return 1
}
return 0
case float64:
val2 := val2.(float64)
if val1 < val2 {
return -1
} else if val1 > val2 {
return 1
}
return 0
case string:
val2 := val2.(string)
return strings.Compare(val1, val2)
case bool:
val2 := val2.(bool)
if val1 == val2 {
return 0
} else if val1 {
return 1
} else {
return -1
}
// TODO: Add more types
default:
if reflect.DeepEqual(val1, val2) {
return 0
}
return 1
}
}
func deduplicate[T RowType | interface{}](slice []T) []T {
var result []T
result = make([]T, 0)
for i := 0; i < len(slice); i++ {
unique := true
for j := 0; j < len(result); j++ {
if compareValues(slice[i], result[j]) == 0 {
unique = false
break
}
}
if unique {
result = append(result, slice[i])
}
}
return result
}
func copyMap[T RowType | []RowType](originalMap map[string]T) map[string]T {
targetMap := make(map[string]T)
for k, v := range originalMap {
targetMap[k] = v
}
return targetMap
return projectedIterator
}

View File

@ -4,18 +4,41 @@ import (
"reflect"
"testing"
"github.com/pikami/cosmium/internal/datastore"
"github.com/pikami/cosmium/parsers"
memoryexecutor "github.com/pikami/cosmium/query_executors/memory_executor"
testutils "github.com/pikami/cosmium/test_utils"
)
type TestDocumentIterator struct {
documents []memoryexecutor.RowType
index int
}
func NewTestDocumentIterator(documents []memoryexecutor.RowType) *TestDocumentIterator {
return &TestDocumentIterator{
documents: documents,
index: -1,
}
}
func (i *TestDocumentIterator) Next() (memoryexecutor.RowType, datastore.DataStoreStatus) {
i.index++
if i.index >= len(i.documents) {
return nil, datastore.IterEOF
}
return i.documents[i.index], datastore.StatusOk
}
func testQueryExecute(
t *testing.T,
query parsers.SelectStmt,
data []memoryexecutor.RowType,
expectedData []memoryexecutor.RowType,
) {
result := memoryexecutor.ExecuteQuery(query, data)
iter := NewTestDocumentIterator(data)
result := memoryexecutor.ExecuteQuery(query, iter)
if !reflect.DeepEqual(result, expectedData) {
t.Errorf("execution result does not match expected data.\nExpected: %+v\nGot: %+v", expectedData, result)

View File

@ -0,0 +1,22 @@
package memoryexecutor
import "github.com/pikami/cosmium/internal/datastore"
type offsetIterator struct {
documents rowTypeIterator
offset int
skipped bool
}
func (oi *offsetIterator) Next() (RowType, datastore.DataStoreStatus) {
if oi.skipped {
return oi.documents.Next()
}
for i := 0; i < oi.offset; i++ {
oi.documents.Next()
}
oi.skipped = true
return oi.Next()
}

View File

@ -0,0 +1,63 @@
package memoryexecutor
import (
"sort"
"github.com/pikami/cosmium/internal/datastore"
"github.com/pikami/cosmium/parsers"
)
type orderIterator struct {
documents rowIterator
orderExpressions []parsers.OrderExpression
orderedDocs []rowContext
docsIndex int
}
func (oi *orderIterator) Next() (rowContext, datastore.DataStoreStatus) {
if oi.orderedDocs != nil {
if oi.docsIndex >= len(oi.orderedDocs) {
return rowContext{}, datastore.IterEOF
}
row := oi.orderedDocs[oi.docsIndex]
oi.orderedDocs[oi.docsIndex] = rowContext{}
oi.docsIndex++
return row, datastore.StatusOk
}
oi.orderedDocs = make([]rowContext, 0)
for {
row, status := oi.documents.Next()
if status != datastore.StatusOk {
break
}
oi.orderedDocs = append(oi.orderedDocs, row)
}
oi.documents = nil
less := func(i, j int) bool {
for _, order := range oi.orderExpressions {
val1 := oi.orderedDocs[i].resolveSelectItem(order.SelectItem)
val2 := oi.orderedDocs[j].resolveSelectItem(order.SelectItem)
cmp := compareValues(val1, val2)
if cmp != 0 {
if order.Direction == parsers.OrderDirectionDesc {
return cmp > 0
}
return cmp < 0
}
}
return i < j
}
sort.SliceStable(oi.orderedDocs, less)
if len(oi.orderedDocs) == 0 {
return rowContext{}, datastore.IterEOF
}
oi.docsIndex = 1
return oi.orderedDocs[0], datastore.StatusOk
}

View File

@ -0,0 +1,90 @@
package memoryexecutor
import (
"github.com/pikami/cosmium/internal/datastore"
"github.com/pikami/cosmium/parsers"
"golang.org/x/exp/slices"
)
type projectIterator struct {
documents rowIterator
selectItems []parsers.SelectItem
groupBy []parsers.SelectItem
}
func (pi *projectIterator) Next() (RowType, datastore.DataStoreStatus) {
if pi.documents == nil {
return rowContext{}, datastore.IterEOF
}
row, status := pi.documents.Next()
if status != datastore.StatusOk {
pi.documents = nil
return rowContext{}, status
}
if hasAggregateFunctions(pi.selectItems) && len(pi.groupBy) == 0 {
// When can have aggregate functions without GROUP BY clause,
// we should aggregate all rows in that case.
allDocuments := []rowContext{row}
for {
row, status := pi.documents.Next()
if status != datastore.StatusOk {
break
}
allDocuments = append(allDocuments, row)
}
if len(allDocuments) == 0 {
return rowContext{}, datastore.IterEOF
}
aggRow := rowContext{
tables: row.tables,
parameters: row.parameters,
grouppedRows: allDocuments,
}
return aggRow.applyProjection(pi.selectItems), datastore.StatusOk
}
return row.applyProjection(pi.selectItems), datastore.StatusOk
}
func (r rowContext) applyProjection(selectItems []parsers.SelectItem) RowType {
// When the first value is top level, select it instead
if len(selectItems) > 0 && selectItems[0].IsTopLevel {
return r.resolveSelectItem(selectItems[0])
}
// Construct a new row based on the selected columns
row := make(map[string]interface{})
for index, selectItem := range selectItems {
destinationName := resolveDestinationColumnName(selectItem, index, r.parameters)
row[destinationName] = r.resolveSelectItem(selectItem)
}
return row
}
func hasAggregateFunctions(selectItems []parsers.SelectItem) bool {
if selectItems == nil {
return false
}
for _, selectItem := range selectItems {
if selectItem.Type == parsers.SelectItemTypeFunctionCall {
if typedValue, ok := selectItem.Value.(parsers.FunctionCall); ok && slices.Contains[[]parsers.FunctionCallType](parsers.AggregateFunctions, typedValue.Type) {
return true
}
}
if hasAggregateFunctions(selectItem.SelectItems) {
return true
}
}
return false
}

View File

@ -0,0 +1,44 @@
package memoryexecutor
import (
"github.com/pikami/cosmium/internal/datastore"
"github.com/pikami/cosmium/parsers"
)
type rowTypeToRowContextIterator struct {
documents rowTypeIterator
query parsers.SelectStmt
}
func (di *rowTypeToRowContextIterator) Next() (rowContext, datastore.DataStoreStatus) {
if di.documents == nil {
return rowContext{}, datastore.IterEOF
}
doc, status := di.documents.Next()
if status != datastore.StatusOk {
di.documents = nil
return rowContext{}, status
}
var initialTableName string
if di.query.Table.SelectItem.Type == parsers.SelectItemTypeSubQuery {
initialTableName = di.query.Table.SelectItem.Value.(parsers.SelectStmt).Table.Value
}
if initialTableName == "" {
initialTableName = di.query.Table.Value
}
if initialTableName == "" {
initialTableName = resolveDestinationColumnName(di.query.Table.SelectItem, 0, di.query.Parameters)
}
return rowContext{
parameters: di.query.Parameters,
tables: map[string]RowType{
initialTableName: doc,
"$root": doc,
},
}, status
}