mirror of https://github.com/pikami/cosmium.git
Implement authentication
This commit is contained in:
parent
790192bf5a
commit
6a40492c7b
|
@ -5,6 +5,10 @@ import (
|
|||
"fmt"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultAccountKey = "C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw=="
|
||||
)
|
||||
|
||||
var Config = ServerConfig{}
|
||||
|
||||
func ParseFlags() {
|
||||
|
@ -14,6 +18,8 @@ func ParseFlags() {
|
|||
tlsCertificatePath := flag.String("Cert", "../example.crt", "Hostname")
|
||||
tlsCertificateKey := flag.String("CertKey", "../example.key", "Hostname")
|
||||
initialDataPath := flag.String("InitialData", "", "Path to JSON containing initial state")
|
||||
accountKey := flag.String("AccountKey", DefaultAccountKey, "Account key for authentication")
|
||||
disableAuthentication := flag.Bool("DisableAuth", false, "Disable authentication")
|
||||
|
||||
flag.Parse()
|
||||
|
||||
|
@ -23,8 +29,10 @@ func ParseFlags() {
|
|||
Config.TLS_CertificatePath = *tlsCertificatePath
|
||||
Config.TLS_CertificateKey = *tlsCertificateKey
|
||||
Config.DataFilePath = *initialDataPath
|
||||
Config.DisableAuth = *disableAuthentication
|
||||
|
||||
Config.DatabaseAccount = Config.Host
|
||||
Config.DatabaseDomain = Config.Host
|
||||
Config.DatabaseEndpoint = fmt.Sprintf("https://%s:%d/", Config.Host, Config.Port)
|
||||
Config.AccountKey = *accountKey
|
||||
}
|
||||
|
|
|
@ -4,6 +4,7 @@ type ServerConfig struct {
|
|||
DatabaseAccount string
|
||||
DatabaseDomain string
|
||||
DatabaseEndpoint string
|
||||
AccountKey string
|
||||
|
||||
ExplorerPath string
|
||||
Port int
|
||||
|
@ -11,4 +12,5 @@ type ServerConfig struct {
|
|||
TLS_CertificatePath string
|
||||
TLS_CertificateKey string
|
||||
DataFilePath string
|
||||
DisableAuth bool
|
||||
}
|
||||
|
|
|
@ -0,0 +1,62 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/pikami/cosmium/api/config"
|
||||
"github.com/pikami/cosmium/internal/authentication"
|
||||
)
|
||||
|
||||
func Authentication() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
requestUrl := c.Request.URL.String()
|
||||
if config.Config.DisableAuth || strings.HasPrefix(requestUrl, "/_explorer") {
|
||||
return
|
||||
}
|
||||
|
||||
var resourceType string
|
||||
parts := strings.Split(requestUrl, "/")
|
||||
switch len(parts) {
|
||||
case 2, 3:
|
||||
resourceType = parts[1]
|
||||
case 4, 5:
|
||||
resourceType = parts[3]
|
||||
case 6, 7:
|
||||
resourceType = parts[5]
|
||||
}
|
||||
|
||||
databaseId, _ := c.Params.Get("databaseId")
|
||||
collId, _ := c.Params.Get("collId")
|
||||
docId, _ := c.Params.Get("docId")
|
||||
var resourceId string
|
||||
if databaseId != "" {
|
||||
resourceId += "dbs/" + databaseId
|
||||
}
|
||||
if collId != "" {
|
||||
resourceId += "/colls/" + collId
|
||||
}
|
||||
if docId != "" {
|
||||
resourceId += "/docs/" + docId
|
||||
}
|
||||
|
||||
authHeader := c.Request.Header.Get("authorization")
|
||||
date := c.Request.Header.Get("x-ms-date")
|
||||
expectedSignature := authentication.GenerateSignature(
|
||||
c.Request.Method, resourceType, resourceId, date, config.Config.AccountKey)
|
||||
|
||||
decoded, _ := url.QueryUnescape(authHeader)
|
||||
params, _ := url.ParseQuery(decoded)
|
||||
clientSignature := strings.Replace(params.Get("sig"), " ", "+", -1)
|
||||
if clientSignature != expectedSignature {
|
||||
fmt.Printf("Got wrong signature from client.\n- Expected: %s\n- Got: %s\n", expectedSignature, clientSignature)
|
||||
c.IndentedJSON(401, gin.H{
|
||||
"code": "Unauthorized",
|
||||
"message": "Wrong signature.",
|
||||
})
|
||||
c.Abort()
|
||||
}
|
||||
}
|
||||
}
|
|
@ -10,6 +10,7 @@ func CreateRouter() *gin.Engine {
|
|||
router := gin.Default()
|
||||
|
||||
router.Use(middleware.RequestLogger())
|
||||
router.Use(middleware.Authentication())
|
||||
|
||||
router.GET("/dbs/:databaseId/colls/:collId/pkranges", handlers.GetPartitionKeyRanges)
|
||||
|
||||
|
|
|
@ -0,0 +1,87 @@
|
|||
package tests_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos"
|
||||
"github.com/pikami/cosmium/api/config"
|
||||
"github.com/pikami/cosmium/internal/repositories"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_Authentication(t *testing.T) {
|
||||
ts := runTestServer()
|
||||
defer ts.Close()
|
||||
|
||||
t.Run("Should get 200 when correct account key is used", func(t *testing.T) {
|
||||
repositories.DeleteDatabase(testDatabaseName)
|
||||
client, err := azcosmos.NewClientFromConnectionString(
|
||||
fmt.Sprintf("AccountEndpoint=%s;AccountKey=%s", ts.URL, config.DefaultAccountKey),
|
||||
&azcosmos.ClientOptions{},
|
||||
)
|
||||
assert.Nil(t, err)
|
||||
|
||||
createResponse, err := client.CreateDatabase(
|
||||
context.TODO(),
|
||||
azcosmos.DatabaseProperties{ID: testDatabaseName},
|
||||
&azcosmos.CreateDatabaseOptions{})
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, createResponse.DatabaseProperties.ID, testDatabaseName)
|
||||
})
|
||||
|
||||
t.Run("Should get 200 when wrong account key is used, but authentication is dissabled", func(t *testing.T) {
|
||||
config.Config.DisableAuth = true
|
||||
repositories.DeleteDatabase(testDatabaseName)
|
||||
client, err := azcosmos.NewClientFromConnectionString(
|
||||
fmt.Sprintf("AccountEndpoint=%s;AccountKey=%s", ts.URL, "AAAA"),
|
||||
&azcosmos.ClientOptions{},
|
||||
)
|
||||
assert.Nil(t, err)
|
||||
|
||||
createResponse, err := client.CreateDatabase(
|
||||
context.TODO(),
|
||||
azcosmos.DatabaseProperties{ID: testDatabaseName},
|
||||
&azcosmos.CreateDatabaseOptions{})
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, createResponse.DatabaseProperties.ID, testDatabaseName)
|
||||
config.Config.DisableAuth = false
|
||||
})
|
||||
|
||||
t.Run("Should get 401 when wrong account key is used", func(t *testing.T) {
|
||||
repositories.DeleteDatabase(testDatabaseName)
|
||||
client, err := azcosmos.NewClientFromConnectionString(
|
||||
fmt.Sprintf("AccountEndpoint=%s;AccountKey=%s", ts.URL, "AAAA"),
|
||||
&azcosmos.ClientOptions{},
|
||||
)
|
||||
assert.Nil(t, err)
|
||||
|
||||
_, err = client.CreateDatabase(
|
||||
context.TODO(),
|
||||
azcosmos.DatabaseProperties{ID: testDatabaseName},
|
||||
&azcosmos.CreateDatabaseOptions{})
|
||||
|
||||
var respErr *azcore.ResponseError
|
||||
if errors.As(err, &respErr) {
|
||||
assert.Equal(t, respErr.StatusCode, http.StatusUnauthorized)
|
||||
} else {
|
||||
panic(err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Should allow unauthorized requests to /_explorer", func(t *testing.T) {
|
||||
res, err := http.Get(ts.URL + "/_explorer/config.json")
|
||||
assert.Nil(t, err)
|
||||
defer res.Body.Close()
|
||||
responseBody, err := io.ReadAll(res.Body)
|
||||
assert.Nil(t, err)
|
||||
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode)
|
||||
assert.Contains(t, string(responseBody), "BACKEND_ENDPOINT")
|
||||
})
|
||||
}
|
|
@ -9,6 +9,7 @@ import (
|
|||
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos"
|
||||
"github.com/pikami/cosmium/api/config"
|
||||
"github.com/pikami/cosmium/internal/repositories"
|
||||
repositorymodels "github.com/pikami/cosmium/internal/repository_models"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
@ -19,7 +20,7 @@ func Test_Collections(t *testing.T) {
|
|||
defer ts.Close()
|
||||
|
||||
client, err := azcosmos.NewClientFromConnectionString(
|
||||
fmt.Sprintf("AccountEndpoint=%s;AccountKey=%s", ts.URL, "asas"),
|
||||
fmt.Sprintf("AccountEndpoint=%s;AccountKey=%s", ts.URL, config.Config.AccountKey),
|
||||
&azcosmos.ClientOptions{},
|
||||
)
|
||||
assert.Nil(t, err)
|
||||
|
|
|
@ -4,9 +4,12 @@ import (
|
|||
"net/http/httptest"
|
||||
|
||||
"github.com/pikami/cosmium/api"
|
||||
"github.com/pikami/cosmium/api/config"
|
||||
)
|
||||
|
||||
func runTestServer() *httptest.Server {
|
||||
config.Config.AccountKey = config.DefaultAccountKey
|
||||
|
||||
return httptest.NewServer(api.CreateRouter())
|
||||
}
|
||||
|
||||
|
|
|
@ -9,6 +9,7 @@ import (
|
|||
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos"
|
||||
"github.com/pikami/cosmium/api/config"
|
||||
"github.com/pikami/cosmium/internal/repositories"
|
||||
repositorymodels "github.com/pikami/cosmium/internal/repository_models"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
@ -19,13 +20,15 @@ func Test_Databases(t *testing.T) {
|
|||
defer ts.Close()
|
||||
|
||||
client, err := azcosmos.NewClientFromConnectionString(
|
||||
fmt.Sprintf("AccountEndpoint=%s;AccountKey=%s", ts.URL, "asas"),
|
||||
fmt.Sprintf("AccountEndpoint=%s;AccountKey=%s", ts.URL, config.Config.AccountKey),
|
||||
&azcosmos.ClientOptions{},
|
||||
)
|
||||
assert.Nil(t, err)
|
||||
|
||||
t.Run("Database Create", func(t *testing.T) {
|
||||
t.Run("Should create database", func(t *testing.T) {
|
||||
repositories.DeleteDatabase(testDatabaseName)
|
||||
|
||||
createResponse, err := client.CreateDatabase(context.TODO(), azcosmos.DatabaseProperties{
|
||||
ID: testDatabaseName,
|
||||
}, &azcosmos.CreateDatabaseOptions{})
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos"
|
||||
"github.com/pikami/cosmium/api/config"
|
||||
"github.com/pikami/cosmium/internal/repositories"
|
||||
repositorymodels "github.com/pikami/cosmium/internal/repository_models"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
@ -66,7 +67,7 @@ func Test_Documents(t *testing.T) {
|
|||
defer ts.Close()
|
||||
|
||||
client, err := azcosmos.NewClientFromConnectionString(
|
||||
fmt.Sprintf("AccountEndpoint=%s;AccountKey=%s", ts.URL, "asas"),
|
||||
fmt.Sprintf("AccountEndpoint=%s;AccountKey=%s", ts.URL, config.Config.AccountKey),
|
||||
&azcosmos.ClientOptions{},
|
||||
)
|
||||
assert.Nil(t, err)
|
||||
|
|
|
@ -0,0 +1,26 @@
|
|||
package authentication
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// https://learn.microsoft.com/en-us/rest/api/cosmos-db/access-control-on-cosmosdb-resources
|
||||
func GenerateSignature(verb string, resourceType string, resourceId string, date string, masterKey string) string {
|
||||
payload := fmt.Sprintf(
|
||||
"%s\n%s\n%s\n%s\n%s\n",
|
||||
strings.ToLower(verb),
|
||||
strings.ToLower(resourceType),
|
||||
resourceId,
|
||||
strings.ToLower(date),
|
||||
"")
|
||||
|
||||
masterKeyBytes, _ := base64.StdEncoding.DecodeString(masterKey)
|
||||
hash := hmac.New(sha256.New, masterKeyBytes)
|
||||
hash.Write([]byte(payload))
|
||||
signature := base64.StdEncoding.EncodeToString(hash.Sum(nil))
|
||||
return signature
|
||||
}
|
|
@ -0,0 +1,30 @@
|
|||
package authentication_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/pikami/cosmium/api/config"
|
||||
"github.com/pikami/cosmium/internal/authentication"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
const (
|
||||
testDate = "Fri, 17 Dec 1926 03:15:00 GMT"
|
||||
)
|
||||
|
||||
func Test_GenerateSignature(t *testing.T) {
|
||||
t.Run("Should generate GET signature", func(t *testing.T) {
|
||||
signature := authentication.GenerateSignature("GET", "colls", "dbs/Test Database/colls/Test Collection", testDate, config.DefaultAccountKey)
|
||||
assert.Equal(t, "cugjaA51bjCvxVi8LXg3XB+ZVKaFAZshILoJZF9nfEY=", signature)
|
||||
})
|
||||
|
||||
t.Run("Should generate POST signature", func(t *testing.T) {
|
||||
signature := authentication.GenerateSignature("POST", "colls", "dbs/Test Database", testDate, config.DefaultAccountKey)
|
||||
assert.Equal(t, "E92FgDG9JiNX+NfsI+edOFtgkZRDkrrJxIfl12Vsu8A=", signature)
|
||||
})
|
||||
|
||||
t.Run("Should generate DELETE signature", func(t *testing.T) {
|
||||
signature := authentication.GenerateSignature("DELETE", "dbs", "dbs/Test Database", testDate, config.DefaultAccountKey)
|
||||
assert.Equal(t, "LcuXXg0TcXxZG0kUCj9tZIWRy2yCzim3oiqGiHpRqGs=", signature)
|
||||
})
|
||||
}
|
Loading…
Reference in New Issue