mirror of
https://github.com/lordmathis/llamactl.git
synced 2025-11-06 09:04:27 +00:00
Implement middleware tests
This commit is contained in:
354
pkg/middleware_test.go
Normal file
354
pkg/middleware_test.go
Normal file
@@ -0,0 +1,354 @@
|
||||
package llamactl_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
llamactl "llamactl/pkg"
|
||||
)
|
||||
|
||||
func TestAuthMiddleware(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
keyType llamactl.KeyType
|
||||
inferenceKeys []string
|
||||
managementKeys []string
|
||||
requestKey string
|
||||
method string
|
||||
expectedStatus int
|
||||
}{
|
||||
// Valid key tests
|
||||
{
|
||||
name: "valid inference key for inference",
|
||||
keyType: llamactl.KeyTypeInference,
|
||||
inferenceKeys: []string{"sk-inference-valid123"},
|
||||
requestKey: "sk-inference-valid123",
|
||||
method: "GET",
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "valid management key for inference", // Management keys work for inference
|
||||
keyType: llamactl.KeyTypeInference,
|
||||
managementKeys: []string{"sk-management-admin123"},
|
||||
requestKey: "sk-management-admin123",
|
||||
method: "GET",
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "valid management key for management",
|
||||
keyType: llamactl.KeyTypeManagement,
|
||||
managementKeys: []string{"sk-management-admin123"},
|
||||
requestKey: "sk-management-admin123",
|
||||
method: "GET",
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
|
||||
// Invalid key tests
|
||||
{
|
||||
name: "inference key for management should fail",
|
||||
keyType: llamactl.KeyTypeManagement,
|
||||
inferenceKeys: []string{"sk-inference-user123"},
|
||||
requestKey: "sk-inference-user123",
|
||||
method: "GET",
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
name: "invalid inference key",
|
||||
keyType: llamactl.KeyTypeInference,
|
||||
inferenceKeys: []string{"sk-inference-valid123"},
|
||||
requestKey: "sk-inference-invalid",
|
||||
method: "GET",
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
name: "missing inference key",
|
||||
keyType: llamactl.KeyTypeInference,
|
||||
inferenceKeys: []string{"sk-inference-valid123"},
|
||||
requestKey: "",
|
||||
method: "GET",
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
name: "invalid management key",
|
||||
keyType: llamactl.KeyTypeManagement,
|
||||
managementKeys: []string{"sk-management-valid123"},
|
||||
requestKey: "sk-management-invalid",
|
||||
method: "GET",
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
name: "missing management key",
|
||||
keyType: llamactl.KeyTypeManagement,
|
||||
managementKeys: []string{"sk-management-valid123"},
|
||||
requestKey: "",
|
||||
method: "GET",
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
},
|
||||
|
||||
// OPTIONS requests should always pass
|
||||
{
|
||||
name: "OPTIONS request bypasses inference auth",
|
||||
keyType: llamactl.KeyTypeInference,
|
||||
inferenceKeys: []string{"sk-inference-valid123"},
|
||||
requestKey: "",
|
||||
method: "OPTIONS",
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "OPTIONS request bypasses management auth",
|
||||
keyType: llamactl.KeyTypeManagement,
|
||||
managementKeys: []string{"sk-management-valid123"},
|
||||
requestKey: "",
|
||||
method: "OPTIONS",
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
|
||||
// Cross-key-type validation
|
||||
{
|
||||
name: "management key works for inference endpoint",
|
||||
keyType: llamactl.KeyTypeInference,
|
||||
inferenceKeys: []string{},
|
||||
managementKeys: []string{"sk-management-admin"},
|
||||
requestKey: "sk-management-admin",
|
||||
method: "POST",
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
config := llamactl.AuthConfig{
|
||||
InferenceKeys: tt.inferenceKeys,
|
||||
ManagementKeys: tt.managementKeys,
|
||||
}
|
||||
middleware := llamactl.NewAPIAuthMiddleware(config)
|
||||
|
||||
// Create test request
|
||||
req := httptest.NewRequest(tt.method, "/test", nil)
|
||||
if tt.requestKey != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+tt.requestKey)
|
||||
}
|
||||
|
||||
// Create test handler using the appropriate middleware
|
||||
var handler http.Handler
|
||||
if tt.keyType == llamactl.KeyTypeInference {
|
||||
handler = middleware.AuthMiddleware(llamactl.KeyTypeInference)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
} else {
|
||||
handler = middleware.AuthMiddleware(llamactl.KeyTypeManagement)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
}
|
||||
|
||||
// Execute request
|
||||
recorder := httptest.NewRecorder()
|
||||
handler.ServeHTTP(recorder, req)
|
||||
|
||||
if recorder.Code != tt.expectedStatus {
|
||||
t.Errorf("AuthMiddleware() status = %v, expected %v", recorder.Code, tt.expectedStatus)
|
||||
}
|
||||
|
||||
// Check that unauthorized responses have proper format
|
||||
if recorder.Code == http.StatusUnauthorized {
|
||||
contentType := recorder.Header().Get("Content-Type")
|
||||
if contentType != "application/json" {
|
||||
t.Errorf("Unauthorized response Content-Type = %v, expected application/json", contentType)
|
||||
}
|
||||
|
||||
body := recorder.Body.String()
|
||||
if !strings.Contains(body, `"type": "authentication_error"`) {
|
||||
t.Errorf("Unauthorized response missing proper error type: %v", body)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateAPIKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
keyType llamactl.KeyType
|
||||
}{
|
||||
{"inference key generation", llamactl.KeyTypeInference},
|
||||
{"management key generation", llamactl.KeyTypeManagement},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Test auto-generation by creating config that will trigger it
|
||||
var config llamactl.AuthConfig
|
||||
if tt.keyType == llamactl.KeyTypeInference {
|
||||
config.RequireInferenceAuth = true
|
||||
config.InferenceKeys = []string{} // Empty to trigger generation
|
||||
} else {
|
||||
config.RequireManagementAuth = true
|
||||
config.ManagementKeys = []string{} // Empty to trigger generation
|
||||
}
|
||||
|
||||
// Create middleware - this should trigger key generation
|
||||
middleware := llamactl.NewAPIAuthMiddleware(config)
|
||||
|
||||
// Test that auth is required (meaning a key was generated)
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
var handler http.Handler
|
||||
if tt.keyType == llamactl.KeyTypeInference {
|
||||
handler = middleware.AuthMiddleware(llamactl.KeyTypeInference)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
} else {
|
||||
handler = middleware.AuthMiddleware(llamactl.KeyTypeManagement)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
}
|
||||
|
||||
handler.ServeHTTP(recorder, req)
|
||||
|
||||
// Should be unauthorized without a key (proving that a key was generated and auth is working)
|
||||
if recorder.Code != http.StatusUnauthorized {
|
||||
t.Errorf("Expected unauthorized without key, got status %v", recorder.Code)
|
||||
}
|
||||
|
||||
// Test uniqueness by creating another middleware instance
|
||||
middleware2 := llamactl.NewAPIAuthMiddleware(config)
|
||||
|
||||
req2 := httptest.NewRequest("GET", "/", nil)
|
||||
recorder2 := httptest.NewRecorder()
|
||||
|
||||
if tt.keyType == llamactl.KeyTypeInference {
|
||||
handler2 := middleware2.AuthMiddleware(llamactl.KeyTypeInference)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
handler2.ServeHTTP(recorder2, req2)
|
||||
} else {
|
||||
handler2 := middleware2.AuthMiddleware(llamactl.KeyTypeManagement)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
handler2.ServeHTTP(recorder2, req2)
|
||||
}
|
||||
|
||||
// Both should require auth (proving keys were generated for both instances)
|
||||
if recorder2.Code != http.StatusUnauthorized {
|
||||
t.Errorf("Expected unauthorized for second middleware without key, got status %v", recorder2.Code)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAutoGeneration(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
requireInference bool
|
||||
requireManagement bool
|
||||
providedInference []string
|
||||
providedManagement []string
|
||||
shouldGenerateInf bool // Whether inference key should be generated
|
||||
shouldGenerateMgmt bool // Whether management key should be generated
|
||||
}{
|
||||
{
|
||||
name: "inference auth required, keys provided - no generation",
|
||||
requireInference: true,
|
||||
requireManagement: false,
|
||||
providedInference: []string{"sk-inference-provided"},
|
||||
providedManagement: []string{},
|
||||
shouldGenerateInf: false,
|
||||
shouldGenerateMgmt: false,
|
||||
},
|
||||
{
|
||||
name: "inference auth required, no keys - should auto-generate",
|
||||
requireInference: true,
|
||||
requireManagement: false,
|
||||
providedInference: []string{},
|
||||
providedManagement: []string{},
|
||||
shouldGenerateInf: true,
|
||||
shouldGenerateMgmt: false,
|
||||
},
|
||||
{
|
||||
name: "management auth required, keys provided - no generation",
|
||||
requireInference: false,
|
||||
requireManagement: true,
|
||||
providedInference: []string{},
|
||||
providedManagement: []string{"sk-management-provided"},
|
||||
shouldGenerateInf: false,
|
||||
shouldGenerateMgmt: false,
|
||||
},
|
||||
{
|
||||
name: "management auth required, no keys - should auto-generate",
|
||||
requireInference: false,
|
||||
requireManagement: true,
|
||||
providedInference: []string{},
|
||||
providedManagement: []string{},
|
||||
shouldGenerateInf: false,
|
||||
shouldGenerateMgmt: true,
|
||||
},
|
||||
{
|
||||
name: "both required, both provided - no generation",
|
||||
requireInference: true,
|
||||
requireManagement: true,
|
||||
providedInference: []string{"sk-inference-provided"},
|
||||
providedManagement: []string{"sk-management-provided"},
|
||||
shouldGenerateInf: false,
|
||||
shouldGenerateMgmt: false,
|
||||
},
|
||||
{
|
||||
name: "both required, none provided - should auto-generate both",
|
||||
requireInference: true,
|
||||
requireManagement: true,
|
||||
providedInference: []string{},
|
||||
providedManagement: []string{},
|
||||
shouldGenerateInf: true,
|
||||
shouldGenerateMgmt: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
config := llamactl.AuthConfig{
|
||||
RequireInferenceAuth: tt.requireInference,
|
||||
RequireManagementAuth: tt.requireManagement,
|
||||
InferenceKeys: tt.providedInference,
|
||||
ManagementKeys: tt.providedManagement,
|
||||
}
|
||||
|
||||
middleware := llamactl.NewAPIAuthMiddleware(config)
|
||||
|
||||
// Test inference behavior if inference auth is required
|
||||
if tt.requireInference {
|
||||
req := httptest.NewRequest("GET", "/v1/models", nil)
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
handler := middleware.AuthMiddleware(llamactl.KeyTypeInference)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
handler.ServeHTTP(recorder, req)
|
||||
|
||||
// Should always be unauthorized without a key (since middleware assumes auth is required)
|
||||
if recorder.Code != http.StatusUnauthorized {
|
||||
t.Errorf("Expected unauthorized for inference without key, got status %v", recorder.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// Test management behavior if management auth is required
|
||||
if tt.requireManagement {
|
||||
req := httptest.NewRequest("GET", "/api/v1/instances", nil)
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
handler := middleware.AuthMiddleware(llamactl.KeyTypeManagement)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
handler.ServeHTTP(recorder, req)
|
||||
|
||||
// Should always be unauthorized without a key (since middleware assumes auth is required)
|
||||
if recorder.Code != http.StatusUnauthorized {
|
||||
t.Errorf("Expected unauthorized for management without key, got status %v", recorder.Code)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user