serve request paths with trailing slashes, as sent by python client

This commit is contained in:
Erik Zentveld 2024-10-28 13:29:26 +01:00
parent 20af73ee9c
commit 9abef691d6
3 changed files with 76 additions and 7 deletions

View File

@ -0,0 +1,17 @@
package middleware
import (
"strings"
"github.com/gin-gonic/gin"
)
func TrailingSlashStripper() gin.HandlerFunc {
return func(c *gin.Context) {
if (len(c.Request.URL.Path)) > 1 { //dont strip root dir slash, path="/"
var stripped_path = strings.TrimSuffix(c.Request.URL.Path, "/")
c.Request.URL.Path = stripped_path
}
c.Next()
}
}

View File

@ -13,7 +13,10 @@ import (
)
func CreateRouter() *gin.Engine {
router := gin.Default()
router := gin.Default(func(e *gin.Engine) {
e.RemoveExtraSlash = true
})
if config.Config.Debug {
router.Use(middleware.RequestLogger())
@ -22,38 +25,51 @@ func CreateRouter() *gin.Engine {
router.Use(middleware.Authentication())
router.GET("/dbs/:databaseId/colls/:collId/pkranges", handlers.GetPartitionKeyRanges)
router.POST("/dbs/:databaseId/colls/:collId/docs", handlers.DocumentsPost)
router.GET("/dbs/:databaseId/colls/:collId/docs", handlers.GetAllDocuments)
router.GET("/dbs/:databaseId/colls/:collId/docs/:docId", handlers.GetDocument)
router.PUT("/dbs/:databaseId/colls/:collId/docs/:docId", handlers.ReplaceDocument)
router.PATCH("/dbs/:databaseId/colls/:collId/docs/:docId", handlers.PatchDocument)
router.DELETE("/dbs/:databaseId/colls/:collId/docs/:docId", handlers.DeleteDocument)
router.POST("/dbs/:databaseId/colls", handlers.CreateCollection)
router.GET("/dbs/:databaseId/colls", handlers.GetAllCollections)
router.GET("/dbs/:databaseId/colls/:collId", handlers.GetCollection)
router.DELETE("/dbs/:databaseId/colls/:collId", handlers.DeleteCollection)
router.POST("/dbs", handlers.CreateDatabase)
router.GET("/dbs", handlers.GetAllDatabases)
router.GET("/dbs/:databaseId", handlers.GetDatabase)
router.DELETE("/dbs/:databaseId", handlers.DeleteDatabase)
router.GET("/dbs/:databaseId/colls/:collId/udfs", handlers.GetAllUserDefinedFunctions)
router.GET("/dbs/:databaseId/colls/:collId/sprocs", handlers.GetAllStoredProcedures)
router.GET("/dbs/:databaseId/colls/:collId/triggers", handlers.GetAllTriggers)
router.GET("/offers", handlers.GetOffers)
router.GET("/", handlers.GetServerInfo)
router.GET("/cosmium/export", handlers.CosmiumExport)
addRoutesForTrailingSlashes(router)
handlers.RegisterExplorerHandlers(router)
return router
}
func addRoutesForTrailingSlashes(router *gin.Engine) {
trailingSlashGroup := router.Group("/")
//prepend, so slash is stripped before authentication middleware reads path
trailingSlashGroup.Handlers = prepend(trailingSlashGroup.Handlers, middleware.TrailingSlashStripper())
for _, route := range router.Routes() {
if route.Path != "/" { //don't append slash to root path, already handled by RemoveExtraSlash
trailingSlashGroup.Handle(route.Method, route.Path+"/", route.HandlerFunc)
}
}
}
func prepend[T any](a []T, e T) []T {
a = append([]T{e}, a...)
return a
}
func StartAPI() {
if !config.Config.Debug {
gin.SetMode(gin.ReleaseMode)

View File

@ -0,0 +1,36 @@
package tests_test
import (
"fmt"
"net/http"
"net/url"
"testing"
"time"
"github.com/pikami/cosmium/api/config"
"github.com/pikami/cosmium/internal/authentication"
"github.com/stretchr/testify/assert"
)
// Request document with trailing slash like python cosmosdb client does.
func Test_Documents_Read_Trailing_Slash(t *testing.T) {
ts, _ := documents_InitializeDb(t)
defer ts.Close()
t.Run("Read doc with client that appends slash to path", func(t *testing.T) {
resourceIdTemplate := "dbs/%s/colls/%s/docs/%s"
path := fmt.Sprintf(resourceIdTemplate, testDatabaseName, testCollectionName, "12345")
testUrl := ts.URL + "/" + path + "/"
date := time.Now().Format(time.RFC1123)
signature := authentication.GenerateSignature("GET", "docs", path, date, config.Config.AccountKey)
httpClient := &http.Client{}
req, _ := http.NewRequest("GET", testUrl, nil)
req.Header.Add("x-ms-date", date)
req.Header.Add("authorization", "sig="+url.QueryEscape(signature))
_, err := httpClient.Do(req)
assert.Nil(t, err)
})
}