diff --git a/api/handlers/documents.go b/api/handlers/documents.go index 91c9021..5f3fa68 100644 --- a/api/handlers/documents.go +++ b/api/handlers/documents.go @@ -11,6 +11,7 @@ import ( apimodels "github.com/pikami/cosmium/api/api_models" "github.com/pikami/cosmium/api/headers" "github.com/pikami/cosmium/internal/constants" + continuationtoken "github.com/pikami/cosmium/internal/continuation_token" "github.com/pikami/cosmium/internal/converters" "github.com/pikami/cosmium/internal/datastore" "github.com/pikami/cosmium/internal/logger" @@ -262,20 +263,50 @@ func (h *Handlers) handleDocumentQuery(c *gin.Context, requestBody map[string]in queryParameters = parametersToMap(paramsArray) } + collection, collectionStatus := h.dataStore.GetCollection(databaseId, collectionId) + if collectionStatus == datastore.StatusNotFound { + c.IndentedJSON(http.StatusNotFound, constants.NotFoundResponse) + return + } + + if collectionStatus != datastore.StatusOk { + c.IndentedJSON(http.StatusInternalServerError, constants.UnknownErrorResponse) + return + } + + continuationToken := continuationtoken.GenerateDefault(collection.ResourceID) + continuationTokenHeader := c.GetHeader(headers.ContinuationToken) + if continuationTokenHeader != "" { + continuationToken = continuationtoken.FromString(continuationTokenHeader) + } + + pageMaxItemCount, maxItemCountError := strconv.Atoi(c.GetHeader(headers.MaxItemCount)) + if maxItemCountError != nil { + pageMaxItemCount = 1000 + } + queryText := requestBody["query"].(string) - docs, status := h.executeQueryDocuments(databaseId, collectionId, queryText, queryParameters) + executeQueryResult, status := h.executeQueryDocuments( + databaseId, collectionId, queryText, queryParameters, pageMaxItemCount, continuationToken.Token.TotalResults) if status != datastore.StatusOk { // TODO: Currently we return everything if the query fails + logger.Infof("Query failed: %s", queryText) h.GetAllDocuments(c) return } - collection, _ := h.dataStore.GetCollection(databaseId, collectionId) - c.Header(headers.ItemCount, fmt.Sprintf("%d", len(docs))) + resultCount := len(executeQueryResult.Rows) + if executeQueryResult.HasMorePages { + nextContinuationToken := continuationtoken.Generate( + collection.ResourceID, continuationToken.Token.PageIndex+1, continuationToken.Token.TotalResults+resultCount) + c.Header(headers.ContinuationToken, nextContinuationToken.ToString()) + } + + c.Header(headers.ItemCount, fmt.Sprintf("%d", resultCount)) c.IndentedJSON(http.StatusOK, gin.H{ "_rid": collection.ResourceID, - "Documents": docs, - "_count": len(docs), + "Documents": executeQueryResult.Rows, + "_count": resultCount, }) } @@ -377,16 +408,23 @@ func dataStoreStatusToResponseCode(status datastore.DataStoreStatus) int { } } -func (h *Handlers) executeQueryDocuments(databaseId string, collectionId string, query string, queryParameters map[string]interface{}) ([]memoryexecutor.RowType, datastore.DataStoreStatus) { +func (h *Handlers) executeQueryDocuments( + databaseId string, + collectionId string, + query string, + queryParameters map[string]interface{}, + pageMaxItemCount int, + pageCursor int, +) (memoryexecutor.ExecuteQueryResult, datastore.DataStoreStatus) { parsedQuery, err := nosql.Parse("", []byte(query)) if err != nil { logger.Errorf("Failed to parse query: %s\nerr: %v", query, err) - return nil, datastore.BadRequest + return memoryexecutor.ExecuteQueryResult{}, datastore.BadRequest } allDocumentsIterator, status := h.dataStore.GetDocumentIterator(databaseId, collectionId) if status != datastore.StatusOk { - return nil, status + return memoryexecutor.ExecuteQueryResult{}, status } defer allDocumentsIterator.Close() @@ -394,8 +432,8 @@ func (h *Handlers) executeQueryDocuments(databaseId string, collectionId string, if typedQuery, ok := parsedQuery.(parsers.SelectStmt); ok { typedQuery.Parameters = queryParameters - return memoryexecutor.ExecuteQuery(typedQuery, rowsIterator), datastore.StatusOk + return memoryexecutor.ExecuteQuery(typedQuery, rowsIterator, pageCursor, pageMaxItemCount), datastore.StatusOk } - return nil, datastore.BadRequest + return memoryexecutor.ExecuteQueryResult{}, datastore.BadRequest } diff --git a/api/headers/headers.go b/api/headers/headers.go index 836d59e..76a5999 100644 --- a/api/headers/headers.go +++ b/api/headers/headers.go @@ -13,6 +13,8 @@ const ( ItemCount = "x-ms-item-count" LSN = "lsn" XDate = "x-ms-date" + MaxItemCount = "x-ms-max-item-count" + ContinuationToken = "x-ms-continuation" // Kinda retarded, but what can I do ¯\_(ツ)_/¯ IsQuery = "x-ms-documentdb-isquery" // Sent from python sdk and web explorer diff --git a/api/tests/documents_test.go b/api/tests/documents_test.go index ab21e5b..8c8cc58 100644 --- a/api/tests/documents_test.go +++ b/api/tests/documents_test.go @@ -14,6 +14,7 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos" "github.com/pikami/cosmium/api/config" + continuationtoken "github.com/pikami/cosmium/internal/continuation_token" "github.com/pikami/cosmium/internal/datastore" "github.com/stretchr/testify/assert" ) @@ -512,4 +513,46 @@ func Test_Documents(t *testing.T) { assert.Equal(t, "67890", itemResponseBody["id"]) }) }) + + runTestsWithPresets(t, "Test_Documents_With_Continuation_Token", presets, func(t *testing.T, ts *TestServer, client *azcosmos.Client) { + collectionClient := documents_InitializeDb(t, ts) + + t.Run("Should query document with continuation token", func(t *testing.T) { + context := context.TODO() + pager := collectionClient.NewQueryItemsPager( + "SELECT c.id, c[\"pk\"] FROM c ORDER BY c.id", + azcosmos.PartitionKey{}, + &azcosmos.QueryOptions{ + PageSizeHint: 1, + }) + + assert.True(t, pager.More()) + + firstResponse, err := pager.NextPage(context) + assert.Nil(t, err) + assert.Equal(t, 1, len(firstResponse.Items)) + var firstItem map[string]interface{} + err = json.Unmarshal(firstResponse.Items[0], &firstItem) + assert.Nil(t, err) + assert.Equal(t, "12345", firstItem["id"]) + assert.Equal(t, "123", firstItem["pk"]) + + firstContinuationToken := continuationtoken.FromString(*firstResponse.ContinuationToken) + assert.Equal(t, 1, firstContinuationToken.Token.PageIndex) + assert.Equal(t, 1, firstContinuationToken.Token.TotalResults) + + assert.True(t, pager.More()) + secondResponse, err := pager.NextPage(context) + assert.Nil(t, err) + assert.Equal(t, 1, len(secondResponse.Items)) + var secondItem map[string]interface{} + err = json.Unmarshal(secondResponse.Items[0], &secondItem) + assert.Nil(t, err) + assert.Equal(t, "67890", secondItem["id"]) + assert.Equal(t, "456", secondItem["pk"]) + assert.Nil(t, secondResponse.ContinuationToken) + + assert.False(t, pager.More()) + }) + }) } diff --git a/internal/continuation_token/continuation_token.go b/internal/continuation_token/continuation_token.go new file mode 100644 index 0000000..d0671fa --- /dev/null +++ b/internal/continuation_token/continuation_token.go @@ -0,0 +1,145 @@ +package continuationtoken + +import ( + "encoding/json" + "fmt" + "strconv" + "strings" + + "github.com/pikami/cosmium/internal/logger" +) + +type ContinuationTokenExternal struct { + Token string `json:"token"` + Range struct { + Min string `json:"min"` + Max string `json:"max"` + } `json:"range"` +} + +type ContinuationToken struct { + Token struct { + ResourceId string // RID + PageIndex int // RT + TotalResults int // TRC + ISV int // ISV + IEO int // IEO + QCF int // QCF + LR int // LR + } + Range struct { + Min string + Max string + } +} + +func Generate(resourceid string, pageIndex int, totalResults int) ContinuationToken { + ct := ContinuationToken{} + ct.Token.ResourceId = resourceid + ct.Token.PageIndex = pageIndex + ct.Token.TotalResults = totalResults + ct.Token.ISV = 2 + ct.Token.IEO = 65567 + ct.Token.QCF = 8 + ct.Token.LR = 1 + ct.Range.Min = "" + ct.Range.Max = "FF" + + return ct +} + +func GenerateDefault(resourceid string) ContinuationToken { + return Generate(resourceid, 0, 0) +} + +func (ct *ContinuationToken) ToString() string { + token := fmt.Sprintf( + "-RID:~%s#RT:%d#TRC:%d#ISV:%d#IEO:%d#QCF:%d#LR:%d", + ct.Token.ResourceId, + ct.Token.PageIndex, + ct.Token.TotalResults, + ct.Token.ISV, + ct.Token.IEO, + ct.Token.QCF, + ct.Token.LR, + ) + + ect := ContinuationTokenExternal{} + ect.Token = token + ect.Range.Min = ct.Range.Min + ect.Range.Max = ct.Range.Max + + json, err := json.Marshal(ect) + if err != nil { + logger.Error(err, "failed to marshal continuation token") + return "" + } + + return string(json) +} + +func FromString(token string) ContinuationToken { + ect := ContinuationTokenExternal{} + err := json.Unmarshal([]byte(token), &ect) + if err != nil { + logger.Error(err, "failed to unmarshal continuation token") + return ContinuationToken{} + } + + ct, err := parseContinuationToken(ect.Token, ect.Range.Min, ect.Range.Max) + if err != nil { + logger.Error(err, "failed to parse continuation token") + return ContinuationToken{} + } + + return *ct +} + +func parseContinuationToken(token string, minRange string, maxRange string) (*ContinuationToken, error) { + const prefix = "-RID:~" + if !strings.HasPrefix(token, prefix) { + return nil, fmt.Errorf("invalid token prefix") + } + + parts := strings.Split(token[len(prefix):], "#") + if len(parts) != 7 { + return nil, fmt.Errorf("invalid token format: expected 7 fields, got %d", len(parts)) + } + + ct := &ContinuationToken{} + + ct.Token.ResourceId = parts[0] + + parseIntField := func(part, key string) (int, error) { + if !strings.HasPrefix(part, key+":") { + return 0, fmt.Errorf("expected %s field", key) + } + return strconv.Atoi(strings.TrimPrefix(part, key+":")) + } + + var err error + + if ct.Token.PageIndex, err = parseIntField(parts[1], "RT"); err != nil { + return nil, err + } + if ct.Token.TotalResults, err = parseIntField(parts[2], "TRC"); err != nil { + return nil, err + } + if ct.Token.ISV, err = parseIntField(parts[3], "ISV"); err != nil { + return nil, err + } + if ct.Token.IEO, err = parseIntField(parts[4], "IEO"); err != nil { + return nil, err + } + if ct.Token.QCF, err = parseIntField(parts[5], "QCF"); err != nil { + return nil, err + } + if ct.Token.LR, err = parseIntField(parts[6], "LR"); err != nil { + return nil, err + } + + ct.Range.Min = minRange + ct.Range.Max = maxRange + + return ct, nil +} diff --git a/internal/continuation_token/continuation_token_test.go b/internal/continuation_token/continuation_token_test.go new file mode 100644 index 0000000..1ba3d5d --- /dev/null +++ b/internal/continuation_token/continuation_token_test.go @@ -0,0 +1,35 @@ +package continuationtoken + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_Generate(t *testing.T) { + token := Generate("test-resource-id", 1, 100) + + assert.Equal(t, "test-resource-id", token.Token.ResourceId) + assert.Equal(t, 1, token.Token.PageIndex) + assert.Equal(t, 100, token.Token.TotalResults) +} + +func Test_FromString(t *testing.T) { + token := FromString("{\"token\":\"-RID:~test-resource-id#RT:1#TRC:100#ISV:2#IEO:65567#QCF:8#LR:1\",\"range\":{\"min\":\"\",\"max\":\"FF\"}}") + + assert.Equal(t, "test-resource-id", token.Token.ResourceId) + assert.Equal(t, 1, token.Token.PageIndex) + assert.Equal(t, 100, token.Token.TotalResults) +} + +func Test_ToString(t *testing.T) { + token := Generate("test-resource-id", 1, 100) + assert.Equal(t, "{\"token\":\"-RID:~test-resource-id#RT:1#TRC:100#ISV:2#IEO:65567#QCF:8#LR:1\",\"range\":{\"min\":\"\",\"max\":\"FF\"}}", token.ToString()) +} + +func Test_GenerateDefault(t *testing.T) { + token := GenerateDefault("test-resource-id") + assert.Equal(t, "test-resource-id", token.Token.ResourceId) + assert.Equal(t, 0, token.Token.PageIndex) + assert.Equal(t, 0, token.Token.TotalResults) +} diff --git a/query_executors/memory_executor/memory_executor.go b/query_executors/memory_executor/memory_executor.go index 89fd362..7905c22 100644 --- a/query_executors/memory_executor/memory_executor.go +++ b/query_executors/memory_executor/memory_executor.go @@ -5,18 +5,46 @@ import ( "github.com/pikami/cosmium/parsers" ) -func ExecuteQuery(query parsers.SelectStmt, documents rowTypeIterator) []RowType { +type ExecuteQueryResult struct { + Rows []RowType + HasMorePages bool +} + +func ExecuteQuery( + query parsers.SelectStmt, + documents rowTypeIterator, + offset int, + limit int, +) ExecuteQueryResult { resultIter := executeQuery(query, &rowTypeToRowContextIterator{documents: documents, query: query}) - result := make([]RowType, 0) - for { + + result := &ExecuteQueryResult{ + Rows: make([]RowType, 0), + HasMorePages: false, + } + + for i := 0; i < offset; i++ { + _, status := resultIter.Next() + if status != datastore.StatusOk { + break + } + } + + for i := 0; i < limit; i++ { row, status := resultIter.Next() if status != datastore.StatusOk { break } - result = append(result, row) + result.Rows = append(result.Rows, row) } - return result + + _, status := resultIter.Next() + if status == datastore.StatusOk { + result.HasMorePages = true + } + + return *result } func executeQuery(query parsers.SelectStmt, documents rowIterator) rowTypeIterator { diff --git a/query_executors/memory_executor/misc_test.go b/query_executors/memory_executor/misc_test.go index f81f09d..ee8f9d9 100644 --- a/query_executors/memory_executor/misc_test.go +++ b/query_executors/memory_executor/misc_test.go @@ -38,10 +38,10 @@ func testQueryExecute( expectedData []memoryexecutor.RowType, ) { iter := NewTestDocumentIterator(data) - result := memoryexecutor.ExecuteQuery(query, iter) + result := memoryexecutor.ExecuteQuery(query, iter, 0, 1000) - if !reflect.DeepEqual(result, expectedData) { - t.Errorf("execution result does not match expected data.\nExpected: %+v\nGot: %+v", expectedData, result) + if !reflect.DeepEqual(result.Rows, expectedData) { + t.Errorf("execution result does not match expected data.\nExpected: %+v\nGot: %+v", expectedData, result.Rows) } }