Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 85 additions & 2 deletions api/auth_middleware_test.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
package api

import (
"context"
"encoding/base64"
"fmt"
"io"
"net/http/httptest"
"strings"
"testing"

"api.audius.co/database"
Expand Down Expand Up @@ -207,14 +209,24 @@ func TestGetApiSignerBasicAuth(t *testing.T) {
assert.Contains(t, string(body), "missing Authorization header")
})

t.Run("invalid basic auth format - not basic", func(t *testing.T) {
t.Run("invalid Bearer token", func(t *testing.T) {
req := httptest.NewRequest("POST", "/test", nil)
req.Header.Set("Authorization", "Bearer invalidtoken")
res, err := testApp.Test(req, -1)
assert.NoError(t, err)
assert.Equal(t, fiber.StatusInternalServerError, res.StatusCode)
body, _ := io.ReadAll(res.Body)
assert.Contains(t, string(body), "Authorization header is not Basic Auth")
assert.Contains(t, string(body), "invalid Bearer token")
})

t.Run("invalid Basic auth format - not Bearer or Basic", func(t *testing.T) {
req := httptest.NewRequest("POST", "/test", nil)
req.Header.Set("Authorization", "Digest some-credentials")
res, err := testApp.Test(req, -1)
assert.NoError(t, err)
assert.Equal(t, fiber.StatusInternalServerError, res.StatusCode)
body, _ := io.ReadAll(res.Body)
assert.Contains(t, string(body), "Authorization must be Bearer or Basic")
})

t.Run("invalid private key", func(t *testing.T) {
Expand Down Expand Up @@ -254,6 +266,77 @@ func TestGetApiSignerBasicAuth(t *testing.T) {
})
}

func TestGetApiSignerWithApiAccessKey(t *testing.T) {
app := emptyTestApp(t)
if app.writePool == nil {
t.Skip("writePool required for api_access_key lookup")
}

ctx := context.Background()
ensureApiKeysTables(t, app, ctx)

// Same private key as TestGetApiSignerBasicAuth - derives to 0xf39Fd6e51aad88F6F4ce6aB8827279cffFb92266
testPrivateKey := "ac0974bec39a17e36ba4a6b4d238ff944bacb478cbed5efcae784d7bf4f2ff80"
parentApiKey := "0xf39Fd6e51aad88F6F4ce6aB8827279cffFb92266"
apiAccessKey := "test-access-key-123"

_, err := app.writePool.Exec(ctx, `
INSERT INTO api_keys (api_key, api_secret, rps, rpm)
VALUES ($1, $2, 10, 500000)
ON CONFLICT (api_key) DO UPDATE SET api_secret = EXCLUDED.api_secret
`, parentApiKey, testPrivateKey)
assert.NoError(t, err)

_, err = app.writePool.Exec(ctx, `
INSERT INTO api_access_keys (api_key, api_access_key, is_active)
VALUES ($1, $2, true)
ON CONFLICT (api_key, api_access_key) DO UPDATE SET is_active = true
`, parentApiKey, apiAccessKey)
assert.NoError(t, err)

testApp := fiber.New()
testApp.Post("/test", func(c *fiber.Ctx) error {
signer, err := app.getApiSigner(c)
if err != nil {
return fiber.NewError(fiber.StatusInternalServerError, err.Error())
}
return c.JSON(fiber.Map{
"address": signer.Address,
})
})

req := httptest.NewRequest("POST", "/test", nil)
req.Header.Set("Authorization", "Basic "+encodeBasicAuth("", apiAccessKey))
res, err := testApp.Test(req, -1)
assert.NoError(t, err)
assert.Equal(t, fiber.StatusOK, res.StatusCode)
body, _ := io.ReadAll(res.Body)
assert.True(t, strings.Contains(strings.ToLower(string(body)), strings.ToLower(parentApiKey)),
"body %s should contain address %s", string(body), parentApiKey)
}

// ensureApiKeysTables creates api_keys and api_access_keys if they do not exist.
func ensureApiKeysTables(t *testing.T, app *ApiServer, ctx context.Context) {
t.Helper()
_, err := app.writePool.Exec(ctx, `
CREATE TABLE IF NOT EXISTS api_keys (
api_key VARCHAR(255) NOT NULL PRIMARY KEY,
api_secret VARCHAR(255),
rps INTEGER NOT NULL DEFAULT 10,
rpm INTEGER NOT NULL DEFAULT 500000,
created_at TIMESTAMP NOT NULL DEFAULT NOW()
);
CREATE TABLE IF NOT EXISTS api_access_keys (
api_key VARCHAR(255) NOT NULL,
api_access_key VARCHAR(255) NOT NULL,
created_at TIMESTAMP NOT NULL DEFAULT NOW(),
is_active BOOLEAN NOT NULL DEFAULT true,
PRIMARY KEY (api_key, api_access_key)
);
`)
assert.NoError(t, err)
}

// Helper function to encode basic auth credentials
func encodeBasicAuth(username, password string) string {
auth := username + ":" + password
Expand Down
15 changes: 15 additions & 0 deletions api/dbv1/models.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

65 changes: 65 additions & 0 deletions api/frontend_auth.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
package api

import (
"strings"

"github.com/ethereum/go-ethereum/crypto"
"github.com/gofiber/fiber/v2"
"go.uber.org/zap"
)

// requireFrontendAppAuth returns a middleware that validates Bearer token and checks
// that the given frontend app (identified by its private key secret) has a grant from the user.
// Must run after requireUserIdMiddleware.
func (app *ApiServer) requireFrontendAppAuth(secret string, appName string) fiber.Handler {
return func(c *fiber.Ctx) error {
if secret == "" {
app.logger.Error(appName+" secret not configured", zap.String("app", appName))
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{
"error": appName + " not configured",
})
}

authHeader := c.Get("Authorization")
if authHeader == "" || !strings.HasPrefix(authHeader, "Bearer ") {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{
"error": "Missing or invalid Authorization header. Use Bearer <oauth_token>",
})
}
token := strings.TrimPrefix(authHeader, "Bearer ")

pathUserId := app.getUserId(c)
if pathUserId == 0 {
return fiber.NewError(fiber.StatusBadRequest, "invalid userId")
}

jwtUserId, err := app.validateOAuthJWTTokenToUserId(c.Context(), token)
if err != nil {
return err
}

if int32(jwtUserId) != pathUserId {
return c.Status(fiber.StatusForbidden).JSON(fiber.Map{
"error": "Token userId does not match path userId",
})
}

// Derive app address from private key
privateKey, err := crypto.HexToECDSA(strings.TrimPrefix(secret, "0x"))
if err != nil {
app.logger.Error("Invalid "+appName+" secret", zap.Error(err))
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{
"error": appName + " misconfigured",
})
}
appAddress := strings.ToLower(crypto.PubkeyToAddress(privateKey.PublicKey).Hex())

if !app.isAuthorizedRequest(c.Context(), pathUserId, appAddress) {
return c.Status(fiber.StatusForbidden).JSON(fiber.Map{
"error": "User has not granted the " + strings.ToLower(appName) + " write access. Log in with OAuth scope 'write'.",
})
}

return c.Next()
}
}
18 changes: 14 additions & 4 deletions api/metrics_middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package api
import (
"context"
"runtime"
"strings"
"sync"
"time"

Expand Down Expand Up @@ -96,13 +97,22 @@ func NewMetricsCollector(logger *zap.Logger, writePool *pgxpool.Pool) *MetricsCo
return collector
}

// Fiber middleware that collects metrics
func (rmc *MetricsCollector) Middleware() fiber.Handler {
// Fiber middleware that collects metrics. Pass apiServer to resolve identifier from Bearer or Basic Auth signer first; if nil or no signer, falls back to api_key/app_name query params.
func (rmc *MetricsCollector) Middleware(apiServer *ApiServer) fiber.Handler {
return func(c *fiber.Ctx) error {
err := c.Next()

apiKey := c.Query("api_key")
appName := c.Query("app_name")
var apiKey, appName string
if apiServer != nil {
signer, signerErr := apiServer.getApiSigner(c)
if signerErr == nil && signer != nil {
apiKey = fiberutils.CopyString(strings.ToLower(signer.Address))
}
}
if apiKey == "" && appName == "" {
apiKey = c.Query("api_key")
appName = c.Query("app_name")
}
ipAddress := utils.GetIP(c)

// Only record if we have some identifier
Expand Down
Loading