Implement continuation tokens

This commit is contained in:
Pijus Kamandulis
2026-01-29 21:45:46 +02:00
parent cae6fda95c
commit d3d238fa98
7 changed files with 309 additions and 18 deletions

View File

@@ -11,6 +11,7 @@ import (
apimodels "github.com/pikami/cosmium/api/api_models"
"github.com/pikami/cosmium/api/headers"
"github.com/pikami/cosmium/internal/constants"
continuationtoken "github.com/pikami/cosmium/internal/continuation_token"
"github.com/pikami/cosmium/internal/converters"
"github.com/pikami/cosmium/internal/datastore"
"github.com/pikami/cosmium/internal/logger"
@@ -262,20 +263,50 @@ func (h *Handlers) handleDocumentQuery(c *gin.Context, requestBody map[string]in
queryParameters = parametersToMap(paramsArray)
}
collection, collectionStatus := h.dataStore.GetCollection(databaseId, collectionId)
if collectionStatus == datastore.StatusNotFound {
c.IndentedJSON(http.StatusNotFound, constants.NotFoundResponse)
return
}
if collectionStatus != datastore.StatusOk {
c.IndentedJSON(http.StatusInternalServerError, constants.UnknownErrorResponse)
return
}
continuationToken := continuationtoken.GenerateDefault(collection.ResourceID)
continuationTokenHeader := c.GetHeader(headers.ContinuationToken)
if continuationTokenHeader != "" {
continuationToken = continuationtoken.FromString(continuationTokenHeader)
}
pageMaxItemCount, maxItemCountError := strconv.Atoi(c.GetHeader(headers.MaxItemCount))
if maxItemCountError != nil {
pageMaxItemCount = 1000
}
queryText := requestBody["query"].(string)
docs, status := h.executeQueryDocuments(databaseId, collectionId, queryText, queryParameters)
executeQueryResult, status := h.executeQueryDocuments(
databaseId, collectionId, queryText, queryParameters, pageMaxItemCount, continuationToken.Token.TotalResults)
if status != datastore.StatusOk {
// TODO: Currently we return everything if the query fails
logger.Infof("Query failed: %s", queryText)
h.GetAllDocuments(c)
return
}
collection, _ := h.dataStore.GetCollection(databaseId, collectionId)
c.Header(headers.ItemCount, fmt.Sprintf("%d", len(docs)))
resultCount := len(executeQueryResult.Rows)
if executeQueryResult.HasMorePages {
nextContinuationToken := continuationtoken.Generate(
collection.ResourceID, continuationToken.Token.PageIndex+1, continuationToken.Token.TotalResults+resultCount)
c.Header(headers.ContinuationToken, nextContinuationToken.ToString())
}
c.Header(headers.ItemCount, fmt.Sprintf("%d", resultCount))
c.IndentedJSON(http.StatusOK, gin.H{
"_rid": collection.ResourceID,
"Documents": docs,
"_count": len(docs),
"Documents": executeQueryResult.Rows,
"_count": resultCount,
})
}
@@ -377,16 +408,23 @@ func dataStoreStatusToResponseCode(status datastore.DataStoreStatus) int {
}
}
func (h *Handlers) executeQueryDocuments(databaseId string, collectionId string, query string, queryParameters map[string]interface{}) ([]memoryexecutor.RowType, datastore.DataStoreStatus) {
func (h *Handlers) executeQueryDocuments(
databaseId string,
collectionId string,
query string,
queryParameters map[string]interface{},
pageMaxItemCount int,
pageCursor int,
) (memoryexecutor.ExecuteQueryResult, datastore.DataStoreStatus) {
parsedQuery, err := nosql.Parse("", []byte(query))
if err != nil {
logger.Errorf("Failed to parse query: %s\nerr: %v", query, err)
return nil, datastore.BadRequest
return memoryexecutor.ExecuteQueryResult{}, datastore.BadRequest
}
allDocumentsIterator, status := h.dataStore.GetDocumentIterator(databaseId, collectionId)
if status != datastore.StatusOk {
return nil, status
return memoryexecutor.ExecuteQueryResult{}, status
}
defer allDocumentsIterator.Close()
@@ -394,8 +432,8 @@ func (h *Handlers) executeQueryDocuments(databaseId string, collectionId string,
if typedQuery, ok := parsedQuery.(parsers.SelectStmt); ok {
typedQuery.Parameters = queryParameters
return memoryexecutor.ExecuteQuery(typedQuery, rowsIterator), datastore.StatusOk
return memoryexecutor.ExecuteQuery(typedQuery, rowsIterator, pageCursor, pageMaxItemCount), datastore.StatusOk
}
return nil, datastore.BadRequest
return memoryexecutor.ExecuteQueryResult{}, datastore.BadRequest
}

View File

@@ -13,6 +13,8 @@ const (
ItemCount = "x-ms-item-count"
LSN = "lsn"
XDate = "x-ms-date"
MaxItemCount = "x-ms-max-item-count"
ContinuationToken = "x-ms-continuation"
// Kinda retarded, but what can I do ¯\_(ツ)_/¯
IsQuery = "x-ms-documentdb-isquery" // Sent from python sdk and web explorer

View File

@@ -14,6 +14,7 @@ import (
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos"
"github.com/pikami/cosmium/api/config"
continuationtoken "github.com/pikami/cosmium/internal/continuation_token"
"github.com/pikami/cosmium/internal/datastore"
"github.com/stretchr/testify/assert"
)
@@ -512,4 +513,46 @@ func Test_Documents(t *testing.T) {
assert.Equal(t, "67890", itemResponseBody["id"])
})
})
runTestsWithPresets(t, "Test_Documents_With_Continuation_Token", presets, func(t *testing.T, ts *TestServer, client *azcosmos.Client) {
collectionClient := documents_InitializeDb(t, ts)
t.Run("Should query document with continuation token", func(t *testing.T) {
context := context.TODO()
pager := collectionClient.NewQueryItemsPager(
"SELECT c.id, c[\"pk\"] FROM c ORDER BY c.id",
azcosmos.PartitionKey{},
&azcosmos.QueryOptions{
PageSizeHint: 1,
})
assert.True(t, pager.More())
firstResponse, err := pager.NextPage(context)
assert.Nil(t, err)
assert.Equal(t, 1, len(firstResponse.Items))
var firstItem map[string]interface{}
err = json.Unmarshal(firstResponse.Items[0], &firstItem)
assert.Nil(t, err)
assert.Equal(t, "12345", firstItem["id"])
assert.Equal(t, "123", firstItem["pk"])
firstContinuationToken := continuationtoken.FromString(*firstResponse.ContinuationToken)
assert.Equal(t, 1, firstContinuationToken.Token.PageIndex)
assert.Equal(t, 1, firstContinuationToken.Token.TotalResults)
assert.True(t, pager.More())
secondResponse, err := pager.NextPage(context)
assert.Nil(t, err)
assert.Equal(t, 1, len(secondResponse.Items))
var secondItem map[string]interface{}
err = json.Unmarshal(secondResponse.Items[0], &secondItem)
assert.Nil(t, err)
assert.Equal(t, "67890", secondItem["id"])
assert.Equal(t, "456", secondItem["pk"])
assert.Nil(t, secondResponse.ContinuationToken)
assert.False(t, pager.More())
})
})
}

