Partial JOIN implementation

This commit is contained in:
Pijus Kamandulis 2024-07-17 21:40:28 +03:00
parent 3bdff9b643
commit 20af73ee9c
7 changed files with 1866 additions and 1536 deletions

View File

@ -3,6 +3,7 @@ package parsers
type SelectStmt struct {
SelectItems []SelectItem
Table Table
JoinItems []JoinItem
Filters interface{}
Distinct bool
Count int
@ -16,6 +17,11 @@ type Table struct {
Value string
}
type JoinItem struct {
Table Table
SelectItem SelectItem
}
type SelectItemType int
const (

View File

@ -0,0 +1,57 @@
package nosql_test
import (
"testing"
"github.com/pikami/cosmium/parsers"
)
func Test_Parse_Join(t *testing.T) {
t.Run("Should parse simple JOIN", func(t *testing.T) {
testQueryParse(
t,
`SELECT c.id, c["pk"] FROM c JOIN cc IN c["tags"]`,
parsers.SelectStmt{
SelectItems: []parsers.SelectItem{
{Path: []string{"c", "id"}},
{Path: []string{"c", "pk"}},
},
Table: parsers.Table{Value: "c"},
JoinItems: []parsers.JoinItem{
{
Table: parsers.Table{
Value: "cc",
},
SelectItem: parsers.SelectItem{
Path: []string{"c", "tags"},
},
},
},
},
)
})
t.Run("Should parse JOIN VALUE", func(t *testing.T) {
testQueryParse(
t,
`SELECT VALUE cc FROM c JOIN cc IN c["tags"]`,
parsers.SelectStmt{
SelectItems: []parsers.SelectItem{
{Path: []string{"cc"}, IsTopLevel: true},
},
Table: parsers.Table{Value: "c"},
JoinItems: []parsers.JoinItem{
{
Table: parsers.Table{
Value: "cc",
},
SelectItem: parsers.SelectItem{
Path: []string{"c", "tags"},
},
},
},
},
)
})
}

File diff suppressed because it is too large Load Diff

View File

@ -4,7 +4,7 @@ package nosql
import "github.com/pikami/cosmium/parsers"
func makeSelectStmt(
columns, table,
columns, table, joinItems,
whereClause interface{}, distinctClause interface{},
count interface{}, groupByClause interface{}, orderList interface{},
offsetClause interface{},
@ -14,6 +14,13 @@ func makeSelectStmt(
Table: table.(parsers.Table),
}
if joinItemsArray, ok := joinItems.([]interface{}); ok && len(joinItemsArray) > 0 {
selectStmt.JoinItems = make([]parsers.JoinItem, len(joinItemsArray))
for i, joinItem := range joinItemsArray {
selectStmt.JoinItems[i] = joinItem.(parsers.JoinItem)
}
}
switch v := whereClause.(type) {
case parsers.ComparisonExpression, parsers.LogicalExpression, parsers.Constant, parsers.SelectItem:
selectStmt.Filters = v
@ -48,6 +55,13 @@ func makeSelectStmt(
return selectStmt, nil
}
func makeJoin(table interface{}, column interface{}) (parsers.JoinItem, error) {
return parsers.JoinItem{
Table: table.(parsers.Table),
SelectItem: column.(parsers.SelectItem),
}, nil
}
func makeSelectItem(name interface{}, path interface{}, selectItemType parsers.SelectItemType) (parsers.SelectItem, error) {
ps := path.([]interface{})
@ -161,13 +175,15 @@ Input <- selectStmt:SelectStmt {
SelectStmt <- Select ws
distinctClause:DistinctClause? ws
topClause:TopClause? ws columns:Selection ws
topClause:TopClause? ws
columns:Selection ws
From ws table:TableName ws
joinClauses:JoinClause* ws
whereClause:(ws Where ws condition:Condition { return condition, nil })?
groupByClause:(ws GroupBy ws columns:ColumnList { return columns, nil })?
orderByClause:OrderByClause?
offsetClause:OffsetClause? {
return makeSelectStmt(columns, table, whereClause,
return makeSelectStmt(columns, table, joinClauses, whereClause,
distinctClause, topClause, groupByClause, orderByClause, offsetClause)
}
@ -177,6 +193,10 @@ TopClause <- Top ws count:Integer {
return count, nil
}
JoinClause <- Join ws table:TableName ws "IN"i ws column:SelectItem {
return makeJoin(table, column)
}
OffsetClause <- "OFFSET"i ws offset:IntegerLiteral ws "LIMIT"i ws limit:IntegerLiteral {
return []interface{}{offset.(parsers.Constant).Value, limit.(parsers.Constant).Value}, nil
}
@ -300,6 +320,8 @@ As <- "AS"i
From <- "FROM"i
Join <- "JOIN"i
Where <- "WHERE"i
And <- "AND"i

View File

@ -11,7 +11,7 @@ func (c memoryExecutorContext) aggregate_Avg(arguments []interface{}, row RowTyp
sum := 0.0
count := 0
if array, isArray := row.([]RowType); isArray {
if array, isArray := row.([]RowWithJoins); isArray {
for _, item := range array {
value := c.getFieldValue(selectExpression, item)
if numericValue, ok := value.(float64); ok {
@ -35,7 +35,7 @@ func (c memoryExecutorContext) aggregate_Count(arguments []interface{}, row RowT
selectExpression := arguments[0].(parsers.SelectItem)
count := 0
if array, isArray := row.([]RowType); isArray {
if array, isArray := row.([]RowWithJoins); isArray {
for _, item := range array {
value := c.getFieldValue(selectExpression, item)
if value != nil {
@ -52,7 +52,7 @@ func (c memoryExecutorContext) aggregate_Max(arguments []interface{}, row RowTyp
max := 0.0
count := 0
if array, isArray := row.([]RowType); isArray {
if array, isArray := row.([]RowWithJoins); isArray {
for _, item := range array {
value := c.getFieldValue(selectExpression, item)
if numericValue, ok := value.(float64); ok {
@ -81,7 +81,7 @@ func (c memoryExecutorContext) aggregate_Min(arguments []interface{}, row RowTyp
min := math.MaxFloat64
count := 0
if array, isArray := row.([]RowType); isArray {
if array, isArray := row.([]RowWithJoins); isArray {
for _, item := range array {
value := c.getFieldValue(selectExpression, item)
if numericValue, ok := value.(float64); ok {
@ -110,7 +110,7 @@ func (c memoryExecutorContext) aggregate_Sum(arguments []interface{}, row RowTyp
sum := 0.0
count := 0
if array, isArray := row.([]RowType); isArray {
if array, isArray := row.([]RowWithJoins); isArray {
for _, item := range array {
value := c.getFieldValue(selectExpression, item)
if numericValue, ok := value.(float64); ok {

View File

@ -0,0 +1,86 @@
package memoryexecutor_test
import (
"testing"
"github.com/pikami/cosmium/parsers"
memoryexecutor "github.com/pikami/cosmium/query_executors/memory_executor"
)
func Test_Execute_Joins(t *testing.T) {
mockData := []memoryexecutor.RowType{
map[string]interface{}{
"id": 1,
"tags": []map[string]interface{}{
{"name": "a"},
{"name": "b"},
},
},
map[string]interface{}{
"id": 2,
"tags": []map[string]interface{}{
{"name": "b"},
{"name": "c"},
},
},
}
t.Run("Should execute JOIN on 'tags'", func(t *testing.T) {
testQueryExecute(
t,
parsers.SelectStmt{
SelectItems: []parsers.SelectItem{
{Path: []string{"c", "id"}},
{Path: []string{"cc", "name"}},
},
Table: parsers.Table{Value: "c"},
JoinItems: []parsers.JoinItem{
{
Table: parsers.Table{
Value: "cc",
},
SelectItem: parsers.SelectItem{
Path: []string{"c", "tags"},
},
},
},
},
mockData,
[]memoryexecutor.RowType{
map[string]interface{}{"id": 1, "name": "a"},
map[string]interface{}{"id": 1, "name": "b"},
map[string]interface{}{"id": 2, "name": "b"},
map[string]interface{}{"id": 2, "name": "c"},
},
)
})
t.Run("Should execute JOIN VALUE on 'tags'", func(t *testing.T) {
testQueryExecute(
t,
parsers.SelectStmt{
SelectItems: []parsers.SelectItem{
{Path: []string{"cc"}, IsTopLevel: true},
},
Table: parsers.Table{Value: "c"},
JoinItems: []parsers.JoinItem{
{
Table: parsers.Table{
Value: "cc",
},
SelectItem: parsers.SelectItem{
Path: []string{"c", "tags"},
},
},
},
},
mockData,
[]memoryexecutor.RowType{
map[string]interface{}{"name": "a"},
map[string]interface{}{"name": "b"},
map[string]interface{}{"name": "b"},
map[string]interface{}{"name": "c"},
},
)
})
}

View File

@ -13,6 +13,7 @@ import (
)
type RowType interface{}
type RowWithJoins map[string]RowType
type ExpressionType interface{}
type memoryExecutorContext struct {
@ -24,24 +25,52 @@ func Execute(query parsers.SelectStmt, data []RowType) []RowType {
parameters: query.Parameters,
}
result := make([]RowType, 0)
// Apply Filter
joinedRows := make([]RowWithJoins, 0)
for _, row := range data {
if ctx.evaluateFilters(query.Filters, row) {
result = append(result, row)
// Perform joins
dataTables := map[string][]RowType{}
for _, join := range query.JoinItems {
joinedData := ctx.getFieldValue(join.SelectItem, row)
if joinedDataArray, isArray := joinedData.([]map[string]interface{}); isArray {
var rows []RowType
for _, m := range joinedDataArray {
rows = append(rows, RowType(m))
}
dataTables[join.Table.Value] = rows
}
}
// Generate flat rows
flatRows := []RowWithJoins{
{query.Table.Value: row},
}
for joinedTableName, joinedTable := range dataTables {
flatRows = zipRows(flatRows, joinedTableName, joinedTable)
}
// Apply filters
filteredRows := []RowWithJoins{}
for _, rowWithJoins := range flatRows {
if ctx.evaluateFilters(query.Filters, rowWithJoins) {
filteredRows = append(filteredRows, rowWithJoins)
}
}
joinedRows = append(joinedRows, filteredRows...)
}
// Apply order
if query.OrderExpressions != nil && len(query.OrderExpressions) > 0 {
ctx.orderBy(query.OrderExpressions, result)
ctx.orderBy(query.OrderExpressions, joinedRows)
}
result := make([]RowType, 0)
// Apply group
isGroupSelect := query.GroupBy != nil && len(query.GroupBy) > 0
if isGroupSelect {
result = ctx.groupBy(query, result)
result = ctx.groupBy(query, joinedRows)
}
// Apply select
@ -50,9 +79,9 @@ func Execute(query parsers.SelectStmt, data []RowType) []RowType {
if hasAggregateFunctions(query.SelectItems) {
// When can have aggregate functions without GROUP BY clause,
// we should aggregate all rows in that case
selectedData = append(selectedData, ctx.selectRow(query.SelectItems, result))
selectedData = append(selectedData, ctx.selectRow(query.SelectItems, joinedRows))
} else {
for _, row := range result {
for _, row := range joinedRows {
selectedData = append(selectedData, ctx.selectRow(query.SelectItems, row))
}
}
@ -79,7 +108,7 @@ func Execute(query parsers.SelectStmt, data []RowType) []RowType {
return result
}
func (c memoryExecutorContext) selectRow(selectItems []parsers.SelectItem, row RowType) interface{} {
func (c memoryExecutorContext) selectRow(selectItems []parsers.SelectItem, row interface{}) interface{} {
// When the first value is top level, select it instead
if len(selectItems) > 0 && selectItems[0].IsTopLevel {
return c.getFieldValue(selectItems[0], row)
@ -103,7 +132,7 @@ func (c memoryExecutorContext) selectRow(selectItems []parsers.SelectItem, row R
return newRow
}
func (c memoryExecutorContext) evaluateFilters(expr ExpressionType, row RowType) bool {
func (c memoryExecutorContext) evaluateFilters(expr ExpressionType, row RowWithJoins) bool {
if expr == nil {
return true
}
@ -164,7 +193,7 @@ func (c memoryExecutorContext) evaluateFilters(expr ExpressionType, row RowType)
return false
}
func (c memoryExecutorContext) getFieldValue(field parsers.SelectItem, row RowType) interface{} {
func (c memoryExecutorContext) getFieldValue(field parsers.SelectItem, row interface{}) interface{} {
if field.Type == parsers.SelectItemTypeArray {
arrayValue := make([]interface{}, 0)
for _, selectItem := range field.SelectItems {
@ -200,7 +229,8 @@ func (c memoryExecutorContext) getFieldValue(field parsers.SelectItem, row RowTy
}
rowValue := row
if array, isArray := row.([]RowType); isArray {
// Used for aggregates
if array, isArray := row.([]RowWithJoins); isArray {
rowValue = array[0]
}
@ -374,6 +404,9 @@ func (c memoryExecutorContext) getFieldValue(field parsers.SelectItem, row RowTy
}
value := rowValue
if joinedRow, isRowWithJoins := value.(RowWithJoins); isRowWithJoins {
value = joinedRow[field.Path[0]]
}
if len(field.Path) > 1 {
for _, pathSegment := range field.Path[1:] {
@ -381,6 +414,8 @@ func (c memoryExecutorContext) getFieldValue(field parsers.SelectItem, row RowTy
switch nestedValue := value.(type) {
case map[string]interface{}:
value = nestedValue[pathSegment]
case RowWithJoins:
value = nestedValue[pathSegment]
case []int, []string, []interface{}:
slice := reflect.ValueOf(nestedValue)
if arrayIndex, err := strconv.Atoi(pathSegment); err == nil && slice.Len() > arrayIndex {
@ -398,7 +433,7 @@ func (c memoryExecutorContext) getFieldValue(field parsers.SelectItem, row RowTy
func (c memoryExecutorContext) getExpressionParameterValue(
parameter interface{},
row RowType,
row RowWithJoins,
) interface{} {
switch typedParameter := parameter.(type) {
case parsers.SelectItem:
@ -410,7 +445,7 @@ func (c memoryExecutorContext) getExpressionParameterValue(
return nil
}
func (c memoryExecutorContext) orderBy(orderBy []parsers.OrderExpression, data []RowType) {
func (c memoryExecutorContext) orderBy(orderBy []parsers.OrderExpression, data []RowWithJoins) {
less := func(i, j int) bool {
for _, order := range orderBy {
val1 := c.getFieldValue(order.SelectItem, data[i])
@ -430,8 +465,8 @@ func (c memoryExecutorContext) orderBy(orderBy []parsers.OrderExpression, data [
sort.SliceStable(data, less)
}
func (c memoryExecutorContext) groupBy(selectStmt parsers.SelectStmt, data []RowType) []RowType {
groupedRows := make(map[string][]RowType)
func (c memoryExecutorContext) groupBy(selectStmt parsers.SelectStmt, data []RowWithJoins) []RowType {
groupedRows := make(map[string][]RowWithJoins)
groupedKeys := make([]string, 0)
// Group rows by group by columns
@ -454,7 +489,7 @@ func (c memoryExecutorContext) groupBy(selectStmt parsers.SelectStmt, data []Row
return aggregatedRows
}
func (c memoryExecutorContext) generateGroupKey(groupByFields []parsers.SelectItem, row RowType) string {
func (c memoryExecutorContext) generateGroupKey(groupByFields []parsers.SelectItem, row RowWithJoins) string {
var keyBuilder strings.Builder
for _, column := range groupByFields {
fieldValue := c.getFieldValue(column, row)
@ -465,7 +500,7 @@ func (c memoryExecutorContext) generateGroupKey(groupByFields []parsers.SelectIt
return keyBuilder.String()
}
func (c memoryExecutorContext) aggregateGroup(selectStmt parsers.SelectStmt, groupRows []RowType) RowType {
func (c memoryExecutorContext) aggregateGroup(selectStmt parsers.SelectStmt, groupRows []RowWithJoins) RowType {
aggregatedRow := c.selectRow(selectStmt.SelectItems, groupRows)
return aggregatedRow
@ -553,3 +588,27 @@ func hasAggregateFunctions(selectItems []parsers.SelectItem) bool {
return false
}
func zipRows(current []RowWithJoins, joinedTableName string, rowsToZip []RowType) []RowWithJoins {
resultMap := make([]RowWithJoins, 0)
for _, currentRow := range current {
for _, rowToZip := range rowsToZip {
newRow := copyMap(currentRow)
newRow[joinedTableName] = rowToZip
resultMap = append(resultMap, newRow)
}
}
return resultMap
}
func copyMap(originalMap map[string]RowType) map[string]RowType {
targetMap := make(map[string]RowType)
for k, v := range originalMap {
targetMap[k] = v
}
return targetMap
}