diff --git a/pkg/backends/backend.go b/pkg/backends/backend.go new file mode 100644 index 0000000..c28a2cc --- /dev/null +++ b/pkg/backends/backend.go @@ -0,0 +1,7 @@ +package backends + +type BackendType string + +const ( + BackendTypeLlamaCpp BackendType = "llama_cpp" +) diff --git a/pkg/instance/instance.go b/pkg/instance/instance.go index 3780f84..fc5089c 100644 --- a/pkg/instance/instance.go +++ b/pkg/instance/instance.go @@ -5,7 +5,7 @@ import ( "encoding/json" "fmt" "io" - "llamactl/pkg/backends/llamacpp" + "llamactl/pkg/backends" "llamactl/pkg/config" "log" "net/http" @@ -29,52 +29,6 @@ func (realTimeProvider) Now() time.Time { return time.Now() } -type CreateInstanceOptions struct { - // Auto restart - AutoRestart *bool `json:"auto_restart,omitempty"` - MaxRestarts *int `json:"max_restarts,omitempty"` - RestartDelay *int `json:"restart_delay,omitempty"` - // On demand start - OnDemandStart *bool `json:"on_demand_start,omitempty"` - // Idle timeout - IdleTimeout *int `json:"idle_timeout,omitempty"` - // LlamaServerOptions contains the options for the llama server - llamacpp.LlamaServerOptions `json:",inline"` -} - -// UnmarshalJSON implements custom JSON unmarshaling for CreateInstanceOptions -// This is needed because the embedded LlamaServerOptions has its own UnmarshalJSON -// which can interfere with proper unmarshaling of the pointer fields -func (c *CreateInstanceOptions) UnmarshalJSON(data []byte) error { - // First, unmarshal into a temporary struct without the embedded type - type tempCreateOptions struct { - AutoRestart *bool `json:"auto_restart,omitempty"` - MaxRestarts *int `json:"max_restarts,omitempty"` - RestartDelay *int `json:"restart_delay,omitempty"` - OnDemandStart *bool `json:"on_demand_start,omitempty"` - IdleTimeout *int `json:"idle_timeout,omitempty"` - } - - var temp tempCreateOptions - if err := json.Unmarshal(data, &temp); err != nil { - return err - } - - // Copy the pointer fields - c.AutoRestart = temp.AutoRestart - c.MaxRestarts = temp.MaxRestarts - c.RestartDelay = temp.RestartDelay - c.OnDemandStart = temp.OnDemandStart - c.IdleTimeout = temp.IdleTimeout - - // Now unmarshal the embedded LlamaServerOptions - if err := json.Unmarshal(data, &c.LlamaServerOptions); err != nil { - return err - } - - return nil -} - // Process represents a running instance of the llama server type Process struct { Name string `json:"name"` @@ -110,101 +64,17 @@ type Process struct { timeProvider TimeProvider `json:"-"` // Time provider for testing } -// validateAndCopyOptions validates and creates a deep copy of the provided options -// It applies validation rules and returns a safe copy -func validateAndCopyOptions(name string, options *CreateInstanceOptions) *CreateInstanceOptions { - optionsCopy := &CreateInstanceOptions{} - - if options != nil { - // Copy the embedded LlamaServerOptions - optionsCopy.LlamaServerOptions = options.LlamaServerOptions - - // Copy and validate pointer fields - if options.AutoRestart != nil { - autoRestart := *options.AutoRestart - optionsCopy.AutoRestart = &autoRestart - } - - if options.MaxRestarts != nil { - maxRestarts := *options.MaxRestarts - if maxRestarts < 0 { - log.Printf("Instance %s MaxRestarts value (%d) cannot be negative, setting to 0", name, maxRestarts) - maxRestarts = 0 - } - optionsCopy.MaxRestarts = &maxRestarts - } - - if options.RestartDelay != nil { - restartDelay := *options.RestartDelay - if restartDelay < 0 { - log.Printf("Instance %s RestartDelay value (%d) cannot be negative, setting to 0 seconds", name, restartDelay) - restartDelay = 0 - } - optionsCopy.RestartDelay = &restartDelay - } - - if options.OnDemandStart != nil { - onDemandStart := *options.OnDemandStart - optionsCopy.OnDemandStart = &onDemandStart - } - - if options.IdleTimeout != nil { - idleTimeout := *options.IdleTimeout - if idleTimeout < 0 { - log.Printf("Instance %s IdleTimeout value (%d) cannot be negative, setting to 0 minutes", name, idleTimeout) - idleTimeout = 0 - } - optionsCopy.IdleTimeout = &idleTimeout - } - } - - return optionsCopy -} - -// applyDefaultOptions applies default values from global settings to any nil options -func applyDefaultOptions(options *CreateInstanceOptions, globalSettings *config.InstancesConfig) { - if globalSettings == nil { - return - } - - if options.AutoRestart == nil { - defaultAutoRestart := globalSettings.DefaultAutoRestart - options.AutoRestart = &defaultAutoRestart - } - - if options.MaxRestarts == nil { - defaultMaxRestarts := globalSettings.DefaultMaxRestarts - options.MaxRestarts = &defaultMaxRestarts - } - - if options.RestartDelay == nil { - defaultRestartDelay := globalSettings.DefaultRestartDelay - options.RestartDelay = &defaultRestartDelay - } - - if options.OnDemandStart == nil { - defaultOnDemandStart := globalSettings.DefaultOnDemandStart - options.OnDemandStart = &defaultOnDemandStart - } - - if options.IdleTimeout == nil { - defaultIdleTimeout := 0 - options.IdleTimeout = &defaultIdleTimeout - } -} - // NewInstance creates a new instance with the given name, log path, and options func NewInstance(name string, globalSettings *config.InstancesConfig, options *CreateInstanceOptions, onStatusChange func(oldStatus, newStatus InstanceStatus)) *Process { // Validate and copy options - optionsCopy := validateAndCopyOptions(name, options) - // Apply defaults - applyDefaultOptions(optionsCopy, globalSettings) + options.ValidateAndApplyDefaults(name, globalSettings) + // Create the instance logger logger := NewInstanceLogger(name, globalSettings.LogsDir) return &Process{ Name: name, - options: optionsCopy, + options: options, globalSettings: globalSettings, logger: logger, timeProvider: realTimeProvider{}, @@ -220,6 +90,30 @@ func (i *Process) GetOptions() *CreateInstanceOptions { return i.options } +func (i *Process) GetPort() int { + i.mu.RLock() + defer i.mu.RUnlock() + if i.options != nil { + switch i.options.BackendType { + case backends.BackendTypeLlamaCpp: + return i.options.LlamaServerOptions.Port + } + } + return 0 +} + +func (i *Process) GetHost() string { + i.mu.RLock() + defer i.mu.RUnlock() + if i.options != nil { + switch i.options.BackendType { + case backends.BackendTypeLlamaCpp: + return i.options.LlamaServerOptions.Host + } + } + return "" +} + func (i *Process) SetOptions(options *CreateInstanceOptions) { i.mu.Lock() defer i.mu.Unlock() @@ -229,11 +123,10 @@ func (i *Process) SetOptions(options *CreateInstanceOptions) { return } - // Validate and copy options and apply defaults - optionsCopy := validateAndCopyOptions(i.Name, options) - applyDefaultOptions(optionsCopy, i.globalSettings) + // Validate and copy options + options.ValidateAndApplyDefaults(i.Name, i.globalSettings) - i.options = optionsCopy + i.options = options // Clear the proxy so it gets recreated with new options i.proxy = nil } @@ -256,7 +149,15 @@ func (i *Process) GetProxy() (*httputil.ReverseProxy, error) { return nil, fmt.Errorf("instance %s has no options set", i.Name) } - targetURL, err := url.Parse(fmt.Sprintf("http://%s:%d", i.options.Host, i.options.Port)) + var host string + var port int + switch i.options.BackendType { + case backends.BackendTypeLlamaCpp: + host = i.options.LlamaServerOptions.Host + port = i.options.LlamaServerOptions.Port + } + + targetURL, err := url.Parse(fmt.Sprintf("http://%s:%d", host, port)) if err != nil { return nil, fmt.Errorf("failed to parse target URL for instance %s: %w", i.Name, err) } @@ -286,44 +187,36 @@ func (i *Process) MarshalJSON() ([]byte, error) { i.mu.RLock() defer i.mu.RUnlock() - // Create a temporary struct with exported fields for JSON marshalling - temp := struct { - Name string `json:"name"` + // Use anonymous struct to avoid recursion + type Alias Process + return json.Marshal(&struct { + *Alias Options *CreateInstanceOptions `json:"options,omitempty"` - Status InstanceStatus `json:"status"` - Created int64 `json:"created,omitempty"` }{ - Name: i.Name, + Alias: (*Alias)(i), Options: i.options, - Status: i.Status, - Created: i.Created, - } - - return json.Marshal(temp) + }) } // UnmarshalJSON implements json.Unmarshaler for Instance func (i *Process) UnmarshalJSON(data []byte) error { - // Create a temporary struct for unmarshalling - temp := struct { - Name string `json:"name"` + // Use anonymous struct to avoid recursion + type Alias Process + aux := &struct { + *Alias Options *CreateInstanceOptions `json:"options,omitempty"` - Status InstanceStatus `json:"status"` - Created int64 `json:"created,omitempty"` - }{} + }{ + Alias: (*Alias)(i), + } - if err := json.Unmarshal(data, &temp); err != nil { + if err := json.Unmarshal(data, aux); err != nil { return err } - // Set the fields - i.Name = temp.Name - i.Status = temp.Status - i.Created = temp.Created - - // Handle options with validation but no defaults - if temp.Options != nil { - i.options = validateAndCopyOptions(i.Name, temp.Options) + // Handle options with validation and defaults + if aux.Options != nil { + aux.Options.ValidateAndApplyDefaults(i.Name, i.globalSettings) + i.options = aux.Options } return nil diff --git a/pkg/instance/instance_test.go b/pkg/instance/instance_test.go index 4f30ab6..aa916b9 100644 --- a/pkg/instance/instance_test.go +++ b/pkg/instance/instance_test.go @@ -2,6 +2,7 @@ package instance_test import ( "encoding/json" + "llamactl/pkg/backends" "llamactl/pkg/backends/llamacpp" "llamactl/pkg/config" "llamactl/pkg/instance" @@ -18,7 +19,8 @@ func TestNewInstance(t *testing.T) { } options := &instance.CreateInstanceOptions{ - LlamaServerOptions: llamacpp.LlamaServerOptions{ + BackendType: backends.BackendTypeLlamaCpp, + LlamaServerOptions: &llamacpp.LlamaServerOptions{ Model: "/path/to/model.gguf", Port: 8080, }, @@ -27,22 +29,22 @@ func TestNewInstance(t *testing.T) { // Mock onStatusChange function mockOnStatusChange := func(oldStatus, newStatus instance.InstanceStatus) {} - instance := instance.NewInstance("test-instance", globalSettings, options, mockOnStatusChange) + inst := instance.NewInstance("test-instance", globalSettings, options, mockOnStatusChange) - if instance.Name != "test-instance" { - t.Errorf("Expected name 'test-instance', got %q", instance.Name) + if inst.Name != "test-instance" { + t.Errorf("Expected name 'test-instance', got %q", inst.Name) } - if instance.IsRunning() { + if inst.IsRunning() { t.Error("New instance should not be running") } // Check that options were properly set with defaults applied - opts := instance.GetOptions() - if opts.Model != "/path/to/model.gguf" { - t.Errorf("Expected model '/path/to/model.gguf', got %q", opts.Model) + opts := inst.GetOptions() + if opts.LlamaServerOptions.Model != "/path/to/model.gguf" { + t.Errorf("Expected model '/path/to/model.gguf', got %q", opts.LlamaServerOptions.Model) } - if opts.Port != 8080 { - t.Errorf("Expected port 8080, got %d", opts.Port) + if inst.GetPort() != 8080 { + t.Errorf("Expected port 8080, got %d", inst.GetPort()) } // Check that defaults were applied @@ -74,7 +76,8 @@ func TestNewInstance_WithRestartOptions(t *testing.T) { AutoRestart: &autoRestart, MaxRestarts: &maxRestarts, RestartDelay: &restartDelay, - LlamaServerOptions: llamacpp.LlamaServerOptions{ + BackendType: backends.BackendTypeLlamaCpp, + LlamaServerOptions: &llamacpp.LlamaServerOptions{ Model: "/path/to/model.gguf", }, } @@ -106,7 +109,8 @@ func TestSetOptions(t *testing.T) { } initialOptions := &instance.CreateInstanceOptions{ - LlamaServerOptions: llamacpp.LlamaServerOptions{ + BackendType: backends.BackendTypeLlamaCpp, + LlamaServerOptions: &llamacpp.LlamaServerOptions{ Model: "/path/to/model.gguf", Port: 8080, }, @@ -119,7 +123,8 @@ func TestSetOptions(t *testing.T) { // Update options newOptions := &instance.CreateInstanceOptions{ - LlamaServerOptions: llamacpp.LlamaServerOptions{ + BackendType: backends.BackendTypeLlamaCpp, + LlamaServerOptions: &llamacpp.LlamaServerOptions{ Model: "/path/to/new-model.gguf", Port: 8081, }, @@ -128,11 +133,11 @@ func TestSetOptions(t *testing.T) { inst.SetOptions(newOptions) opts := inst.GetOptions() - if opts.Model != "/path/to/new-model.gguf" { - t.Errorf("Expected updated model '/path/to/new-model.gguf', got %q", opts.Model) + if opts.LlamaServerOptions.Model != "/path/to/new-model.gguf" { + t.Errorf("Expected updated model '/path/to/new-model.gguf', got %q", opts.LlamaServerOptions.Model) } - if opts.Port != 8081 { - t.Errorf("Expected updated port 8081, got %d", opts.Port) + if inst.GetPort() != 8081 { + t.Errorf("Expected updated port 8081, got %d", inst.GetPort()) } // Check that defaults are still applied @@ -147,7 +152,8 @@ func TestGetProxy(t *testing.T) { } options := &instance.CreateInstanceOptions{ - LlamaServerOptions: llamacpp.LlamaServerOptions{ + BackendType: backends.BackendTypeLlamaCpp, + LlamaServerOptions: &llamacpp.LlamaServerOptions{ Host: "localhost", Port: 8080, }, @@ -186,7 +192,8 @@ func TestMarshalJSON(t *testing.T) { } options := &instance.CreateInstanceOptions{ - LlamaServerOptions: llamacpp.LlamaServerOptions{ + BackendType: backends.BackendTypeLlamaCpp, + LlamaServerOptions: &llamacpp.LlamaServerOptions{ Model: "/path/to/model.gguf", Port: 8080, }, @@ -225,8 +232,26 @@ func TestMarshalJSON(t *testing.T) { if !ok { t.Error("Expected options to be a map") } - if options_map["model"] != "/path/to/model.gguf" { - t.Errorf("Expected model '/path/to/model.gguf', got %v", options_map["model"]) + + // Check backend type + if options_map["backend_type"] != string(backends.BackendTypeLlamaCpp) { + t.Errorf("Expected backend_type '%s', got %v", backends.BackendTypeLlamaCpp, options_map["backend_type"]) + } + + // Check backend options + backend_options_data, ok := options_map["backend_options"] + if !ok { + t.Error("Expected backend_options to be included in JSON") + } + backend_options_map, ok := backend_options_data.(map[string]any) + if !ok { + t.Error("Expected backend_options to be a map") + } + if backend_options_map["model"] != "/path/to/model.gguf" { + t.Errorf("Expected model '/path/to/model.gguf', got %v", backend_options_map["model"]) + } + if backend_options_map["port"] != float64(8080) { + t.Errorf("Expected port 8080, got %v", backend_options_map["port"]) } } @@ -235,10 +260,13 @@ func TestUnmarshalJSON(t *testing.T) { "name": "test-instance", "status": "running", "options": { - "model": "/path/to/model.gguf", - "port": 8080, "auto_restart": false, - "max_restarts": 5 + "max_restarts": 5, + "backend_type": "llama_cpp", + "backend_options": { + "model": "/path/to/model.gguf", + "port": 8080 + } } }` @@ -259,11 +287,17 @@ func TestUnmarshalJSON(t *testing.T) { if opts == nil { t.Fatal("Expected options to be set") } - if opts.Model != "/path/to/model.gguf" { - t.Errorf("Expected model '/path/to/model.gguf', got %q", opts.Model) + if opts.BackendType != backends.BackendTypeLlamaCpp { + t.Errorf("Expected backend_type '%s', got %s", backends.BackendTypeLlamaCpp, opts.BackendType) } - if opts.Port != 8080 { - t.Errorf("Expected port 8080, got %d", opts.Port) + if opts.LlamaServerOptions == nil { + t.Fatal("Expected LlamaServerOptions to be set") + } + if opts.LlamaServerOptions.Model != "/path/to/model.gguf" { + t.Errorf("Expected model '/path/to/model.gguf', got %q", opts.LlamaServerOptions.Model) + } + if inst.GetPort() != 8080 { + t.Errorf("Expected port 8080, got %d", inst.GetPort()) } if opts.AutoRestart == nil || *opts.AutoRestart { t.Error("Expected AutoRestart to be false") @@ -313,7 +347,8 @@ func TestCreateInstanceOptionsValidation(t *testing.T) { options := &instance.CreateInstanceOptions{ MaxRestarts: tt.maxRestarts, RestartDelay: tt.restartDelay, - LlamaServerOptions: llamacpp.LlamaServerOptions{ + BackendType: backends.BackendTypeLlamaCpp, + LlamaServerOptions: &llamacpp.LlamaServerOptions{ Model: "/path/to/model.gguf", }, } diff --git a/pkg/instance/lifecycle.go b/pkg/instance/lifecycle.go index 91bea6c..28a65b9 100644 --- a/pkg/instance/lifecycle.go +++ b/pkg/instance/lifecycle.go @@ -40,7 +40,6 @@ func (i *Process) Start() error { } args := i.options.BuildCommandArgs() - i.ctx, i.cancel = context.WithCancel(context.Background()) i.cmd = exec.CommandContext(i.ctx, "llama-server", args...) @@ -173,11 +172,17 @@ func (i *Process) WaitForHealthy(timeout int) error { } // Build the health check URL directly - host := opts.Host + var host string + var port int + switch opts.BackendType { + case "llama-cpp": + host = opts.LlamaServerOptions.Host + port = opts.LlamaServerOptions.Port + } if host == "" { host = "localhost" } - healthURL := fmt.Sprintf("http://%s:%d/health", host, opts.Port) + healthURL := fmt.Sprintf("http://%s:%d/health", host, port) // Create a dedicated HTTP client for health checks client := &http.Client{ diff --git a/pkg/instance/options.go b/pkg/instance/options.go new file mode 100644 index 0000000..b9a2cca --- /dev/null +++ b/pkg/instance/options.go @@ -0,0 +1,141 @@ +package instance + +import ( + "encoding/json" + "fmt" + "llamactl/pkg/backends" + "llamactl/pkg/backends/llamacpp" + "llamactl/pkg/config" + "log" +) + +type CreateInstanceOptions struct { + // Auto restart + AutoRestart *bool `json:"auto_restart,omitempty"` + MaxRestarts *int `json:"max_restarts,omitempty"` + RestartDelay *int `json:"restart_delay,omitempty"` // seconds + // On demand start + OnDemandStart *bool `json:"on_demand_start,omitempty"` + // Idle timeout + IdleTimeout *int `json:"idle_timeout,omitempty"` // minutes + + BackendType backends.BackendType `json:"backend_type"` + BackendOptions map[string]any `json:"backend_options,omitempty"` + + // LlamaServerOptions contains the options for the llama server + LlamaServerOptions *llamacpp.LlamaServerOptions `json:"-"` +} + +// UnmarshalJSON implements custom JSON unmarshaling for CreateInstanceOptions +func (c *CreateInstanceOptions) UnmarshalJSON(data []byte) error { + // Use anonymous struct to avoid recursion + type Alias CreateInstanceOptions + aux := &struct { + *Alias + }{ + Alias: (*Alias)(c), + } + + if err := json.Unmarshal(data, aux); err != nil { + return err + } + + // Parse backend-specific options + switch c.BackendType { + case backends.BackendTypeLlamaCpp: + if c.BackendOptions != nil { + // Convert map to JSON and then unmarshal to LlamaServerOptions + optionsData, err := json.Marshal(c.BackendOptions) + if err != nil { + return fmt.Errorf("failed to marshal backend options: %w", err) + } + + c.LlamaServerOptions = &llamacpp.LlamaServerOptions{} + if err := json.Unmarshal(optionsData, c.LlamaServerOptions); err != nil { + return fmt.Errorf("failed to unmarshal llama.cpp options: %w", err) + } + } + default: + return fmt.Errorf("unknown backend type: %s", c.BackendType) + } + + return nil +} + +// MarshalJSON implements custom JSON marshaling for CreateInstanceOptions +func (c *CreateInstanceOptions) MarshalJSON() ([]byte, error) { + // Use anonymous struct to avoid recursion + type Alias CreateInstanceOptions + aux := struct { + *Alias + }{ + Alias: (*Alias)(c), + } + + // Convert LlamaServerOptions back to BackendOptions map for JSON + if c.BackendType == backends.BackendTypeLlamaCpp && c.LlamaServerOptions != nil { + data, err := json.Marshal(c.LlamaServerOptions) + if err != nil { + return nil, fmt.Errorf("failed to marshal llama server options: %w", err) + } + + var backendOpts map[string]any + if err := json.Unmarshal(data, &backendOpts); err != nil { + return nil, fmt.Errorf("failed to unmarshal to map: %w", err) + } + + aux.BackendOptions = backendOpts + } + + return json.Marshal(aux) +} + +// ValidateAndApplyDefaults validates the instance options and applies constraints +func (c *CreateInstanceOptions) ValidateAndApplyDefaults(name string, globalSettings *config.InstancesConfig) { + // Validate and apply constraints + if c.MaxRestarts != nil && *c.MaxRestarts < 0 { + log.Printf("Instance %s MaxRestarts value (%d) cannot be negative, setting to 0", name, *c.MaxRestarts) + *c.MaxRestarts = 0 + } + + if c.RestartDelay != nil && *c.RestartDelay < 0 { + log.Printf("Instance %s RestartDelay value (%d) cannot be negative, setting to 0 seconds", name, *c.RestartDelay) + *c.RestartDelay = 0 + } + + if c.IdleTimeout != nil && *c.IdleTimeout < 0 { + log.Printf("Instance %s IdleTimeout value (%d) cannot be negative, setting to 0 minutes", name, *c.IdleTimeout) + *c.IdleTimeout = 0 + } + + // Apply defaults from global settings for nil fields + if globalSettings != nil { + if c.AutoRestart == nil { + c.AutoRestart = &globalSettings.DefaultAutoRestart + } + if c.MaxRestarts == nil { + c.MaxRestarts = &globalSettings.DefaultMaxRestarts + } + if c.RestartDelay == nil { + c.RestartDelay = &globalSettings.DefaultRestartDelay + } + if c.OnDemandStart == nil { + c.OnDemandStart = &globalSettings.DefaultOnDemandStart + } + if c.IdleTimeout == nil { + defaultIdleTimeout := 0 + c.IdleTimeout = &defaultIdleTimeout + } + } +} + +// BuildCommandArgs builds command line arguments for the backend +func (c *CreateInstanceOptions) BuildCommandArgs() []string { + switch c.BackendType { + case backends.BackendTypeLlamaCpp: + if c.LlamaServerOptions != nil { + return c.LlamaServerOptions.BuildCommandArgs() + } + } + return []string{} +} diff --git a/pkg/instance/timeout_test.go b/pkg/instance/timeout_test.go index 05abd04..c791bfb 100644 --- a/pkg/instance/timeout_test.go +++ b/pkg/instance/timeout_test.go @@ -1,6 +1,7 @@ package instance_test import ( + "llamactl/pkg/backends" "llamactl/pkg/backends/llamacpp" "llamactl/pkg/config" "llamactl/pkg/instance" @@ -37,7 +38,8 @@ func TestUpdateLastRequestTime(t *testing.T) { } options := &instance.CreateInstanceOptions{ - LlamaServerOptions: llamacpp.LlamaServerOptions{ + BackendType: backends.BackendTypeLlamaCpp, + LlamaServerOptions: &llamacpp.LlamaServerOptions{ Model: "/path/to/model.gguf", }, } @@ -59,7 +61,8 @@ func TestShouldTimeout_NotRunning(t *testing.T) { idleTimeout := 1 // 1 minute options := &instance.CreateInstanceOptions{ IdleTimeout: &idleTimeout, - LlamaServerOptions: llamacpp.LlamaServerOptions{ + BackendType: backends.BackendTypeLlamaCpp, + LlamaServerOptions: &llamacpp.LlamaServerOptions{ Model: "/path/to/model.gguf", }, } @@ -96,7 +99,8 @@ func TestShouldTimeout_NoTimeoutConfigured(t *testing.T) { options := &instance.CreateInstanceOptions{ IdleTimeout: tt.idleTimeout, - LlamaServerOptions: llamacpp.LlamaServerOptions{ + BackendType: backends.BackendTypeLlamaCpp, + LlamaServerOptions: &llamacpp.LlamaServerOptions{ Model: "/path/to/model.gguf", }, } @@ -120,7 +124,8 @@ func TestShouldTimeout_WithinTimeLimit(t *testing.T) { idleTimeout := 5 // 5 minutes options := &instance.CreateInstanceOptions{ IdleTimeout: &idleTimeout, - LlamaServerOptions: llamacpp.LlamaServerOptions{ + BackendType: backends.BackendTypeLlamaCpp, + LlamaServerOptions: &llamacpp.LlamaServerOptions{ Model: "/path/to/model.gguf", }, } @@ -148,7 +153,8 @@ func TestShouldTimeout_ExceedsTimeLimit(t *testing.T) { idleTimeout := 1 // 1 minute options := &instance.CreateInstanceOptions{ IdleTimeout: &idleTimeout, - LlamaServerOptions: llamacpp.LlamaServerOptions{ + BackendType: backends.BackendTypeLlamaCpp, + LlamaServerOptions: &llamacpp.LlamaServerOptions{ Model: "/path/to/model.gguf", }, } @@ -194,7 +200,8 @@ func TestTimeoutConfiguration_Validation(t *testing.T) { t.Run(tt.name, func(t *testing.T) { options := &instance.CreateInstanceOptions{ IdleTimeout: tt.inputTimeout, - LlamaServerOptions: llamacpp.LlamaServerOptions{ + BackendType: backends.BackendTypeLlamaCpp, + LlamaServerOptions: &llamacpp.LlamaServerOptions{ Model: "/path/to/model.gguf", }, } diff --git a/pkg/manager/manager.go b/pkg/manager/manager.go index 390cefe..80652a8 100644 --- a/pkg/manager/manager.go +++ b/pkg/manager/manager.go @@ -248,8 +248,8 @@ func (im *instanceManager) loadInstance(name, path string) error { inst.SetStatus(persistedInstance.Status) // Check for port conflicts and add to maps - if inst.GetOptions() != nil && inst.GetOptions().Port > 0 { - port := inst.GetOptions().Port + if inst.GetPort() > 0 { + port := inst.GetPort() if im.ports[port] { return fmt.Errorf("port conflict: instance %s wants port %d which is already in use", name, port) } diff --git a/pkg/manager/manager_test.go b/pkg/manager/manager_test.go index a0d5492..c332739 100644 --- a/pkg/manager/manager_test.go +++ b/pkg/manager/manager_test.go @@ -2,6 +2,7 @@ package manager_test import ( "fmt" + "llamactl/pkg/backends" "llamactl/pkg/backends/llamacpp" "llamactl/pkg/config" "llamactl/pkg/instance" @@ -53,7 +54,8 @@ func TestPersistence(t *testing.T) { // Test instance persistence on creation manager1 := manager.NewInstanceManager(cfg) options := &instance.CreateInstanceOptions{ - LlamaServerOptions: llamacpp.LlamaServerOptions{ + BackendType: backends.BackendTypeLlamaCpp, + LlamaServerOptions: &llamacpp.LlamaServerOptions{ Model: "/path/to/model.gguf", Port: 8080, }, @@ -109,12 +111,13 @@ func TestConcurrentAccess(t *testing.T) { errChan := make(chan error, 10) // Concurrent instance creation - for i := 0; i < 5; i++ { + for i := range 5 { wg.Add(1) go func(index int) { defer wg.Done() options := &instance.CreateInstanceOptions{ - LlamaServerOptions: llamacpp.LlamaServerOptions{ + BackendType: backends.BackendTypeLlamaCpp, + LlamaServerOptions: &llamacpp.LlamaServerOptions{ Model: "/path/to/model.gguf", }, } @@ -150,7 +153,8 @@ func TestShutdown(t *testing.T) { // Create test instance options := &instance.CreateInstanceOptions{ - LlamaServerOptions: llamacpp.LlamaServerOptions{ + BackendType: backends.BackendTypeLlamaCpp, + LlamaServerOptions: &llamacpp.LlamaServerOptions{ Model: "/path/to/model.gguf", }, } diff --git a/pkg/manager/operations.go b/pkg/manager/operations.go index b347900..6f65680 100644 --- a/pkg/manager/operations.go +++ b/pkg/manager/operations.go @@ -2,6 +2,7 @@ package manager import ( "fmt" + "llamactl/pkg/backends" "llamactl/pkg/instance" "llamactl/pkg/validation" "os" @@ -52,19 +53,9 @@ func (im *instanceManager) CreateInstance(name string, options *instance.CreateI return nil, fmt.Errorf("instance with name %s already exists", name) } - // Assign a port if not specified - if options.Port == 0 { - port, err := im.getNextAvailablePort() - if err != nil { - return nil, fmt.Errorf("failed to get next available port: %w", err) - } - options.Port = port - } else { - // Validate the specified port - if _, exists := im.ports[options.Port]; exists { - return nil, fmt.Errorf("port %d is already in use", options.Port) - } - im.ports[options.Port] = true + // Assign and validate port for backend-specific options + if err := im.assignAndValidatePort(options); err != nil { + return nil, err } statusCallback := func(oldStatus, newStatus instance.InstanceStatus) { @@ -73,7 +64,6 @@ func (im *instanceManager) CreateInstance(name string, options *instance.CreateI inst := instance.NewInstance(name, &im.instancesConfig, options, statusCallback) im.instances[inst.Name] = inst - im.ports[options.Port] = true if err := im.persistInstance(inst); err != nil { return nil, fmt.Errorf("failed to persist instance %s: %w", name, err) @@ -157,7 +147,7 @@ func (im *instanceManager) DeleteInstance(name string) error { return fmt.Errorf("instance with name %s is still running, stop it before deleting", name) } - delete(im.ports, instance.GetOptions().Port) + delete(im.ports, instance.GetPort()) delete(im.instances, name) // Delete the instance's config file if persistence is enabled @@ -262,3 +252,49 @@ func (im *instanceManager) GetInstanceLogs(name string) (string, error) { // TODO: Implement actual log retrieval logic return fmt.Sprintf("Logs for instance %s", name), nil } + +// getPortFromOptions extracts the port from backend-specific options +func (im *instanceManager) getPortFromOptions(options *instance.CreateInstanceOptions) int { + switch options.BackendType { + case backends.BackendTypeLlamaCpp: + if options.LlamaServerOptions != nil { + return options.LlamaServerOptions.Port + } + } + return 0 +} + +// setPortInOptions sets the port in backend-specific options +func (im *instanceManager) setPortInOptions(options *instance.CreateInstanceOptions, port int) { + switch options.BackendType { + case backends.BackendTypeLlamaCpp: + if options.LlamaServerOptions != nil { + options.LlamaServerOptions.Port = port + } + } +} + +// assignAndValidatePort assigns a port if not specified and validates it's not in use +func (im *instanceManager) assignAndValidatePort(options *instance.CreateInstanceOptions) error { + currentPort := im.getPortFromOptions(options) + + if currentPort == 0 { + // Assign a port if not specified + port, err := im.getNextAvailablePort() + if err != nil { + return fmt.Errorf("failed to get next available port: %w", err) + } + im.setPortInOptions(options, port) + // Mark the port as used + im.ports[port] = true + } else { + // Validate the specified port + if _, exists := im.ports[currentPort]; exists { + return fmt.Errorf("port %d is already in use", currentPort) + } + // Mark the port as used + im.ports[currentPort] = true + } + + return nil +} diff --git a/pkg/manager/operations_test.go b/pkg/manager/operations_test.go index d045b81..7dd4889 100644 --- a/pkg/manager/operations_test.go +++ b/pkg/manager/operations_test.go @@ -1,6 +1,7 @@ package manager_test import ( + "llamactl/pkg/backends" "llamactl/pkg/backends/llamacpp" "llamactl/pkg/config" "llamactl/pkg/instance" @@ -13,7 +14,8 @@ func TestCreateInstance_Success(t *testing.T) { manager := createTestManager() options := &instance.CreateInstanceOptions{ - LlamaServerOptions: llamacpp.LlamaServerOptions{ + BackendType: backends.BackendTypeLlamaCpp, + LlamaServerOptions: &llamacpp.LlamaServerOptions{ Model: "/path/to/model.gguf", Port: 8080, }, @@ -30,8 +32,8 @@ func TestCreateInstance_Success(t *testing.T) { if inst.GetStatus() != instance.Stopped { t.Error("New instance should not be running") } - if inst.GetOptions().Port != 8080 { - t.Errorf("Expected port 8080, got %d", inst.GetOptions().Port) + if inst.GetPort() != 8080 { + t.Errorf("Expected port 8080, got %d", inst.GetPort()) } } @@ -39,7 +41,8 @@ func TestCreateInstance_ValidationAndLimits(t *testing.T) { // Test duplicate names mngr := createTestManager() options := &instance.CreateInstanceOptions{ - LlamaServerOptions: llamacpp.LlamaServerOptions{ + BackendType: backends.BackendTypeLlamaCpp, + LlamaServerOptions: &llamacpp.LlamaServerOptions{ Model: "/path/to/model.gguf", }, } @@ -86,7 +89,8 @@ func TestPortManagement(t *testing.T) { // Test auto port assignment options1 := &instance.CreateInstanceOptions{ - LlamaServerOptions: llamacpp.LlamaServerOptions{ + BackendType: backends.BackendTypeLlamaCpp, + LlamaServerOptions: &llamacpp.LlamaServerOptions{ Model: "/path/to/model.gguf", }, } @@ -96,14 +100,15 @@ func TestPortManagement(t *testing.T) { t.Fatalf("CreateInstance failed: %v", err) } - port1 := inst1.GetOptions().Port + port1 := inst1.GetPort() if port1 < 8000 || port1 > 9000 { t.Errorf("Expected port in range 8000-9000, got %d", port1) } // Test port conflict detection options2 := &instance.CreateInstanceOptions{ - LlamaServerOptions: llamacpp.LlamaServerOptions{ + BackendType: backends.BackendTypeLlamaCpp, + LlamaServerOptions: &llamacpp.LlamaServerOptions{ Model: "/path/to/model2.gguf", Port: port1, // Same port - should conflict }, @@ -120,7 +125,8 @@ func TestPortManagement(t *testing.T) { // Test port release on deletion specificPort := 8080 options3 := &instance.CreateInstanceOptions{ - LlamaServerOptions: llamacpp.LlamaServerOptions{ + BackendType: backends.BackendTypeLlamaCpp, + LlamaServerOptions: &llamacpp.LlamaServerOptions{ Model: "/path/to/model.gguf", Port: specificPort, }, @@ -147,7 +153,8 @@ func TestInstanceOperations(t *testing.T) { manager := createTestManager() options := &instance.CreateInstanceOptions{ - LlamaServerOptions: llamacpp.LlamaServerOptions{ + BackendType: backends.BackendTypeLlamaCpp, + LlamaServerOptions: &llamacpp.LlamaServerOptions{ Model: "/path/to/model.gguf", }, } @@ -169,7 +176,8 @@ func TestInstanceOperations(t *testing.T) { // Update instance newOptions := &instance.CreateInstanceOptions{ - LlamaServerOptions: llamacpp.LlamaServerOptions{ + BackendType: backends.BackendTypeLlamaCpp, + LlamaServerOptions: &llamacpp.LlamaServerOptions{ Model: "/path/to/new-model.gguf", Port: 8081, }, @@ -179,8 +187,8 @@ func TestInstanceOperations(t *testing.T) { if err != nil { t.Fatalf("UpdateInstance failed: %v", err) } - if updated.GetOptions().Model != "/path/to/new-model.gguf" { - t.Errorf("Expected model '/path/to/new-model.gguf', got %q", updated.GetOptions().Model) + if updated.GetOptions().LlamaServerOptions.Model != "/path/to/new-model.gguf" { + t.Errorf("Expected model '/path/to/new-model.gguf', got %q", updated.GetOptions().LlamaServerOptions.Model) } // List instances diff --git a/pkg/manager/timeout_test.go b/pkg/manager/timeout_test.go index 41ca188..23143d2 100644 --- a/pkg/manager/timeout_test.go +++ b/pkg/manager/timeout_test.go @@ -1,6 +1,7 @@ package manager_test import ( + "llamactl/pkg/backends" "llamactl/pkg/backends/llamacpp" "llamactl/pkg/config" "llamactl/pkg/instance" @@ -31,7 +32,8 @@ func TestTimeoutFunctionality(t *testing.T) { idleTimeout := 1 // 1 minute options := &instance.CreateInstanceOptions{ IdleTimeout: &idleTimeout, - LlamaServerOptions: llamacpp.LlamaServerOptions{ + BackendType: backends.BackendTypeLlamaCpp, + LlamaServerOptions: &llamacpp.LlamaServerOptions{ Model: "/path/to/model.gguf", }, } @@ -79,7 +81,8 @@ func TestTimeoutFunctionality(t *testing.T) { // Test that instance without timeout doesn't timeout noTimeoutOptions := &instance.CreateInstanceOptions{ - LlamaServerOptions: llamacpp.LlamaServerOptions{ + BackendType: backends.BackendTypeLlamaCpp, + LlamaServerOptions: &llamacpp.LlamaServerOptions{ Model: "/path/to/model.gguf", }, // No IdleTimeout set @@ -109,19 +112,22 @@ func TestEvictLRUInstance_Success(t *testing.T) { // Create 3 instances with idle timeout enabled (value doesn't matter for LRU logic) options1 := &instance.CreateInstanceOptions{ - LlamaServerOptions: llamacpp.LlamaServerOptions{ + BackendType: backends.BackendTypeLlamaCpp, + LlamaServerOptions: &llamacpp.LlamaServerOptions{ Model: "/path/to/model1.gguf", }, IdleTimeout: func() *int { timeout := 1; return &timeout }(), // Any value > 0 } options2 := &instance.CreateInstanceOptions{ - LlamaServerOptions: llamacpp.LlamaServerOptions{ + BackendType: backends.BackendTypeLlamaCpp, + LlamaServerOptions: &llamacpp.LlamaServerOptions{ Model: "/path/to/model2.gguf", }, IdleTimeout: func() *int { timeout := 1; return &timeout }(), // Any value > 0 } options3 := &instance.CreateInstanceOptions{ - LlamaServerOptions: llamacpp.LlamaServerOptions{ + BackendType: backends.BackendTypeLlamaCpp, + LlamaServerOptions: &llamacpp.LlamaServerOptions{ Model: "/path/to/model3.gguf", }, IdleTimeout: func() *int { timeout := 1; return &timeout }(), // Any value > 0 @@ -188,7 +194,8 @@ func TestEvictLRUInstance_NoEligibleInstances(t *testing.T) { // Helper function to create instances with different timeout configurations createInstanceWithTimeout := func(manager manager.InstanceManager, name, model string, timeout *int) *instance.Process { options := &instance.CreateInstanceOptions{ - LlamaServerOptions: llamacpp.LlamaServerOptions{ + BackendType: backends.BackendTypeLlamaCpp, + LlamaServerOptions: &llamacpp.LlamaServerOptions{ Model: model, }, IdleTimeout: timeout, diff --git a/pkg/validation/validation.go b/pkg/validation/validation.go index 9145cf6..77873ca 100644 --- a/pkg/validation/validation.go +++ b/pkg/validation/validation.go @@ -2,6 +2,7 @@ package validation import ( "fmt" + "llamactl/pkg/backends" "llamactl/pkg/instance" "reflect" "regexp" @@ -33,20 +34,35 @@ func validateStringForInjection(value string) error { return nil } -// ValidateInstanceOptions performs minimal security validation +// ValidateInstanceOptions performs validation based on backend type func ValidateInstanceOptions(options *instance.CreateInstanceOptions) error { if options == nil { return ValidationError(fmt.Errorf("options cannot be nil")) } + // Validate based on backend type + switch options.BackendType { + case backends.BackendTypeLlamaCpp: + return validateLlamaCppOptions(options) + default: + return ValidationError(fmt.Errorf("unsupported backend type: %s", options.BackendType)) + } +} + +// validateLlamaCppOptions validates llama.cpp specific options +func validateLlamaCppOptions(options *instance.CreateInstanceOptions) error { + if options.LlamaServerOptions == nil { + return ValidationError(fmt.Errorf("llama server options cannot be nil for llama.cpp backend")) + } + // Use reflection to check all string fields for injection patterns - if err := validateStructStrings(&options.LlamaServerOptions, ""); err != nil { + if err := validateStructStrings(options.LlamaServerOptions, ""); err != nil { return err } - // Basic network validation - only check for reasonable ranges - if options.Port < 0 || options.Port > 65535 { - return ValidationError(fmt.Errorf("invalid port range")) + // Basic network validation for port + if options.LlamaServerOptions.Port < 0 || options.LlamaServerOptions.Port > 65535 { + return ValidationError(fmt.Errorf("invalid port range: %d", options.LlamaServerOptions.Port)) } return nil diff --git a/pkg/validation/validation_test.go b/pkg/validation/validation_test.go index 3e12606..8d8c49e 100644 --- a/pkg/validation/validation_test.go +++ b/pkg/validation/validation_test.go @@ -1,6 +1,7 @@ package validation_test import ( + "llamactl/pkg/backends" "llamactl/pkg/backends/llamacpp" "llamactl/pkg/instance" "llamactl/pkg/testutil" @@ -83,7 +84,8 @@ func TestValidateInstanceOptions_PortValidation(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { options := &instance.CreateInstanceOptions{ - LlamaServerOptions: llamacpp.LlamaServerOptions{ + BackendType: backends.BackendTypeLlamaCpp, + LlamaServerOptions: &llamacpp.LlamaServerOptions{ Port: tt.port, }, } @@ -136,7 +138,8 @@ func TestValidateInstanceOptions_StringInjection(t *testing.T) { t.Run(tt.name, func(t *testing.T) { // Test with Model field (string field) options := &instance.CreateInstanceOptions{ - LlamaServerOptions: llamacpp.LlamaServerOptions{ + BackendType: backends.BackendTypeLlamaCpp, + LlamaServerOptions: &llamacpp.LlamaServerOptions{ Model: tt.value, }, } @@ -173,7 +176,8 @@ func TestValidateInstanceOptions_ArrayInjection(t *testing.T) { t.Run(tt.name, func(t *testing.T) { // Test with Lora field (array field) options := &instance.CreateInstanceOptions{ - LlamaServerOptions: llamacpp.LlamaServerOptions{ + BackendType: backends.BackendTypeLlamaCpp, + LlamaServerOptions: &llamacpp.LlamaServerOptions{ Lora: tt.array, }, } @@ -196,7 +200,8 @@ func TestValidateInstanceOptions_MultipleFieldInjection(t *testing.T) { { name: "injection in model field", options: &instance.CreateInstanceOptions{ - LlamaServerOptions: llamacpp.LlamaServerOptions{ + BackendType: backends.BackendTypeLlamaCpp, + LlamaServerOptions: &llamacpp.LlamaServerOptions{ Model: "safe.gguf", HFRepo: "microsoft/model; curl evil.com", }, @@ -206,7 +211,8 @@ func TestValidateInstanceOptions_MultipleFieldInjection(t *testing.T) { { name: "injection in log file", options: &instance.CreateInstanceOptions{ - LlamaServerOptions: llamacpp.LlamaServerOptions{ + BackendType: backends.BackendTypeLlamaCpp, + LlamaServerOptions: &llamacpp.LlamaServerOptions{ Model: "safe.gguf", LogFile: "/tmp/log.txt | tee /etc/passwd", }, @@ -216,7 +222,8 @@ func TestValidateInstanceOptions_MultipleFieldInjection(t *testing.T) { { name: "all safe fields", options: &instance.CreateInstanceOptions{ - LlamaServerOptions: llamacpp.LlamaServerOptions{ + BackendType: backends.BackendTypeLlamaCpp, + LlamaServerOptions: &llamacpp.LlamaServerOptions{ Model: "/path/to/model.gguf", HFRepo: "microsoft/DialoGPT-medium", LogFile: "/tmp/llama.log", @@ -244,7 +251,8 @@ func TestValidateInstanceOptions_NonStringFields(t *testing.T) { AutoRestart: testutil.BoolPtr(true), MaxRestarts: testutil.IntPtr(5), RestartDelay: testutil.IntPtr(10), - LlamaServerOptions: llamacpp.LlamaServerOptions{ + BackendType: backends.BackendTypeLlamaCpp, + LlamaServerOptions: &llamacpp.LlamaServerOptions{ Port: 8080, GPULayers: 32, CtxSize: 4096, diff --git a/webui/src/__tests__/App.test.tsx b/webui/src/__tests__/App.test.tsx index 57f00a0..4358321 100644 --- a/webui/src/__tests__/App.test.tsx +++ b/webui/src/__tests__/App.test.tsx @@ -5,6 +5,7 @@ import App from '@/App' import { InstancesProvider } from '@/contexts/InstancesContext' import { instancesApi } from '@/lib/api' import type { Instance } from '@/types/instance' +import { BackendType } from '@/types/instance' import { AuthProvider } from '@/contexts/AuthContext' // Mock the API @@ -46,8 +47,8 @@ function renderApp() { describe('App Component - Critical Business Logic Only', () => { const mockInstances: Instance[] = [ - { name: 'test-instance-1', status: 'stopped', options: { model: 'model1.gguf' } }, - { name: 'test-instance-2', status: 'running', options: { model: 'model2.gguf' } } + { name: 'test-instance-1', status: 'stopped', options: { backend_type: BackendType.LLAMA_CPP, backend_options: { model: 'model1.gguf' } } }, + { name: 'test-instance-2', status: 'running', options: { backend_type: BackendType.LLAMA_CPP, backend_options: { model: 'model2.gguf' } } } ] beforeEach(() => { @@ -82,7 +83,7 @@ describe('App Component - Critical Business Logic Only', () => { const newInstance: Instance = { name: 'new-test-instance', status: 'stopped', - options: { model: 'new-model.gguf' } + options: { backend_type: BackendType.LLAMA_CPP, backend_options: { model: 'new-model.gguf' } } } vi.mocked(instancesApi.create).mockResolvedValue(newInstance) @@ -105,6 +106,7 @@ describe('App Component - Critical Business Logic Only', () => { await waitFor(() => { expect(instancesApi.create).toHaveBeenCalledWith('new-test-instance', { auto_restart: true, // Default value + backend_type: BackendType.LLAMA_CPP }) }) @@ -119,7 +121,7 @@ describe('App Component - Critical Business Logic Only', () => { const updatedInstance: Instance = { name: 'test-instance-1', status: 'stopped', - options: { model: 'updated-model.gguf' } + options: { backend_type: BackendType.LLAMA_CPP, backend_options: { model: 'updated-model.gguf' } } } vi.mocked(instancesApi.update).mockResolvedValue(updatedInstance) @@ -138,7 +140,8 @@ describe('App Component - Critical Business Logic Only', () => { // Verify correct API call with existing instance data await waitFor(() => { expect(instancesApi.update).toHaveBeenCalledWith('test-instance-1', { - model: "model1.gguf", // Pre-filled from existing instance + backend_type: BackendType.LLAMA_CPP, + backend_options: { model: "model1.gguf" } // Pre-filled from existing instance }) }) }) diff --git a/webui/src/components/BackendFormField.tsx b/webui/src/components/BackendFormField.tsx new file mode 100644 index 0000000..a210626 --- /dev/null +++ b/webui/src/components/BackendFormField.tsx @@ -0,0 +1,123 @@ +import React from 'react' +import { Input } from '@/components/ui/input' +import { Label } from '@/components/ui/label' +import { Checkbox } from '@/components/ui/checkbox' +import type { BackendOptions } from '@/schemas/instanceOptions' +import { getBackendFieldType, basicBackendFieldsConfig } from '@/lib/zodFormUtils' + +interface BackendFormFieldProps { + fieldKey: keyof BackendOptions + value: string | number | boolean | string[] | undefined + onChange: (key: string, value: string | number | boolean | string[] | undefined) => void +} + +const BackendFormField: React.FC = ({ fieldKey, value, onChange }) => { + // Get configuration for basic fields, or use field name for advanced fields + const config = basicBackendFieldsConfig[fieldKey as string] || { label: fieldKey } + + // Get type from Zod schema + const fieldType = getBackendFieldType(fieldKey) + + const handleChange = (newValue: string | number | boolean | string[] | undefined) => { + onChange(fieldKey as string, newValue) + } + + const renderField = () => { + switch (fieldType) { + case 'boolean': + return ( +
+ handleChange(checked)} + /> + +
+ ) + + case 'number': + return ( +
+ + { + const numValue = e.target.value ? parseFloat(e.target.value) : undefined + // Only update if the parsed value is valid or the input is empty + if (e.target.value === '' || (numValue !== undefined && !isNaN(numValue))) { + handleChange(numValue) + } + }} + placeholder={config.placeholder} + /> + {config.description && ( +

{config.description}

+ )} +
+ ) + + case 'array': + return ( +
+ + { + const arrayValue = e.target.value + ? e.target.value.split(',').map(s => s.trim()).filter(Boolean) + : undefined + handleChange(arrayValue) + }} + placeholder="item1, item2, item3" + /> + {config.description && ( +

{config.description}

+ )} +