View File

@@ -0,0 +1,145 @@
package continuationtoken
import (
"encoding/json"
"fmt"
"strconv"
"strings"
"github.com/pikami/cosmium/internal/logger"
)
type ContinuationTokenExternal struct {
Token string `json:"token"`
Range struct {
Min string `json:"min"`
Max string `json:"max"`
} `json:"range"`
}
type ContinuationToken struct {
Token struct {
ResourceId string // RID
PageIndex int // RT
TotalResults int // TRC
ISV int // ISV
IEO int // IEO
QCF int // QCF
LR int // LR
}
Range struct {
Min string
Max string
}
}
func Generate(resourceid string, pageIndex int, totalResults int) ContinuationToken {
ct := ContinuationToken{}
ct.Token.ResourceId = resourceid
ct.Token.PageIndex = pageIndex
ct.Token.TotalResults = totalResults
ct.Token.ISV = 2
ct.Token.IEO = 65567
ct.Token.QCF = 8
ct.Token.LR = 1
ct.Range.Min = ""
ct.Range.Max = "FF"
return ct
}
func GenerateDefault(resourceid string) ContinuationToken {
return Generate(resourceid, 0, 0)
}
func (ct *ContinuationToken) ToString() string {
token := fmt.Sprintf(
"-RID:~%s#RT:%d#TRC:%d#ISV:%d#IEO:%d#QCF:%d#LR:%d",
ct.Token.ResourceId,
ct.Token.PageIndex,
ct.Token.TotalResults,
ct.Token.ISV,
ct.Token.IEO,
ct.Token.QCF,
ct.Token.LR,
)
ect := ContinuationTokenExternal{}
ect.Token = token
ect.Range.Min = ct.Range.Min
ect.Range.Max = ct.Range.Max
json, err := json.Marshal(ect)
if err != nil {
logger.Error(err, "failed to marshal continuation token")
return ""
}
return string(json)
}
func FromString(token string) ContinuationToken {
ect := ContinuationTokenExternal{}
err := json.Unmarshal([]byte(token), &ect)
if err != nil {
logger.Error(err, "failed to unmarshal continuation token")
return ContinuationToken{}
}
ct, err := parseContinuationToken(ect.Token, ect.Range.Min, ect.Range.Max)
if err != nil {
logger.Error(err, "failed to parse continuation token")
return ContinuationToken{}
}
return *ct
}
func parseContinuationToken(token string, minRange string, maxRange string) (*ContinuationToken, error) {
const prefix = "-RID:~"
if !strings.HasPrefix(token, prefix) {
return nil, fmt.Errorf("invalid token prefix")
}
parts := strings.Split(token[len(prefix):], "#")
if len(parts) != 7 {
return nil, fmt.Errorf("invalid token format: expected 7 fields, got %d", len(parts))
}
ct := &ContinuationToken{}
ct.Token.ResourceId = parts[0]
parseIntField := func(part, key string) (int, error) {
if !strings.HasPrefix(part, key+":") {
return 0, fmt.Errorf("expected %s field", key)
}
return strconv.Atoi(strings.TrimPrefix(part, key+":"))
}
var err error
if ct.Token.PageIndex, err = parseIntField(parts[1], "RT"); err != nil {
return nil, err
}
if ct.Token.TotalResults, err = parseIntField(parts[2], "TRC"); err != nil {
return nil, err
}
if ct.Token.ISV, err = parseIntField(parts[3], "ISV"); err != nil {
return nil, err
}
if ct.Token.IEO, err = parseIntField(parts[4], "IEO"); err != nil {
return nil, err
}
if ct.Token.QCF, err = parseIntField(parts[5], "QCF"); err != nil {
return nil, err
}
if ct.Token.LR, err = parseIntField(parts[6], "LR"); err != nil {
return nil, err
}
ct.Range.Min = minRange
ct.Range.Max = maxRange
return ct, nil
}

