diff --git a/api/api_server.go b/api/api_server.go new file mode 100644 index 0000000..a22158b --- /dev/null +++ b/api/api_server.go @@ -0,0 +1,35 @@ +package api + +import ( + "github.com/gin-gonic/gin" + "github.com/pikami/cosmium/api/config" + "github.com/pikami/cosmium/internal/repositories" +) + +type ApiServer struct { + stopServer chan interface{} + isActive bool + router *gin.Engine + config config.ServerConfig +} + +func NewApiServer(dataRepository *repositories.DataRepository, config config.ServerConfig) *ApiServer { + stopChan := make(chan interface{}) + + apiServer := &ApiServer{ + stopServer: stopChan, + config: config, + } + + apiServer.CreateRouter(dataRepository) + + return apiServer +} + +func (s *ApiServer) GetRouter() *gin.Engine { + return s.router +} + +func (s *ApiServer) Stop() { + s.stopServer <- true +} diff --git a/api/config/config.go b/api/config/config.go index f74dc9b..3553361 100644 --- a/api/config/config.go +++ b/api/config/config.go @@ -5,6 +5,8 @@ import ( "fmt" "os" "strings" + + "github.com/pikami/cosmium/internal/logger" ) const ( @@ -13,9 +15,7 @@ const ( ExplorerBaseUrlLocation = "/_explorer" ) -var Config = ServerConfig{} - -func ParseFlags() { +func ParseFlags() ServerConfig { host := flag.String("Host", "localhost", "Hostname") port := flag.Int("Port", 8081, "Listen port") explorerPath := flag.String("ExplorerDir", "", "Path to cosmos-explorer files") @@ -31,22 +31,30 @@ func ParseFlags() { flag.Parse() setFlagsFromEnvironment() - Config.Host = *host - Config.Port = *port - Config.ExplorerPath = *explorerPath - Config.TLS_CertificatePath = *tlsCertificatePath - Config.TLS_CertificateKey = *tlsCertificateKey - Config.InitialDataFilePath = *initialDataPath - Config.PersistDataFilePath = *persistDataPath - Config.DisableAuth = *disableAuthentication - Config.DisableTls = *disableTls - Config.Debug = *debug + config := ServerConfig{} + config.Host = *host + config.Port = *port + config.ExplorerPath = *explorerPath + config.TLS_CertificatePath = *tlsCertificatePath + config.TLS_CertificateKey = *tlsCertificateKey + config.InitialDataFilePath = *initialDataPath + config.PersistDataFilePath = *persistDataPath + config.DisableAuth = *disableAuthentication + config.DisableTls = *disableTls + config.Debug = *debug + config.AccountKey = *accountKey - Config.DatabaseAccount = Config.Host - Config.DatabaseDomain = Config.Host - Config.DatabaseEndpoint = fmt.Sprintf("https://%s:%d/", Config.Host, Config.Port) - Config.AccountKey = *accountKey - Config.ExplorerBaseUrlLocation = ExplorerBaseUrlLocation + config.PopulateCalculatedFields() + + return config +} + +func (c *ServerConfig) PopulateCalculatedFields() { + c.DatabaseAccount = c.Host + c.DatabaseDomain = c.Host + c.DatabaseEndpoint = fmt.Sprintf("https://%s:%d/", c.Host, c.Port) + c.ExplorerBaseUrlLocation = ExplorerBaseUrlLocation + logger.EnableDebugOutput = c.Debug } func setFlagsFromEnvironment() (err error) { diff --git a/api/handlers/collections.go b/api/handlers/collections.go index 294c2c3..f2e3836 100644 --- a/api/handlers/collections.go +++ b/api/handlers/collections.go @@ -5,16 +5,15 @@ import ( "net/http" "github.com/gin-gonic/gin" - "github.com/pikami/cosmium/internal/repositories" repositorymodels "github.com/pikami/cosmium/internal/repository_models" ) -func GetAllCollections(c *gin.Context) { +func (h *Handlers) GetAllCollections(c *gin.Context) { databaseId := c.Param("databaseId") - collections, status := repositories.GetAllCollections(databaseId) + collections, status := h.repository.GetAllCollections(databaseId) if status == repositorymodels.StatusOk { - database, _ := repositories.GetDatabase(databaseId) + database, _ := h.repository.GetDatabase(databaseId) c.Header("x-ms-item-count", fmt.Sprintf("%d", len(collections))) c.IndentedJSON(http.StatusOK, gin.H{ @@ -28,11 +27,11 @@ func GetAllCollections(c *gin.Context) { c.IndentedJSON(http.StatusInternalServerError, gin.H{"message": "Unknown error"}) } -func GetCollection(c *gin.Context) { +func (h *Handlers) GetCollection(c *gin.Context) { databaseId := c.Param("databaseId") id := c.Param("collId") - collection, status := repositories.GetCollection(databaseId, id) + collection, status := h.repository.GetCollection(databaseId, id) if status == repositorymodels.StatusOk { c.IndentedJSON(http.StatusOK, collection) return @@ -46,11 +45,11 @@ func GetCollection(c *gin.Context) { c.IndentedJSON(http.StatusInternalServerError, gin.H{"message": "Unknown error"}) } -func DeleteCollection(c *gin.Context) { +func (h *Handlers) DeleteCollection(c *gin.Context) { databaseId := c.Param("databaseId") id := c.Param("collId") - status := repositories.DeleteCollection(databaseId, id) + status := h.repository.DeleteCollection(databaseId, id) if status == repositorymodels.StatusOk { c.Status(http.StatusNoContent) return @@ -64,7 +63,7 @@ func DeleteCollection(c *gin.Context) { c.IndentedJSON(http.StatusInternalServerError, gin.H{"message": "Unknown error"}) } -func CreateCollection(c *gin.Context) { +func (h *Handlers) CreateCollection(c *gin.Context) { databaseId := c.Param("databaseId") var newCollection repositorymodels.Collection @@ -78,7 +77,7 @@ func CreateCollection(c *gin.Context) { return } - createdCollection, status := repositories.CreateCollection(databaseId, newCollection) + createdCollection, status := h.repository.CreateCollection(databaseId, newCollection) if status == repositorymodels.Conflict { c.IndentedJSON(http.StatusConflict, gin.H{"message": "Conflict"}) return diff --git a/api/handlers/cosmium.go b/api/handlers/cosmium.go index cb8512c..2ad9366 100644 --- a/api/handlers/cosmium.go +++ b/api/handlers/cosmium.go @@ -4,11 +4,10 @@ import ( "net/http" "github.com/gin-gonic/gin" - "github.com/pikami/cosmium/internal/repositories" ) -func CosmiumExport(c *gin.Context) { - repositoryState, err := repositories.GetState() +func (h *Handlers) CosmiumExport(c *gin.Context) { + repositoryState, err := h.repository.GetState() if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return diff --git a/api/handlers/databases.go b/api/handlers/databases.go index 9764259..8a8733a 100644 --- a/api/handlers/databases.go +++ b/api/handlers/databases.go @@ -5,12 +5,11 @@ import ( "net/http" "github.com/gin-gonic/gin" - "github.com/pikami/cosmium/internal/repositories" repositorymodels "github.com/pikami/cosmium/internal/repository_models" ) -func GetAllDatabases(c *gin.Context) { - databases, status := repositories.GetAllDatabases() +func (h *Handlers) GetAllDatabases(c *gin.Context) { + databases, status := h.repository.GetAllDatabases() if status == repositorymodels.StatusOk { c.Header("x-ms-item-count", fmt.Sprintf("%d", len(databases))) c.IndentedJSON(http.StatusOK, gin.H{ @@ -24,10 +23,10 @@ func GetAllDatabases(c *gin.Context) { c.IndentedJSON(http.StatusInternalServerError, gin.H{"message": "Unknown error"}) } -func GetDatabase(c *gin.Context) { +func (h *Handlers) GetDatabase(c *gin.Context) { id := c.Param("databaseId") - database, status := repositories.GetDatabase(id) + database, status := h.repository.GetDatabase(id) if status == repositorymodels.StatusOk { c.IndentedJSON(http.StatusOK, database) return @@ -41,10 +40,10 @@ func GetDatabase(c *gin.Context) { c.IndentedJSON(http.StatusInternalServerError, gin.H{"message": "Unknown error"}) } -func DeleteDatabase(c *gin.Context) { +func (h *Handlers) DeleteDatabase(c *gin.Context) { id := c.Param("databaseId") - status := repositories.DeleteDatabase(id) + status := h.repository.DeleteDatabase(id) if status == repositorymodels.StatusOk { c.Status(http.StatusNoContent) return @@ -58,7 +57,7 @@ func DeleteDatabase(c *gin.Context) { c.IndentedJSON(http.StatusInternalServerError, gin.H{"message": "Unknown error"}) } -func CreateDatabase(c *gin.Context) { +func (h *Handlers) CreateDatabase(c *gin.Context) { var newDatabase repositorymodels.Database if err := c.BindJSON(&newDatabase); err != nil { @@ -71,7 +70,7 @@ func CreateDatabase(c *gin.Context) { return } - createdDatabase, status := repositories.CreateDatabase(newDatabase) + createdDatabase, status := h.repository.CreateDatabase(newDatabase) if status == repositorymodels.Conflict { c.IndentedJSON(http.StatusConflict, gin.H{"message": "Conflict"}) return diff --git a/api/handlers/documents.go b/api/handlers/documents.go index 964ee92..07ea6d7 100644 --- a/api/handlers/documents.go +++ b/api/handlers/documents.go @@ -10,17 +10,16 @@ import ( "github.com/gin-gonic/gin" "github.com/pikami/cosmium/internal/constants" "github.com/pikami/cosmium/internal/logger" - "github.com/pikami/cosmium/internal/repositories" repositorymodels "github.com/pikami/cosmium/internal/repository_models" ) -func GetAllDocuments(c *gin.Context) { +func (h *Handlers) GetAllDocuments(c *gin.Context) { databaseId := c.Param("databaseId") collectionId := c.Param("collId") - documents, status := repositories.GetAllDocuments(databaseId, collectionId) + documents, status := h.repository.GetAllDocuments(databaseId, collectionId) if status == repositorymodels.StatusOk { - collection, _ := repositories.GetCollection(databaseId, collectionId) + collection, _ := h.repository.GetCollection(databaseId, collectionId) c.Header("x-ms-item-count", fmt.Sprintf("%d", len(documents))) c.IndentedJSON(http.StatusOK, gin.H{ @@ -34,12 +33,12 @@ func GetAllDocuments(c *gin.Context) { c.IndentedJSON(http.StatusInternalServerError, gin.H{"message": "Unknown error"}) } -func GetDocument(c *gin.Context) { +func (h *Handlers) GetDocument(c *gin.Context) { databaseId := c.Param("databaseId") collectionId := c.Param("collId") documentId := c.Param("docId") - document, status := repositories.GetDocument(databaseId, collectionId, documentId) + document, status := h.repository.GetDocument(databaseId, collectionId, documentId) if status == repositorymodels.StatusOk { c.IndentedJSON(http.StatusOK, document) return @@ -53,12 +52,12 @@ func GetDocument(c *gin.Context) { c.IndentedJSON(http.StatusInternalServerError, gin.H{"message": "Unknown error"}) } -func DeleteDocument(c *gin.Context) { +func (h *Handlers) DeleteDocument(c *gin.Context) { databaseId := c.Param("databaseId") collectionId := c.Param("collId") documentId := c.Param("docId") - status := repositories.DeleteDocument(databaseId, collectionId, documentId) + status := h.repository.DeleteDocument(databaseId, collectionId, documentId) if status == repositorymodels.StatusOk { c.Status(http.StatusNoContent) return @@ -73,7 +72,7 @@ func DeleteDocument(c *gin.Context) { } // TODO: Maybe move "replace" logic to repository -func ReplaceDocument(c *gin.Context) { +func (h *Handlers) ReplaceDocument(c *gin.Context) { databaseId := c.Param("databaseId") collectionId := c.Param("collId") documentId := c.Param("docId") @@ -84,13 +83,13 @@ func ReplaceDocument(c *gin.Context) { return } - status := repositories.DeleteDocument(databaseId, collectionId, documentId) + status := h.repository.DeleteDocument(databaseId, collectionId, documentId) if status == repositorymodels.StatusNotFound { c.IndentedJSON(http.StatusNotFound, gin.H{"message": "NotFound"}) return } - createdDocument, status := repositories.CreateDocument(databaseId, collectionId, requestBody) + createdDocument, status := h.repository.CreateDocument(databaseId, collectionId, requestBody) if status == repositorymodels.Conflict { c.IndentedJSON(http.StatusConflict, gin.H{"message": "Conflict"}) return @@ -104,12 +103,12 @@ func ReplaceDocument(c *gin.Context) { c.IndentedJSON(http.StatusInternalServerError, gin.H{"message": "Unknown error"}) } -func PatchDocument(c *gin.Context) { +func (h *Handlers) PatchDocument(c *gin.Context) { databaseId := c.Param("databaseId") collectionId := c.Param("collId") documentId := c.Param("docId") - document, status := repositories.GetDocument(databaseId, collectionId, documentId) + document, status := h.repository.GetDocument(databaseId, collectionId, documentId) if status == repositorymodels.StatusNotFound { c.IndentedJSON(http.StatusNotFound, gin.H{"message": "NotFound"}) return @@ -160,13 +159,13 @@ func PatchDocument(c *gin.Context) { return } - status = repositories.DeleteDocument(databaseId, collectionId, documentId) + status = h.repository.DeleteDocument(databaseId, collectionId, documentId) if status == repositorymodels.StatusNotFound { c.IndentedJSON(http.StatusNotFound, gin.H{"message": "NotFound"}) return } - createdDocument, status := repositories.CreateDocument(databaseId, collectionId, modifiedDocument) + createdDocument, status := h.repository.CreateDocument(databaseId, collectionId, modifiedDocument) if status == repositorymodels.Conflict { c.IndentedJSON(http.StatusConflict, gin.H{"message": "Conflict"}) return @@ -180,7 +179,7 @@ func PatchDocument(c *gin.Context) { c.IndentedJSON(http.StatusInternalServerError, gin.H{"message": "Unknown error"}) } -func DocumentsPost(c *gin.Context) { +func (h *Handlers) DocumentsPost(c *gin.Context) { databaseId := c.Param("databaseId") collectionId := c.Param("collId") @@ -202,14 +201,14 @@ func DocumentsPost(c *gin.Context) { queryParameters = parametersToMap(paramsArray) } - docs, status := repositories.ExecuteQueryDocuments(databaseId, collectionId, query.(string), queryParameters) + docs, status := h.repository.ExecuteQueryDocuments(databaseId, collectionId, query.(string), queryParameters) if status != repositorymodels.StatusOk { // TODO: Currently we return everything if the query fails - GetAllDocuments(c) + h.GetAllDocuments(c) return } - collection, _ := repositories.GetCollection(databaseId, collectionId) + collection, _ := h.repository.GetCollection(databaseId, collectionId) c.Header("x-ms-item-count", fmt.Sprintf("%d", len(docs))) c.IndentedJSON(http.StatusOK, gin.H{ "_rid": collection.ResourceID, @@ -226,10 +225,10 @@ func DocumentsPost(c *gin.Context) { isUpsert, _ := strconv.ParseBool(c.GetHeader("x-ms-documentdb-is-upsert")) if isUpsert { - repositories.DeleteDocument(databaseId, collectionId, requestBody["id"].(string)) + h.repository.DeleteDocument(databaseId, collectionId, requestBody["id"].(string)) } - createdDocument, status := repositories.CreateDocument(databaseId, collectionId, requestBody) + createdDocument, status := h.repository.CreateDocument(databaseId, collectionId, requestBody) if status == repositorymodels.Conflict { c.IndentedJSON(http.StatusConflict, gin.H{"message": "Conflict"}) return diff --git a/api/handlers/explorer.go b/api/handlers/explorer.go index 65393f5..ed43e87 100644 --- a/api/handlers/explorer.go +++ b/api/handlers/explorer.go @@ -4,15 +4,14 @@ import ( "fmt" "github.com/gin-gonic/gin" - "github.com/pikami/cosmium/api/config" ) -func RegisterExplorerHandlers(router *gin.Engine) { - explorer := router.Group(config.Config.ExplorerBaseUrlLocation) +func (h *Handlers) RegisterExplorerHandlers(router *gin.Engine) { + explorer := router.Group(h.config.ExplorerBaseUrlLocation) { explorer.Use(func(ctx *gin.Context) { if ctx.Param("filepath") == "/config.json" { - endpoint := fmt.Sprintf("https://%s:%d", config.Config.Host, config.Config.Port) + endpoint := fmt.Sprintf("https://%s:%d", h.config.Host, h.config.Port) ctx.JSON(200, gin.H{ "BACKEND_ENDPOINT": endpoint, "MONGO_BACKEND_ENDPOINT": endpoint, @@ -25,8 +24,8 @@ func RegisterExplorerHandlers(router *gin.Engine) { } }) - if config.Config.ExplorerPath != "" { - explorer.Static("/", config.Config.ExplorerPath) + if h.config.ExplorerPath != "" { + explorer.Static("/", h.config.ExplorerPath) } } } diff --git a/api/handlers/handlers.go b/api/handlers/handlers.go new file mode 100644 index 0000000..5e97442 --- /dev/null +++ b/api/handlers/handlers.go @@ -0,0 +1,18 @@ +package handlers + +import ( + "github.com/pikami/cosmium/api/config" + "github.com/pikami/cosmium/internal/repositories" +) + +type Handlers struct { + repository *repositories.DataRepository + config config.ServerConfig +} + +func NewHandlers(dataRepository *repositories.DataRepository, config config.ServerConfig) *Handlers { + return &Handlers{ + repository: dataRepository, + config: config, + } +} diff --git a/api/handlers/middleware/authentication.go b/api/handlers/middleware/authentication.go index 4c95b9d..5f94ce5 100644 --- a/api/handlers/middleware/authentication.go +++ b/api/handlers/middleware/authentication.go @@ -10,11 +10,11 @@ import ( "github.com/pikami/cosmium/internal/logger" ) -func Authentication() gin.HandlerFunc { +func Authentication(config config.ServerConfig) gin.HandlerFunc { return func(c *gin.Context) { requestUrl := c.Request.URL.String() - if config.Config.DisableAuth || - strings.HasPrefix(requestUrl, config.Config.ExplorerBaseUrlLocation) || + if config.DisableAuth || + strings.HasPrefix(requestUrl, config.ExplorerBaseUrlLocation) || strings.HasPrefix(requestUrl, "/cosmium") { return } @@ -25,7 +25,7 @@ func Authentication() gin.HandlerFunc { authHeader := c.Request.Header.Get("authorization") date := c.Request.Header.Get("x-ms-date") expectedSignature := authentication.GenerateSignature( - c.Request.Method, resourceType, resourceId, date, config.Config.AccountKey) + c.Request.Method, resourceType, resourceId, date, config.AccountKey) decoded, _ := url.QueryUnescape(authHeader) params, _ := url.ParseQuery(decoded) diff --git a/api/handlers/middleware/strip_trailing_slashes.go b/api/handlers/middleware/strip_trailing_slashes.go index 18c091d..8a6f82c 100644 --- a/api/handlers/middleware/strip_trailing_slashes.go +++ b/api/handlers/middleware/strip_trailing_slashes.go @@ -7,10 +7,10 @@ import ( "github.com/pikami/cosmium/api/config" ) -func StripTrailingSlashes(r *gin.Engine) gin.HandlerFunc { +func StripTrailingSlashes(r *gin.Engine, config config.ServerConfig) gin.HandlerFunc { return func(c *gin.Context) { path := c.Request.URL.Path - if len(path) > 1 && path[len(path)-1] == '/' && !strings.Contains(path, config.Config.ExplorerBaseUrlLocation) { + if len(path) > 1 && path[len(path)-1] == '/' && !strings.Contains(path, config.ExplorerBaseUrlLocation) { c.Request.URL.Path = path[:len(path)-1] r.HandleContext(c) c.Abort() diff --git a/api/handlers/partition_key_ranges.go b/api/handlers/partition_key_ranges.go index 865c13a..1f60ef6 100644 --- a/api/handlers/partition_key_ranges.go +++ b/api/handlers/partition_key_ranges.go @@ -5,11 +5,10 @@ import ( "net/http" "github.com/gin-gonic/gin" - "github.com/pikami/cosmium/internal/repositories" repositorymodels "github.com/pikami/cosmium/internal/repository_models" ) -func GetPartitionKeyRanges(c *gin.Context) { +func (h *Handlers) GetPartitionKeyRanges(c *gin.Context) { databaseId := c.Param("databaseId") collectionId := c.Param("collId") @@ -18,7 +17,7 @@ func GetPartitionKeyRanges(c *gin.Context) { return } - partitionKeyRanges, status := repositories.GetPartitionKeyRanges(databaseId, collectionId) + partitionKeyRanges, status := h.repository.GetPartitionKeyRanges(databaseId, collectionId) if status == repositorymodels.StatusOk { c.Header("etag", "\"420\"") c.Header("lsn", "420") @@ -27,7 +26,7 @@ func GetPartitionKeyRanges(c *gin.Context) { c.Header("x-ms-item-count", fmt.Sprintf("%d", len(partitionKeyRanges))) collectionRid := collectionId - collection, _ := repositories.GetCollection(databaseId, collectionId) + collection, _ := h.repository.GetCollection(databaseId, collectionId) if collection.ResourceID != "" { collectionRid = collection.ResourceID } diff --git a/api/handlers/server_info.go b/api/handlers/server_info.go index 23ed253..fd1c31d 100644 --- a/api/handlers/server_info.go +++ b/api/handlers/server_info.go @@ -5,27 +5,26 @@ import ( "net/http" "github.com/gin-gonic/gin" - "github.com/pikami/cosmium/api/config" ) -func GetServerInfo(c *gin.Context) { +func (h *Handlers) GetServerInfo(c *gin.Context) { c.IndentedJSON(http.StatusOK, gin.H{ "_self": "", - "id": config.Config.DatabaseAccount, - "_rid": fmt.Sprintf("%s.%s", config.Config.DatabaseAccount, config.Config.DatabaseDomain), + "id": h.config.DatabaseAccount, + "_rid": fmt.Sprintf("%s.%s", h.config.DatabaseAccount, h.config.DatabaseDomain), "media": "//media/", "addresses": "//addresses/", "_dbs": "//dbs/", "writableLocations": []map[string]interface{}{ { "name": "South Central US", - "databaseAccountEndpoint": config.Config.DatabaseEndpoint, + "databaseAccountEndpoint": h.config.DatabaseEndpoint, }, }, "readableLocations": []map[string]interface{}{ { "name": "South Central US", - "databaseAccountEndpoint": config.Config.DatabaseEndpoint, + "databaseAccountEndpoint": h.config.DatabaseEndpoint, }, }, "enableMultipleWriteLocations": false, diff --git a/api/handlers/stored_procedures.go b/api/handlers/stored_procedures.go index b2b29e3..0c6639b 100644 --- a/api/handlers/stored_procedures.go +++ b/api/handlers/stored_procedures.go @@ -5,15 +5,14 @@ import ( "net/http" "github.com/gin-gonic/gin" - "github.com/pikami/cosmium/internal/repositories" repositorymodels "github.com/pikami/cosmium/internal/repository_models" ) -func GetAllStoredProcedures(c *gin.Context) { +func (h *Handlers) GetAllStoredProcedures(c *gin.Context) { databaseId := c.Param("databaseId") collectionId := c.Param("collId") - sps, status := repositories.GetAllStoredProcedures(databaseId, collectionId) + sps, status := h.repository.GetAllStoredProcedures(databaseId, collectionId) if status == repositorymodels.StatusOk { c.Header("x-ms-item-count", fmt.Sprintf("%d", len(sps))) diff --git a/api/handlers/triggers.go b/api/handlers/triggers.go index a9bb6a7..3ea0c9c 100644 --- a/api/handlers/triggers.go +++ b/api/handlers/triggers.go @@ -5,15 +5,14 @@ import ( "net/http" "github.com/gin-gonic/gin" - "github.com/pikami/cosmium/internal/repositories" repositorymodels "github.com/pikami/cosmium/internal/repository_models" ) -func GetAllTriggers(c *gin.Context) { +func (h *Handlers) GetAllTriggers(c *gin.Context) { databaseId := c.Param("databaseId") collectionId := c.Param("collId") - triggers, status := repositories.GetAllTriggers(databaseId, collectionId) + triggers, status := h.repository.GetAllTriggers(databaseId, collectionId) if status == repositorymodels.StatusOk { c.Header("x-ms-item-count", fmt.Sprintf("%d", len(triggers))) diff --git a/api/handlers/user_defined_functions.go b/api/handlers/user_defined_functions.go index ac2086a..688c38f 100644 --- a/api/handlers/user_defined_functions.go +++ b/api/handlers/user_defined_functions.go @@ -5,15 +5,14 @@ import ( "net/http" "github.com/gin-gonic/gin" - "github.com/pikami/cosmium/internal/repositories" repositorymodels "github.com/pikami/cosmium/internal/repository_models" ) -func GetAllUserDefinedFunctions(c *gin.Context) { +func (h *Handlers) GetAllUserDefinedFunctions(c *gin.Context) { databaseId := c.Param("databaseId") collectionId := c.Param("collId") - udfs, status := repositories.GetAllUserDefinedFunctions(databaseId, collectionId) + udfs, status := h.repository.GetAllUserDefinedFunctions(databaseId, collectionId) if status == repositorymodels.StatusOk { c.Header("x-ms-item-count", fmt.Sprintf("%d", len(udfs))) diff --git a/api/router.go b/api/router.go index 8ac3138..79bdcba 100644 --- a/api/router.go +++ b/api/router.go @@ -6,78 +6,75 @@ import ( "net/http" "github.com/gin-gonic/gin" - "github.com/pikami/cosmium/api/config" "github.com/pikami/cosmium/api/handlers" "github.com/pikami/cosmium/api/handlers/middleware" "github.com/pikami/cosmium/internal/logger" + "github.com/pikami/cosmium/internal/repositories" tlsprovider "github.com/pikami/cosmium/internal/tls_provider" ) -type Server struct { - StopServer chan interface{} -} +func (s *ApiServer) CreateRouter(repository *repositories.DataRepository) { + routeHandlers := handlers.NewHandlers(repository, s.config) -func CreateRouter() *gin.Engine { router := gin.Default(func(e *gin.Engine) { e.RedirectTrailingSlash = false }) - if config.Config.Debug { + if s.config.Debug { router.Use(middleware.RequestLogger()) } - router.Use(middleware.StripTrailingSlashes(router)) - router.Use(middleware.Authentication()) + router.Use(middleware.StripTrailingSlashes(router, s.config)) + router.Use(middleware.Authentication(s.config)) - router.GET("/dbs/:databaseId/colls/:collId/pkranges", handlers.GetPartitionKeyRanges) + router.GET("/dbs/:databaseId/colls/:collId/pkranges", routeHandlers.GetPartitionKeyRanges) - router.POST("/dbs/:databaseId/colls/:collId/docs", handlers.DocumentsPost) - router.GET("/dbs/:databaseId/colls/:collId/docs", handlers.GetAllDocuments) - router.GET("/dbs/:databaseId/colls/:collId/docs/:docId", handlers.GetDocument) - router.PUT("/dbs/:databaseId/colls/:collId/docs/:docId", handlers.ReplaceDocument) - router.PATCH("/dbs/:databaseId/colls/:collId/docs/:docId", handlers.PatchDocument) - router.DELETE("/dbs/:databaseId/colls/:collId/docs/:docId", handlers.DeleteDocument) + router.POST("/dbs/:databaseId/colls/:collId/docs", routeHandlers.DocumentsPost) + router.GET("/dbs/:databaseId/colls/:collId/docs", routeHandlers.GetAllDocuments) + router.GET("/dbs/:databaseId/colls/:collId/docs/:docId", routeHandlers.GetDocument) + router.PUT("/dbs/:databaseId/colls/:collId/docs/:docId", routeHandlers.ReplaceDocument) + router.PATCH("/dbs/:databaseId/colls/:collId/docs/:docId", routeHandlers.PatchDocument) + router.DELETE("/dbs/:databaseId/colls/:collId/docs/:docId", routeHandlers.DeleteDocument) - router.POST("/dbs/:databaseId/colls", handlers.CreateCollection) - router.GET("/dbs/:databaseId/colls", handlers.GetAllCollections) - router.GET("/dbs/:databaseId/colls/:collId", handlers.GetCollection) - router.DELETE("/dbs/:databaseId/colls/:collId", handlers.DeleteCollection) + router.POST("/dbs/:databaseId/colls", routeHandlers.CreateCollection) + router.GET("/dbs/:databaseId/colls", routeHandlers.GetAllCollections) + router.GET("/dbs/:databaseId/colls/:collId", routeHandlers.GetCollection) + router.DELETE("/dbs/:databaseId/colls/:collId", routeHandlers.DeleteCollection) - router.POST("/dbs", handlers.CreateDatabase) - router.GET("/dbs", handlers.GetAllDatabases) - router.GET("/dbs/:databaseId", handlers.GetDatabase) - router.DELETE("/dbs/:databaseId", handlers.DeleteDatabase) + router.POST("/dbs", routeHandlers.CreateDatabase) + router.GET("/dbs", routeHandlers.GetAllDatabases) + router.GET("/dbs/:databaseId", routeHandlers.GetDatabase) + router.DELETE("/dbs/:databaseId", routeHandlers.DeleteDatabase) - router.GET("/dbs/:databaseId/colls/:collId/udfs", handlers.GetAllUserDefinedFunctions) - router.GET("/dbs/:databaseId/colls/:collId/sprocs", handlers.GetAllStoredProcedures) - router.GET("/dbs/:databaseId/colls/:collId/triggers", handlers.GetAllTriggers) + router.GET("/dbs/:databaseId/colls/:collId/udfs", routeHandlers.GetAllUserDefinedFunctions) + router.GET("/dbs/:databaseId/colls/:collId/sprocs", routeHandlers.GetAllStoredProcedures) + router.GET("/dbs/:databaseId/colls/:collId/triggers", routeHandlers.GetAllTriggers) router.GET("/offers", handlers.GetOffers) - router.GET("/", handlers.GetServerInfo) + router.GET("/", routeHandlers.GetServerInfo) - router.GET("/cosmium/export", handlers.CosmiumExport) + router.GET("/cosmium/export", routeHandlers.CosmiumExport) - handlers.RegisterExplorerHandlers(router) + routeHandlers.RegisterExplorerHandlers(router) - return router + s.router = router } -func StartAPI() *Server { - if !config.Config.Debug { +func (s *ApiServer) Start() { + if !s.config.Debug { gin.SetMode(gin.ReleaseMode) } - router := CreateRouter() - listenAddress := fmt.Sprintf(":%d", config.Config.Port) - stopChan := make(chan interface{}) + listenAddress := fmt.Sprintf(":%d", s.config.Port) + s.isActive = true server := &http.Server{ Addr: listenAddress, - Handler: router.Handler(), + Handler: s.router.Handler(), } go func() { - <-stopChan + <-s.stopServer logger.Info("Shutting down server...") err := server.Shutdown(context.TODO()) if err != nil { @@ -86,24 +83,22 @@ func StartAPI() *Server { }() go func() { - if config.Config.DisableTls { + if s.config.DisableTls { logger.Infof("Listening and serving HTTP on %s\n", server.Addr) err := server.ListenAndServe() if err != nil { logger.Error("Failed to start HTTP server:", err) } - return - } - - if config.Config.TLS_CertificatePath != "" && config.Config.TLS_CertificateKey != "" { + s.isActive = false + } else if s.config.TLS_CertificatePath != "" && s.config.TLS_CertificateKey != "" { logger.Infof("Listening and serving HTTPS on %s\n", server.Addr) err := server.ListenAndServeTLS( - config.Config.TLS_CertificatePath, - config.Config.TLS_CertificateKey) + s.config.TLS_CertificatePath, + s.config.TLS_CertificateKey) if err != nil { logger.Error("Failed to start HTTPS server:", err) } - return + s.isActive = false } else { tlsConfig := tlsprovider.GetDefaultTlsConfig() server.TLSConfig = tlsConfig @@ -113,9 +108,7 @@ func StartAPI() *Server { if err != nil { logger.Error("Failed to start HTTPS server:", err) } - return + s.isActive = false } }() - - return &Server{StopServer: stopChan} } diff --git a/api/tests/authentication_test.go b/api/tests/authentication_test.go index acc30b6..e519d07 100644 --- a/api/tests/authentication_test.go +++ b/api/tests/authentication_test.go @@ -11,16 +11,15 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos" "github.com/pikami/cosmium/api/config" - "github.com/pikami/cosmium/internal/repositories" "github.com/stretchr/testify/assert" ) func Test_Authentication(t *testing.T) { ts := runTestServer() - defer ts.Close() + defer ts.Server.Close() t.Run("Should get 200 when correct account key is used", func(t *testing.T) { - repositories.DeleteDatabase(testDatabaseName) + ts.Repository.DeleteDatabase(testDatabaseName) client, err := azcosmos.NewClientFromConnectionString( fmt.Sprintf("AccountEndpoint=%s;AccountKey=%s", ts.URL, config.DefaultAccountKey), &azcosmos.ClientOptions{}, @@ -35,26 +34,8 @@ func Test_Authentication(t *testing.T) { assert.Equal(t, createResponse.DatabaseProperties.ID, testDatabaseName) }) - t.Run("Should get 200 when wrong account key is used, but authentication is dissabled", func(t *testing.T) { - config.Config.DisableAuth = true - repositories.DeleteDatabase(testDatabaseName) - client, err := azcosmos.NewClientFromConnectionString( - fmt.Sprintf("AccountEndpoint=%s;AccountKey=%s", ts.URL, "AAAA"), - &azcosmos.ClientOptions{}, - ) - assert.Nil(t, err) - - createResponse, err := client.CreateDatabase( - context.TODO(), - azcosmos.DatabaseProperties{ID: testDatabaseName}, - &azcosmos.CreateDatabaseOptions{}) - assert.Nil(t, err) - assert.Equal(t, createResponse.DatabaseProperties.ID, testDatabaseName) - config.Config.DisableAuth = false - }) - t.Run("Should get 401 when wrong account key is used", func(t *testing.T) { - repositories.DeleteDatabase(testDatabaseName) + ts.Repository.DeleteDatabase(testDatabaseName) client, err := azcosmos.NewClientFromConnectionString( fmt.Sprintf("AccountEndpoint=%s;AccountKey=%s", ts.URL, "AAAA"), &azcosmos.ClientOptions{}, @@ -85,3 +66,29 @@ func Test_Authentication(t *testing.T) { assert.Contains(t, string(responseBody), "BACKEND_ENDPOINT") }) } + +func Test_Authentication_Disabled(t *testing.T) { + ts := runTestServerCustomConfig(config.ServerConfig{ + AccountKey: config.DefaultAccountKey, + ExplorerPath: "/tmp/nothing", + ExplorerBaseUrlLocation: config.ExplorerBaseUrlLocation, + DisableAuth: true, + }) + defer ts.Server.Close() + + t.Run("Should get 200 when wrong account key is used, but authentication is dissabled", func(t *testing.T) { + ts.Repository.DeleteDatabase(testDatabaseName) + client, err := azcosmos.NewClientFromConnectionString( + fmt.Sprintf("AccountEndpoint=%s;AccountKey=%s", ts.URL, "AAAA"), + &azcosmos.ClientOptions{}, + ) + assert.Nil(t, err) + + createResponse, err := client.CreateDatabase( + context.TODO(), + azcosmos.DatabaseProperties{ID: testDatabaseName}, + &azcosmos.CreateDatabaseOptions{}) + assert.Nil(t, err) + assert.Equal(t, createResponse.DatabaseProperties.ID, testDatabaseName) + }) +} diff --git a/api/tests/collections_test.go b/api/tests/collections_test.go index 2563988..68db1e4 100644 --- a/api/tests/collections_test.go +++ b/api/tests/collections_test.go @@ -10,22 +10,21 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos" "github.com/pikami/cosmium/api/config" - "github.com/pikami/cosmium/internal/repositories" repositorymodels "github.com/pikami/cosmium/internal/repository_models" "github.com/stretchr/testify/assert" ) func Test_Collections(t *testing.T) { ts := runTestServer() - defer ts.Close() + defer ts.Server.Close() client, err := azcosmos.NewClientFromConnectionString( - fmt.Sprintf("AccountEndpoint=%s;AccountKey=%s", ts.URL, config.Config.AccountKey), + fmt.Sprintf("AccountEndpoint=%s;AccountKey=%s", ts.URL, config.DefaultAccountKey), &azcosmos.ClientOptions{}, ) assert.Nil(t, err) - repositories.CreateDatabase(repositorymodels.Database{ID: testDatabaseName}) + ts.Repository.CreateDatabase(repositorymodels.Database{ID: testDatabaseName}) databaseClient, err := client.NewDatabase(testDatabaseName) assert.Nil(t, err) @@ -40,7 +39,7 @@ func Test_Collections(t *testing.T) { }) t.Run("Should return conflict when collection exists", func(t *testing.T) { - repositories.CreateCollection(testDatabaseName, repositorymodels.Collection{ + ts.Repository.CreateCollection(testDatabaseName, repositorymodels.Collection{ ID: testCollectionName, }) @@ -60,7 +59,7 @@ func Test_Collections(t *testing.T) { t.Run("Collection Read", func(t *testing.T) { t.Run("Should read collection", func(t *testing.T) { - repositories.CreateCollection(testDatabaseName, repositorymodels.Collection{ + ts.Repository.CreateCollection(testDatabaseName, repositorymodels.Collection{ ID: testCollectionName, }) @@ -74,7 +73,7 @@ func Test_Collections(t *testing.T) { }) t.Run("Should return not found when collection does not exist", func(t *testing.T) { - repositories.DeleteCollection(testDatabaseName, testCollectionName) + ts.Repository.DeleteCollection(testDatabaseName, testCollectionName) collectionResponse, err := databaseClient.NewContainer(testCollectionName) assert.Nil(t, err) @@ -93,7 +92,7 @@ func Test_Collections(t *testing.T) { t.Run("Collection Delete", func(t *testing.T) { t.Run("Should delete collection", func(t *testing.T) { - repositories.CreateCollection(testDatabaseName, repositorymodels.Collection{ + ts.Repository.CreateCollection(testDatabaseName, repositorymodels.Collection{ ID: testCollectionName, }) @@ -106,7 +105,7 @@ func Test_Collections(t *testing.T) { }) t.Run("Should return not found when collection does not exist", func(t *testing.T) { - repositories.DeleteCollection(testDatabaseName, testCollectionName) + ts.Repository.DeleteCollection(testDatabaseName, testCollectionName) collectionResponse, err := databaseClient.NewContainer(testCollectionName) assert.Nil(t, err) diff --git a/api/tests/config_test.go b/api/tests/config_test.go index 065ecc3..2fe2e72 100644 --- a/api/tests/config_test.go +++ b/api/tests/config_test.go @@ -5,14 +5,37 @@ import ( "github.com/pikami/cosmium/api" "github.com/pikami/cosmium/api/config" + "github.com/pikami/cosmium/internal/repositories" ) -func runTestServer() *httptest.Server { - config.Config.AccountKey = config.DefaultAccountKey - config.Config.ExplorerPath = "/tmp/nothing" - config.Config.ExplorerBaseUrlLocation = config.ExplorerBaseUrlLocation +type TestServer struct { + Server *httptest.Server + Repository *repositories.DataRepository + URL string +} - return httptest.NewServer(api.CreateRouter()) +func runTestServerCustomConfig(config config.ServerConfig) *TestServer { + repository := repositories.NewDataRepository(repositories.RepositoryOptions{}) + + api := api.NewApiServer(repository, config) + + server := httptest.NewServer(api.GetRouter()) + + return &TestServer{ + Server: server, + Repository: repository, + URL: server.URL, + } +} + +func runTestServer() *TestServer { + config := config.ServerConfig{ + AccountKey: config.DefaultAccountKey, + ExplorerPath: "/tmp/nothing", + ExplorerBaseUrlLocation: config.ExplorerBaseUrlLocation, + } + + return runTestServerCustomConfig(config) } const ( diff --git a/api/tests/databases_test.go b/api/tests/databases_test.go index 012c2ed..0c0bec1 100644 --- a/api/tests/databases_test.go +++ b/api/tests/databases_test.go @@ -10,24 +10,23 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos" "github.com/pikami/cosmium/api/config" - "github.com/pikami/cosmium/internal/repositories" repositorymodels "github.com/pikami/cosmium/internal/repository_models" "github.com/stretchr/testify/assert" ) func Test_Databases(t *testing.T) { ts := runTestServer() - defer ts.Close() + defer ts.Server.Close() client, err := azcosmos.NewClientFromConnectionString( - fmt.Sprintf("AccountEndpoint=%s;AccountKey=%s", ts.URL, config.Config.AccountKey), + fmt.Sprintf("AccountEndpoint=%s;AccountKey=%s", ts.URL, config.DefaultAccountKey), &azcosmos.ClientOptions{}, ) assert.Nil(t, err) t.Run("Database Create", func(t *testing.T) { t.Run("Should create database", func(t *testing.T) { - repositories.DeleteDatabase(testDatabaseName) + ts.Repository.DeleteDatabase(testDatabaseName) createResponse, err := client.CreateDatabase(context.TODO(), azcosmos.DatabaseProperties{ ID: testDatabaseName, @@ -38,7 +37,7 @@ func Test_Databases(t *testing.T) { }) t.Run("Should return conflict when database exists", func(t *testing.T) { - repositories.CreateDatabase(repositorymodels.Database{ + ts.Repository.CreateDatabase(repositorymodels.Database{ ID: testDatabaseName, }) @@ -58,7 +57,7 @@ func Test_Databases(t *testing.T) { t.Run("Database Read", func(t *testing.T) { t.Run("Should read database", func(t *testing.T) { - repositories.CreateDatabase(repositorymodels.Database{ + ts.Repository.CreateDatabase(repositorymodels.Database{ ID: testDatabaseName, }) @@ -72,7 +71,7 @@ func Test_Databases(t *testing.T) { }) t.Run("Should return not found when database does not exist", func(t *testing.T) { - repositories.DeleteDatabase(testDatabaseName) + ts.Repository.DeleteDatabase(testDatabaseName) databaseResponse, err := client.NewDatabase(testDatabaseName) assert.Nil(t, err) @@ -91,7 +90,7 @@ func Test_Databases(t *testing.T) { t.Run("Database Delete", func(t *testing.T) { t.Run("Should delete database", func(t *testing.T) { - repositories.CreateDatabase(repositorymodels.Database{ + ts.Repository.CreateDatabase(repositorymodels.Database{ ID: testDatabaseName, }) @@ -104,7 +103,7 @@ func Test_Databases(t *testing.T) { }) t.Run("Should return not found when database does not exist", func(t *testing.T) { - repositories.DeleteDatabase(testDatabaseName) + ts.Repository.DeleteDatabase(testDatabaseName) databaseResponse, err := client.NewDatabase(testDatabaseName) assert.Nil(t, err) diff --git a/api/tests/documents_test.go b/api/tests/documents_test.go index cf6c08e..1499ee0 100644 --- a/api/tests/documents_test.go +++ b/api/tests/documents_test.go @@ -6,7 +6,6 @@ import ( "errors" "fmt" "net/http" - "net/http/httptest" "reflect" "sync" "testing" @@ -15,7 +14,6 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos" "github.com/pikami/cosmium/api/config" - "github.com/pikami/cosmium/internal/repositories" repositorymodels "github.com/pikami/cosmium/internal/repository_models" "github.com/stretchr/testify/assert" ) @@ -55,9 +53,11 @@ func testCosmosQuery(t *testing.T, } } -func documents_InitializeDb(t *testing.T) (*httptest.Server, *azcosmos.ContainerClient) { - repositories.CreateDatabase(repositorymodels.Database{ID: testDatabaseName}) - repositories.CreateCollection(testDatabaseName, repositorymodels.Collection{ +func documents_InitializeDb(t *testing.T) (*TestServer, *azcosmos.ContainerClient) { + ts := runTestServer() + + ts.Repository.CreateDatabase(repositorymodels.Database{ID: testDatabaseName}) + ts.Repository.CreateCollection(testDatabaseName, repositorymodels.Collection{ ID: testCollectionName, PartitionKey: struct { Paths []string "json:\"paths\"" @@ -67,13 +67,11 @@ func documents_InitializeDb(t *testing.T) (*httptest.Server, *azcosmos.Container Paths: []string{"/pk"}, }, }) - repositories.CreateDocument(testDatabaseName, testCollectionName, map[string]interface{}{"id": "12345", "pk": "123", "isCool": false, "arr": []int{1, 2, 3}}) - repositories.CreateDocument(testDatabaseName, testCollectionName, map[string]interface{}{"id": "67890", "pk": "456", "isCool": true, "arr": []int{6, 7, 8}}) - - ts := runTestServer() + ts.Repository.CreateDocument(testDatabaseName, testCollectionName, map[string]interface{}{"id": "12345", "pk": "123", "isCool": false, "arr": []int{1, 2, 3}}) + ts.Repository.CreateDocument(testDatabaseName, testCollectionName, map[string]interface{}{"id": "67890", "pk": "456", "isCool": true, "arr": []int{6, 7, 8}}) client, err := azcosmos.NewClientFromConnectionString( - fmt.Sprintf("AccountEndpoint=%s;AccountKey=%s", ts.URL, config.Config.AccountKey), + fmt.Sprintf("AccountEndpoint=%s;AccountKey=%s", ts.URL, config.DefaultAccountKey), &azcosmos.ClientOptions{}, ) assert.Nil(t, err) @@ -86,7 +84,7 @@ func documents_InitializeDb(t *testing.T) (*httptest.Server, *azcosmos.Container func Test_Documents(t *testing.T) { ts, collectionClient := documents_InitializeDb(t) - defer ts.Close() + defer ts.Server.Close() t.Run("Should query document", func(t *testing.T) { testCosmosQuery(t, collectionClient, @@ -218,7 +216,7 @@ func Test_Documents(t *testing.T) { func Test_Documents_Patch(t *testing.T) { ts, collectionClient := documents_InitializeDb(t) - defer ts.Close() + defer ts.Server.Close() t.Run("Should PATCH document", func(t *testing.T) { context := context.TODO() diff --git a/api/tests/documents_trailingslash_test.go b/api/tests/documents_trailingslash_test.go index 04d461f..a82acde 100644 --- a/api/tests/documents_trailingslash_test.go +++ b/api/tests/documents_trailingslash_test.go @@ -15,14 +15,14 @@ import ( // Request document with trailing slash like python cosmosdb client does. func Test_Documents_Read_Trailing_Slash(t *testing.T) { ts, _ := documents_InitializeDb(t) - defer ts.Close() + defer ts.Server.Close() t.Run("Read doc with client that appends slash to path", func(t *testing.T) { resourceIdTemplate := "dbs/%s/colls/%s/docs/%s" path := fmt.Sprintf(resourceIdTemplate, testDatabaseName, testCollectionName, "12345") testUrl := ts.URL + "/" + path + "/" date := time.Now().Format(time.RFC1123) - signature := authentication.GenerateSignature("GET", "docs", path, date, config.Config.AccountKey) + signature := authentication.GenerateSignature("GET", "docs", path, date, config.DefaultAccountKey) httpClient := &http.Client{} req, _ := http.NewRequest("GET", testUrl, nil) req.Header.Add("x-ms-date", date) diff --git a/cmd/server/server.go b/cmd/server/server.go index 5cfb8cc..6b8bdae 100644 --- a/cmd/server/server.go +++ b/cmd/server/server.go @@ -11,16 +11,20 @@ import ( ) func main() { - config.ParseFlags() + configuration := config.ParseFlags() - repositories.InitializeRepository() + repository := repositories.NewDataRepository(repositories.RepositoryOptions{ + InitialDataFilePath: configuration.InitialDataFilePath, + PersistDataFilePath: configuration.PersistDataFilePath, + }) - server := api.StartAPI() + server := api.NewApiServer(repository, configuration) + server.Start() - waitForExit(server) + waitForExit(server, repository, configuration) } -func waitForExit(server *api.Server) { +func waitForExit(server *api.ApiServer, repository *repositories.DataRepository, config config.ServerConfig) { sigs := make(chan os.Signal, 1) signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) @@ -28,9 +32,9 @@ func waitForExit(server *api.Server) { <-sigs // Stop the server - server.StopServer <- true + server.Stop() - if config.Config.PersistDataFilePath != "" { - repositories.SaveStateFS(config.Config.PersistDataFilePath) + if config.PersistDataFilePath != "" { + repository.SaveStateFS(config.PersistDataFilePath) } } diff --git a/internal/logger/logger.go b/internal/logger/logger.go index 01590a3..f2d8710 100644 --- a/internal/logger/logger.go +++ b/internal/logger/logger.go @@ -3,22 +3,22 @@ package logger import ( "log" "os" - - "github.com/pikami/cosmium/api/config" ) +var EnableDebugOutput = false + var DebugLogger = log.New(os.Stdout, "", log.Ldate|log.Ltime|log.Lshortfile) var InfoLogger = log.New(os.Stdout, "", log.Ldate|log.Ltime) var ErrorLogger = log.New(os.Stderr, "", log.Ldate|log.Ltime|log.Lshortfile) func Debug(v ...any) { - if config.Config.Debug { + if EnableDebugOutput { DebugLogger.Println(v...) } } func Debugf(format string, v ...any) { - if config.Config.Debug { + if EnableDebugOutput { DebugLogger.Printf(format, v...) } } diff --git a/internal/repositories/collections.go b/internal/repositories/collections.go index 54e3705..95bd4b3 100644 --- a/internal/repositories/collections.go +++ b/internal/repositories/collections.go @@ -11,60 +11,60 @@ import ( "golang.org/x/exp/maps" ) -func GetAllCollections(databaseId string) ([]repositorymodels.Collection, repositorymodels.RepositoryStatus) { - storeState.RLock() - defer storeState.RUnlock() +func (r *DataRepository) GetAllCollections(databaseId string) ([]repositorymodels.Collection, repositorymodels.RepositoryStatus) { + r.storeState.RLock() + defer r.storeState.RUnlock() - if _, ok := storeState.Databases[databaseId]; !ok { + if _, ok := r.storeState.Databases[databaseId]; !ok { return make([]repositorymodels.Collection, 0), repositorymodels.StatusNotFound } - return maps.Values(storeState.Collections[databaseId]), repositorymodels.StatusOk + return maps.Values(r.storeState.Collections[databaseId]), repositorymodels.StatusOk } -func GetCollection(databaseId string, collectionId string) (repositorymodels.Collection, repositorymodels.RepositoryStatus) { - storeState.RLock() - defer storeState.RUnlock() +func (r *DataRepository) GetCollection(databaseId string, collectionId string) (repositorymodels.Collection, repositorymodels.RepositoryStatus) { + r.storeState.RLock() + defer r.storeState.RUnlock() - if _, ok := storeState.Databases[databaseId]; !ok { + if _, ok := r.storeState.Databases[databaseId]; !ok { return repositorymodels.Collection{}, repositorymodels.StatusNotFound } - if _, ok := storeState.Collections[databaseId][collectionId]; !ok { + if _, ok := r.storeState.Collections[databaseId][collectionId]; !ok { return repositorymodels.Collection{}, repositorymodels.StatusNotFound } - return storeState.Collections[databaseId][collectionId], repositorymodels.StatusOk + return r.storeState.Collections[databaseId][collectionId], repositorymodels.StatusOk } -func DeleteCollection(databaseId string, collectionId string) repositorymodels.RepositoryStatus { - storeState.Lock() - defer storeState.Unlock() +func (r *DataRepository) DeleteCollection(databaseId string, collectionId string) repositorymodels.RepositoryStatus { + r.storeState.Lock() + defer r.storeState.Unlock() - if _, ok := storeState.Databases[databaseId]; !ok { + if _, ok := r.storeState.Databases[databaseId]; !ok { return repositorymodels.StatusNotFound } - if _, ok := storeState.Collections[databaseId][collectionId]; !ok { + if _, ok := r.storeState.Collections[databaseId][collectionId]; !ok { return repositorymodels.StatusNotFound } - delete(storeState.Collections[databaseId], collectionId) + delete(r.storeState.Collections[databaseId], collectionId) return repositorymodels.StatusOk } -func CreateCollection(databaseId string, newCollection repositorymodels.Collection) (repositorymodels.Collection, repositorymodels.RepositoryStatus) { - storeState.Lock() - defer storeState.Unlock() +func (r *DataRepository) CreateCollection(databaseId string, newCollection repositorymodels.Collection) (repositorymodels.Collection, repositorymodels.RepositoryStatus) { + r.storeState.Lock() + defer r.storeState.Unlock() var ok bool var database repositorymodels.Database - if database, ok = storeState.Databases[databaseId]; !ok { + if database, ok = r.storeState.Databases[databaseId]; !ok { return repositorymodels.Collection{}, repositorymodels.StatusNotFound } - if _, ok = storeState.Collections[databaseId][newCollection.ID]; ok { + if _, ok = r.storeState.Collections[databaseId][newCollection.ID]; ok { return repositorymodels.Collection{}, repositorymodels.Conflict } @@ -75,8 +75,8 @@ func CreateCollection(databaseId string, newCollection repositorymodels.Collecti newCollection.ETag = fmt.Sprintf("\"%s\"", uuid.New()) newCollection.Self = fmt.Sprintf("dbs/%s/colls/%s/", database.ResourceID, newCollection.ResourceID) - storeState.Collections[databaseId][newCollection.ID] = newCollection - storeState.Documents[databaseId][newCollection.ID] = make(map[string]repositorymodels.Document) + r.storeState.Collections[databaseId][newCollection.ID] = newCollection + r.storeState.Documents[databaseId][newCollection.ID] = make(map[string]repositorymodels.Document) return newCollection, repositorymodels.StatusOk } diff --git a/internal/repositories/databases.go b/internal/repositories/databases.go index abd609a..6a5b7be 100644 --- a/internal/repositories/databases.go +++ b/internal/repositories/databases.go @@ -10,42 +10,42 @@ import ( "golang.org/x/exp/maps" ) -func GetAllDatabases() ([]repositorymodels.Database, repositorymodels.RepositoryStatus) { - storeState.RLock() - defer storeState.RUnlock() +func (r *DataRepository) GetAllDatabases() ([]repositorymodels.Database, repositorymodels.RepositoryStatus) { + r.storeState.RLock() + defer r.storeState.RUnlock() - return maps.Values(storeState.Databases), repositorymodels.StatusOk + return maps.Values(r.storeState.Databases), repositorymodels.StatusOk } -func GetDatabase(id string) (repositorymodels.Database, repositorymodels.RepositoryStatus) { - storeState.RLock() - defer storeState.RUnlock() +func (r *DataRepository) GetDatabase(id string) (repositorymodels.Database, repositorymodels.RepositoryStatus) { + r.storeState.RLock() + defer r.storeState.RUnlock() - if database, ok := storeState.Databases[id]; ok { + if database, ok := r.storeState.Databases[id]; ok { return database, repositorymodels.StatusOk } return repositorymodels.Database{}, repositorymodels.StatusNotFound } -func DeleteDatabase(id string) repositorymodels.RepositoryStatus { - storeState.Lock() - defer storeState.Unlock() +func (r *DataRepository) DeleteDatabase(id string) repositorymodels.RepositoryStatus { + r.storeState.Lock() + defer r.storeState.Unlock() - if _, ok := storeState.Databases[id]; !ok { + if _, ok := r.storeState.Databases[id]; !ok { return repositorymodels.StatusNotFound } - delete(storeState.Databases, id) + delete(r.storeState.Databases, id) return repositorymodels.StatusOk } -func CreateDatabase(newDatabase repositorymodels.Database) (repositorymodels.Database, repositorymodels.RepositoryStatus) { - storeState.Lock() - defer storeState.Unlock() +func (r *DataRepository) CreateDatabase(newDatabase repositorymodels.Database) (repositorymodels.Database, repositorymodels.RepositoryStatus) { + r.storeState.Lock() + defer r.storeState.Unlock() - if _, ok := storeState.Databases[newDatabase.ID]; ok { + if _, ok := r.storeState.Databases[newDatabase.ID]; ok { return repositorymodels.Database{}, repositorymodels.Conflict } @@ -54,9 +54,9 @@ func CreateDatabase(newDatabase repositorymodels.Database) (repositorymodels.Dat newDatabase.ETag = fmt.Sprintf("\"%s\"", uuid.New()) newDatabase.Self = fmt.Sprintf("dbs/%s/", newDatabase.ResourceID) - storeState.Databases[newDatabase.ID] = newDatabase - storeState.Collections[newDatabase.ID] = make(map[string]repositorymodels.Collection) - storeState.Documents[newDatabase.ID] = make(map[string]map[string]repositorymodels.Document) + r.storeState.Databases[newDatabase.ID] = newDatabase + r.storeState.Collections[newDatabase.ID] = make(map[string]repositorymodels.Collection) + r.storeState.Documents[newDatabase.ID] = make(map[string]map[string]repositorymodels.Document) return newDatabase, repositorymodels.StatusOk } diff --git a/internal/repositories/documents.go b/internal/repositories/documents.go index 0a6feb6..187840d 100644 --- a/internal/repositories/documents.go +++ b/internal/repositories/documents.go @@ -14,64 +14,64 @@ import ( "golang.org/x/exp/maps" ) -func GetAllDocuments(databaseId string, collectionId string) ([]repositorymodels.Document, repositorymodels.RepositoryStatus) { - storeState.RLock() - defer storeState.RUnlock() +func (r *DataRepository) GetAllDocuments(databaseId string, collectionId string) ([]repositorymodels.Document, repositorymodels.RepositoryStatus) { + r.storeState.RLock() + defer r.storeState.RUnlock() - if _, ok := storeState.Databases[databaseId]; !ok { + if _, ok := r.storeState.Databases[databaseId]; !ok { return make([]repositorymodels.Document, 0), repositorymodels.StatusNotFound } - if _, ok := storeState.Collections[databaseId][collectionId]; !ok { + if _, ok := r.storeState.Collections[databaseId][collectionId]; !ok { return make([]repositorymodels.Document, 0), repositorymodels.StatusNotFound } - return maps.Values(storeState.Documents[databaseId][collectionId]), repositorymodels.StatusOk + return maps.Values(r.storeState.Documents[databaseId][collectionId]), repositorymodels.StatusOk } -func GetDocument(databaseId string, collectionId string, documentId string) (repositorymodels.Document, repositorymodels.RepositoryStatus) { - storeState.RLock() - defer storeState.RUnlock() +func (r *DataRepository) GetDocument(databaseId string, collectionId string, documentId string) (repositorymodels.Document, repositorymodels.RepositoryStatus) { + r.storeState.RLock() + defer r.storeState.RUnlock() - if _, ok := storeState.Databases[databaseId]; !ok { + if _, ok := r.storeState.Databases[databaseId]; !ok { return repositorymodels.Document{}, repositorymodels.StatusNotFound } - if _, ok := storeState.Collections[databaseId][collectionId]; !ok { + if _, ok := r.storeState.Collections[databaseId][collectionId]; !ok { return repositorymodels.Document{}, repositorymodels.StatusNotFound } - if _, ok := storeState.Documents[databaseId][collectionId][documentId]; !ok { + if _, ok := r.storeState.Documents[databaseId][collectionId][documentId]; !ok { return repositorymodels.Document{}, repositorymodels.StatusNotFound } - return storeState.Documents[databaseId][collectionId][documentId], repositorymodels.StatusOk + return r.storeState.Documents[databaseId][collectionId][documentId], repositorymodels.StatusOk } -func DeleteDocument(databaseId string, collectionId string, documentId string) repositorymodels.RepositoryStatus { - storeState.Lock() - defer storeState.Unlock() +func (r *DataRepository) DeleteDocument(databaseId string, collectionId string, documentId string) repositorymodels.RepositoryStatus { + r.storeState.Lock() + defer r.storeState.Unlock() - if _, ok := storeState.Databases[databaseId]; !ok { + if _, ok := r.storeState.Databases[databaseId]; !ok { return repositorymodels.StatusNotFound } - if _, ok := storeState.Collections[databaseId][collectionId]; !ok { + if _, ok := r.storeState.Collections[databaseId][collectionId]; !ok { return repositorymodels.StatusNotFound } - if _, ok := storeState.Documents[databaseId][collectionId][documentId]; !ok { + if _, ok := r.storeState.Documents[databaseId][collectionId][documentId]; !ok { return repositorymodels.StatusNotFound } - delete(storeState.Documents[databaseId][collectionId], documentId) + delete(r.storeState.Documents[databaseId][collectionId], documentId) return repositorymodels.StatusOk } -func CreateDocument(databaseId string, collectionId string, document map[string]interface{}) (repositorymodels.Document, repositorymodels.RepositoryStatus) { - storeState.Lock() - defer storeState.Unlock() +func (r *DataRepository) CreateDocument(databaseId string, collectionId string, document map[string]interface{}) (repositorymodels.Document, repositorymodels.RepositoryStatus) { + r.storeState.Lock() + defer r.storeState.Unlock() var ok bool var documentId string @@ -82,15 +82,15 @@ func CreateDocument(databaseId string, collectionId string, document map[string] document["id"] = documentId } - if database, ok = storeState.Databases[databaseId]; !ok { + if database, ok = r.storeState.Databases[databaseId]; !ok { return repositorymodels.Document{}, repositorymodels.StatusNotFound } - if collection, ok = storeState.Collections[databaseId][collectionId]; !ok { + if collection, ok = r.storeState.Collections[databaseId][collectionId]; !ok { return repositorymodels.Document{}, repositorymodels.StatusNotFound } - if _, ok := storeState.Documents[databaseId][collectionId][documentId]; ok { + if _, ok := r.storeState.Documents[databaseId][collectionId][documentId]; ok { return repositorymodels.Document{}, repositorymodels.Conflict } @@ -99,19 +99,19 @@ func CreateDocument(databaseId string, collectionId string, document map[string] document["_etag"] = fmt.Sprintf("\"%s\"", uuid.New()) document["_self"] = fmt.Sprintf("dbs/%s/colls/%s/docs/%s/", database.ResourceID, collection.ResourceID, document["_rid"]) - storeState.Documents[databaseId][collectionId][documentId] = document + r.storeState.Documents[databaseId][collectionId][documentId] = document return document, repositorymodels.StatusOk } -func ExecuteQueryDocuments(databaseId string, collectionId string, query string, queryParameters map[string]interface{}) ([]memoryexecutor.RowType, repositorymodels.RepositoryStatus) { +func (r *DataRepository) ExecuteQueryDocuments(databaseId string, collectionId string, query string, queryParameters map[string]interface{}) ([]memoryexecutor.RowType, repositorymodels.RepositoryStatus) { parsedQuery, err := nosql.Parse("", []byte(query)) if err != nil { log.Printf("Failed to parse query: %s\nerr: %v", query, err) return nil, repositorymodels.BadRequest } - collectionDocuments, status := GetAllDocuments(databaseId, collectionId) + collectionDocuments, status := r.GetAllDocuments(databaseId, collectionId) if status != repositorymodels.StatusOk { return nil, status } diff --git a/internal/repositories/partition_key_ranges.go b/internal/repositories/partition_key_ranges.go index 4b1ff91..8424ed1 100644 --- a/internal/repositories/partition_key_ranges.go +++ b/internal/repositories/partition_key_ranges.go @@ -9,19 +9,19 @@ import ( ) // I have no idea what this is tbh -func GetPartitionKeyRanges(databaseId string, collectionId string) ([]repositorymodels.PartitionKeyRange, repositorymodels.RepositoryStatus) { - storeState.RLock() - defer storeState.RUnlock() +func (r *DataRepository) GetPartitionKeyRanges(databaseId string, collectionId string) ([]repositorymodels.PartitionKeyRange, repositorymodels.RepositoryStatus) { + r.storeState.RLock() + defer r.storeState.RUnlock() databaseRid := databaseId collectionRid := collectionId var timestamp int64 = 0 - if database, ok := storeState.Databases[databaseId]; !ok { + if database, ok := r.storeState.Databases[databaseId]; !ok { databaseRid = database.ResourceID } - if collection, ok := storeState.Collections[databaseId][collectionId]; !ok { + if collection, ok := r.storeState.Collections[databaseId][collectionId]; !ok { collectionRid = collection.ResourceID timestamp = collection.TimeStamp } diff --git a/internal/repositories/repositories.go b/internal/repositories/repositories.go new file mode 100644 index 0000000..34fae53 --- /dev/null +++ b/internal/repositories/repositories.go @@ -0,0 +1,37 @@ +package repositories + +import repositorymodels "github.com/pikami/cosmium/internal/repository_models" + +type DataRepository struct { + storedProcedures []repositorymodels.StoredProcedure + triggers []repositorymodels.Trigger + userDefinedFunctions []repositorymodels.UserDefinedFunction + storeState repositorymodels.State + + initialDataFilePath string + persistDataFilePath string +} + +type RepositoryOptions struct { + InitialDataFilePath string + PersistDataFilePath string +} + +func NewDataRepository(options RepositoryOptions) *DataRepository { + repository := &DataRepository{ + storedProcedures: []repositorymodels.StoredProcedure{}, + triggers: []repositorymodels.Trigger{}, + userDefinedFunctions: []repositorymodels.UserDefinedFunction{}, + storeState: repositorymodels.State{ + Databases: make(map[string]repositorymodels.Database), + Collections: make(map[string]map[string]repositorymodels.Collection), + Documents: make(map[string]map[string]map[string]repositorymodels.Document), + }, + initialDataFilePath: options.InitialDataFilePath, + persistDataFilePath: options.PersistDataFilePath, + } + + repository.InitializeRepository() + + return repository +} diff --git a/internal/repositories/state.go b/internal/repositories/state.go index 063f436..0c0438b 100644 --- a/internal/repositories/state.go +++ b/internal/repositories/state.go @@ -6,28 +6,18 @@ import ( "os" "reflect" - "github.com/pikami/cosmium/api/config" "github.com/pikami/cosmium/internal/logger" repositorymodels "github.com/pikami/cosmium/internal/repository_models" ) -var storedProcedures = []repositorymodels.StoredProcedure{} -var triggers = []repositorymodels.Trigger{} -var userDefinedFunctions = []repositorymodels.UserDefinedFunction{} -var storeState = repositorymodels.State{ - Databases: make(map[string]repositorymodels.Database), - Collections: make(map[string]map[string]repositorymodels.Collection), - Documents: make(map[string]map[string]map[string]repositorymodels.Document), -} - -func InitializeRepository() { - if config.Config.InitialDataFilePath != "" { - LoadStateFS(config.Config.InitialDataFilePath) +func (r *DataRepository) InitializeRepository() { + if r.initialDataFilePath != "" { + r.LoadStateFS(r.initialDataFilePath) return } - if config.Config.PersistDataFilePath != "" { - stat, err := os.Stat(config.Config.PersistDataFilePath) + if r.persistDataFilePath != "" { + stat, err := os.Stat(r.persistDataFilePath) if err != nil { return } @@ -37,12 +27,12 @@ func InitializeRepository() { os.Exit(1) } - LoadStateFS(config.Config.PersistDataFilePath) + r.LoadStateFS(r.persistDataFilePath) return } } -func LoadStateFS(filePath string) { +func (r *DataRepository) LoadStateFS(filePath string) { data, err := os.ReadFile(filePath) if err != nil { log.Fatalf("Error reading state JSON file: %v", err) @@ -60,16 +50,16 @@ func LoadStateFS(filePath string) { logger.Infof("Collections: %d\n", getLength(state.Collections)) logger.Infof("Documents: %d\n", getLength(state.Documents)) - storeState = state + r.storeState = state - ensureStoreStateNoNullReferences() + r.ensureStoreStateNoNullReferences() } -func SaveStateFS(filePath string) { - storeState.RLock() - defer storeState.RUnlock() +func (r *DataRepository) SaveStateFS(filePath string) { + r.storeState.RLock() + defer r.storeState.RUnlock() - data, err := json.MarshalIndent(storeState, "", "\t") + data, err := json.MarshalIndent(r.storeState, "", "\t") if err != nil { logger.Errorf("Failed to save state: %v\n", err) return @@ -78,16 +68,16 @@ func SaveStateFS(filePath string) { os.WriteFile(filePath, data, os.ModePerm) logger.Info("Saved state:") - logger.Infof("Databases: %d\n", getLength(storeState.Databases)) - logger.Infof("Collections: %d\n", getLength(storeState.Collections)) - logger.Infof("Documents: %d\n", getLength(storeState.Documents)) + logger.Infof("Databases: %d\n", getLength(r.storeState.Databases)) + logger.Infof("Collections: %d\n", getLength(r.storeState.Collections)) + logger.Infof("Documents: %d\n", getLength(r.storeState.Documents)) } -func GetState() (string, error) { - storeState.RLock() - defer storeState.RUnlock() +func (r *DataRepository) GetState() (string, error) { + r.storeState.RLock() + defer r.storeState.RUnlock() - data, err := json.MarshalIndent(storeState, "", "\t") + data, err := json.MarshalIndent(r.storeState, "", "\t") if err != nil { logger.Errorf("Failed to serialize state: %v\n", err) return "", err @@ -121,36 +111,36 @@ func getLength(v interface{}) int { return count } -func ensureStoreStateNoNullReferences() { - if storeState.Databases == nil { - storeState.Databases = make(map[string]repositorymodels.Database) +func (r *DataRepository) ensureStoreStateNoNullReferences() { + if r.storeState.Databases == nil { + r.storeState.Databases = make(map[string]repositorymodels.Database) } - if storeState.Collections == nil { - storeState.Collections = make(map[string]map[string]repositorymodels.Collection) + if r.storeState.Collections == nil { + r.storeState.Collections = make(map[string]map[string]repositorymodels.Collection) } - if storeState.Documents == nil { - storeState.Documents = make(map[string]map[string]map[string]repositorymodels.Document) + if r.storeState.Documents == nil { + r.storeState.Documents = make(map[string]map[string]map[string]repositorymodels.Document) } - for database := range storeState.Databases { - if storeState.Collections[database] == nil { - storeState.Collections[database] = make(map[string]repositorymodels.Collection) + for database := range r.storeState.Databases { + if r.storeState.Collections[database] == nil { + r.storeState.Collections[database] = make(map[string]repositorymodels.Collection) } - if storeState.Documents[database] == nil { - storeState.Documents[database] = make(map[string]map[string]repositorymodels.Document) + if r.storeState.Documents[database] == nil { + r.storeState.Documents[database] = make(map[string]map[string]repositorymodels.Document) } - for collection := range storeState.Collections[database] { - if storeState.Documents[database][collection] == nil { - storeState.Documents[database][collection] = make(map[string]repositorymodels.Document) + for collection := range r.storeState.Collections[database] { + if r.storeState.Documents[database][collection] == nil { + r.storeState.Documents[database][collection] = make(map[string]repositorymodels.Document) } - for document := range storeState.Documents[database][collection] { - if storeState.Documents[database][collection][document] == nil { - delete(storeState.Documents[database][collection], document) + for document := range r.storeState.Documents[database][collection] { + if r.storeState.Documents[database][collection][document] == nil { + delete(r.storeState.Documents[database][collection], document) } } } diff --git a/internal/repositories/stored_procedures.go b/internal/repositories/stored_procedures.go index d54f62f..64d2e69 100644 --- a/internal/repositories/stored_procedures.go +++ b/internal/repositories/stored_procedures.go @@ -2,6 +2,6 @@ package repositories import repositorymodels "github.com/pikami/cosmium/internal/repository_models" -func GetAllStoredProcedures(databaseId string, collectionId string) ([]repositorymodels.StoredProcedure, repositorymodels.RepositoryStatus) { - return storedProcedures, repositorymodels.StatusOk +func (r *DataRepository) GetAllStoredProcedures(databaseId string, collectionId string) ([]repositorymodels.StoredProcedure, repositorymodels.RepositoryStatus) { + return r.storedProcedures, repositorymodels.StatusOk } diff --git a/internal/repositories/triggers.go b/internal/repositories/triggers.go index c1baa8f..4b0a75b 100644 --- a/internal/repositories/triggers.go +++ b/internal/repositories/triggers.go @@ -2,6 +2,6 @@ package repositories import repositorymodels "github.com/pikami/cosmium/internal/repository_models" -func GetAllTriggers(databaseId string, collectionId string) ([]repositorymodels.Trigger, repositorymodels.RepositoryStatus) { - return triggers, repositorymodels.StatusOk +func (r *DataRepository) GetAllTriggers(databaseId string, collectionId string) ([]repositorymodels.Trigger, repositorymodels.RepositoryStatus) { + return r.triggers, repositorymodels.StatusOk } diff --git a/internal/repositories/user_defined_functions.go b/internal/repositories/user_defined_functions.go index 0374460..3bc6ca7 100644 --- a/internal/repositories/user_defined_functions.go +++ b/internal/repositories/user_defined_functions.go @@ -2,6 +2,6 @@ package repositories import repositorymodels "github.com/pikami/cosmium/internal/repository_models" -func GetAllUserDefinedFunctions(databaseId string, collectionId string) ([]repositorymodels.UserDefinedFunction, repositorymodels.RepositoryStatus) { - return userDefinedFunctions, repositorymodels.StatusOk +func (r *DataRepository) GetAllUserDefinedFunctions(databaseId string, collectionId string) ([]repositorymodels.UserDefinedFunction, repositorymodels.RepositoryStatus) { + return r.userDefinedFunctions, repositorymodels.StatusOk } diff --git a/sharedlibrary/sharedlibrary.go b/sharedlibrary/sharedlibrary.go index bbacea6..88cd0cf 100644 --- a/sharedlibrary/sharedlibrary.go +++ b/sharedlibrary/sharedlibrary.go @@ -9,44 +9,77 @@ import ( "github.com/pikami/cosmium/internal/repositories" ) -var currentServer *api.Server +type ServerInstance struct { + server *api.ApiServer + repository *repositories.DataRepository +} + +var serverInstances map[string]*ServerInstance + +const ( + ResponseSuccess = 0 + ResponseUnknown = 1 + ResponseServerInstanceAlreadyExists = 2 + ResponseFailedToParseConfiguration = 3 + ResponseServerInstanceNotFound = 4 +) + +//export CreateServerInstance +func CreateServerInstance(serverName *C.char, configurationJSON *C.char) int { + if serverInstances == nil { + serverInstances = make(map[string]*ServerInstance) + } + + if _, ok := serverInstances[C.GoString(serverName)]; ok { + return ResponseServerInstanceAlreadyExists + } -//export Configure -func Configure(configurationJSON *C.char) bool { var configuration config.ServerConfig err := json.Unmarshal([]byte(C.GoString(configurationJSON)), &configuration) if err != nil { - return false + return ResponseFailedToParseConfiguration } - config.Config = configuration - return true -} -//export InitializeRepository -func InitializeRepository() { - repositories.InitializeRepository() -} + configuration.PopulateCalculatedFields() -//export StartAPI -func StartAPI() { - currentServer = api.StartAPI() -} + repository := repositories.NewDataRepository(repositories.RepositoryOptions{ + InitialDataFilePath: configuration.InitialDataFilePath, + PersistDataFilePath: configuration.PersistDataFilePath, + }) -//export StopAPI -func StopAPI() { - if currentServer == nil { - currentServer.StopServer <- true - currentServer = nil + server := api.NewApiServer(repository, configuration) + server.Start() + + serverInstances[C.GoString(serverName)] = &ServerInstance{ + server: server, + repository: repository, } + + return ResponseSuccess } -//export GetState -func GetState() *C.char { - stateJSON, err := repositories.GetState() - if err != nil { - return nil +//export StopServerInstance +func StopServerInstance(serverName *C.char) int { + if serverInstance, ok := serverInstances[C.GoString(serverName)]; ok { + serverInstance.server.Stop() + delete(serverInstances, C.GoString(serverName)) + return ResponseSuccess } - return C.CString(stateJSON) + + return ResponseServerInstanceNotFound +} + +//export GetServerInstanceState +func GetServerInstanceState(serverName *C.char) *C.char { + if serverInstance, ok := serverInstances[C.GoString(serverName)]; ok { + stateJSON, err := serverInstance.repository.GetState() + if err != nil { + return nil + } + return C.CString(stateJSON) + } + + return nil } func main() {}