diff --git a/pkg/config.go b/pkg/config.go index 7e7256a..d5b4571 100644 --- a/pkg/config.go +++ b/pkg/config.go @@ -14,6 +14,7 @@ import ( type Config struct { Server ServerConfig `yaml:"server"` Instances InstancesConfig `yaml:"instances"` + Auth AuthConfig `yaml:"auth"` } // ServerConfig contains HTTP server configuration @@ -26,6 +27,9 @@ type ServerConfig struct { // Allowed origins for CORS (e.g., "http://localhost:3000") AllowedOrigins []string `yaml:"allowed_origins"` + + // Enable Swagger UI for API documentation + EnableSwagger bool `yaml:"enable_swagger"` } // InstancesConfig contains instance management configuration @@ -52,6 +56,22 @@ type InstancesConfig struct { DefaultRestartDelay int `yaml:"default_restart_delay"` } +// AuthConfig contains authentication settings +type AuthConfig struct { + + // Require authentication for OpenAI compatible inference endpoints + RequireInferenceAuth bool `yaml:"require_inference_auth"` + + // List of keys for OpenAI compatible inference endpoints + InferenceKeys []string `yaml:"inference_keys"` + + // Require authentication for management endpoints + RequireManagementAuth bool `yaml:"require_management_auth"` + + // List of keys for management endpoints + ManagementKeys []string `yaml:"management_keys"` +} + // LoadConfig loads configuration with the following precedence: // 1. Hardcoded defaults // 2. Config file @@ -63,6 +83,7 @@ func LoadConfig(configPath string) (Config, error) { Host: "0.0.0.0", Port: 8080, AllowedOrigins: []string{"*"}, // Default to allow all origins + EnableSwagger: false, }, Instances: InstancesConfig{ PortRange: [2]int{8000, 9000}, @@ -73,6 +94,12 @@ func LoadConfig(configPath string) (Config, error) { DefaultMaxRestarts: 3, DefaultRestartDelay: 5, }, + Auth: AuthConfig{ + RequireInferenceAuth: true, + InferenceKeys: []string{}, + RequireManagementAuth: true, + ManagementKeys: []string{}, + }, } // 2. Load from config file @@ -121,6 +148,14 @@ func loadEnvVars(cfg *Config) { cfg.Server.Port = p } } + if allowedOrigins := os.Getenv("LLAMACTL_ALLOWED_ORIGINS"); allowedOrigins != "" { + cfg.Server.AllowedOrigins = strings.Split(allowedOrigins, ",") + } + if enableSwagger := os.Getenv("LLAMACTL_ENABLE_SWAGGER"); enableSwagger != "" { + if b, err := strconv.ParseBool(enableSwagger); err == nil { + cfg.Server.EnableSwagger = b + } + } // Instance config if portRange := os.Getenv("LLAMACTL_INSTANCE_PORT_RANGE"); portRange != "" { @@ -154,6 +189,23 @@ func loadEnvVars(cfg *Config) { cfg.Instances.DefaultRestartDelay = seconds } } + // Auth config + if requireInferenceAuth := os.Getenv("LLAMACTL_REQUIRE_INFERENCE_AUTH"); requireInferenceAuth != "" { + if b, err := strconv.ParseBool(requireInferenceAuth); err == nil { + cfg.Auth.RequireInferenceAuth = b + } + } + if inferenceKeys := os.Getenv("LLAMACTL_INFERENCE_KEYS"); inferenceKeys != "" { + cfg.Auth.InferenceKeys = strings.Split(inferenceKeys, ",") + } + if requireManagementAuth := os.Getenv("LLAMACTL_REQUIRE_MANAGEMENT_AUTH"); requireManagementAuth != "" { + if b, err := strconv.ParseBool(requireManagementAuth); err == nil { + cfg.Auth.RequireManagementAuth = b + } + } + if managementKeys := os.Getenv("LLAMACTL_MANAGEMENT_KEYS"); managementKeys != "" { + cfg.Auth.ManagementKeys = strings.Split(managementKeys, ",") + } } // ParsePortRange parses port range from string formats like "8000-9000" or "8000,9000" diff --git a/pkg/middleware.go b/pkg/middleware.go new file mode 100644 index 0000000..13cc800 --- /dev/null +++ b/pkg/middleware.go @@ -0,0 +1,215 @@ +package llamactl + +import ( + "crypto/rand" + "crypto/subtle" + "encoding/hex" + "fmt" + "log" + "net/http" + "os" + "strings" +) + +type KeyType int + +const ( + KeyTypeInference KeyType = iota + KeyTypeManagement +) + +type APIAuthMiddleware struct { + requireInferenceAuth bool + inferenceKeys map[string]bool + requireManagementAuth bool + managementKeys map[string]bool +} + +// NewAPIAuthMiddleware creates a new APIAuthMiddleware with the given configuration +func NewAPIAuthMiddleware(config AuthConfig) *APIAuthMiddleware { + + var generated bool = false + + inferenceAPIKeys := make(map[string]bool) + managementAPIKeys := make(map[string]bool) + + const banner = "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" + + if config.RequireManagementAuth && len(config.ManagementKeys) == 0 { + key := generateAPIKey(KeyTypeManagement) + managementAPIKeys[key] = true + generated = true + fmt.Printf("%s\n⚠️ MANAGEMENT AUTHENTICATION REQUIRED\n%s\n", banner, banner) + fmt.Printf("🔑 Generated Management API Key:\n\n %s\n\n", key) + } + for _, key := range config.ManagementKeys { + managementAPIKeys[key] = true + } + + if config.RequireInferenceAuth && len(config.InferenceKeys) == 0 { + key := generateAPIKey(KeyTypeInference) + inferenceAPIKeys[key] = true + generated = true + fmt.Printf("%s\n⚠️ INFERENCE AUTHENTICATION REQUIRED\n%s\n", banner, banner) + fmt.Printf("🔑 Generated Inference API Key:\n\n %s\n\n", key) + } + for _, key := range config.InferenceKeys { + inferenceAPIKeys[key] = true + } + + if generated { + fmt.Printf("%s\n⚠️ IMPORTANT\n%s\n", banner, banner) + fmt.Println("• These keys are auto-generated and will change on restart") + fmt.Println("• For production, add explicit keys to your configuration") + fmt.Println("• Copy these keys before they disappear from the terminal") + fmt.Println(banner) + } + + return &APIAuthMiddleware{ + requireInferenceAuth: config.RequireInferenceAuth, + inferenceKeys: inferenceAPIKeys, + requireManagementAuth: config.RequireManagementAuth, + managementKeys: managementAPIKeys, + } +} + +// generateAPIKey creates a cryptographically secure API key +func generateAPIKey(keyType KeyType) string { + // Generate 32 random bytes (256 bits) + randomBytes := make([]byte, 32) + + var prefix string + + switch keyType { + case KeyTypeInference: + prefix = "sk-inference" + case KeyTypeManagement: + prefix = "sk-management" + default: + prefix = "sk-unknown" + } + + if _, err := rand.Read(randomBytes); err != nil { + log.Printf("Warning: Failed to generate secure random key, using fallback") + // Fallback to a less secure method if crypto/rand fails + return fmt.Sprintf("%s-fallback-%d", prefix, os.Getpid()) + } + + // Convert to hex and add prefix + return fmt.Sprintf("%s-%s", prefix, hex.EncodeToString(randomBytes)) +} + +// InferenceMiddleware returns middleware for OpenAI inference endpoints +func (a *APIAuthMiddleware) InferenceMiddleware() func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !a.requireInferenceAuth { + next.ServeHTTP(w, r) + return + } + + if r.Method == "OPTIONS" { + next.ServeHTTP(w, r) + return + } + + apiKey := a.extractAPIKey(r) + if apiKey == "" { + a.unauthorized(w, "Missing API key") + return + } + + // Check if key is valid for OpenAI access + // Management keys also work for OpenAI endpoints (higher privilege) + if !a.isValidKey(apiKey, KeyTypeInference) && !a.isValidKey(apiKey, KeyTypeManagement) { + a.unauthorized(w, "Invalid API key") + return + } + + next.ServeHTTP(w, r) + }) + } +} + +// ManagementMiddleware returns middleware for management endpoints +func (a *APIAuthMiddleware) ManagementMiddleware() func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !a.requireManagementAuth { + next.ServeHTTP(w, r) + return + } + + if r.Method == "OPTIONS" { + next.ServeHTTP(w, r) + return + } + + apiKey := a.extractAPIKey(r) + if apiKey == "" { + a.unauthorized(w, "Missing API key") + return + } + + // Only management keys work for management endpoints + if !a.isValidKey(apiKey, KeyTypeManagement) { + a.unauthorized(w, "Insufficient privileges - management key required") + return + } + + next.ServeHTTP(w, r) + }) + } +} + +// extractAPIKey extracts the API key from the request +func (a *APIAuthMiddleware) extractAPIKey(r *http.Request) string { + // Check Authorization header: "Bearer sk-..." + if auth := r.Header.Get("Authorization"); auth != "" { + if strings.HasPrefix(auth, "Bearer ") { + return strings.TrimPrefix(auth, "Bearer ") + } + } + + // Check X-API-Key header + if apiKey := r.Header.Get("X-API-Key"); apiKey != "" { + return apiKey + } + + // Check query parameter + if apiKey := r.URL.Query().Get("api_key"); apiKey != "" { + return apiKey + } + + return "" +} + +// isValidKey checks if the provided API key is valid for the given key type +func (a *APIAuthMiddleware) isValidKey(providedKey string, keyType KeyType) bool { + var validKeys map[string]bool + + switch keyType { + case KeyTypeInference: + validKeys = a.inferenceKeys + case KeyTypeManagement: + validKeys = a.managementKeys + default: + return false + } + + for validKey := range validKeys { + if len(providedKey) == len(validKey) && + subtle.ConstantTimeCompare([]byte(providedKey), []byte(validKey)) == 1 { + return true + } + } + return false +} + +// unauthorized sends an unauthorized response +func (a *APIAuthMiddleware) unauthorized(w http.ResponseWriter, message string) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusUnauthorized) + response := fmt.Sprintf(`{"error": {"message": "%s", "type": "authentication_error"}}`, message) + w.Write([]byte(response)) +} diff --git a/pkg/routes.go b/pkg/routes.go index 4685b85..998af38 100644 --- a/pkg/routes.go +++ b/pkg/routes.go @@ -26,12 +26,22 @@ func SetupRouter(handler *Handler) *chi.Mux { MaxAge: 300, })) - r.Get("/swagger/*", httpSwagger.Handler( - httpSwagger.URL("/swagger/doc.json"), - )) + // Add API authentication middleware + authMiddleware := NewAPIAuthMiddleware(handler.config.Auth) + + if handler.config.Server.EnableSwagger { + r.Get("/swagger/*", httpSwagger.Handler( + httpSwagger.URL("/swagger/doc.json"), + )) + } // Define routes r.Route("/api/v1", func(r chi.Router) { + + if authMiddleware != nil && handler.config.Auth.RequireManagementAuth { + r.Use(authMiddleware.ManagementMiddleware()) + } + r.Route("/server", func(r chi.Router) { r.Get("/help", handler.HelpHandler()) r.Get("/version", handler.VersionHandler()) @@ -61,17 +71,25 @@ func SetupRouter(handler *Handler) *chi.Mux { }) }) - r.Get(("/v1/models"), handler.OpenAIListInstances()) // List instances in OpenAI-compatible format + r.Route(("/v1"), func(r chi.Router) { - // OpenAI-compatible proxy endpoint - // Handles all POST requests to /v1/*, including: - // - /v1/completions - // - /v1/chat/completions - // - /v1/embeddings - // - /v1/rerank - // - /v1/reranking - // The instance/model to use is determined by the request body. - r.Post("/v1/*", handler.OpenAIProxy()) + if authMiddleware != nil && handler.config.Auth.RequireInferenceAuth { + r.Use(authMiddleware.InferenceMiddleware()) + } + + r.Get(("/models"), handler.OpenAIListInstances()) // List instances in OpenAI-compatible format + + // OpenAI-compatible proxy endpoint + // Handles all POST requests to /v1/*, including: + // - /v1/completions + // - /v1/chat/completions + // - /v1/embeddings + // - /v1/rerank + // - /v1/reranking + // The instance/model to use is determined by the request body. + r.Post("/*", handler.OpenAIProxy()) + + }) // Serve WebUI files if err := webui.SetupWebUI(r); err != nil {