mirror of
https://github.com/pikami/cosmium.git
synced 2025-02-08 17:06:45 +00:00
Improved concurrency handling
This commit is contained in:
parent
66ea859f34
commit
e5ddc143f0
@ -8,5 +8,11 @@ import (
|
||||
)
|
||||
|
||||
func CosmiumExport(c *gin.Context) {
|
||||
c.IndentedJSON(http.StatusOK, repositories.GetState())
|
||||
repositoryState, err := repositories.GetState()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.Data(http.StatusOK, "application/json", []byte(repositoryState))
|
||||
}
|
||||
|
@ -8,7 +8,9 @@ import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"reflect"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos"
|
||||
@ -162,6 +164,56 @@ func Test_Documents(t *testing.T) {
|
||||
},
|
||||
)
|
||||
})
|
||||
|
||||
t.Run("Should handle parallel writes", func(t *testing.T) {
|
||||
var wg sync.WaitGroup
|
||||
rutineCount := 100
|
||||
results := make(chan error, rutineCount)
|
||||
|
||||
createCall := func(i int) {
|
||||
defer wg.Done()
|
||||
item := map[string]interface{}{
|
||||
"id": fmt.Sprintf("id-%d", i),
|
||||
"pk": fmt.Sprintf("pk-%d", i),
|
||||
"val": i,
|
||||
}
|
||||
bytes, err := json.Marshal(item)
|
||||
if err != nil {
|
||||
results <- err
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
_, err = collectionClient.CreateItem(
|
||||
ctx,
|
||||
azcosmos.PartitionKey{},
|
||||
bytes,
|
||||
&azcosmos.ItemOptions{
|
||||
EnableContentResponseOnWrite: false,
|
||||
},
|
||||
)
|
||||
results <- err
|
||||
|
||||
collectionClient.ReadItem(ctx, azcosmos.PartitionKey{}, fmt.Sprintf("id-%d", i), nil)
|
||||
collectionClient.DeleteItem(ctx, azcosmos.PartitionKey{}, fmt.Sprintf("id-%d", i), nil)
|
||||
}
|
||||
|
||||
for i := 0; i < rutineCount; i++ {
|
||||
wg.Add(1)
|
||||
go createCall(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(results)
|
||||
|
||||
for err := range results {
|
||||
if err != nil {
|
||||
t.Errorf("Error creating item: %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func Test_Documents_Patch(t *testing.T) {
|
||||
|
@ -12,6 +12,9 @@ import (
|
||||
)
|
||||
|
||||
func GetAllCollections(databaseId string) ([]repositorymodels.Collection, repositorymodels.RepositoryStatus) {
|
||||
storeState.RLock()
|
||||
defer storeState.RUnlock()
|
||||
|
||||
if _, ok := storeState.Databases[databaseId]; !ok {
|
||||
return make([]repositorymodels.Collection, 0), repositorymodels.StatusNotFound
|
||||
}
|
||||
@ -20,6 +23,9 @@ func GetAllCollections(databaseId string) ([]repositorymodels.Collection, reposi
|
||||
}
|
||||
|
||||
func GetCollection(databaseId string, collectionId string) (repositorymodels.Collection, repositorymodels.RepositoryStatus) {
|
||||
storeState.RLock()
|
||||
defer storeState.RUnlock()
|
||||
|
||||
if _, ok := storeState.Databases[databaseId]; !ok {
|
||||
return repositorymodels.Collection{}, repositorymodels.StatusNotFound
|
||||
}
|
||||
@ -32,6 +38,9 @@ func GetCollection(databaseId string, collectionId string) (repositorymodels.Col
|
||||
}
|
||||
|
||||
func DeleteCollection(databaseId string, collectionId string) repositorymodels.RepositoryStatus {
|
||||
storeState.Lock()
|
||||
defer storeState.Unlock()
|
||||
|
||||
if _, ok := storeState.Databases[databaseId]; !ok {
|
||||
return repositorymodels.StatusNotFound
|
||||
}
|
||||
@ -46,6 +55,9 @@ func DeleteCollection(databaseId string, collectionId string) repositorymodels.R
|
||||
}
|
||||
|
||||
func CreateCollection(databaseId string, newCollection repositorymodels.Collection) (repositorymodels.Collection, repositorymodels.RepositoryStatus) {
|
||||
storeState.Lock()
|
||||
defer storeState.Unlock()
|
||||
|
||||
var ok bool
|
||||
var database repositorymodels.Database
|
||||
if database, ok = storeState.Databases[databaseId]; !ok {
|
||||
|
@ -11,10 +11,16 @@ import (
|
||||
)
|
||||
|
||||
func GetAllDatabases() ([]repositorymodels.Database, repositorymodels.RepositoryStatus) {
|
||||
storeState.RLock()
|
||||
defer storeState.RUnlock()
|
||||
|
||||
return maps.Values(storeState.Databases), repositorymodels.StatusOk
|
||||
}
|
||||
|
||||
func GetDatabase(id string) (repositorymodels.Database, repositorymodels.RepositoryStatus) {
|
||||
storeState.RLock()
|
||||
defer storeState.RUnlock()
|
||||
|
||||
if database, ok := storeState.Databases[id]; ok {
|
||||
return database, repositorymodels.StatusOk
|
||||
}
|
||||
@ -23,6 +29,9 @@ func GetDatabase(id string) (repositorymodels.Database, repositorymodels.Reposit
|
||||
}
|
||||
|
||||
func DeleteDatabase(id string) repositorymodels.RepositoryStatus {
|
||||
storeState.Lock()
|
||||
defer storeState.Unlock()
|
||||
|
||||
if _, ok := storeState.Databases[id]; !ok {
|
||||
return repositorymodels.StatusNotFound
|
||||
}
|
||||
@ -33,6 +42,9 @@ func DeleteDatabase(id string) repositorymodels.RepositoryStatus {
|
||||
}
|
||||
|
||||
func CreateDatabase(newDatabase repositorymodels.Database) (repositorymodels.Database, repositorymodels.RepositoryStatus) {
|
||||
storeState.Lock()
|
||||
defer storeState.Unlock()
|
||||
|
||||
if _, ok := storeState.Databases[newDatabase.ID]; ok {
|
||||
return repositorymodels.Database{}, repositorymodels.Conflict
|
||||
}
|
||||
|
@ -15,6 +15,9 @@ import (
|
||||
)
|
||||
|
||||
func GetAllDocuments(databaseId string, collectionId string) ([]repositorymodels.Document, repositorymodels.RepositoryStatus) {
|
||||
storeState.RLock()
|
||||
defer storeState.RUnlock()
|
||||
|
||||
if _, ok := storeState.Databases[databaseId]; !ok {
|
||||
return make([]repositorymodels.Document, 0), repositorymodels.StatusNotFound
|
||||
}
|
||||
@ -27,6 +30,9 @@ func GetAllDocuments(databaseId string, collectionId string) ([]repositorymodels
|
||||
}
|
||||
|
||||
func GetDocument(databaseId string, collectionId string, documentId string) (repositorymodels.Document, repositorymodels.RepositoryStatus) {
|
||||
storeState.RLock()
|
||||
defer storeState.RUnlock()
|
||||
|
||||
if _, ok := storeState.Databases[databaseId]; !ok {
|
||||
return repositorymodels.Document{}, repositorymodels.StatusNotFound
|
||||
}
|
||||
@ -43,6 +49,9 @@ func GetDocument(databaseId string, collectionId string, documentId string) (rep
|
||||
}
|
||||
|
||||
func DeleteDocument(databaseId string, collectionId string, documentId string) repositorymodels.RepositoryStatus {
|
||||
storeState.Lock()
|
||||
defer storeState.Unlock()
|
||||
|
||||
if _, ok := storeState.Databases[databaseId]; !ok {
|
||||
return repositorymodels.StatusNotFound
|
||||
}
|
||||
@ -61,6 +70,9 @@ func DeleteDocument(databaseId string, collectionId string, documentId string) r
|
||||
}
|
||||
|
||||
func CreateDocument(databaseId string, collectionId string, document map[string]interface{}) (repositorymodels.Document, repositorymodels.RepositoryStatus) {
|
||||
storeState.Lock()
|
||||
defer storeState.Unlock()
|
||||
|
||||
var ok bool
|
||||
var documentId string
|
||||
var database repositorymodels.Database
|
||||
|
@ -10,6 +10,9 @@ import (
|
||||
|
||||
// I have no idea what this is tbh
|
||||
func GetPartitionKeyRanges(databaseId string, collectionId string) ([]repositorymodels.PartitionKeyRange, repositorymodels.RepositoryStatus) {
|
||||
storeState.RLock()
|
||||
defer storeState.RUnlock()
|
||||
|
||||
databaseRid := databaseId
|
||||
collectionRid := collectionId
|
||||
var timestamp int64 = 0
|
||||
|
@ -66,6 +66,9 @@ func LoadStateFS(filePath string) {
|
||||
}
|
||||
|
||||
func SaveStateFS(filePath string) {
|
||||
storeState.RLock()
|
||||
defer storeState.RUnlock()
|
||||
|
||||
data, err := json.MarshalIndent(storeState, "", "\t")
|
||||
if err != nil {
|
||||
logger.Errorf("Failed to save state: %v\n", err)
|
||||
@ -80,8 +83,17 @@ func SaveStateFS(filePath string) {
|
||||
logger.Infof("Documents: %d\n", getLength(storeState.Documents))
|
||||
}
|
||||
|
||||
func GetState() repositorymodels.State {
|
||||
return storeState
|
||||
func GetState() (string, error) {
|
||||
storeState.RLock()
|
||||
defer storeState.RUnlock()
|
||||
|
||||
data, err := json.MarshalIndent(storeState, "", "\t")
|
||||
if err != nil {
|
||||
logger.Errorf("Failed to serialize state: %v\n", err)
|
||||
return "", err
|
||||
}
|
||||
|
||||
return string(data), nil
|
||||
}
|
||||
|
||||
func getLength(v interface{}) int {
|
||||
|
@ -1,5 +1,7 @@
|
||||
package repositorymodels
|
||||
|
||||
import "sync"
|
||||
|
||||
type Database struct {
|
||||
ID string `json:"id"`
|
||||
TimeStamp int64 `json:"_ts"`
|
||||
@ -101,6 +103,8 @@ type PartitionKeyRange struct {
|
||||
}
|
||||
|
||||
type State struct {
|
||||
sync.RWMutex
|
||||
|
||||
// Map databaseId -> Database
|
||||
Databases map[string]Database `json:"databases"`
|
||||
|
||||
|
@ -659,6 +659,7 @@ func compareValues(val1, val2 interface{}) int {
|
||||
|
||||
func deduplicate[T RowType | interface{}](slice []T) []T {
|
||||
var result []T
|
||||
result = make([]T, 0)
|
||||
|
||||
for i := 0; i < len(slice); i++ {
|
||||
unique := true
|
||||
|
Loading…
x
Reference in New Issue
Block a user