From 9abef691d61795c0eb425400aeac5b979c43b5bc Mon Sep 17 00:00:00 2001 From: Erik Zentveld Date: Mon, 28 Oct 2024 13:29:26 +0100 Subject: [PATCH] serve request paths with trailing slashes, as sent by python client --- .../middleware/trailing_slash_stripper.go | 17 +++++++++ api/router.go | 30 ++++++++++++---- api/tests/documents_trailingslash_test.go | 36 +++++++++++++++++++ 3 files changed, 76 insertions(+), 7 deletions(-) create mode 100644 api/handlers/middleware/trailing_slash_stripper.go create mode 100644 api/tests/documents_trailingslash_test.go diff --git a/api/handlers/middleware/trailing_slash_stripper.go b/api/handlers/middleware/trailing_slash_stripper.go new file mode 100644 index 0000000..8877f32 --- /dev/null +++ b/api/handlers/middleware/trailing_slash_stripper.go @@ -0,0 +1,17 @@ +package middleware + +import ( + "strings" + + "github.com/gin-gonic/gin" +) + +func TrailingSlashStripper() gin.HandlerFunc { + return func(c *gin.Context) { + if (len(c.Request.URL.Path)) > 1 { //dont strip root dir slash, path="/" + var stripped_path = strings.TrimSuffix(c.Request.URL.Path, "/") + c.Request.URL.Path = stripped_path + } + c.Next() + } +} diff --git a/api/router.go b/api/router.go index 9f1fc76..96ded93 100644 --- a/api/router.go +++ b/api/router.go @@ -13,7 +13,10 @@ import ( ) func CreateRouter() *gin.Engine { - router := gin.Default() + + router := gin.Default(func(e *gin.Engine) { + e.RemoveExtraSlash = true + }) if config.Config.Debug { router.Use(middleware.RequestLogger()) @@ -22,38 +25,51 @@ func CreateRouter() *gin.Engine { router.Use(middleware.Authentication()) router.GET("/dbs/:databaseId/colls/:collId/pkranges", handlers.GetPartitionKeyRanges) - router.POST("/dbs/:databaseId/colls/:collId/docs", handlers.DocumentsPost) router.GET("/dbs/:databaseId/colls/:collId/docs", handlers.GetAllDocuments) router.GET("/dbs/:databaseId/colls/:collId/docs/:docId", handlers.GetDocument) router.PUT("/dbs/:databaseId/colls/:collId/docs/:docId", handlers.ReplaceDocument) router.PATCH("/dbs/:databaseId/colls/:collId/docs/:docId", handlers.PatchDocument) router.DELETE("/dbs/:databaseId/colls/:collId/docs/:docId", handlers.DeleteDocument) - router.POST("/dbs/:databaseId/colls", handlers.CreateCollection) router.GET("/dbs/:databaseId/colls", handlers.GetAllCollections) router.GET("/dbs/:databaseId/colls/:collId", handlers.GetCollection) router.DELETE("/dbs/:databaseId/colls/:collId", handlers.DeleteCollection) - router.POST("/dbs", handlers.CreateDatabase) router.GET("/dbs", handlers.GetAllDatabases) router.GET("/dbs/:databaseId", handlers.GetDatabase) router.DELETE("/dbs/:databaseId", handlers.DeleteDatabase) - router.GET("/dbs/:databaseId/colls/:collId/udfs", handlers.GetAllUserDefinedFunctions) router.GET("/dbs/:databaseId/colls/:collId/sprocs", handlers.GetAllStoredProcedures) router.GET("/dbs/:databaseId/colls/:collId/triggers", handlers.GetAllTriggers) - router.GET("/offers", handlers.GetOffers) router.GET("/", handlers.GetServerInfo) - router.GET("/cosmium/export", handlers.CosmiumExport) + addRoutesForTrailingSlashes(router) + handlers.RegisterExplorerHandlers(router) return router } +func addRoutesForTrailingSlashes(router *gin.Engine) { + trailingSlashGroup := router.Group("/") + //prepend, so slash is stripped before authentication middleware reads path + trailingSlashGroup.Handlers = prepend(trailingSlashGroup.Handlers, middleware.TrailingSlashStripper()) + + for _, route := range router.Routes() { + if route.Path != "/" { //don't append slash to root path, already handled by RemoveExtraSlash + trailingSlashGroup.Handle(route.Method, route.Path+"/", route.HandlerFunc) + } + } +} + +func prepend[T any](a []T, e T) []T { + a = append([]T{e}, a...) + return a +} + func StartAPI() { if !config.Config.Debug { gin.SetMode(gin.ReleaseMode) diff --git a/api/tests/documents_trailingslash_test.go b/api/tests/documents_trailingslash_test.go new file mode 100644 index 0000000..0557266 --- /dev/null +++ b/api/tests/documents_trailingslash_test.go @@ -0,0 +1,36 @@ +package tests_test + +import ( + "fmt" + "net/http" + "net/url" + "testing" + "time" + + "github.com/pikami/cosmium/api/config" + "github.com/pikami/cosmium/internal/authentication" + "github.com/stretchr/testify/assert" +) + +// Request document with trailing slash like python cosmosdb client does. +func Test_Documents_Read_Trailing_Slash(t *testing.T) { + ts, _ := documents_InitializeDb(t) + defer ts.Close() + + t.Run("Read doc with client that appends slash to path", func(t *testing.T) { + resourceIdTemplate := "dbs/%s/colls/%s/docs/%s" + path := fmt.Sprintf(resourceIdTemplate, testDatabaseName, testCollectionName, "12345") + testUrl := ts.URL + "/" + path + "/" + date := time.Now().Format(time.RFC1123) + signature := authentication.GenerateSignature("GET", "docs", path, date, config.Config.AccountKey) + httpClient := &http.Client{} + req, _ := http.NewRequest("GET", testUrl, nil) + req.Header.Add("x-ms-date", date) + req.Header.Add("authorization", "sig="+url.QueryEscape(signature)) + _, err := httpClient.Do(req) + + assert.Nil(t, err) + + }) + +}