Enhance instance creation with name validation and security checks

This commit is contained in:
2025-07-19 16:20:44 +02:00
parent 37107f76d5
commit dab23e487b
5 changed files with 146 additions and 24 deletions

View File

@@ -120,13 +120,19 @@ func (h *Handler) ListInstances() http.HandlerFunc {
// @Router /instances [post]
func (h *Handler) CreateInstance() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
name := chi.URLParam(r, "name")
if name == "" {
http.Error(w, "Instance name cannot be empty", http.StatusBadRequest)
return
}
var options InstanceOptions
if err := json.NewDecoder(r.Body).Decode(&options); err != nil {
http.Error(w, "Invalid request body", http.StatusBadRequest)
return
}
instance, err := h.InstanceManager.CreateInstance(&options)
instance, err := h.InstanceManager.CreateInstance(name, &options)
if err != nil {
http.Error(w, "Failed to create instance: "+err.Error(), http.StatusInternalServerError)
return

View File

@@ -7,7 +7,7 @@ import (
// InstanceManager defines the interface for managing instances of the llama server.
type InstanceManager interface {
ListInstances() ([]*Instance, error)
CreateInstance(options *InstanceOptions) (*Instance, error)
CreateInstance(name string, options *InstanceOptions) (*Instance, error)
GetInstance(name string) (*Instance, error)
UpdateInstance(name string, options *InstanceOptions) (*Instance, error)
DeleteInstance(name string) error
@@ -43,19 +43,24 @@ func (im *instanceManager) ListInstances() ([]*Instance, error) {
// CreateInstance creates a new instance with the given options and returns it.
// The instance is initially in a "stopped" state.
func (im *instanceManager) CreateInstance(options *InstanceOptions) (*Instance, error) {
func (im *instanceManager) CreateInstance(name string, options *InstanceOptions) (*Instance, error) {
if options == nil {
return nil, fmt.Errorf("instance options cannot be nil")
}
// Check if name is provided
if options.Name == "" || !isValidInstanceName(options.Name) {
return nil, fmt.Errorf("invalid instance name: %s", options.Name)
err := ValidateInstanceName(name)
if err != nil {
return nil, err
}
err = ValidateInstanceOptions(options)
if err != nil {
return nil, err
}
// Check if instance with this name already exists
if im.instances[options.Name] != nil {
return nil, fmt.Errorf("instance with name %s already exists", options.Name)
if im.instances[name] != nil {
return nil, fmt.Errorf("instance with name %s already exists", name)
}
// Assign a port if not specified
@@ -67,23 +72,12 @@ func (im *instanceManager) CreateInstance(options *InstanceOptions) (*Instance,
options.Port = port
}
instance := NewInstance(options.Name, options)
instance := NewInstance(name, options)
im.instances[instance.Name] = instance
return instance, nil
}
// isValidInstanceName checks if the instance name is valid.
func isValidInstanceName(name string) bool {
// A simple validation: name should only contain alphanumeric characters, dashes, and underscores
for _, char := range name {
if !(('a' <= char && char <= 'z') || ('A' <= char && char <= 'Z') || ('0' <= char && char <= '9') || char == '-' || char == '_') {
return false
}
}
return true
}
// GetInstance retrieves an instance by its name.
func (im *instanceManager) GetInstance(name string) (*Instance, error) {
instance, exists := im.instances[name]
@@ -100,6 +94,15 @@ func (im *instanceManager) UpdateInstance(name string, options *InstanceOptions)
return nil, fmt.Errorf("instance with name %s not found", name)
}
if options == nil {
return nil, fmt.Errorf("instance options cannot be nil")
}
err := ValidateInstanceOptions(options)
if err != nil {
return nil, err
}
instance.SetOptions(options)
return instance, nil
}

View File

@@ -33,8 +33,6 @@ func (d Duration) ToDuration() time.Duration {
}
type InstanceOptions struct {
Name string `json:"name,omitempty"` // Display name
// Auto restart
AutoRestart bool `json:"auto_restart,omitempty"`
MaxRestarts int `json:"max_restarts,omitempty"`

View File

@@ -26,12 +26,12 @@ func SetupRouter(handler *Handler) *chi.Mux {
// Instance management endpoints
r.Route("/instances", func(r chi.Router) {
r.Get("/", handler.ListInstances()) // List all instances
r.Post("/", handler.CreateInstance()) // Create and start new instance
r.Get("/", handler.ListInstances()) // List all instances
r.Route("/{name}", func(r chi.Router) {
// Instance management
r.Get("/", handler.GetInstance()) // Get instance details
r.Post("/", handler.CreateInstance()) // Create and start new instance
r.Put("/", handler.UpdateInstance()) // Update instance configuration
r.Delete("/", handler.DeleteInstance()) // Stop and remove instance
r.Post("/start", handler.StartInstance()) // Start stopped instance

115
server/pkg/validation.go Normal file
View File

@@ -0,0 +1,115 @@
package llamactl
import (
"fmt"
"reflect"
"regexp"
)
// Simple security validation that focuses only on actual injection risks
var (
// Block shell metacharacters that could enable command injection
dangerousPatterns = []*regexp.Regexp{
regexp.MustCompile(`[;&|$` + "`" + `]`), // Shell metacharacters
regexp.MustCompile(`\$\(.*\)`), // Command substitution $(...)
regexp.MustCompile("`.*`"), // Command substitution backticks
}
// Simple validation for instance names
validNamePattern = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`)
)
type ValidationError error
// validateStringForInjection checks if a string contains dangerous patterns
func validateStringForInjection(value string) error {
for _, pattern := range dangerousPatterns {
if pattern.MatchString(value) {
return ValidationError(fmt.Errorf("value contains potentially dangerous characters: %s", value))
}
}
return nil
}
// ValidateInstanceOptions performs minimal security validation
func ValidateInstanceOptions(options *InstanceOptions) error {
if options == nil {
return ValidationError(fmt.Errorf("options cannot be nil"))
}
// Use reflection to check all string fields for injection patterns
if err := validateStructStrings(&options.LlamaServerOptions, ""); err != nil {
return err
}
// Basic network validation - only check for reasonable ranges
if options.Port < 0 || options.Port > 65535 {
return ValidationError(fmt.Errorf("invalid port range"))
}
return nil
}
// validateStructStrings recursively validates all string fields in a struct
func validateStructStrings(v any, fieldPath string) error {
val := reflect.ValueOf(v)
if val.Kind() == reflect.Ptr {
val = val.Elem()
}
if val.Kind() != reflect.Struct {
return nil
}
typ := val.Type()
for i := 0; i < val.NumField(); i++ {
field := val.Field(i)
fieldType := typ.Field(i)
if !field.CanInterface() {
continue
}
fieldName := fieldType.Name
if fieldPath != "" {
fieldName = fieldPath + "." + fieldName
}
switch field.Kind() {
case reflect.String:
if err := validateStringForInjection(field.String()); err != nil {
return ValidationError(fmt.Errorf("field %s: %w", fieldName, err))
}
case reflect.Slice:
if field.Type().Elem().Kind() == reflect.String {
for j := 0; j < field.Len(); j++ {
if err := validateStringForInjection(field.Index(j).String()); err != nil {
return ValidationError(fmt.Errorf("field %s[%d]: %w", fieldName, j, err))
}
}
}
case reflect.Struct:
if err := validateStructStrings(field.Interface(), fieldName); err != nil {
return err
}
}
}
return nil
}
func ValidateInstanceName(name string) error {
// Validate instance name
if name == "" {
return ValidationError(fmt.Errorf("name cannot be empty"))
}
if !validNamePattern.MatchString(name) {
return ValidationError(fmt.Errorf("name contains invalid characters (only alphanumeric, hyphens, underscores allowed)"))
}
if len(name) > 50 {
return ValidationError(fmt.Errorf("name too long (max 50 characters)"))
}
return nil
}