mirror of https://github.com/pikami/cosmium.git
Partial JOIN implementation
This commit is contained in:
parent
3bdff9b643
commit
20af73ee9c
|
@ -3,6 +3,7 @@ package parsers
|
|||
type SelectStmt struct {
|
||||
SelectItems []SelectItem
|
||||
Table Table
|
||||
JoinItems []JoinItem
|
||||
Filters interface{}
|
||||
Distinct bool
|
||||
Count int
|
||||
|
@ -16,6 +17,11 @@ type Table struct {
|
|||
Value string
|
||||
}
|
||||
|
||||
type JoinItem struct {
|
||||
Table Table
|
||||
SelectItem SelectItem
|
||||
}
|
||||
|
||||
type SelectItemType int
|
||||
|
||||
const (
|
||||
|
|
|
@ -0,0 +1,57 @@
|
|||
package nosql_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/pikami/cosmium/parsers"
|
||||
)
|
||||
|
||||
func Test_Parse_Join(t *testing.T) {
|
||||
|
||||
t.Run("Should parse simple JOIN", func(t *testing.T) {
|
||||
testQueryParse(
|
||||
t,
|
||||
`SELECT c.id, c["pk"] FROM c JOIN cc IN c["tags"]`,
|
||||
parsers.SelectStmt{
|
||||
SelectItems: []parsers.SelectItem{
|
||||
{Path: []string{"c", "id"}},
|
||||
{Path: []string{"c", "pk"}},
|
||||
},
|
||||
Table: parsers.Table{Value: "c"},
|
||||
JoinItems: []parsers.JoinItem{
|
||||
{
|
||||
Table: parsers.Table{
|
||||
Value: "cc",
|
||||
},
|
||||
SelectItem: parsers.SelectItem{
|
||||
Path: []string{"c", "tags"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
})
|
||||
|
||||
t.Run("Should parse JOIN VALUE", func(t *testing.T) {
|
||||
testQueryParse(
|
||||
t,
|
||||
`SELECT VALUE cc FROM c JOIN cc IN c["tags"]`,
|
||||
parsers.SelectStmt{
|
||||
SelectItems: []parsers.SelectItem{
|
||||
{Path: []string{"cc"}, IsTopLevel: true},
|
||||
},
|
||||
Table: parsers.Table{Value: "c"},
|
||||
JoinItems: []parsers.JoinItem{
|
||||
{
|
||||
Table: parsers.Table{
|
||||
Value: "cc",
|
||||
},
|
||||
SelectItem: parsers.SelectItem{
|
||||
Path: []string{"c", "tags"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
})
|
||||
}
|
File diff suppressed because it is too large
Load Diff
|
@ -4,7 +4,7 @@ package nosql
|
|||
import "github.com/pikami/cosmium/parsers"
|
||||
|
||||
func makeSelectStmt(
|
||||
columns, table,
|
||||
columns, table, joinItems,
|
||||
whereClause interface{}, distinctClause interface{},
|
||||
count interface{}, groupByClause interface{}, orderList interface{},
|
||||
offsetClause interface{},
|
||||
|
@ -14,6 +14,13 @@ func makeSelectStmt(
|
|||
Table: table.(parsers.Table),
|
||||
}
|
||||
|
||||
if joinItemsArray, ok := joinItems.([]interface{}); ok && len(joinItemsArray) > 0 {
|
||||
selectStmt.JoinItems = make([]parsers.JoinItem, len(joinItemsArray))
|
||||
for i, joinItem := range joinItemsArray {
|
||||
selectStmt.JoinItems[i] = joinItem.(parsers.JoinItem)
|
||||
}
|
||||
}
|
||||
|
||||
switch v := whereClause.(type) {
|
||||
case parsers.ComparisonExpression, parsers.LogicalExpression, parsers.Constant, parsers.SelectItem:
|
||||
selectStmt.Filters = v
|
||||
|
@ -48,6 +55,13 @@ func makeSelectStmt(
|
|||
return selectStmt, nil
|
||||
}
|
||||
|
||||
func makeJoin(table interface{}, column interface{}) (parsers.JoinItem, error) {
|
||||
return parsers.JoinItem{
|
||||
Table: table.(parsers.Table),
|
||||
SelectItem: column.(parsers.SelectItem),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func makeSelectItem(name interface{}, path interface{}, selectItemType parsers.SelectItemType) (parsers.SelectItem, error) {
|
||||
ps := path.([]interface{})
|
||||
|
||||
|
@ -161,13 +175,15 @@ Input <- selectStmt:SelectStmt {
|
|||
|
||||
SelectStmt <- Select ws
|
||||
distinctClause:DistinctClause? ws
|
||||
topClause:TopClause? ws columns:Selection ws
|
||||
topClause:TopClause? ws
|
||||
columns:Selection ws
|
||||
From ws table:TableName ws
|
||||
joinClauses:JoinClause* ws
|
||||
whereClause:(ws Where ws condition:Condition { return condition, nil })?
|
||||
groupByClause:(ws GroupBy ws columns:ColumnList { return columns, nil })?
|
||||
orderByClause:OrderByClause?
|
||||
offsetClause:OffsetClause? {
|
||||
return makeSelectStmt(columns, table, whereClause,
|
||||
return makeSelectStmt(columns, table, joinClauses, whereClause,
|
||||
distinctClause, topClause, groupByClause, orderByClause, offsetClause)
|
||||
}
|
||||
|
||||
|
@ -177,6 +193,10 @@ TopClause <- Top ws count:Integer {
|
|||
return count, nil
|
||||
}
|
||||
|
||||
JoinClause <- Join ws table:TableName ws "IN"i ws column:SelectItem {
|
||||
return makeJoin(table, column)
|
||||
}
|
||||
|
||||
OffsetClause <- "OFFSET"i ws offset:IntegerLiteral ws "LIMIT"i ws limit:IntegerLiteral {
|
||||
return []interface{}{offset.(parsers.Constant).Value, limit.(parsers.Constant).Value}, nil
|
||||
}
|
||||
|
@ -300,6 +320,8 @@ As <- "AS"i
|
|||
|
||||
From <- "FROM"i
|
||||
|
||||
Join <- "JOIN"i
|
||||
|
||||
Where <- "WHERE"i
|
||||
|
||||
And <- "AND"i
|
||||
|
|
|
@ -11,7 +11,7 @@ func (c memoryExecutorContext) aggregate_Avg(arguments []interface{}, row RowTyp
|
|||
sum := 0.0
|
||||
count := 0
|
||||
|
||||
if array, isArray := row.([]RowType); isArray {
|
||||
if array, isArray := row.([]RowWithJoins); isArray {
|
||||
for _, item := range array {
|
||||
value := c.getFieldValue(selectExpression, item)
|
||||
if numericValue, ok := value.(float64); ok {
|
||||
|
@ -35,7 +35,7 @@ func (c memoryExecutorContext) aggregate_Count(arguments []interface{}, row RowT
|
|||
selectExpression := arguments[0].(parsers.SelectItem)
|
||||
count := 0
|
||||
|
||||
if array, isArray := row.([]RowType); isArray {
|
||||
if array, isArray := row.([]RowWithJoins); isArray {
|
||||
for _, item := range array {
|
||||
value := c.getFieldValue(selectExpression, item)
|
||||
if value != nil {
|
||||
|
@ -52,7 +52,7 @@ func (c memoryExecutorContext) aggregate_Max(arguments []interface{}, row RowTyp
|
|||
max := 0.0
|
||||
count := 0
|
||||
|
||||
if array, isArray := row.([]RowType); isArray {
|
||||
if array, isArray := row.([]RowWithJoins); isArray {
|
||||
for _, item := range array {
|
||||
value := c.getFieldValue(selectExpression, item)
|
||||
if numericValue, ok := value.(float64); ok {
|
||||
|
@ -81,7 +81,7 @@ func (c memoryExecutorContext) aggregate_Min(arguments []interface{}, row RowTyp
|
|||
min := math.MaxFloat64
|
||||
count := 0
|
||||
|
||||
if array, isArray := row.([]RowType); isArray {
|
||||
if array, isArray := row.([]RowWithJoins); isArray {
|
||||
for _, item := range array {
|
||||
value := c.getFieldValue(selectExpression, item)
|
||||
if numericValue, ok := value.(float64); ok {
|
||||
|
@ -110,7 +110,7 @@ func (c memoryExecutorContext) aggregate_Sum(arguments []interface{}, row RowTyp
|
|||
sum := 0.0
|
||||
count := 0
|
||||
|
||||
if array, isArray := row.([]RowType); isArray {
|
||||
if array, isArray := row.([]RowWithJoins); isArray {
|
||||
for _, item := range array {
|
||||
value := c.getFieldValue(selectExpression, item)
|
||||
if numericValue, ok := value.(float64); ok {
|
||||
|
|
|
@ -0,0 +1,86 @@
|
|||
package memoryexecutor_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/pikami/cosmium/parsers"
|
||||
memoryexecutor "github.com/pikami/cosmium/query_executors/memory_executor"
|
||||
)
|
||||
|
||||
func Test_Execute_Joins(t *testing.T) {
|
||||
mockData := []memoryexecutor.RowType{
|
||||
map[string]interface{}{
|
||||
"id": 1,
|
||||
"tags": []map[string]interface{}{
|
||||
{"name": "a"},
|
||||
{"name": "b"},
|
||||
},
|
||||
},
|
||||
map[string]interface{}{
|
||||
"id": 2,
|
||||
"tags": []map[string]interface{}{
|
||||
{"name": "b"},
|
||||
{"name": "c"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("Should execute JOIN on 'tags'", func(t *testing.T) {
|
||||
testQueryExecute(
|
||||
t,
|
||||
parsers.SelectStmt{
|
||||
SelectItems: []parsers.SelectItem{
|
||||
{Path: []string{"c", "id"}},
|
||||
{Path: []string{"cc", "name"}},
|
||||
},
|
||||
Table: parsers.Table{Value: "c"},
|
||||
JoinItems: []parsers.JoinItem{
|
||||
{
|
||||
Table: parsers.Table{
|
||||
Value: "cc",
|
||||
},
|
||||
SelectItem: parsers.SelectItem{
|
||||
Path: []string{"c", "tags"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
mockData,
|
||||
[]memoryexecutor.RowType{
|
||||
map[string]interface{}{"id": 1, "name": "a"},
|
||||
map[string]interface{}{"id": 1, "name": "b"},
|
||||
map[string]interface{}{"id": 2, "name": "b"},
|
||||
map[string]interface{}{"id": 2, "name": "c"},
|
||||
},
|
||||
)
|
||||
})
|
||||
|
||||
t.Run("Should execute JOIN VALUE on 'tags'", func(t *testing.T) {
|
||||
testQueryExecute(
|
||||
t,
|
||||
parsers.SelectStmt{
|
||||
SelectItems: []parsers.SelectItem{
|
||||
{Path: []string{"cc"}, IsTopLevel: true},
|
||||
},
|
||||
Table: parsers.Table{Value: "c"},
|
||||
JoinItems: []parsers.JoinItem{
|
||||
{
|
||||
Table: parsers.Table{
|
||||
Value: "cc",
|
||||
},
|
||||
SelectItem: parsers.SelectItem{
|
||||
Path: []string{"c", "tags"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
mockData,
|
||||
[]memoryexecutor.RowType{
|
||||
map[string]interface{}{"name": "a"},
|
||||
map[string]interface{}{"name": "b"},
|
||||
map[string]interface{}{"name": "b"},
|
||||
map[string]interface{}{"name": "c"},
|
||||
},
|
||||
)
|
||||
})
|
||||
}
|
|
@ -13,6 +13,7 @@ import (
|
|||
)
|
||||
|
||||
type RowType interface{}
|
||||
type RowWithJoins map[string]RowType
|
||||
type ExpressionType interface{}
|
||||
|
||||
type memoryExecutorContext struct {
|
||||
|
@ -24,24 +25,52 @@ func Execute(query parsers.SelectStmt, data []RowType) []RowType {
|
|||
parameters: query.Parameters,
|
||||
}
|
||||
|
||||
result := make([]RowType, 0)
|
||||
|
||||
// Apply Filter
|
||||
joinedRows := make([]RowWithJoins, 0)
|
||||
for _, row := range data {
|
||||
if ctx.evaluateFilters(query.Filters, row) {
|
||||
result = append(result, row)
|
||||
// Perform joins
|
||||
dataTables := map[string][]RowType{}
|
||||
|
||||
for _, join := range query.JoinItems {
|
||||
joinedData := ctx.getFieldValue(join.SelectItem, row)
|
||||
if joinedDataArray, isArray := joinedData.([]map[string]interface{}); isArray {
|
||||
var rows []RowType
|
||||
for _, m := range joinedDataArray {
|
||||
rows = append(rows, RowType(m))
|
||||
}
|
||||
dataTables[join.Table.Value] = rows
|
||||
}
|
||||
}
|
||||
|
||||
// Generate flat rows
|
||||
flatRows := []RowWithJoins{
|
||||
{query.Table.Value: row},
|
||||
}
|
||||
for joinedTableName, joinedTable := range dataTables {
|
||||
flatRows = zipRows(flatRows, joinedTableName, joinedTable)
|
||||
}
|
||||
|
||||
// Apply filters
|
||||
filteredRows := []RowWithJoins{}
|
||||
for _, rowWithJoins := range flatRows {
|
||||
if ctx.evaluateFilters(query.Filters, rowWithJoins) {
|
||||
filteredRows = append(filteredRows, rowWithJoins)
|
||||
}
|
||||
}
|
||||
|
||||
joinedRows = append(joinedRows, filteredRows...)
|
||||
}
|
||||
|
||||
// Apply order
|
||||
if query.OrderExpressions != nil && len(query.OrderExpressions) > 0 {
|
||||
ctx.orderBy(query.OrderExpressions, result)
|
||||
ctx.orderBy(query.OrderExpressions, joinedRows)
|
||||
}
|
||||
|
||||
result := make([]RowType, 0)
|
||||
|
||||
// Apply group
|
||||
isGroupSelect := query.GroupBy != nil && len(query.GroupBy) > 0
|
||||
if isGroupSelect {
|
||||
result = ctx.groupBy(query, result)
|
||||
result = ctx.groupBy(query, joinedRows)
|
||||
}
|
||||
|
||||
// Apply select
|
||||
|
@ -50,9 +79,9 @@ func Execute(query parsers.SelectStmt, data []RowType) []RowType {
|
|||
if hasAggregateFunctions(query.SelectItems) {
|
||||
// When can have aggregate functions without GROUP BY clause,
|
||||
// we should aggregate all rows in that case
|
||||
selectedData = append(selectedData, ctx.selectRow(query.SelectItems, result))
|
||||
selectedData = append(selectedData, ctx.selectRow(query.SelectItems, joinedRows))
|
||||
} else {
|
||||
for _, row := range result {
|
||||
for _, row := range joinedRows {
|
||||
selectedData = append(selectedData, ctx.selectRow(query.SelectItems, row))
|
||||
}
|
||||
}
|
||||
|
@ -79,7 +108,7 @@ func Execute(query parsers.SelectStmt, data []RowType) []RowType {
|
|||
return result
|
||||
}
|
||||
|
||||
func (c memoryExecutorContext) selectRow(selectItems []parsers.SelectItem, row RowType) interface{} {
|
||||
func (c memoryExecutorContext) selectRow(selectItems []parsers.SelectItem, row interface{}) interface{} {
|
||||
// When the first value is top level, select it instead
|
||||
if len(selectItems) > 0 && selectItems[0].IsTopLevel {
|
||||
return c.getFieldValue(selectItems[0], row)
|
||||
|
@ -103,7 +132,7 @@ func (c memoryExecutorContext) selectRow(selectItems []parsers.SelectItem, row R
|
|||
return newRow
|
||||
}
|
||||
|
||||
func (c memoryExecutorContext) evaluateFilters(expr ExpressionType, row RowType) bool {
|
||||
func (c memoryExecutorContext) evaluateFilters(expr ExpressionType, row RowWithJoins) bool {
|
||||
if expr == nil {
|
||||
return true
|
||||
}
|
||||
|
@ -164,7 +193,7 @@ func (c memoryExecutorContext) evaluateFilters(expr ExpressionType, row RowType)
|
|||
return false
|
||||
}
|
||||
|
||||
func (c memoryExecutorContext) getFieldValue(field parsers.SelectItem, row RowType) interface{} {
|
||||
func (c memoryExecutorContext) getFieldValue(field parsers.SelectItem, row interface{}) interface{} {
|
||||
if field.Type == parsers.SelectItemTypeArray {
|
||||
arrayValue := make([]interface{}, 0)
|
||||
for _, selectItem := range field.SelectItems {
|
||||
|
@ -200,7 +229,8 @@ func (c memoryExecutorContext) getFieldValue(field parsers.SelectItem, row RowTy
|
|||
}
|
||||
|
||||
rowValue := row
|
||||
if array, isArray := row.([]RowType); isArray {
|
||||
// Used for aggregates
|
||||
if array, isArray := row.([]RowWithJoins); isArray {
|
||||
rowValue = array[0]
|
||||
}
|
||||
|
||||
|
@ -374,6 +404,9 @@ func (c memoryExecutorContext) getFieldValue(field parsers.SelectItem, row RowTy
|
|||
}
|
||||
|
||||
value := rowValue
|
||||
if joinedRow, isRowWithJoins := value.(RowWithJoins); isRowWithJoins {
|
||||
value = joinedRow[field.Path[0]]
|
||||
}
|
||||
|
||||
if len(field.Path) > 1 {
|
||||
for _, pathSegment := range field.Path[1:] {
|
||||
|
@ -381,6 +414,8 @@ func (c memoryExecutorContext) getFieldValue(field parsers.SelectItem, row RowTy
|
|||
switch nestedValue := value.(type) {
|
||||
case map[string]interface{}:
|
||||
value = nestedValue[pathSegment]
|
||||
case RowWithJoins:
|
||||
value = nestedValue[pathSegment]
|
||||
case []int, []string, []interface{}:
|
||||
slice := reflect.ValueOf(nestedValue)
|
||||
if arrayIndex, err := strconv.Atoi(pathSegment); err == nil && slice.Len() > arrayIndex {
|
||||
|
@ -398,7 +433,7 @@ func (c memoryExecutorContext) getFieldValue(field parsers.SelectItem, row RowTy
|
|||
|
||||
func (c memoryExecutorContext) getExpressionParameterValue(
|
||||
parameter interface{},
|
||||
row RowType,
|
||||
row RowWithJoins,
|
||||
) interface{} {
|
||||
switch typedParameter := parameter.(type) {
|
||||
case parsers.SelectItem:
|
||||
|
@ -410,7 +445,7 @@ func (c memoryExecutorContext) getExpressionParameterValue(
|
|||
return nil
|
||||
}
|
||||
|
||||
func (c memoryExecutorContext) orderBy(orderBy []parsers.OrderExpression, data []RowType) {
|
||||
func (c memoryExecutorContext) orderBy(orderBy []parsers.OrderExpression, data []RowWithJoins) {
|
||||
less := func(i, j int) bool {
|
||||
for _, order := range orderBy {
|
||||
val1 := c.getFieldValue(order.SelectItem, data[i])
|
||||
|
@ -430,8 +465,8 @@ func (c memoryExecutorContext) orderBy(orderBy []parsers.OrderExpression, data [
|
|||
sort.SliceStable(data, less)
|
||||
}
|
||||
|
||||
func (c memoryExecutorContext) groupBy(selectStmt parsers.SelectStmt, data []RowType) []RowType {
|
||||
groupedRows := make(map[string][]RowType)
|
||||
func (c memoryExecutorContext) groupBy(selectStmt parsers.SelectStmt, data []RowWithJoins) []RowType {
|
||||
groupedRows := make(map[string][]RowWithJoins)
|
||||
groupedKeys := make([]string, 0)
|
||||
|
||||
// Group rows by group by columns
|
||||
|
@ -454,7 +489,7 @@ func (c memoryExecutorContext) groupBy(selectStmt parsers.SelectStmt, data []Row
|
|||
return aggregatedRows
|
||||
}
|
||||
|
||||
func (c memoryExecutorContext) generateGroupKey(groupByFields []parsers.SelectItem, row RowType) string {
|
||||
func (c memoryExecutorContext) generateGroupKey(groupByFields []parsers.SelectItem, row RowWithJoins) string {
|
||||
var keyBuilder strings.Builder
|
||||
for _, column := range groupByFields {
|
||||
fieldValue := c.getFieldValue(column, row)
|
||||
|
@ -465,7 +500,7 @@ func (c memoryExecutorContext) generateGroupKey(groupByFields []parsers.SelectIt
|
|||
return keyBuilder.String()
|
||||
}
|
||||
|
||||
func (c memoryExecutorContext) aggregateGroup(selectStmt parsers.SelectStmt, groupRows []RowType) RowType {
|
||||
func (c memoryExecutorContext) aggregateGroup(selectStmt parsers.SelectStmt, groupRows []RowWithJoins) RowType {
|
||||
aggregatedRow := c.selectRow(selectStmt.SelectItems, groupRows)
|
||||
|
||||
return aggregatedRow
|
||||
|
@ -553,3 +588,27 @@ func hasAggregateFunctions(selectItems []parsers.SelectItem) bool {
|
|||
|
||||
return false
|
||||
}
|
||||
|
||||
func zipRows(current []RowWithJoins, joinedTableName string, rowsToZip []RowType) []RowWithJoins {
|
||||
resultMap := make([]RowWithJoins, 0)
|
||||
|
||||
for _, currentRow := range current {
|
||||
for _, rowToZip := range rowsToZip {
|
||||
newRow := copyMap(currentRow)
|
||||
newRow[joinedTableName] = rowToZip
|
||||
resultMap = append(resultMap, newRow)
|
||||
}
|
||||
}
|
||||
|
||||
return resultMap
|
||||
}
|
||||
|
||||
func copyMap(originalMap map[string]RowType) map[string]RowType {
|
||||
targetMap := make(map[string]RowType)
|
||||
|
||||
for k, v := range originalMap {
|
||||
targetMap[k] = v
|
||||
}
|
||||
|
||||
return targetMap
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue