From e5ddc143f06522b6da8a42898a80c2bfc93b3f3d Mon Sep 17 00:00:00 2001 From: Pijus Kamandulis Date: Sun, 8 Dec 2024 17:54:58 +0200 Subject: [PATCH] Improved concurrency handling --- api/handlers/cosmium.go | 8 ++- api/tests/documents_test.go | 52 +++++++++++++++++++ internal/repositories/collections.go | 12 +++++ internal/repositories/databases.go | 12 +++++ internal/repositories/documents.go | 12 +++++ internal/repositories/partition_key_ranges.go | 3 ++ internal/repositories/state.go | 16 +++++- internal/repository_models/models.go | 4 ++ .../memory_executor/memory_executor.go | 1 + 9 files changed, 117 insertions(+), 3 deletions(-) diff --git a/api/handlers/cosmium.go b/api/handlers/cosmium.go index 0844652..cb8512c 100644 --- a/api/handlers/cosmium.go +++ b/api/handlers/cosmium.go @@ -8,5 +8,11 @@ import ( ) func CosmiumExport(c *gin.Context) { - c.IndentedJSON(http.StatusOK, repositories.GetState()) + repositoryState, err := repositories.GetState() + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.Data(http.StatusOK, "application/json", []byte(repositoryState)) } diff --git a/api/tests/documents_test.go b/api/tests/documents_test.go index 33492ad..cf6c08e 100644 --- a/api/tests/documents_test.go +++ b/api/tests/documents_test.go @@ -8,7 +8,9 @@ import ( "net/http" "net/http/httptest" "reflect" + "sync" "testing" + "time" "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos" @@ -162,6 +164,56 @@ func Test_Documents(t *testing.T) { }, ) }) + + t.Run("Should handle parallel writes", func(t *testing.T) { + var wg sync.WaitGroup + rutineCount := 100 + results := make(chan error, rutineCount) + + createCall := func(i int) { + defer wg.Done() + item := map[string]interface{}{ + "id": fmt.Sprintf("id-%d", i), + "pk": fmt.Sprintf("pk-%d", i), + "val": i, + } + bytes, err := json.Marshal(item) + if err != nil { + results <- err + return + } + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + _, err = collectionClient.CreateItem( + ctx, + azcosmos.PartitionKey{}, + bytes, + &azcosmos.ItemOptions{ + EnableContentResponseOnWrite: false, + }, + ) + results <- err + + collectionClient.ReadItem(ctx, azcosmos.PartitionKey{}, fmt.Sprintf("id-%d", i), nil) + collectionClient.DeleteItem(ctx, azcosmos.PartitionKey{}, fmt.Sprintf("id-%d", i), nil) + } + + for i := 0; i < rutineCount; i++ { + wg.Add(1) + go createCall(i) + } + + wg.Wait() + close(results) + + for err := range results { + if err != nil { + t.Errorf("Error creating item: %v", err) + } + } + }) } func Test_Documents_Patch(t *testing.T) { diff --git a/internal/repositories/collections.go b/internal/repositories/collections.go index 3e2703c..54e3705 100644 --- a/internal/repositories/collections.go +++ b/internal/repositories/collections.go @@ -12,6 +12,9 @@ import ( ) func GetAllCollections(databaseId string) ([]repositorymodels.Collection, repositorymodels.RepositoryStatus) { + storeState.RLock() + defer storeState.RUnlock() + if _, ok := storeState.Databases[databaseId]; !ok { return make([]repositorymodels.Collection, 0), repositorymodels.StatusNotFound } @@ -20,6 +23,9 @@ func GetAllCollections(databaseId string) ([]repositorymodels.Collection, reposi } func GetCollection(databaseId string, collectionId string) (repositorymodels.Collection, repositorymodels.RepositoryStatus) { + storeState.RLock() + defer storeState.RUnlock() + if _, ok := storeState.Databases[databaseId]; !ok { return repositorymodels.Collection{}, repositorymodels.StatusNotFound } @@ -32,6 +38,9 @@ func GetCollection(databaseId string, collectionId string) (repositorymodels.Col } func DeleteCollection(databaseId string, collectionId string) repositorymodels.RepositoryStatus { + storeState.Lock() + defer storeState.Unlock() + if _, ok := storeState.Databases[databaseId]; !ok { return repositorymodels.StatusNotFound } @@ -46,6 +55,9 @@ func DeleteCollection(databaseId string, collectionId string) repositorymodels.R } func CreateCollection(databaseId string, newCollection repositorymodels.Collection) (repositorymodels.Collection, repositorymodels.RepositoryStatus) { + storeState.Lock() + defer storeState.Unlock() + var ok bool var database repositorymodels.Database if database, ok = storeState.Databases[databaseId]; !ok { diff --git a/internal/repositories/databases.go b/internal/repositories/databases.go index 23b9e8f..abd609a 100644 --- a/internal/repositories/databases.go +++ b/internal/repositories/databases.go @@ -11,10 +11,16 @@ import ( ) func GetAllDatabases() ([]repositorymodels.Database, repositorymodels.RepositoryStatus) { + storeState.RLock() + defer storeState.RUnlock() + return maps.Values(storeState.Databases), repositorymodels.StatusOk } func GetDatabase(id string) (repositorymodels.Database, repositorymodels.RepositoryStatus) { + storeState.RLock() + defer storeState.RUnlock() + if database, ok := storeState.Databases[id]; ok { return database, repositorymodels.StatusOk } @@ -23,6 +29,9 @@ func GetDatabase(id string) (repositorymodels.Database, repositorymodels.Reposit } func DeleteDatabase(id string) repositorymodels.RepositoryStatus { + storeState.Lock() + defer storeState.Unlock() + if _, ok := storeState.Databases[id]; !ok { return repositorymodels.StatusNotFound } @@ -33,6 +42,9 @@ func DeleteDatabase(id string) repositorymodels.RepositoryStatus { } func CreateDatabase(newDatabase repositorymodels.Database) (repositorymodels.Database, repositorymodels.RepositoryStatus) { + storeState.Lock() + defer storeState.Unlock() + if _, ok := storeState.Databases[newDatabase.ID]; ok { return repositorymodels.Database{}, repositorymodels.Conflict } diff --git a/internal/repositories/documents.go b/internal/repositories/documents.go index 8b8a978..0a6feb6 100644 --- a/internal/repositories/documents.go +++ b/internal/repositories/documents.go @@ -15,6 +15,9 @@ import ( ) func GetAllDocuments(databaseId string, collectionId string) ([]repositorymodels.Document, repositorymodels.RepositoryStatus) { + storeState.RLock() + defer storeState.RUnlock() + if _, ok := storeState.Databases[databaseId]; !ok { return make([]repositorymodels.Document, 0), repositorymodels.StatusNotFound } @@ -27,6 +30,9 @@ func GetAllDocuments(databaseId string, collectionId string) ([]repositorymodels } func GetDocument(databaseId string, collectionId string, documentId string) (repositorymodels.Document, repositorymodels.RepositoryStatus) { + storeState.RLock() + defer storeState.RUnlock() + if _, ok := storeState.Databases[databaseId]; !ok { return repositorymodels.Document{}, repositorymodels.StatusNotFound } @@ -43,6 +49,9 @@ func GetDocument(databaseId string, collectionId string, documentId string) (rep } func DeleteDocument(databaseId string, collectionId string, documentId string) repositorymodels.RepositoryStatus { + storeState.Lock() + defer storeState.Unlock() + if _, ok := storeState.Databases[databaseId]; !ok { return repositorymodels.StatusNotFound } @@ -61,6 +70,9 @@ func DeleteDocument(databaseId string, collectionId string, documentId string) r } func CreateDocument(databaseId string, collectionId string, document map[string]interface{}) (repositorymodels.Document, repositorymodels.RepositoryStatus) { + storeState.Lock() + defer storeState.Unlock() + var ok bool var documentId string var database repositorymodels.Database diff --git a/internal/repositories/partition_key_ranges.go b/internal/repositories/partition_key_ranges.go index 689d6d9..4b1ff91 100644 --- a/internal/repositories/partition_key_ranges.go +++ b/internal/repositories/partition_key_ranges.go @@ -10,6 +10,9 @@ import ( // I have no idea what this is tbh func GetPartitionKeyRanges(databaseId string, collectionId string) ([]repositorymodels.PartitionKeyRange, repositorymodels.RepositoryStatus) { + storeState.RLock() + defer storeState.RUnlock() + databaseRid := databaseId collectionRid := collectionId var timestamp int64 = 0 diff --git a/internal/repositories/state.go b/internal/repositories/state.go index 4d13846..063f436 100644 --- a/internal/repositories/state.go +++ b/internal/repositories/state.go @@ -66,6 +66,9 @@ func LoadStateFS(filePath string) { } func SaveStateFS(filePath string) { + storeState.RLock() + defer storeState.RUnlock() + data, err := json.MarshalIndent(storeState, "", "\t") if err != nil { logger.Errorf("Failed to save state: %v\n", err) @@ -80,8 +83,17 @@ func SaveStateFS(filePath string) { logger.Infof("Documents: %d\n", getLength(storeState.Documents)) } -func GetState() repositorymodels.State { - return storeState +func GetState() (string, error) { + storeState.RLock() + defer storeState.RUnlock() + + data, err := json.MarshalIndent(storeState, "", "\t") + if err != nil { + logger.Errorf("Failed to serialize state: %v\n", err) + return "", err + } + + return string(data), nil } func getLength(v interface{}) int { diff --git a/internal/repository_models/models.go b/internal/repository_models/models.go index d828a14..490cf52 100644 --- a/internal/repository_models/models.go +++ b/internal/repository_models/models.go @@ -1,5 +1,7 @@ package repositorymodels +import "sync" + type Database struct { ID string `json:"id"` TimeStamp int64 `json:"_ts"` @@ -101,6 +103,8 @@ type PartitionKeyRange struct { } type State struct { + sync.RWMutex + // Map databaseId -> Database Databases map[string]Database `json:"databases"` diff --git a/query_executors/memory_executor/memory_executor.go b/query_executors/memory_executor/memory_executor.go index 4551eea..779d502 100644 --- a/query_executors/memory_executor/memory_executor.go +++ b/query_executors/memory_executor/memory_executor.go @@ -659,6 +659,7 @@ func compareValues(val1, val2 interface{}) int { func deduplicate[T RowType | interface{}](slice []T) []T { var result []T + result = make([]T, 0) for i := 0; i < len(slice); i++ { unique := true