mirror of
https://github.com/lordmathis/llamactl.git
synced 2025-12-23 09:34:23 +00:00
Add permission checks to proxies
This commit is contained in:
@@ -33,9 +33,8 @@ type KeyPermission struct {
|
|||||||
CanViewLogs bool
|
CanViewLogs bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// GenerateKey generates a cryptographically secure inference API key
|
// GenerateKey generates a cryptographically secure API key with the given prefix
|
||||||
// Format: sk-inference-<64-hex-chars>
|
func GenerateKey(prefix string) (string, error) {
|
||||||
func GenerateKey() (string, error) {
|
|
||||||
// Generate 32 random bytes
|
// Generate 32 random bytes
|
||||||
bytes := make([]byte, 32)
|
bytes := make([]byte, 32)
|
||||||
_, err := rand.Read(bytes)
|
_, err := rand.Read(bytes)
|
||||||
@@ -46,6 +45,5 @@ func GenerateKey() (string, error) {
|
|||||||
// Convert to hex (64 characters)
|
// Convert to hex (64 characters)
|
||||||
hexStr := hex.EncodeToString(bytes)
|
hexStr := hex.EncodeToString(bytes)
|
||||||
|
|
||||||
// Prefix with "sk-inference-"
|
return fmt.Sprintf("%s-%s", prefix, hexStr), nil
|
||||||
return fmt.Sprintf("sk-inference-%s", hexStr), nil
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -24,6 +24,38 @@ type CreateKeyRequest struct {
|
|||||||
InstancePermissions []InstancePermission
|
InstancePermissions []InstancePermission
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type CreateKeyResponse struct {
|
||||||
|
ID int `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
UserID string `json:"user_id"`
|
||||||
|
PermissionMode auth.PermissionMode `json:"permission_mode"`
|
||||||
|
ExpiresAt *int64 `json:"expires_at"`
|
||||||
|
Enabled bool `json:"enabled"`
|
||||||
|
CreatedAt int64 `json:"created_at"`
|
||||||
|
UpdatedAt int64 `json:"updated_at"`
|
||||||
|
LastUsedAt *int64 `json:"last_used_at"`
|
||||||
|
Key string `json:"key"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type KeyResponse struct {
|
||||||
|
ID int `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
UserID string `json:"user_id"`
|
||||||
|
PermissionMode auth.PermissionMode `json:"permission_mode"`
|
||||||
|
ExpiresAt *int64 `json:"expires_at"`
|
||||||
|
Enabled bool `json:"enabled"`
|
||||||
|
CreatedAt int64 `json:"created_at"`
|
||||||
|
UpdatedAt int64 `json:"updated_at"`
|
||||||
|
LastUsedAt *int64 `json:"last_used_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type KeyPermissionResponse struct {
|
||||||
|
InstanceID int `json:"instance_id"`
|
||||||
|
InstanceName string `json:"instance_name"`
|
||||||
|
CanInfer bool `json:"can_infer"`
|
||||||
|
CanViewLogs bool `json:"can_view_logs"`
|
||||||
|
}
|
||||||
|
|
||||||
// CreateInferenceKey handles POST /api/v1/keys
|
// CreateInferenceKey handles POST /api/v1/keys
|
||||||
func (h *Handler) CreateInferenceKey() http.HandlerFunc {
|
func (h *Handler) CreateInferenceKey() http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
@@ -76,7 +108,7 @@ func (h *Handler) CreateInferenceKey() http.HandlerFunc {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Generate plain-text key
|
// Generate plain-text key
|
||||||
plainTextKey, err := auth.GenerateKey()
|
plainTextKey, err := auth.GenerateKey("llamactl-")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
writeError(w, http.StatusInternalServerError, "key_generation_failed", "Failed to generate API key")
|
writeError(w, http.StatusInternalServerError, "key_generation_failed", "Failed to generate API key")
|
||||||
return
|
return
|
||||||
@@ -118,18 +150,20 @@ func (h *Handler) CreateInferenceKey() http.HandlerFunc {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
writeError(w, http.StatusInternalServerError, "creation_failed", fmt.Sprintf("Failed to create API key: %v", err))
|
writeError(w, http.StatusInternalServerError, "creation_failed", fmt.Sprintf("Failed to create API key: %v", err))
|
||||||
return
|
return
|
||||||
} // Return response with plain-text key (only shown once)
|
}
|
||||||
response := map[string]interface{}{
|
|
||||||
"id": apiKey.ID,
|
// Return response with plain-text key (only shown once)
|
||||||
"name": apiKey.Name,
|
response := CreateKeyResponse{
|
||||||
"user_id": apiKey.UserID,
|
ID: apiKey.ID,
|
||||||
"permission_mode": apiKey.PermissionMode,
|
Name: apiKey.Name,
|
||||||
"expires_at": apiKey.ExpiresAt,
|
UserID: apiKey.UserID,
|
||||||
"enabled": apiKey.Enabled,
|
PermissionMode: apiKey.PermissionMode,
|
||||||
"created_at": apiKey.CreatedAt,
|
ExpiresAt: apiKey.ExpiresAt,
|
||||||
"updated_at": apiKey.UpdatedAt,
|
Enabled: apiKey.Enabled,
|
||||||
"last_used_at": apiKey.LastUsedAt,
|
CreatedAt: apiKey.CreatedAt,
|
||||||
"key": plainTextKey, // Only returned on creation
|
UpdatedAt: apiKey.UpdatedAt,
|
||||||
|
LastUsedAt: apiKey.LastUsedAt,
|
||||||
|
Key: plainTextKey,
|
||||||
}
|
}
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
@@ -148,18 +182,18 @@ func (h *Handler) ListInferenceKeys() http.HandlerFunc {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Remove key_hash from all keys
|
// Remove key_hash from all keys
|
||||||
var response []map[string]interface{}
|
response := make([]KeyResponse, 0, len(keys))
|
||||||
for _, key := range keys {
|
for _, key := range keys {
|
||||||
response = append(response, map[string]interface{}{
|
response = append(response, KeyResponse{
|
||||||
"id": key.ID,
|
ID: key.ID,
|
||||||
"name": key.Name,
|
Name: key.Name,
|
||||||
"user_id": key.UserID,
|
UserID: key.UserID,
|
||||||
"permission_mode": key.PermissionMode,
|
PermissionMode: key.PermissionMode,
|
||||||
"expires_at": key.ExpiresAt,
|
ExpiresAt: key.ExpiresAt,
|
||||||
"enabled": key.Enabled,
|
Enabled: key.Enabled,
|
||||||
"created_at": key.CreatedAt,
|
CreatedAt: key.CreatedAt,
|
||||||
"updated_at": key.UpdatedAt,
|
UpdatedAt: key.UpdatedAt,
|
||||||
"last_used_at": key.LastUsedAt,
|
LastUsedAt: key.LastUsedAt,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -189,16 +223,16 @@ func (h *Handler) GetInferenceKey() http.HandlerFunc {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Remove key_hash from response
|
// Remove key_hash from response
|
||||||
response := map[string]interface{}{
|
response := KeyResponse{
|
||||||
"id": key.ID,
|
ID: key.ID,
|
||||||
"name": key.Name,
|
Name: key.Name,
|
||||||
"user_id": key.UserID,
|
UserID: key.UserID,
|
||||||
"permission_mode": key.PermissionMode,
|
PermissionMode: key.PermissionMode,
|
||||||
"expires_at": key.ExpiresAt,
|
ExpiresAt: key.ExpiresAt,
|
||||||
"enabled": key.Enabled,
|
Enabled: key.Enabled,
|
||||||
"created_at": key.CreatedAt,
|
CreatedAt: key.CreatedAt,
|
||||||
"updated_at": key.UpdatedAt,
|
UpdatedAt: key.UpdatedAt,
|
||||||
"last_used_at": key.LastUsedAt,
|
LastUsedAt: key.LastUsedAt,
|
||||||
}
|
}
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
@@ -268,13 +302,13 @@ func (h *Handler) GetInferenceKeyPermissions() http.HandlerFunc {
|
|||||||
instanceNameMap[inst.ID] = inst.Name
|
instanceNameMap[inst.ID] = inst.Name
|
||||||
}
|
}
|
||||||
|
|
||||||
var response []map[string]interface{}
|
response := make([]KeyPermissionResponse, 0, len(permissions))
|
||||||
for _, perm := range permissions {
|
for _, perm := range permissions {
|
||||||
response = append(response, map[string]interface{}{
|
response = append(response, KeyPermissionResponse{
|
||||||
"instance_id": perm.InstanceID,
|
InstanceID: perm.InstanceID,
|
||||||
"instance_name": instanceNameMap[perm.InstanceID],
|
InstanceName: instanceNameMap[perm.InstanceID],
|
||||||
"can_infer": perm.CanInfer,
|
CanInfer: perm.CanInfer,
|
||||||
"can_view_logs": perm.CanViewLogs,
|
CanViewLogs: perm.CanViewLogs,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -109,6 +109,12 @@ func (h *Handler) LlamaCppProxy() http.HandlerFunc {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check instance permissions
|
||||||
|
if err := h.authMiddleware.CheckInstancePermission(r.Context(), inst.ID); err != nil {
|
||||||
|
writeError(w, http.StatusForbidden, "permission_denied", err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// Check if instance is shutting down before autostart logic
|
// Check if instance is shutting down before autostart logic
|
||||||
if inst.GetStatus() == instance.ShuttingDown {
|
if inst.GetStatus() == instance.ShuttingDown {
|
||||||
writeError(w, http.StatusServiceUnavailable, "instance_shutting_down", "Instance is shutting down")
|
writeError(w, http.StatusServiceUnavailable, "instance_shutting_down", "Instance is shutting down")
|
||||||
|
|||||||
@@ -327,6 +327,12 @@ func (h *Handler) InstanceProxy() http.HandlerFunc {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check instance permissions
|
||||||
|
if err := h.authMiddleware.CheckInstancePermission(r.Context(), inst.ID); err != nil {
|
||||||
|
writeError(w, http.StatusForbidden, "permission_denied", err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if !inst.IsRunning() {
|
if !inst.IsRunning() {
|
||||||
writeError(w, http.StatusServiceUnavailable, "instance_not_running", "Instance is not running")
|
writeError(w, http.StatusServiceUnavailable, "instance_not_running", "Instance is not running")
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -107,6 +107,12 @@ func (h *Handler) OpenAIProxy() http.HandlerFunc {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check instance permissions
|
||||||
|
if err := h.authMiddleware.CheckInstancePermission(r.Context(), inst.ID); err != nil {
|
||||||
|
writeError(w, http.StatusForbidden, "permission_denied", err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// Check if instance is shutting down before autostart logic
|
// Check if instance is shutting down before autostart logic
|
||||||
if inst.GetStatus() == instance.ShuttingDown {
|
if inst.GetStatus() == instance.ShuttingDown {
|
||||||
writeError(w, http.StatusServiceUnavailable, "instance_shutting_down", "Instance is shutting down")
|
writeError(w, http.StatusServiceUnavailable, "instance_shutting_down", "Instance is shutting down")
|
||||||
|
|||||||
@@ -2,9 +2,7 @@ package server
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/rand"
|
|
||||||
"crypto/subtle"
|
"crypto/subtle"
|
||||||
"encoding/hex"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"llamactl/pkg/auth"
|
"llamactl/pkg/auth"
|
||||||
"llamactl/pkg/config"
|
"llamactl/pkg/config"
|
||||||
@@ -16,13 +14,6 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type KeyType int
|
|
||||||
|
|
||||||
const (
|
|
||||||
KeyTypeInference KeyType = iota
|
|
||||||
KeyTypeManagement
|
|
||||||
)
|
|
||||||
|
|
||||||
// contextKey is a custom type for context keys to avoid collisions
|
// contextKey is a custom type for context keys to avoid collisions
|
||||||
type contextKey string
|
type contextKey string
|
||||||
|
|
||||||
@@ -56,7 +47,12 @@ func NewAPIAuthMiddleware(authCfg config.AuthConfig, authStore database.AuthStor
|
|||||||
const banner = "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
|
const banner = "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
|
||||||
|
|
||||||
if authCfg.RequireManagementAuth && len(authCfg.ManagementKeys) == 0 {
|
if authCfg.RequireManagementAuth && len(authCfg.ManagementKeys) == 0 {
|
||||||
key := generateAPIKey(KeyTypeManagement)
|
key, err := auth.GenerateKey("llamactl-mgmt-")
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Warning: Failed to generate management key: %v", err)
|
||||||
|
// Fallback to PID-based key for safety
|
||||||
|
key = fmt.Sprintf("sk-management-fallback-%d", os.Getpid())
|
||||||
|
}
|
||||||
managementKeys[key] = true
|
managementKeys[key] = true
|
||||||
generated = true
|
generated = true
|
||||||
fmt.Printf("%s\n⚠️ MANAGEMENT AUTHENTICATION REQUIRED\n%s\n", banner, banner)
|
fmt.Printf("%s\n⚠️ MANAGEMENT AUTHENTICATION REQUIRED\n%s\n", banner, banner)
|
||||||
@@ -79,32 +75,6 @@ func NewAPIAuthMiddleware(authCfg config.AuthConfig, authStore database.AuthStor
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// generateAPIKey creates a cryptographically secure API key
|
|
||||||
func generateAPIKey(keyType KeyType) string {
|
|
||||||
// Generate 32 random bytes (256 bits)
|
|
||||||
randomBytes := make([]byte, 32)
|
|
||||||
|
|
||||||
var prefix string
|
|
||||||
|
|
||||||
switch keyType {
|
|
||||||
case KeyTypeInference:
|
|
||||||
prefix = "sk-inference"
|
|
||||||
case KeyTypeManagement:
|
|
||||||
prefix = "sk-management"
|
|
||||||
default:
|
|
||||||
prefix = "sk-unknown"
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := rand.Read(randomBytes); err != nil {
|
|
||||||
log.Printf("Warning: Failed to generate secure random key, using fallback")
|
|
||||||
// Fallback to a less secure method if crypto/rand fails
|
|
||||||
return fmt.Sprintf("%s-fallback-%d", prefix, os.Getpid())
|
|
||||||
}
|
|
||||||
|
|
||||||
// Convert to hex and add prefix
|
|
||||||
return fmt.Sprintf("%s-%s", prefix, hex.EncodeToString(randomBytes))
|
|
||||||
}
|
|
||||||
|
|
||||||
// InferenceAuthMiddleware returns middleware for inference endpoints
|
// InferenceAuthMiddleware returns middleware for inference endpoints
|
||||||
func (a *APIAuthMiddleware) InferenceAuthMiddleware() func(http.Handler) http.Handler {
|
func (a *APIAuthMiddleware) InferenceAuthMiddleware() func(http.Handler) http.Handler {
|
||||||
return func(next http.Handler) http.Handler {
|
return func(next http.Handler) http.Handler {
|
||||||
@@ -123,7 +93,7 @@ func (a *APIAuthMiddleware) InferenceAuthMiddleware() func(http.Handler) http.Ha
|
|||||||
|
|
||||||
// Try database authentication first
|
// Try database authentication first
|
||||||
var foundKey *auth.APIKey
|
var foundKey *auth.APIKey
|
||||||
if a.requireInferenceAuth {
|
if a.requireInferenceAuth && a.authStore != nil {
|
||||||
activeKeys, err := a.authStore.GetActiveKeys(r.Context())
|
activeKeys, err := a.authStore.GetActiveKeys(r.Context())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Failed to get active inference keys: %v", err)
|
log.Printf("Failed to get active inference keys: %v", err)
|
||||||
@@ -208,80 +178,16 @@ func (a *APIAuthMiddleware) CheckInstancePermission(ctx context.Context, instanc
|
|||||||
// Check per-instance permissions
|
// Check per-instance permissions
|
||||||
canInfer, err := a.authStore.HasPermission(ctx, apiKey.ID, instanceID)
|
canInfer, err := a.authStore.HasPermission(ctx, apiKey.ID, instanceID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("failed to check permission: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !canInfer {
|
if !canInfer {
|
||||||
return http.ErrBodyNotAllowed // Use this as a generic error to indicate permission denied
|
return fmt.Errorf("permission denied: key does not have access to this instance")
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// AuthMiddleware returns a middleware that checks API keys for the given key type (legacy support)
|
|
||||||
func (a *APIAuthMiddleware) AuthMiddleware(keyType KeyType) func(http.Handler) http.Handler {
|
|
||||||
return func(next http.Handler) http.Handler {
|
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
if r.Method == "OPTIONS" {
|
|
||||||
next.ServeHTTP(w, r)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
apiKey := a.extractAPIKey(r)
|
|
||||||
if apiKey == "" {
|
|
||||||
a.unauthorized(w, "Missing API key")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var isValid bool
|
|
||||||
switch keyType {
|
|
||||||
case KeyTypeInference:
|
|
||||||
// Try database authentication first
|
|
||||||
if a.requireInferenceAuth {
|
|
||||||
activeKeys, err := a.authStore.GetActiveKeys(r.Context())
|
|
||||||
if err == nil {
|
|
||||||
for _, key := range activeKeys {
|
|
||||||
if auth.VerifyKey(apiKey, key.KeyHash) {
|
|
||||||
foundKey := key
|
|
||||||
// Async update last_used_at
|
|
||||||
go func(keyID int) {
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
if err := a.authStore.TouchKey(ctx, keyID); err != nil {
|
|
||||||
log.Printf("Failed to update last used timestamp for key %d: %v", keyID, err)
|
|
||||||
}
|
|
||||||
}(key.ID)
|
|
||||||
|
|
||||||
// Add APIKey to context for permission checking
|
|
||||||
ctx := context.WithValue(r.Context(), apiKeyContextKey, foundKey)
|
|
||||||
r = r.WithContext(ctx)
|
|
||||||
isValid = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// If no database key found, try management key (higher privilege)
|
|
||||||
if !isValid {
|
|
||||||
isValid = a.isValidManagementKey(apiKey)
|
|
||||||
}
|
|
||||||
case KeyTypeManagement:
|
|
||||||
isValid = a.isValidManagementKey(apiKey)
|
|
||||||
default:
|
|
||||||
isValid = false
|
|
||||||
}
|
|
||||||
|
|
||||||
if !isValid {
|
|
||||||
a.unauthorized(w, "Invalid API key")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
next.ServeHTTP(w, r)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// extractAPIKey extracts the API key from the request
|
// extractAPIKey extracts the API key from the request
|
||||||
func (a *APIAuthMiddleware) extractAPIKey(r *http.Request) string {
|
func (a *APIAuthMiddleware) extractAPIKey(r *http.Request) string {
|
||||||
// Check Authorization header: "Bearer sk-..."
|
// Check Authorization header: "Bearer sk-..."
|
||||||
|
|||||||
@@ -9,99 +9,44 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestAuthMiddleware(t *testing.T) {
|
func TestInferenceAuthMiddleware(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
keyType server.KeyType
|
|
||||||
inferenceKeys []string
|
|
||||||
managementKeys []string
|
managementKeys []string
|
||||||
requestKey string
|
requestKey string
|
||||||
method string
|
method string
|
||||||
expectedStatus int
|
expectedStatus int
|
||||||
}{
|
}{
|
||||||
// Valid key tests - using management keys only since config-based inference keys are deprecated
|
|
||||||
{
|
{
|
||||||
name: "valid management key for inference", // Management keys work for inference
|
name: "valid management key for inference",
|
||||||
keyType: server.KeyTypeInference,
|
|
||||||
managementKeys: []string{"sk-management-admin123"},
|
managementKeys: []string{"sk-management-admin123"},
|
||||||
requestKey: "sk-management-admin123",
|
requestKey: "sk-management-admin123",
|
||||||
method: "GET",
|
method: "GET",
|
||||||
expectedStatus: http.StatusOK,
|
expectedStatus: http.StatusOK,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "valid management key for management",
|
name: "invalid key",
|
||||||
keyType: server.KeyTypeManagement,
|
|
||||||
managementKeys: []string{"sk-management-admin123"},
|
|
||||||
requestKey: "sk-management-admin123",
|
|
||||||
method: "GET",
|
|
||||||
expectedStatus: http.StatusOK,
|
|
||||||
},
|
|
||||||
|
|
||||||
// Invalid key tests
|
|
||||||
{
|
|
||||||
name: "inference key for management should fail",
|
|
||||||
keyType: server.KeyTypeManagement,
|
|
||||||
inferenceKeys: []string{"sk-inference-user123"},
|
|
||||||
requestKey: "sk-inference-user123",
|
|
||||||
method: "GET",
|
|
||||||
expectedStatus: http.StatusUnauthorized,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "invalid inference key",
|
|
||||||
keyType: server.KeyTypeInference,
|
|
||||||
inferenceKeys: []string{"sk-inference-valid123"},
|
|
||||||
requestKey: "sk-inference-invalid",
|
|
||||||
method: "GET",
|
|
||||||
expectedStatus: http.StatusUnauthorized,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "missing inference key",
|
|
||||||
keyType: server.KeyTypeInference,
|
|
||||||
inferenceKeys: []string{"sk-inference-valid123"},
|
|
||||||
requestKey: "",
|
|
||||||
method: "GET",
|
|
||||||
expectedStatus: http.StatusUnauthorized,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "invalid management key",
|
|
||||||
keyType: server.KeyTypeManagement,
|
|
||||||
managementKeys: []string{"sk-management-valid123"},
|
managementKeys: []string{"sk-management-valid123"},
|
||||||
requestKey: "sk-management-invalid",
|
requestKey: "sk-management-invalid",
|
||||||
method: "GET",
|
method: "GET",
|
||||||
expectedStatus: http.StatusUnauthorized,
|
expectedStatus: http.StatusUnauthorized,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "missing management key",
|
name: "missing key",
|
||||||
keyType: server.KeyTypeManagement,
|
|
||||||
managementKeys: []string{"sk-management-valid123"},
|
managementKeys: []string{"sk-management-valid123"},
|
||||||
requestKey: "",
|
requestKey: "",
|
||||||
method: "GET",
|
method: "GET",
|
||||||
expectedStatus: http.StatusUnauthorized,
|
expectedStatus: http.StatusUnauthorized,
|
||||||
},
|
},
|
||||||
|
|
||||||
// OPTIONS requests should always pass
|
|
||||||
{
|
{
|
||||||
name: "OPTIONS request bypasses inference auth",
|
name: "OPTIONS request bypasses auth",
|
||||||
keyType: server.KeyTypeInference,
|
|
||||||
inferenceKeys: []string{"sk-inference-valid123"},
|
|
||||||
requestKey: "",
|
|
||||||
method: "OPTIONS",
|
|
||||||
expectedStatus: http.StatusOK,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "OPTIONS request bypasses management auth",
|
|
||||||
keyType: server.KeyTypeManagement,
|
|
||||||
managementKeys: []string{"sk-management-valid123"},
|
managementKeys: []string{"sk-management-valid123"},
|
||||||
requestKey: "",
|
requestKey: "",
|
||||||
method: "OPTIONS",
|
method: "OPTIONS",
|
||||||
expectedStatus: http.StatusOK,
|
expectedStatus: http.StatusOK,
|
||||||
},
|
},
|
||||||
|
|
||||||
// Cross-key-type validation
|
|
||||||
{
|
{
|
||||||
name: "management key works for inference endpoint",
|
name: "management key works for inference endpoint",
|
||||||
keyType: server.KeyTypeInference,
|
|
||||||
inferenceKeys: []string{},
|
|
||||||
managementKeys: []string{"sk-management-admin"},
|
managementKeys: []string{"sk-management-admin"},
|
||||||
requestKey: "sk-management-admin",
|
requestKey: "sk-management-admin",
|
||||||
method: "POST",
|
method: "POST",
|
||||||
@@ -112,7 +57,7 @@ func TestAuthMiddleware(t *testing.T) {
|
|||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
cfg := config.AuthConfig{
|
cfg := config.AuthConfig{
|
||||||
InferenceKeys: tt.inferenceKeys,
|
RequireInferenceAuth: true,
|
||||||
ManagementKeys: tt.managementKeys,
|
ManagementKeys: tt.managementKeys,
|
||||||
}
|
}
|
||||||
middleware := server.NewAPIAuthMiddleware(cfg, nil)
|
middleware := server.NewAPIAuthMiddleware(cfg, nil)
|
||||||
@@ -123,24 +68,17 @@ func TestAuthMiddleware(t *testing.T) {
|
|||||||
req.Header.Set("Authorization", "Bearer "+tt.requestKey)
|
req.Header.Set("Authorization", "Bearer "+tt.requestKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create test handler using appropriate middleware
|
// Create test handler
|
||||||
var handler http.Handler
|
handler := middleware.InferenceAuthMiddleware()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
if tt.keyType == server.KeyTypeInference {
|
|
||||||
handler = middleware.AuthMiddleware(server.KeyTypeInference)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
}))
|
}))
|
||||||
} else {
|
|
||||||
handler = middleware.AuthMiddleware(server.KeyTypeManagement)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
}))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Execute request
|
// Execute request
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
handler.ServeHTTP(recorder, req)
|
handler.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
if recorder.Code != tt.expectedStatus {
|
if recorder.Code != tt.expectedStatus {
|
||||||
t.Errorf("AuthMiddleware() status = %v, expected %v", recorder.Code, tt.expectedStatus)
|
t.Errorf("InferenceAuthMiddleware() status = %v, expected %v", recorder.Code, tt.expectedStatus)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check that unauthorized responses have proper format
|
// Check that unauthorized responses have proper format
|
||||||
@@ -159,25 +97,92 @@ func TestAuthMiddleware(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGenerateAPIKey(t *testing.T) {
|
func TestManagementAuthMiddleware(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
keyType server.KeyType
|
managementKeys []string
|
||||||
|
requestKey string
|
||||||
|
method string
|
||||||
|
expectedStatus int
|
||||||
}{
|
}{
|
||||||
{"inference key generation", server.KeyTypeInference},
|
{
|
||||||
{"management key generation", server.KeyTypeManagement},
|
name: "valid management key",
|
||||||
|
managementKeys: []string{"sk-management-admin123"},
|
||||||
|
requestKey: "sk-management-admin123",
|
||||||
|
method: "GET",
|
||||||
|
expectedStatus: http.StatusOK,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid management key",
|
||||||
|
managementKeys: []string{"sk-management-valid123"},
|
||||||
|
requestKey: "sk-management-invalid",
|
||||||
|
method: "GET",
|
||||||
|
expectedStatus: http.StatusUnauthorized,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing management key",
|
||||||
|
managementKeys: []string{"sk-management-valid123"},
|
||||||
|
requestKey: "",
|
||||||
|
method: "GET",
|
||||||
|
expectedStatus: http.StatusUnauthorized,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "OPTIONS request bypasses management auth",
|
||||||
|
managementKeys: []string{"sk-management-valid123"},
|
||||||
|
requestKey: "",
|
||||||
|
method: "OPTIONS",
|
||||||
|
expectedStatus: http.StatusOK,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
// Test auto-generation by creating config that will trigger it
|
cfg := config.AuthConfig{
|
||||||
var config config.AuthConfig
|
RequireManagementAuth: true,
|
||||||
if tt.keyType == server.KeyTypeInference {
|
ManagementKeys: tt.managementKeys,
|
||||||
config.RequireInferenceAuth = true
|
}
|
||||||
config.InferenceKeys = []string{} // Empty to trigger generation
|
middleware := server.NewAPIAuthMiddleware(cfg, nil)
|
||||||
} else {
|
|
||||||
config.RequireManagementAuth = true
|
// Create test request
|
||||||
config.ManagementKeys = []string{} // Empty to trigger generation
|
req := httptest.NewRequest(tt.method, "/test", nil)
|
||||||
|
if tt.requestKey != "" {
|
||||||
|
req.Header.Set("Authorization", "Bearer "+tt.requestKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create test handler
|
||||||
|
handler := middleware.ManagementAuthMiddleware()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
|
||||||
|
// Execute request
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
if recorder.Code != tt.expectedStatus {
|
||||||
|
t.Errorf("ManagementAuthMiddleware() status = %v, expected %v", recorder.Code, tt.expectedStatus)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that unauthorized responses have proper format
|
||||||
|
if recorder.Code == http.StatusUnauthorized {
|
||||||
|
contentType := recorder.Header().Get("Content-Type")
|
||||||
|
if contentType != "application/json" {
|
||||||
|
t.Errorf("Unauthorized response Content-Type = %v, expected application/json", contentType)
|
||||||
|
}
|
||||||
|
|
||||||
|
body := recorder.Body.String()
|
||||||
|
if !strings.Contains(body, `"type": "authentication_error"`) {
|
||||||
|
t.Errorf("Unauthorized response missing proper error type: %v", body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManagementKeyAutoGeneration(t *testing.T) {
|
||||||
|
// Test auto-generation for management keys
|
||||||
|
config := config.AuthConfig{
|
||||||
|
RequireManagementAuth: true,
|
||||||
|
ManagementKeys: []string{}, // Empty to trigger generation
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create middleware - this should trigger key generation
|
// Create middleware - this should trigger key generation
|
||||||
@@ -187,16 +192,9 @@ func TestGenerateAPIKey(t *testing.T) {
|
|||||||
req := httptest.NewRequest("GET", "/", nil)
|
req := httptest.NewRequest("GET", "/", nil)
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
|
|
||||||
var handler http.Handler
|
handler := middleware.ManagementAuthMiddleware()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
if tt.keyType == server.KeyTypeInference {
|
|
||||||
handler = middleware.AuthMiddleware(server.KeyTypeInference)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
}))
|
}))
|
||||||
} else {
|
|
||||||
handler = middleware.AuthMiddleware(server.KeyTypeManagement)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
}))
|
|
||||||
}
|
|
||||||
|
|
||||||
handler.ServeHTTP(recorder, req)
|
handler.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
@@ -211,126 +209,59 @@ func TestGenerateAPIKey(t *testing.T) {
|
|||||||
req2 := httptest.NewRequest("GET", "/", nil)
|
req2 := httptest.NewRequest("GET", "/", nil)
|
||||||
recorder2 := httptest.NewRecorder()
|
recorder2 := httptest.NewRecorder()
|
||||||
|
|
||||||
if tt.keyType == server.KeyTypeInference {
|
handler2 := middleware2.ManagementAuthMiddleware()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
handler2 := middleware2.AuthMiddleware(server.KeyTypeInference)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
}))
|
}))
|
||||||
handler2.ServeHTTP(recorder2, req2)
|
handler2.ServeHTTP(recorder2, req2)
|
||||||
} else {
|
|
||||||
handler2 := middleware2.AuthMiddleware(server.KeyTypeManagement)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
}))
|
|
||||||
handler2.ServeHTTP(recorder2, req2)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Both should require auth (proving keys were generated for both instances)
|
// Both should require auth (proving keys were generated for both instances)
|
||||||
if recorder2.Code != http.StatusUnauthorized {
|
if recorder2.Code != http.StatusUnauthorized {
|
||||||
t.Errorf("Expected unauthorized for second middleware without key, got status %v", recorder2.Code)
|
t.Errorf("Expected unauthorized for second middleware without key, got status %v", recorder2.Code)
|
||||||
}
|
}
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAutoGeneration(t *testing.T) {
|
func TestAutoGenerationScenarios(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
requireInference bool
|
|
||||||
requireManagement bool
|
requireManagement bool
|
||||||
providedInference []string
|
|
||||||
providedManagement []string
|
providedManagement []string
|
||||||
shouldGenerateInf bool // Whether inference key should be generated
|
shouldGenerate bool
|
||||||
shouldGenerateMgmt bool // Whether management key should be generated
|
|
||||||
}{
|
}{
|
||||||
{
|
|
||||||
name: "inference auth required, keys provided - no generation",
|
|
||||||
requireInference: true,
|
|
||||||
requireManagement: false,
|
|
||||||
providedInference: []string{"sk-inference-provided"},
|
|
||||||
providedManagement: []string{},
|
|
||||||
shouldGenerateInf: false,
|
|
||||||
shouldGenerateMgmt: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "inference auth required, no keys - should auto-generate",
|
|
||||||
requireInference: true,
|
|
||||||
requireManagement: false,
|
|
||||||
providedInference: []string{},
|
|
||||||
providedManagement: []string{},
|
|
||||||
shouldGenerateInf: true,
|
|
||||||
shouldGenerateMgmt: false,
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
name: "management auth required, keys provided - no generation",
|
name: "management auth required, keys provided - no generation",
|
||||||
requireInference: false,
|
|
||||||
requireManagement: true,
|
requireManagement: true,
|
||||||
providedInference: []string{},
|
|
||||||
providedManagement: []string{"sk-management-provided"},
|
providedManagement: []string{"sk-management-provided"},
|
||||||
shouldGenerateInf: false,
|
shouldGenerate: false,
|
||||||
shouldGenerateMgmt: false,
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "management auth required, no keys - should auto-generate",
|
name: "management auth required, no keys - should auto-generate",
|
||||||
requireInference: false,
|
|
||||||
requireManagement: true,
|
requireManagement: true,
|
||||||
providedInference: []string{},
|
|
||||||
providedManagement: []string{},
|
providedManagement: []string{},
|
||||||
shouldGenerateInf: false,
|
shouldGenerate: true,
|
||||||
shouldGenerateMgmt: true,
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "both required, both provided - no generation",
|
name: "management auth not required - no generation",
|
||||||
requireInference: true,
|
requireManagement: false,
|
||||||
requireManagement: true,
|
|
||||||
providedInference: []string{"sk-inference-provided"},
|
|
||||||
providedManagement: []string{"sk-management-provided"},
|
|
||||||
shouldGenerateInf: false,
|
|
||||||
shouldGenerateMgmt: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "both required, none provided - should auto-generate both",
|
|
||||||
requireInference: true,
|
|
||||||
requireManagement: true,
|
|
||||||
providedInference: []string{},
|
|
||||||
providedManagement: []string{},
|
providedManagement: []string{},
|
||||||
shouldGenerateInf: true,
|
shouldGenerate: false,
|
||||||
shouldGenerateMgmt: true,
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
cfg := config.AuthConfig{
|
cfg := config.AuthConfig{
|
||||||
RequireInferenceAuth: tt.requireInference,
|
|
||||||
RequireManagementAuth: tt.requireManagement,
|
RequireManagementAuth: tt.requireManagement,
|
||||||
InferenceKeys: tt.providedInference,
|
|
||||||
ManagementKeys: tt.providedManagement,
|
ManagementKeys: tt.providedManagement,
|
||||||
}
|
}
|
||||||
|
|
||||||
middleware := server.NewAPIAuthMiddleware(cfg, nil)
|
middleware := server.NewAPIAuthMiddleware(cfg, nil)
|
||||||
|
|
||||||
// Test inference behavior if inference auth is required
|
|
||||||
if tt.requireInference {
|
|
||||||
req := httptest.NewRequest("GET", "/v1/models", nil)
|
|
||||||
recorder := httptest.NewRecorder()
|
|
||||||
|
|
||||||
handler := middleware.AuthMiddleware(server.KeyTypeInference)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
}))
|
|
||||||
|
|
||||||
handler.ServeHTTP(recorder, req)
|
|
||||||
|
|
||||||
// Should always be unauthorized without a key (since middleware assumes auth is required)
|
|
||||||
if recorder.Code != http.StatusUnauthorized {
|
|
||||||
t.Errorf("Expected unauthorized for inference without key, got status %v", recorder.Code)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test management behavior if management auth is required
|
// Test management behavior if management auth is required
|
||||||
if tt.requireManagement {
|
if tt.requireManagement {
|
||||||
req := httptest.NewRequest("GET", "/api/v1/instances", nil)
|
req := httptest.NewRequest("GET", "/api/v1/instances", nil)
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
|
|
||||||
handler := middleware.AuthMiddleware(server.KeyTypeManagement)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
handler := middleware.ManagementAuthMiddleware()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
}))
|
}))
|
||||||
|
|
||||||
@@ -344,3 +275,16 @@ func TestAutoGeneration(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConfigBasedInferenceKeysDeprecationWarning(t *testing.T) {
|
||||||
|
// Test that config-based inference keys trigger a warning (captured in logs)
|
||||||
|
cfg := config.AuthConfig{
|
||||||
|
InferenceKeys: []string{"sk-inference-old"},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Creating middleware should log a warning, but shouldn't fail
|
||||||
|
_ = server.NewAPIAuthMiddleware(cfg, nil)
|
||||||
|
|
||||||
|
// If we get here without panic, the test passes
|
||||||
|
// The warning is logged but not returned as an error
|
||||||
|
}
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ func SetupRouter(handler *Handler) *chi.Mux {
|
|||||||
r.Route("/api/v1", func(r chi.Router) {
|
r.Route("/api/v1", func(r chi.Router) {
|
||||||
|
|
||||||
if authMiddleware != nil && handler.cfg.Auth.RequireManagementAuth {
|
if authMiddleware != nil && handler.cfg.Auth.RequireManagementAuth {
|
||||||
r.Use(authMiddleware.AuthMiddleware(KeyTypeManagement))
|
r.Use(authMiddleware.ManagementAuthMiddleware())
|
||||||
}
|
}
|
||||||
|
|
||||||
r.Get("/version", handler.VersionHandler())
|
r.Get("/version", handler.VersionHandler())
|
||||||
@@ -108,7 +108,7 @@ func SetupRouter(handler *Handler) *chi.Mux {
|
|||||||
r.Route("/v1", func(r chi.Router) {
|
r.Route("/v1", func(r chi.Router) {
|
||||||
|
|
||||||
if authMiddleware != nil && handler.cfg.Auth.RequireInferenceAuth {
|
if authMiddleware != nil && handler.cfg.Auth.RequireInferenceAuth {
|
||||||
r.Use(authMiddleware.AuthMiddleware(KeyTypeInference))
|
r.Use(authMiddleware.InferenceAuthMiddleware())
|
||||||
}
|
}
|
||||||
|
|
||||||
r.Get("/models", handler.OpenAIListInstances()) // List instances in OpenAI-compatible format
|
r.Get("/models", handler.OpenAIListInstances()) // List instances in OpenAI-compatible format
|
||||||
@@ -136,7 +136,7 @@ func SetupRouter(handler *Handler) *chi.Mux {
|
|||||||
r.Group(func(r chi.Router) {
|
r.Group(func(r chi.Router) {
|
||||||
|
|
||||||
if authMiddleware != nil && handler.cfg.Auth.RequireInferenceAuth {
|
if authMiddleware != nil && handler.cfg.Auth.RequireInferenceAuth {
|
||||||
r.Use(authMiddleware.AuthMiddleware(KeyTypeInference))
|
r.Use(authMiddleware.InferenceAuthMiddleware())
|
||||||
}
|
}
|
||||||
|
|
||||||
// This handler auto starts the server if it's not running
|
// This handler auto starts the server if it's not running
|
||||||
|
|||||||
Reference in New Issue
Block a user