Refactor to support multiple server instances in shared library

This commit is contained in:
Pijus Kamandulis 2024-12-18 19:39:57 +02:00
parent 84c33e3c8e
commit 777034181f
34 changed files with 503 additions and 369 deletions

35
api/api_server.go Normal file
View File

@ -0,0 +1,35 @@
package api
import (
"github.com/gin-gonic/gin"
"github.com/pikami/cosmium/api/config"
"github.com/pikami/cosmium/internal/repositories"
)
type ApiServer struct {
stopServer chan interface{}
isActive bool
router *gin.Engine
config config.ServerConfig
}
func NewApiServer(dataRepository *repositories.DataRepository, config config.ServerConfig) *ApiServer {
stopChan := make(chan interface{})
apiServer := &ApiServer{
stopServer: stopChan,
config: config,
}
apiServer.CreateRouter(dataRepository)
return apiServer
}
func (s *ApiServer) GetRouter() *gin.Engine {
return s.router
}
func (s *ApiServer) Stop() {
s.stopServer <- true
}

View File

@ -5,6 +5,8 @@ import (
"fmt" "fmt"
"os" "os"
"strings" "strings"
"github.com/pikami/cosmium/internal/logger"
) )
const ( const (
@ -13,9 +15,7 @@ const (
ExplorerBaseUrlLocation = "/_explorer" ExplorerBaseUrlLocation = "/_explorer"
) )
var Config = ServerConfig{} func ParseFlags() ServerConfig {
func ParseFlags() {
host := flag.String("Host", "localhost", "Hostname") host := flag.String("Host", "localhost", "Hostname")
port := flag.Int("Port", 8081, "Listen port") port := flag.Int("Port", 8081, "Listen port")
explorerPath := flag.String("ExplorerDir", "", "Path to cosmos-explorer files") explorerPath := flag.String("ExplorerDir", "", "Path to cosmos-explorer files")
@ -31,22 +31,30 @@ func ParseFlags() {
flag.Parse() flag.Parse()
setFlagsFromEnvironment() setFlagsFromEnvironment()
Config.Host = *host config := ServerConfig{}
Config.Port = *port config.Host = *host
Config.ExplorerPath = *explorerPath config.Port = *port
Config.TLS_CertificatePath = *tlsCertificatePath config.ExplorerPath = *explorerPath
Config.TLS_CertificateKey = *tlsCertificateKey config.TLS_CertificatePath = *tlsCertificatePath
Config.InitialDataFilePath = *initialDataPath config.TLS_CertificateKey = *tlsCertificateKey
Config.PersistDataFilePath = *persistDataPath config.InitialDataFilePath = *initialDataPath
Config.DisableAuth = *disableAuthentication config.PersistDataFilePath = *persistDataPath
Config.DisableTls = *disableTls config.DisableAuth = *disableAuthentication
Config.Debug = *debug config.DisableTls = *disableTls
config.Debug = *debug
config.AccountKey = *accountKey
Config.DatabaseAccount = Config.Host config.PopulateCalculatedFields()
Config.DatabaseDomain = Config.Host
Config.DatabaseEndpoint = fmt.Sprintf("https://%s:%d/", Config.Host, Config.Port) return config
Config.AccountKey = *accountKey }
Config.ExplorerBaseUrlLocation = ExplorerBaseUrlLocation
func (c *ServerConfig) PopulateCalculatedFields() {
c.DatabaseAccount = c.Host
c.DatabaseDomain = c.Host
c.DatabaseEndpoint = fmt.Sprintf("https://%s:%d/", c.Host, c.Port)
c.ExplorerBaseUrlLocation = ExplorerBaseUrlLocation
logger.EnableDebugOutput = c.Debug
} }
func setFlagsFromEnvironment() (err error) { func setFlagsFromEnvironment() (err error) {

View File

@ -5,16 +5,15 @@ import (
"net/http" "net/http"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/pikami/cosmium/internal/repositories"
repositorymodels "github.com/pikami/cosmium/internal/repository_models" repositorymodels "github.com/pikami/cosmium/internal/repository_models"
) )
func GetAllCollections(c *gin.Context) { func (h *Handlers) GetAllCollections(c *gin.Context) {
databaseId := c.Param("databaseId") databaseId := c.Param("databaseId")
collections, status := repositories.GetAllCollections(databaseId) collections, status := h.repository.GetAllCollections(databaseId)
if status == repositorymodels.StatusOk { if status == repositorymodels.StatusOk {
database, _ := repositories.GetDatabase(databaseId) database, _ := h.repository.GetDatabase(databaseId)
c.Header("x-ms-item-count", fmt.Sprintf("%d", len(collections))) c.Header("x-ms-item-count", fmt.Sprintf("%d", len(collections)))
c.IndentedJSON(http.StatusOK, gin.H{ c.IndentedJSON(http.StatusOK, gin.H{
@ -28,11 +27,11 @@ func GetAllCollections(c *gin.Context) {
c.IndentedJSON(http.StatusInternalServerError, gin.H{"message": "Unknown error"}) c.IndentedJSON(http.StatusInternalServerError, gin.H{"message": "Unknown error"})
} }
func GetCollection(c *gin.Context) { func (h *Handlers) GetCollection(c *gin.Context) {
databaseId := c.Param("databaseId") databaseId := c.Param("databaseId")
id := c.Param("collId") id := c.Param("collId")
collection, status := repositories.GetCollection(databaseId, id) collection, status := h.repository.GetCollection(databaseId, id)
if status == repositorymodels.StatusOk { if status == repositorymodels.StatusOk {
c.IndentedJSON(http.StatusOK, collection) c.IndentedJSON(http.StatusOK, collection)
return return
@ -46,11 +45,11 @@ func GetCollection(c *gin.Context) {
c.IndentedJSON(http.StatusInternalServerError, gin.H{"message": "Unknown error"}) c.IndentedJSON(http.StatusInternalServerError, gin.H{"message": "Unknown error"})
} }
func DeleteCollection(c *gin.Context) { func (h *Handlers) DeleteCollection(c *gin.Context) {
databaseId := c.Param("databaseId") databaseId := c.Param("databaseId")
id := c.Param("collId") id := c.Param("collId")
status := repositories.DeleteCollection(databaseId, id) status := h.repository.DeleteCollection(databaseId, id)
if status == repositorymodels.StatusOk { if status == repositorymodels.StatusOk {
c.Status(http.StatusNoContent) c.Status(http.StatusNoContent)
return return
@ -64,7 +63,7 @@ func DeleteCollection(c *gin.Context) {
c.IndentedJSON(http.StatusInternalServerError, gin.H{"message": "Unknown error"}) c.IndentedJSON(http.StatusInternalServerError, gin.H{"message": "Unknown error"})
} }
func CreateCollection(c *gin.Context) { func (h *Handlers) CreateCollection(c *gin.Context) {
databaseId := c.Param("databaseId") databaseId := c.Param("databaseId")
var newCollection repositorymodels.Collection var newCollection repositorymodels.Collection
@ -78,7 +77,7 @@ func CreateCollection(c *gin.Context) {
return return
} }
createdCollection, status := repositories.CreateCollection(databaseId, newCollection) createdCollection, status := h.repository.CreateCollection(databaseId, newCollection)
if status == repositorymodels.Conflict { if status == repositorymodels.Conflict {
c.IndentedJSON(http.StatusConflict, gin.H{"message": "Conflict"}) c.IndentedJSON(http.StatusConflict, gin.H{"message": "Conflict"})
return return

View File

@ -4,11 +4,10 @@ import (
"net/http" "net/http"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/pikami/cosmium/internal/repositories"
) )
func CosmiumExport(c *gin.Context) { func (h *Handlers) CosmiumExport(c *gin.Context) {
repositoryState, err := repositories.GetState() repositoryState, err := h.repository.GetState()
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return

View File

@ -5,12 +5,11 @@ import (
"net/http" "net/http"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/pikami/cosmium/internal/repositories"
repositorymodels "github.com/pikami/cosmium/internal/repository_models" repositorymodels "github.com/pikami/cosmium/internal/repository_models"
) )
func GetAllDatabases(c *gin.Context) { func (h *Handlers) GetAllDatabases(c *gin.Context) {
databases, status := repositories.GetAllDatabases() databases, status := h.repository.GetAllDatabases()
if status == repositorymodels.StatusOk { if status == repositorymodels.StatusOk {
c.Header("x-ms-item-count", fmt.Sprintf("%d", len(databases))) c.Header("x-ms-item-count", fmt.Sprintf("%d", len(databases)))
c.IndentedJSON(http.StatusOK, gin.H{ c.IndentedJSON(http.StatusOK, gin.H{
@ -24,10 +23,10 @@ func GetAllDatabases(c *gin.Context) {
c.IndentedJSON(http.StatusInternalServerError, gin.H{"message": "Unknown error"}) c.IndentedJSON(http.StatusInternalServerError, gin.H{"message": "Unknown error"})
} }
func GetDatabase(c *gin.Context) { func (h *Handlers) GetDatabase(c *gin.Context) {
id := c.Param("databaseId") id := c.Param("databaseId")
database, status := repositories.GetDatabase(id) database, status := h.repository.GetDatabase(id)
if status == repositorymodels.StatusOk { if status == repositorymodels.StatusOk {
c.IndentedJSON(http.StatusOK, database) c.IndentedJSON(http.StatusOK, database)
return return
@ -41,10 +40,10 @@ func GetDatabase(c *gin.Context) {
c.IndentedJSON(http.StatusInternalServerError, gin.H{"message": "Unknown error"}) c.IndentedJSON(http.StatusInternalServerError, gin.H{"message": "Unknown error"})
} }
func DeleteDatabase(c *gin.Context) { func (h *Handlers) DeleteDatabase(c *gin.Context) {
id := c.Param("databaseId") id := c.Param("databaseId")
status := repositories.DeleteDatabase(id) status := h.repository.DeleteDatabase(id)
if status == repositorymodels.StatusOk { if status == repositorymodels.StatusOk {
c.Status(http.StatusNoContent) c.Status(http.StatusNoContent)
return return
@ -58,7 +57,7 @@ func DeleteDatabase(c *gin.Context) {
c.IndentedJSON(http.StatusInternalServerError, gin.H{"message": "Unknown error"}) c.IndentedJSON(http.StatusInternalServerError, gin.H{"message": "Unknown error"})
} }
func CreateDatabase(c *gin.Context) { func (h *Handlers) CreateDatabase(c *gin.Context) {
var newDatabase repositorymodels.Database var newDatabase repositorymodels.Database
if err := c.BindJSON(&newDatabase); err != nil { if err := c.BindJSON(&newDatabase); err != nil {
@ -71,7 +70,7 @@ func CreateDatabase(c *gin.Context) {
return return
} }
createdDatabase, status := repositories.CreateDatabase(newDatabase) createdDatabase, status := h.repository.CreateDatabase(newDatabase)
if status == repositorymodels.Conflict { if status == repositorymodels.Conflict {
c.IndentedJSON(http.StatusConflict, gin.H{"message": "Conflict"}) c.IndentedJSON(http.StatusConflict, gin.H{"message": "Conflict"})
return return

View File

@ -10,17 +10,16 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/pikami/cosmium/internal/constants" "github.com/pikami/cosmium/internal/constants"
"github.com/pikami/cosmium/internal/logger" "github.com/pikami/cosmium/internal/logger"
"github.com/pikami/cosmium/internal/repositories"
repositorymodels "github.com/pikami/cosmium/internal/repository_models" repositorymodels "github.com/pikami/cosmium/internal/repository_models"
) )
func GetAllDocuments(c *gin.Context) { func (h *Handlers) GetAllDocuments(c *gin.Context) {
databaseId := c.Param("databaseId") databaseId := c.Param("databaseId")
collectionId := c.Param("collId") collectionId := c.Param("collId")
documents, status := repositories.GetAllDocuments(databaseId, collectionId) documents, status := h.repository.GetAllDocuments(databaseId, collectionId)
if status == repositorymodels.StatusOk { if status == repositorymodels.StatusOk {
collection, _ := repositories.GetCollection(databaseId, collectionId) collection, _ := h.repository.GetCollection(databaseId, collectionId)
c.Header("x-ms-item-count", fmt.Sprintf("%d", len(documents))) c.Header("x-ms-item-count", fmt.Sprintf("%d", len(documents)))
c.IndentedJSON(http.StatusOK, gin.H{ c.IndentedJSON(http.StatusOK, gin.H{
@ -34,12 +33,12 @@ func GetAllDocuments(c *gin.Context) {
c.IndentedJSON(http.StatusInternalServerError, gin.H{"message": "Unknown error"}) c.IndentedJSON(http.StatusInternalServerError, gin.H{"message": "Unknown error"})
} }
func GetDocument(c *gin.Context) { func (h *Handlers) GetDocument(c *gin.Context) {
databaseId := c.Param("databaseId") databaseId := c.Param("databaseId")
collectionId := c.Param("collId") collectionId := c.Param("collId")
documentId := c.Param("docId") documentId := c.Param("docId")
document, status := repositories.GetDocument(databaseId, collectionId, documentId) document, status := h.repository.GetDocument(databaseId, collectionId, documentId)
if status == repositorymodels.StatusOk { if status == repositorymodels.StatusOk {
c.IndentedJSON(http.StatusOK, document) c.IndentedJSON(http.StatusOK, document)
return return
@ -53,12 +52,12 @@ func GetDocument(c *gin.Context) {
c.IndentedJSON(http.StatusInternalServerError, gin.H{"message": "Unknown error"}) c.IndentedJSON(http.StatusInternalServerError, gin.H{"message": "Unknown error"})
} }
func DeleteDocument(c *gin.Context) { func (h *Handlers) DeleteDocument(c *gin.Context) {
databaseId := c.Param("databaseId") databaseId := c.Param("databaseId")
collectionId := c.Param("collId") collectionId := c.Param("collId")
documentId := c.Param("docId") documentId := c.Param("docId")
status := repositories.DeleteDocument(databaseId, collectionId, documentId) status := h.repository.DeleteDocument(databaseId, collectionId, documentId)
if status == repositorymodels.StatusOk { if status == repositorymodels.StatusOk {
c.Status(http.StatusNoContent) c.Status(http.StatusNoContent)
return return
@ -73,7 +72,7 @@ func DeleteDocument(c *gin.Context) {
} }
// TODO: Maybe move "replace" logic to repository // TODO: Maybe move "replace" logic to repository
func ReplaceDocument(c *gin.Context) { func (h *Handlers) ReplaceDocument(c *gin.Context) {
databaseId := c.Param("databaseId") databaseId := c.Param("databaseId")
collectionId := c.Param("collId") collectionId := c.Param("collId")
documentId := c.Param("docId") documentId := c.Param("docId")
@ -84,13 +83,13 @@ func ReplaceDocument(c *gin.Context) {
return return
} }
status := repositories.DeleteDocument(databaseId, collectionId, documentId) status := h.repository.DeleteDocument(databaseId, collectionId, documentId)
if status == repositorymodels.StatusNotFound { if status == repositorymodels.StatusNotFound {
c.IndentedJSON(http.StatusNotFound, gin.H{"message": "NotFound"}) c.IndentedJSON(http.StatusNotFound, gin.H{"message": "NotFound"})
return return
} }
createdDocument, status := repositories.CreateDocument(databaseId, collectionId, requestBody) createdDocument, status := h.repository.CreateDocument(databaseId, collectionId, requestBody)
if status == repositorymodels.Conflict { if status == repositorymodels.Conflict {
c.IndentedJSON(http.StatusConflict, gin.H{"message": "Conflict"}) c.IndentedJSON(http.StatusConflict, gin.H{"message": "Conflict"})
return return
@ -104,12 +103,12 @@ func ReplaceDocument(c *gin.Context) {
c.IndentedJSON(http.StatusInternalServerError, gin.H{"message": "Unknown error"}) c.IndentedJSON(http.StatusInternalServerError, gin.H{"message": "Unknown error"})
} }
func PatchDocument(c *gin.Context) { func (h *Handlers) PatchDocument(c *gin.Context) {
databaseId := c.Param("databaseId") databaseId := c.Param("databaseId")
collectionId := c.Param("collId") collectionId := c.Param("collId")
documentId := c.Param("docId") documentId := c.Param("docId")
document, status := repositories.GetDocument(databaseId, collectionId, documentId) document, status := h.repository.GetDocument(databaseId, collectionId, documentId)
if status == repositorymodels.StatusNotFound { if status == repositorymodels.StatusNotFound {
c.IndentedJSON(http.StatusNotFound, gin.H{"message": "NotFound"}) c.IndentedJSON(http.StatusNotFound, gin.H{"message": "NotFound"})
return return
@ -160,13 +159,13 @@ func PatchDocument(c *gin.Context) {
return return
} }
status = repositories.DeleteDocument(databaseId, collectionId, documentId) status = h.repository.DeleteDocument(databaseId, collectionId, documentId)
if status == repositorymodels.StatusNotFound { if status == repositorymodels.StatusNotFound {
c.IndentedJSON(http.StatusNotFound, gin.H{"message": "NotFound"}) c.IndentedJSON(http.StatusNotFound, gin.H{"message": "NotFound"})
return return
} }
createdDocument, status := repositories.CreateDocument(databaseId, collectionId, modifiedDocument) createdDocument, status := h.repository.CreateDocument(databaseId, collectionId, modifiedDocument)
if status == repositorymodels.Conflict { if status == repositorymodels.Conflict {
c.IndentedJSON(http.StatusConflict, gin.H{"message": "Conflict"}) c.IndentedJSON(http.StatusConflict, gin.H{"message": "Conflict"})
return return
@ -180,7 +179,7 @@ func PatchDocument(c *gin.Context) {
c.IndentedJSON(http.StatusInternalServerError, gin.H{"message": "Unknown error"}) c.IndentedJSON(http.StatusInternalServerError, gin.H{"message": "Unknown error"})
} }
func DocumentsPost(c *gin.Context) { func (h *Handlers) DocumentsPost(c *gin.Context) {
databaseId := c.Param("databaseId") databaseId := c.Param("databaseId")
collectionId := c.Param("collId") collectionId := c.Param("collId")
@ -202,14 +201,14 @@ func DocumentsPost(c *gin.Context) {
queryParameters = parametersToMap(paramsArray) queryParameters = parametersToMap(paramsArray)
} }
docs, status := repositories.ExecuteQueryDocuments(databaseId, collectionId, query.(string), queryParameters) docs, status := h.repository.ExecuteQueryDocuments(databaseId, collectionId, query.(string), queryParameters)
if status != repositorymodels.StatusOk { if status != repositorymodels.StatusOk {
// TODO: Currently we return everything if the query fails // TODO: Currently we return everything if the query fails
GetAllDocuments(c) h.GetAllDocuments(c)
return return
} }
collection, _ := repositories.GetCollection(databaseId, collectionId) collection, _ := h.repository.GetCollection(databaseId, collectionId)
c.Header("x-ms-item-count", fmt.Sprintf("%d", len(docs))) c.Header("x-ms-item-count", fmt.Sprintf("%d", len(docs)))
c.IndentedJSON(http.StatusOK, gin.H{ c.IndentedJSON(http.StatusOK, gin.H{
"_rid": collection.ResourceID, "_rid": collection.ResourceID,
@ -226,10 +225,10 @@ func DocumentsPost(c *gin.Context) {
isUpsert, _ := strconv.ParseBool(c.GetHeader("x-ms-documentdb-is-upsert")) isUpsert, _ := strconv.ParseBool(c.GetHeader("x-ms-documentdb-is-upsert"))
if isUpsert { if isUpsert {
repositories.DeleteDocument(databaseId, collectionId, requestBody["id"].(string)) h.repository.DeleteDocument(databaseId, collectionId, requestBody["id"].(string))
} }
createdDocument, status := repositories.CreateDocument(databaseId, collectionId, requestBody) createdDocument, status := h.repository.CreateDocument(databaseId, collectionId, requestBody)
if status == repositorymodels.Conflict { if status == repositorymodels.Conflict {
c.IndentedJSON(http.StatusConflict, gin.H{"message": "Conflict"}) c.IndentedJSON(http.StatusConflict, gin.H{"message": "Conflict"})
return return

View File

@ -4,15 +4,14 @@ import (
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/pikami/cosmium/api/config"
) )
func RegisterExplorerHandlers(router *gin.Engine) { func (h *Handlers) RegisterExplorerHandlers(router *gin.Engine) {
explorer := router.Group(config.Config.ExplorerBaseUrlLocation) explorer := router.Group(h.config.ExplorerBaseUrlLocation)
{ {
explorer.Use(func(ctx *gin.Context) { explorer.Use(func(ctx *gin.Context) {
if ctx.Param("filepath") == "/config.json" { if ctx.Param("filepath") == "/config.json" {
endpoint := fmt.Sprintf("https://%s:%d", config.Config.Host, config.Config.Port) endpoint := fmt.Sprintf("https://%s:%d", h.config.Host, h.config.Port)
ctx.JSON(200, gin.H{ ctx.JSON(200, gin.H{
"BACKEND_ENDPOINT": endpoint, "BACKEND_ENDPOINT": endpoint,
"MONGO_BACKEND_ENDPOINT": endpoint, "MONGO_BACKEND_ENDPOINT": endpoint,
@ -25,8 +24,8 @@ func RegisterExplorerHandlers(router *gin.Engine) {
} }
}) })
if config.Config.ExplorerPath != "" { if h.config.ExplorerPath != "" {
explorer.Static("/", config.Config.ExplorerPath) explorer.Static("/", h.config.ExplorerPath)
} }
} }
} }

18
api/handlers/handlers.go Normal file
View File

@ -0,0 +1,18 @@
package handlers
import (
"github.com/pikami/cosmium/api/config"
"github.com/pikami/cosmium/internal/repositories"
)
type Handlers struct {
repository *repositories.DataRepository
config config.ServerConfig
}
func NewHandlers(dataRepository *repositories.DataRepository, config config.ServerConfig) *Handlers {
return &Handlers{
repository: dataRepository,
config: config,
}
}

View File

@ -10,11 +10,11 @@ import (
"github.com/pikami/cosmium/internal/logger" "github.com/pikami/cosmium/internal/logger"
) )
func Authentication() gin.HandlerFunc { func Authentication(config config.ServerConfig) gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
requestUrl := c.Request.URL.String() requestUrl := c.Request.URL.String()
if config.Config.DisableAuth || if config.DisableAuth ||
strings.HasPrefix(requestUrl, config.Config.ExplorerBaseUrlLocation) || strings.HasPrefix(requestUrl, config.ExplorerBaseUrlLocation) ||
strings.HasPrefix(requestUrl, "/cosmium") { strings.HasPrefix(requestUrl, "/cosmium") {
return return
} }
@ -25,7 +25,7 @@ func Authentication() gin.HandlerFunc {
authHeader := c.Request.Header.Get("authorization") authHeader := c.Request.Header.Get("authorization")
date := c.Request.Header.Get("x-ms-date") date := c.Request.Header.Get("x-ms-date")
expectedSignature := authentication.GenerateSignature( expectedSignature := authentication.GenerateSignature(
c.Request.Method, resourceType, resourceId, date, config.Config.AccountKey) c.Request.Method, resourceType, resourceId, date, config.AccountKey)
decoded, _ := url.QueryUnescape(authHeader) decoded, _ := url.QueryUnescape(authHeader)
params, _ := url.ParseQuery(decoded) params, _ := url.ParseQuery(decoded)

View File

@ -7,10 +7,10 @@ import (
"github.com/pikami/cosmium/api/config" "github.com/pikami/cosmium/api/config"
) )
func StripTrailingSlashes(r *gin.Engine) gin.HandlerFunc { func StripTrailingSlashes(r *gin.Engine, config config.ServerConfig) gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
path := c.Request.URL.Path path := c.Request.URL.Path
if len(path) > 1 && path[len(path)-1] == '/' && !strings.Contains(path, config.Config.ExplorerBaseUrlLocation) { if len(path) > 1 && path[len(path)-1] == '/' && !strings.Contains(path, config.ExplorerBaseUrlLocation) {
c.Request.URL.Path = path[:len(path)-1] c.Request.URL.Path = path[:len(path)-1]
r.HandleContext(c) r.HandleContext(c)
c.Abort() c.Abort()

View File

@ -5,11 +5,10 @@ import (
"net/http" "net/http"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/pikami/cosmium/internal/repositories"
repositorymodels "github.com/pikami/cosmium/internal/repository_models" repositorymodels "github.com/pikami/cosmium/internal/repository_models"
) )
func GetPartitionKeyRanges(c *gin.Context) { func (h *Handlers) GetPartitionKeyRanges(c *gin.Context) {
databaseId := c.Param("databaseId") databaseId := c.Param("databaseId")
collectionId := c.Param("collId") collectionId := c.Param("collId")
@ -18,7 +17,7 @@ func GetPartitionKeyRanges(c *gin.Context) {
return return
} }
partitionKeyRanges, status := repositories.GetPartitionKeyRanges(databaseId, collectionId) partitionKeyRanges, status := h.repository.GetPartitionKeyRanges(databaseId, collectionId)
if status == repositorymodels.StatusOk { if status == repositorymodels.StatusOk {
c.Header("etag", "\"420\"") c.Header("etag", "\"420\"")
c.Header("lsn", "420") c.Header("lsn", "420")
@ -27,7 +26,7 @@ func GetPartitionKeyRanges(c *gin.Context) {
c.Header("x-ms-item-count", fmt.Sprintf("%d", len(partitionKeyRanges))) c.Header("x-ms-item-count", fmt.Sprintf("%d", len(partitionKeyRanges)))
collectionRid := collectionId collectionRid := collectionId
collection, _ := repositories.GetCollection(databaseId, collectionId) collection, _ := h.repository.GetCollection(databaseId, collectionId)
if collection.ResourceID != "" { if collection.ResourceID != "" {
collectionRid = collection.ResourceID collectionRid = collection.ResourceID
} }

View File

@ -5,27 +5,26 @@ import (
"net/http" "net/http"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/pikami/cosmium/api/config"
) )
func GetServerInfo(c *gin.Context) { func (h *Handlers) GetServerInfo(c *gin.Context) {
c.IndentedJSON(http.StatusOK, gin.H{ c.IndentedJSON(http.StatusOK, gin.H{
"_self": "", "_self": "",
"id": config.Config.DatabaseAccount, "id": h.config.DatabaseAccount,
"_rid": fmt.Sprintf("%s.%s", config.Config.DatabaseAccount, config.Config.DatabaseDomain), "_rid": fmt.Sprintf("%s.%s", h.config.DatabaseAccount, h.config.DatabaseDomain),
"media": "//media/", "media": "//media/",
"addresses": "//addresses/", "addresses": "//addresses/",
"_dbs": "//dbs/", "_dbs": "//dbs/",
"writableLocations": []map[string]interface{}{ "writableLocations": []map[string]interface{}{
{ {
"name": "South Central US", "name": "South Central US",
"databaseAccountEndpoint": config.Config.DatabaseEndpoint, "databaseAccountEndpoint": h.config.DatabaseEndpoint,
}, },
}, },
"readableLocations": []map[string]interface{}{ "readableLocations": []map[string]interface{}{
{ {
"name": "South Central US", "name": "South Central US",
"databaseAccountEndpoint": config.Config.DatabaseEndpoint, "databaseAccountEndpoint": h.config.DatabaseEndpoint,
}, },
}, },
"enableMultipleWriteLocations": false, "enableMultipleWriteLocations": false,

View File

@ -5,15 +5,14 @@ import (
"net/http" "net/http"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/pikami/cosmium/internal/repositories"
repositorymodels "github.com/pikami/cosmium/internal/repository_models" repositorymodels "github.com/pikami/cosmium/internal/repository_models"
) )
func GetAllStoredProcedures(c *gin.Context) { func (h *Handlers) GetAllStoredProcedures(c *gin.Context) {
databaseId := c.Param("databaseId") databaseId := c.Param("databaseId")
collectionId := c.Param("collId") collectionId := c.Param("collId")
sps, status := repositories.GetAllStoredProcedures(databaseId, collectionId) sps, status := h.repository.GetAllStoredProcedures(databaseId, collectionId)
if status == repositorymodels.StatusOk { if status == repositorymodels.StatusOk {
c.Header("x-ms-item-count", fmt.Sprintf("%d", len(sps))) c.Header("x-ms-item-count", fmt.Sprintf("%d", len(sps)))

View File

@ -5,15 +5,14 @@ import (
"net/http" "net/http"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/pikami/cosmium/internal/repositories"
repositorymodels "github.com/pikami/cosmium/internal/repository_models" repositorymodels "github.com/pikami/cosmium/internal/repository_models"
) )
func GetAllTriggers(c *gin.Context) { func (h *Handlers) GetAllTriggers(c *gin.Context) {
databaseId := c.Param("databaseId") databaseId := c.Param("databaseId")
collectionId := c.Param("collId") collectionId := c.Param("collId")
triggers, status := repositories.GetAllTriggers(databaseId, collectionId) triggers, status := h.repository.GetAllTriggers(databaseId, collectionId)
if status == repositorymodels.StatusOk { if status == repositorymodels.StatusOk {
c.Header("x-ms-item-count", fmt.Sprintf("%d", len(triggers))) c.Header("x-ms-item-count", fmt.Sprintf("%d", len(triggers)))

View File

@ -5,15 +5,14 @@ import (
"net/http" "net/http"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/pikami/cosmium/internal/repositories"
repositorymodels "github.com/pikami/cosmium/internal/repository_models" repositorymodels "github.com/pikami/cosmium/internal/repository_models"
) )
func GetAllUserDefinedFunctions(c *gin.Context) { func (h *Handlers) GetAllUserDefinedFunctions(c *gin.Context) {
databaseId := c.Param("databaseId") databaseId := c.Param("databaseId")
collectionId := c.Param("collId") collectionId := c.Param("collId")
udfs, status := repositories.GetAllUserDefinedFunctions(databaseId, collectionId) udfs, status := h.repository.GetAllUserDefinedFunctions(databaseId, collectionId)
if status == repositorymodels.StatusOk { if status == repositorymodels.StatusOk {
c.Header("x-ms-item-count", fmt.Sprintf("%d", len(udfs))) c.Header("x-ms-item-count", fmt.Sprintf("%d", len(udfs)))

View File

@ -6,78 +6,75 @@ import (
"net/http" "net/http"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/pikami/cosmium/api/config"
"github.com/pikami/cosmium/api/handlers" "github.com/pikami/cosmium/api/handlers"
"github.com/pikami/cosmium/api/handlers/middleware" "github.com/pikami/cosmium/api/handlers/middleware"
"github.com/pikami/cosmium/internal/logger" "github.com/pikami/cosmium/internal/logger"
"github.com/pikami/cosmium/internal/repositories"
tlsprovider "github.com/pikami/cosmium/internal/tls_provider" tlsprovider "github.com/pikami/cosmium/internal/tls_provider"
) )
type Server struct { func (s *ApiServer) CreateRouter(repository *repositories.DataRepository) {
StopServer chan interface{} routeHandlers := handlers.NewHandlers(repository, s.config)
}
func CreateRouter() *gin.Engine {
router := gin.Default(func(e *gin.Engine) { router := gin.Default(func(e *gin.Engine) {
e.RedirectTrailingSlash = false e.RedirectTrailingSlash = false
}) })
if config.Config.Debug { if s.config.Debug {
router.Use(middleware.RequestLogger()) router.Use(middleware.RequestLogger())
} }
router.Use(middleware.StripTrailingSlashes(router)) router.Use(middleware.StripTrailingSlashes(router, s.config))
router.Use(middleware.Authentication()) router.Use(middleware.Authentication(s.config))
router.GET("/dbs/:databaseId/colls/:collId/pkranges", handlers.GetPartitionKeyRanges) router.GET("/dbs/:databaseId/colls/:collId/pkranges", routeHandlers.GetPartitionKeyRanges)
router.POST("/dbs/:databaseId/colls/:collId/docs", handlers.DocumentsPost) router.POST("/dbs/:databaseId/colls/:collId/docs", routeHandlers.DocumentsPost)
router.GET("/dbs/:databaseId/colls/:collId/docs", handlers.GetAllDocuments) router.GET("/dbs/:databaseId/colls/:collId/docs", routeHandlers.GetAllDocuments)
router.GET("/dbs/:databaseId/colls/:collId/docs/:docId", handlers.GetDocument) router.GET("/dbs/:databaseId/colls/:collId/docs/:docId", routeHandlers.GetDocument)
router.PUT("/dbs/:databaseId/colls/:collId/docs/:docId", handlers.ReplaceDocument) router.PUT("/dbs/:databaseId/colls/:collId/docs/:docId", routeHandlers.ReplaceDocument)
router.PATCH("/dbs/:databaseId/colls/:collId/docs/:docId", handlers.PatchDocument) router.PATCH("/dbs/:databaseId/colls/:collId/docs/:docId", routeHandlers.PatchDocument)
router.DELETE("/dbs/:databaseId/colls/:collId/docs/:docId", handlers.DeleteDocument) router.DELETE("/dbs/:databaseId/colls/:collId/docs/:docId", routeHandlers.DeleteDocument)
router.POST("/dbs/:databaseId/colls", handlers.CreateCollection) router.POST("/dbs/:databaseId/colls", routeHandlers.CreateCollection)
router.GET("/dbs/:databaseId/colls", handlers.GetAllCollections) router.GET("/dbs/:databaseId/colls", routeHandlers.GetAllCollections)
router.GET("/dbs/:databaseId/colls/:collId", handlers.GetCollection) router.GET("/dbs/:databaseId/colls/:collId", routeHandlers.GetCollection)
router.DELETE("/dbs/:databaseId/colls/:collId", handlers.DeleteCollection) router.DELETE("/dbs/:databaseId/colls/:collId", routeHandlers.DeleteCollection)
router.POST("/dbs", handlers.CreateDatabase) router.POST("/dbs", routeHandlers.CreateDatabase)
router.GET("/dbs", handlers.GetAllDatabases) router.GET("/dbs", routeHandlers.GetAllDatabases)
router.GET("/dbs/:databaseId", handlers.GetDatabase) router.GET("/dbs/:databaseId", routeHandlers.GetDatabase)
router.DELETE("/dbs/:databaseId", handlers.DeleteDatabase) router.DELETE("/dbs/:databaseId", routeHandlers.DeleteDatabase)
router.GET("/dbs/:databaseId/colls/:collId/udfs", handlers.GetAllUserDefinedFunctions) router.GET("/dbs/:databaseId/colls/:collId/udfs", routeHandlers.GetAllUserDefinedFunctions)
router.GET("/dbs/:databaseId/colls/:collId/sprocs", handlers.GetAllStoredProcedures) router.GET("/dbs/:databaseId/colls/:collId/sprocs", routeHandlers.GetAllStoredProcedures)
router.GET("/dbs/:databaseId/colls/:collId/triggers", handlers.GetAllTriggers) router.GET("/dbs/:databaseId/colls/:collId/triggers", routeHandlers.GetAllTriggers)
router.GET("/offers", handlers.GetOffers) router.GET("/offers", handlers.GetOffers)
router.GET("/", handlers.GetServerInfo) router.GET("/", routeHandlers.GetServerInfo)
router.GET("/cosmium/export", handlers.CosmiumExport) router.GET("/cosmium/export", routeHandlers.CosmiumExport)
handlers.RegisterExplorerHandlers(router) routeHandlers.RegisterExplorerHandlers(router)
return router s.router = router
} }
func StartAPI() *Server { func (s *ApiServer) Start() {
if !config.Config.Debug { if !s.config.Debug {
gin.SetMode(gin.ReleaseMode) gin.SetMode(gin.ReleaseMode)
} }
router := CreateRouter() listenAddress := fmt.Sprintf(":%d", s.config.Port)
listenAddress := fmt.Sprintf(":%d", config.Config.Port) s.isActive = true
stopChan := make(chan interface{})
server := &http.Server{ server := &http.Server{
Addr: listenAddress, Addr: listenAddress,
Handler: router.Handler(), Handler: s.router.Handler(),
} }
go func() { go func() {
<-stopChan <-s.stopServer
logger.Info("Shutting down server...") logger.Info("Shutting down server...")
err := server.Shutdown(context.TODO()) err := server.Shutdown(context.TODO())
if err != nil { if err != nil {
@ -86,24 +83,22 @@ func StartAPI() *Server {
}() }()
go func() { go func() {
if config.Config.DisableTls { if s.config.DisableTls {
logger.Infof("Listening and serving HTTP on %s\n", server.Addr) logger.Infof("Listening and serving HTTP on %s\n", server.Addr)
err := server.ListenAndServe() err := server.ListenAndServe()
if err != nil { if err != nil {
logger.Error("Failed to start HTTP server:", err) logger.Error("Failed to start HTTP server:", err)
} }
return s.isActive = false
} } else if s.config.TLS_CertificatePath != "" && s.config.TLS_CertificateKey != "" {
if config.Config.TLS_CertificatePath != "" && config.Config.TLS_CertificateKey != "" {
logger.Infof("Listening and serving HTTPS on %s\n", server.Addr) logger.Infof("Listening and serving HTTPS on %s\n", server.Addr)
err := server.ListenAndServeTLS( err := server.ListenAndServeTLS(
config.Config.TLS_CertificatePath, s.config.TLS_CertificatePath,
config.Config.TLS_CertificateKey) s.config.TLS_CertificateKey)
if err != nil { if err != nil {
logger.Error("Failed to start HTTPS server:", err) logger.Error("Failed to start HTTPS server:", err)
} }
return s.isActive = false
} else { } else {
tlsConfig := tlsprovider.GetDefaultTlsConfig() tlsConfig := tlsprovider.GetDefaultTlsConfig()
server.TLSConfig = tlsConfig server.TLSConfig = tlsConfig
@ -113,9 +108,7 @@ func StartAPI() *Server {
if err != nil { if err != nil {
logger.Error("Failed to start HTTPS server:", err) logger.Error("Failed to start HTTPS server:", err)
} }
return s.isActive = false
} }
}() }()
return &Server{StopServer: stopChan}
} }

View File

@ -11,16 +11,15 @@ import (
"github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos" "github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos"
"github.com/pikami/cosmium/api/config" "github.com/pikami/cosmium/api/config"
"github.com/pikami/cosmium/internal/repositories"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func Test_Authentication(t *testing.T) { func Test_Authentication(t *testing.T) {
ts := runTestServer() ts := runTestServer()
defer ts.Close() defer ts.Server.Close()
t.Run("Should get 200 when correct account key is used", func(t *testing.T) { t.Run("Should get 200 when correct account key is used", func(t *testing.T) {
repositories.DeleteDatabase(testDatabaseName) ts.Repository.DeleteDatabase(testDatabaseName)
client, err := azcosmos.NewClientFromConnectionString( client, err := azcosmos.NewClientFromConnectionString(
fmt.Sprintf("AccountEndpoint=%s;AccountKey=%s", ts.URL, config.DefaultAccountKey), fmt.Sprintf("AccountEndpoint=%s;AccountKey=%s", ts.URL, config.DefaultAccountKey),
&azcosmos.ClientOptions{}, &azcosmos.ClientOptions{},
@ -35,26 +34,8 @@ func Test_Authentication(t *testing.T) {
assert.Equal(t, createResponse.DatabaseProperties.ID, testDatabaseName) assert.Equal(t, createResponse.DatabaseProperties.ID, testDatabaseName)
}) })
t.Run("Should get 200 when wrong account key is used, but authentication is dissabled", func(t *testing.T) {
config.Config.DisableAuth = true
repositories.DeleteDatabase(testDatabaseName)
client, err := azcosmos.NewClientFromConnectionString(
fmt.Sprintf("AccountEndpoint=%s;AccountKey=%s", ts.URL, "AAAA"),
&azcosmos.ClientOptions{},
)
assert.Nil(t, err)
createResponse, err := client.CreateDatabase(
context.TODO(),
azcosmos.DatabaseProperties{ID: testDatabaseName},
&azcosmos.CreateDatabaseOptions{})
assert.Nil(t, err)
assert.Equal(t, createResponse.DatabaseProperties.ID, testDatabaseName)
config.Config.DisableAuth = false
})
t.Run("Should get 401 when wrong account key is used", func(t *testing.T) { t.Run("Should get 401 when wrong account key is used", func(t *testing.T) {
repositories.DeleteDatabase(testDatabaseName) ts.Repository.DeleteDatabase(testDatabaseName)
client, err := azcosmos.NewClientFromConnectionString( client, err := azcosmos.NewClientFromConnectionString(
fmt.Sprintf("AccountEndpoint=%s;AccountKey=%s", ts.URL, "AAAA"), fmt.Sprintf("AccountEndpoint=%s;AccountKey=%s", ts.URL, "AAAA"),
&azcosmos.ClientOptions{}, &azcosmos.ClientOptions{},
@ -85,3 +66,29 @@ func Test_Authentication(t *testing.T) {
assert.Contains(t, string(responseBody), "BACKEND_ENDPOINT") assert.Contains(t, string(responseBody), "BACKEND_ENDPOINT")
}) })
} }
func Test_Authentication_Disabled(t *testing.T) {
ts := runTestServerCustomConfig(config.ServerConfig{
AccountKey: config.DefaultAccountKey,
ExplorerPath: "/tmp/nothing",
ExplorerBaseUrlLocation: config.ExplorerBaseUrlLocation,
DisableAuth: true,
})
defer ts.Server.Close()
t.Run("Should get 200 when wrong account key is used, but authentication is dissabled", func(t *testing.T) {
ts.Repository.DeleteDatabase(testDatabaseName)
client, err := azcosmos.NewClientFromConnectionString(
fmt.Sprintf("AccountEndpoint=%s;AccountKey=%s", ts.URL, "AAAA"),
&azcosmos.ClientOptions{},
)
assert.Nil(t, err)
createResponse, err := client.CreateDatabase(
context.TODO(),
azcosmos.DatabaseProperties{ID: testDatabaseName},
&azcosmos.CreateDatabaseOptions{})
assert.Nil(t, err)
assert.Equal(t, createResponse.DatabaseProperties.ID, testDatabaseName)
})
}

View File

@ -10,22 +10,21 @@ import (
"github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos" "github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos"
"github.com/pikami/cosmium/api/config" "github.com/pikami/cosmium/api/config"
"github.com/pikami/cosmium/internal/repositories"
repositorymodels "github.com/pikami/cosmium/internal/repository_models" repositorymodels "github.com/pikami/cosmium/internal/repository_models"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func Test_Collections(t *testing.T) { func Test_Collections(t *testing.T) {
ts := runTestServer() ts := runTestServer()
defer ts.Close() defer ts.Server.Close()
client, err := azcosmos.NewClientFromConnectionString( client, err := azcosmos.NewClientFromConnectionString(
fmt.Sprintf("AccountEndpoint=%s;AccountKey=%s", ts.URL, config.Config.AccountKey), fmt.Sprintf("AccountEndpoint=%s;AccountKey=%s", ts.URL, config.DefaultAccountKey),
&azcosmos.ClientOptions{}, &azcosmos.ClientOptions{},
) )
assert.Nil(t, err) assert.Nil(t, err)
repositories.CreateDatabase(repositorymodels.Database{ID: testDatabaseName}) ts.Repository.CreateDatabase(repositorymodels.Database{ID: testDatabaseName})
databaseClient, err := client.NewDatabase(testDatabaseName) databaseClient, err := client.NewDatabase(testDatabaseName)
assert.Nil(t, err) assert.Nil(t, err)
@ -40,7 +39,7 @@ func Test_Collections(t *testing.T) {
}) })
t.Run("Should return conflict when collection exists", func(t *testing.T) { t.Run("Should return conflict when collection exists", func(t *testing.T) {
repositories.CreateCollection(testDatabaseName, repositorymodels.Collection{ ts.Repository.CreateCollection(testDatabaseName, repositorymodels.Collection{
ID: testCollectionName, ID: testCollectionName,
}) })
@ -60,7 +59,7 @@ func Test_Collections(t *testing.T) {
t.Run("Collection Read", func(t *testing.T) { t.Run("Collection Read", func(t *testing.T) {
t.Run("Should read collection", func(t *testing.T) { t.Run("Should read collection", func(t *testing.T) {
repositories.CreateCollection(testDatabaseName, repositorymodels.Collection{ ts.Repository.CreateCollection(testDatabaseName, repositorymodels.Collection{
ID: testCollectionName, ID: testCollectionName,
}) })
@ -74,7 +73,7 @@ func Test_Collections(t *testing.T) {
}) })
t.Run("Should return not found when collection does not exist", func(t *testing.T) { t.Run("Should return not found when collection does not exist", func(t *testing.T) {
repositories.DeleteCollection(testDatabaseName, testCollectionName) ts.Repository.DeleteCollection(testDatabaseName, testCollectionName)
collectionResponse, err := databaseClient.NewContainer(testCollectionName) collectionResponse, err := databaseClient.NewContainer(testCollectionName)
assert.Nil(t, err) assert.Nil(t, err)
@ -93,7 +92,7 @@ func Test_Collections(t *testing.T) {
t.Run("Collection Delete", func(t *testing.T) { t.Run("Collection Delete", func(t *testing.T) {
t.Run("Should delete collection", func(t *testing.T) { t.Run("Should delete collection", func(t *testing.T) {
repositories.CreateCollection(testDatabaseName, repositorymodels.Collection{ ts.Repository.CreateCollection(testDatabaseName, repositorymodels.Collection{
ID: testCollectionName, ID: testCollectionName,
}) })
@ -106,7 +105,7 @@ func Test_Collections(t *testing.T) {
}) })
t.Run("Should return not found when collection does not exist", func(t *testing.T) { t.Run("Should return not found when collection does not exist", func(t *testing.T) {
repositories.DeleteCollection(testDatabaseName, testCollectionName) ts.Repository.DeleteCollection(testDatabaseName, testCollectionName)
collectionResponse, err := databaseClient.NewContainer(testCollectionName) collectionResponse, err := databaseClient.NewContainer(testCollectionName)
assert.Nil(t, err) assert.Nil(t, err)

View File

@ -5,14 +5,37 @@ import (
"github.com/pikami/cosmium/api" "github.com/pikami/cosmium/api"
"github.com/pikami/cosmium/api/config" "github.com/pikami/cosmium/api/config"
"github.com/pikami/cosmium/internal/repositories"
) )
func runTestServer() *httptest.Server { type TestServer struct {
config.Config.AccountKey = config.DefaultAccountKey Server *httptest.Server
config.Config.ExplorerPath = "/tmp/nothing" Repository *repositories.DataRepository
config.Config.ExplorerBaseUrlLocation = config.ExplorerBaseUrlLocation URL string
}
return httptest.NewServer(api.CreateRouter()) func runTestServerCustomConfig(config config.ServerConfig) *TestServer {
repository := repositories.NewDataRepository(repositories.RepositoryOptions{})
api := api.NewApiServer(repository, config)
server := httptest.NewServer(api.GetRouter())
return &TestServer{
Server: server,
Repository: repository,
URL: server.URL,
}
}
func runTestServer() *TestServer {
config := config.ServerConfig{
AccountKey: config.DefaultAccountKey,
ExplorerPath: "/tmp/nothing",
ExplorerBaseUrlLocation: config.ExplorerBaseUrlLocation,
}
return runTestServerCustomConfig(config)
} }
const ( const (

View File

@ -10,24 +10,23 @@ import (
"github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos" "github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos"
"github.com/pikami/cosmium/api/config" "github.com/pikami/cosmium/api/config"
"github.com/pikami/cosmium/internal/repositories"
repositorymodels "github.com/pikami/cosmium/internal/repository_models" repositorymodels "github.com/pikami/cosmium/internal/repository_models"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func Test_Databases(t *testing.T) { func Test_Databases(t *testing.T) {
ts := runTestServer() ts := runTestServer()
defer ts.Close() defer ts.Server.Close()
client, err := azcosmos.NewClientFromConnectionString( client, err := azcosmos.NewClientFromConnectionString(
fmt.Sprintf("AccountEndpoint=%s;AccountKey=%s", ts.URL, config.Config.AccountKey), fmt.Sprintf("AccountEndpoint=%s;AccountKey=%s", ts.URL, config.DefaultAccountKey),
&azcosmos.ClientOptions{}, &azcosmos.ClientOptions{},
) )
assert.Nil(t, err) assert.Nil(t, err)
t.Run("Database Create", func(t *testing.T) { t.Run("Database Create", func(t *testing.T) {
t.Run("Should create database", func(t *testing.T) { t.Run("Should create database", func(t *testing.T) {
repositories.DeleteDatabase(testDatabaseName) ts.Repository.DeleteDatabase(testDatabaseName)
createResponse, err := client.CreateDatabase(context.TODO(), azcosmos.DatabaseProperties{ createResponse, err := client.CreateDatabase(context.TODO(), azcosmos.DatabaseProperties{
ID: testDatabaseName, ID: testDatabaseName,
@ -38,7 +37,7 @@ func Test_Databases(t *testing.T) {
}) })
t.Run("Should return conflict when database exists", func(t *testing.T) { t.Run("Should return conflict when database exists", func(t *testing.T) {
repositories.CreateDatabase(repositorymodels.Database{ ts.Repository.CreateDatabase(repositorymodels.Database{
ID: testDatabaseName, ID: testDatabaseName,
}) })
@ -58,7 +57,7 @@ func Test_Databases(t *testing.T) {
t.Run("Database Read", func(t *testing.T) { t.Run("Database Read", func(t *testing.T) {
t.Run("Should read database", func(t *testing.T) { t.Run("Should read database", func(t *testing.T) {
repositories.CreateDatabase(repositorymodels.Database{ ts.Repository.CreateDatabase(repositorymodels.Database{
ID: testDatabaseName, ID: testDatabaseName,
}) })
@ -72,7 +71,7 @@ func Test_Databases(t *testing.T) {
}) })
t.Run("Should return not found when database does not exist", func(t *testing.T) { t.Run("Should return not found when database does not exist", func(t *testing.T) {
repositories.DeleteDatabase(testDatabaseName) ts.Repository.DeleteDatabase(testDatabaseName)
databaseResponse, err := client.NewDatabase(testDatabaseName) databaseResponse, err := client.NewDatabase(testDatabaseName)
assert.Nil(t, err) assert.Nil(t, err)
@ -91,7 +90,7 @@ func Test_Databases(t *testing.T) {
t.Run("Database Delete", func(t *testing.T) { t.Run("Database Delete", func(t *testing.T) {
t.Run("Should delete database", func(t *testing.T) { t.Run("Should delete database", func(t *testing.T) {
repositories.CreateDatabase(repositorymodels.Database{ ts.Repository.CreateDatabase(repositorymodels.Database{
ID: testDatabaseName, ID: testDatabaseName,
}) })
@ -104,7 +103,7 @@ func Test_Databases(t *testing.T) {
}) })
t.Run("Should return not found when database does not exist", func(t *testing.T) { t.Run("Should return not found when database does not exist", func(t *testing.T) {
repositories.DeleteDatabase(testDatabaseName) ts.Repository.DeleteDatabase(testDatabaseName)
databaseResponse, err := client.NewDatabase(testDatabaseName) databaseResponse, err := client.NewDatabase(testDatabaseName)
assert.Nil(t, err) assert.Nil(t, err)

View File

@ -6,7 +6,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"net/http" "net/http"
"net/http/httptest"
"reflect" "reflect"
"sync" "sync"
"testing" "testing"
@ -15,7 +14,6 @@ import (
"github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos" "github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos"
"github.com/pikami/cosmium/api/config" "github.com/pikami/cosmium/api/config"
"github.com/pikami/cosmium/internal/repositories"
repositorymodels "github.com/pikami/cosmium/internal/repository_models" repositorymodels "github.com/pikami/cosmium/internal/repository_models"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@ -55,9 +53,11 @@ func testCosmosQuery(t *testing.T,
} }
} }
func documents_InitializeDb(t *testing.T) (*httptest.Server, *azcosmos.ContainerClient) { func documents_InitializeDb(t *testing.T) (*TestServer, *azcosmos.ContainerClient) {
repositories.CreateDatabase(repositorymodels.Database{ID: testDatabaseName}) ts := runTestServer()
repositories.CreateCollection(testDatabaseName, repositorymodels.Collection{
ts.Repository.CreateDatabase(repositorymodels.Database{ID: testDatabaseName})
ts.Repository.CreateCollection(testDatabaseName, repositorymodels.Collection{
ID: testCollectionName, ID: testCollectionName,
PartitionKey: struct { PartitionKey: struct {
Paths []string "json:\"paths\"" Paths []string "json:\"paths\""
@ -67,13 +67,11 @@ func documents_InitializeDb(t *testing.T) (*httptest.Server, *azcosmos.Container
Paths: []string{"/pk"}, Paths: []string{"/pk"},
}, },
}) })
repositories.CreateDocument(testDatabaseName, testCollectionName, map[string]interface{}{"id": "12345", "pk": "123", "isCool": false, "arr": []int{1, 2, 3}}) ts.Repository.CreateDocument(testDatabaseName, testCollectionName, map[string]interface{}{"id": "12345", "pk": "123", "isCool": false, "arr": []int{1, 2, 3}})
repositories.CreateDocument(testDatabaseName, testCollectionName, map[string]interface{}{"id": "67890", "pk": "456", "isCool": true, "arr": []int{6, 7, 8}}) ts.Repository.CreateDocument(testDatabaseName, testCollectionName, map[string]interface{}{"id": "67890", "pk": "456", "isCool": true, "arr": []int{6, 7, 8}})
ts := runTestServer()
client, err := azcosmos.NewClientFromConnectionString( client, err := azcosmos.NewClientFromConnectionString(
fmt.Sprintf("AccountEndpoint=%s;AccountKey=%s", ts.URL, config.Config.AccountKey), fmt.Sprintf("AccountEndpoint=%s;AccountKey=%s", ts.URL, config.DefaultAccountKey),
&azcosmos.ClientOptions{}, &azcosmos.ClientOptions{},
) )
assert.Nil(t, err) assert.Nil(t, err)
@ -86,7 +84,7 @@ func documents_InitializeDb(t *testing.T) (*httptest.Server, *azcosmos.Container
func Test_Documents(t *testing.T) { func Test_Documents(t *testing.T) {
ts, collectionClient := documents_InitializeDb(t) ts, collectionClient := documents_InitializeDb(t)
defer ts.Close() defer ts.Server.Close()
t.Run("Should query document", func(t *testing.T) { t.Run("Should query document", func(t *testing.T) {
testCosmosQuery(t, collectionClient, testCosmosQuery(t, collectionClient,
@ -218,7 +216,7 @@ func Test_Documents(t *testing.T) {
func Test_Documents_Patch(t *testing.T) { func Test_Documents_Patch(t *testing.T) {
ts, collectionClient := documents_InitializeDb(t) ts, collectionClient := documents_InitializeDb(t)
defer ts.Close() defer ts.Server.Close()
t.Run("Should PATCH document", func(t *testing.T) { t.Run("Should PATCH document", func(t *testing.T) {
context := context.TODO() context := context.TODO()

View File

@ -15,14 +15,14 @@ import (
// Request document with trailing slash like python cosmosdb client does. // Request document with trailing slash like python cosmosdb client does.
func Test_Documents_Read_Trailing_Slash(t *testing.T) { func Test_Documents_Read_Trailing_Slash(t *testing.T) {
ts, _ := documents_InitializeDb(t) ts, _ := documents_InitializeDb(t)
defer ts.Close() defer ts.Server.Close()
t.Run("Read doc with client that appends slash to path", func(t *testing.T) { t.Run("Read doc with client that appends slash to path", func(t *testing.T) {
resourceIdTemplate := "dbs/%s/colls/%s/docs/%s" resourceIdTemplate := "dbs/%s/colls/%s/docs/%s"
path := fmt.Sprintf(resourceIdTemplate, testDatabaseName, testCollectionName, "12345") path := fmt.Sprintf(resourceIdTemplate, testDatabaseName, testCollectionName, "12345")
testUrl := ts.URL + "/" + path + "/" testUrl := ts.URL + "/" + path + "/"
date := time.Now().Format(time.RFC1123) date := time.Now().Format(time.RFC1123)
signature := authentication.GenerateSignature("GET", "docs", path, date, config.Config.AccountKey) signature := authentication.GenerateSignature("GET", "docs", path, date, config.DefaultAccountKey)
httpClient := &http.Client{} httpClient := &http.Client{}
req, _ := http.NewRequest("GET", testUrl, nil) req, _ := http.NewRequest("GET", testUrl, nil)
req.Header.Add("x-ms-date", date) req.Header.Add("x-ms-date", date)

View File

@ -11,16 +11,20 @@ import (
) )
func main() { func main() {
config.ParseFlags() configuration := config.ParseFlags()
repositories.InitializeRepository() repository := repositories.NewDataRepository(repositories.RepositoryOptions{
InitialDataFilePath: configuration.InitialDataFilePath,
PersistDataFilePath: configuration.PersistDataFilePath,
})
server := api.StartAPI() server := api.NewApiServer(repository, configuration)
server.Start()
waitForExit(server) waitForExit(server, repository, configuration)
} }
func waitForExit(server *api.Server) { func waitForExit(server *api.ApiServer, repository *repositories.DataRepository, config config.ServerConfig) {
sigs := make(chan os.Signal, 1) sigs := make(chan os.Signal, 1)
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
@ -28,9 +32,9 @@ func waitForExit(server *api.Server) {
<-sigs <-sigs
// Stop the server // Stop the server
server.StopServer <- true server.Stop()
if config.Config.PersistDataFilePath != "" { if config.PersistDataFilePath != "" {
repositories.SaveStateFS(config.Config.PersistDataFilePath) repository.SaveStateFS(config.PersistDataFilePath)
} }
} }

View File

@ -3,22 +3,22 @@ package logger
import ( import (
"log" "log"
"os" "os"
"github.com/pikami/cosmium/api/config"
) )
var EnableDebugOutput = false
var DebugLogger = log.New(os.Stdout, "", log.Ldate|log.Ltime|log.Lshortfile) var DebugLogger = log.New(os.Stdout, "", log.Ldate|log.Ltime|log.Lshortfile)
var InfoLogger = log.New(os.Stdout, "", log.Ldate|log.Ltime) var InfoLogger = log.New(os.Stdout, "", log.Ldate|log.Ltime)
var ErrorLogger = log.New(os.Stderr, "", log.Ldate|log.Ltime|log.Lshortfile) var ErrorLogger = log.New(os.Stderr, "", log.Ldate|log.Ltime|log.Lshortfile)
func Debug(v ...any) { func Debug(v ...any) {
if config.Config.Debug { if EnableDebugOutput {
DebugLogger.Println(v...) DebugLogger.Println(v...)
} }
} }
func Debugf(format string, v ...any) { func Debugf(format string, v ...any) {
if config.Config.Debug { if EnableDebugOutput {
DebugLogger.Printf(format, v...) DebugLogger.Printf(format, v...)
} }
} }

View File

@ -11,60 +11,60 @@ import (
"golang.org/x/exp/maps" "golang.org/x/exp/maps"
) )
func GetAllCollections(databaseId string) ([]repositorymodels.Collection, repositorymodels.RepositoryStatus) { func (r *DataRepository) GetAllCollections(databaseId string) ([]repositorymodels.Collection, repositorymodels.RepositoryStatus) {
storeState.RLock() r.storeState.RLock()
defer storeState.RUnlock() defer r.storeState.RUnlock()
if _, ok := storeState.Databases[databaseId]; !ok { if _, ok := r.storeState.Databases[databaseId]; !ok {
return make([]repositorymodels.Collection, 0), repositorymodels.StatusNotFound return make([]repositorymodels.Collection, 0), repositorymodels.StatusNotFound
} }
return maps.Values(storeState.Collections[databaseId]), repositorymodels.StatusOk return maps.Values(r.storeState.Collections[databaseId]), repositorymodels.StatusOk
} }
func GetCollection(databaseId string, collectionId string) (repositorymodels.Collection, repositorymodels.RepositoryStatus) { func (r *DataRepository) GetCollection(databaseId string, collectionId string) (repositorymodels.Collection, repositorymodels.RepositoryStatus) {
storeState.RLock() r.storeState.RLock()
defer storeState.RUnlock() defer r.storeState.RUnlock()
if _, ok := storeState.Databases[databaseId]; !ok { if _, ok := r.storeState.Databases[databaseId]; !ok {
return repositorymodels.Collection{}, repositorymodels.StatusNotFound return repositorymodels.Collection{}, repositorymodels.StatusNotFound
} }
if _, ok := storeState.Collections[databaseId][collectionId]; !ok { if _, ok := r.storeState.Collections[databaseId][collectionId]; !ok {
return repositorymodels.Collection{}, repositorymodels.StatusNotFound return repositorymodels.Collection{}, repositorymodels.StatusNotFound
} }
return storeState.Collections[databaseId][collectionId], repositorymodels.StatusOk return r.storeState.Collections[databaseId][collectionId], repositorymodels.StatusOk
} }
func DeleteCollection(databaseId string, collectionId string) repositorymodels.RepositoryStatus { func (r *DataRepository) DeleteCollection(databaseId string, collectionId string) repositorymodels.RepositoryStatus {
storeState.Lock() r.storeState.Lock()
defer storeState.Unlock() defer r.storeState.Unlock()
if _, ok := storeState.Databases[databaseId]; !ok { if _, ok := r.storeState.Databases[databaseId]; !ok {
return repositorymodels.StatusNotFound return repositorymodels.StatusNotFound
} }
if _, ok := storeState.Collections[databaseId][collectionId]; !ok { if _, ok := r.storeState.Collections[databaseId][collectionId]; !ok {
return repositorymodels.StatusNotFound return repositorymodels.StatusNotFound
} }
delete(storeState.Collections[databaseId], collectionId) delete(r.storeState.Collections[databaseId], collectionId)
return repositorymodels.StatusOk return repositorymodels.StatusOk
} }
func CreateCollection(databaseId string, newCollection repositorymodels.Collection) (repositorymodels.Collection, repositorymodels.RepositoryStatus) { func (r *DataRepository) CreateCollection(databaseId string, newCollection repositorymodels.Collection) (repositorymodels.Collection, repositorymodels.RepositoryStatus) {
storeState.Lock() r.storeState.Lock()
defer storeState.Unlock() defer r.storeState.Unlock()
var ok bool var ok bool
var database repositorymodels.Database var database repositorymodels.Database
if database, ok = storeState.Databases[databaseId]; !ok { if database, ok = r.storeState.Databases[databaseId]; !ok {
return repositorymodels.Collection{}, repositorymodels.StatusNotFound return repositorymodels.Collection{}, repositorymodels.StatusNotFound
} }
if _, ok = storeState.Collections[databaseId][newCollection.ID]; ok { if _, ok = r.storeState.Collections[databaseId][newCollection.ID]; ok {
return repositorymodels.Collection{}, repositorymodels.Conflict return repositorymodels.Collection{}, repositorymodels.Conflict
} }
@ -75,8 +75,8 @@ func CreateCollection(databaseId string, newCollection repositorymodels.Collecti
newCollection.ETag = fmt.Sprintf("\"%s\"", uuid.New()) newCollection.ETag = fmt.Sprintf("\"%s\"", uuid.New())
newCollection.Self = fmt.Sprintf("dbs/%s/colls/%s/", database.ResourceID, newCollection.ResourceID) newCollection.Self = fmt.Sprintf("dbs/%s/colls/%s/", database.ResourceID, newCollection.ResourceID)
storeState.Collections[databaseId][newCollection.ID] = newCollection r.storeState.Collections[databaseId][newCollection.ID] = newCollection
storeState.Documents[databaseId][newCollection.ID] = make(map[string]repositorymodels.Document) r.storeState.Documents[databaseId][newCollection.ID] = make(map[string]repositorymodels.Document)
return newCollection, repositorymodels.StatusOk return newCollection, repositorymodels.StatusOk
} }

View File

@ -10,42 +10,42 @@ import (
"golang.org/x/exp/maps" "golang.org/x/exp/maps"
) )
func GetAllDatabases() ([]repositorymodels.Database, repositorymodels.RepositoryStatus) { func (r *DataRepository) GetAllDatabases() ([]repositorymodels.Database, repositorymodels.RepositoryStatus) {
storeState.RLock() r.storeState.RLock()
defer storeState.RUnlock() defer r.storeState.RUnlock()
return maps.Values(storeState.Databases), repositorymodels.StatusOk return maps.Values(r.storeState.Databases), repositorymodels.StatusOk
} }
func GetDatabase(id string) (repositorymodels.Database, repositorymodels.RepositoryStatus) { func (r *DataRepository) GetDatabase(id string) (repositorymodels.Database, repositorymodels.RepositoryStatus) {
storeState.RLock() r.storeState.RLock()
defer storeState.RUnlock() defer r.storeState.RUnlock()
if database, ok := storeState.Databases[id]; ok { if database, ok := r.storeState.Databases[id]; ok {
return database, repositorymodels.StatusOk return database, repositorymodels.StatusOk
} }
return repositorymodels.Database{}, repositorymodels.StatusNotFound return repositorymodels.Database{}, repositorymodels.StatusNotFound
} }
func DeleteDatabase(id string) repositorymodels.RepositoryStatus { func (r *DataRepository) DeleteDatabase(id string) repositorymodels.RepositoryStatus {
storeState.Lock() r.storeState.Lock()
defer storeState.Unlock() defer r.storeState.Unlock()
if _, ok := storeState.Databases[id]; !ok { if _, ok := r.storeState.Databases[id]; !ok {
return repositorymodels.StatusNotFound return repositorymodels.StatusNotFound
} }
delete(storeState.Databases, id) delete(r.storeState.Databases, id)
return repositorymodels.StatusOk return repositorymodels.StatusOk
} }
func CreateDatabase(newDatabase repositorymodels.Database) (repositorymodels.Database, repositorymodels.RepositoryStatus) { func (r *DataRepository) CreateDatabase(newDatabase repositorymodels.Database) (repositorymodels.Database, repositorymodels.RepositoryStatus) {
storeState.Lock() r.storeState.Lock()
defer storeState.Unlock() defer r.storeState.Unlock()
if _, ok := storeState.Databases[newDatabase.ID]; ok { if _, ok := r.storeState.Databases[newDatabase.ID]; ok {
return repositorymodels.Database{}, repositorymodels.Conflict return repositorymodels.Database{}, repositorymodels.Conflict
} }
@ -54,9 +54,9 @@ func CreateDatabase(newDatabase repositorymodels.Database) (repositorymodels.Dat
newDatabase.ETag = fmt.Sprintf("\"%s\"", uuid.New()) newDatabase.ETag = fmt.Sprintf("\"%s\"", uuid.New())
newDatabase.Self = fmt.Sprintf("dbs/%s/", newDatabase.ResourceID) newDatabase.Self = fmt.Sprintf("dbs/%s/", newDatabase.ResourceID)
storeState.Databases[newDatabase.ID] = newDatabase r.storeState.Databases[newDatabase.ID] = newDatabase
storeState.Collections[newDatabase.ID] = make(map[string]repositorymodels.Collection) r.storeState.Collections[newDatabase.ID] = make(map[string]repositorymodels.Collection)
storeState.Documents[newDatabase.ID] = make(map[string]map[string]repositorymodels.Document) r.storeState.Documents[newDatabase.ID] = make(map[string]map[string]repositorymodels.Document)
return newDatabase, repositorymodels.StatusOk return newDatabase, repositorymodels.StatusOk
} }

View File

@ -14,64 +14,64 @@ import (
"golang.org/x/exp/maps" "golang.org/x/exp/maps"
) )
func GetAllDocuments(databaseId string, collectionId string) ([]repositorymodels.Document, repositorymodels.RepositoryStatus) { func (r *DataRepository) GetAllDocuments(databaseId string, collectionId string) ([]repositorymodels.Document, repositorymodels.RepositoryStatus) {
storeState.RLock() r.storeState.RLock()
defer storeState.RUnlock() defer r.storeState.RUnlock()
if _, ok := storeState.Databases[databaseId]; !ok { if _, ok := r.storeState.Databases[databaseId]; !ok {
return make([]repositorymodels.Document, 0), repositorymodels.StatusNotFound return make([]repositorymodels.Document, 0), repositorymodels.StatusNotFound
} }
if _, ok := storeState.Collections[databaseId][collectionId]; !ok { if _, ok := r.storeState.Collections[databaseId][collectionId]; !ok {
return make([]repositorymodels.Document, 0), repositorymodels.StatusNotFound return make([]repositorymodels.Document, 0), repositorymodels.StatusNotFound
} }
return maps.Values(storeState.Documents[databaseId][collectionId]), repositorymodels.StatusOk return maps.Values(r.storeState.Documents[databaseId][collectionId]), repositorymodels.StatusOk
} }
func GetDocument(databaseId string, collectionId string, documentId string) (repositorymodels.Document, repositorymodels.RepositoryStatus) { func (r *DataRepository) GetDocument(databaseId string, collectionId string, documentId string) (repositorymodels.Document, repositorymodels.RepositoryStatus) {
storeState.RLock() r.storeState.RLock()
defer storeState.RUnlock() defer r.storeState.RUnlock()
if _, ok := storeState.Databases[databaseId]; !ok { if _, ok := r.storeState.Databases[databaseId]; !ok {
return repositorymodels.Document{}, repositorymodels.StatusNotFound return repositorymodels.Document{}, repositorymodels.StatusNotFound
} }
if _, ok := storeState.Collections[databaseId][collectionId]; !ok { if _, ok := r.storeState.Collections[databaseId][collectionId]; !ok {
return repositorymodels.Document{}, repositorymodels.StatusNotFound return repositorymodels.Document{}, repositorymodels.StatusNotFound
} }
if _, ok := storeState.Documents[databaseId][collectionId][documentId]; !ok { if _, ok := r.storeState.Documents[databaseId][collectionId][documentId]; !ok {
return repositorymodels.Document{}, repositorymodels.StatusNotFound return repositorymodels.Document{}, repositorymodels.StatusNotFound
} }
return storeState.Documents[databaseId][collectionId][documentId], repositorymodels.StatusOk return r.storeState.Documents[databaseId][collectionId][documentId], repositorymodels.StatusOk
} }
func DeleteDocument(databaseId string, collectionId string, documentId string) repositorymodels.RepositoryStatus { func (r *DataRepository) DeleteDocument(databaseId string, collectionId string, documentId string) repositorymodels.RepositoryStatus {
storeState.Lock() r.storeState.Lock()
defer storeState.Unlock() defer r.storeState.Unlock()
if _, ok := storeState.Databases[databaseId]; !ok { if _, ok := r.storeState.Databases[databaseId]; !ok {
return repositorymodels.StatusNotFound return repositorymodels.StatusNotFound
} }
if _, ok := storeState.Collections[databaseId][collectionId]; !ok { if _, ok := r.storeState.Collections[databaseId][collectionId]; !ok {
return repositorymodels.StatusNotFound return repositorymodels.StatusNotFound
} }
if _, ok := storeState.Documents[databaseId][collectionId][documentId]; !ok { if _, ok := r.storeState.Documents[databaseId][collectionId][documentId]; !ok {
return repositorymodels.StatusNotFound return repositorymodels.StatusNotFound
} }
delete(storeState.Documents[databaseId][collectionId], documentId) delete(r.storeState.Documents[databaseId][collectionId], documentId)
return repositorymodels.StatusOk return repositorymodels.StatusOk
} }
func CreateDocument(databaseId string, collectionId string, document map[string]interface{}) (repositorymodels.Document, repositorymodels.RepositoryStatus) { func (r *DataRepository) CreateDocument(databaseId string, collectionId string, document map[string]interface{}) (repositorymodels.Document, repositorymodels.RepositoryStatus) {
storeState.Lock() r.storeState.Lock()
defer storeState.Unlock() defer r.storeState.Unlock()
var ok bool var ok bool
var documentId string var documentId string
@ -82,15 +82,15 @@ func CreateDocument(databaseId string, collectionId string, document map[string]
document["id"] = documentId document["id"] = documentId
} }
if database, ok = storeState.Databases[databaseId]; !ok { if database, ok = r.storeState.Databases[databaseId]; !ok {
return repositorymodels.Document{}, repositorymodels.StatusNotFound return repositorymodels.Document{}, repositorymodels.StatusNotFound
} }
if collection, ok = storeState.Collections[databaseId][collectionId]; !ok { if collection, ok = r.storeState.Collections[databaseId][collectionId]; !ok {
return repositorymodels.Document{}, repositorymodels.StatusNotFound return repositorymodels.Document{}, repositorymodels.StatusNotFound
} }
if _, ok := storeState.Documents[databaseId][collectionId][documentId]; ok { if _, ok := r.storeState.Documents[databaseId][collectionId][documentId]; ok {
return repositorymodels.Document{}, repositorymodels.Conflict return repositorymodels.Document{}, repositorymodels.Conflict
} }
@ -99,19 +99,19 @@ func CreateDocument(databaseId string, collectionId string, document map[string]
document["_etag"] = fmt.Sprintf("\"%s\"", uuid.New()) document["_etag"] = fmt.Sprintf("\"%s\"", uuid.New())
document["_self"] = fmt.Sprintf("dbs/%s/colls/%s/docs/%s/", database.ResourceID, collection.ResourceID, document["_rid"]) document["_self"] = fmt.Sprintf("dbs/%s/colls/%s/docs/%s/", database.ResourceID, collection.ResourceID, document["_rid"])
storeState.Documents[databaseId][collectionId][documentId] = document r.storeState.Documents[databaseId][collectionId][documentId] = document
return document, repositorymodels.StatusOk return document, repositorymodels.StatusOk
} }
func ExecuteQueryDocuments(databaseId string, collectionId string, query string, queryParameters map[string]interface{}) ([]memoryexecutor.RowType, repositorymodels.RepositoryStatus) { func (r *DataRepository) ExecuteQueryDocuments(databaseId string, collectionId string, query string, queryParameters map[string]interface{}) ([]memoryexecutor.RowType, repositorymodels.RepositoryStatus) {
parsedQuery, err := nosql.Parse("", []byte(query)) parsedQuery, err := nosql.Parse("", []byte(query))
if err != nil { if err != nil {
log.Printf("Failed to parse query: %s\nerr: %v", query, err) log.Printf("Failed to parse query: %s\nerr: %v", query, err)
return nil, repositorymodels.BadRequest return nil, repositorymodels.BadRequest
} }
collectionDocuments, status := GetAllDocuments(databaseId, collectionId) collectionDocuments, status := r.GetAllDocuments(databaseId, collectionId)
if status != repositorymodels.StatusOk { if status != repositorymodels.StatusOk {
return nil, status return nil, status
} }

View File

@ -9,19 +9,19 @@ import (
) )
// I have no idea what this is tbh // I have no idea what this is tbh
func GetPartitionKeyRanges(databaseId string, collectionId string) ([]repositorymodels.PartitionKeyRange, repositorymodels.RepositoryStatus) { func (r *DataRepository) GetPartitionKeyRanges(databaseId string, collectionId string) ([]repositorymodels.PartitionKeyRange, repositorymodels.RepositoryStatus) {
storeState.RLock() r.storeState.RLock()
defer storeState.RUnlock() defer r.storeState.RUnlock()
databaseRid := databaseId databaseRid := databaseId
collectionRid := collectionId collectionRid := collectionId
var timestamp int64 = 0 var timestamp int64 = 0
if database, ok := storeState.Databases[databaseId]; !ok { if database, ok := r.storeState.Databases[databaseId]; !ok {
databaseRid = database.ResourceID databaseRid = database.ResourceID
} }
if collection, ok := storeState.Collections[databaseId][collectionId]; !ok { if collection, ok := r.storeState.Collections[databaseId][collectionId]; !ok {
collectionRid = collection.ResourceID collectionRid = collection.ResourceID
timestamp = collection.TimeStamp timestamp = collection.TimeStamp
} }

View File

@ -0,0 +1,37 @@
package repositories
import repositorymodels "github.com/pikami/cosmium/internal/repository_models"
type DataRepository struct {
storedProcedures []repositorymodels.StoredProcedure
triggers []repositorymodels.Trigger
userDefinedFunctions []repositorymodels.UserDefinedFunction
storeState repositorymodels.State
initialDataFilePath string
persistDataFilePath string
}
type RepositoryOptions struct {
InitialDataFilePath string
PersistDataFilePath string
}
func NewDataRepository(options RepositoryOptions) *DataRepository {
repository := &DataRepository{
storedProcedures: []repositorymodels.StoredProcedure{},
triggers: []repositorymodels.Trigger{},
userDefinedFunctions: []repositorymodels.UserDefinedFunction{},
storeState: repositorymodels.State{
Databases: make(map[string]repositorymodels.Database),
Collections: make(map[string]map[string]repositorymodels.Collection),
Documents: make(map[string]map[string]map[string]repositorymodels.Document),
},
initialDataFilePath: options.InitialDataFilePath,
persistDataFilePath: options.PersistDataFilePath,
}
repository.InitializeRepository()
return repository
}

View File

@ -6,28 +6,18 @@ import (
"os" "os"
"reflect" "reflect"
"github.com/pikami/cosmium/api/config"
"github.com/pikami/cosmium/internal/logger" "github.com/pikami/cosmium/internal/logger"
repositorymodels "github.com/pikami/cosmium/internal/repository_models" repositorymodels "github.com/pikami/cosmium/internal/repository_models"
) )
var storedProcedures = []repositorymodels.StoredProcedure{} func (r *DataRepository) InitializeRepository() {
var triggers = []repositorymodels.Trigger{} if r.initialDataFilePath != "" {
var userDefinedFunctions = []repositorymodels.UserDefinedFunction{} r.LoadStateFS(r.initialDataFilePath)
var storeState = repositorymodels.State{
Databases: make(map[string]repositorymodels.Database),
Collections: make(map[string]map[string]repositorymodels.Collection),
Documents: make(map[string]map[string]map[string]repositorymodels.Document),
}
func InitializeRepository() {
if config.Config.InitialDataFilePath != "" {
LoadStateFS(config.Config.InitialDataFilePath)
return return
} }
if config.Config.PersistDataFilePath != "" { if r.persistDataFilePath != "" {
stat, err := os.Stat(config.Config.PersistDataFilePath) stat, err := os.Stat(r.persistDataFilePath)
if err != nil { if err != nil {
return return
} }
@ -37,12 +27,12 @@ func InitializeRepository() {
os.Exit(1) os.Exit(1)
} }
LoadStateFS(config.Config.PersistDataFilePath) r.LoadStateFS(r.persistDataFilePath)
return return
} }
} }
func LoadStateFS(filePath string) { func (r *DataRepository) LoadStateFS(filePath string) {
data, err := os.ReadFile(filePath) data, err := os.ReadFile(filePath)
if err != nil { if err != nil {
log.Fatalf("Error reading state JSON file: %v", err) log.Fatalf("Error reading state JSON file: %v", err)
@ -60,16 +50,16 @@ func LoadStateFS(filePath string) {
logger.Infof("Collections: %d\n", getLength(state.Collections)) logger.Infof("Collections: %d\n", getLength(state.Collections))
logger.Infof("Documents: %d\n", getLength(state.Documents)) logger.Infof("Documents: %d\n", getLength(state.Documents))
storeState = state r.storeState = state
ensureStoreStateNoNullReferences() r.ensureStoreStateNoNullReferences()
} }
func SaveStateFS(filePath string) { func (r *DataRepository) SaveStateFS(filePath string) {
storeState.RLock() r.storeState.RLock()
defer storeState.RUnlock() defer r.storeState.RUnlock()
data, err := json.MarshalIndent(storeState, "", "\t") data, err := json.MarshalIndent(r.storeState, "", "\t")
if err != nil { if err != nil {
logger.Errorf("Failed to save state: %v\n", err) logger.Errorf("Failed to save state: %v\n", err)
return return
@ -78,16 +68,16 @@ func SaveStateFS(filePath string) {
os.WriteFile(filePath, data, os.ModePerm) os.WriteFile(filePath, data, os.ModePerm)
logger.Info("Saved state:") logger.Info("Saved state:")
logger.Infof("Databases: %d\n", getLength(storeState.Databases)) logger.Infof("Databases: %d\n", getLength(r.storeState.Databases))
logger.Infof("Collections: %d\n", getLength(storeState.Collections)) logger.Infof("Collections: %d\n", getLength(r.storeState.Collections))
logger.Infof("Documents: %d\n", getLength(storeState.Documents)) logger.Infof("Documents: %d\n", getLength(r.storeState.Documents))
} }
func GetState() (string, error) { func (r *DataRepository) GetState() (string, error) {
storeState.RLock() r.storeState.RLock()
defer storeState.RUnlock() defer r.storeState.RUnlock()
data, err := json.MarshalIndent(storeState, "", "\t") data, err := json.MarshalIndent(r.storeState, "", "\t")
if err != nil { if err != nil {
logger.Errorf("Failed to serialize state: %v\n", err) logger.Errorf("Failed to serialize state: %v\n", err)
return "", err return "", err
@ -121,36 +111,36 @@ func getLength(v interface{}) int {
return count return count
} }
func ensureStoreStateNoNullReferences() { func (r *DataRepository) ensureStoreStateNoNullReferences() {
if storeState.Databases == nil { if r.storeState.Databases == nil {
storeState.Databases = make(map[string]repositorymodels.Database) r.storeState.Databases = make(map[string]repositorymodels.Database)
} }
if storeState.Collections == nil { if r.storeState.Collections == nil {
storeState.Collections = make(map[string]map[string]repositorymodels.Collection) r.storeState.Collections = make(map[string]map[string]repositorymodels.Collection)
} }
if storeState.Documents == nil { if r.storeState.Documents == nil {
storeState.Documents = make(map[string]map[string]map[string]repositorymodels.Document) r.storeState.Documents = make(map[string]map[string]map[string]repositorymodels.Document)
} }
for database := range storeState.Databases { for database := range r.storeState.Databases {
if storeState.Collections[database] == nil { if r.storeState.Collections[database] == nil {
storeState.Collections[database] = make(map[string]repositorymodels.Collection) r.storeState.Collections[database] = make(map[string]repositorymodels.Collection)
} }
if storeState.Documents[database] == nil { if r.storeState.Documents[database] == nil {
storeState.Documents[database] = make(map[string]map[string]repositorymodels.Document) r.storeState.Documents[database] = make(map[string]map[string]repositorymodels.Document)
} }
for collection := range storeState.Collections[database] { for collection := range r.storeState.Collections[database] {
if storeState.Documents[database][collection] == nil { if r.storeState.Documents[database][collection] == nil {
storeState.Documents[database][collection] = make(map[string]repositorymodels.Document) r.storeState.Documents[database][collection] = make(map[string]repositorymodels.Document)
} }
for document := range storeState.Documents[database][collection] { for document := range r.storeState.Documents[database][collection] {
if storeState.Documents[database][collection][document] == nil { if r.storeState.Documents[database][collection][document] == nil {
delete(storeState.Documents[database][collection], document) delete(r.storeState.Documents[database][collection], document)
} }
} }
} }

View File

@ -2,6 +2,6 @@ package repositories
import repositorymodels "github.com/pikami/cosmium/internal/repository_models" import repositorymodels "github.com/pikami/cosmium/internal/repository_models"
func GetAllStoredProcedures(databaseId string, collectionId string) ([]repositorymodels.StoredProcedure, repositorymodels.RepositoryStatus) { func (r *DataRepository) GetAllStoredProcedures(databaseId string, collectionId string) ([]repositorymodels.StoredProcedure, repositorymodels.RepositoryStatus) {
return storedProcedures, repositorymodels.StatusOk return r.storedProcedures, repositorymodels.StatusOk
} }

View File

@ -2,6 +2,6 @@ package repositories
import repositorymodels "github.com/pikami/cosmium/internal/repository_models" import repositorymodels "github.com/pikami/cosmium/internal/repository_models"
func GetAllTriggers(databaseId string, collectionId string) ([]repositorymodels.Trigger, repositorymodels.RepositoryStatus) { func (r *DataRepository) GetAllTriggers(databaseId string, collectionId string) ([]repositorymodels.Trigger, repositorymodels.RepositoryStatus) {
return triggers, repositorymodels.StatusOk return r.triggers, repositorymodels.StatusOk
} }

View File

@ -2,6 +2,6 @@ package repositories
import repositorymodels "github.com/pikami/cosmium/internal/repository_models" import repositorymodels "github.com/pikami/cosmium/internal/repository_models"
func GetAllUserDefinedFunctions(databaseId string, collectionId string) ([]repositorymodels.UserDefinedFunction, repositorymodels.RepositoryStatus) { func (r *DataRepository) GetAllUserDefinedFunctions(databaseId string, collectionId string) ([]repositorymodels.UserDefinedFunction, repositorymodels.RepositoryStatus) {
return userDefinedFunctions, repositorymodels.StatusOk return r.userDefinedFunctions, repositorymodels.StatusOk
} }

View File

@ -9,44 +9,77 @@ import (
"github.com/pikami/cosmium/internal/repositories" "github.com/pikami/cosmium/internal/repositories"
) )
var currentServer *api.Server type ServerInstance struct {
server *api.ApiServer
repository *repositories.DataRepository
}
var serverInstances map[string]*ServerInstance
const (
ResponseSuccess = 0
ResponseUnknown = 1
ResponseServerInstanceAlreadyExists = 2
ResponseFailedToParseConfiguration = 3
ResponseServerInstanceNotFound = 4
)
//export CreateServerInstance
func CreateServerInstance(serverName *C.char, configurationJSON *C.char) int {
if serverInstances == nil {
serverInstances = make(map[string]*ServerInstance)
}
if _, ok := serverInstances[C.GoString(serverName)]; ok {
return ResponseServerInstanceAlreadyExists
}
//export Configure
func Configure(configurationJSON *C.char) bool {
var configuration config.ServerConfig var configuration config.ServerConfig
err := json.Unmarshal([]byte(C.GoString(configurationJSON)), &configuration) err := json.Unmarshal([]byte(C.GoString(configurationJSON)), &configuration)
if err != nil { if err != nil {
return false return ResponseFailedToParseConfiguration
} }
config.Config = configuration
return true
}
//export InitializeRepository configuration.PopulateCalculatedFields()
func InitializeRepository() {
repositories.InitializeRepository()
}
//export StartAPI repository := repositories.NewDataRepository(repositories.RepositoryOptions{
func StartAPI() { InitialDataFilePath: configuration.InitialDataFilePath,
currentServer = api.StartAPI() PersistDataFilePath: configuration.PersistDataFilePath,
} })
//export StopAPI server := api.NewApiServer(repository, configuration)
func StopAPI() { server.Start()
if currentServer == nil {
currentServer.StopServer <- true serverInstances[C.GoString(serverName)] = &ServerInstance{
currentServer = nil server: server,
repository: repository,
} }
return ResponseSuccess
} }
//export GetState //export StopServerInstance
func GetState() *C.char { func StopServerInstance(serverName *C.char) int {
stateJSON, err := repositories.GetState() if serverInstance, ok := serverInstances[C.GoString(serverName)]; ok {
if err != nil { serverInstance.server.Stop()
return nil delete(serverInstances, C.GoString(serverName))
return ResponseSuccess
} }
return C.CString(stateJSON)
return ResponseServerInstanceNotFound
}
//export GetServerInstanceState
func GetServerInstanceState(serverName *C.char) *C.char {
if serverInstance, ok := serverInstances[C.GoString(serverName)]; ok {
stateJSON, err := serverInstance.repository.GetState()
if err != nil {
return nil
}
return C.CString(stateJSON)
}
return nil
} }
func main() {} func main() {}