Strip trailing slash using middleware

This commit is contained in:
Pijus Kamandulis 2024-10-28 20:20:52 +02:00
parent 827046f634
commit 0e98e3481a
4 changed files with 27 additions and 40 deletions

View File

@ -0,0 +1,18 @@
package middleware
import (
"github.com/gin-gonic/gin"
)
func StripTrailingSlashes(r *gin.Engine) gin.HandlerFunc {
return func(c *gin.Context) {
path := c.Request.URL.Path
if len(path) > 1 && path[len(path)-1] == '/' {
c.Request.URL.Path = path[:len(path)-1]
r.HandleContext(c)
c.Abort()
return
}
c.Next()
}
}

View File

@ -1,17 +0,0 @@
package middleware
import (
"strings"
"github.com/gin-gonic/gin"
)
func TrailingSlashStripper() gin.HandlerFunc {
return func(c *gin.Context) {
if (len(c.Request.URL.Path)) > 1 { //dont strip root dir slash, path="/"
var stripped_path = strings.TrimSuffix(c.Request.URL.Path, "/")
c.Request.URL.Path = stripped_path
}
c.Next()
}
}

View File

@ -13,15 +13,15 @@ import (
)
func CreateRouter() *gin.Engine {
router := gin.Default(func(e *gin.Engine) {
e.RemoveExtraSlash = true
e.RedirectTrailingSlash = false
})
if config.Config.Debug {
router.Use(middleware.RequestLogger())
}
router.Use(middleware.StripTrailingSlashes(router))
router.Use(middleware.Authentication())
router.GET("/dbs/:databaseId/colls/:collId/pkranges", handlers.GetPartitionKeyRanges)
@ -52,30 +52,11 @@ func CreateRouter() *gin.Engine {
router.GET("/cosmium/export", handlers.CosmiumExport)
addRoutesForTrailingSlashes(router)
handlers.RegisterExplorerHandlers(router)
return router
}
func addRoutesForTrailingSlashes(router *gin.Engine) {
trailingSlashGroup := router.Group("/")
//prepend, so slash is stripped before authentication middleware reads path
trailingSlashGroup.Handlers = prepend(trailingSlashGroup.Handlers, middleware.TrailingSlashStripper())
for _, route := range router.Routes() {
if route.Path != "/" { //don't append slash to root path, already handled by RemoveExtraSlash
trailingSlashGroup.Handle(route.Method, route.Path+"/", route.HandlerFunc)
}
}
}
func prepend[T any](a []T, e T) []T {
a = append([]T{e}, a...)
return a
}
func StartAPI() {
if !config.Config.Debug {
gin.SetMode(gin.ReleaseMode)

View File

@ -27,10 +27,15 @@ func Test_Documents_Read_Trailing_Slash(t *testing.T) {
req, _ := http.NewRequest("GET", testUrl, nil)
req.Header.Add("x-ms-date", date)
req.Header.Add("authorization", "sig="+url.QueryEscape(signature))
_, err := httpClient.Do(req)
res, err := httpClient.Do(req)
assert.Nil(t, err)
})
if res != nil {
defer res.Body.Close()
assert.Equal(t, http.StatusOK, res.StatusCode, "Expected HTTP status 200 OK")
} else {
t.FailNow()
}
})
}