Improved concurrency handling

This commit is contained in:
Pijus Kamandulis 2024-12-08 17:54:58 +02:00
parent 66ea859f34
commit e5ddc143f0
9 changed files with 117 additions and 3 deletions

View File

@ -8,5 +8,11 @@ import (
) )
func CosmiumExport(c *gin.Context) { func CosmiumExport(c *gin.Context) {
c.IndentedJSON(http.StatusOK, repositories.GetState()) repositoryState, err := repositories.GetState()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.Data(http.StatusOK, "application/json", []byte(repositoryState))
} }

View File

@ -8,7 +8,9 @@ import (
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"reflect" "reflect"
"sync"
"testing" "testing"
"time"
"github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos" "github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos"
@ -162,6 +164,56 @@ func Test_Documents(t *testing.T) {
}, },
) )
}) })
t.Run("Should handle parallel writes", func(t *testing.T) {
var wg sync.WaitGroup
rutineCount := 100
results := make(chan error, rutineCount)
createCall := func(i int) {
defer wg.Done()
item := map[string]interface{}{
"id": fmt.Sprintf("id-%d", i),
"pk": fmt.Sprintf("pk-%d", i),
"val": i,
}
bytes, err := json.Marshal(item)
if err != nil {
results <- err
return
}
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
_, err = collectionClient.CreateItem(
ctx,
azcosmos.PartitionKey{},
bytes,
&azcosmos.ItemOptions{
EnableContentResponseOnWrite: false,
},
)
results <- err
collectionClient.ReadItem(ctx, azcosmos.PartitionKey{}, fmt.Sprintf("id-%d", i), nil)
collectionClient.DeleteItem(ctx, azcosmos.PartitionKey{}, fmt.Sprintf("id-%d", i), nil)
}
for i := 0; i < rutineCount; i++ {
wg.Add(1)
go createCall(i)
}
wg.Wait()
close(results)
for err := range results {
if err != nil {
t.Errorf("Error creating item: %v", err)
}
}
})
} }
func Test_Documents_Patch(t *testing.T) { func Test_Documents_Patch(t *testing.T) {

View File

@ -12,6 +12,9 @@ import (
) )
func GetAllCollections(databaseId string) ([]repositorymodels.Collection, repositorymodels.RepositoryStatus) { func GetAllCollections(databaseId string) ([]repositorymodels.Collection, repositorymodels.RepositoryStatus) {
storeState.RLock()
defer storeState.RUnlock()
if _, ok := storeState.Databases[databaseId]; !ok { if _, ok := storeState.Databases[databaseId]; !ok {
return make([]repositorymodels.Collection, 0), repositorymodels.StatusNotFound return make([]repositorymodels.Collection, 0), repositorymodels.StatusNotFound
} }
@ -20,6 +23,9 @@ func GetAllCollections(databaseId string) ([]repositorymodels.Collection, reposi
} }
func GetCollection(databaseId string, collectionId string) (repositorymodels.Collection, repositorymodels.RepositoryStatus) { func GetCollection(databaseId string, collectionId string) (repositorymodels.Collection, repositorymodels.RepositoryStatus) {
storeState.RLock()
defer storeState.RUnlock()
if _, ok := storeState.Databases[databaseId]; !ok { if _, ok := storeState.Databases[databaseId]; !ok {
return repositorymodels.Collection{}, repositorymodels.StatusNotFound return repositorymodels.Collection{}, repositorymodels.StatusNotFound
} }
@ -32,6 +38,9 @@ func GetCollection(databaseId string, collectionId string) (repositorymodels.Col
} }
func DeleteCollection(databaseId string, collectionId string) repositorymodels.RepositoryStatus { func DeleteCollection(databaseId string, collectionId string) repositorymodels.RepositoryStatus {
storeState.Lock()
defer storeState.Unlock()
if _, ok := storeState.Databases[databaseId]; !ok { if _, ok := storeState.Databases[databaseId]; !ok {
return repositorymodels.StatusNotFound return repositorymodels.StatusNotFound
} }
@ -46,6 +55,9 @@ func DeleteCollection(databaseId string, collectionId string) repositorymodels.R
} }
func CreateCollection(databaseId string, newCollection repositorymodels.Collection) (repositorymodels.Collection, repositorymodels.RepositoryStatus) { func CreateCollection(databaseId string, newCollection repositorymodels.Collection) (repositorymodels.Collection, repositorymodels.RepositoryStatus) {
storeState.Lock()
defer storeState.Unlock()
var ok bool var ok bool
var database repositorymodels.Database var database repositorymodels.Database
if database, ok = storeState.Databases[databaseId]; !ok { if database, ok = storeState.Databases[databaseId]; !ok {

View File

@ -11,10 +11,16 @@ import (
) )
func GetAllDatabases() ([]repositorymodels.Database, repositorymodels.RepositoryStatus) { func GetAllDatabases() ([]repositorymodels.Database, repositorymodels.RepositoryStatus) {
storeState.RLock()
defer storeState.RUnlock()
return maps.Values(storeState.Databases), repositorymodels.StatusOk return maps.Values(storeState.Databases), repositorymodels.StatusOk
} }
func GetDatabase(id string) (repositorymodels.Database, repositorymodels.RepositoryStatus) { func GetDatabase(id string) (repositorymodels.Database, repositorymodels.RepositoryStatus) {
storeState.RLock()
defer storeState.RUnlock()
if database, ok := storeState.Databases[id]; ok { if database, ok := storeState.Databases[id]; ok {
return database, repositorymodels.StatusOk return database, repositorymodels.StatusOk
} }
@ -23,6 +29,9 @@ func GetDatabase(id string) (repositorymodels.Database, repositorymodels.Reposit
} }
func DeleteDatabase(id string) repositorymodels.RepositoryStatus { func DeleteDatabase(id string) repositorymodels.RepositoryStatus {
storeState.Lock()
defer storeState.Unlock()
if _, ok := storeState.Databases[id]; !ok { if _, ok := storeState.Databases[id]; !ok {
return repositorymodels.StatusNotFound return repositorymodels.StatusNotFound
} }
@ -33,6 +42,9 @@ func DeleteDatabase(id string) repositorymodels.RepositoryStatus {
} }
func CreateDatabase(newDatabase repositorymodels.Database) (repositorymodels.Database, repositorymodels.RepositoryStatus) { func CreateDatabase(newDatabase repositorymodels.Database) (repositorymodels.Database, repositorymodels.RepositoryStatus) {
storeState.Lock()
defer storeState.Unlock()
if _, ok := storeState.Databases[newDatabase.ID]; ok { if _, ok := storeState.Databases[newDatabase.ID]; ok {
return repositorymodels.Database{}, repositorymodels.Conflict return repositorymodels.Database{}, repositorymodels.Conflict
} }

View File

@ -15,6 +15,9 @@ import (
) )
func GetAllDocuments(databaseId string, collectionId string) ([]repositorymodels.Document, repositorymodels.RepositoryStatus) { func GetAllDocuments(databaseId string, collectionId string) ([]repositorymodels.Document, repositorymodels.RepositoryStatus) {
storeState.RLock()
defer storeState.RUnlock()
if _, ok := storeState.Databases[databaseId]; !ok { if _, ok := storeState.Databases[databaseId]; !ok {
return make([]repositorymodels.Document, 0), repositorymodels.StatusNotFound return make([]repositorymodels.Document, 0), repositorymodels.StatusNotFound
} }
@ -27,6 +30,9 @@ func GetAllDocuments(databaseId string, collectionId string) ([]repositorymodels
} }
func GetDocument(databaseId string, collectionId string, documentId string) (repositorymodels.Document, repositorymodels.RepositoryStatus) { func GetDocument(databaseId string, collectionId string, documentId string) (repositorymodels.Document, repositorymodels.RepositoryStatus) {
storeState.RLock()
defer storeState.RUnlock()
if _, ok := storeState.Databases[databaseId]; !ok { if _, ok := storeState.Databases[databaseId]; !ok {
return repositorymodels.Document{}, repositorymodels.StatusNotFound return repositorymodels.Document{}, repositorymodels.StatusNotFound
} }
@ -43,6 +49,9 @@ func GetDocument(databaseId string, collectionId string, documentId string) (rep
} }
func DeleteDocument(databaseId string, collectionId string, documentId string) repositorymodels.RepositoryStatus { func DeleteDocument(databaseId string, collectionId string, documentId string) repositorymodels.RepositoryStatus {
storeState.Lock()
defer storeState.Unlock()
if _, ok := storeState.Databases[databaseId]; !ok { if _, ok := storeState.Databases[databaseId]; !ok {
return repositorymodels.StatusNotFound return repositorymodels.StatusNotFound
} }
@ -61,6 +70,9 @@ func DeleteDocument(databaseId string, collectionId string, documentId string) r
} }
func CreateDocument(databaseId string, collectionId string, document map[string]interface{}) (repositorymodels.Document, repositorymodels.RepositoryStatus) { func CreateDocument(databaseId string, collectionId string, document map[string]interface{}) (repositorymodels.Document, repositorymodels.RepositoryStatus) {
storeState.Lock()
defer storeState.Unlock()
var ok bool var ok bool
var documentId string var documentId string
var database repositorymodels.Database var database repositorymodels.Database

View File

@ -10,6 +10,9 @@ import (
// I have no idea what this is tbh // I have no idea what this is tbh
func GetPartitionKeyRanges(databaseId string, collectionId string) ([]repositorymodels.PartitionKeyRange, repositorymodels.RepositoryStatus) { func GetPartitionKeyRanges(databaseId string, collectionId string) ([]repositorymodels.PartitionKeyRange, repositorymodels.RepositoryStatus) {
storeState.RLock()
defer storeState.RUnlock()
databaseRid := databaseId databaseRid := databaseId
collectionRid := collectionId collectionRid := collectionId
var timestamp int64 = 0 var timestamp int64 = 0

View File

@ -66,6 +66,9 @@ func LoadStateFS(filePath string) {
} }
func SaveStateFS(filePath string) { func SaveStateFS(filePath string) {
storeState.RLock()
defer storeState.RUnlock()
data, err := json.MarshalIndent(storeState, "", "\t") data, err := json.MarshalIndent(storeState, "", "\t")
if err != nil { if err != nil {
logger.Errorf("Failed to save state: %v\n", err) logger.Errorf("Failed to save state: %v\n", err)
@ -80,8 +83,17 @@ func SaveStateFS(filePath string) {
logger.Infof("Documents: %d\n", getLength(storeState.Documents)) logger.Infof("Documents: %d\n", getLength(storeState.Documents))
} }
func GetState() repositorymodels.State { func GetState() (string, error) {
return storeState storeState.RLock()
defer storeState.RUnlock()
data, err := json.MarshalIndent(storeState, "", "\t")
if err != nil {
logger.Errorf("Failed to serialize state: %v\n", err)
return "", err
}
return string(data), nil
} }
func getLength(v interface{}) int { func getLength(v interface{}) int {

View File

@ -1,5 +1,7 @@
package repositorymodels package repositorymodels
import "sync"
type Database struct { type Database struct {
ID string `json:"id"` ID string `json:"id"`
TimeStamp int64 `json:"_ts"` TimeStamp int64 `json:"_ts"`
@ -101,6 +103,8 @@ type PartitionKeyRange struct {
} }
type State struct { type State struct {
sync.RWMutex
// Map databaseId -> Database // Map databaseId -> Database
Databases map[string]Database `json:"databases"` Databases map[string]Database `json:"databases"`

View File

@ -659,6 +659,7 @@ func compareValues(val1, val2 interface{}) int {
func deduplicate[T RowType | interface{}](slice []T) []T { func deduplicate[T RowType | interface{}](slice []T) []T {
var result []T var result []T
result = make([]T, 0)
for i := 0; i < len(slice); i++ { for i := 0; i < len(slice); i++ {
unique := true unique := true