Implement middleware tests

This commit is contained in:
2025-07-30 21:20:50 +02:00
parent b3540d5b3e
commit bedec089ef
3 changed files with 372 additions and 45 deletions

View File

@@ -99,15 +99,10 @@ func generateAPIKey(keyType KeyType) string {
return fmt.Sprintf("%s-%s", prefix, hex.EncodeToString(randomBytes))
}
// InferenceMiddleware returns middleware for OpenAI inference endpoints
func (a *APIAuthMiddleware) InferenceMiddleware() func(http.Handler) http.Handler {
// AuthMiddleware returns a middleware that checks API keys for the given key type
func (a *APIAuthMiddleware) AuthMiddleware(keyType KeyType) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !a.requireInferenceAuth {
next.ServeHTTP(w, r)
return
}
if r.Method == "OPTIONS" {
next.ServeHTTP(w, r)
return
@@ -119,9 +114,18 @@ func (a *APIAuthMiddleware) InferenceMiddleware() func(http.Handler) http.Handle
return
}
// Check if key is valid for OpenAI access
// Management keys also work for OpenAI endpoints (higher privilege)
if !a.isValidKey(apiKey, KeyTypeInference) && !a.isValidKey(apiKey, KeyTypeManagement) {
var isValid bool
switch keyType {
case KeyTypeInference:
// Management keys also work for OpenAI endpoints (higher privilege)
isValid = a.isValidKey(apiKey, KeyTypeInference) || a.isValidKey(apiKey, KeyTypeManagement)
case KeyTypeManagement:
isValid = a.isValidKey(apiKey, KeyTypeManagement)
default:
isValid = false
}
if !isValid {
a.unauthorized(w, "Invalid API key")
return
}
@@ -131,43 +135,12 @@ func (a *APIAuthMiddleware) InferenceMiddleware() func(http.Handler) http.Handle
}
}
// ManagementMiddleware returns middleware for management endpoints
func (a *APIAuthMiddleware) ManagementMiddleware() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !a.requireManagementAuth {
next.ServeHTTP(w, r)
return
}
if r.Method == "OPTIONS" {
next.ServeHTTP(w, r)
return
}
apiKey := a.extractAPIKey(r)
if apiKey == "" {
a.unauthorized(w, "Missing API key")
return
}
// Only management keys work for management endpoints
if !a.isValidKey(apiKey, KeyTypeManagement) {
a.unauthorized(w, "Insufficient privileges - management key required")
return
}
next.ServeHTTP(w, r)
})
}
}
// extractAPIKey extracts the API key from the request
func (a *APIAuthMiddleware) extractAPIKey(r *http.Request) string {
// Check Authorization header: "Bearer sk-..."
if auth := r.Header.Get("Authorization"); auth != "" {
if strings.HasPrefix(auth, "Bearer ") {
return strings.TrimPrefix(auth, "Bearer ")
if after, ok := strings.CutPrefix(auth, "Bearer "); ok {
return after
}
}