View File

@@ -0,0 +1,35 @@
package continuationtoken
import (
"testing"
"github.com/stretchr/testify/assert"
)
func Test_Generate(t *testing.T) {
token := Generate("test-resource-id", 1, 100)
assert.Equal(t, "test-resource-id", token.Token.ResourceId)
assert.Equal(t, 1, token.Token.PageIndex)
assert.Equal(t, 100, token.Token.TotalResults)
}
func Test_FromString(t *testing.T) {
token := FromString("{\"token\":\"-RID:~test-resource-id#RT:1#TRC:100#ISV:2#IEO:65567#QCF:8#LR:1\",\"range\":{\"min\":\"\",\"max\":\"FF\"}}")
assert.Equal(t, "test-resource-id", token.Token.ResourceId)
assert.Equal(t, 1, token.Token.PageIndex)
assert.Equal(t, 100, token.Token.TotalResults)
}
func Test_ToString(t *testing.T) {
token := Generate("test-resource-id", 1, 100)
assert.Equal(t, "{\"token\":\"-RID:~test-resource-id#RT:1#TRC:100#ISV:2#IEO:65567#QCF:8#LR:1\",\"range\":{\"min\":\"\",\"max\":\"FF\"}}", token.ToString())
}
func Test_GenerateDefault(t *testing.T) {
token := GenerateDefault("test-resource-id")
assert.Equal(t, "test-resource-id", token.Token.ResourceId)
assert.Equal(t, 0, token.Token.PageIndex)
assert.Equal(t, 0, token.Token.TotalResults)
}

View File

@@ -5,18 +5,46 @@ import (
"github.com/pikami/cosmium/parsers"
)
func ExecuteQuery(query parsers.SelectStmt, documents rowTypeIterator) []RowType {
type ExecuteQueryResult struct {
Rows []RowType
HasMorePages bool
}
func ExecuteQuery(
query parsers.SelectStmt,
documents rowTypeIterator,
offset int,
limit int,
) ExecuteQueryResult {
resultIter := executeQuery(query, &rowTypeToRowContextIterator{documents: documents, query: query})
result := make([]RowType, 0)
for {
result := &ExecuteQueryResult{
Rows: make([]RowType, 0),
HasMorePages: false,
}
for i := 0; i < offset; i++ {
_, status := resultIter.Next()
if status != datastore.StatusOk {
break
}
}
for i := 0; i < limit; i++ {
row, status := resultIter.Next()
if status != datastore.StatusOk {
break
}
result = append(result, row)
result.Rows = append(result.Rows, row)
}
return result
_, status := resultIter.Next()
if status == datastore.StatusOk {
result.HasMorePages = true
}
return *result
}
func executeQuery(query parsers.SelectStmt, documents rowIterator) rowTypeIterator {

View File

@@ -38,10 +38,10 @@ func testQueryExecute(
expectedData []memoryexecutor.RowType,
) {
iter := NewTestDocumentIterator(data)
result := memoryexecutor.ExecuteQuery(query, iter)
result := memoryexecutor.ExecuteQuery(query, iter, 0, 1000)
if !reflect.DeepEqual(result, expectedData) {
t.Errorf("execution result does not match expected data.\nExpected: %+v\nGot: %+v", expectedData, result)
if !reflect.DeepEqual(result.Rows, expectedData) {
t.Errorf("execution result does not match expected data.\nExpected: %+v\nGot: %+v", expectedData, result.Rows)
}
}