Implement continuation tokens

This commit is contained in:
Pijus Kamandulis
2026-01-29 21:45:46 +02:00
parent cae6fda95c
commit d3d238fa98
7 changed files with 309 additions and 18 deletions

View File

@@ -11,6 +11,7 @@ import (
apimodels "github.com/pikami/cosmium/api/api_models"
"github.com/pikami/cosmium/api/headers"
"github.com/pikami/cosmium/internal/constants"
continuationtoken "github.com/pikami/cosmium/internal/continuation_token"
"github.com/pikami/cosmium/internal/converters"
"github.com/pikami/cosmium/internal/datastore"
"github.com/pikami/cosmium/internal/logger"
@@ -262,20 +263,50 @@ func (h *Handlers) handleDocumentQuery(c *gin.Context, requestBody map[string]in
queryParameters = parametersToMap(paramsArray)
}
collection, collectionStatus := h.dataStore.GetCollection(databaseId, collectionId)
if collectionStatus == datastore.StatusNotFound {
c.IndentedJSON(http.StatusNotFound, constants.NotFoundResponse)
return
}
if collectionStatus != datastore.StatusOk {
c.IndentedJSON(http.StatusInternalServerError, constants.UnknownErrorResponse)
return
}
continuationToken := continuationtoken.GenerateDefault(collection.ResourceID)
continuationTokenHeader := c.GetHeader(headers.ContinuationToken)
if continuationTokenHeader != "" {
continuationToken = continuationtoken.FromString(continuationTokenHeader)
}
pageMaxItemCount, maxItemCountError := strconv.Atoi(c.GetHeader(headers.MaxItemCount))
if maxItemCountError != nil {
pageMaxItemCount = 1000
}
queryText := requestBody["query"].(string)
docs, status := h.executeQueryDocuments(databaseId, collectionId, queryText, queryParameters)
executeQueryResult, status := h.executeQueryDocuments(
databaseId, collectionId, queryText, queryParameters, pageMaxItemCount, continuationToken.Token.TotalResults)
if status != datastore.StatusOk {
// TODO: Currently we return everything if the query fails
logger.Infof("Query failed: %s", queryText)
h.GetAllDocuments(c)
return
}
collection, _ := h.dataStore.GetCollection(databaseId, collectionId)
c.Header(headers.ItemCount, fmt.Sprintf("%d", len(docs)))
resultCount := len(executeQueryResult.Rows)
if executeQueryResult.HasMorePages {
nextContinuationToken := continuationtoken.Generate(
collection.ResourceID, continuationToken.Token.PageIndex+1, continuationToken.Token.TotalResults+resultCount)
c.Header(headers.ContinuationToken, nextContinuationToken.ToString())
}
c.Header(headers.ItemCount, fmt.Sprintf("%d", resultCount))
c.IndentedJSON(http.StatusOK, gin.H{
"_rid": collection.ResourceID,
"Documents": docs,
"_count": len(docs),
"Documents": executeQueryResult.Rows,
"_count": resultCount,
})
}
@@ -377,16 +408,23 @@ func dataStoreStatusToResponseCode(status datastore.DataStoreStatus) int {
}
}
func (h *Handlers) executeQueryDocuments(databaseId string, collectionId string, query string, queryParameters map[string]interface{}) ([]memoryexecutor.RowType, datastore.DataStoreStatus) {
func (h *Handlers) executeQueryDocuments(
databaseId string,
collectionId string,
query string,
queryParameters map[string]interface{},
pageMaxItemCount int,
pageCursor int,
) (memoryexecutor.ExecuteQueryResult, datastore.DataStoreStatus) {
parsedQuery, err := nosql.Parse("", []byte(query))
if err != nil {
logger.Errorf("Failed to parse query: %s\nerr: %v", query, err)
return nil, datastore.BadRequest
return memoryexecutor.ExecuteQueryResult{}, datastore.BadRequest
}
allDocumentsIterator, status := h.dataStore.GetDocumentIterator(databaseId, collectionId)
if status != datastore.StatusOk {
return nil, status
return memoryexecutor.ExecuteQueryResult{}, status
}
defer allDocumentsIterator.Close()
@@ -394,8 +432,8 @@ func (h *Handlers) executeQueryDocuments(databaseId string, collectionId string,
if typedQuery, ok := parsedQuery.(parsers.SelectStmt); ok {
typedQuery.Parameters = queryParameters
return memoryexecutor.ExecuteQuery(typedQuery, rowsIterator), datastore.StatusOk
return memoryexecutor.ExecuteQuery(typedQuery, rowsIterator, pageCursor, pageMaxItemCount), datastore.StatusOk
}
return nil, datastore.BadRequest
return memoryexecutor.ExecuteQueryResult{}, datastore.BadRequest
}

View File

@@ -13,6 +13,8 @@ const (
ItemCount = "x-ms-item-count"
LSN = "lsn"
XDate = "x-ms-date"
MaxItemCount = "x-ms-max-item-count"
ContinuationToken = "x-ms-continuation"
// Kinda retarded, but what can I do ¯\_(ツ)_/¯
IsQuery = "x-ms-documentdb-isquery" // Sent from python sdk and web explorer

View File

@@ -14,6 +14,7 @@ import (
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos"
"github.com/pikami/cosmium/api/config"
continuationtoken "github.com/pikami/cosmium/internal/continuation_token"
"github.com/pikami/cosmium/internal/datastore"
"github.com/stretchr/testify/assert"
)
@@ -512,4 +513,46 @@ func Test_Documents(t *testing.T) {
assert.Equal(t, "67890", itemResponseBody["id"])
})
})
runTestsWithPresets(t, "Test_Documents_With_Continuation_Token", presets, func(t *testing.T, ts *TestServer, client *azcosmos.Client) {
collectionClient := documents_InitializeDb(t, ts)
t.Run("Should query document with continuation token", func(t *testing.T) {
context := context.TODO()
pager := collectionClient.NewQueryItemsPager(
"SELECT c.id, c[\"pk\"] FROM c ORDER BY c.id",
azcosmos.PartitionKey{},
&azcosmos.QueryOptions{
PageSizeHint: 1,
})
assert.True(t, pager.More())
firstResponse, err := pager.NextPage(context)
assert.Nil(t, err)
assert.Equal(t, 1, len(firstResponse.Items))
var firstItem map[string]interface{}
err = json.Unmarshal(firstResponse.Items[0], &firstItem)
assert.Nil(t, err)
assert.Equal(t, "12345", firstItem["id"])
assert.Equal(t, "123", firstItem["pk"])
firstContinuationToken := continuationtoken.FromString(*firstResponse.ContinuationToken)
assert.Equal(t, 1, firstContinuationToken.Token.PageIndex)
assert.Equal(t, 1, firstContinuationToken.Token.TotalResults)
assert.True(t, pager.More())
secondResponse, err := pager.NextPage(context)
assert.Nil(t, err)
assert.Equal(t, 1, len(secondResponse.Items))
var secondItem map[string]interface{}
err = json.Unmarshal(secondResponse.Items[0], &secondItem)
assert.Nil(t, err)
assert.Equal(t, "67890", secondItem["id"])
assert.Equal(t, "456", secondItem["pk"])
assert.Nil(t, secondResponse.ContinuationToken)
assert.False(t, pager.More())
})
})
}