From b3540d5b3e0899bb4338b260daf9b29850131d27 Mon Sep 17 00:00:00 2001 From: LordMathis Date: Wed, 30 Jul 2025 20:15:14 +0200 Subject: [PATCH 1/4] Implement api key auth --- pkg/config.go | 52 +++++++++++ pkg/middleware.go | 215 ++++++++++++++++++++++++++++++++++++++++++++++ pkg/routes.go | 44 +++++++--- 3 files changed, 298 insertions(+), 13 deletions(-) create mode 100644 pkg/middleware.go diff --git a/pkg/config.go b/pkg/config.go index 7e7256a..d5b4571 100644 --- a/pkg/config.go +++ b/pkg/config.go @@ -14,6 +14,7 @@ import ( type Config struct { Server ServerConfig `yaml:"server"` Instances InstancesConfig `yaml:"instances"` + Auth AuthConfig `yaml:"auth"` } // ServerConfig contains HTTP server configuration @@ -26,6 +27,9 @@ type ServerConfig struct { // Allowed origins for CORS (e.g., "http://localhost:3000") AllowedOrigins []string `yaml:"allowed_origins"` + + // Enable Swagger UI for API documentation + EnableSwagger bool `yaml:"enable_swagger"` } // InstancesConfig contains instance management configuration @@ -52,6 +56,22 @@ type InstancesConfig struct { DefaultRestartDelay int `yaml:"default_restart_delay"` } +// AuthConfig contains authentication settings +type AuthConfig struct { + + // Require authentication for OpenAI compatible inference endpoints + RequireInferenceAuth bool `yaml:"require_inference_auth"` + + // List of keys for OpenAI compatible inference endpoints + InferenceKeys []string `yaml:"inference_keys"` + + // Require authentication for management endpoints + RequireManagementAuth bool `yaml:"require_management_auth"` + + // List of keys for management endpoints + ManagementKeys []string `yaml:"management_keys"` +} + // LoadConfig loads configuration with the following precedence: // 1. Hardcoded defaults // 2. Config file @@ -63,6 +83,7 @@ func LoadConfig(configPath string) (Config, error) { Host: "0.0.0.0", Port: 8080, AllowedOrigins: []string{"*"}, // Default to allow all origins + EnableSwagger: false, }, Instances: InstancesConfig{ PortRange: [2]int{8000, 9000}, @@ -73,6 +94,12 @@ func LoadConfig(configPath string) (Config, error) { DefaultMaxRestarts: 3, DefaultRestartDelay: 5, }, + Auth: AuthConfig{ + RequireInferenceAuth: true, + InferenceKeys: []string{}, + RequireManagementAuth: true, + ManagementKeys: []string{}, + }, } // 2. Load from config file @@ -121,6 +148,14 @@ func loadEnvVars(cfg *Config) { cfg.Server.Port = p } } + if allowedOrigins := os.Getenv("LLAMACTL_ALLOWED_ORIGINS"); allowedOrigins != "" { + cfg.Server.AllowedOrigins = strings.Split(allowedOrigins, ",") + } + if enableSwagger := os.Getenv("LLAMACTL_ENABLE_SWAGGER"); enableSwagger != "" { + if b, err := strconv.ParseBool(enableSwagger); err == nil { + cfg.Server.EnableSwagger = b + } + } // Instance config if portRange := os.Getenv("LLAMACTL_INSTANCE_PORT_RANGE"); portRange != "" { @@ -154,6 +189,23 @@ func loadEnvVars(cfg *Config) { cfg.Instances.DefaultRestartDelay = seconds } } + // Auth config + if requireInferenceAuth := os.Getenv("LLAMACTL_REQUIRE_INFERENCE_AUTH"); requireInferenceAuth != "" { + if b, err := strconv.ParseBool(requireInferenceAuth); err == nil { + cfg.Auth.RequireInferenceAuth = b + } + } + if inferenceKeys := os.Getenv("LLAMACTL_INFERENCE_KEYS"); inferenceKeys != "" { + cfg.Auth.InferenceKeys = strings.Split(inferenceKeys, ",") + } + if requireManagementAuth := os.Getenv("LLAMACTL_REQUIRE_MANAGEMENT_AUTH"); requireManagementAuth != "" { + if b, err := strconv.ParseBool(requireManagementAuth); err == nil { + cfg.Auth.RequireManagementAuth = b + } + } + if managementKeys := os.Getenv("LLAMACTL_MANAGEMENT_KEYS"); managementKeys != "" { + cfg.Auth.ManagementKeys = strings.Split(managementKeys, ",") + } } // ParsePortRange parses port range from string formats like "8000-9000" or "8000,9000" diff --git a/pkg/middleware.go b/pkg/middleware.go new file mode 100644 index 0000000..13cc800 --- /dev/null +++ b/pkg/middleware.go @@ -0,0 +1,215 @@ +package llamactl + +import ( + "crypto/rand" + "crypto/subtle" + "encoding/hex" + "fmt" + "log" + "net/http" + "os" + "strings" +) + +type KeyType int + +const ( + KeyTypeInference KeyType = iota + KeyTypeManagement +) + +type APIAuthMiddleware struct { + requireInferenceAuth bool + inferenceKeys map[string]bool + requireManagementAuth bool + managementKeys map[string]bool +} + +// NewAPIAuthMiddleware creates a new APIAuthMiddleware with the given configuration +func NewAPIAuthMiddleware(config AuthConfig) *APIAuthMiddleware { + + var generated bool = false + + inferenceAPIKeys := make(map[string]bool) + managementAPIKeys := make(map[string]bool) + + const banner = "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" + + if config.RequireManagementAuth && len(config.ManagementKeys) == 0 { + key := generateAPIKey(KeyTypeManagement) + managementAPIKeys[key] = true + generated = true + fmt.Printf("%s\n⚠️ MANAGEMENT AUTHENTICATION REQUIRED\n%s\n", banner, banner) + fmt.Printf("🔑 Generated Management API Key:\n\n %s\n\n", key) + } + for _, key := range config.ManagementKeys { + managementAPIKeys[key] = true + } + + if config.RequireInferenceAuth && len(config.InferenceKeys) == 0 { + key := generateAPIKey(KeyTypeInference) + inferenceAPIKeys[key] = true + generated = true + fmt.Printf("%s\n⚠️ INFERENCE AUTHENTICATION REQUIRED\n%s\n", banner, banner) + fmt.Printf("🔑 Generated Inference API Key:\n\n %s\n\n", key) + } + for _, key := range config.InferenceKeys { + inferenceAPIKeys[key] = true + } + + if generated { + fmt.Printf("%s\n⚠️ IMPORTANT\n%s\n", banner, banner) + fmt.Println("• These keys are auto-generated and will change on restart") + fmt.Println("• For production, add explicit keys to your configuration") + fmt.Println("• Copy these keys before they disappear from the terminal") + fmt.Println(banner) + } + + return &APIAuthMiddleware{ + requireInferenceAuth: config.RequireInferenceAuth, + inferenceKeys: inferenceAPIKeys, + requireManagementAuth: config.RequireManagementAuth, + managementKeys: managementAPIKeys, + } +} + +// generateAPIKey creates a cryptographically secure API key +func generateAPIKey(keyType KeyType) string { + // Generate 32 random bytes (256 bits) + randomBytes := make([]byte, 32) + + var prefix string + + switch keyType { + case KeyTypeInference: + prefix = "sk-inference" + case KeyTypeManagement: + prefix = "sk-management" + default: + prefix = "sk-unknown" + } + + if _, err := rand.Read(randomBytes); err != nil { + log.Printf("Warning: Failed to generate secure random key, using fallback") + // Fallback to a less secure method if crypto/rand fails + return fmt.Sprintf("%s-fallback-%d", prefix, os.Getpid()) + } + + // Convert to hex and add prefix + return fmt.Sprintf("%s-%s", prefix, hex.EncodeToString(randomBytes)) +} + +// InferenceMiddleware returns middleware for OpenAI inference endpoints +func (a *APIAuthMiddleware) InferenceMiddleware() func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !a.requireInferenceAuth { + next.ServeHTTP(w, r) + return + } + + if r.Method == "OPTIONS" { + next.ServeHTTP(w, r) + return + } + + apiKey := a.extractAPIKey(r) + if apiKey == "" { + a.unauthorized(w, "Missing API key") + return + } + + // Check if key is valid for OpenAI access + // Management keys also work for OpenAI endpoints (higher privilege) + if !a.isValidKey(apiKey, KeyTypeInference) && !a.isValidKey(apiKey, KeyTypeManagement) { + a.unauthorized(w, "Invalid API key") + return + } + + next.ServeHTTP(w, r) + }) + } +} + +// ManagementMiddleware returns middleware for management endpoints +func (a *APIAuthMiddleware) ManagementMiddleware() func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !a.requireManagementAuth { + next.ServeHTTP(w, r) + return + } + + if r.Method == "OPTIONS" { + next.ServeHTTP(w, r) + return + } + + apiKey := a.extractAPIKey(r) + if apiKey == "" { + a.unauthorized(w, "Missing API key") + return + } + + // Only management keys work for management endpoints + if !a.isValidKey(apiKey, KeyTypeManagement) { + a.unauthorized(w, "Insufficient privileges - management key required") + return + } + + next.ServeHTTP(w, r) + }) + } +} + +// extractAPIKey extracts the API key from the request +func (a *APIAuthMiddleware) extractAPIKey(r *http.Request) string { + // Check Authorization header: "Bearer sk-..." + if auth := r.Header.Get("Authorization"); auth != "" { + if strings.HasPrefix(auth, "Bearer ") { + return strings.TrimPrefix(auth, "Bearer ") + } + } + + // Check X-API-Key header + if apiKey := r.Header.Get("X-API-Key"); apiKey != "" { + return apiKey + } + + // Check query parameter + if apiKey := r.URL.Query().Get("api_key"); apiKey != "" { + return apiKey + } + + return "" +} + +// isValidKey checks if the provided API key is valid for the given key type +func (a *APIAuthMiddleware) isValidKey(providedKey string, keyType KeyType) bool { + var validKeys map[string]bool + + switch keyType { + case KeyTypeInference: + validKeys = a.inferenceKeys + case KeyTypeManagement: + validKeys = a.managementKeys + default: + return false + } + + for validKey := range validKeys { + if len(providedKey) == len(validKey) && + subtle.ConstantTimeCompare([]byte(providedKey), []byte(validKey)) == 1 { + return true + } + } + return false +} + +// unauthorized sends an unauthorized response +func (a *APIAuthMiddleware) unauthorized(w http.ResponseWriter, message string) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusUnauthorized) + response := fmt.Sprintf(`{"error": {"message": "%s", "type": "authentication_error"}}`, message) + w.Write([]byte(response)) +} diff --git a/pkg/routes.go b/pkg/routes.go index 4685b85..998af38 100644 --- a/pkg/routes.go +++ b/pkg/routes.go @@ -26,12 +26,22 @@ func SetupRouter(handler *Handler) *chi.Mux { MaxAge: 300, })) - r.Get("/swagger/*", httpSwagger.Handler( - httpSwagger.URL("/swagger/doc.json"), - )) + // Add API authentication middleware + authMiddleware := NewAPIAuthMiddleware(handler.config.Auth) + + if handler.config.Server.EnableSwagger { + r.Get("/swagger/*", httpSwagger.Handler( + httpSwagger.URL("/swagger/doc.json"), + )) + } // Define routes r.Route("/api/v1", func(r chi.Router) { + + if authMiddleware != nil && handler.config.Auth.RequireManagementAuth { + r.Use(authMiddleware.ManagementMiddleware()) + } + r.Route("/server", func(r chi.Router) { r.Get("/help", handler.HelpHandler()) r.Get("/version", handler.VersionHandler()) @@ -61,17 +71,25 @@ func SetupRouter(handler *Handler) *chi.Mux { }) }) - r.Get(("/v1/models"), handler.OpenAIListInstances()) // List instances in OpenAI-compatible format + r.Route(("/v1"), func(r chi.Router) { - // OpenAI-compatible proxy endpoint - // Handles all POST requests to /v1/*, including: - // - /v1/completions - // - /v1/chat/completions - // - /v1/embeddings - // - /v1/rerank - // - /v1/reranking - // The instance/model to use is determined by the request body. - r.Post("/v1/*", handler.OpenAIProxy()) + if authMiddleware != nil && handler.config.Auth.RequireInferenceAuth { + r.Use(authMiddleware.InferenceMiddleware()) + } + + r.Get(("/models"), handler.OpenAIListInstances()) // List instances in OpenAI-compatible format + + // OpenAI-compatible proxy endpoint + // Handles all POST requests to /v1/*, including: + // - /v1/completions + // - /v1/chat/completions + // - /v1/embeddings + // - /v1/rerank + // - /v1/reranking + // The instance/model to use is determined by the request body. + r.Post("/*", handler.OpenAIProxy()) + + }) // Serve WebUI files if err := webui.SetupWebUI(r); err != nil { From bedec089ef3c5a9c13e3abd84cb1646d138d96bb Mon Sep 17 00:00:00 2001 From: LordMathis Date: Wed, 30 Jul 2025 21:20:50 +0200 Subject: [PATCH 2/4] Implement middleware tests --- pkg/middleware.go | 59 ++----- pkg/middleware_test.go | 354 +++++++++++++++++++++++++++++++++++++++++ pkg/routes.go | 4 +- 3 files changed, 372 insertions(+), 45 deletions(-) create mode 100644 pkg/middleware_test.go diff --git a/pkg/middleware.go b/pkg/middleware.go index 13cc800..6ed3195 100644 --- a/pkg/middleware.go +++ b/pkg/middleware.go @@ -99,15 +99,10 @@ func generateAPIKey(keyType KeyType) string { return fmt.Sprintf("%s-%s", prefix, hex.EncodeToString(randomBytes)) } -// InferenceMiddleware returns middleware for OpenAI inference endpoints -func (a *APIAuthMiddleware) InferenceMiddleware() func(http.Handler) http.Handler { +// AuthMiddleware returns a middleware that checks API keys for the given key type +func (a *APIAuthMiddleware) AuthMiddleware(keyType KeyType) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if !a.requireInferenceAuth { - next.ServeHTTP(w, r) - return - } - if r.Method == "OPTIONS" { next.ServeHTTP(w, r) return @@ -119,9 +114,18 @@ func (a *APIAuthMiddleware) InferenceMiddleware() func(http.Handler) http.Handle return } - // Check if key is valid for OpenAI access - // Management keys also work for OpenAI endpoints (higher privilege) - if !a.isValidKey(apiKey, KeyTypeInference) && !a.isValidKey(apiKey, KeyTypeManagement) { + var isValid bool + switch keyType { + case KeyTypeInference: + // Management keys also work for OpenAI endpoints (higher privilege) + isValid = a.isValidKey(apiKey, KeyTypeInference) || a.isValidKey(apiKey, KeyTypeManagement) + case KeyTypeManagement: + isValid = a.isValidKey(apiKey, KeyTypeManagement) + default: + isValid = false + } + + if !isValid { a.unauthorized(w, "Invalid API key") return } @@ -131,43 +135,12 @@ func (a *APIAuthMiddleware) InferenceMiddleware() func(http.Handler) http.Handle } } -// ManagementMiddleware returns middleware for management endpoints -func (a *APIAuthMiddleware) ManagementMiddleware() func(http.Handler) http.Handler { - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if !a.requireManagementAuth { - next.ServeHTTP(w, r) - return - } - - if r.Method == "OPTIONS" { - next.ServeHTTP(w, r) - return - } - - apiKey := a.extractAPIKey(r) - if apiKey == "" { - a.unauthorized(w, "Missing API key") - return - } - - // Only management keys work for management endpoints - if !a.isValidKey(apiKey, KeyTypeManagement) { - a.unauthorized(w, "Insufficient privileges - management key required") - return - } - - next.ServeHTTP(w, r) - }) - } -} - // extractAPIKey extracts the API key from the request func (a *APIAuthMiddleware) extractAPIKey(r *http.Request) string { // Check Authorization header: "Bearer sk-..." if auth := r.Header.Get("Authorization"); auth != "" { - if strings.HasPrefix(auth, "Bearer ") { - return strings.TrimPrefix(auth, "Bearer ") + if after, ok := strings.CutPrefix(auth, "Bearer "); ok { + return after } } diff --git a/pkg/middleware_test.go b/pkg/middleware_test.go new file mode 100644 index 0000000..2e16d1a --- /dev/null +++ b/pkg/middleware_test.go @@ -0,0 +1,354 @@ +package llamactl_test + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" + + llamactl "llamactl/pkg" +) + +func TestAuthMiddleware(t *testing.T) { + tests := []struct { + name string + keyType llamactl.KeyType + inferenceKeys []string + managementKeys []string + requestKey string + method string + expectedStatus int + }{ + // Valid key tests + { + name: "valid inference key for inference", + keyType: llamactl.KeyTypeInference, + inferenceKeys: []string{"sk-inference-valid123"}, + requestKey: "sk-inference-valid123", + method: "GET", + expectedStatus: http.StatusOK, + }, + { + name: "valid management key for inference", // Management keys work for inference + keyType: llamactl.KeyTypeInference, + managementKeys: []string{"sk-management-admin123"}, + requestKey: "sk-management-admin123", + method: "GET", + expectedStatus: http.StatusOK, + }, + { + name: "valid management key for management", + keyType: llamactl.KeyTypeManagement, + managementKeys: []string{"sk-management-admin123"}, + requestKey: "sk-management-admin123", + method: "GET", + expectedStatus: http.StatusOK, + }, + + // Invalid key tests + { + name: "inference key for management should fail", + keyType: llamactl.KeyTypeManagement, + inferenceKeys: []string{"sk-inference-user123"}, + requestKey: "sk-inference-user123", + method: "GET", + expectedStatus: http.StatusUnauthorized, + }, + { + name: "invalid inference key", + keyType: llamactl.KeyTypeInference, + inferenceKeys: []string{"sk-inference-valid123"}, + requestKey: "sk-inference-invalid", + method: "GET", + expectedStatus: http.StatusUnauthorized, + }, + { + name: "missing inference key", + keyType: llamactl.KeyTypeInference, + inferenceKeys: []string{"sk-inference-valid123"}, + requestKey: "", + method: "GET", + expectedStatus: http.StatusUnauthorized, + }, + { + name: "invalid management key", + keyType: llamactl.KeyTypeManagement, + managementKeys: []string{"sk-management-valid123"}, + requestKey: "sk-management-invalid", + method: "GET", + expectedStatus: http.StatusUnauthorized, + }, + { + name: "missing management key", + keyType: llamactl.KeyTypeManagement, + managementKeys: []string{"sk-management-valid123"}, + requestKey: "", + method: "GET", + expectedStatus: http.StatusUnauthorized, + }, + + // OPTIONS requests should always pass + { + name: "OPTIONS request bypasses inference auth", + keyType: llamactl.KeyTypeInference, + inferenceKeys: []string{"sk-inference-valid123"}, + requestKey: "", + method: "OPTIONS", + expectedStatus: http.StatusOK, + }, + { + name: "OPTIONS request bypasses management auth", + keyType: llamactl.KeyTypeManagement, + managementKeys: []string{"sk-management-valid123"}, + requestKey: "", + method: "OPTIONS", + expectedStatus: http.StatusOK, + }, + + // Cross-key-type validation + { + name: "management key works for inference endpoint", + keyType: llamactl.KeyTypeInference, + inferenceKeys: []string{}, + managementKeys: []string{"sk-management-admin"}, + requestKey: "sk-management-admin", + method: "POST", + expectedStatus: http.StatusOK, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + config := llamactl.AuthConfig{ + InferenceKeys: tt.inferenceKeys, + ManagementKeys: tt.managementKeys, + } + middleware := llamactl.NewAPIAuthMiddleware(config) + + // Create test request + req := httptest.NewRequest(tt.method, "/test", nil) + if tt.requestKey != "" { + req.Header.Set("Authorization", "Bearer "+tt.requestKey) + } + + // Create test handler using the appropriate middleware + var handler http.Handler + if tt.keyType == llamactl.KeyTypeInference { + handler = middleware.AuthMiddleware(llamactl.KeyTypeInference)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + } else { + handler = middleware.AuthMiddleware(llamactl.KeyTypeManagement)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + } + + // Execute request + recorder := httptest.NewRecorder() + handler.ServeHTTP(recorder, req) + + if recorder.Code != tt.expectedStatus { + t.Errorf("AuthMiddleware() status = %v, expected %v", recorder.Code, tt.expectedStatus) + } + + // Check that unauthorized responses have proper format + if recorder.Code == http.StatusUnauthorized { + contentType := recorder.Header().Get("Content-Type") + if contentType != "application/json" { + t.Errorf("Unauthorized response Content-Type = %v, expected application/json", contentType) + } + + body := recorder.Body.String() + if !strings.Contains(body, `"type": "authentication_error"`) { + t.Errorf("Unauthorized response missing proper error type: %v", body) + } + } + }) + } +} + +func TestGenerateAPIKey(t *testing.T) { + tests := []struct { + name string + keyType llamactl.KeyType + }{ + {"inference key generation", llamactl.KeyTypeInference}, + {"management key generation", llamactl.KeyTypeManagement}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Test auto-generation by creating config that will trigger it + var config llamactl.AuthConfig + if tt.keyType == llamactl.KeyTypeInference { + config.RequireInferenceAuth = true + config.InferenceKeys = []string{} // Empty to trigger generation + } else { + config.RequireManagementAuth = true + config.ManagementKeys = []string{} // Empty to trigger generation + } + + // Create middleware - this should trigger key generation + middleware := llamactl.NewAPIAuthMiddleware(config) + + // Test that auth is required (meaning a key was generated) + req := httptest.NewRequest("GET", "/", nil) + recorder := httptest.NewRecorder() + + var handler http.Handler + if tt.keyType == llamactl.KeyTypeInference { + handler = middleware.AuthMiddleware(llamactl.KeyTypeInference)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + } else { + handler = middleware.AuthMiddleware(llamactl.KeyTypeManagement)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + } + + handler.ServeHTTP(recorder, req) + + // Should be unauthorized without a key (proving that a key was generated and auth is working) + if recorder.Code != http.StatusUnauthorized { + t.Errorf("Expected unauthorized without key, got status %v", recorder.Code) + } + + // Test uniqueness by creating another middleware instance + middleware2 := llamactl.NewAPIAuthMiddleware(config) + + req2 := httptest.NewRequest("GET", "/", nil) + recorder2 := httptest.NewRecorder() + + if tt.keyType == llamactl.KeyTypeInference { + handler2 := middleware2.AuthMiddleware(llamactl.KeyTypeInference)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + handler2.ServeHTTP(recorder2, req2) + } else { + handler2 := middleware2.AuthMiddleware(llamactl.KeyTypeManagement)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + handler2.ServeHTTP(recorder2, req2) + } + + // Both should require auth (proving keys were generated for both instances) + if recorder2.Code != http.StatusUnauthorized { + t.Errorf("Expected unauthorized for second middleware without key, got status %v", recorder2.Code) + } + }) + } +} + +func TestAutoGeneration(t *testing.T) { + tests := []struct { + name string + requireInference bool + requireManagement bool + providedInference []string + providedManagement []string + shouldGenerateInf bool // Whether inference key should be generated + shouldGenerateMgmt bool // Whether management key should be generated + }{ + { + name: "inference auth required, keys provided - no generation", + requireInference: true, + requireManagement: false, + providedInference: []string{"sk-inference-provided"}, + providedManagement: []string{}, + shouldGenerateInf: false, + shouldGenerateMgmt: false, + }, + { + name: "inference auth required, no keys - should auto-generate", + requireInference: true, + requireManagement: false, + providedInference: []string{}, + providedManagement: []string{}, + shouldGenerateInf: true, + shouldGenerateMgmt: false, + }, + { + name: "management auth required, keys provided - no generation", + requireInference: false, + requireManagement: true, + providedInference: []string{}, + providedManagement: []string{"sk-management-provided"}, + shouldGenerateInf: false, + shouldGenerateMgmt: false, + }, + { + name: "management auth required, no keys - should auto-generate", + requireInference: false, + requireManagement: true, + providedInference: []string{}, + providedManagement: []string{}, + shouldGenerateInf: false, + shouldGenerateMgmt: true, + }, + { + name: "both required, both provided - no generation", + requireInference: true, + requireManagement: true, + providedInference: []string{"sk-inference-provided"}, + providedManagement: []string{"sk-management-provided"}, + shouldGenerateInf: false, + shouldGenerateMgmt: false, + }, + { + name: "both required, none provided - should auto-generate both", + requireInference: true, + requireManagement: true, + providedInference: []string{}, + providedManagement: []string{}, + shouldGenerateInf: true, + shouldGenerateMgmt: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + config := llamactl.AuthConfig{ + RequireInferenceAuth: tt.requireInference, + RequireManagementAuth: tt.requireManagement, + InferenceKeys: tt.providedInference, + ManagementKeys: tt.providedManagement, + } + + middleware := llamactl.NewAPIAuthMiddleware(config) + + // Test inference behavior if inference auth is required + if tt.requireInference { + req := httptest.NewRequest("GET", "/v1/models", nil) + recorder := httptest.NewRecorder() + + handler := middleware.AuthMiddleware(llamactl.KeyTypeInference)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + handler.ServeHTTP(recorder, req) + + // Should always be unauthorized without a key (since middleware assumes auth is required) + if recorder.Code != http.StatusUnauthorized { + t.Errorf("Expected unauthorized for inference without key, got status %v", recorder.Code) + } + } + + // Test management behavior if management auth is required + if tt.requireManagement { + req := httptest.NewRequest("GET", "/api/v1/instances", nil) + recorder := httptest.NewRecorder() + + handler := middleware.AuthMiddleware(llamactl.KeyTypeManagement)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + handler.ServeHTTP(recorder, req) + + // Should always be unauthorized without a key (since middleware assumes auth is required) + if recorder.Code != http.StatusUnauthorized { + t.Errorf("Expected unauthorized for management without key, got status %v", recorder.Code) + } + } + }) + } +} diff --git a/pkg/routes.go b/pkg/routes.go index 998af38..f0c3c54 100644 --- a/pkg/routes.go +++ b/pkg/routes.go @@ -39,7 +39,7 @@ func SetupRouter(handler *Handler) *chi.Mux { r.Route("/api/v1", func(r chi.Router) { if authMiddleware != nil && handler.config.Auth.RequireManagementAuth { - r.Use(authMiddleware.ManagementMiddleware()) + r.Use(authMiddleware.AuthMiddleware(KeyTypeManagement)) } r.Route("/server", func(r chi.Router) { @@ -74,7 +74,7 @@ func SetupRouter(handler *Handler) *chi.Mux { r.Route(("/v1"), func(r chi.Router) { if authMiddleware != nil && handler.config.Auth.RequireInferenceAuth { - r.Use(authMiddleware.InferenceMiddleware()) + r.Use(authMiddleware.AuthMiddleware(KeyTypeInference)) } r.Get(("/models"), handler.OpenAIListInstances()) // List instances in OpenAI-compatible format From 4d06bc487a2b04c4716763ede3dbd96b978119ff Mon Sep 17 00:00:00 2001 From: LordMathis Date: Wed, 30 Jul 2025 21:31:20 +0200 Subject: [PATCH 3/4] Update README for api key auth --- README.md | 44 ++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 42 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 1a9cc8d..36f6891 100644 --- a/README.md +++ b/README.md @@ -54,12 +54,14 @@ go build -o llamactl ./cmd/server ## Configuration + llamactl can be configured via configuration files or environment variables. Configuration is loaded in the following order of precedence: 1. Hardcoded defaults 2. Configuration file 3. Environment variables + ### Configuration Files Configuration files are searched in the following locations: @@ -76,19 +78,35 @@ Configuration files are searched in the following locations: You can specify the path to config file with `LLAMACTL_CONFIG_PATH` environment variable +## API Key Authentication + +llamactl now supports API Key authentication for both management and inference (OpenAI-compatible) endpoints. The are separate keys for management and inference APIs. Management keys grant full access; inference keys grant access to OpenAI-compatible endpoints + +**How to Use:** +- Pass your API key in requests using one of: + - `Authorization: Bearer ` header + - `X-API-Key: ` header + - `api_key=` query parameter + +**Auto-generated keys**: If no keys are set and authentication is required, a key will be generated and printed to the terminal at startup. For production, set your own keys in config or environment variables. + ### Configuration Options #### Server Configuration ```yaml server: - host: "" # Server host to bind to (default: "") - port: 8080 # Server port to bind to (default: 8080) + host: "0.0.0.0" # Server host to bind to (default: "0.0.0.0") + port: 8080 # Server port to bind to (default: 8080) + allowed_origins: ["*"] # CORS allowed origins (default: ["*"]) + enable_swagger: false # Enable Swagger UI (default: false) ``` **Environment Variables:** - `LLAMACTL_HOST` - Server host - `LLAMACTL_PORT` - Server port +- `LLAMACTL_ALLOWED_ORIGINS` - Comma-separated CORS origins +- `LLAMACTL_ENABLE_SWAGGER` - Enable Swagger UI (true/false) #### Instance Configuration @@ -112,6 +130,22 @@ instances: - `LLAMACTL_DEFAULT_MAX_RESTARTS` - Default maximum restarts - `LLAMACTL_DEFAULT_RESTART_DELAY` - Default restart delay in seconds +#### Auth Configuration + +```yaml +auth: + require_inference_auth: true # Require API key for OpenAI endpoints (default: true) + inference_keys: [] # List of valid inference API keys + require_management_auth: true # Require API key for management endpoints (default: true) + management_keys: [] # List of valid management API keys +``` + +**Environment Variables:** +- `LLAMACTL_REQUIRE_INFERENCE_AUTH` - Require auth for OpenAI endpoints (true/false) +- `LLAMACTL_INFERENCE_KEYS` - Comma-separated inference API keys +- `LLAMACTL_REQUIRE_MANAGEMENT_AUTH` - Require auth for management endpoints (true/false) +- `LLAMACTL_MANAGEMENT_KEYS` - Comma-separated management API keys + ### Example Configuration ```yaml @@ -127,6 +161,12 @@ instances: default_auto_restart: true default_max_restarts: 5 default_restart_delay: 10 + +auth: + require_inference_auth: true + inference_keys: ["sk-inference-abc123"] + require_management_auth: true + management_keys: ["sk-management-xyz456"] ``` ## Usage From 8e8056f07182929e047e759b8d79c1ff7eedf047 Mon Sep 17 00:00:00 2001 From: LordMathis Date: Wed, 30 Jul 2025 21:34:46 +0200 Subject: [PATCH 4/4] Update swagger docs --- docs/docs.go | 82 ++++++++++++++++++++++++++++++++++++++++++++++- docs/swagger.json | 82 ++++++++++++++++++++++++++++++++++++++++++++++- docs/swagger.yaml | 35 +++++++++++++++++++- pkg/handlers.go | 17 +++++++++- 4 files changed, 212 insertions(+), 4 deletions(-) diff --git a/docs/docs.go b/docs/docs.go index 6248bd8..0448820 100644 --- a/docs/docs.go +++ b/docs/docs.go @@ -21,6 +21,11 @@ const docTemplate = `{ "paths": { "/instances": { "get": { + "security": [ + { + "ApiKeyAuth": [] + } + ], "description": "Returns a list of all instances managed by the server", "tags": [ "instances" @@ -47,6 +52,11 @@ const docTemplate = `{ }, "/instances/{name}": { "get": { + "security": [ + { + "ApiKeyAuth": [] + } + ], "description": "Returns the details of a specific instance by name", "tags": [ "instances" @@ -83,6 +93,11 @@ const docTemplate = `{ } }, "put": { + "security": [ + { + "ApiKeyAuth": [] + } + ], "description": "Updates the configuration of a specific instance by name", "consumes": [ "application/json" @@ -131,6 +146,11 @@ const docTemplate = `{ } }, "post": { + "security": [ + { + "ApiKeyAuth": [] + } + ], "description": "Creates a new instance with the provided configuration options", "consumes": [ "application/json" @@ -179,6 +199,11 @@ const docTemplate = `{ } }, "delete": { + "security": [ + { + "ApiKeyAuth": [] + } + ], "description": "Stops and removes a specific instance by name", "tags": [ "instances" @@ -214,6 +239,11 @@ const docTemplate = `{ }, "/instances/{name}/logs": { "get": { + "security": [ + { + "ApiKeyAuth": [] + } + ], "description": "Returns the logs from a specific instance by name with optional line limit", "tags": [ "instances" @@ -258,6 +288,11 @@ const docTemplate = `{ }, "/instances/{name}/proxy": { "get": { + "security": [ + { + "ApiKeyAuth": [] + } + ], "description": "Forwards HTTP requests to the llama-server instance running on a specific port", "tags": [ "instances" @@ -297,6 +332,11 @@ const docTemplate = `{ } }, "post": { + "security": [ + { + "ApiKeyAuth": [] + } + ], "description": "Forwards HTTP requests to the llama-server instance running on a specific port", "tags": [ "instances" @@ -338,6 +378,11 @@ const docTemplate = `{ }, "/instances/{name}/restart": { "post": { + "security": [ + { + "ApiKeyAuth": [] + } + ], "description": "Restarts a specific instance by name", "tags": [ "instances" @@ -376,6 +421,11 @@ const docTemplate = `{ }, "/instances/{name}/start": { "post": { + "security": [ + { + "ApiKeyAuth": [] + } + ], "description": "Starts a specific instance by name", "tags": [ "instances" @@ -414,6 +464,11 @@ const docTemplate = `{ }, "/instances/{name}/stop": { "post": { + "security": [ + { + "ApiKeyAuth": [] + } + ], "description": "Stops a specific instance by name", "tags": [ "instances" @@ -452,6 +507,11 @@ const docTemplate = `{ }, "/server/devices": { "get": { + "security": [ + { + "ApiKeyAuth": [] + } + ], "description": "Returns a list of available devices for the llama server", "tags": [ "server" @@ -475,6 +535,11 @@ const docTemplate = `{ }, "/server/help": { "get": { + "security": [ + { + "ApiKeyAuth": [] + } + ], "description": "Returns the help text for the llama server command", "tags": [ "server" @@ -498,6 +563,11 @@ const docTemplate = `{ }, "/server/version": { "get": { + "security": [ + { + "ApiKeyAuth": [] + } + ], "description": "Returns the version of the llama server command", "tags": [ "server" @@ -521,7 +591,12 @@ const docTemplate = `{ }, "/v1/": { "post": { - "description": "Handles all POST requests to /v1/*, routing to the appropriate instance based on the request body", + "security": [ + { + "ApiKeyAuth": [] + } + ], + "description": "Handles all POST requests to /v1/*, routing to the appropriate instance based on the request body. Requires API key authentication via the ` + "`" + `Authorization` + "`" + ` header.", "consumes": [ "application/json" ], @@ -550,6 +625,11 @@ const docTemplate = `{ }, "/v1/models": { "get": { + "security": [ + { + "ApiKeyAuth": [] + } + ], "description": "Returns a list of instances in a format compatible with OpenAI API", "tags": [ "openai" diff --git a/docs/swagger.json b/docs/swagger.json index ae75018..87168b7 100644 --- a/docs/swagger.json +++ b/docs/swagger.json @@ -14,6 +14,11 @@ "paths": { "/instances": { "get": { + "security": [ + { + "ApiKeyAuth": [] + } + ], "description": "Returns a list of all instances managed by the server", "tags": [ "instances" @@ -40,6 +45,11 @@ }, "/instances/{name}": { "get": { + "security": [ + { + "ApiKeyAuth": [] + } + ], "description": "Returns the details of a specific instance by name", "tags": [ "instances" @@ -76,6 +86,11 @@ } }, "put": { + "security": [ + { + "ApiKeyAuth": [] + } + ], "description": "Updates the configuration of a specific instance by name", "consumes": [ "application/json" @@ -124,6 +139,11 @@ } }, "post": { + "security": [ + { + "ApiKeyAuth": [] + } + ], "description": "Creates a new instance with the provided configuration options", "consumes": [ "application/json" @@ -172,6 +192,11 @@ } }, "delete": { + "security": [ + { + "ApiKeyAuth": [] + } + ], "description": "Stops and removes a specific instance by name", "tags": [ "instances" @@ -207,6 +232,11 @@ }, "/instances/{name}/logs": { "get": { + "security": [ + { + "ApiKeyAuth": [] + } + ], "description": "Returns the logs from a specific instance by name with optional line limit", "tags": [ "instances" @@ -251,6 +281,11 @@ }, "/instances/{name}/proxy": { "get": { + "security": [ + { + "ApiKeyAuth": [] + } + ], "description": "Forwards HTTP requests to the llama-server instance running on a specific port", "tags": [ "instances" @@ -290,6 +325,11 @@ } }, "post": { + "security": [ + { + "ApiKeyAuth": [] + } + ], "description": "Forwards HTTP requests to the llama-server instance running on a specific port", "tags": [ "instances" @@ -331,6 +371,11 @@ }, "/instances/{name}/restart": { "post": { + "security": [ + { + "ApiKeyAuth": [] + } + ], "description": "Restarts a specific instance by name", "tags": [ "instances" @@ -369,6 +414,11 @@ }, "/instances/{name}/start": { "post": { + "security": [ + { + "ApiKeyAuth": [] + } + ], "description": "Starts a specific instance by name", "tags": [ "instances" @@ -407,6 +457,11 @@ }, "/instances/{name}/stop": { "post": { + "security": [ + { + "ApiKeyAuth": [] + } + ], "description": "Stops a specific instance by name", "tags": [ "instances" @@ -445,6 +500,11 @@ }, "/server/devices": { "get": { + "security": [ + { + "ApiKeyAuth": [] + } + ], "description": "Returns a list of available devices for the llama server", "tags": [ "server" @@ -468,6 +528,11 @@ }, "/server/help": { "get": { + "security": [ + { + "ApiKeyAuth": [] + } + ], "description": "Returns the help text for the llama server command", "tags": [ "server" @@ -491,6 +556,11 @@ }, "/server/version": { "get": { + "security": [ + { + "ApiKeyAuth": [] + } + ], "description": "Returns the version of the llama server command", "tags": [ "server" @@ -514,7 +584,12 @@ }, "/v1/": { "post": { - "description": "Handles all POST requests to /v1/*, routing to the appropriate instance based on the request body", + "security": [ + { + "ApiKeyAuth": [] + } + ], + "description": "Handles all POST requests to /v1/*, routing to the appropriate instance based on the request body. Requires API key authentication via the `Authorization` header.", "consumes": [ "application/json" ], @@ -543,6 +618,11 @@ }, "/v1/models": { "get": { + "security": [ + { + "ApiKeyAuth": [] + } + ], "description": "Returns a list of instances in a format compatible with OpenAI API", "tags": [ "openai" diff --git a/docs/swagger.yaml b/docs/swagger.yaml index cf45768..d6a7433 100644 --- a/docs/swagger.yaml +++ b/docs/swagger.yaml @@ -399,6 +399,8 @@ paths: description: Internal Server Error schema: type: string + security: + - ApiKeyAuth: [] summary: List all instances tags: - instances @@ -422,6 +424,8 @@ paths: description: Internal Server Error schema: type: string + security: + - ApiKeyAuth: [] summary: Delete an instance tags: - instances @@ -446,6 +450,8 @@ paths: description: Internal Server Error schema: type: string + security: + - ApiKeyAuth: [] summary: Get details of a specific instance tags: - instances @@ -478,6 +484,8 @@ paths: description: Internal Server Error schema: type: string + security: + - ApiKeyAuth: [] summary: Create and start a new instance tags: - instances @@ -510,6 +518,8 @@ paths: description: Internal Server Error schema: type: string + security: + - ApiKeyAuth: [] summary: Update an instance's configuration tags: - instances @@ -540,6 +550,8 @@ paths: description: Internal Server Error schema: type: string + security: + - ApiKeyAuth: [] summary: Get logs from a specific instance tags: - instances @@ -568,6 +580,8 @@ paths: description: Instance is not running schema: type: string + security: + - ApiKeyAuth: [] summary: Proxy requests to a specific instance tags: - instances @@ -595,6 +609,8 @@ paths: description: Instance is not running schema: type: string + security: + - ApiKeyAuth: [] summary: Proxy requests to a specific instance tags: - instances @@ -620,6 +636,8 @@ paths: description: Internal Server Error schema: type: string + security: + - ApiKeyAuth: [] summary: Restart a running instance tags: - instances @@ -645,6 +663,8 @@ paths: description: Internal Server Error schema: type: string + security: + - ApiKeyAuth: [] summary: Start a stopped instance tags: - instances @@ -670,6 +690,8 @@ paths: description: Internal Server Error schema: type: string + security: + - ApiKeyAuth: [] summary: Stop a running instance tags: - instances @@ -685,6 +707,8 @@ paths: description: Internal Server Error schema: type: string + security: + - ApiKeyAuth: [] summary: List available devices for llama server tags: - server @@ -700,6 +724,8 @@ paths: description: Internal Server Error schema: type: string + security: + - ApiKeyAuth: [] summary: Get help for llama server tags: - server @@ -715,6 +741,8 @@ paths: description: Internal Server Error schema: type: string + security: + - ApiKeyAuth: [] summary: Get version of llama server tags: - server @@ -723,7 +751,8 @@ paths: consumes: - application/json description: Handles all POST requests to /v1/*, routing to the appropriate - instance based on the request body + instance based on the request body. Requires API key authentication via the + `Authorization` header. responses: "200": description: OpenAI response @@ -735,6 +764,8 @@ paths: description: Internal Server Error schema: type: string + security: + - ApiKeyAuth: [] summary: OpenAI-compatible proxy endpoint tags: - openai @@ -751,6 +782,8 @@ paths: description: Internal Server Error schema: type: string + security: + - ApiKeyAuth: [] summary: List instances in OpenAI-compatible format tags: - openai diff --git a/pkg/handlers.go b/pkg/handlers.go index afaee99..fe2088b 100644 --- a/pkg/handlers.go +++ b/pkg/handlers.go @@ -29,6 +29,7 @@ func NewHandler(im InstanceManager, config Config) *Handler { // @Summary Get help for llama server // @Description Returns the help text for the llama server command // @Tags server +// @Security ApiKeyAuth // @Produces text/plain // @Success 200 {string} string "Help text" // @Failure 500 {string} string "Internal Server Error" @@ -50,6 +51,7 @@ func (h *Handler) HelpHandler() http.HandlerFunc { // @Summary Get version of llama server // @Description Returns the version of the llama server command // @Tags server +// @Security ApiKeyAuth // @Produces text/plain // @Success 200 {string} string "Version information" // @Failure 500 {string} string "Internal Server Error" @@ -71,6 +73,7 @@ func (h *Handler) VersionHandler() http.HandlerFunc { // @Summary List available devices for llama server // @Description Returns a list of available devices for the llama server // @Tags server +// @Security ApiKeyAuth // @Produces text/plain // @Success 200 {string} string "List of devices" // @Failure 500 {string} string "Internal Server Error" @@ -92,6 +95,7 @@ func (h *Handler) ListDevicesHandler() http.HandlerFunc { // @Summary List all instances // @Description Returns a list of all instances managed by the server // @Tags instances +// @Security ApiKeyAuth // @Produces json // @Success 200 {array} Instance "List of instances" // @Failure 500 {string} string "Internal Server Error" @@ -116,6 +120,7 @@ func (h *Handler) ListInstances() http.HandlerFunc { // @Summary Create and start a new instance // @Description Creates a new instance with the provided configuration options // @Tags instances +// @Security ApiKeyAuth // @Accept json // @Produces json // @Param name path string true "Instance Name" @@ -157,6 +162,7 @@ func (h *Handler) CreateInstance() http.HandlerFunc { // @Summary Get details of a specific instance // @Description Returns the details of a specific instance by name // @Tags instances +// @Security ApiKeyAuth // @Produces json // @Param name path string true "Instance Name" // @Success 200 {object} Instance "Instance details" @@ -189,6 +195,7 @@ func (h *Handler) GetInstance() http.HandlerFunc { // @Summary Update an instance's configuration // @Description Updates the configuration of a specific instance by name // @Tags instances +// @Security ApiKeyAuth // @Accept json // @Produces json // @Param name path string true "Instance Name" @@ -229,6 +236,7 @@ func (h *Handler) UpdateInstance() http.HandlerFunc { // @Summary Start a stopped instance // @Description Starts a specific instance by name // @Tags instances +// @Security ApiKeyAuth // @Produces json // @Param name path string true "Instance Name" // @Success 200 {object} Instance "Started instance details" @@ -261,6 +269,7 @@ func (h *Handler) StartInstance() http.HandlerFunc { // @Summary Stop a running instance // @Description Stops a specific instance by name // @Tags instances +// @Security ApiKeyAuth // @Produces json // @Param name path string true "Instance Name" // @Success 200 {object} Instance "Stopped instance details" @@ -293,6 +302,7 @@ func (h *Handler) StopInstance() http.HandlerFunc { // @Summary Restart a running instance // @Description Restarts a specific instance by name // @Tags instances +// @Security ApiKeyAuth // @Produces json // @Param name path string true "Instance Name" // @Success 200 {object} Instance "Restarted instance details" @@ -325,6 +335,7 @@ func (h *Handler) RestartInstance() http.HandlerFunc { // @Summary Delete an instance // @Description Stops and removes a specific instance by name // @Tags instances +// @Security ApiKeyAuth // @Param name path string true "Instance Name" // @Success 204 "No Content" // @Failure 400 {string} string "Invalid name format" @@ -351,6 +362,7 @@ func (h *Handler) DeleteInstance() http.HandlerFunc { // @Summary Get logs from a specific instance // @Description Returns the logs from a specific instance by name with optional line limit // @Tags instances +// @Security ApiKeyAuth // @Param name path string true "Instance Name" // @Param lines query string false "Number of lines to retrieve (default: all lines)" // @Produces text/plain @@ -398,6 +410,7 @@ func (h *Handler) GetInstanceLogs() http.HandlerFunc { // @Summary Proxy requests to a specific instance // @Description Forwards HTTP requests to the llama-server instance running on a specific port // @Tags instances +// @Security ApiKeyAuth // @Param name path string true "Instance Name" // @Success 200 "Request successfully proxied to instance" // @Failure 400 {string} string "Invalid name format" @@ -462,6 +475,7 @@ func (h *Handler) ProxyToInstance() http.HandlerFunc { // @Summary List instances in OpenAI-compatible format // @Description Returns a list of instances in a format compatible with OpenAI API // @Tags openai +// @Security ApiKeyAuth // @Produces json // @Success 200 {object} OpenAIListInstancesResponse "List of OpenAI-compatible instances" // @Failure 500 {string} string "Internal Server Error" @@ -499,8 +513,9 @@ func (h *Handler) OpenAIListInstances() http.HandlerFunc { // OpenAIProxy godoc // @Summary OpenAI-compatible proxy endpoint -// @Description Handles all POST requests to /v1/*, routing to the appropriate instance based on the request body +// @Description Handles all POST requests to /v1/*, routing to the appropriate instance based on the request body. Requires API key authentication via the `Authorization` header. // @Tags openai +// @Security ApiKeyAuth // @Accept json // @Produces json // @Success 200 "OpenAI response"