8 Commits

Author SHA1 Message Date
Pijus Kamandulis
f5b8453995 Support patch operations 'set' and 'incr' #7 2024-12-25 23:32:50 +02:00
Pijus Kamandulis
928ca29fe4 Support parameter in bracket #8 2024-12-25 21:28:42 +02:00
Pijus Kamandulis
39cd9e2357 Update dependancies 2024-12-20 20:27:42 +02:00
Pijus Kamandulis
bcf4b513b6 Expose repository functions to sharedlibs 2024-12-20 20:25:32 +02:00
Pijus Kamandulis
363f822e5a Added some tests for sharedlibrary 2024-12-19 23:21:45 +02:00
Pijus Kamandulis
be7a615931 Cross-Compile Shared Libraries 2024-12-19 00:48:17 +02:00
Pijus Kamandulis
83f086a2dc Configuration fixes 2024-12-18 23:28:04 +02:00
Pijus Kamandulis
777034181f Refactor to support multiple server instances in shared library 2024-12-18 19:39:57 +02:00
55 changed files with 2686 additions and 1713 deletions

View File

@@ -0,0 +1,30 @@
name: Cross-Compile Shared Libraries
on:
workflow_dispatch:
jobs:
build:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v3
- name: Cross-Compile with xgo
uses: crazy-max/ghaction-xgo@v3.1.0
with:
xgo_version: latest
go_version: 1.22.0
dest: dist
pkg: sharedlibrary
prefix: cosmium
targets: linux/amd64,linux/arm64,windows/amd64,windows/arm64,darwin/amd64,darwin/arm64
v: true
buildmode: c-shared
- name: Upload artifact
uses: actions/upload-artifact@v3
with:
name: shared-libraries
path: dist/*

View File

@@ -8,9 +8,17 @@ SERVER_LOCATION=./cmd/server
SHARED_LIB_LOCATION=./sharedlibrary SHARED_LIB_LOCATION=./sharedlibrary
SHARED_LIB_OPT=-buildmode=c-shared SHARED_LIB_OPT=-buildmode=c-shared
XGO_TARGETS=linux/amd64,linux/arm64,windows/amd64,windows/arm64,darwin/amd64,darwin/arm64
GOVERSION=1.22.0
DIST_DIR=dist DIST_DIR=dist
SHARED_LIB_TEST_CC=gcc
SHARED_LIB_TEST_CFLAGS=-Wall -ldl
SHARED_LIB_TEST_TARGET=$(DIST_DIR)/sharedlibrary_test
SHARED_LIB_TEST_DIR=./sharedlibrary/tests
SHARED_LIB_TEST_SOURCES=$(wildcard $(SHARED_LIB_TEST_DIR)/*.c)
all: test build-all all: test build-all
build-all: build-darwin-arm64 build-darwin-amd64 build-linux-amd64 build-linux-arm64 build-windows-amd64 build-windows-arm64 build-all: build-darwin-arm64 build-darwin-amd64 build-linux-amd64 build-linux-arm64 build-windows-amd64 build-windows-arm64
@@ -43,6 +51,19 @@ build-sharedlib-linux-amd64:
@echo "Building shared library for Linux x64..." @echo "Building shared library for Linux x64..."
@GOOS=linux GOARCH=amd64 $(GOBUILD) $(SHARED_LIB_OPT) -o $(DIST_DIR)/$(BINARY_NAME)-linux-amd64.so $(SHARED_LIB_LOCATION) @GOOS=linux GOARCH=amd64 $(GOBUILD) $(SHARED_LIB_OPT) -o $(DIST_DIR)/$(BINARY_NAME)-linux-amd64.so $(SHARED_LIB_LOCATION)
build-sharedlib-tests: build-sharedlib-linux-amd64
@echo "Building shared library tests..."
@$(SHARED_LIB_TEST_CC) $(SHARED_LIB_TEST_CFLAGS) -o $(SHARED_LIB_TEST_TARGET) $(SHARED_LIB_TEST_SOURCES)
run-sharedlib-tests: build-sharedlib-tests
@echo "Running shared library tests..."
@$(SHARED_LIB_TEST_TARGET) $(DIST_DIR)/$(BINARY_NAME)-linux-amd64.so
xgo-compile-sharedlib:
@echo "Building shared libraries using xgo..."
@mkdir -p $(DIST_DIR)
@xgo -targets=$(XGO_TARGETS) -go $(GOVERSION) -buildmode=c-shared -dest=$(DIST_DIR) -out=$(BINARY_NAME) -pkg=$(SHARED_LIB_LOCATION) .
generate-parser-nosql: generate-parser-nosql:
pigeon -o ./parsers/nosql/nosql.go ./parsers/nosql/nosql.peg pigeon -o ./parsers/nosql/nosql.go ./parsers/nosql/nosql.peg

35
api/api_server.go Normal file
View File

@@ -0,0 +1,35 @@
package api
import (
"github.com/gin-gonic/gin"
"github.com/pikami/cosmium/api/config"
"github.com/pikami/cosmium/internal/repositories"
)
type ApiServer struct {
stopServer chan interface{}
isActive bool
router *gin.Engine
config config.ServerConfig
}
func NewApiServer(dataRepository *repositories.DataRepository, config config.ServerConfig) *ApiServer {
stopChan := make(chan interface{})
apiServer := &ApiServer{
stopServer: stopChan,
config: config,
}
apiServer.CreateRouter(dataRepository)
return apiServer
}
func (s *ApiServer) GetRouter() *gin.Engine {
return s.router
}
func (s *ApiServer) Stop() {
s.stopServer <- true
}

View File

@@ -5,6 +5,8 @@ import (
"fmt" "fmt"
"os" "os"
"strings" "strings"
"github.com/pikami/cosmium/internal/logger"
) )
const ( const (
@@ -13,9 +15,7 @@ const (
ExplorerBaseUrlLocation = "/_explorer" ExplorerBaseUrlLocation = "/_explorer"
) )
var Config = ServerConfig{} func ParseFlags() ServerConfig {
func ParseFlags() {
host := flag.String("Host", "localhost", "Hostname") host := flag.String("Host", "localhost", "Hostname")
port := flag.Int("Port", 8081, "Listen port") port := flag.Int("Port", 8081, "Listen port")
explorerPath := flag.String("ExplorerDir", "", "Path to cosmos-explorer files") explorerPath := flag.String("ExplorerDir", "", "Path to cosmos-explorer files")
@@ -31,22 +31,42 @@ func ParseFlags() {
flag.Parse() flag.Parse()
setFlagsFromEnvironment() setFlagsFromEnvironment()
Config.Host = *host config := ServerConfig{}
Config.Port = *port config.Host = *host
Config.ExplorerPath = *explorerPath config.Port = *port
Config.TLS_CertificatePath = *tlsCertificatePath config.ExplorerPath = *explorerPath
Config.TLS_CertificateKey = *tlsCertificateKey config.TLS_CertificatePath = *tlsCertificatePath
Config.InitialDataFilePath = *initialDataPath config.TLS_CertificateKey = *tlsCertificateKey
Config.PersistDataFilePath = *persistDataPath config.InitialDataFilePath = *initialDataPath
Config.DisableAuth = *disableAuthentication config.PersistDataFilePath = *persistDataPath
Config.DisableTls = *disableTls config.DisableAuth = *disableAuthentication
Config.Debug = *debug config.DisableTls = *disableTls
config.Debug = *debug
config.AccountKey = *accountKey
Config.DatabaseAccount = Config.Host config.PopulateCalculatedFields()
Config.DatabaseDomain = Config.Host
Config.DatabaseEndpoint = fmt.Sprintf("https://%s:%d/", Config.Host, Config.Port) return config
Config.AccountKey = *accountKey }
Config.ExplorerBaseUrlLocation = ExplorerBaseUrlLocation
func (c *ServerConfig) PopulateCalculatedFields() {
c.DatabaseAccount = c.Host
c.DatabaseDomain = c.Host
c.DatabaseEndpoint = fmt.Sprintf("https://%s:%d/", c.Host, c.Port)
c.ExplorerBaseUrlLocation = ExplorerBaseUrlLocation
logger.EnableDebugOutput = c.Debug
}
func (c *ServerConfig) ApplyDefaultsToEmptyFields() {
if c.Host == "" {
c.Host = "localhost"
}
if c.Port == 0 {
c.Port = 8081
}
if c.AccountKey == "" {
c.AccountKey = DefaultAccountKey
}
} }
func setFlagsFromEnvironment() (err error) { func setFlagsFromEnvironment() (err error) {

View File

@@ -1,20 +1,20 @@
package config package config
type ServerConfig struct { type ServerConfig struct {
DatabaseAccount string DatabaseAccount string `json:"databaseAccount"`
DatabaseDomain string DatabaseDomain string `json:"databaseDomain"`
DatabaseEndpoint string DatabaseEndpoint string `json:"databaseEndpoint"`
AccountKey string AccountKey string `json:"accountKey"`
ExplorerPath string ExplorerPath string `json:"explorerPath"`
Port int Port int `json:"port"`
Host string Host string `json:"host"`
TLS_CertificatePath string TLS_CertificatePath string `json:"tlsCertificatePath"`
TLS_CertificateKey string TLS_CertificateKey string `json:"tlsCertificateKey"`
InitialDataFilePath string InitialDataFilePath string `json:"initialDataFilePath"`
PersistDataFilePath string PersistDataFilePath string `json:"persistDataFilePath"`
DisableAuth bool DisableAuth bool `json:"disableAuth"`
DisableTls bool DisableTls bool `json:"disableTls"`
Debug bool Debug bool `json:"debug"`
ExplorerBaseUrlLocation string ExplorerBaseUrlLocation string `json:"explorerBaseUrlLocation"`
} }

View File

@@ -5,16 +5,15 @@ import (
"net/http" "net/http"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/pikami/cosmium/internal/repositories"
repositorymodels "github.com/pikami/cosmium/internal/repository_models" repositorymodels "github.com/pikami/cosmium/internal/repository_models"
) )
func GetAllCollections(c *gin.Context) { func (h *Handlers) GetAllCollections(c *gin.Context) {
databaseId := c.Param("databaseId") databaseId := c.Param("databaseId")
collections, status := repositories.GetAllCollections(databaseId) collections, status := h.repository.GetAllCollections(databaseId)
if status == repositorymodels.StatusOk { if status == repositorymodels.StatusOk {
database, _ := repositories.GetDatabase(databaseId) database, _ := h.repository.GetDatabase(databaseId)
c.Header("x-ms-item-count", fmt.Sprintf("%d", len(collections))) c.Header("x-ms-item-count", fmt.Sprintf("%d", len(collections)))
c.IndentedJSON(http.StatusOK, gin.H{ c.IndentedJSON(http.StatusOK, gin.H{
@@ -28,11 +27,11 @@ func GetAllCollections(c *gin.Context) {
c.IndentedJSON(http.StatusInternalServerError, gin.H{"message": "Unknown error"}) c.IndentedJSON(http.StatusInternalServerError, gin.H{"message": "Unknown error"})
} }
func GetCollection(c *gin.Context) { func (h *Handlers) GetCollection(c *gin.Context) {
databaseId := c.Param("databaseId") databaseId := c.Param("databaseId")
id := c.Param("collId") id := c.Param("collId")
collection, status := repositories.GetCollection(databaseId, id) collection, status := h.repository.GetCollection(databaseId, id)
if status == repositorymodels.StatusOk { if status == repositorymodels.StatusOk {
c.IndentedJSON(http.StatusOK, collection) c.IndentedJSON(http.StatusOK, collection)
return return
@@ -46,11 +45,11 @@ func GetCollection(c *gin.Context) {
c.IndentedJSON(http.StatusInternalServerError, gin.H{"message": "Unknown error"}) c.IndentedJSON(http.StatusInternalServerError, gin.H{"message": "Unknown error"})
} }
func DeleteCollection(c *gin.Context) { func (h *Handlers) DeleteCollection(c *gin.Context) {
databaseId := c.Param("databaseId") databaseId := c.Param("databaseId")
id := c.Param("collId") id := c.Param("collId")
status := repositories.DeleteCollection(databaseId, id) status := h.repository.DeleteCollection(databaseId, id)
if status == repositorymodels.StatusOk { if status == repositorymodels.StatusOk {
c.Status(http.StatusNoContent) c.Status(http.StatusNoContent)
return return
@@ -64,7 +63,7 @@ func DeleteCollection(c *gin.Context) {
c.IndentedJSON(http.StatusInternalServerError, gin.H{"message": "Unknown error"}) c.IndentedJSON(http.StatusInternalServerError, gin.H{"message": "Unknown error"})
} }
func CreateCollection(c *gin.Context) { func (h *Handlers) CreateCollection(c *gin.Context) {
databaseId := c.Param("databaseId") databaseId := c.Param("databaseId")
var newCollection repositorymodels.Collection var newCollection repositorymodels.Collection
@@ -78,7 +77,7 @@ func CreateCollection(c *gin.Context) {
return return
} }
createdCollection, status := repositories.CreateCollection(databaseId, newCollection) createdCollection, status := h.repository.CreateCollection(databaseId, newCollection)
if status == repositorymodels.Conflict { if status == repositorymodels.Conflict {
c.IndentedJSON(http.StatusConflict, gin.H{"message": "Conflict"}) c.IndentedJSON(http.StatusConflict, gin.H{"message": "Conflict"})
return return

View File

@@ -4,11 +4,10 @@ import (
"net/http" "net/http"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/pikami/cosmium/internal/repositories"
) )
func CosmiumExport(c *gin.Context) { func (h *Handlers) CosmiumExport(c *gin.Context) {
repositoryState, err := repositories.GetState() repositoryState, err := h.repository.GetState()
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return

View File

@@ -5,12 +5,11 @@ import (
"net/http" "net/http"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/pikami/cosmium/internal/repositories"
repositorymodels "github.com/pikami/cosmium/internal/repository_models" repositorymodels "github.com/pikami/cosmium/internal/repository_models"
) )
func GetAllDatabases(c *gin.Context) { func (h *Handlers) GetAllDatabases(c *gin.Context) {
databases, status := repositories.GetAllDatabases() databases, status := h.repository.GetAllDatabases()
if status == repositorymodels.StatusOk { if status == repositorymodels.StatusOk {
c.Header("x-ms-item-count", fmt.Sprintf("%d", len(databases))) c.Header("x-ms-item-count", fmt.Sprintf("%d", len(databases)))
c.IndentedJSON(http.StatusOK, gin.H{ c.IndentedJSON(http.StatusOK, gin.H{
@@ -24,10 +23,10 @@ func GetAllDatabases(c *gin.Context) {
c.IndentedJSON(http.StatusInternalServerError, gin.H{"message": "Unknown error"}) c.IndentedJSON(http.StatusInternalServerError, gin.H{"message": "Unknown error"})
} }
func GetDatabase(c *gin.Context) { func (h *Handlers) GetDatabase(c *gin.Context) {
id := c.Param("databaseId") id := c.Param("databaseId")
database, status := repositories.GetDatabase(id) database, status := h.repository.GetDatabase(id)
if status == repositorymodels.StatusOk { if status == repositorymodels.StatusOk {
c.IndentedJSON(http.StatusOK, database) c.IndentedJSON(http.StatusOK, database)
return return
@@ -41,10 +40,10 @@ func GetDatabase(c *gin.Context) {
c.IndentedJSON(http.StatusInternalServerError, gin.H{"message": "Unknown error"}) c.IndentedJSON(http.StatusInternalServerError, gin.H{"message": "Unknown error"})
} }
func DeleteDatabase(c *gin.Context) { func (h *Handlers) DeleteDatabase(c *gin.Context) {
id := c.Param("databaseId") id := c.Param("databaseId")
status := repositories.DeleteDatabase(id) status := h.repository.DeleteDatabase(id)
if status == repositorymodels.StatusOk { if status == repositorymodels.StatusOk {
c.Status(http.StatusNoContent) c.Status(http.StatusNoContent)
return return
@@ -58,7 +57,7 @@ func DeleteDatabase(c *gin.Context) {
c.IndentedJSON(http.StatusInternalServerError, gin.H{"message": "Unknown error"}) c.IndentedJSON(http.StatusInternalServerError, gin.H{"message": "Unknown error"})
} }
func CreateDatabase(c *gin.Context) { func (h *Handlers) CreateDatabase(c *gin.Context) {
var newDatabase repositorymodels.Database var newDatabase repositorymodels.Database
if err := c.BindJSON(&newDatabase); err != nil { if err := c.BindJSON(&newDatabase); err != nil {
@@ -71,7 +70,7 @@ func CreateDatabase(c *gin.Context) {
return return
} }
createdDatabase, status := repositories.CreateDatabase(newDatabase) createdDatabase, status := h.repository.CreateDatabase(newDatabase)
if status == repositorymodels.Conflict { if status == repositorymodels.Conflict {
c.IndentedJSON(http.StatusConflict, gin.H{"message": "Conflict"}) c.IndentedJSON(http.StatusConflict, gin.H{"message": "Conflict"})
return return

View File

@@ -6,21 +6,20 @@ import (
"net/http" "net/http"
"strconv" "strconv"
jsonpatch "github.com/evanphx/json-patch/v5"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/pikami/cosmium/internal/constants" "github.com/pikami/cosmium/internal/constants"
"github.com/pikami/cosmium/internal/logger" "github.com/pikami/cosmium/internal/logger"
"github.com/pikami/cosmium/internal/repositories"
repositorymodels "github.com/pikami/cosmium/internal/repository_models" repositorymodels "github.com/pikami/cosmium/internal/repository_models"
jsonpatch "github.com/pikami/json-patch/v5"
) )
func GetAllDocuments(c *gin.Context) { func (h *Handlers) GetAllDocuments(c *gin.Context) {
databaseId := c.Param("databaseId") databaseId := c.Param("databaseId")
collectionId := c.Param("collId") collectionId := c.Param("collId")
documents, status := repositories.GetAllDocuments(databaseId, collectionId) documents, status := h.repository.GetAllDocuments(databaseId, collectionId)
if status == repositorymodels.StatusOk { if status == repositorymodels.StatusOk {
collection, _ := repositories.GetCollection(databaseId, collectionId) collection, _ := h.repository.GetCollection(databaseId, collectionId)
c.Header("x-ms-item-count", fmt.Sprintf("%d", len(documents))) c.Header("x-ms-item-count", fmt.Sprintf("%d", len(documents)))
c.IndentedJSON(http.StatusOK, gin.H{ c.IndentedJSON(http.StatusOK, gin.H{
@@ -34,12 +33,12 @@ func GetAllDocuments(c *gin.Context) {
c.IndentedJSON(http.StatusInternalServerError, gin.H{"message": "Unknown error"}) c.IndentedJSON(http.StatusInternalServerError, gin.H{"message": "Unknown error"})
} }
func GetDocument(c *gin.Context) { func (h *Handlers) GetDocument(c *gin.Context) {
databaseId := c.Param("databaseId") databaseId := c.Param("databaseId")
collectionId := c.Param("collId") collectionId := c.Param("collId")
documentId := c.Param("docId") documentId := c.Param("docId")
document, status := repositories.GetDocument(databaseId, collectionId, documentId) document, status := h.repository.GetDocument(databaseId, collectionId, documentId)
if status == repositorymodels.StatusOk { if status == repositorymodels.StatusOk {
c.IndentedJSON(http.StatusOK, document) c.IndentedJSON(http.StatusOK, document)
return return
@@ -53,12 +52,12 @@ func GetDocument(c *gin.Context) {
c.IndentedJSON(http.StatusInternalServerError, gin.H{"message": "Unknown error"}) c.IndentedJSON(http.StatusInternalServerError, gin.H{"message": "Unknown error"})
} }
func DeleteDocument(c *gin.Context) { func (h *Handlers) DeleteDocument(c *gin.Context) {
databaseId := c.Param("databaseId") databaseId := c.Param("databaseId")
collectionId := c.Param("collId") collectionId := c.Param("collId")
documentId := c.Param("docId") documentId := c.Param("docId")
status := repositories.DeleteDocument(databaseId, collectionId, documentId) status := h.repository.DeleteDocument(databaseId, collectionId, documentId)
if status == repositorymodels.StatusOk { if status == repositorymodels.StatusOk {
c.Status(http.StatusNoContent) c.Status(http.StatusNoContent)
return return
@@ -73,7 +72,7 @@ func DeleteDocument(c *gin.Context) {
} }
// TODO: Maybe move "replace" logic to repository // TODO: Maybe move "replace" logic to repository
func ReplaceDocument(c *gin.Context) { func (h *Handlers) ReplaceDocument(c *gin.Context) {
databaseId := c.Param("databaseId") databaseId := c.Param("databaseId")
collectionId := c.Param("collId") collectionId := c.Param("collId")
documentId := c.Param("docId") documentId := c.Param("docId")
@@ -84,13 +83,13 @@ func ReplaceDocument(c *gin.Context) {
return return
} }
status := repositories.DeleteDocument(databaseId, collectionId, documentId) status := h.repository.DeleteDocument(databaseId, collectionId, documentId)
if status == repositorymodels.StatusNotFound { if status == repositorymodels.StatusNotFound {
c.IndentedJSON(http.StatusNotFound, gin.H{"message": "NotFound"}) c.IndentedJSON(http.StatusNotFound, gin.H{"message": "NotFound"})
return return
} }
createdDocument, status := repositories.CreateDocument(databaseId, collectionId, requestBody) createdDocument, status := h.repository.CreateDocument(databaseId, collectionId, requestBody)
if status == repositorymodels.Conflict { if status == repositorymodels.Conflict {
c.IndentedJSON(http.StatusConflict, gin.H{"message": "Conflict"}) c.IndentedJSON(http.StatusConflict, gin.H{"message": "Conflict"})
return return
@@ -104,12 +103,12 @@ func ReplaceDocument(c *gin.Context) {
c.IndentedJSON(http.StatusInternalServerError, gin.H{"message": "Unknown error"}) c.IndentedJSON(http.StatusInternalServerError, gin.H{"message": "Unknown error"})
} }
func PatchDocument(c *gin.Context) { func (h *Handlers) PatchDocument(c *gin.Context) {
databaseId := c.Param("databaseId") databaseId := c.Param("databaseId")
collectionId := c.Param("collId") collectionId := c.Param("collId")
documentId := c.Param("docId") documentId := c.Param("docId")
document, status := repositories.GetDocument(databaseId, collectionId, documentId) document, status := h.repository.GetDocument(databaseId, collectionId, documentId)
if status == repositorymodels.StatusNotFound { if status == repositorymodels.StatusNotFound {
c.IndentedJSON(http.StatusNotFound, gin.H{"message": "NotFound"}) c.IndentedJSON(http.StatusNotFound, gin.H{"message": "NotFound"})
return return
@@ -160,13 +159,13 @@ func PatchDocument(c *gin.Context) {
return return
} }
status = repositories.DeleteDocument(databaseId, collectionId, documentId) status = h.repository.DeleteDocument(databaseId, collectionId, documentId)
if status == repositorymodels.StatusNotFound { if status == repositorymodels.StatusNotFound {
c.IndentedJSON(http.StatusNotFound, gin.H{"message": "NotFound"}) c.IndentedJSON(http.StatusNotFound, gin.H{"message": "NotFound"})
return return
} }
createdDocument, status := repositories.CreateDocument(databaseId, collectionId, modifiedDocument) createdDocument, status := h.repository.CreateDocument(databaseId, collectionId, modifiedDocument)
if status == repositorymodels.Conflict { if status == repositorymodels.Conflict {
c.IndentedJSON(http.StatusConflict, gin.H{"message": "Conflict"}) c.IndentedJSON(http.StatusConflict, gin.H{"message": "Conflict"})
return return
@@ -180,7 +179,7 @@ func PatchDocument(c *gin.Context) {
c.IndentedJSON(http.StatusInternalServerError, gin.H{"message": "Unknown error"}) c.IndentedJSON(http.StatusInternalServerError, gin.H{"message": "Unknown error"})
} }
func DocumentsPost(c *gin.Context) { func (h *Handlers) DocumentsPost(c *gin.Context) {
databaseId := c.Param("databaseId") databaseId := c.Param("databaseId")
collectionId := c.Param("collId") collectionId := c.Param("collId")
@@ -202,14 +201,14 @@ func DocumentsPost(c *gin.Context) {
queryParameters = parametersToMap(paramsArray) queryParameters = parametersToMap(paramsArray)
} }
docs, status := repositories.ExecuteQueryDocuments(databaseId, collectionId, query.(string), queryParameters) docs, status := h.repository.ExecuteQueryDocuments(databaseId, collectionId, query.(string), queryParameters)
if status != repositorymodels.StatusOk { if status != repositorymodels.StatusOk {
// TODO: Currently we return everything if the query fails // TODO: Currently we return everything if the query fails
GetAllDocuments(c) h.GetAllDocuments(c)
return return
} }
collection, _ := repositories.GetCollection(databaseId, collectionId) collection, _ := h.repository.GetCollection(databaseId, collectionId)
c.Header("x-ms-item-count", fmt.Sprintf("%d", len(docs))) c.Header("x-ms-item-count", fmt.Sprintf("%d", len(docs)))
c.IndentedJSON(http.StatusOK, gin.H{ c.IndentedJSON(http.StatusOK, gin.H{
"_rid": collection.ResourceID, "_rid": collection.ResourceID,
@@ -226,10 +225,10 @@ func DocumentsPost(c *gin.Context) {
isUpsert, _ := strconv.ParseBool(c.GetHeader("x-ms-documentdb-is-upsert")) isUpsert, _ := strconv.ParseBool(c.GetHeader("x-ms-documentdb-is-upsert"))
if isUpsert { if isUpsert {
repositories.DeleteDocument(databaseId, collectionId, requestBody["id"].(string)) h.repository.DeleteDocument(databaseId, collectionId, requestBody["id"].(string))
} }
createdDocument, status := repositories.CreateDocument(databaseId, collectionId, requestBody) createdDocument, status := h.repository.CreateDocument(databaseId, collectionId, requestBody)
if status == repositorymodels.Conflict { if status == repositorymodels.Conflict {
c.IndentedJSON(http.StatusConflict, gin.H{"message": "Conflict"}) c.IndentedJSON(http.StatusConflict, gin.H{"message": "Conflict"})
return return

View File

@@ -4,15 +4,14 @@ import (
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/pikami/cosmium/api/config"
) )
func RegisterExplorerHandlers(router *gin.Engine) { func (h *Handlers) RegisterExplorerHandlers(router *gin.Engine) {
explorer := router.Group(config.Config.ExplorerBaseUrlLocation) explorer := router.Group(h.config.ExplorerBaseUrlLocation)
{ {
explorer.Use(func(ctx *gin.Context) { explorer.Use(func(ctx *gin.Context) {
if ctx.Param("filepath") == "/config.json" { if ctx.Param("filepath") == "/config.json" {
endpoint := fmt.Sprintf("https://%s:%d", config.Config.Host, config.Config.Port) endpoint := fmt.Sprintf("https://%s:%d", h.config.Host, h.config.Port)
ctx.JSON(200, gin.H{ ctx.JSON(200, gin.H{
"BACKEND_ENDPOINT": endpoint, "BACKEND_ENDPOINT": endpoint,
"MONGO_BACKEND_ENDPOINT": endpoint, "MONGO_BACKEND_ENDPOINT": endpoint,
@@ -25,8 +24,8 @@ func RegisterExplorerHandlers(router *gin.Engine) {
} }
}) })
if config.Config.ExplorerPath != "" { if h.config.ExplorerPath != "" {
explorer.Static("/", config.Config.ExplorerPath) explorer.Static("/", h.config.ExplorerPath)
} }
} }
} }

18
api/handlers/handlers.go Normal file
View File

@@ -0,0 +1,18 @@
package handlers
import (
"github.com/pikami/cosmium/api/config"
"github.com/pikami/cosmium/internal/repositories"
)
type Handlers struct {
repository *repositories.DataRepository
config config.ServerConfig
}
func NewHandlers(dataRepository *repositories.DataRepository, config config.ServerConfig) *Handlers {
return &Handlers{
repository: dataRepository,
config: config,
}
}

View File

@@ -10,11 +10,11 @@ import (
"github.com/pikami/cosmium/internal/logger" "github.com/pikami/cosmium/internal/logger"
) )
func Authentication() gin.HandlerFunc { func Authentication(config config.ServerConfig) gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
requestUrl := c.Request.URL.String() requestUrl := c.Request.URL.String()
if config.Config.DisableAuth || if config.DisableAuth ||
strings.HasPrefix(requestUrl, config.Config.ExplorerBaseUrlLocation) || strings.HasPrefix(requestUrl, config.ExplorerBaseUrlLocation) ||
strings.HasPrefix(requestUrl, "/cosmium") { strings.HasPrefix(requestUrl, "/cosmium") {
return return
} }
@@ -25,7 +25,7 @@ func Authentication() gin.HandlerFunc {
authHeader := c.Request.Header.Get("authorization") authHeader := c.Request.Header.Get("authorization")
date := c.Request.Header.Get("x-ms-date") date := c.Request.Header.Get("x-ms-date")
expectedSignature := authentication.GenerateSignature( expectedSignature := authentication.GenerateSignature(
c.Request.Method, resourceType, resourceId, date, config.Config.AccountKey) c.Request.Method, resourceType, resourceId, date, config.AccountKey)
decoded, _ := url.QueryUnescape(authHeader) decoded, _ := url.QueryUnescape(authHeader)
params, _ := url.ParseQuery(decoded) params, _ := url.ParseQuery(decoded)

View File

@@ -7,10 +7,10 @@ import (
"github.com/pikami/cosmium/api/config" "github.com/pikami/cosmium/api/config"
) )
func StripTrailingSlashes(r *gin.Engine) gin.HandlerFunc { func StripTrailingSlashes(r *gin.Engine, config config.ServerConfig) gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
path := c.Request.URL.Path path := c.Request.URL.Path
if len(path) > 1 && path[len(path)-1] == '/' && !strings.Contains(path, config.Config.ExplorerBaseUrlLocation) { if len(path) > 1 && path[len(path)-1] == '/' && !strings.Contains(path, config.ExplorerBaseUrlLocation) {
c.Request.URL.Path = path[:len(path)-1] c.Request.URL.Path = path[:len(path)-1]
r.HandleContext(c) r.HandleContext(c)
c.Abort() c.Abort()

View File

@@ -5,11 +5,10 @@ import (
"net/http" "net/http"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/pikami/cosmium/internal/repositories"
repositorymodels "github.com/pikami/cosmium/internal/repository_models" repositorymodels "github.com/pikami/cosmium/internal/repository_models"
) )
func GetPartitionKeyRanges(c *gin.Context) { func (h *Handlers) GetPartitionKeyRanges(c *gin.Context) {
databaseId := c.Param("databaseId") databaseId := c.Param("databaseId")
collectionId := c.Param("collId") collectionId := c.Param("collId")
@@ -18,7 +17,7 @@ func GetPartitionKeyRanges(c *gin.Context) {
return return
} }
partitionKeyRanges, status := repositories.GetPartitionKeyRanges(databaseId, collectionId) partitionKeyRanges, status := h.repository.GetPartitionKeyRanges(databaseId, collectionId)
if status == repositorymodels.StatusOk { if status == repositorymodels.StatusOk {
c.Header("etag", "\"420\"") c.Header("etag", "\"420\"")
c.Header("lsn", "420") c.Header("lsn", "420")
@@ -27,7 +26,7 @@ func GetPartitionKeyRanges(c *gin.Context) {
c.Header("x-ms-item-count", fmt.Sprintf("%d", len(partitionKeyRanges))) c.Header("x-ms-item-count", fmt.Sprintf("%d", len(partitionKeyRanges)))
collectionRid := collectionId collectionRid := collectionId
collection, _ := repositories.GetCollection(databaseId, collectionId) collection, _ := h.repository.GetCollection(databaseId, collectionId)
if collection.ResourceID != "" { if collection.ResourceID != "" {
collectionRid = collection.ResourceID collectionRid = collection.ResourceID
} }

View File

@@ -5,27 +5,26 @@ import (
"net/http" "net/http"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/pikami/cosmium/api/config"
) )
func GetServerInfo(c *gin.Context) { func (h *Handlers) GetServerInfo(c *gin.Context) {
c.IndentedJSON(http.StatusOK, gin.H{ c.IndentedJSON(http.StatusOK, gin.H{
"_self": "", "_self": "",
"id": config.Config.DatabaseAccount, "id": h.config.DatabaseAccount,
"_rid": fmt.Sprintf("%s.%s", config.Config.DatabaseAccount, config.Config.DatabaseDomain), "_rid": fmt.Sprintf("%s.%s", h.config.DatabaseAccount, h.config.DatabaseDomain),
"media": "//media/", "media": "//media/",
"addresses": "//addresses/", "addresses": "//addresses/",
"_dbs": "//dbs/", "_dbs": "//dbs/",
"writableLocations": []map[string]interface{}{ "writableLocations": []map[string]interface{}{
{ {
"name": "South Central US", "name": "South Central US",
"databaseAccountEndpoint": config.Config.DatabaseEndpoint, "databaseAccountEndpoint": h.config.DatabaseEndpoint,
}, },
}, },
"readableLocations": []map[string]interface{}{ "readableLocations": []map[string]interface{}{
{ {
"name": "South Central US", "name": "South Central US",
"databaseAccountEndpoint": config.Config.DatabaseEndpoint, "databaseAccountEndpoint": h.config.DatabaseEndpoint,
}, },
}, },
"enableMultipleWriteLocations": false, "enableMultipleWriteLocations": false,

View File

@@ -5,15 +5,14 @@ import (
"net/http" "net/http"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/pikami/cosmium/internal/repositories"
repositorymodels "github.com/pikami/cosmium/internal/repository_models" repositorymodels "github.com/pikami/cosmium/internal/repository_models"
) )
func GetAllStoredProcedures(c *gin.Context) { func (h *Handlers) GetAllStoredProcedures(c *gin.Context) {
databaseId := c.Param("databaseId") databaseId := c.Param("databaseId")
collectionId := c.Param("collId") collectionId := c.Param("collId")
sps, status := repositories.GetAllStoredProcedures(databaseId, collectionId) sps, status := h.repository.GetAllStoredProcedures(databaseId, collectionId)
if status == repositorymodels.StatusOk { if status == repositorymodels.StatusOk {
c.Header("x-ms-item-count", fmt.Sprintf("%d", len(sps))) c.Header("x-ms-item-count", fmt.Sprintf("%d", len(sps)))

View File

@@ -5,15 +5,14 @@ import (
"net/http" "net/http"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/pikami/cosmium/internal/repositories"
repositorymodels "github.com/pikami/cosmium/internal/repository_models" repositorymodels "github.com/pikami/cosmium/internal/repository_models"
) )
func GetAllTriggers(c *gin.Context) { func (h *Handlers) GetAllTriggers(c *gin.Context) {
databaseId := c.Param("databaseId") databaseId := c.Param("databaseId")
collectionId := c.Param("collId") collectionId := c.Param("collId")
triggers, status := repositories.GetAllTriggers(databaseId, collectionId) triggers, status := h.repository.GetAllTriggers(databaseId, collectionId)
if status == repositorymodels.StatusOk { if status == repositorymodels.StatusOk {
c.Header("x-ms-item-count", fmt.Sprintf("%d", len(triggers))) c.Header("x-ms-item-count", fmt.Sprintf("%d", len(triggers)))

View File

@@ -5,15 +5,14 @@ import (
"net/http" "net/http"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/pikami/cosmium/internal/repositories"
repositorymodels "github.com/pikami/cosmium/internal/repository_models" repositorymodels "github.com/pikami/cosmium/internal/repository_models"
) )
func GetAllUserDefinedFunctions(c *gin.Context) { func (h *Handlers) GetAllUserDefinedFunctions(c *gin.Context) {
databaseId := c.Param("databaseId") databaseId := c.Param("databaseId")
collectionId := c.Param("collId") collectionId := c.Param("collId")
udfs, status := repositories.GetAllUserDefinedFunctions(databaseId, collectionId) udfs, status := h.repository.GetAllUserDefinedFunctions(databaseId, collectionId)
if status == repositorymodels.StatusOk { if status == repositorymodels.StatusOk {
c.Header("x-ms-item-count", fmt.Sprintf("%d", len(udfs))) c.Header("x-ms-item-count", fmt.Sprintf("%d", len(udfs)))

View File

@@ -6,78 +6,75 @@ import (
"net/http" "net/http"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/pikami/cosmium/api/config"
"github.com/pikami/cosmium/api/handlers" "github.com/pikami/cosmium/api/handlers"
"github.com/pikami/cosmium/api/handlers/middleware" "github.com/pikami/cosmium/api/handlers/middleware"
"github.com/pikami/cosmium/internal/logger" "github.com/pikami/cosmium/internal/logger"
"github.com/pikami/cosmium/internal/repositories"
tlsprovider "github.com/pikami/cosmium/internal/tls_provider" tlsprovider "github.com/pikami/cosmium/internal/tls_provider"
) )
type Server struct { func (s *ApiServer) CreateRouter(repository *repositories.DataRepository) {
StopServer chan interface{} routeHandlers := handlers.NewHandlers(repository, s.config)
}
if !s.config.Debug {
gin.SetMode(gin.ReleaseMode)
}
func CreateRouter() *gin.Engine {
router := gin.Default(func(e *gin.Engine) { router := gin.Default(func(e *gin.Engine) {
e.RedirectTrailingSlash = false e.RedirectTrailingSlash = false
}) })
if config.Config.Debug { if s.config.Debug {
router.Use(middleware.RequestLogger()) router.Use(middleware.RequestLogger())
} }
router.Use(middleware.StripTrailingSlashes(router)) router.Use(middleware.StripTrailingSlashes(router, s.config))
router.Use(middleware.Authentication()) router.Use(middleware.Authentication(s.config))
router.GET("/dbs/:databaseId/colls/:collId/pkranges", handlers.GetPartitionKeyRanges) router.GET("/dbs/:databaseId/colls/:collId/pkranges", routeHandlers.GetPartitionKeyRanges)
router.POST("/dbs/:databaseId/colls/:collId/docs", handlers.DocumentsPost) router.POST("/dbs/:databaseId/colls/:collId/docs", routeHandlers.DocumentsPost)
router.GET("/dbs/:databaseId/colls/:collId/docs", handlers.GetAllDocuments) router.GET("/dbs/:databaseId/colls/:collId/docs", routeHandlers.GetAllDocuments)
router.GET("/dbs/:databaseId/colls/:collId/docs/:docId", handlers.GetDocument) router.GET("/dbs/:databaseId/colls/:collId/docs/:docId", routeHandlers.GetDocument)
router.PUT("/dbs/:databaseId/colls/:collId/docs/:docId", handlers.ReplaceDocument) router.PUT("/dbs/:databaseId/colls/:collId/docs/:docId", routeHandlers.ReplaceDocument)
router.PATCH("/dbs/:databaseId/colls/:collId/docs/:docId", handlers.PatchDocument) router.PATCH("/dbs/:databaseId/colls/:collId/docs/:docId", routeHandlers.PatchDocument)
router.DELETE("/dbs/:databaseId/colls/:collId/docs/:docId", handlers.DeleteDocument) router.DELETE("/dbs/:databaseId/colls/:collId/docs/:docId", routeHandlers.DeleteDocument)
router.POST("/dbs/:databaseId/colls", handlers.CreateCollection) router.POST("/dbs/:databaseId/colls", routeHandlers.CreateCollection)
router.GET("/dbs/:databaseId/colls", handlers.GetAllCollections) router.GET("/dbs/:databaseId/colls", routeHandlers.GetAllCollections)
router.GET("/dbs/:databaseId/colls/:collId", handlers.GetCollection) router.GET("/dbs/:databaseId/colls/:collId", routeHandlers.GetCollection)
router.DELETE("/dbs/:databaseId/colls/:collId", handlers.DeleteCollection) router.DELETE("/dbs/:databaseId/colls/:collId", routeHandlers.DeleteCollection)
router.POST("/dbs", handlers.CreateDatabase) router.POST("/dbs", routeHandlers.CreateDatabase)
router.GET("/dbs", handlers.GetAllDatabases) router.GET("/dbs", routeHandlers.GetAllDatabases)
router.GET("/dbs/:databaseId", handlers.GetDatabase) router.GET("/dbs/:databaseId", routeHandlers.GetDatabase)
router.DELETE("/dbs/:databaseId", handlers.DeleteDatabase) router.DELETE("/dbs/:databaseId", routeHandlers.DeleteDatabase)
router.GET("/dbs/:databaseId/colls/:collId/udfs", handlers.GetAllUserDefinedFunctions) router.GET("/dbs/:databaseId/colls/:collId/udfs", routeHandlers.GetAllUserDefinedFunctions)
router.GET("/dbs/:databaseId/colls/:collId/sprocs", handlers.GetAllStoredProcedures) router.GET("/dbs/:databaseId/colls/:collId/sprocs", routeHandlers.GetAllStoredProcedures)
router.GET("/dbs/:databaseId/colls/:collId/triggers", handlers.GetAllTriggers) router.GET("/dbs/:databaseId/colls/:collId/triggers", routeHandlers.GetAllTriggers)
router.GET("/offers", handlers.GetOffers) router.GET("/offers", handlers.GetOffers)
router.GET("/", handlers.GetServerInfo) router.GET("/", routeHandlers.GetServerInfo)
router.GET("/cosmium/export", handlers.CosmiumExport) router.GET("/cosmium/export", routeHandlers.CosmiumExport)
handlers.RegisterExplorerHandlers(router) routeHandlers.RegisterExplorerHandlers(router)
return router s.router = router
} }
func StartAPI() *Server { func (s *ApiServer) Start() {
if !config.Config.Debug { listenAddress := fmt.Sprintf(":%d", s.config.Port)
gin.SetMode(gin.ReleaseMode) s.isActive = true
}
router := CreateRouter()
listenAddress := fmt.Sprintf(":%d", config.Config.Port)
stopChan := make(chan interface{})
server := &http.Server{ server := &http.Server{
Addr: listenAddress, Addr: listenAddress,
Handler: router.Handler(), Handler: s.router.Handler(),
} }
go func() { go func() {
<-stopChan <-s.stopServer
logger.Info("Shutting down server...") logger.Info("Shutting down server...")
err := server.Shutdown(context.TODO()) err := server.Shutdown(context.TODO())
if err != nil { if err != nil {
@@ -86,24 +83,22 @@ func StartAPI() *Server {
}() }()
go func() { go func() {
if config.Config.DisableTls { if s.config.DisableTls {
logger.Infof("Listening and serving HTTP on %s\n", server.Addr) logger.Infof("Listening and serving HTTP on %s\n", server.Addr)
err := server.ListenAndServe() err := server.ListenAndServe()
if err != nil { if err != nil {
logger.Error("Failed to start HTTP server:", err) logger.Error("Failed to start HTTP server:", err)
} }
return s.isActive = false
} } else if s.config.TLS_CertificatePath != "" && s.config.TLS_CertificateKey != "" {
if config.Config.TLS_CertificatePath != "" && config.Config.TLS_CertificateKey != "" {
logger.Infof("Listening and serving HTTPS on %s\n", server.Addr) logger.Infof("Listening and serving HTTPS on %s\n", server.Addr)
err := server.ListenAndServeTLS( err := server.ListenAndServeTLS(
config.Config.TLS_CertificatePath, s.config.TLS_CertificatePath,
config.Config.TLS_CertificateKey) s.config.TLS_CertificateKey)
if err != nil { if err != nil {
logger.Error("Failed to start HTTPS server:", err) logger.Error("Failed to start HTTPS server:", err)
} }
return s.isActive = false
} else { } else {
tlsConfig := tlsprovider.GetDefaultTlsConfig() tlsConfig := tlsprovider.GetDefaultTlsConfig()
server.TLSConfig = tlsConfig server.TLSConfig = tlsConfig
@@ -113,9 +108,7 @@ func StartAPI() *Server {
if err != nil { if err != nil {
logger.Error("Failed to start HTTPS server:", err) logger.Error("Failed to start HTTPS server:", err)
} }
return s.isActive = false
} }
}() }()
return &Server{StopServer: stopChan}
} }

View File

@@ -11,16 +11,15 @@ import (
"github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos" "github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos"
"github.com/pikami/cosmium/api/config" "github.com/pikami/cosmium/api/config"
"github.com/pikami/cosmium/internal/repositories"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func Test_Authentication(t *testing.T) { func Test_Authentication(t *testing.T) {
ts := runTestServer() ts := runTestServer()
defer ts.Close() defer ts.Server.Close()
t.Run("Should get 200 when correct account key is used", func(t *testing.T) { t.Run("Should get 200 when correct account key is used", func(t *testing.T) {
repositories.DeleteDatabase(testDatabaseName) ts.Repository.DeleteDatabase(testDatabaseName)
client, err := azcosmos.NewClientFromConnectionString( client, err := azcosmos.NewClientFromConnectionString(
fmt.Sprintf("AccountEndpoint=%s;AccountKey=%s", ts.URL, config.DefaultAccountKey), fmt.Sprintf("AccountEndpoint=%s;AccountKey=%s", ts.URL, config.DefaultAccountKey),
&azcosmos.ClientOptions{}, &azcosmos.ClientOptions{},
@@ -35,26 +34,8 @@ func Test_Authentication(t *testing.T) {
assert.Equal(t, createResponse.DatabaseProperties.ID, testDatabaseName) assert.Equal(t, createResponse.DatabaseProperties.ID, testDatabaseName)
}) })
t.Run("Should get 200 when wrong account key is used, but authentication is dissabled", func(t *testing.T) {
config.Config.DisableAuth = true
repositories.DeleteDatabase(testDatabaseName)
client, err := azcosmos.NewClientFromConnectionString(
fmt.Sprintf("AccountEndpoint=%s;AccountKey=%s", ts.URL, "AAAA"),
&azcosmos.ClientOptions{},
)
assert.Nil(t, err)
createResponse, err := client.CreateDatabase(
context.TODO(),
azcosmos.DatabaseProperties{ID: testDatabaseName},
&azcosmos.CreateDatabaseOptions{})
assert.Nil(t, err)
assert.Equal(t, createResponse.DatabaseProperties.ID, testDatabaseName)
config.Config.DisableAuth = false
})
t.Run("Should get 401 when wrong account key is used", func(t *testing.T) { t.Run("Should get 401 when wrong account key is used", func(t *testing.T) {
repositories.DeleteDatabase(testDatabaseName) ts.Repository.DeleteDatabase(testDatabaseName)
client, err := azcosmos.NewClientFromConnectionString( client, err := azcosmos.NewClientFromConnectionString(
fmt.Sprintf("AccountEndpoint=%s;AccountKey=%s", ts.URL, "AAAA"), fmt.Sprintf("AccountEndpoint=%s;AccountKey=%s", ts.URL, "AAAA"),
&azcosmos.ClientOptions{}, &azcosmos.ClientOptions{},
@@ -85,3 +66,29 @@ func Test_Authentication(t *testing.T) {
assert.Contains(t, string(responseBody), "BACKEND_ENDPOINT") assert.Contains(t, string(responseBody), "BACKEND_ENDPOINT")
}) })
} }
func Test_Authentication_Disabled(t *testing.T) {
ts := runTestServerCustomConfig(config.ServerConfig{
AccountKey: config.DefaultAccountKey,
ExplorerPath: "/tmp/nothing",
ExplorerBaseUrlLocation: config.ExplorerBaseUrlLocation,
DisableAuth: true,
})
defer ts.Server.Close()
t.Run("Should get 200 when wrong account key is used, but authentication is dissabled", func(t *testing.T) {
ts.Repository.DeleteDatabase(testDatabaseName)
client, err := azcosmos.NewClientFromConnectionString(
fmt.Sprintf("AccountEndpoint=%s;AccountKey=%s", ts.URL, "AAAA"),
&azcosmos.ClientOptions{},
)
assert.Nil(t, err)
createResponse, err := client.CreateDatabase(
context.TODO(),
azcosmos.DatabaseProperties{ID: testDatabaseName},
&azcosmos.CreateDatabaseOptions{})
assert.Nil(t, err)
assert.Equal(t, createResponse.DatabaseProperties.ID, testDatabaseName)
})
}

View File

@@ -10,22 +10,21 @@ import (
"github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos" "github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos"
"github.com/pikami/cosmium/api/config" "github.com/pikami/cosmium/api/config"
"github.com/pikami/cosmium/internal/repositories"
repositorymodels "github.com/pikami/cosmium/internal/repository_models" repositorymodels "github.com/pikami/cosmium/internal/repository_models"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func Test_Collections(t *testing.T) { func Test_Collections(t *testing.T) {
ts := runTestServer() ts := runTestServer()
defer ts.Close() defer ts.Server.Close()
client, err := azcosmos.NewClientFromConnectionString( client, err := azcosmos.NewClientFromConnectionString(
fmt.Sprintf("AccountEndpoint=%s;AccountKey=%s", ts.URL, config.Config.AccountKey), fmt.Sprintf("AccountEndpoint=%s;AccountKey=%s", ts.URL, config.DefaultAccountKey),
&azcosmos.ClientOptions{}, &azcosmos.ClientOptions{},
) )
assert.Nil(t, err) assert.Nil(t, err)
repositories.CreateDatabase(repositorymodels.Database{ID: testDatabaseName}) ts.Repository.CreateDatabase(repositorymodels.Database{ID: testDatabaseName})
databaseClient, err := client.NewDatabase(testDatabaseName) databaseClient, err := client.NewDatabase(testDatabaseName)
assert.Nil(t, err) assert.Nil(t, err)
@@ -40,7 +39,7 @@ func Test_Collections(t *testing.T) {
}) })
t.Run("Should return conflict when collection exists", func(t *testing.T) { t.Run("Should return conflict when collection exists", func(t *testing.T) {
repositories.CreateCollection(testDatabaseName, repositorymodels.Collection{ ts.Repository.CreateCollection(testDatabaseName, repositorymodels.Collection{
ID: testCollectionName, ID: testCollectionName,
}) })
@@ -60,7 +59,7 @@ func Test_Collections(t *testing.T) {
t.Run("Collection Read", func(t *testing.T) { t.Run("Collection Read", func(t *testing.T) {
t.Run("Should read collection", func(t *testing.T) { t.Run("Should read collection", func(t *testing.T) {
repositories.CreateCollection(testDatabaseName, repositorymodels.Collection{ ts.Repository.CreateCollection(testDatabaseName, repositorymodels.Collection{
ID: testCollectionName, ID: testCollectionName,
}) })
@@ -74,7 +73,7 @@ func Test_Collections(t *testing.T) {
}) })
t.Run("Should return not found when collection does not exist", func(t *testing.T) { t.Run("Should return not found when collection does not exist", func(t *testing.T) {
repositories.DeleteCollection(testDatabaseName, testCollectionName) ts.Repository.DeleteCollection(testDatabaseName, testCollectionName)
collectionResponse, err := databaseClient.NewContainer(testCollectionName) collectionResponse, err := databaseClient.NewContainer(testCollectionName)
assert.Nil(t, err) assert.Nil(t, err)
@@ -93,7 +92,7 @@ func Test_Collections(t *testing.T) {
t.Run("Collection Delete", func(t *testing.T) { t.Run("Collection Delete", func(t *testing.T) {
t.Run("Should delete collection", func(t *testing.T) { t.Run("Should delete collection", func(t *testing.T) {
repositories.CreateCollection(testDatabaseName, repositorymodels.Collection{ ts.Repository.CreateCollection(testDatabaseName, repositorymodels.Collection{
ID: testCollectionName, ID: testCollectionName,
}) })
@@ -106,7 +105,7 @@ func Test_Collections(t *testing.T) {
}) })
t.Run("Should return not found when collection does not exist", func(t *testing.T) { t.Run("Should return not found when collection does not exist", func(t *testing.T) {
repositories.DeleteCollection(testDatabaseName, testCollectionName) ts.Repository.DeleteCollection(testDatabaseName, testCollectionName)
collectionResponse, err := databaseClient.NewContainer(testCollectionName) collectionResponse, err := databaseClient.NewContainer(testCollectionName)
assert.Nil(t, err) assert.Nil(t, err)

View File

@@ -5,14 +5,37 @@ import (
"github.com/pikami/cosmium/api" "github.com/pikami/cosmium/api"
"github.com/pikami/cosmium/api/config" "github.com/pikami/cosmium/api/config"
"github.com/pikami/cosmium/internal/repositories"
) )
func runTestServer() *httptest.Server { type TestServer struct {
config.Config.AccountKey = config.DefaultAccountKey Server *httptest.Server
config.Config.ExplorerPath = "/tmp/nothing" Repository *repositories.DataRepository
config.Config.ExplorerBaseUrlLocation = config.ExplorerBaseUrlLocation URL string
}
return httptest.NewServer(api.CreateRouter()) func runTestServerCustomConfig(config config.ServerConfig) *TestServer {
repository := repositories.NewDataRepository(repositories.RepositoryOptions{})
api := api.NewApiServer(repository, config)
server := httptest.NewServer(api.GetRouter())
return &TestServer{
Server: server,
Repository: repository,
URL: server.URL,
}
}
func runTestServer() *TestServer {
config := config.ServerConfig{
AccountKey: config.DefaultAccountKey,
ExplorerPath: "/tmp/nothing",
ExplorerBaseUrlLocation: config.ExplorerBaseUrlLocation,
}
return runTestServerCustomConfig(config)
} }
const ( const (

View File

@@ -10,24 +10,23 @@ import (
"github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos" "github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos"
"github.com/pikami/cosmium/api/config" "github.com/pikami/cosmium/api/config"
"github.com/pikami/cosmium/internal/repositories"
repositorymodels "github.com/pikami/cosmium/internal/repository_models" repositorymodels "github.com/pikami/cosmium/internal/repository_models"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func Test_Databases(t *testing.T) { func Test_Databases(t *testing.T) {
ts := runTestServer() ts := runTestServer()
defer ts.Close() defer ts.Server.Close()
client, err := azcosmos.NewClientFromConnectionString( client, err := azcosmos.NewClientFromConnectionString(
fmt.Sprintf("AccountEndpoint=%s;AccountKey=%s", ts.URL, config.Config.AccountKey), fmt.Sprintf("AccountEndpoint=%s;AccountKey=%s", ts.URL, config.DefaultAccountKey),
&azcosmos.ClientOptions{}, &azcosmos.ClientOptions{},
) )
assert.Nil(t, err) assert.Nil(t, err)
t.Run("Database Create", func(t *testing.T) { t.Run("Database Create", func(t *testing.T) {
t.Run("Should create database", func(t *testing.T) { t.Run("Should create database", func(t *testing.T) {
repositories.DeleteDatabase(testDatabaseName) ts.Repository.DeleteDatabase(testDatabaseName)
createResponse, err := client.CreateDatabase(context.TODO(), azcosmos.DatabaseProperties{ createResponse, err := client.CreateDatabase(context.TODO(), azcosmos.DatabaseProperties{
ID: testDatabaseName, ID: testDatabaseName,
@@ -38,7 +37,7 @@ func Test_Databases(t *testing.T) {
}) })
t.Run("Should return conflict when database exists", func(t *testing.T) { t.Run("Should return conflict when database exists", func(t *testing.T) {
repositories.CreateDatabase(repositorymodels.Database{ ts.Repository.CreateDatabase(repositorymodels.Database{
ID: testDatabaseName, ID: testDatabaseName,
}) })
@@ -58,7 +57,7 @@ func Test_Databases(t *testing.T) {
t.Run("Database Read", func(t *testing.T) { t.Run("Database Read", func(t *testing.T) {
t.Run("Should read database", func(t *testing.T) { t.Run("Should read database", func(t *testing.T) {
repositories.CreateDatabase(repositorymodels.Database{ ts.Repository.CreateDatabase(repositorymodels.Database{
ID: testDatabaseName, ID: testDatabaseName,
}) })
@@ -72,7 +71,7 @@ func Test_Databases(t *testing.T) {
}) })
t.Run("Should return not found when database does not exist", func(t *testing.T) { t.Run("Should return not found when database does not exist", func(t *testing.T) {
repositories.DeleteDatabase(testDatabaseName) ts.Repository.DeleteDatabase(testDatabaseName)
databaseResponse, err := client.NewDatabase(testDatabaseName) databaseResponse, err := client.NewDatabase(testDatabaseName)
assert.Nil(t, err) assert.Nil(t, err)
@@ -91,7 +90,7 @@ func Test_Databases(t *testing.T) {
t.Run("Database Delete", func(t *testing.T) { t.Run("Database Delete", func(t *testing.T) {
t.Run("Should delete database", func(t *testing.T) { t.Run("Should delete database", func(t *testing.T) {
repositories.CreateDatabase(repositorymodels.Database{ ts.Repository.CreateDatabase(repositorymodels.Database{
ID: testDatabaseName, ID: testDatabaseName,
}) })
@@ -104,7 +103,7 @@ func Test_Databases(t *testing.T) {
}) })
t.Run("Should return not found when database does not exist", func(t *testing.T) { t.Run("Should return not found when database does not exist", func(t *testing.T) {
repositories.DeleteDatabase(testDatabaseName) ts.Repository.DeleteDatabase(testDatabaseName)
databaseResponse, err := client.NewDatabase(testDatabaseName) databaseResponse, err := client.NewDatabase(testDatabaseName)
assert.Nil(t, err) assert.Nil(t, err)

View File

@@ -6,7 +6,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"net/http" "net/http"
"net/http/httptest"
"reflect" "reflect"
"sync" "sync"
"testing" "testing"
@@ -15,7 +14,6 @@ import (
"github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos" "github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos"
"github.com/pikami/cosmium/api/config" "github.com/pikami/cosmium/api/config"
"github.com/pikami/cosmium/internal/repositories"
repositorymodels "github.com/pikami/cosmium/internal/repository_models" repositorymodels "github.com/pikami/cosmium/internal/repository_models"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@@ -55,9 +53,11 @@ func testCosmosQuery(t *testing.T,
} }
} }
func documents_InitializeDb(t *testing.T) (*httptest.Server, *azcosmos.ContainerClient) { func documents_InitializeDb(t *testing.T) (*TestServer, *azcosmos.ContainerClient) {
repositories.CreateDatabase(repositorymodels.Database{ID: testDatabaseName}) ts := runTestServer()
repositories.CreateCollection(testDatabaseName, repositorymodels.Collection{
ts.Repository.CreateDatabase(repositorymodels.Database{ID: testDatabaseName})
ts.Repository.CreateCollection(testDatabaseName, repositorymodels.Collection{
ID: testCollectionName, ID: testCollectionName,
PartitionKey: struct { PartitionKey: struct {
Paths []string "json:\"paths\"" Paths []string "json:\"paths\""
@@ -67,13 +67,11 @@ func documents_InitializeDb(t *testing.T) (*httptest.Server, *azcosmos.Container
Paths: []string{"/pk"}, Paths: []string{"/pk"},
}, },
}) })
repositories.CreateDocument(testDatabaseName, testCollectionName, map[string]interface{}{"id": "12345", "pk": "123", "isCool": false, "arr": []int{1, 2, 3}}) ts.Repository.CreateDocument(testDatabaseName, testCollectionName, map[string]interface{}{"id": "12345", "pk": "123", "isCool": false, "arr": []int{1, 2, 3}})
repositories.CreateDocument(testDatabaseName, testCollectionName, map[string]interface{}{"id": "67890", "pk": "456", "isCool": true, "arr": []int{6, 7, 8}}) ts.Repository.CreateDocument(testDatabaseName, testCollectionName, map[string]interface{}{"id": "67890", "pk": "456", "isCool": true, "arr": []int{6, 7, 8}})
ts := runTestServer()
client, err := azcosmos.NewClientFromConnectionString( client, err := azcosmos.NewClientFromConnectionString(
fmt.Sprintf("AccountEndpoint=%s;AccountKey=%s", ts.URL, config.Config.AccountKey), fmt.Sprintf("AccountEndpoint=%s;AccountKey=%s", ts.URL, config.DefaultAccountKey),
&azcosmos.ClientOptions{}, &azcosmos.ClientOptions{},
) )
assert.Nil(t, err) assert.Nil(t, err)
@@ -86,7 +84,7 @@ func documents_InitializeDb(t *testing.T) (*httptest.Server, *azcosmos.Container
func Test_Documents(t *testing.T) { func Test_Documents(t *testing.T) {
ts, collectionClient := documents_InitializeDb(t) ts, collectionClient := documents_InitializeDb(t)
defer ts.Close() defer ts.Server.Close()
t.Run("Should query document", func(t *testing.T) { t.Run("Should query document", func(t *testing.T) {
testCosmosQuery(t, collectionClient, testCosmosQuery(t, collectionClient,
@@ -149,6 +147,21 @@ func Test_Documents(t *testing.T) {
) )
}) })
t.Run("Should query document with query parameters as accessor", func(t *testing.T) {
testCosmosQuery(t, collectionClient,
`select c.id
FROM c
WHERE c[@param]="67890"
ORDER BY c.id`,
[]azcosmos.QueryParameter{
{Name: "@param", Value: "id"},
},
[]interface{}{
map[string]interface{}{"id": "67890"},
},
)
})
t.Run("Should query array accessor", func(t *testing.T) { t.Run("Should query array accessor", func(t *testing.T) {
testCosmosQuery(t, collectionClient, testCosmosQuery(t, collectionClient,
`SELECT c.id, `SELECT c.id,
@@ -218,15 +231,18 @@ func Test_Documents(t *testing.T) {
func Test_Documents_Patch(t *testing.T) { func Test_Documents_Patch(t *testing.T) {
ts, collectionClient := documents_InitializeDb(t) ts, collectionClient := documents_InitializeDb(t)
defer ts.Close() defer ts.Server.Close()
t.Run("Should PATCH document", func(t *testing.T) { t.Run("Should PATCH document", func(t *testing.T) {
context := context.TODO() context := context.TODO()
expectedData := map[string]interface{}{"id": "67890", "pk": "456", "newField": "newValue"} expectedData := map[string]interface{}{"id": "67890", "pk": "666", "newField": "newValue", "incr": 15., "setted": "isSet"}
patch := azcosmos.PatchOperations{} patch := azcosmos.PatchOperations{}
patch.AppendAdd("/newField", "newValue") patch.AppendAdd("/newField", "newValue")
patch.AppendIncrement("/incr", 15)
patch.AppendRemove("/isCool") patch.AppendRemove("/isCool")
patch.AppendReplace("/pk", "666")
patch.AppendSet("/setted", "isSet")
itemResponse, err := collectionClient.PatchItem( itemResponse, err := collectionClient.PatchItem(
context, context,
@@ -239,13 +255,15 @@ func Test_Documents_Patch(t *testing.T) {
) )
assert.Nil(t, err) assert.Nil(t, err)
var itemResponseBody map[string]string var itemResponseBody map[string]interface{}
json.Unmarshal(itemResponse.Value, &itemResponseBody) json.Unmarshal(itemResponse.Value, &itemResponseBody)
assert.Equal(t, expectedData["id"], itemResponseBody["id"]) assert.Equal(t, expectedData["id"], itemResponseBody["id"])
assert.Equal(t, expectedData["pk"], itemResponseBody["pk"]) assert.Equal(t, expectedData["pk"], itemResponseBody["pk"])
assert.Empty(t, itemResponseBody["isCool"]) assert.Empty(t, itemResponseBody["isCool"])
assert.Equal(t, expectedData["newField"], itemResponseBody["newField"]) assert.Equal(t, expectedData["newField"], itemResponseBody["newField"])
assert.Equal(t, expectedData["incr"], itemResponseBody["incr"])
assert.Equal(t, expectedData["setted"], itemResponseBody["setted"])
}) })
t.Run("Should not allow to PATCH document ID", func(t *testing.T) { t.Run("Should not allow to PATCH document ID", func(t *testing.T) {

View File

@@ -15,14 +15,14 @@ import (
// Request document with trailing slash like python cosmosdb client does. // Request document with trailing slash like python cosmosdb client does.
func Test_Documents_Read_Trailing_Slash(t *testing.T) { func Test_Documents_Read_Trailing_Slash(t *testing.T) {
ts, _ := documents_InitializeDb(t) ts, _ := documents_InitializeDb(t)
defer ts.Close() defer ts.Server.Close()
t.Run("Read doc with client that appends slash to path", func(t *testing.T) { t.Run("Read doc with client that appends slash to path", func(t *testing.T) {
resourceIdTemplate := "dbs/%s/colls/%s/docs/%s" resourceIdTemplate := "dbs/%s/colls/%s/docs/%s"
path := fmt.Sprintf(resourceIdTemplate, testDatabaseName, testCollectionName, "12345") path := fmt.Sprintf(resourceIdTemplate, testDatabaseName, testCollectionName, "12345")
testUrl := ts.URL + "/" + path + "/" testUrl := ts.URL + "/" + path + "/"
date := time.Now().Format(time.RFC1123) date := time.Now().Format(time.RFC1123)
signature := authentication.GenerateSignature("GET", "docs", path, date, config.Config.AccountKey) signature := authentication.GenerateSignature("GET", "docs", path, date, config.DefaultAccountKey)
httpClient := &http.Client{} httpClient := &http.Client{}
req, _ := http.NewRequest("GET", testUrl, nil) req, _ := http.NewRequest("GET", testUrl, nil)
req.Header.Add("x-ms-date", date) req.Header.Add("x-ms-date", date)

View File

@@ -11,16 +11,20 @@ import (
) )
func main() { func main() {
config.ParseFlags() configuration := config.ParseFlags()
repositories.InitializeRepository() repository := repositories.NewDataRepository(repositories.RepositoryOptions{
InitialDataFilePath: configuration.InitialDataFilePath,
PersistDataFilePath: configuration.PersistDataFilePath,
})
server := api.StartAPI() server := api.NewApiServer(repository, configuration)
server.Start()
waitForExit(server) waitForExit(server, repository, configuration)
} }
func waitForExit(server *api.Server) { func waitForExit(server *api.ApiServer, repository *repositories.DataRepository, config config.ServerConfig) {
sigs := make(chan os.Signal, 1) sigs := make(chan os.Signal, 1)
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
@@ -28,9 +32,9 @@ func waitForExit(server *api.Server) {
<-sigs <-sigs
// Stop the server // Stop the server
server.StopServer <- true server.Stop()
if config.Config.PersistDataFilePath != "" { if config.PersistDataFilePath != "" {
repositories.SaveStateFS(config.Config.PersistDataFilePath) repository.SaveStateFS(config.PersistDataFilePath)
} }
} }

4
go.mod
View File

@@ -5,9 +5,9 @@ go 1.22.0
require ( require (
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.9.2 github.com/Azure/azure-sdk-for-go/sdk/azcore v1.9.2
github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos v0.3.6 github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos v0.3.6
github.com/evanphx/json-patch/v5 v5.9.0
github.com/gin-gonic/gin v1.10.0 github.com/gin-gonic/gin v1.10.0
github.com/google/uuid v1.6.0 github.com/google/uuid v1.6.0
github.com/pikami/json-patch/v5 v5.9.2
github.com/stretchr/testify v1.9.0 github.com/stretchr/testify v1.9.0
golang.org/x/exp v0.0.0-20241217172543-b2144cdd0a67 golang.org/x/exp v0.0.0-20241217172543-b2144cdd0a67
) )
@@ -39,7 +39,7 @@ require (
github.com/ugorji/go/codec v1.2.12 // indirect github.com/ugorji/go/codec v1.2.12 // indirect
golang.org/x/arch v0.12.0 // indirect golang.org/x/arch v0.12.0 // indirect
golang.org/x/crypto v0.31.0 // indirect golang.org/x/crypto v0.31.0 // indirect
golang.org/x/net v0.32.0 // indirect golang.org/x/net v0.33.0 // indirect
golang.org/x/sys v0.28.0 // indirect golang.org/x/sys v0.28.0 // indirect
golang.org/x/text v0.21.0 // indirect golang.org/x/text v0.21.0 // indirect
google.golang.org/protobuf v1.36.0 // indirect google.golang.org/protobuf v1.36.0 // indirect

8
go.sum
View File

@@ -22,8 +22,6 @@ github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQ
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/evanphx/json-patch/v5 v5.9.0 h1:kcBlZQbplgElYIlo/n1hJbls2z/1awpXxpRi0/FOJfg=
github.com/evanphx/json-patch/v5 v5.9.0/go.mod h1:VNkHZ/282BpEyt/tObQO8s5CMPmYYq14uClGH4abBuQ=
github.com/gabriel-vasile/mimetype v1.4.7 h1:SKFKl7kD0RiPdbht0s7hFtjl489WcQ1VyPW8ZzUMYCA= github.com/gabriel-vasile/mimetype v1.4.7 h1:SKFKl7kD0RiPdbht0s7hFtjl489WcQ1VyPW8ZzUMYCA=
github.com/gabriel-vasile/mimetype v1.4.7/go.mod h1:GDlAgAyIRT27BhFl53XNAFtfjzOkLaF35JdEG0P7LtU= github.com/gabriel-vasile/mimetype v1.4.7/go.mod h1:GDlAgAyIRT27BhFl53XNAFtfjzOkLaF35JdEG0P7LtU=
github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
@@ -66,6 +64,8 @@ github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9G
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/pelletier/go-toml/v2 v2.2.3 h1:YmeHyLY8mFWbdkNWwpr+qIL2bEqT0o95WSdkNHvL12M= github.com/pelletier/go-toml/v2 v2.2.3 h1:YmeHyLY8mFWbdkNWwpr+qIL2bEqT0o95WSdkNHvL12M=
github.com/pelletier/go-toml/v2 v2.2.3/go.mod h1:MfCQTFTvCcUyyvvwm1+G6H/jORL20Xlb6rzQu9GuUkc= github.com/pelletier/go-toml/v2 v2.2.3/go.mod h1:MfCQTFTvCcUyyvvwm1+G6H/jORL20Xlb6rzQu9GuUkc=
github.com/pikami/json-patch/v5 v5.9.2 h1:ciTlocWccYVE3DEa45dsMm02c/tOvcaBY7PpEUNZhrU=
github.com/pikami/json-patch/v5 v5.9.2/go.mod h1:eJIScZ4xgf2aBHLi2UMzYtjlWESUBDOBf7EAx3JW0nI=
github.com/pkg/browser v0.0.0-20210115035449-ce105d075bb4 h1:Qj1ukM4GlMWXNdMBuXcXfz/Kw9s1qm0CLY32QxuSImI= github.com/pkg/browser v0.0.0-20210115035449-ce105d075bb4 h1:Qj1ukM4GlMWXNdMBuXcXfz/Kw9s1qm0CLY32QxuSImI=
github.com/pkg/browser v0.0.0-20210115035449-ce105d075bb4/go.mod h1:N6UoU20jOqggOuDwUaBQpluzLNDqif3kq9z2wpdYEfQ= github.com/pkg/browser v0.0.0-20210115035449-ce105d075bb4/go.mod h1:N6UoU20jOqggOuDwUaBQpluzLNDqif3kq9z2wpdYEfQ=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
@@ -92,8 +92,8 @@ golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U=
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
golang.org/x/exp v0.0.0-20241217172543-b2144cdd0a67 h1:1UoZQm6f0P/ZO0w1Ri+f+ifG/gXhegadRdwBIXEFWDo= golang.org/x/exp v0.0.0-20241217172543-b2144cdd0a67 h1:1UoZQm6f0P/ZO0w1Ri+f+ifG/gXhegadRdwBIXEFWDo=
golang.org/x/exp v0.0.0-20241217172543-b2144cdd0a67/go.mod h1:qj5a5QZpwLU2NLQudwIN5koi3beDhSAlJwa67PuM98c= golang.org/x/exp v0.0.0-20241217172543-b2144cdd0a67/go.mod h1:qj5a5QZpwLU2NLQudwIN5koi3beDhSAlJwa67PuM98c=
golang.org/x/net v0.32.0 h1:ZqPmj8Kzc+Y6e0+skZsuACbx+wzMgo5MQsJh9Qd6aYI= golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I=
golang.org/x/net v0.32.0/go.mod h1:CwU0IoeOlnQQWJ6ioyFrfRuomB8GKF6KbYXZVyeXNfs= golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=

View File

@@ -3,22 +3,22 @@ package logger
import ( import (
"log" "log"
"os" "os"
"github.com/pikami/cosmium/api/config"
) )
var EnableDebugOutput = false
var DebugLogger = log.New(os.Stdout, "", log.Ldate|log.Ltime|log.Lshortfile) var DebugLogger = log.New(os.Stdout, "", log.Ldate|log.Ltime|log.Lshortfile)
var InfoLogger = log.New(os.Stdout, "", log.Ldate|log.Ltime) var InfoLogger = log.New(os.Stdout, "", log.Ldate|log.Ltime)
var ErrorLogger = log.New(os.Stderr, "", log.Ldate|log.Ltime|log.Lshortfile) var ErrorLogger = log.New(os.Stderr, "", log.Ldate|log.Ltime|log.Lshortfile)
func Debug(v ...any) { func Debug(v ...any) {
if config.Config.Debug { if EnableDebugOutput {
DebugLogger.Println(v...) DebugLogger.Println(v...)
} }
} }
func Debugf(format string, v ...any) { func Debugf(format string, v ...any) {
if config.Config.Debug { if EnableDebugOutput {
DebugLogger.Printf(format, v...) DebugLogger.Printf(format, v...)
} }
} }

View File

@@ -11,60 +11,60 @@ import (
"golang.org/x/exp/maps" "golang.org/x/exp/maps"
) )
func GetAllCollections(databaseId string) ([]repositorymodels.Collection, repositorymodels.RepositoryStatus) { func (r *DataRepository) GetAllCollections(databaseId string) ([]repositorymodels.Collection, repositorymodels.RepositoryStatus) {
storeState.RLock() r.storeState.RLock()
defer storeState.RUnlock() defer r.storeState.RUnlock()
if _, ok := storeState.Databases[databaseId]; !ok { if _, ok := r.storeState.Databases[databaseId]; !ok {
return make([]repositorymodels.Collection, 0), repositorymodels.StatusNotFound return make([]repositorymodels.Collection, 0), repositorymodels.StatusNotFound
} }
return maps.Values(storeState.Collections[databaseId]), repositorymodels.StatusOk return maps.Values(r.storeState.Collections[databaseId]), repositorymodels.StatusOk
} }
func GetCollection(databaseId string, collectionId string) (repositorymodels.Collection, repositorymodels.RepositoryStatus) { func (r *DataRepository) GetCollection(databaseId string, collectionId string) (repositorymodels.Collection, repositorymodels.RepositoryStatus) {
storeState.RLock() r.storeState.RLock()
defer storeState.RUnlock() defer r.storeState.RUnlock()
if _, ok := storeState.Databases[databaseId]; !ok { if _, ok := r.storeState.Databases[databaseId]; !ok {
return repositorymodels.Collection{}, repositorymodels.StatusNotFound return repositorymodels.Collection{}, repositorymodels.StatusNotFound
} }
if _, ok := storeState.Collections[databaseId][collectionId]; !ok { if _, ok := r.storeState.Collections[databaseId][collectionId]; !ok {
return repositorymodels.Collection{}, repositorymodels.StatusNotFound return repositorymodels.Collection{}, repositorymodels.StatusNotFound
} }
return storeState.Collections[databaseId][collectionId], repositorymodels.StatusOk return r.storeState.Collections[databaseId][collectionId], repositorymodels.StatusOk
} }
func DeleteCollection(databaseId string, collectionId string) repositorymodels.RepositoryStatus { func (r *DataRepository) DeleteCollection(databaseId string, collectionId string) repositorymodels.RepositoryStatus {
storeState.Lock() r.storeState.Lock()
defer storeState.Unlock() defer r.storeState.Unlock()
if _, ok := storeState.Databases[databaseId]; !ok { if _, ok := r.storeState.Databases[databaseId]; !ok {
return repositorymodels.StatusNotFound return repositorymodels.StatusNotFound
} }
if _, ok := storeState.Collections[databaseId][collectionId]; !ok { if _, ok := r.storeState.Collections[databaseId][collectionId]; !ok {
return repositorymodels.StatusNotFound return repositorymodels.StatusNotFound
} }
delete(storeState.Collections[databaseId], collectionId) delete(r.storeState.Collections[databaseId], collectionId)
return repositorymodels.StatusOk return repositorymodels.StatusOk
} }
func CreateCollection(databaseId string, newCollection repositorymodels.Collection) (repositorymodels.Collection, repositorymodels.RepositoryStatus) { func (r *DataRepository) CreateCollection(databaseId string, newCollection repositorymodels.Collection) (repositorymodels.Collection, repositorymodels.RepositoryStatus) {
storeState.Lock() r.storeState.Lock()
defer storeState.Unlock() defer r.storeState.Unlock()
var ok bool var ok bool
var database repositorymodels.Database var database repositorymodels.Database
if database, ok = storeState.Databases[databaseId]; !ok { if database, ok = r.storeState.Databases[databaseId]; !ok {
return repositorymodels.Collection{}, repositorymodels.StatusNotFound return repositorymodels.Collection{}, repositorymodels.StatusNotFound
} }
if _, ok = storeState.Collections[databaseId][newCollection.ID]; ok { if _, ok = r.storeState.Collections[databaseId][newCollection.ID]; ok {
return repositorymodels.Collection{}, repositorymodels.Conflict return repositorymodels.Collection{}, repositorymodels.Conflict
} }
@@ -75,8 +75,8 @@ func CreateCollection(databaseId string, newCollection repositorymodels.Collecti
newCollection.ETag = fmt.Sprintf("\"%s\"", uuid.New()) newCollection.ETag = fmt.Sprintf("\"%s\"", uuid.New())
newCollection.Self = fmt.Sprintf("dbs/%s/colls/%s/", database.ResourceID, newCollection.ResourceID) newCollection.Self = fmt.Sprintf("dbs/%s/colls/%s/", database.ResourceID, newCollection.ResourceID)
storeState.Collections[databaseId][newCollection.ID] = newCollection r.storeState.Collections[databaseId][newCollection.ID] = newCollection
storeState.Documents[databaseId][newCollection.ID] = make(map[string]repositorymodels.Document) r.storeState.Documents[databaseId][newCollection.ID] = make(map[string]repositorymodels.Document)
return newCollection, repositorymodels.StatusOk return newCollection, repositorymodels.StatusOk
} }

View File

@@ -10,42 +10,42 @@ import (
"golang.org/x/exp/maps" "golang.org/x/exp/maps"
) )
func GetAllDatabases() ([]repositorymodels.Database, repositorymodels.RepositoryStatus) { func (r *DataRepository) GetAllDatabases() ([]repositorymodels.Database, repositorymodels.RepositoryStatus) {
storeState.RLock() r.storeState.RLock()
defer storeState.RUnlock() defer r.storeState.RUnlock()
return maps.Values(storeState.Databases), repositorymodels.StatusOk return maps.Values(r.storeState.Databases), repositorymodels.StatusOk
} }
func GetDatabase(id string) (repositorymodels.Database, repositorymodels.RepositoryStatus) { func (r *DataRepository) GetDatabase(id string) (repositorymodels.Database, repositorymodels.RepositoryStatus) {
storeState.RLock() r.storeState.RLock()
defer storeState.RUnlock() defer r.storeState.RUnlock()
if database, ok := storeState.Databases[id]; ok { if database, ok := r.storeState.Databases[id]; ok {
return database, repositorymodels.StatusOk return database, repositorymodels.StatusOk
} }
return repositorymodels.Database{}, repositorymodels.StatusNotFound return repositorymodels.Database{}, repositorymodels.StatusNotFound
} }
func DeleteDatabase(id string) repositorymodels.RepositoryStatus { func (r *DataRepository) DeleteDatabase(id string) repositorymodels.RepositoryStatus {
storeState.Lock() r.storeState.Lock()
defer storeState.Unlock() defer r.storeState.Unlock()
if _, ok := storeState.Databases[id]; !ok { if _, ok := r.storeState.Databases[id]; !ok {
return repositorymodels.StatusNotFound return repositorymodels.StatusNotFound
} }
delete(storeState.Databases, id) delete(r.storeState.Databases, id)
return repositorymodels.StatusOk return repositorymodels.StatusOk
} }
func CreateDatabase(newDatabase repositorymodels.Database) (repositorymodels.Database, repositorymodels.RepositoryStatus) { func (r *DataRepository) CreateDatabase(newDatabase repositorymodels.Database) (repositorymodels.Database, repositorymodels.RepositoryStatus) {
storeState.Lock() r.storeState.Lock()
defer storeState.Unlock() defer r.storeState.Unlock()
if _, ok := storeState.Databases[newDatabase.ID]; ok { if _, ok := r.storeState.Databases[newDatabase.ID]; ok {
return repositorymodels.Database{}, repositorymodels.Conflict return repositorymodels.Database{}, repositorymodels.Conflict
} }
@@ -54,9 +54,9 @@ func CreateDatabase(newDatabase repositorymodels.Database) (repositorymodels.Dat
newDatabase.ETag = fmt.Sprintf("\"%s\"", uuid.New()) newDatabase.ETag = fmt.Sprintf("\"%s\"", uuid.New())
newDatabase.Self = fmt.Sprintf("dbs/%s/", newDatabase.ResourceID) newDatabase.Self = fmt.Sprintf("dbs/%s/", newDatabase.ResourceID)
storeState.Databases[newDatabase.ID] = newDatabase r.storeState.Databases[newDatabase.ID] = newDatabase
storeState.Collections[newDatabase.ID] = make(map[string]repositorymodels.Collection) r.storeState.Collections[newDatabase.ID] = make(map[string]repositorymodels.Collection)
storeState.Documents[newDatabase.ID] = make(map[string]map[string]repositorymodels.Document) r.storeState.Documents[newDatabase.ID] = make(map[string]map[string]repositorymodels.Document)
return newDatabase, repositorymodels.StatusOk return newDatabase, repositorymodels.StatusOk
} }

View File

@@ -14,64 +14,64 @@ import (
"golang.org/x/exp/maps" "golang.org/x/exp/maps"
) )
func GetAllDocuments(databaseId string, collectionId string) ([]repositorymodels.Document, repositorymodels.RepositoryStatus) { func (r *DataRepository) GetAllDocuments(databaseId string, collectionId string) ([]repositorymodels.Document, repositorymodels.RepositoryStatus) {
storeState.RLock() r.storeState.RLock()
defer storeState.RUnlock() defer r.storeState.RUnlock()
if _, ok := storeState.Databases[databaseId]; !ok { if _, ok := r.storeState.Databases[databaseId]; !ok {
return make([]repositorymodels.Document, 0), repositorymodels.StatusNotFound return make([]repositorymodels.Document, 0), repositorymodels.StatusNotFound
} }
if _, ok := storeState.Collections[databaseId][collectionId]; !ok { if _, ok := r.storeState.Collections[databaseId][collectionId]; !ok {
return make([]repositorymodels.Document, 0), repositorymodels.StatusNotFound return make([]repositorymodels.Document, 0), repositorymodels.StatusNotFound
} }
return maps.Values(storeState.Documents[databaseId][collectionId]), repositorymodels.StatusOk return maps.Values(r.storeState.Documents[databaseId][collectionId]), repositorymodels.StatusOk
} }
func GetDocument(databaseId string, collectionId string, documentId string) (repositorymodels.Document, repositorymodels.RepositoryStatus) { func (r *DataRepository) GetDocument(databaseId string, collectionId string, documentId string) (repositorymodels.Document, repositorymodels.RepositoryStatus) {
storeState.RLock() r.storeState.RLock()
defer storeState.RUnlock() defer r.storeState.RUnlock()
if _, ok := storeState.Databases[databaseId]; !ok { if _, ok := r.storeState.Databases[databaseId]; !ok {
return repositorymodels.Document{}, repositorymodels.StatusNotFound return repositorymodels.Document{}, repositorymodels.StatusNotFound
} }
if _, ok := storeState.Collections[databaseId][collectionId]; !ok { if _, ok := r.storeState.Collections[databaseId][collectionId]; !ok {
return repositorymodels.Document{}, repositorymodels.StatusNotFound return repositorymodels.Document{}, repositorymodels.StatusNotFound
} }
if _, ok := storeState.Documents[databaseId][collectionId][documentId]; !ok { if _, ok := r.storeState.Documents[databaseId][collectionId][documentId]; !ok {
return repositorymodels.Document{}, repositorymodels.StatusNotFound return repositorymodels.Document{}, repositorymodels.StatusNotFound
} }
return storeState.Documents[databaseId][collectionId][documentId], repositorymodels.StatusOk return r.storeState.Documents[databaseId][collectionId][documentId], repositorymodels.StatusOk
} }
func DeleteDocument(databaseId string, collectionId string, documentId string) repositorymodels.RepositoryStatus { func (r *DataRepository) DeleteDocument(databaseId string, collectionId string, documentId string) repositorymodels.RepositoryStatus {
storeState.Lock() r.storeState.Lock()
defer storeState.Unlock() defer r.storeState.Unlock()
if _, ok := storeState.Databases[databaseId]; !ok { if _, ok := r.storeState.Databases[databaseId]; !ok {
return repositorymodels.StatusNotFound return repositorymodels.StatusNotFound
} }
if _, ok := storeState.Collections[databaseId][collectionId]; !ok { if _, ok := r.storeState.Collections[databaseId][collectionId]; !ok {
return repositorymodels.StatusNotFound return repositorymodels.StatusNotFound
} }
if _, ok := storeState.Documents[databaseId][collectionId][documentId]; !ok { if _, ok := r.storeState.Documents[databaseId][collectionId][documentId]; !ok {
return repositorymodels.StatusNotFound return repositorymodels.StatusNotFound
} }
delete(storeState.Documents[databaseId][collectionId], documentId) delete(r.storeState.Documents[databaseId][collectionId], documentId)
return repositorymodels.StatusOk return repositorymodels.StatusOk
} }
func CreateDocument(databaseId string, collectionId string, document map[string]interface{}) (repositorymodels.Document, repositorymodels.RepositoryStatus) { func (r *DataRepository) CreateDocument(databaseId string, collectionId string, document map[string]interface{}) (repositorymodels.Document, repositorymodels.RepositoryStatus) {
storeState.Lock() r.storeState.Lock()
defer storeState.Unlock() defer r.storeState.Unlock()
var ok bool var ok bool
var documentId string var documentId string
@@ -82,15 +82,15 @@ func CreateDocument(databaseId string, collectionId string, document map[string]
document["id"] = documentId document["id"] = documentId
} }
if database, ok = storeState.Databases[databaseId]; !ok { if database, ok = r.storeState.Databases[databaseId]; !ok {
return repositorymodels.Document{}, repositorymodels.StatusNotFound return repositorymodels.Document{}, repositorymodels.StatusNotFound
} }
if collection, ok = storeState.Collections[databaseId][collectionId]; !ok { if collection, ok = r.storeState.Collections[databaseId][collectionId]; !ok {
return repositorymodels.Document{}, repositorymodels.StatusNotFound return repositorymodels.Document{}, repositorymodels.StatusNotFound
} }
if _, ok := storeState.Documents[databaseId][collectionId][documentId]; ok { if _, ok := r.storeState.Documents[databaseId][collectionId][documentId]; ok {
return repositorymodels.Document{}, repositorymodels.Conflict return repositorymodels.Document{}, repositorymodels.Conflict
} }
@@ -99,19 +99,19 @@ func CreateDocument(databaseId string, collectionId string, document map[string]
document["_etag"] = fmt.Sprintf("\"%s\"", uuid.New()) document["_etag"] = fmt.Sprintf("\"%s\"", uuid.New())
document["_self"] = fmt.Sprintf("dbs/%s/colls/%s/docs/%s/", database.ResourceID, collection.ResourceID, document["_rid"]) document["_self"] = fmt.Sprintf("dbs/%s/colls/%s/docs/%s/", database.ResourceID, collection.ResourceID, document["_rid"])
storeState.Documents[databaseId][collectionId][documentId] = document r.storeState.Documents[databaseId][collectionId][documentId] = document
return document, repositorymodels.StatusOk return document, repositorymodels.StatusOk
} }
func ExecuteQueryDocuments(databaseId string, collectionId string, query string, queryParameters map[string]interface{}) ([]memoryexecutor.RowType, repositorymodels.RepositoryStatus) { func (r *DataRepository) ExecuteQueryDocuments(databaseId string, collectionId string, query string, queryParameters map[string]interface{}) ([]memoryexecutor.RowType, repositorymodels.RepositoryStatus) {
parsedQuery, err := nosql.Parse("", []byte(query)) parsedQuery, err := nosql.Parse("", []byte(query))
if err != nil { if err != nil {
log.Printf("Failed to parse query: %s\nerr: %v", query, err) log.Printf("Failed to parse query: %s\nerr: %v", query, err)
return nil, repositorymodels.BadRequest return nil, repositorymodels.BadRequest
} }
collectionDocuments, status := GetAllDocuments(databaseId, collectionId) collectionDocuments, status := r.GetAllDocuments(databaseId, collectionId)
if status != repositorymodels.StatusOk { if status != repositorymodels.StatusOk {
return nil, status return nil, status
} }

View File

@@ -9,19 +9,19 @@ import (
) )
// I have no idea what this is tbh // I have no idea what this is tbh
func GetPartitionKeyRanges(databaseId string, collectionId string) ([]repositorymodels.PartitionKeyRange, repositorymodels.RepositoryStatus) { func (r *DataRepository) GetPartitionKeyRanges(databaseId string, collectionId string) ([]repositorymodels.PartitionKeyRange, repositorymodels.RepositoryStatus) {
storeState.RLock() r.storeState.RLock()
defer storeState.RUnlock() defer r.storeState.RUnlock()
databaseRid := databaseId databaseRid := databaseId
collectionRid := collectionId collectionRid := collectionId
var timestamp int64 = 0 var timestamp int64 = 0
if database, ok := storeState.Databases[databaseId]; !ok { if database, ok := r.storeState.Databases[databaseId]; !ok {
databaseRid = database.ResourceID databaseRid = database.ResourceID
} }
if collection, ok := storeState.Collections[databaseId][collectionId]; !ok { if collection, ok := r.storeState.Collections[databaseId][collectionId]; !ok {
collectionRid = collection.ResourceID collectionRid = collection.ResourceID
timestamp = collection.TimeStamp timestamp = collection.TimeStamp
} }

View File

@@ -0,0 +1,37 @@
package repositories
import repositorymodels "github.com/pikami/cosmium/internal/repository_models"
type DataRepository struct {
storedProcedures []repositorymodels.StoredProcedure
triggers []repositorymodels.Trigger
userDefinedFunctions []repositorymodels.UserDefinedFunction
storeState repositorymodels.State
initialDataFilePath string
persistDataFilePath string
}
type RepositoryOptions struct {
InitialDataFilePath string
PersistDataFilePath string
}
func NewDataRepository(options RepositoryOptions) *DataRepository {
repository := &DataRepository{
storedProcedures: []repositorymodels.StoredProcedure{},
triggers: []repositorymodels.Trigger{},
userDefinedFunctions: []repositorymodels.UserDefinedFunction{},
storeState: repositorymodels.State{
Databases: make(map[string]repositorymodels.Database),
Collections: make(map[string]map[string]repositorymodels.Collection),
Documents: make(map[string]map[string]map[string]repositorymodels.Document),
},
initialDataFilePath: options.InitialDataFilePath,
persistDataFilePath: options.PersistDataFilePath,
}
repository.InitializeRepository()
return repository
}

View File

@@ -6,28 +6,18 @@ import (
"os" "os"
"reflect" "reflect"
"github.com/pikami/cosmium/api/config"
"github.com/pikami/cosmium/internal/logger" "github.com/pikami/cosmium/internal/logger"
repositorymodels "github.com/pikami/cosmium/internal/repository_models" repositorymodels "github.com/pikami/cosmium/internal/repository_models"
) )
var storedProcedures = []repositorymodels.StoredProcedure{} func (r *DataRepository) InitializeRepository() {
var triggers = []repositorymodels.Trigger{} if r.initialDataFilePath != "" {
var userDefinedFunctions = []repositorymodels.UserDefinedFunction{} r.LoadStateFS(r.initialDataFilePath)
var storeState = repositorymodels.State{
Databases: make(map[string]repositorymodels.Database),
Collections: make(map[string]map[string]repositorymodels.Collection),
Documents: make(map[string]map[string]map[string]repositorymodels.Document),
}
func InitializeRepository() {
if config.Config.InitialDataFilePath != "" {
LoadStateFS(config.Config.InitialDataFilePath)
return return
} }
if config.Config.PersistDataFilePath != "" { if r.persistDataFilePath != "" {
stat, err := os.Stat(config.Config.PersistDataFilePath) stat, err := os.Stat(r.persistDataFilePath)
if err != nil { if err != nil {
return return
} }
@@ -37,39 +27,52 @@ func InitializeRepository() {
os.Exit(1) os.Exit(1)
} }
LoadStateFS(config.Config.PersistDataFilePath) r.LoadStateFS(r.persistDataFilePath)
return return
} }
} }
func LoadStateFS(filePath string) { func (r *DataRepository) LoadStateFS(filePath string) {
data, err := os.ReadFile(filePath) data, err := os.ReadFile(filePath)
if err != nil { if err != nil {
log.Fatalf("Error reading state JSON file: %v", err) log.Fatalf("Error reading state JSON file: %v", err)
return return
} }
var state repositorymodels.State err = r.LoadStateJSON(string(data))
if err := json.Unmarshal(data, &state); err != nil { if err != nil {
log.Fatalf("Error unmarshalling state JSON: %v", err) log.Fatalf("Error unmarshalling state JSON: %v", err)
return
} }
logger.Info("Loaded state:")
logger.Infof("Databases: %d\n", getLength(state.Databases))
logger.Infof("Collections: %d\n", getLength(state.Collections))
logger.Infof("Documents: %d\n", getLength(state.Documents))
storeState = state
ensureStoreStateNoNullReferences()
} }
func SaveStateFS(filePath string) { func (r *DataRepository) LoadStateJSON(jsonData string) error {
storeState.RLock() r.storeState.RLock()
defer storeState.RUnlock() defer r.storeState.RUnlock()
data, err := json.MarshalIndent(storeState, "", "\t") var state repositorymodels.State
if err := json.Unmarshal([]byte(jsonData), &state); err != nil {
return err
}
r.storeState.Collections = state.Collections
r.storeState.Databases = state.Databases
r.storeState.Documents = state.Documents
r.ensureStoreStateNoNullReferences()
logger.Info("Loaded state:")
logger.Infof("Databases: %d\n", getLength(r.storeState.Databases))
logger.Infof("Collections: %d\n", getLength(r.storeState.Collections))
logger.Infof("Documents: %d\n", getLength(r.storeState.Documents))
return nil
}
func (r *DataRepository) SaveStateFS(filePath string) {
r.storeState.RLock()
defer r.storeState.RUnlock()
data, err := json.MarshalIndent(r.storeState, "", "\t")
if err != nil { if err != nil {
logger.Errorf("Failed to save state: %v\n", err) logger.Errorf("Failed to save state: %v\n", err)
return return
@@ -78,16 +81,16 @@ func SaveStateFS(filePath string) {
os.WriteFile(filePath, data, os.ModePerm) os.WriteFile(filePath, data, os.ModePerm)
logger.Info("Saved state:") logger.Info("Saved state:")
logger.Infof("Databases: %d\n", getLength(storeState.Databases)) logger.Infof("Databases: %d\n", getLength(r.storeState.Databases))
logger.Infof("Collections: %d\n", getLength(storeState.Collections)) logger.Infof("Collections: %d\n", getLength(r.storeState.Collections))
logger.Infof("Documents: %d\n", getLength(storeState.Documents)) logger.Infof("Documents: %d\n", getLength(r.storeState.Documents))
} }
func GetState() (string, error) { func (r *DataRepository) GetState() (string, error) {
storeState.RLock() r.storeState.RLock()
defer storeState.RUnlock() defer r.storeState.RUnlock()
data, err := json.MarshalIndent(storeState, "", "\t") data, err := json.MarshalIndent(r.storeState, "", "\t")
if err != nil { if err != nil {
logger.Errorf("Failed to serialize state: %v\n", err) logger.Errorf("Failed to serialize state: %v\n", err)
return "", err return "", err
@@ -121,36 +124,36 @@ func getLength(v interface{}) int {
return count return count
} }
func ensureStoreStateNoNullReferences() { func (r *DataRepository) ensureStoreStateNoNullReferences() {
if storeState.Databases == nil { if r.storeState.Databases == nil {
storeState.Databases = make(map[string]repositorymodels.Database) r.storeState.Databases = make(map[string]repositorymodels.Database)
} }
if storeState.Collections == nil { if r.storeState.Collections == nil {
storeState.Collections = make(map[string]map[string]repositorymodels.Collection) r.storeState.Collections = make(map[string]map[string]repositorymodels.Collection)
} }
if storeState.Documents == nil { if r.storeState.Documents == nil {
storeState.Documents = make(map[string]map[string]map[string]repositorymodels.Document) r.storeState.Documents = make(map[string]map[string]map[string]repositorymodels.Document)
} }
for database := range storeState.Databases { for database := range r.storeState.Databases {
if storeState.Collections[database] == nil { if r.storeState.Collections[database] == nil {
storeState.Collections[database] = make(map[string]repositorymodels.Collection) r.storeState.Collections[database] = make(map[string]repositorymodels.Collection)
} }
if storeState.Documents[database] == nil { if r.storeState.Documents[database] == nil {
storeState.Documents[database] = make(map[string]map[string]repositorymodels.Document) r.storeState.Documents[database] = make(map[string]map[string]repositorymodels.Document)
} }
for collection := range storeState.Collections[database] { for collection := range r.storeState.Collections[database] {
if storeState.Documents[database][collection] == nil { if r.storeState.Documents[database][collection] == nil {
storeState.Documents[database][collection] = make(map[string]repositorymodels.Document) r.storeState.Documents[database][collection] = make(map[string]repositorymodels.Document)
} }
for document := range storeState.Documents[database][collection] { for document := range r.storeState.Documents[database][collection] {
if storeState.Documents[database][collection][document] == nil { if r.storeState.Documents[database][collection][document] == nil {
delete(storeState.Documents[database][collection], document) delete(r.storeState.Documents[database][collection], document)
} }
} }
} }

View File

@@ -2,6 +2,6 @@ package repositories
import repositorymodels "github.com/pikami/cosmium/internal/repository_models" import repositorymodels "github.com/pikami/cosmium/internal/repository_models"
func GetAllStoredProcedures(databaseId string, collectionId string) ([]repositorymodels.StoredProcedure, repositorymodels.RepositoryStatus) { func (r *DataRepository) GetAllStoredProcedures(databaseId string, collectionId string) ([]repositorymodels.StoredProcedure, repositorymodels.RepositoryStatus) {
return storedProcedures, repositorymodels.StatusOk return r.storedProcedures, repositorymodels.StatusOk
} }

View File

@@ -2,6 +2,6 @@ package repositories
import repositorymodels "github.com/pikami/cosmium/internal/repository_models" import repositorymodels "github.com/pikami/cosmium/internal/repository_models"
func GetAllTriggers(databaseId string, collectionId string) ([]repositorymodels.Trigger, repositorymodels.RepositoryStatus) { func (r *DataRepository) GetAllTriggers(databaseId string, collectionId string) ([]repositorymodels.Trigger, repositorymodels.RepositoryStatus) {
return triggers, repositorymodels.StatusOk return r.triggers, repositorymodels.StatusOk
} }

View File

@@ -2,6 +2,6 @@ package repositories
import repositorymodels "github.com/pikami/cosmium/internal/repository_models" import repositorymodels "github.com/pikami/cosmium/internal/repository_models"
func GetAllUserDefinedFunctions(databaseId string, collectionId string) ([]repositorymodels.UserDefinedFunction, repositorymodels.RepositoryStatus) { func (r *DataRepository) GetAllUserDefinedFunctions(databaseId string, collectionId string) ([]repositorymodels.UserDefinedFunction, repositorymodels.RepositoryStatus) {
return userDefinedFunctions, repositorymodels.StatusOk return r.userDefinedFunctions, repositorymodels.StatusOk
} }

File diff suppressed because it is too large Load Diff

View File

@@ -325,6 +325,7 @@ DotFieldAccess <- "." id:Identifier {
ArrayFieldAccess <- "[\"" id:Identifier "\"]" { return id, nil } ArrayFieldAccess <- "[\"" id:Identifier "\"]" { return id, nil }
/ "[" id:Integer "]" { return strconv.Itoa(id.(int)), nil } / "[" id:Integer "]" { return strconv.Itoa(id.(int)), nil }
/ "[" id:ParameterConstant "]" { return id.(parsers.Constant).Value.(string), nil }
Identifier <- [a-zA-Z_][a-zA-Z0-9_]* { Identifier <- [a-zA-Z_][a-zA-Z0-9_]* {
return string(c.text), nil return string(c.text), nil

View File

@@ -22,6 +22,20 @@ func Test_Parse_Select(t *testing.T) {
) )
}) })
t.Run("Should parse SELECT with query parameters as accessor", func(t *testing.T) {
testQueryParse(
t,
`SELECT c.id, c[@param] FROM c`,
parsers.SelectStmt{
SelectItems: []parsers.SelectItem{
{Path: []string{"c", "id"}},
{Path: []string{"c", "@param"}},
},
Table: parsers.Table{Value: "c"},
},
)
})
t.Run("Should parse SELECT DISTINCT", func(t *testing.T) { t.Run("Should parse SELECT DISTINCT", func(t *testing.T) {
testQueryParse( testQueryParse(
t, t,

View File

@@ -317,6 +317,10 @@ func (r rowContext) applyProjection(selectItems []parsers.SelectItem) RowType {
} else { } else {
destinationName = fmt.Sprintf("$%d", index+1) destinationName = fmt.Sprintf("$%d", index+1)
} }
if destinationName[0] == '@' {
destinationName = r.parameters[destinationName].(string)
}
} }
row[destinationName] = r.resolveSelectItem(selectItem) row[destinationName] = r.resolveSelectItem(selectItem)
@@ -572,6 +576,9 @@ func (r rowContext) selectItem_SelectItemTypeField(selectItem parsers.SelectItem
if len(selectItem.Path) > 1 { if len(selectItem.Path) > 1 {
for _, pathSegment := range selectItem.Path[1:] { for _, pathSegment := range selectItem.Path[1:] {
if pathSegment[0] == '@' {
pathSegment = r.parameters[pathSegment].(string)
}
switch nestedValue := value.(type) { switch nestedValue := value.(type) {
case map[string]interface{}: case map[string]interface{}:

View File

@@ -35,6 +35,29 @@ func Test_Execute_Select(t *testing.T) {
) )
}) })
t.Run("Should execute SELECT with query parameters as accessor", func(t *testing.T) {
testQueryExecute(
t,
parsers.SelectStmt{
SelectItems: []parsers.SelectItem{
{Path: []string{"c", "id"}},
{Path: []string{"c", "@param"}},
},
Table: parsers.Table{Value: "c"},
Parameters: map[string]interface{}{
"@param": "pk",
},
},
mockData,
[]memoryexecutor.RowType{
map[string]interface{}{"id": "12345", "pk": 123},
map[string]interface{}{"id": "67890", "pk": 456},
map[string]interface{}{"id": "456", "pk": 456},
map[string]interface{}{"id": "123", "pk": 456},
},
)
})
t.Run("Should execute SELECT DISTINCT", func(t *testing.T) { t.Run("Should execute SELECT DISTINCT", func(t *testing.T) {
testQueryExecute( testQueryExecute(
t, t,

View File

@@ -0,0 +1,89 @@
package main
import "C"
import (
"encoding/json"
repositorymodels "github.com/pikami/cosmium/internal/repository_models"
)
//export CreateCollection
func CreateCollection(serverName *C.char, databaseId *C.char, collectionJson *C.char) int {
serverNameStr := C.GoString(serverName)
databaseIdStr := C.GoString(databaseId)
collectionStr := C.GoString(collectionJson)
var ok bool
var serverInstance *ServerInstance
if serverInstance, ok = getInstance(serverNameStr); !ok {
return ResponseServerInstanceNotFound
}
var collection repositorymodels.Collection
err := json.Unmarshal([]byte(collectionStr), &collection)
if err != nil {
return ResponseFailedToParseRequest
}
_, code := serverInstance.repository.CreateCollection(databaseIdStr, collection)
return repositoryStatusToResponseCode(code)
}
//export GetCollection
func GetCollection(serverName *C.char, databaseId *C.char, collectionId *C.char) *C.char {
serverNameStr := C.GoString(serverName)
databaseIdStr := C.GoString(databaseId)
collectionIdStr := C.GoString(collectionId)
var ok bool
var serverInstance *ServerInstance
if serverInstance, ok = getInstance(serverNameStr); !ok {
return C.CString("")
}
collection, code := serverInstance.repository.GetCollection(databaseIdStr, collectionIdStr)
if code != repositorymodels.StatusOk {
return C.CString("")
}
collectionJson, _ := json.Marshal(collection)
return C.CString(string(collectionJson))
}
//export GetAllCollections
func GetAllCollections(serverName *C.char, databaseId *C.char) *C.char {
serverNameStr := C.GoString(serverName)
databaseIdStr := C.GoString(databaseId)
var ok bool
var serverInstance *ServerInstance
if serverInstance, ok = getInstance(serverNameStr); !ok {
return C.CString("")
}
collections, code := serverInstance.repository.GetAllCollections(databaseIdStr)
if code != repositorymodels.StatusOk {
return C.CString("")
}
collectionsJson, _ := json.Marshal(collections)
return C.CString(string(collectionsJson))
}
//export DeleteCollection
func DeleteCollection(serverName *C.char, databaseId *C.char, collectionId *C.char) int {
serverNameStr := C.GoString(serverName)
databaseIdStr := C.GoString(databaseId)
collectionIdStr := C.GoString(collectionId)
var ok bool
var serverInstance *ServerInstance
if serverInstance, ok = getInstance(serverNameStr); !ok {
return ResponseServerInstanceNotFound
}
code := serverInstance.repository.DeleteCollection(databaseIdStr, collectionIdStr)
return repositoryStatusToResponseCode(code)
}

View File

@@ -0,0 +1,89 @@
package main
import "C"
import (
"encoding/json"
repositorymodels "github.com/pikami/cosmium/internal/repository_models"
)
//export CreateDatabase
func CreateDatabase(serverName *C.char, databaseJson *C.char) int {
serverNameStr := C.GoString(serverName)
databaseStr := C.GoString(databaseJson)
var ok bool
var serverInstance *ServerInstance
if serverInstance, ok = getInstance(serverNameStr); !ok {
return ResponseServerInstanceNotFound
}
var database repositorymodels.Database
err := json.Unmarshal([]byte(databaseStr), &database)
if err != nil {
return ResponseFailedToParseRequest
}
_, code := serverInstance.repository.CreateDatabase(database)
return repositoryStatusToResponseCode(code)
}
//export GetDatabase
func GetDatabase(serverName *C.char, databaseId *C.char) *C.char {
serverNameStr := C.GoString(serverName)
databaseIdStr := C.GoString(databaseId)
var ok bool
var serverInstance *ServerInstance
if serverInstance, ok = getInstance(serverNameStr); !ok {
return C.CString("")
}
database, code := serverInstance.repository.GetDatabase(databaseIdStr)
if code != repositorymodels.StatusOk {
return C.CString("")
}
databaseJson, _ := json.Marshal(database)
return C.CString(string(databaseJson))
}
//export GetAllDatabases
func GetAllDatabases(serverName *C.char) *C.char {
serverNameStr := C.GoString(serverName)
var ok bool
var serverInstance *ServerInstance
if serverInstance, ok = getInstance(serverNameStr); !ok {
return C.CString("")
}
databases, code := serverInstance.repository.GetAllDatabases()
if code != repositorymodels.StatusOk {
return C.CString("")
}
databasesJson, err := json.Marshal(databases)
if err != nil {
return C.CString("")
}
return C.CString(string(databasesJson))
}
//export DeleteDatabase
func DeleteDatabase(serverName *C.char, databaseId *C.char) int {
serverNameStr := C.GoString(serverName)
databaseIdStr := C.GoString(databaseId)
var ok bool
var serverInstance *ServerInstance
if serverInstance, ok = getInstance(serverNameStr); !ok {
return ResponseServerInstanceNotFound
}
code := serverInstance.repository.DeleteDatabase(databaseIdStr)
return repositoryStatusToResponseCode(code)
}

122
sharedlibrary/documents.go Normal file
View File

@@ -0,0 +1,122 @@
package main
import "C"
import (
"encoding/json"
repositorymodels "github.com/pikami/cosmium/internal/repository_models"
)
//export CreateDocument
func CreateDocument(serverName *C.char, databaseId *C.char, collectionId *C.char, documentJson *C.char) int {
serverNameStr := C.GoString(serverName)
databaseIdStr := C.GoString(databaseId)
collectionIdStr := C.GoString(collectionId)
documentStr := C.GoString(documentJson)
var ok bool
var serverInstance *ServerInstance
if serverInstance, ok = getInstance(serverNameStr); !ok {
return ResponseServerInstanceNotFound
}
var document repositorymodels.Document
err := json.Unmarshal([]byte(documentStr), &document)
if err != nil {
return ResponseFailedToParseRequest
}
_, code := serverInstance.repository.CreateDocument(databaseIdStr, collectionIdStr, document)
return repositoryStatusToResponseCode(code)
}
//export GetDocument
func GetDocument(serverName *C.char, databaseId *C.char, collectionId *C.char, documentId *C.char) *C.char {
serverNameStr := C.GoString(serverName)
databaseIdStr := C.GoString(databaseId)
collectionIdStr := C.GoString(collectionId)
documentIdStr := C.GoString(documentId)
var ok bool
var serverInstance *ServerInstance
if serverInstance, ok = getInstance(serverNameStr); !ok {
return C.CString("")
}
document, code := serverInstance.repository.GetDocument(databaseIdStr, collectionIdStr, documentIdStr)
if code != repositorymodels.StatusOk {
return C.CString("")
}
documentJson, _ := json.Marshal(document)
return C.CString(string(documentJson))
}
//export GetAllDocuments
func GetAllDocuments(serverName *C.char, databaseId *C.char, collectionId *C.char) *C.char {
serverNameStr := C.GoString(serverName)
databaseIdStr := C.GoString(databaseId)
collectionIdStr := C.GoString(collectionId)
var ok bool
var serverInstance *ServerInstance
if serverInstance, ok = getInstance(serverNameStr); !ok {
return C.CString("")
}
documents, code := serverInstance.repository.GetAllDocuments(databaseIdStr, collectionIdStr)
if code != repositorymodels.StatusOk {
return C.CString("")
}
documentsJson, _ := json.Marshal(documents)
return C.CString(string(documentsJson))
}
//export UpdateDocument
func UpdateDocument(serverName *C.char, databaseId *C.char, collectionId *C.char, documentId *C.char, documentJson *C.char) int {
serverNameStr := C.GoString(serverName)
databaseIdStr := C.GoString(databaseId)
collectionIdStr := C.GoString(collectionId)
documentIdStr := C.GoString(documentId)
documentStr := C.GoString(documentJson)
var ok bool
var serverInstance *ServerInstance
if serverInstance, ok = getInstance(serverNameStr); !ok {
return ResponseServerInstanceNotFound
}
var document repositorymodels.Document
err := json.Unmarshal([]byte(documentStr), &document)
if err != nil {
return ResponseFailedToParseRequest
}
code := serverInstance.repository.DeleteDocument(databaseIdStr, collectionIdStr, documentIdStr)
if code != repositorymodels.StatusOk {
return repositoryStatusToResponseCode(code)
}
_, code = serverInstance.repository.CreateDocument(databaseIdStr, collectionIdStr, document)
return repositoryStatusToResponseCode(code)
}
//export DeleteDocument
func DeleteDocument(serverName *C.char, databaseId *C.char, collectionId *C.char, documentId *C.char) int {
serverNameStr := C.GoString(serverName)
databaseIdStr := C.GoString(databaseId)
collectionIdStr := C.GoString(collectionId)
documentIdStr := C.GoString(documentId)
var ok bool
var serverInstance *ServerInstance
if serverInstance, ok = getInstance(serverNameStr); !ok {
return ResponseServerInstanceNotFound
}
code := serverInstance.repository.DeleteDocument(databaseIdStr, collectionIdStr, documentIdStr)
return repositoryStatusToResponseCode(code)
}

86
sharedlibrary/shared.go Normal file
View File

@@ -0,0 +1,86 @@
package main
import (
"sync"
"github.com/pikami/cosmium/api"
"github.com/pikami/cosmium/internal/repositories"
repositorymodels "github.com/pikami/cosmium/internal/repository_models"
)
type ServerInstance struct {
server *api.ApiServer
repository *repositories.DataRepository
}
var serverInstances map[string]*ServerInstance
var mutex sync.RWMutex
const (
ResponseSuccess = 0
ResponseUnknown = 100
ResponseFailedToParseConfiguration = 101
ResponseFailedToLoadState = 102
ResponseFailedToParseRequest = 103
ResponseServerInstanceAlreadyExists = 104
ResponseServerInstanceNotFound = 105
ResponseRepositoryNotFound = 200
ResponseRepositoryConflict = 201
ResponseRepositoryBadRequest = 202
)
func getInstance(serverName string) (*ServerInstance, bool) {
mutex.RLock()
defer mutex.RUnlock()
if serverInstances == nil {
serverInstances = make(map[string]*ServerInstance)
}
var ok bool
var serverInstance *ServerInstance
if serverInstance, ok = serverInstances[serverName]; !ok {
return nil, false
}
return serverInstance, true
}
func addInstance(serverName string, serverInstance *ServerInstance) {
mutex.Lock()
defer mutex.Unlock()
if serverInstances == nil {
serverInstances = make(map[string]*ServerInstance)
}
serverInstances[serverName] = serverInstance
}
func removeInstance(serverName string) {
mutex.Lock()
defer mutex.Unlock()
if serverInstances == nil {
return
}
delete(serverInstances, serverName)
}
func repositoryStatusToResponseCode(status repositorymodels.RepositoryStatus) int {
switch status {
case repositorymodels.StatusOk:
return ResponseSuccess
case repositorymodels.StatusNotFound:
return ResponseRepositoryNotFound
case repositorymodels.Conflict:
return ResponseRepositoryConflict
case repositorymodels.BadRequest:
return ResponseRepositoryBadRequest
default:
return ResponseUnknown
}
}

View File

@@ -9,44 +9,82 @@ import (
"github.com/pikami/cosmium/internal/repositories" "github.com/pikami/cosmium/internal/repositories"
) )
var currentServer *api.Server //export CreateServerInstance
func CreateServerInstance(serverName *C.char, configurationJSON *C.char) int {
configStr := C.GoString(configurationJSON)
serverNameStr := C.GoString(serverName)
if _, ok := getInstance(serverNameStr); ok {
return ResponseServerInstanceAlreadyExists
}
//export Configure
func Configure(configurationJSON *C.char) bool {
var configuration config.ServerConfig var configuration config.ServerConfig
err := json.Unmarshal([]byte(C.GoString(configurationJSON)), &configuration) err := json.Unmarshal([]byte(configStr), &configuration)
if err != nil { if err != nil {
return false return ResponseFailedToParseConfiguration
} }
config.Config = configuration
return true configuration.PopulateCalculatedFields()
configuration.ApplyDefaultsToEmptyFields()
repository := repositories.NewDataRepository(repositories.RepositoryOptions{
InitialDataFilePath: configuration.InitialDataFilePath,
PersistDataFilePath: configuration.PersistDataFilePath,
})
server := api.NewApiServer(repository, configuration)
server.Start()
addInstance(serverNameStr, &ServerInstance{
server: server,
repository: repository,
})
return ResponseSuccess
} }
//export InitializeRepository //export StopServerInstance
func InitializeRepository() { func StopServerInstance(serverName *C.char) int {
repositories.InitializeRepository() serverNameStr := C.GoString(serverName)
}
//export StartAPI if serverInstance, ok := getInstance(serverNameStr); ok {
func StartAPI() { serverInstance.server.Stop()
currentServer = api.StartAPI() removeInstance(serverNameStr)
} return ResponseSuccess
//export StopAPI
func StopAPI() {
if currentServer == nil {
currentServer.StopServer <- true
currentServer = nil
} }
return ResponseServerInstanceNotFound
} }
//export GetState //export GetServerInstanceState
func GetState() *C.char { func GetServerInstanceState(serverName *C.char) *C.char {
stateJSON, err := repositories.GetState() serverNameStr := C.GoString(serverName)
if err != nil {
return nil if serverInstance, ok := getInstance(serverNameStr); ok {
stateJSON, err := serverInstance.repository.GetState()
if err != nil {
return nil
}
return C.CString(stateJSON)
} }
return C.CString(stateJSON)
return nil
}
//export LoadServerInstanceState
func LoadServerInstanceState(serverName *C.char, stateJSON *C.char) int {
serverNameStr := C.GoString(serverName)
stateJSONStr := C.GoString(stateJSON)
if serverInstance, ok := getInstance(serverNameStr); ok {
err := serverInstance.repository.LoadStateJSON(stateJSONStr)
if err != nil {
return ResponseFailedToLoadState
}
return ResponseSuccess
}
return ResponseServerInstanceNotFound
} }
func main() {} func main() {}

View File

@@ -0,0 +1,46 @@
#include "shared.h"
int test_CreateServerInstance();
int test_StopServerInstance();
int test_ServerInstanceStateMethods();
int test_Databases();
int main(int argc, char *argv[])
{
if (argc < 2)
{
fprintf(stderr, "Usage: %s <path_to_shared_library>\n", argv[0]);
return EXIT_FAILURE;
}
const char *libPath = argv[1];
handle = dlopen(libPath, RTLD_LAZY);
if (!handle)
{
fprintf(stderr, "Failed to load shared library: %s\n", dlerror());
return EXIT_FAILURE;
}
printf("Running tests for library: %s\n", libPath);
int results[] = {
test_CreateServerInstance(),
test_Databases(),
test_ServerInstanceStateMethods(),
test_StopServerInstance(),
};
int numTests = sizeof(results) / sizeof(results[0]);
int numPassed = 0;
for (int i = 0; i < numTests; i++)
{
if (results[i])
{
numPassed++;
}
}
printf("Tests passed: %d/%d\n", numPassed, numTests);
dlclose(handle);
return EXIT_SUCCESS;
}

View File

@@ -0,0 +1,36 @@
#include "shared.h"
void *handle = NULL;
void *load_function(const char *func_name)
{
void *func = dlsym(handle, func_name);
if (!func)
{
fprintf(stderr, "Failed to load function %s: %s\n", func_name, dlerror());
}
return func;
}
char *compact_json(const char *json)
{
size_t len = strlen(json);
char *compact = (char *)malloc(len + 1);
if (!compact)
{
fprintf(stderr, "Failed to allocate memory for compacted JSON\n");
return NULL;
}
char *dest = compact;
for (const char *src = json; *src != '\0'; ++src)
{
if (!isspace((unsigned char)*src)) // Skip spaces, newlines, tabs, etc.
{
*dest++ = *src;
}
}
*dest = '\0'; // Null-terminate the string
return compact;
}

View File

@@ -0,0 +1,15 @@
#ifndef SHARED_H
#define SHARED_H
#include <stdio.h>
#include <stdlib.h>
#include <dlfcn.h>
#include <string.h>
#include <ctype.h>
extern void *handle;
void *load_function(const char *func_name);
char *compact_json(const char *json);
#endif

View File

@@ -0,0 +1,29 @@
#include "shared.h"
int test_CreateServerInstance()
{
typedef int (*CreateServerInstanceFn)(char *, char *);
CreateServerInstanceFn CreateServerInstance = (CreateServerInstanceFn)load_function("CreateServerInstance");
if (!CreateServerInstance)
{
fprintf(stderr, "Failed to find CreateServerInstance function\n");
return 0;
}
char *serverName = "TestServer";
char *configJSON = "{\"host\":\"localhost\",\"port\":8080}";
int result = CreateServerInstance(serverName, configJSON);
if (result == 0)
{
printf("CreateServerInstance: SUCCESS\n");
}
else
{
printf("CreateServerInstance: FAILED (result = %d)\n", result);
return 0;
}
return 1;
}

View File

@@ -0,0 +1,47 @@
#include "shared.h"
int test_Databases()
{
typedef int (*CreateDatabaseFn)(char *, char *);
CreateDatabaseFn CreateDatabase = (CreateDatabaseFn)load_function("CreateDatabase");
if (!CreateDatabase)
{
fprintf(stderr, "Failed to find CreateDatabase function\n");
return 0;
}
char *serverName = "TestServer";
char *configJSON = "{\"id\":\"test-db\"}";
int result = CreateDatabase(serverName, configJSON);
if (result == 0)
{
printf("CreateDatabase: SUCCESS\n");
}
else
{
printf("CreateDatabase: FAILED (result = %d)\n", result);
return 0;
}
typedef char *(*GetDatabaseFn)(char *, char *);
GetDatabaseFn GetDatabase = (GetDatabaseFn)load_function("GetDatabase");
if (!GetDatabase)
{
fprintf(stderr, "Failed to find GetDatabase function\n");
return 0;
}
char *database = GetDatabase(serverName, "test-db");
if (database)
{
printf("GetDatabase: SUCCESS (database = %s)\n", database);
}
else
{
printf("GetDatabase: FAILED\n");
return 0;
}
return 1;
}

View File

@@ -0,0 +1,68 @@
#include "shared.h"
int test_ServerInstanceStateMethods()
{
typedef int (*LoadServerInstanceStateFn)(char *, char *);
LoadServerInstanceStateFn LoadServerInstanceState = (LoadServerInstanceStateFn)load_function("LoadServerInstanceState");
if (!LoadServerInstanceState)
{
fprintf(stderr, "Failed to find LoadServerInstanceState function\n");
return 0;
}
char *serverName = "TestServer";
char *stateJSON = "{\"databases\":{\"test-db\":{\"id\":\"test-db\"}}}";
int result = LoadServerInstanceState(serverName, stateJSON);
if (result == 0)
{
printf("LoadServerInstanceState: SUCCESS\n");
}
else
{
printf("LoadServerInstanceState: FAILED (result = %d)\n", result);
return 0;
}
typedef char *(*GetServerInstanceStateFn)(char *);
GetServerInstanceStateFn GetServerInstanceState = (GetServerInstanceStateFn)load_function("GetServerInstanceState");
if (!GetServerInstanceState)
{
fprintf(stderr, "Failed to find GetServerInstanceState function\n");
return 0;
}
char *state = GetServerInstanceState(serverName);
if (state)
{
printf("GetServerInstanceState: SUCCESS (state = %s)\n", state);
}
else
{
printf("GetServerInstanceState: FAILED\n");
return 0;
}
const char *expected_state = "{\"databases\":{\"test-db\":{\"id\":\"test-db\",\"_ts\":0,\"_rid\":\"\",\"_etag\":\"\",\"_self\":\"\"}},\"collections\":{\"test-db\":{}},\"documents\":{\"test-db\":{}}}";
char *compact_state = compact_json(state);
if (!compact_state)
{
free(state);
return 0;
}
if (strcmp(compact_state, expected_state) == 0)
{
printf("GetServerInstanceState: State matches expected value.\n");
}
else
{
printf("GetServerInstanceState: State does not match expected value.\n");
printf("Expected: %s\n", expected_state);
printf("Actual: %s\n", compact_state);
return 0;
}
free(state);
free(compact_state);
return 1;
}

View File

@@ -0,0 +1,27 @@
#include "shared.h"
int test_StopServerInstance()
{
typedef int (*StopServerInstanceFn)(char *);
StopServerInstanceFn StopServerInstance = (StopServerInstanceFn)load_function("StopServerInstance");
if (!StopServerInstance)
{
fprintf(stderr, "Failed to find StopServerInstance function\n");
return 0;
}
char *serverName = "TestServer";
int result = StopServerInstance(serverName);
if (result == 0)
{
printf("StopServerInstance: SUCCESS\n");
}
else
{
printf("StopServerInstance: FAILED (result = %d)\n", result);
return 0;
}
return 1;
}