From dab23e487b186f2042dafb6031f5c30f4da9c901 Mon Sep 17 00:00:00 2001 From: LordMathis Date: Sat, 19 Jul 2025 16:20:44 +0200 Subject: [PATCH] Enhance instance creation with name validation and security checks --- server/pkg/handlers.go | 8 ++- server/pkg/manager.go | 41 +++++++------- server/pkg/options.go | 2 - server/pkg/routes.go | 4 +- server/pkg/validation.go | 115 +++++++++++++++++++++++++++++++++++++++ 5 files changed, 146 insertions(+), 24 deletions(-) create mode 100644 server/pkg/validation.go diff --git a/server/pkg/handlers.go b/server/pkg/handlers.go index c03c467..5a00350 100644 --- a/server/pkg/handlers.go +++ b/server/pkg/handlers.go @@ -120,13 +120,19 @@ func (h *Handler) ListInstances() http.HandlerFunc { // @Router /instances [post] func (h *Handler) CreateInstance() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { + name := chi.URLParam(r, "name") + if name == "" { + http.Error(w, "Instance name cannot be empty", http.StatusBadRequest) + return + } + var options InstanceOptions if err := json.NewDecoder(r.Body).Decode(&options); err != nil { http.Error(w, "Invalid request body", http.StatusBadRequest) return } - instance, err := h.InstanceManager.CreateInstance(&options) + instance, err := h.InstanceManager.CreateInstance(name, &options) if err != nil { http.Error(w, "Failed to create instance: "+err.Error(), http.StatusInternalServerError) return diff --git a/server/pkg/manager.go b/server/pkg/manager.go index 2bc6e28..139de70 100644 --- a/server/pkg/manager.go +++ b/server/pkg/manager.go @@ -7,7 +7,7 @@ import ( // InstanceManager defines the interface for managing instances of the llama server. type InstanceManager interface { ListInstances() ([]*Instance, error) - CreateInstance(options *InstanceOptions) (*Instance, error) + CreateInstance(name string, options *InstanceOptions) (*Instance, error) GetInstance(name string) (*Instance, error) UpdateInstance(name string, options *InstanceOptions) (*Instance, error) DeleteInstance(name string) error @@ -43,19 +43,24 @@ func (im *instanceManager) ListInstances() ([]*Instance, error) { // CreateInstance creates a new instance with the given options and returns it. // The instance is initially in a "stopped" state. -func (im *instanceManager) CreateInstance(options *InstanceOptions) (*Instance, error) { +func (im *instanceManager) CreateInstance(name string, options *InstanceOptions) (*Instance, error) { if options == nil { return nil, fmt.Errorf("instance options cannot be nil") } - // Check if name is provided - if options.Name == "" || !isValidInstanceName(options.Name) { - return nil, fmt.Errorf("invalid instance name: %s", options.Name) + err := ValidateInstanceName(name) + if err != nil { + return nil, err + } + + err = ValidateInstanceOptions(options) + if err != nil { + return nil, err } // Check if instance with this name already exists - if im.instances[options.Name] != nil { - return nil, fmt.Errorf("instance with name %s already exists", options.Name) + if im.instances[name] != nil { + return nil, fmt.Errorf("instance with name %s already exists", name) } // Assign a port if not specified @@ -67,23 +72,12 @@ func (im *instanceManager) CreateInstance(options *InstanceOptions) (*Instance, options.Port = port } - instance := NewInstance(options.Name, options) + instance := NewInstance(name, options) im.instances[instance.Name] = instance return instance, nil } -// isValidInstanceName checks if the instance name is valid. -func isValidInstanceName(name string) bool { - // A simple validation: name should only contain alphanumeric characters, dashes, and underscores - for _, char := range name { - if !(('a' <= char && char <= 'z') || ('A' <= char && char <= 'Z') || ('0' <= char && char <= '9') || char == '-' || char == '_') { - return false - } - } - return true -} - // GetInstance retrieves an instance by its name. func (im *instanceManager) GetInstance(name string) (*Instance, error) { instance, exists := im.instances[name] @@ -100,6 +94,15 @@ func (im *instanceManager) UpdateInstance(name string, options *InstanceOptions) return nil, fmt.Errorf("instance with name %s not found", name) } + if options == nil { + return nil, fmt.Errorf("instance options cannot be nil") + } + + err := ValidateInstanceOptions(options) + if err != nil { + return nil, err + } + instance.SetOptions(options) return instance, nil } diff --git a/server/pkg/options.go b/server/pkg/options.go index 1d053cb..d3113b0 100644 --- a/server/pkg/options.go +++ b/server/pkg/options.go @@ -33,8 +33,6 @@ func (d Duration) ToDuration() time.Duration { } type InstanceOptions struct { - Name string `json:"name,omitempty"` // Display name - // Auto restart AutoRestart bool `json:"auto_restart,omitempty"` MaxRestarts int `json:"max_restarts,omitempty"` diff --git a/server/pkg/routes.go b/server/pkg/routes.go index 502971b..0570913 100644 --- a/server/pkg/routes.go +++ b/server/pkg/routes.go @@ -26,12 +26,12 @@ func SetupRouter(handler *Handler) *chi.Mux { // Instance management endpoints r.Route("/instances", func(r chi.Router) { - r.Get("/", handler.ListInstances()) // List all instances - r.Post("/", handler.CreateInstance()) // Create and start new instance + r.Get("/", handler.ListInstances()) // List all instances r.Route("/{name}", func(r chi.Router) { // Instance management r.Get("/", handler.GetInstance()) // Get instance details + r.Post("/", handler.CreateInstance()) // Create and start new instance r.Put("/", handler.UpdateInstance()) // Update instance configuration r.Delete("/", handler.DeleteInstance()) // Stop and remove instance r.Post("/start", handler.StartInstance()) // Start stopped instance diff --git a/server/pkg/validation.go b/server/pkg/validation.go new file mode 100644 index 0000000..3d0bc23 --- /dev/null +++ b/server/pkg/validation.go @@ -0,0 +1,115 @@ +package llamactl + +import ( + "fmt" + "reflect" + "regexp" +) + +// Simple security validation that focuses only on actual injection risks +var ( + // Block shell metacharacters that could enable command injection + dangerousPatterns = []*regexp.Regexp{ + regexp.MustCompile(`[;&|$` + "`" + `]`), // Shell metacharacters + regexp.MustCompile(`\$\(.*\)`), // Command substitution $(...) + regexp.MustCompile("`.*`"), // Command substitution backticks + } + + // Simple validation for instance names + validNamePattern = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`) +) + +type ValidationError error + +// validateStringForInjection checks if a string contains dangerous patterns +func validateStringForInjection(value string) error { + for _, pattern := range dangerousPatterns { + if pattern.MatchString(value) { + return ValidationError(fmt.Errorf("value contains potentially dangerous characters: %s", value)) + } + } + return nil +} + +// ValidateInstanceOptions performs minimal security validation +func ValidateInstanceOptions(options *InstanceOptions) error { + if options == nil { + return ValidationError(fmt.Errorf("options cannot be nil")) + } + + // Use reflection to check all string fields for injection patterns + if err := validateStructStrings(&options.LlamaServerOptions, ""); err != nil { + return err + } + + // Basic network validation - only check for reasonable ranges + if options.Port < 0 || options.Port > 65535 { + return ValidationError(fmt.Errorf("invalid port range")) + } + + return nil +} + +// validateStructStrings recursively validates all string fields in a struct +func validateStructStrings(v any, fieldPath string) error { + val := reflect.ValueOf(v) + if val.Kind() == reflect.Ptr { + val = val.Elem() + } + + if val.Kind() != reflect.Struct { + return nil + } + + typ := val.Type() + for i := 0; i < val.NumField(); i++ { + field := val.Field(i) + fieldType := typ.Field(i) + + if !field.CanInterface() { + continue + } + + fieldName := fieldType.Name + if fieldPath != "" { + fieldName = fieldPath + "." + fieldName + } + + switch field.Kind() { + case reflect.String: + if err := validateStringForInjection(field.String()); err != nil { + return ValidationError(fmt.Errorf("field %s: %w", fieldName, err)) + } + + case reflect.Slice: + if field.Type().Elem().Kind() == reflect.String { + for j := 0; j < field.Len(); j++ { + if err := validateStringForInjection(field.Index(j).String()); err != nil { + return ValidationError(fmt.Errorf("field %s[%d]: %w", fieldName, j, err)) + } + } + } + + case reflect.Struct: + if err := validateStructStrings(field.Interface(), fieldName); err != nil { + return err + } + } + } + + return nil +} + +func ValidateInstanceName(name string) error { + // Validate instance name + if name == "" { + return ValidationError(fmt.Errorf("name cannot be empty")) + } + if !validNamePattern.MatchString(name) { + return ValidationError(fmt.Errorf("name contains invalid characters (only alphanumeric, hyphens, underscores allowed)")) + } + if len(name) > 50 { + return ValidationError(fmt.Errorf("name too long (max 50 characters)")) + } + return nil +}