Separate multiple values with commas

+
+ ) + + case 'text': + default: + return ( +
+ + handleChange(e.target.value || undefined)} + placeholder={config.placeholder} + /> + {config.description && ( +

{config.description}

+ )} +
+ ) + } + } + + return
{renderField()}
+} + +export default BackendFormField \ No newline at end of file diff --git a/webui/src/components/InstanceDialog.tsx b/webui/src/components/InstanceDialog.tsx index 56792e6..dc46e31 100644 --- a/webui/src/components/InstanceDialog.tsx +++ b/webui/src/components/InstanceDialog.tsx @@ -10,10 +10,11 @@ import { DialogHeader, DialogTitle, } from "@/components/ui/dialog"; -import type { CreateInstanceOptions, Instance } from "@/types/instance"; -import { getBasicFields, getAdvancedFields } from "@/lib/zodFormUtils"; +import { BackendType, type CreateInstanceOptions, type Instance } from "@/types/instance"; +import { getBasicFields, getAdvancedFields, getBasicBackendFields, getAdvancedBackendFields } from "@/lib/zodFormUtils"; import { ChevronDown, ChevronRight } from "lucide-react"; import ZodFormField from "@/components/ZodFormField"; +import BackendFormField from "@/components/BackendFormField"; interface InstanceDialogProps { open: boolean; @@ -38,6 +39,8 @@ const InstanceDialog: React.FC = ({ // Get field lists dynamically from the type const basicFields = getBasicFields(); const advancedFields = getAdvancedFields(); + const basicBackendFields = getBasicBackendFields(); + const advancedBackendFields = getAdvancedBackendFields(); // Reset form when dialog opens/closes or when instance changes useEffect(() => { @@ -51,6 +54,8 @@ const InstanceDialog: React.FC = ({ setInstanceName(""); setFormData({ auto_restart: true, // Default value + backend_type: BackendType.LLAMA_CPP, // Default backend type + backend_options: {}, }); } setShowAdvanced(false); // Always start with basic view @@ -65,6 +70,16 @@ const InstanceDialog: React.FC = ({ })); }; + const handleBackendFieldChange = (key: string, value: any) => { + setFormData((prev) => ({ + ...prev, + backend_options: { + ...prev.backend_options, + [key]: value, + }, + })); + }; + const handleNameChange = (name: string) => { setInstanceName(name); // Validate instance name @@ -89,7 +104,24 @@ const InstanceDialog: React.FC = ({ // Clean up undefined values to avoid sending empty fields const cleanOptions: CreateInstanceOptions = {}; Object.entries(formData).forEach(([key, value]) => { - if (value !== undefined && value !== "" && value !== null) { + if (key === 'backend_options' && value && typeof value === 'object') { + // Handle backend_options specially - clean nested object + const cleanBackendOptions: any = {}; + Object.entries(value).forEach(([backendKey, backendValue]) => { + if (backendValue !== undefined && backendValue !== null && (typeof backendValue !== 'string' || backendValue.trim() !== "")) { + // Handle arrays - don't include empty arrays + if (Array.isArray(backendValue) && backendValue.length === 0) { + return; + } + cleanBackendOptions[backendKey] = backendValue; + } + }); + + // Only include backend_options if it has content + if (Object.keys(cleanBackendOptions).length > 0) { + (cleanOptions as any)[key] = cleanBackendOptions; + } + } else if (value !== undefined && value !== null && (typeof value !== 'string' || value.trim() !== "")) { // Handle arrays - don't include empty arrays if (Array.isArray(value) && value.length === 0) { return; @@ -196,8 +228,9 @@ const InstanceDialog: React.FC = ({ (fieldKey) => fieldKey !== "auto_restart" && fieldKey !== "max_restarts" && - fieldKey !== "restart_delay" - ) // Exclude auto_restart, max_restarts, and restart_delay as they're handled above + fieldKey !== "restart_delay" && + fieldKey !== "backend_options" // backend_options is handled separately + ) .map((fieldKey) => ( = ({ ))} + {/* Backend Configuration Section */} +
+

Backend Configuration

+ + {/* Basic backend fields */} + {basicBackendFields.map((fieldKey) => ( + + ))} +
+ {/* Advanced Fields Toggle */}
diff --git a/webui/src/components/ZodFormField.tsx b/webui/src/components/ZodFormField.tsx index 2ee912d..f1ab226 100644 --- a/webui/src/components/ZodFormField.tsx +++ b/webui/src/components/ZodFormField.tsx @@ -3,6 +3,7 @@ import { Input } from '@/components/ui/input' import { Label } from '@/components/ui/label' import { Checkbox } from '@/components/ui/checkbox' import type { CreateInstanceOptions } from '@/types/instance' +import { BackendType } from '@/types/instance' import { getFieldType, basicFieldsConfig } from '@/lib/zodFormUtils' interface ZodFormFieldProps { @@ -23,6 +24,30 @@ const ZodFormField: React.FC = ({ fieldKey, value, onChange } } const renderField = () => { + // Special handling for backend_type field - render as dropdown + if (fieldKey === 'backend_type') { + return ( +
+ + + {config.description && ( +

{config.description}

+ )} +
+ ) + } + switch (fieldType) { case 'boolean': return ( diff --git a/webui/src/components/__tests__/InstanceCard.test.tsx b/webui/src/components/__tests__/InstanceCard.test.tsx index 5daebe4..e0c788a 100644 --- a/webui/src/components/__tests__/InstanceCard.test.tsx +++ b/webui/src/components/__tests__/InstanceCard.test.tsx @@ -3,6 +3,7 @@ import { render, screen } from '@testing-library/react' import userEvent from '@testing-library/user-event' import InstanceCard from '@/components/InstanceCard' import type { Instance } from '@/types/instance' +import { BackendType } from '@/types/instance' // Mock the health hook since we're not testing health logic here vi.mock('@/hooks/useInstanceHealth', () => ({ @@ -18,13 +19,13 @@ describe('InstanceCard - Instance Actions and State', () => { const stoppedInstance: Instance = { name: 'test-instance', status: 'stopped', - options: { model: 'test-model.gguf' } + options: { backend_type: BackendType.LLAMA_CPP, backend_options: { model: 'test-model.gguf' } } } const runningInstance: Instance = { name: 'running-instance', status: 'running', - options: { model: 'running-model.gguf' } + options: { backend_type: BackendType.LLAMA_CPP, backend_options: { model: 'running-model.gguf' } } } beforeEach(() => { diff --git a/webui/src/components/__tests__/InstanceList.test.tsx b/webui/src/components/__tests__/InstanceList.test.tsx index a38f873..cbd9e3f 100644 --- a/webui/src/components/__tests__/InstanceList.test.tsx +++ b/webui/src/components/__tests__/InstanceList.test.tsx @@ -5,6 +5,7 @@ import InstanceList from '@/components/InstanceList' import { InstancesProvider } from '@/contexts/InstancesContext' import { instancesApi } from '@/lib/api' import type { Instance } from '@/types/instance' +import { BackendType } from '@/types/instance' import { AuthProvider } from '@/contexts/AuthContext' // Mock the API @@ -44,9 +45,9 @@ describe('InstanceList - State Management and UI Logic', () => { const mockEditInstance = vi.fn() const mockInstances: Instance[] = [ - { name: 'instance-1', status: 'stopped', options: { model: 'model1.gguf' } }, - { name: 'instance-2', status: 'running', options: { model: 'model2.gguf' } }, - { name: 'instance-3', status: 'stopped', options: { model: 'model3.gguf' } } + { name: 'instance-1', status: 'stopped', options: { backend_type: BackendType.LLAMA_CPP, backend_options: { model: 'model1.gguf' } } }, + { name: 'instance-2', status: 'running', options: { backend_type: BackendType.LLAMA_CPP, backend_options: { model: 'model2.gguf' } } }, + { name: 'instance-3', status: 'stopped', options: { backend_type: BackendType.LLAMA_CPP, backend_options: { model: 'model3.gguf' } } } ] const DUMMY_API_KEY = 'test-api-key-123' diff --git a/webui/src/components/__tests__/InstanceModal.test.tsx b/webui/src/components/__tests__/InstanceModal.test.tsx index 8468379..0644c3c 100644 --- a/webui/src/components/__tests__/InstanceModal.test.tsx +++ b/webui/src/components/__tests__/InstanceModal.test.tsx @@ -3,6 +3,7 @@ import { render, screen, waitFor } from '@testing-library/react' import userEvent from '@testing-library/user-event' import InstanceDialog from '@/components/InstanceDialog' import type { Instance } from '@/types/instance' +import { BackendType } from '@/types/instance' describe('InstanceModal - Form Logic and Validation', () => { const mockOnSave = vi.fn() @@ -91,6 +92,7 @@ afterEach(() => { expect(mockOnSave).toHaveBeenCalledWith('my-instance', { auto_restart: true, // Default value + backend_type: BackendType.LLAMA_CPP }) }) @@ -136,8 +138,8 @@ afterEach(() => { name: 'existing-instance', status: 'stopped', options: { - model: 'test-model.gguf', - gpu_layers: 10, + backend_type: BackendType.LLAMA_CPP, + backend_options: { model: 'test-model.gguf', gpu_layers: 10 }, auto_restart: false } } @@ -177,8 +179,8 @@ afterEach(() => { await user.click(screen.getByTestId('dialog-save-button')) expect(mockOnSave).toHaveBeenCalledWith('existing-instance', { - model: 'test-model.gguf', - gpu_layers: 10, + backend_type: BackendType.LLAMA_CPP, + backend_options: { model: 'test-model.gguf', gpu_layers: 10 }, auto_restart: false }) }) @@ -271,6 +273,7 @@ afterEach(() => { expect(mockOnSave).toHaveBeenCalledWith('test-instance', { auto_restart: true, + backend_type: BackendType.LLAMA_CPP, max_restarts: 5, restart_delay: 10 }) @@ -321,6 +324,7 @@ afterEach(() => { // Should only include non-empty values expect(mockOnSave).toHaveBeenCalledWith('clean-instance', { auto_restart: true, // Only this default value should be included + backend_type: BackendType.LLAMA_CPP }) }) @@ -345,7 +349,8 @@ afterEach(() => { expect(mockOnSave).toHaveBeenCalledWith('numeric-test', { auto_restart: true, - gpu_layers: 15, // Should be number, not string + backend_type: BackendType.LLAMA_CPP, + backend_options: { gpu_layers: 15 }, // Should be number, not string }) }) }) diff --git a/webui/src/contexts/__tests__/InstancesContext.test.tsx b/webui/src/contexts/__tests__/InstancesContext.test.tsx index c271730..c60455f 100644 --- a/webui/src/contexts/__tests__/InstancesContext.test.tsx +++ b/webui/src/contexts/__tests__/InstancesContext.test.tsx @@ -4,6 +4,7 @@ import type { ReactNode } from "react"; import { InstancesProvider, useInstances } from "@/contexts/InstancesContext"; import { instancesApi } from "@/lib/api"; import type { Instance } from "@/types/instance"; +import { BackendType } from "@/types/instance"; import { AuthProvider } from "../AuthContext"; // Mock the API module @@ -47,13 +48,13 @@ function TestComponent() { {/* Action buttons for testing with specific instances */}