mirror of
https://github.com/pikami/cosmium.git
synced 2026-01-30 14:53:00 +00:00
Implement continuation tokens
This commit is contained in:
@@ -11,6 +11,7 @@ import (
|
||||
apimodels "github.com/pikami/cosmium/api/api_models"
|
||||
"github.com/pikami/cosmium/api/headers"
|
||||
"github.com/pikami/cosmium/internal/constants"
|
||||
continuationtoken "github.com/pikami/cosmium/internal/continuation_token"
|
||||
"github.com/pikami/cosmium/internal/converters"
|
||||
"github.com/pikami/cosmium/internal/datastore"
|
||||
"github.com/pikami/cosmium/internal/logger"
|
||||
@@ -262,20 +263,50 @@ func (h *Handlers) handleDocumentQuery(c *gin.Context, requestBody map[string]in
|
||||
queryParameters = parametersToMap(paramsArray)
|
||||
}
|
||||
|
||||
collection, collectionStatus := h.dataStore.GetCollection(databaseId, collectionId)
|
||||
if collectionStatus == datastore.StatusNotFound {
|
||||
c.IndentedJSON(http.StatusNotFound, constants.NotFoundResponse)
|
||||
return
|
||||
}
|
||||
|
||||
if collectionStatus != datastore.StatusOk {
|
||||
c.IndentedJSON(http.StatusInternalServerError, constants.UnknownErrorResponse)
|
||||
return
|
||||
}
|
||||
|
||||
continuationToken := continuationtoken.GenerateDefault(collection.ResourceID)
|
||||
continuationTokenHeader := c.GetHeader(headers.ContinuationToken)
|
||||
if continuationTokenHeader != "" {
|
||||
continuationToken = continuationtoken.FromString(continuationTokenHeader)
|
||||
}
|
||||
|
||||
pageMaxItemCount, maxItemCountError := strconv.Atoi(c.GetHeader(headers.MaxItemCount))
|
||||
if maxItemCountError != nil {
|
||||
pageMaxItemCount = 1000
|
||||
}
|
||||
|
||||
queryText := requestBody["query"].(string)
|
||||
docs, status := h.executeQueryDocuments(databaseId, collectionId, queryText, queryParameters)
|
||||
executeQueryResult, status := h.executeQueryDocuments(
|
||||
databaseId, collectionId, queryText, queryParameters, pageMaxItemCount, continuationToken.Token.TotalResults)
|
||||
if status != datastore.StatusOk {
|
||||
// TODO: Currently we return everything if the query fails
|
||||
logger.Infof("Query failed: %s", queryText)
|
||||
h.GetAllDocuments(c)
|
||||
return
|
||||
}
|
||||
|
||||
collection, _ := h.dataStore.GetCollection(databaseId, collectionId)
|
||||
c.Header(headers.ItemCount, fmt.Sprintf("%d", len(docs)))
|
||||
resultCount := len(executeQueryResult.Rows)
|
||||
if executeQueryResult.HasMorePages {
|
||||
nextContinuationToken := continuationtoken.Generate(
|
||||
collection.ResourceID, continuationToken.Token.PageIndex+1, continuationToken.Token.TotalResults+resultCount)
|
||||
c.Header(headers.ContinuationToken, nextContinuationToken.ToString())
|
||||
}
|
||||
|
||||
c.Header(headers.ItemCount, fmt.Sprintf("%d", resultCount))
|
||||
c.IndentedJSON(http.StatusOK, gin.H{
|
||||
"_rid": collection.ResourceID,
|
||||
"Documents": docs,
|
||||
"_count": len(docs),
|
||||
"Documents": executeQueryResult.Rows,
|
||||
"_count": resultCount,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -377,16 +408,23 @@ func dataStoreStatusToResponseCode(status datastore.DataStoreStatus) int {
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handlers) executeQueryDocuments(databaseId string, collectionId string, query string, queryParameters map[string]interface{}) ([]memoryexecutor.RowType, datastore.DataStoreStatus) {
|
||||
func (h *Handlers) executeQueryDocuments(
|
||||
databaseId string,
|
||||
collectionId string,
|
||||
query string,
|
||||
queryParameters map[string]interface{},
|
||||
pageMaxItemCount int,
|
||||
pageCursor int,
|
||||
) (memoryexecutor.ExecuteQueryResult, datastore.DataStoreStatus) {
|
||||
parsedQuery, err := nosql.Parse("", []byte(query))
|
||||
if err != nil {
|
||||
logger.Errorf("Failed to parse query: %s\nerr: %v", query, err)
|
||||
return nil, datastore.BadRequest
|
||||
return memoryexecutor.ExecuteQueryResult{}, datastore.BadRequest
|
||||
}
|
||||
|
||||
allDocumentsIterator, status := h.dataStore.GetDocumentIterator(databaseId, collectionId)
|
||||
if status != datastore.StatusOk {
|
||||
return nil, status
|
||||
return memoryexecutor.ExecuteQueryResult{}, status
|
||||
}
|
||||
defer allDocumentsIterator.Close()
|
||||
|
||||
@@ -394,8 +432,8 @@ func (h *Handlers) executeQueryDocuments(databaseId string, collectionId string,
|
||||
|
||||
if typedQuery, ok := parsedQuery.(parsers.SelectStmt); ok {
|
||||
typedQuery.Parameters = queryParameters
|
||||
return memoryexecutor.ExecuteQuery(typedQuery, rowsIterator), datastore.StatusOk
|
||||
return memoryexecutor.ExecuteQuery(typedQuery, rowsIterator, pageCursor, pageMaxItemCount), datastore.StatusOk
|
||||
}
|
||||
|
||||
return nil, datastore.BadRequest
|
||||
return memoryexecutor.ExecuteQueryResult{}, datastore.BadRequest
|
||||
}
|
||||
|
||||
@@ -13,6 +13,8 @@ const (
|
||||
ItemCount = "x-ms-item-count"
|
||||
LSN = "lsn"
|
||||
XDate = "x-ms-date"
|
||||
MaxItemCount = "x-ms-max-item-count"
|
||||
ContinuationToken = "x-ms-continuation"
|
||||
|
||||
// Kinda retarded, but what can I do ¯\_(ツ)_/¯
|
||||
IsQuery = "x-ms-documentdb-isquery" // Sent from python sdk and web explorer
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos"
|
||||
"github.com/pikami/cosmium/api/config"
|
||||
continuationtoken "github.com/pikami/cosmium/internal/continuation_token"
|
||||
"github.com/pikami/cosmium/internal/datastore"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
@@ -512,4 +513,46 @@ func Test_Documents(t *testing.T) {
|
||||
assert.Equal(t, "67890", itemResponseBody["id"])
|
||||
})
|
||||
})
|
||||
|
||||
runTestsWithPresets(t, "Test_Documents_With_Continuation_Token", presets, func(t *testing.T, ts *TestServer, client *azcosmos.Client) {
|
||||
collectionClient := documents_InitializeDb(t, ts)
|
||||
|
||||
t.Run("Should query document with continuation token", func(t *testing.T) {
|
||||
context := context.TODO()
|
||||
pager := collectionClient.NewQueryItemsPager(
|
||||
"SELECT c.id, c[\"pk\"] FROM c ORDER BY c.id",
|
||||
azcosmos.PartitionKey{},
|
||||
&azcosmos.QueryOptions{
|
||||
PageSizeHint: 1,
|
||||
})
|
||||
|
||||
assert.True(t, pager.More())
|
||||
|
||||
firstResponse, err := pager.NextPage(context)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 1, len(firstResponse.Items))
|
||||
var firstItem map[string]interface{}
|
||||
err = json.Unmarshal(firstResponse.Items[0], &firstItem)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "12345", firstItem["id"])
|
||||
assert.Equal(t, "123", firstItem["pk"])
|
||||
|
||||
firstContinuationToken := continuationtoken.FromString(*firstResponse.ContinuationToken)
|
||||
assert.Equal(t, 1, firstContinuationToken.Token.PageIndex)
|
||||
assert.Equal(t, 1, firstContinuationToken.Token.TotalResults)
|
||||
|
||||
assert.True(t, pager.More())
|
||||
secondResponse, err := pager.NextPage(context)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 1, len(secondResponse.Items))
|
||||
var secondItem map[string]interface{}
|
||||
err = json.Unmarshal(secondResponse.Items[0], &secondItem)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "67890", secondItem["id"])
|
||||
assert.Equal(t, "456", secondItem["pk"])
|
||||
assert.Nil(t, secondResponse.ContinuationToken)
|
||||
|
||||
assert.False(t, pager.More())
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
145
internal/continuation_token/continuation_token.go
Normal file
145
internal/continuation_token/continuation_token.go
Normal file
@@ -0,0 +1,145 @@
|
||||
package continuationtoken
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/pikami/cosmium/internal/logger"
|
||||
)
|
||||
|
||||
type ContinuationTokenExternal struct {
|
||||
Token string `json:"token"`
|
||||
Range struct {
|
||||
Min string `json:"min"`
|
||||
Max string `json:"max"`
|
||||
} `json:"range"`
|
||||
}
|
||||
|
||||
type ContinuationToken struct {
|
||||
Token struct {
|
||||
ResourceId string // RID
|
||||
PageIndex int // RT
|
||||
TotalResults int // TRC
|
||||
ISV int // ISV
|
||||
IEO int // IEO
|
||||
QCF int // QCF
|
||||
LR int // LR
|
||||
}
|
||||
Range struct {
|
||||
Min string
|
||||
Max string
|
||||
}
|
||||
}
|
||||
|
||||
func Generate(resourceid string, pageIndex int, totalResults int) ContinuationToken {
|
||||
ct := ContinuationToken{}
|
||||
ct.Token.ResourceId = resourceid
|
||||
ct.Token.PageIndex = pageIndex
|
||||
ct.Token.TotalResults = totalResults
|
||||
ct.Token.ISV = 2
|
||||
ct.Token.IEO = 65567
|
||||
ct.Token.QCF = 8
|
||||
ct.Token.LR = 1
|
||||
ct.Range.Min = ""
|
||||
ct.Range.Max = "FF"
|
||||
|
||||
return ct
|
||||
}
|
||||
|
||||
func GenerateDefault(resourceid string) ContinuationToken {
|
||||
return Generate(resourceid, 0, 0)
|
||||
}
|
||||
|
||||
func (ct *ContinuationToken) ToString() string {
|
||||
token := fmt.Sprintf(
|
||||
"-RID:~%s#RT:%d#TRC:%d#ISV:%d#IEO:%d#QCF:%d#LR:%d",
|
||||
ct.Token.ResourceId,
|
||||
ct.Token.PageIndex,
|
||||
ct.Token.TotalResults,
|
||||
ct.Token.ISV,
|
||||
ct.Token.IEO,
|
||||
ct.Token.QCF,
|
||||
ct.Token.LR,
|
||||
)
|
||||
|
||||
ect := ContinuationTokenExternal{}
|
||||
ect.Token = token
|
||||
ect.Range.Min = ct.Range.Min
|
||||
ect.Range.Max = ct.Range.Max
|
||||
|
||||
json, err := json.Marshal(ect)
|
||||
if err != nil {
|
||||
logger.Error(err, "failed to marshal continuation token")
|
||||
return ""
|
||||
}
|
||||
|
||||
return string(json)
|
||||
}
|
||||
|
||||
func FromString(token string) ContinuationToken {
|
||||
ect := ContinuationTokenExternal{}
|
||||
err := json.Unmarshal([]byte(token), &ect)
|
||||
if err != nil {
|
||||
logger.Error(err, "failed to unmarshal continuation token")
|
||||
return ContinuationToken{}
|
||||
}
|
||||
|
||||
ct, err := parseContinuationToken(ect.Token, ect.Range.Min, ect.Range.Max)
|
||||
if err != nil {
|
||||
logger.Error(err, "failed to parse continuation token")
|
||||
return ContinuationToken{}
|
||||
}
|
||||
|
||||
return *ct
|
||||
}
|
||||
|
||||
func parseContinuationToken(token string, minRange string, maxRange string) (*ContinuationToken, error) {
|
||||
const prefix = "-RID:~"
|
||||
if !strings.HasPrefix(token, prefix) {
|
||||
return nil, fmt.Errorf("invalid token prefix")
|
||||
}
|
||||
|
||||
parts := strings.Split(token[len(prefix):], "#")
|
||||
if len(parts) != 7 {
|
||||
return nil, fmt.Errorf("invalid token format: expected 7 fields, got %d", len(parts))
|
||||
}
|
||||
|
||||
ct := &ContinuationToken{}
|
||||
|
||||
ct.Token.ResourceId = parts[0]
|
||||
|
||||
parseIntField := func(part, key string) (int, error) {
|
||||
if !strings.HasPrefix(part, key+":") {
|
||||
return 0, fmt.Errorf("expected %s field", key)
|
||||
}
|
||||
return strconv.Atoi(strings.TrimPrefix(part, key+":"))
|
||||
}
|
||||
|
||||
var err error
|
||||
|
||||
if ct.Token.PageIndex, err = parseIntField(parts[1], "RT"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if ct.Token.TotalResults, err = parseIntField(parts[2], "TRC"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if ct.Token.ISV, err = parseIntField(parts[3], "ISV"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if ct.Token.IEO, err = parseIntField(parts[4], "IEO"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if ct.Token.QCF, err = parseIntField(parts[5], "QCF"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if ct.Token.LR, err = parseIntField(parts[6], "LR"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ct.Range.Min = minRange
|
||||
ct.Range.Max = maxRange
|
||||
|
||||
return ct, nil
|
||||
}
|
||||
35
internal/continuation_token/continuation_token_test.go
Normal file
35
internal/continuation_token/continuation_token_test.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package continuationtoken
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_Generate(t *testing.T) {
|
||||
token := Generate("test-resource-id", 1, 100)
|
||||
|
||||
assert.Equal(t, "test-resource-id", token.Token.ResourceId)
|
||||
assert.Equal(t, 1, token.Token.PageIndex)
|
||||
assert.Equal(t, 100, token.Token.TotalResults)
|
||||
}
|
||||
|
||||
func Test_FromString(t *testing.T) {
|
||||
token := FromString("{\"token\":\"-RID:~test-resource-id#RT:1#TRC:100#ISV:2#IEO:65567#QCF:8#LR:1\",\"range\":{\"min\":\"\",\"max\":\"FF\"}}")
|
||||
|
||||
assert.Equal(t, "test-resource-id", token.Token.ResourceId)
|
||||
assert.Equal(t, 1, token.Token.PageIndex)
|
||||
assert.Equal(t, 100, token.Token.TotalResults)
|
||||
}
|
||||
|
||||
func Test_ToString(t *testing.T) {
|
||||
token := Generate("test-resource-id", 1, 100)
|
||||
assert.Equal(t, "{\"token\":\"-RID:~test-resource-id#RT:1#TRC:100#ISV:2#IEO:65567#QCF:8#LR:1\",\"range\":{\"min\":\"\",\"max\":\"FF\"}}", token.ToString())
|
||||
}
|
||||
|
||||
func Test_GenerateDefault(t *testing.T) {
|
||||
token := GenerateDefault("test-resource-id")
|
||||
assert.Equal(t, "test-resource-id", token.Token.ResourceId)
|
||||
assert.Equal(t, 0, token.Token.PageIndex)
|
||||
assert.Equal(t, 0, token.Token.TotalResults)
|
||||
}
|
||||
@@ -5,18 +5,46 @@ import (
|
||||
"github.com/pikami/cosmium/parsers"
|
||||
)
|
||||
|
||||
func ExecuteQuery(query parsers.SelectStmt, documents rowTypeIterator) []RowType {
|
||||
type ExecuteQueryResult struct {
|
||||
Rows []RowType
|
||||
HasMorePages bool
|
||||
}
|
||||
|
||||
func ExecuteQuery(
|
||||
query parsers.SelectStmt,
|
||||
documents rowTypeIterator,
|
||||
offset int,
|
||||
limit int,
|
||||
) ExecuteQueryResult {
|
||||
resultIter := executeQuery(query, &rowTypeToRowContextIterator{documents: documents, query: query})
|
||||
result := make([]RowType, 0)
|
||||
for {
|
||||
|
||||
result := &ExecuteQueryResult{
|
||||
Rows: make([]RowType, 0),
|
||||
HasMorePages: false,
|
||||
}
|
||||
|
||||
for i := 0; i < offset; i++ {
|
||||
_, status := resultIter.Next()
|
||||
if status != datastore.StatusOk {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
for i := 0; i < limit; i++ {
|
||||
row, status := resultIter.Next()
|
||||
if status != datastore.StatusOk {
|
||||
break
|
||||
}
|
||||
|
||||
result = append(result, row)
|
||||
result.Rows = append(result.Rows, row)
|
||||
}
|
||||
return result
|
||||
|
||||
_, status := resultIter.Next()
|
||||
if status == datastore.StatusOk {
|
||||
result.HasMorePages = true
|
||||
}
|
||||
|
||||
return *result
|
||||
}
|
||||
|
||||
func executeQuery(query parsers.SelectStmt, documents rowIterator) rowTypeIterator {
|
||||
|
||||
@@ -38,10 +38,10 @@ func testQueryExecute(
|
||||
expectedData []memoryexecutor.RowType,
|
||||
) {
|
||||
iter := NewTestDocumentIterator(data)
|
||||
result := memoryexecutor.ExecuteQuery(query, iter)
|
||||
result := memoryexecutor.ExecuteQuery(query, iter, 0, 1000)
|
||||
|
||||
if !reflect.DeepEqual(result, expectedData) {
|
||||
t.Errorf("execution result does not match expected data.\nExpected: %+v\nGot: %+v", expectedData, result)
|
||||
if !reflect.DeepEqual(result.Rows, expectedData) {
|
||||
t.Errorf("execution result does not match expected data.\nExpected: %+v\nGot: %+v", expectedData, result.Rows)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user