Implement authentication

This commit is contained in:
Pijus Kamandulis 2024-02-21 23:40:54 +02:00
parent 790192bf5a
commit 6a40492c7b
11 changed files with 227 additions and 3 deletions

View File

@ -5,6 +5,10 @@ import (
"fmt"
)
const (
DefaultAccountKey = "C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw=="
)
var Config = ServerConfig{}
func ParseFlags() {
@ -14,6 +18,8 @@ func ParseFlags() {
tlsCertificatePath := flag.String("Cert", "../example.crt", "Hostname")
tlsCertificateKey := flag.String("CertKey", "../example.key", "Hostname")
initialDataPath := flag.String("InitialData", "", "Path to JSON containing initial state")
accountKey := flag.String("AccountKey", DefaultAccountKey, "Account key for authentication")
disableAuthentication := flag.Bool("DisableAuth", false, "Disable authentication")
flag.Parse()
@ -23,8 +29,10 @@ func ParseFlags() {
Config.TLS_CertificatePath = *tlsCertificatePath
Config.TLS_CertificateKey = *tlsCertificateKey
Config.DataFilePath = *initialDataPath
Config.DisableAuth = *disableAuthentication
Config.DatabaseAccount = Config.Host
Config.DatabaseDomain = Config.Host
Config.DatabaseEndpoint = fmt.Sprintf("https://%s:%d/", Config.Host, Config.Port)
Config.AccountKey = *accountKey
}

View File

@ -4,6 +4,7 @@ type ServerConfig struct {
DatabaseAccount string
DatabaseDomain string
DatabaseEndpoint string
AccountKey string
ExplorerPath string
Port int
@ -11,4 +12,5 @@ type ServerConfig struct {
TLS_CertificatePath string
TLS_CertificateKey string
DataFilePath string
DisableAuth bool
}

View File

@ -0,0 +1,62 @@
package middleware
import (
"fmt"
"net/url"
"strings"
"github.com/gin-gonic/gin"
"github.com/pikami/cosmium/api/config"
"github.com/pikami/cosmium/internal/authentication"
)
func Authentication() gin.HandlerFunc {
return func(c *gin.Context) {
requestUrl := c.Request.URL.String()
if config.Config.DisableAuth || strings.HasPrefix(requestUrl, "/_explorer") {
return
}
var resourceType string
parts := strings.Split(requestUrl, "/")
switch len(parts) {
case 2, 3:
resourceType = parts[1]
case 4, 5:
resourceType = parts[3]
case 6, 7:
resourceType = parts[5]
}
databaseId, _ := c.Params.Get("databaseId")
collId, _ := c.Params.Get("collId")
docId, _ := c.Params.Get("docId")
var resourceId string
if databaseId != "" {
resourceId += "dbs/" + databaseId
}
if collId != "" {
resourceId += "/colls/" + collId
}
if docId != "" {
resourceId += "/docs/" + docId
}
authHeader := c.Request.Header.Get("authorization")
date := c.Request.Header.Get("x-ms-date")
expectedSignature := authentication.GenerateSignature(
c.Request.Method, resourceType, resourceId, date, config.Config.AccountKey)
decoded, _ := url.QueryUnescape(authHeader)
params, _ := url.ParseQuery(decoded)
clientSignature := strings.Replace(params.Get("sig"), " ", "+", -1)
if clientSignature != expectedSignature {
fmt.Printf("Got wrong signature from client.\n- Expected: %s\n- Got: %s\n", expectedSignature, clientSignature)
c.IndentedJSON(401, gin.H{
"code": "Unauthorized",
"message": "Wrong signature.",
})
c.Abort()
}
}
}

View File

@ -10,6 +10,7 @@ func CreateRouter() *gin.Engine {
router := gin.Default()
router.Use(middleware.RequestLogger())
router.Use(middleware.Authentication())
router.GET("/dbs/:databaseId/colls/:collId/pkranges", handlers.GetPartitionKeyRanges)

View File

@ -0,0 +1,87 @@
package tests_test
import (
"context"
"errors"
"fmt"
"io"
"net/http"
"testing"
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos"
"github.com/pikami/cosmium/api/config"
"github.com/pikami/cosmium/internal/repositories"
"github.com/stretchr/testify/assert"
)
func Test_Authentication(t *testing.T) {
ts := runTestServer()
defer ts.Close()
t.Run("Should get 200 when correct account key is used", func(t *testing.T) {
repositories.DeleteDatabase(testDatabaseName)
client, err := azcosmos.NewClientFromConnectionString(
fmt.Sprintf("AccountEndpoint=%s;AccountKey=%s", ts.URL, config.DefaultAccountKey),
&azcosmos.ClientOptions{},
)
assert.Nil(t, err)
createResponse, err := client.CreateDatabase(
context.TODO(),
azcosmos.DatabaseProperties{ID: testDatabaseName},
&azcosmos.CreateDatabaseOptions{})
assert.Nil(t, err)
assert.Equal(t, createResponse.DatabaseProperties.ID, testDatabaseName)
})
t.Run("Should get 200 when wrong account key is used, but authentication is dissabled", func(t *testing.T) {
config.Config.DisableAuth = true
repositories.DeleteDatabase(testDatabaseName)
client, err := azcosmos.NewClientFromConnectionString(
fmt.Sprintf("AccountEndpoint=%s;AccountKey=%s", ts.URL, "AAAA"),
&azcosmos.ClientOptions{},
)
assert.Nil(t, err)
createResponse, err := client.CreateDatabase(
context.TODO(),
azcosmos.DatabaseProperties{ID: testDatabaseName},
&azcosmos.CreateDatabaseOptions{})
assert.Nil(t, err)
assert.Equal(t, createResponse.DatabaseProperties.ID, testDatabaseName)
config.Config.DisableAuth = false
})
t.Run("Should get 401 when wrong account key is used", func(t *testing.T) {
repositories.DeleteDatabase(testDatabaseName)
client, err := azcosmos.NewClientFromConnectionString(
fmt.Sprintf("AccountEndpoint=%s;AccountKey=%s", ts.URL, "AAAA"),
&azcosmos.ClientOptions{},
)
assert.Nil(t, err)
_, err = client.CreateDatabase(
context.TODO(),
azcosmos.DatabaseProperties{ID: testDatabaseName},
&azcosmos.CreateDatabaseOptions{})
var respErr *azcore.ResponseError
if errors.As(err, &respErr) {
assert.Equal(t, respErr.StatusCode, http.StatusUnauthorized)
} else {
panic(err)
}
})
t.Run("Should allow unauthorized requests to /_explorer", func(t *testing.T) {
res, err := http.Get(ts.URL + "/_explorer/config.json")
assert.Nil(t, err)
defer res.Body.Close()
responseBody, err := io.ReadAll(res.Body)
assert.Nil(t, err)
assert.Equal(t, http.StatusOK, res.StatusCode)
assert.Contains(t, string(responseBody), "BACKEND_ENDPOINT")
})
}

View File

@ -9,6 +9,7 @@ import (
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos"
"github.com/pikami/cosmium/api/config"
"github.com/pikami/cosmium/internal/repositories"
repositorymodels "github.com/pikami/cosmium/internal/repository_models"
"github.com/stretchr/testify/assert"
@ -19,7 +20,7 @@ func Test_Collections(t *testing.T) {
defer ts.Close()
client, err := azcosmos.NewClientFromConnectionString(
fmt.Sprintf("AccountEndpoint=%s;AccountKey=%s", ts.URL, "asas"),
fmt.Sprintf("AccountEndpoint=%s;AccountKey=%s", ts.URL, config.Config.AccountKey),
&azcosmos.ClientOptions{},
)
assert.Nil(t, err)

View File

@ -4,9 +4,12 @@ import (
"net/http/httptest"
"github.com/pikami/cosmium/api"
"github.com/pikami/cosmium/api/config"
)
func runTestServer() *httptest.Server {
config.Config.AccountKey = config.DefaultAccountKey
return httptest.NewServer(api.CreateRouter())
}

View File

@ -9,6 +9,7 @@ import (
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos"
"github.com/pikami/cosmium/api/config"
"github.com/pikami/cosmium/internal/repositories"
repositorymodels "github.com/pikami/cosmium/internal/repository_models"
"github.com/stretchr/testify/assert"
@ -19,13 +20,15 @@ func Test_Databases(t *testing.T) {
defer ts.Close()
client, err := azcosmos.NewClientFromConnectionString(
fmt.Sprintf("AccountEndpoint=%s;AccountKey=%s", ts.URL, "asas"),
fmt.Sprintf("AccountEndpoint=%s;AccountKey=%s", ts.URL, config.Config.AccountKey),
&azcosmos.ClientOptions{},
)
assert.Nil(t, err)
t.Run("Database Create", func(t *testing.T) {
t.Run("Should create database", func(t *testing.T) {
repositories.DeleteDatabase(testDatabaseName)
createResponse, err := client.CreateDatabase(context.TODO(), azcosmos.DatabaseProperties{
ID: testDatabaseName,
}, &azcosmos.CreateDatabaseOptions{})

View File

@ -8,6 +8,7 @@ import (
"testing"
"github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos"
"github.com/pikami/cosmium/api/config"
"github.com/pikami/cosmium/internal/repositories"
repositorymodels "github.com/pikami/cosmium/internal/repository_models"
"github.com/stretchr/testify/assert"
@ -66,7 +67,7 @@ func Test_Documents(t *testing.T) {
defer ts.Close()
client, err := azcosmos.NewClientFromConnectionString(
fmt.Sprintf("AccountEndpoint=%s;AccountKey=%s", ts.URL, "asas"),
fmt.Sprintf("AccountEndpoint=%s;AccountKey=%s", ts.URL, config.Config.AccountKey),
&azcosmos.ClientOptions{},
)
assert.Nil(t, err)

View File

@ -0,0 +1,26 @@
package authentication
import (
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"fmt"
"strings"
)
// https://learn.microsoft.com/en-us/rest/api/cosmos-db/access-control-on-cosmosdb-resources
func GenerateSignature(verb string, resourceType string, resourceId string, date string, masterKey string) string {
payload := fmt.Sprintf(
"%s\n%s\n%s\n%s\n%s\n",
strings.ToLower(verb),
strings.ToLower(resourceType),
resourceId,
strings.ToLower(date),
"")
masterKeyBytes, _ := base64.StdEncoding.DecodeString(masterKey)
hash := hmac.New(sha256.New, masterKeyBytes)
hash.Write([]byte(payload))
signature := base64.StdEncoding.EncodeToString(hash.Sum(nil))
return signature
}

View File

@ -0,0 +1,30 @@
package authentication_test
import (
"testing"
"github.com/pikami/cosmium/api/config"
"github.com/pikami/cosmium/internal/authentication"
"github.com/stretchr/testify/assert"
)
const (
testDate = "Fri, 17 Dec 1926 03:15:00 GMT"
)
func Test_GenerateSignature(t *testing.T) {
t.Run("Should generate GET signature", func(t *testing.T) {
signature := authentication.GenerateSignature("GET", "colls", "dbs/Test Database/colls/Test Collection", testDate, config.DefaultAccountKey)
assert.Equal(t, "cugjaA51bjCvxVi8LXg3XB+ZVKaFAZshILoJZF9nfEY=", signature)
})
t.Run("Should generate POST signature", func(t *testing.T) {
signature := authentication.GenerateSignature("POST", "colls", "dbs/Test Database", testDate, config.DefaultAccountKey)
assert.Equal(t, "E92FgDG9JiNX+NfsI+edOFtgkZRDkrrJxIfl12Vsu8A=", signature)
})
t.Run("Should generate DELETE signature", func(t *testing.T) {
signature := authentication.GenerateSignature("DELETE", "dbs", "dbs/Test Database", testDate, config.DefaultAccountKey)
assert.Equal(t, "LcuXXg0TcXxZG0kUCj9tZIWRy2yCzim3oiqGiHpRqGs=", signature)
})
}