From 5d99b653cc5e91a070498229a5ca2fb50ae3d772 Mon Sep 17 00:00:00 2001 From: Pijus Kamandulis Date: Sun, 9 Feb 2025 00:36:35 +0200 Subject: [PATCH] Generate more realistic resource ids --- api/handlers/middleware/authentication.go | 3 +- api/handlers/partition_key_ranges.go | 4 +- internal/repositories/collections.go | 6 +- internal/repositories/databases.go | 7 +- internal/repositories/documents.go | 2 +- internal/repositories/partition_key_ranges.go | 2 +- internal/repositories/stored_procedures.go | 2 +- internal/repositories/triggers.go | 2 +- .../repositories/user_defined_functions.go | 2 +- internal/resourceid/resourceid.go | 74 ++++++++++++++++--- 10 files changed, 84 insertions(+), 20 deletions(-) diff --git a/api/handlers/middleware/authentication.go b/api/handlers/middleware/authentication.go index 380f576..5f94ce5 100644 --- a/api/handlers/middleware/authentication.go +++ b/api/handlers/middleware/authentication.go @@ -75,8 +75,7 @@ func requestToResourceId(c *gin.Context) string { isFeed := c.Request.Header.Get("A-Im") == "Incremental Feed" if resourceType == "pkranges" && isFeed { - // CosmosSDK replaces '/' with '-' in resource id requests - resourceId = strings.Replace(collId, "-", "/", -1) + resourceId = collId } return resourceId diff --git a/api/handlers/partition_key_ranges.go b/api/handlers/partition_key_ranges.go index 1f60ef6..063e51c 100644 --- a/api/handlers/partition_key_ranges.go +++ b/api/handlers/partition_key_ranges.go @@ -6,6 +6,7 @@ import ( "github.com/gin-gonic/gin" repositorymodels "github.com/pikami/cosmium/internal/repository_models" + "github.com/pikami/cosmium/internal/resourceid" ) func (h *Handlers) GetPartitionKeyRanges(c *gin.Context) { @@ -31,8 +32,9 @@ func (h *Handlers) GetPartitionKeyRanges(c *gin.Context) { collectionRid = collection.ResourceID } + rid := resourceid.NewCombined(collectionRid, resourceid.New(resourceid.ResourceTypePartitionKeyRange)) c.IndentedJSON(http.StatusOK, gin.H{ - "_rid": collectionRid, + "_rid": rid, "_count": len(partitionKeyRanges), "PartitionKeyRanges": partitionKeyRanges, }) diff --git a/internal/repositories/collections.go b/internal/repositories/collections.go index b8d9fdb..b1054a5 100644 --- a/internal/repositories/collections.go +++ b/internal/repositories/collections.go @@ -50,6 +50,10 @@ func (r *DataRepository) DeleteCollection(databaseId string, collectionId string } delete(r.storeState.Collections[databaseId], collectionId) + delete(r.storeState.Documents[databaseId], collectionId) + delete(r.storeState.Triggers[databaseId], collectionId) + delete(r.storeState.StoredProcedures[databaseId], collectionId) + delete(r.storeState.UserDefinedFunctions[databaseId], collectionId) return repositorymodels.StatusOk } @@ -71,7 +75,7 @@ func (r *DataRepository) CreateCollection(databaseId string, newCollection repos newCollection = structhidrators.Hidrate(newCollection).(repositorymodels.Collection) newCollection.TimeStamp = time.Now().Unix() - newCollection.ResourceID = resourceid.NewCombined(database.ResourceID, resourceid.New()) + newCollection.ResourceID = resourceid.NewCombined(database.ResourceID, resourceid.New(resourceid.ResourceTypeCollection)) newCollection.ETag = fmt.Sprintf("\"%s\"", uuid.New()) newCollection.Self = fmt.Sprintf("dbs/%s/colls/%s/", database.ResourceID, newCollection.ResourceID) diff --git a/internal/repositories/databases.go b/internal/repositories/databases.go index d094911..dee0dbb 100644 --- a/internal/repositories/databases.go +++ b/internal/repositories/databases.go @@ -37,6 +37,11 @@ func (r *DataRepository) DeleteDatabase(id string) repositorymodels.RepositorySt } delete(r.storeState.Databases, id) + delete(r.storeState.Collections, id) + delete(r.storeState.Documents, id) + delete(r.storeState.Triggers, id) + delete(r.storeState.StoredProcedures, id) + delete(r.storeState.UserDefinedFunctions, id) return repositorymodels.StatusOk } @@ -50,7 +55,7 @@ func (r *DataRepository) CreateDatabase(newDatabase repositorymodels.Database) ( } newDatabase.TimeStamp = time.Now().Unix() - newDatabase.ResourceID = resourceid.New() + newDatabase.ResourceID = resourceid.New(resourceid.ResourceTypeDatabase) newDatabase.ETag = fmt.Sprintf("\"%s\"", uuid.New()) newDatabase.Self = fmt.Sprintf("dbs/%s/", newDatabase.ResourceID) diff --git a/internal/repositories/documents.go b/internal/repositories/documents.go index 187840d..8bb1763 100644 --- a/internal/repositories/documents.go +++ b/internal/repositories/documents.go @@ -95,7 +95,7 @@ func (r *DataRepository) CreateDocument(databaseId string, collectionId string, } document["_ts"] = time.Now().Unix() - document["_rid"] = resourceid.NewCombined(database.ResourceID, collection.ResourceID, resourceid.New()) + document["_rid"] = resourceid.NewCombined(collection.ResourceID, resourceid.New(resourceid.ResourceTypeDocument)) document["_etag"] = fmt.Sprintf("\"%s\"", uuid.New()) document["_self"] = fmt.Sprintf("dbs/%s/colls/%s/docs/%s/", database.ResourceID, collection.ResourceID, document["_rid"]) diff --git a/internal/repositories/partition_key_ranges.go b/internal/repositories/partition_key_ranges.go index 8424ed1..3ae4aa5 100644 --- a/internal/repositories/partition_key_ranges.go +++ b/internal/repositories/partition_key_ranges.go @@ -26,7 +26,7 @@ func (r *DataRepository) GetPartitionKeyRanges(databaseId string, collectionId s timestamp = collection.TimeStamp } - pkrResourceId := resourceid.NewCombined(databaseRid, collectionRid, resourceid.New()) + pkrResourceId := resourceid.NewCombined(collectionRid, resourceid.New(resourceid.ResourceTypePartitionKeyRange)) pkrSelf := fmt.Sprintf("dbs/%s/colls/%s/pkranges/%s/", databaseRid, collectionRid, pkrResourceId) etag := fmt.Sprintf("\"%s\"", uuid.New()) diff --git a/internal/repositories/stored_procedures.go b/internal/repositories/stored_procedures.go index 5cfd060..95e8819 100644 --- a/internal/repositories/stored_procedures.go +++ b/internal/repositories/stored_procedures.go @@ -81,7 +81,7 @@ func (r *DataRepository) CreateStoredProcedure(databaseId string, collectionId s } sp.TimeStamp = time.Now().Unix() - sp.ResourceID = resourceid.NewCombined(database.ResourceID, collection.ResourceID, resourceid.New()) + sp.ResourceID = resourceid.NewCombined(collection.ResourceID, resourceid.New(resourceid.ResourceTypeStoredProcedure)) sp.ETag = fmt.Sprintf("\"%s\"", uuid.New()) sp.Self = fmt.Sprintf("dbs/%s/colls/%s/sprocs/%s/", database.ResourceID, collection.ResourceID, sp.ResourceID) diff --git a/internal/repositories/triggers.go b/internal/repositories/triggers.go index 3ef6535..0ec60fb 100644 --- a/internal/repositories/triggers.go +++ b/internal/repositories/triggers.go @@ -81,7 +81,7 @@ func (r *DataRepository) CreateTrigger(databaseId string, collectionId string, t } trigger.TimeStamp = time.Now().Unix() - trigger.ResourceID = resourceid.NewCombined(database.ResourceID, collection.ResourceID, resourceid.New()) + trigger.ResourceID = resourceid.NewCombined(collection.ResourceID, resourceid.New(resourceid.ResourceTypeTrigger)) trigger.ETag = fmt.Sprintf("\"%s\"", uuid.New()) trigger.Self = fmt.Sprintf("dbs/%s/colls/%s/triggers/%s/", database.ResourceID, collection.ResourceID, trigger.ResourceID) diff --git a/internal/repositories/user_defined_functions.go b/internal/repositories/user_defined_functions.go index aee22b4..3f55468 100644 --- a/internal/repositories/user_defined_functions.go +++ b/internal/repositories/user_defined_functions.go @@ -81,7 +81,7 @@ func (r *DataRepository) CreateUserDefinedFunction(databaseId string, collection } udf.TimeStamp = time.Now().Unix() - udf.ResourceID = resourceid.NewCombined(database.ResourceID, collection.ResourceID, resourceid.New()) + udf.ResourceID = resourceid.NewCombined(collection.ResourceID, resourceid.New(resourceid.ResourceTypeUserDefinedFunction)) udf.ETag = fmt.Sprintf("\"%s\"", uuid.New()) udf.Self = fmt.Sprintf("dbs/%s/colls/%s/udfs/%s/", database.ResourceID, collection.ResourceID, udf.ResourceID) diff --git a/internal/resourceid/resourceid.go b/internal/resourceid/resourceid.go index c001e6c..1b07d16 100644 --- a/internal/resourceid/resourceid.go +++ b/internal/resourceid/resourceid.go @@ -3,32 +3,76 @@ package resourceid import ( "encoding/base64" "math/rand" + "strings" "github.com/google/uuid" ) -func New() string { - id := uuid.New().ID() - idBytes := uintToBytes(id) +type ResourceType int - // first byte should be bigger than 0x80 for collection ids - // clients classify this id as "user" otherwise - if (idBytes[0] & 0x80) <= 0 { - idBytes[0] = byte(rand.Intn(0x80) + 0x80) +const ( + ResourceTypeDatabase ResourceType = iota + ResourceTypeCollection + ResourceTypeDocument + ResourceTypeStoredProcedure + ResourceTypeTrigger + ResourceTypeUserDefinedFunction + ResourceTypeConflict + ResourceTypePartitionKeyRange + ResourceTypeSchema +) + +func New(resourceType ResourceType) string { + var idBytes []byte + switch resourceType { + case ResourceTypeDatabase: + idBytes = randomBytes(4) + case ResourceTypeCollection: + idBytes = randomBytes(4) + // first byte should be bigger than 0x80 for collection ids + // clients classify this id as "user" otherwise + if (idBytes[0] & 0x80) <= 0 { + idBytes[0] = byte(rand.Intn(0x80) + 0x80) + } + case ResourceTypeDocument: + idBytes = randomBytes(8) + idBytes[7] = byte(rand.Intn(0x10)) // Upper 4 bits = 0 + case ResourceTypeStoredProcedure: + idBytes = randomBytes(8) + idBytes[7] = byte(rand.Intn(0x10)) | 0x08 // Upper 4 bits = 0x08 + case ResourceTypeTrigger: + idBytes = randomBytes(8) + idBytes[7] = byte(rand.Intn(0x10)) | 0x07 // Upper 4 bits = 0x07 + case ResourceTypeUserDefinedFunction: + idBytes = randomBytes(8) + idBytes[7] = byte(rand.Intn(0x10)) | 0x06 // Upper 4 bits = 0x06 + case ResourceTypeConflict: + idBytes = randomBytes(8) + idBytes[7] = byte(rand.Intn(0x10)) | 0x04 // Upper 4 bits = 0x04 + case ResourceTypePartitionKeyRange: + // we don't do partitions yet, so just use a fixed id + idBytes = []byte{0x69, 0x69, 0x69, 0x69, 0x69, 0x69, 0x69, 0x50} + case ResourceTypeSchema: + idBytes = randomBytes(8) + idBytes[7] = byte(rand.Intn(0x10)) | 0x09 // Upper 4 bits = 0x09 + default: + idBytes = randomBytes(4) } - return base64.StdEncoding.EncodeToString(idBytes) + encoded := base64.StdEncoding.EncodeToString(idBytes) + return strings.ReplaceAll(encoded, "/", "-") } func NewCombined(ids ...string) string { combinedIdBytes := make([]byte, 0) for _, id := range ids { - idBytes, _ := base64.StdEncoding.DecodeString(id) + idBytes, _ := base64.StdEncoding.DecodeString(strings.ReplaceAll(id, "-", "/")) combinedIdBytes = append(combinedIdBytes, idBytes...) } - return base64.StdEncoding.EncodeToString(combinedIdBytes) + encoded := base64.StdEncoding.EncodeToString(combinedIdBytes) + return strings.ReplaceAll(encoded, "/", "-") } func uintToBytes(id uint32) []byte { @@ -39,3 +83,13 @@ func uintToBytes(id uint32) []byte { return buf } + +func randomBytes(count int) []byte { + buf := make([]byte, count) + for i := 0; i < count; i += 4 { + id := uuid.New().ID() + idBytes := uintToBytes(id) + copy(buf[i:], idBytes) + } + return buf +}