Implement api key auth

This commit is contained in:
2025-07-30 20:15:14 +02:00
parent 72ba008d1e
commit b3540d5b3e
3 changed files with 298 additions and 13 deletions

View File

@@ -14,6 +14,7 @@ import (
type Config struct {
Server ServerConfig `yaml:"server"`
Instances InstancesConfig `yaml:"instances"`
Auth AuthConfig `yaml:"auth"`
}
// ServerConfig contains HTTP server configuration
@@ -26,6 +27,9 @@ type ServerConfig struct {
// Allowed origins for CORS (e.g., "http://localhost:3000")
AllowedOrigins []string `yaml:"allowed_origins"`
// Enable Swagger UI for API documentation
EnableSwagger bool `yaml:"enable_swagger"`
}
// InstancesConfig contains instance management configuration
@@ -52,6 +56,22 @@ type InstancesConfig struct {
DefaultRestartDelay int `yaml:"default_restart_delay"`
}
// AuthConfig contains authentication settings
type AuthConfig struct {
// Require authentication for OpenAI compatible inference endpoints
RequireInferenceAuth bool `yaml:"require_inference_auth"`
// List of keys for OpenAI compatible inference endpoints
InferenceKeys []string `yaml:"inference_keys"`
// Require authentication for management endpoints
RequireManagementAuth bool `yaml:"require_management_auth"`
// List of keys for management endpoints
ManagementKeys []string `yaml:"management_keys"`
}
// LoadConfig loads configuration with the following precedence:
// 1. Hardcoded defaults
// 2. Config file
@@ -63,6 +83,7 @@ func LoadConfig(configPath string) (Config, error) {
Host: "0.0.0.0",
Port: 8080,
AllowedOrigins: []string{"*"}, // Default to allow all origins
EnableSwagger: false,
},
Instances: InstancesConfig{
PortRange: [2]int{8000, 9000},
@@ -73,6 +94,12 @@ func LoadConfig(configPath string) (Config, error) {
DefaultMaxRestarts: 3,
DefaultRestartDelay: 5,
},
Auth: AuthConfig{
RequireInferenceAuth: true,
InferenceKeys: []string{},
RequireManagementAuth: true,
ManagementKeys: []string{},
},
}
// 2. Load from config file
@@ -121,6 +148,14 @@ func loadEnvVars(cfg *Config) {
cfg.Server.Port = p
}
}
if allowedOrigins := os.Getenv("LLAMACTL_ALLOWED_ORIGINS"); allowedOrigins != "" {
cfg.Server.AllowedOrigins = strings.Split(allowedOrigins, ",")
}
if enableSwagger := os.Getenv("LLAMACTL_ENABLE_SWAGGER"); enableSwagger != "" {
if b, err := strconv.ParseBool(enableSwagger); err == nil {
cfg.Server.EnableSwagger = b
}
}
// Instance config
if portRange := os.Getenv("LLAMACTL_INSTANCE_PORT_RANGE"); portRange != "" {
@@ -154,6 +189,23 @@ func loadEnvVars(cfg *Config) {
cfg.Instances.DefaultRestartDelay = seconds
}
}
// Auth config
if requireInferenceAuth := os.Getenv("LLAMACTL_REQUIRE_INFERENCE_AUTH"); requireInferenceAuth != "" {
if b, err := strconv.ParseBool(requireInferenceAuth); err == nil {
cfg.Auth.RequireInferenceAuth = b
}
}
if inferenceKeys := os.Getenv("LLAMACTL_INFERENCE_KEYS"); inferenceKeys != "" {
cfg.Auth.InferenceKeys = strings.Split(inferenceKeys, ",")
}
if requireManagementAuth := os.Getenv("LLAMACTL_REQUIRE_MANAGEMENT_AUTH"); requireManagementAuth != "" {
if b, err := strconv.ParseBool(requireManagementAuth); err == nil {
cfg.Auth.RequireManagementAuth = b
}
}
if managementKeys := os.Getenv("LLAMACTL_MANAGEMENT_KEYS"); managementKeys != "" {
cfg.Auth.ManagementKeys = strings.Split(managementKeys, ",")
}
}
// ParsePortRange parses port range from string formats like "8000-9000" or "8000,9000"