diff --git a/pkg/instance/instance.go b/pkg/instance/instance.go index 3781fcf..ee6757d 100644 --- a/pkg/instance/instance.go +++ b/pkg/instance/instance.go @@ -1,255 +1,315 @@ package instance import ( - "context" "encoding/json" "fmt" - "io" "llamactl/pkg/backends" "llamactl/pkg/config" "log" - "net/http" "net/http/httputil" - "net/url" - "os/exec" - "sync" - "sync/atomic" "time" ) -// TimeProvider interface allows for testing with mock time -type TimeProvider interface { - Now() time.Time -} +// Instance represents a running instance of the llama server +type Instance struct { + Name string `json:"name"` + Created int64 `json:"created,omitempty"` // Unix timestamp when the instance was created -// realTimeProvider implements TimeProvider using the actual time -type realTimeProvider struct{} - -func (realTimeProvider) Now() time.Time { - return time.Now() -} - -// Process represents a running instance of the llama server -type Process struct { - Name string `json:"name"` - options *CreateInstanceOptions `json:"-"` + // Global configuration globalInstanceSettings *config.InstancesConfig globalBackendSettings *config.BackendConfig - localNodeName string `json:"-"` // Name of the local node for remote detection + localNodeName string `json:"-"` // Name of the local node for remote detection - // Status - Status InstanceStatus `json:"status"` - onStatusChange func(oldStatus, newStatus InstanceStatus) + status *status `json:"-"` + options *options `json:"-"` - // Creation time - Created int64 `json:"created,omitempty"` // Unix timestamp when the instance was created - - // Logging file - logger *InstanceLogger `json:"-"` - - // internal - cmd *exec.Cmd `json:"-"` // Command to run the instance - ctx context.Context `json:"-"` // Context for managing the instance lifecycle - cancel context.CancelFunc `json:"-"` // Function to cancel the context - stdout io.ReadCloser `json:"-"` // Standard output stream - stderr io.ReadCloser `json:"-"` // Standard error stream - mu sync.RWMutex `json:"-"` // RWMutex for better read/write separation - restarts int `json:"-"` // Number of restarts - proxy *httputil.ReverseProxy `json:"-"` // Reverse proxy for this instance - - // Restart control - restartCancel context.CancelFunc `json:"-"` // Cancel function for pending restarts - monitorDone chan struct{} `json:"-"` // Channel to signal monitor goroutine completion - - // Timeout management - lastRequestTime atomic.Int64 // Unix timestamp of last request - timeProvider TimeProvider `json:"-"` // Time provider for testing + // Components (can be nil for remote instances) + process *process `json:"-"` + proxy *proxy `json:"-"` + logger *logger `json:"-"` } -// NewInstance creates a new instance with the given name, log path, and options -func NewInstance(name string, globalBackendSettings *config.BackendConfig, globalInstanceSettings *config.InstancesConfig, options *CreateInstanceOptions, localNodeName string, onStatusChange func(oldStatus, newStatus InstanceStatus)) *Process { +// New creates a new instance with the given name, log path, options and local node name +func New(name string, globalBackendSettings *config.BackendConfig, globalInstanceSettings *config.InstancesConfig, opts *Options, localNodeName string, onStatusChange func(oldStatus, newStatus Status)) *Instance { // Validate and copy options - options.ValidateAndApplyDefaults(name, globalInstanceSettings) + opts.validateAndApplyDefaults(name, globalInstanceSettings) - // Create the instance logger - logger := NewInstanceLogger(name, globalInstanceSettings.LogsDir) + // Create status wrapper + status := newStatus(Stopped) + status.onStatusChange = onStatusChange - return &Process{ + // Create options wrapper + options := newOptions(opts) + + instance := &Instance{ Name: name, options: options, globalInstanceSettings: globalInstanceSettings, globalBackendSettings: globalBackendSettings, localNodeName: localNodeName, - logger: logger, - timeProvider: realTimeProvider{}, Created: time.Now().Unix(), - Status: Stopped, - onStatusChange: onStatusChange, + status: status, + } + + // Only create logger, proxy, and process for local instances + if !instance.IsRemote() { + instance.logger = newLogger(name, globalInstanceSettings.LogsDir) + instance.proxy = newProxy(instance) + instance.process = newProcess(instance) + } + + return instance +} + +// Start starts the instance +func (i *Instance) Start() error { + if i.process == nil { + return fmt.Errorf("instance %s has no process component (remote instances cannot be started locally)", i.Name) + } + return i.process.start() +} + +// Stop stops the instance +func (i *Instance) Stop() error { + if i.process == nil { + return fmt.Errorf("instance %s has no process component (remote instances cannot be stopped locally)", i.Name) + } + return i.process.stop() +} + +// Restart restarts the instance +func (i *Instance) Restart() error { + if i.process == nil { + return fmt.Errorf("instance %s has no process component (remote instances cannot be restarted locally)", i.Name) + } + return i.process.restart() +} + +// WaitForHealthy waits for the instance to become healthy +func (i *Instance) WaitForHealthy(timeout int) error { + if i.process == nil { + return fmt.Errorf("instance %s has no process component (remote instances cannot be health checked locally)", i.Name) + } + return i.process.waitForHealthy(timeout) +} + +// GetOptions returns the current options +func (i *Instance) GetOptions() *Options { + if i.options == nil { + return nil + } + return i.options.get() +} + +// GetStatus returns the current status +func (i *Instance) GetStatus() Status { + if i.status == nil { + return Stopped + } + return i.status.get() +} + +// SetStatus sets the status +func (i *Instance) SetStatus(s Status) { + if i.status != nil { + i.status.set(s) } } -func (i *Process) GetOptions() *CreateInstanceOptions { - i.mu.RLock() - defer i.mu.RUnlock() - return i.options +// IsRunning returns true if the status is Running +func (i *Instance) IsRunning() bool { + if i.status == nil { + return false + } + return i.status.isRunning() } -func (i *Process) GetPort() int { - i.mu.RLock() - defer i.mu.RUnlock() - if i.options != nil { - switch i.options.BackendType { +func (i *Instance) GetPort() int { + opts := i.GetOptions() + if opts != nil { + switch opts.BackendType { case backends.BackendTypeLlamaCpp: - if i.options.LlamaServerOptions != nil { - return i.options.LlamaServerOptions.Port + if opts.LlamaServerOptions != nil { + return opts.LlamaServerOptions.Port } case backends.BackendTypeMlxLm: - if i.options.MlxServerOptions != nil { - return i.options.MlxServerOptions.Port + if opts.MlxServerOptions != nil { + return opts.MlxServerOptions.Port } case backends.BackendTypeVllm: - if i.options.VllmServerOptions != nil { - return i.options.VllmServerOptions.Port + if opts.VllmServerOptions != nil { + return opts.VllmServerOptions.Port } } } return 0 } -func (i *Process) GetHost() string { - i.mu.RLock() - defer i.mu.RUnlock() - if i.options != nil { - switch i.options.BackendType { +func (i *Instance) GetHost() string { + opts := i.GetOptions() + if opts != nil { + switch opts.BackendType { case backends.BackendTypeLlamaCpp: - if i.options.LlamaServerOptions != nil { - return i.options.LlamaServerOptions.Host + if opts.LlamaServerOptions != nil { + return opts.LlamaServerOptions.Host } case backends.BackendTypeMlxLm: - if i.options.MlxServerOptions != nil { - return i.options.MlxServerOptions.Host + if opts.MlxServerOptions != nil { + return opts.MlxServerOptions.Host } case backends.BackendTypeVllm: - if i.options.VllmServerOptions != nil { - return i.options.VllmServerOptions.Host + if opts.VllmServerOptions != nil { + return opts.VllmServerOptions.Host } } } return "" } -func (i *Process) SetOptions(options *CreateInstanceOptions) { - i.mu.Lock() - defer i.mu.Unlock() - - if options == nil { +// SetOptions sets the options +func (i *Instance) SetOptions(opts *Options) { + if opts == nil { log.Println("Warning: Attempted to set nil options on instance", i.Name) return } // Preserve the original nodes to prevent changing instance location - if i.options != nil && i.options.Nodes != nil { - options.Nodes = i.options.Nodes + if i.options != nil && i.options.get() != nil { + opts.Nodes = i.options.get().Nodes } // Validate and copy options - options.ValidateAndApplyDefaults(i.Name, i.globalInstanceSettings) + opts.validateAndApplyDefaults(i.Name, i.globalInstanceSettings) + + if i.options != nil { + i.options.set(opts) + } - i.options = options // Clear the proxy so it gets recreated with new options - i.proxy = nil + if i.proxy != nil { + i.proxy.clear() + } } // SetTimeProvider sets a custom time provider for testing -func (i *Process) SetTimeProvider(tp TimeProvider) { - i.timeProvider = tp +func (i *Instance) SetTimeProvider(tp TimeProvider) { + if i.proxy != nil { + i.proxy.setTimeProvider(tp) + } } -// GetProxy returns the reverse proxy for this instance, creating it if needed -func (i *Process) GetProxy() (*httputil.ReverseProxy, error) { - i.mu.Lock() - defer i.mu.Unlock() - - if i.proxy != nil { - return i.proxy, nil - } - - if i.options == nil { - return nil, fmt.Errorf("instance %s has no options set", i.Name) +// GetProxy returns the reverse proxy for this instance +func (i *Instance) GetProxy() (*httputil.ReverseProxy, error) { + if i.proxy == nil { + return nil, fmt.Errorf("instance %s has no proxy component", i.Name) } // Remote instances should not use local proxy - they are handled by RemoteInstanceProxy - if len(i.options.Nodes) > 0 && i.options.Nodes[0] != i.localNodeName { - return nil, fmt.Errorf("instance %s is a remote instance and should not use local proxy", i.Name) + opts := i.GetOptions() + if opts != nil && len(opts.Nodes) > 0 { + if _, isLocal := opts.Nodes[i.localNodeName]; !isLocal { + return nil, fmt.Errorf("instance %s is a remote instance and should not use local proxy", i.Name) + } + } + + return i.proxy.get() +} + +func (i *Instance) IsRemote() bool { + opts := i.GetOptions() + if opts == nil { + return false + } + + // If no nodes specified, it's a local instance + if len(opts.Nodes) == 0 { + return false + } + + // If the local node is in the nodes map, treat it as a local instance + if _, isLocal := opts.Nodes[i.localNodeName]; isLocal { + return false + } + + // Otherwise, it's a remote instance + return true +} + +// GetLogs retrieves the last n lines of logs from the instance +func (i *Instance) GetLogs(num_lines int) (string, error) { + if i.logger == nil { + return "", fmt.Errorf("instance %s has no logger (remote instances don't have logs)", i.Name) + } + return i.logger.getLogs(num_lines) +} + +// LastRequestTime returns the last request time as a Unix timestamp +func (i *Instance) LastRequestTime() int64 { + if i.proxy == nil { + return 0 + } + return i.proxy.getLastRequestTime() +} + +// UpdateLastRequestTime updates the last request access time for the instance via proxy +func (i *Instance) UpdateLastRequestTime() { + if i.proxy != nil { + i.proxy.updateLastRequestTime() + } +} + +// ShouldTimeout checks if the instance should timeout based on idle time +func (i *Instance) ShouldTimeout() bool { + if i.proxy == nil { + return false + } + return i.proxy.shouldTimeout() +} + +// getBackendHostPort extracts the host and port from instance options +// Returns the configured host and port for the backend +func (i *Instance) getBackendHostPort() (string, int) { + opts := i.GetOptions() + if opts == nil { + return "localhost", 0 } var host string var port int - switch i.options.BackendType { + switch opts.BackendType { case backends.BackendTypeLlamaCpp: - if i.options.LlamaServerOptions != nil { - host = i.options.LlamaServerOptions.Host - port = i.options.LlamaServerOptions.Port + if opts.LlamaServerOptions != nil { + host = opts.LlamaServerOptions.Host + port = opts.LlamaServerOptions.Port } case backends.BackendTypeMlxLm: - if i.options.MlxServerOptions != nil { - host = i.options.MlxServerOptions.Host - port = i.options.MlxServerOptions.Port + if opts.MlxServerOptions != nil { + host = opts.MlxServerOptions.Host + port = opts.MlxServerOptions.Port } case backends.BackendTypeVllm: - if i.options.VllmServerOptions != nil { - host = i.options.VllmServerOptions.Host - port = i.options.VllmServerOptions.Port + if opts.VllmServerOptions != nil { + host = opts.VllmServerOptions.Host + port = opts.VllmServerOptions.Port } } - targetURL, err := url.Parse(fmt.Sprintf("http://%s:%d", host, port)) - if err != nil { - return nil, fmt.Errorf("failed to parse target URL for instance %s: %w", i.Name, err) + if host == "" { + host = "localhost" } - proxy := httputil.NewSingleHostReverseProxy(targetURL) - - var responseHeaders map[string]string - switch i.options.BackendType { - case backends.BackendTypeLlamaCpp: - responseHeaders = i.globalBackendSettings.LlamaCpp.ResponseHeaders - case backends.BackendTypeVllm: - responseHeaders = i.globalBackendSettings.VLLM.ResponseHeaders - case backends.BackendTypeMlxLm: - responseHeaders = i.globalBackendSettings.MLX.ResponseHeaders - } - proxy.ModifyResponse = func(resp *http.Response) error { - // Remove CORS headers from llama-server response to avoid conflicts - // llamactl will add its own CORS headers - resp.Header.Del("Access-Control-Allow-Origin") - resp.Header.Del("Access-Control-Allow-Methods") - resp.Header.Del("Access-Control-Allow-Headers") - resp.Header.Del("Access-Control-Allow-Credentials") - resp.Header.Del("Access-Control-Max-Age") - resp.Header.Del("Access-Control-Expose-Headers") - - for key, value := range responseHeaders { - resp.Header.Set(key, value) - } - return nil - } - - i.proxy = proxy - - return i.proxy, nil + return host, port } // MarshalJSON implements json.Marshaler for Instance -func (i *Process) MarshalJSON() ([]byte, error) { - // Use read lock since we're only reading data - i.mu.RLock() - defer i.mu.RUnlock() +func (i *Instance) MarshalJSON() ([]byte, error) { + // Get options + opts := i.GetOptions() // Determine if docker is enabled for this instance's backend var dockerEnabled bool - if i.options != nil { - switch i.options.BackendType { + if opts != nil { + switch opts.BackendType { case backends.BackendTypeLlamaCpp: if i.globalBackendSettings != nil && i.globalBackendSettings.LlamaCpp.Docker != nil && i.globalBackendSettings.LlamaCpp.Docker.Enabled { dockerEnabled = true @@ -263,69 +323,69 @@ func (i *Process) MarshalJSON() ([]byte, error) { } } - // Use anonymous struct to avoid recursion - type Alias Process return json.Marshal(&struct { - *Alias - Options *CreateInstanceOptions `json:"options,omitempty"` - DockerEnabled bool `json:"docker_enabled,omitempty"` + Name string `json:"name"` + Status *status `json:"status"` + Created int64 `json:"created,omitempty"` + Options *options `json:"options,omitempty"` + DockerEnabled bool `json:"docker_enabled,omitempty"` }{ - Alias: (*Alias)(i), + Name: i.Name, + Status: i.status, + Created: i.Created, Options: i.options, DockerEnabled: dockerEnabled, }) } // UnmarshalJSON implements json.Unmarshaler for Instance -func (i *Process) UnmarshalJSON(data []byte) error { - // Use anonymous struct to avoid recursion - type Alias Process +func (i *Instance) UnmarshalJSON(data []byte) error { + // Explicitly deserialize to match MarshalJSON format aux := &struct { - *Alias - Options *CreateInstanceOptions `json:"options,omitempty"` - }{ - Alias: (*Alias)(i), - } + Name string `json:"name"` + Status *status `json:"status"` + Created int64 `json:"created,omitempty"` + Options *options `json:"options,omitempty"` + }{} if err := json.Unmarshal(data, aux); err != nil { return err } + // Set the fields + i.Name = aux.Name + i.Created = aux.Created + i.status = aux.Status + i.options = aux.Options + // Handle options with validation and defaults - if aux.Options != nil { - aux.Options.ValidateAndApplyDefaults(i.Name, i.globalInstanceSettings) - i.options = aux.Options + if i.options != nil { + opts := i.options.get() + if opts != nil { + opts.validateAndApplyDefaults(i.Name, i.globalInstanceSettings) + } } - // Initialize fields that are not serialized - if i.timeProvider == nil { - i.timeProvider = realTimeProvider{} + // Initialize fields that are not serialized or may be nil + if i.status == nil { + i.status = newStatus(Stopped) } - if i.logger == nil && i.globalInstanceSettings != nil { - i.logger = NewInstanceLogger(i.Name, i.globalInstanceSettings.LogsDir) + if i.options == nil { + i.options = newOptions(&Options{}) + } + + // Only create logger, proxy, and process for non-remote instances + if !i.IsRemote() { + if i.logger == nil && i.globalInstanceSettings != nil { + i.logger = newLogger(i.Name, i.globalInstanceSettings.LogsDir) + } + if i.proxy == nil { + i.proxy = newProxy(i) + } + if i.process == nil { + i.process = newProcess(i) + } } return nil } - -func (i *Process) IsRemote() bool { - i.mu.RLock() - defer i.mu.RUnlock() - - if i.options == nil { - return false - } - - // If no nodes specified, it's a local instance - if len(i.options.Nodes) == 0 { - return false - } - - // If the first node is the local node, treat it as a local instance - if i.options.Nodes[0] == i.localNodeName { - return false - } - - // Otherwise, it's a remote instance - return true -} diff --git a/pkg/instance/instance_test.go b/pkg/instance/instance_test.go index 0402ac4..375c210 100644 --- a/pkg/instance/instance_test.go +++ b/pkg/instance/instance_test.go @@ -8,6 +8,7 @@ import ( "llamactl/pkg/instance" "llamactl/pkg/testutil" "testing" + "time" ) func TestNewInstance(t *testing.T) { @@ -33,7 +34,7 @@ func TestNewInstance(t *testing.T) { DefaultRestartDelay: 5, } - options := &instance.CreateInstanceOptions{ + options := &instance.Options{ BackendType: backends.BackendTypeLlamaCpp, LlamaServerOptions: &llamacpp.LlamaServerOptions{ Model: "/path/to/model.gguf", @@ -42,9 +43,9 @@ func TestNewInstance(t *testing.T) { } // Mock onStatusChange function - mockOnStatusChange := func(oldStatus, newStatus instance.InstanceStatus) {} + mockOnStatusChange := func(oldStatus, newStatus instance.Status) {} - inst := instance.NewInstance("test-instance", backendConfig, globalSettings, options, "main", mockOnStatusChange) + inst := instance.New("test-instance", backendConfig, globalSettings, options, "main", mockOnStatusChange) if inst.Name != "test-instance" { t.Errorf("Expected name 'test-instance', got %q", inst.Name) @@ -102,7 +103,7 @@ func TestNewInstance_WithRestartOptions(t *testing.T) { maxRestarts := 10 restartDelay := 15 - options := &instance.CreateInstanceOptions{ + options := &instance.Options{ AutoRestart: &autoRestart, MaxRestarts: &maxRestarts, RestartDelay: &restartDelay, @@ -113,9 +114,9 @@ func TestNewInstance_WithRestartOptions(t *testing.T) { } // Mock onStatusChange function - mockOnStatusChange := func(oldStatus, newStatus instance.InstanceStatus) {} + mockOnStatusChange := func(oldStatus, newStatus instance.Status) {} - instance := instance.NewInstance("test-instance", backendConfig, globalSettings, options, "main", mockOnStatusChange) + instance := instance.New("test-instance", backendConfig, globalSettings, options, "main", mockOnStatusChange) opts := instance.GetOptions() // Check that explicit values override defaults @@ -153,7 +154,7 @@ func TestSetOptions(t *testing.T) { DefaultRestartDelay: 5, } - initialOptions := &instance.CreateInstanceOptions{ + initialOptions := &instance.Options{ BackendType: backends.BackendTypeLlamaCpp, LlamaServerOptions: &llamacpp.LlamaServerOptions{ Model: "/path/to/model.gguf", @@ -162,12 +163,12 @@ func TestSetOptions(t *testing.T) { } // Mock onStatusChange function - mockOnStatusChange := func(oldStatus, newStatus instance.InstanceStatus) {} + mockOnStatusChange := func(oldStatus, newStatus instance.Status) {} - inst := instance.NewInstance("test-instance", backendConfig, globalSettings, initialOptions, "main", mockOnStatusChange) + inst := instance.New("test-instance", backendConfig, globalSettings, initialOptions, "main", mockOnStatusChange) // Update options - newOptions := &instance.CreateInstanceOptions{ + newOptions := &instance.Options{ BackendType: backends.BackendTypeLlamaCpp, LlamaServerOptions: &llamacpp.LlamaServerOptions{ Model: "/path/to/new-model.gguf", @@ -207,22 +208,22 @@ func TestSetOptions_PreservesNodes(t *testing.T) { } // Create instance with initial nodes - initialOptions := &instance.CreateInstanceOptions{ + initialOptions := &instance.Options{ BackendType: backends.BackendTypeLlamaCpp, - Nodes: []string{"worker1"}, + Nodes: map[string]struct{}{"worker1": {}}, LlamaServerOptions: &llamacpp.LlamaServerOptions{ Model: "/path/to/model.gguf", Port: 8080, }, } - mockOnStatusChange := func(oldStatus, newStatus instance.InstanceStatus) {} - inst := instance.NewInstance("test-instance", backendConfig, globalSettings, initialOptions, "main", mockOnStatusChange) + mockOnStatusChange := func(oldStatus, newStatus instance.Status) {} + inst := instance.New("test-instance", backendConfig, globalSettings, initialOptions, "main", mockOnStatusChange) // Try to update with different nodes - updatedOptions := &instance.CreateInstanceOptions{ + updatedOptions := &instance.Options{ BackendType: backends.BackendTypeLlamaCpp, - Nodes: []string{"worker2"}, // Attempt to change node + Nodes: map[string]struct{}{"worker2": {}}, // Attempt to change node LlamaServerOptions: &llamacpp.LlamaServerOptions{ Model: "/path/to/new-model.gguf", Port: 8081, @@ -233,8 +234,8 @@ func TestSetOptions_PreservesNodes(t *testing.T) { opts := inst.GetOptions() // Nodes should remain unchanged - if len(opts.Nodes) != 1 || opts.Nodes[0] != "worker1" { - t.Errorf("Expected nodes to remain ['worker1'], got %v", opts.Nodes) + if _, exists := opts.Nodes["worker1"]; len(opts.Nodes) != 1 || !exists { + t.Errorf("Expected nodes to contain 'worker1', got %v", opts.Nodes) } // Other options should be updated @@ -263,7 +264,7 @@ func TestGetProxy(t *testing.T) { LogsDir: "/tmp/test", } - options := &instance.CreateInstanceOptions{ + options := &instance.Options{ BackendType: backends.BackendTypeLlamaCpp, LlamaServerOptions: &llamacpp.LlamaServerOptions{ Host: "localhost", @@ -272,9 +273,9 @@ func TestGetProxy(t *testing.T) { } // Mock onStatusChange function - mockOnStatusChange := func(oldStatus, newStatus instance.InstanceStatus) {} + mockOnStatusChange := func(oldStatus, newStatus instance.Status) {} - inst := instance.NewInstance("test-instance", backendConfig, globalSettings, options, "main", mockOnStatusChange) + inst := instance.New("test-instance", backendConfig, globalSettings, options, "main", mockOnStatusChange) // Get proxy for the first time proxy1, err := inst.GetProxy() @@ -318,7 +319,7 @@ func TestMarshalJSON(t *testing.T) { DefaultRestartDelay: 5, } - options := &instance.CreateInstanceOptions{ + options := &instance.Options{ BackendType: backends.BackendTypeLlamaCpp, LlamaServerOptions: &llamacpp.LlamaServerOptions{ Model: "/path/to/model.gguf", @@ -327,9 +328,9 @@ func TestMarshalJSON(t *testing.T) { } // Mock onStatusChange function - mockOnStatusChange := func(oldStatus, newStatus instance.InstanceStatus) {} + mockOnStatusChange := func(oldStatus, newStatus instance.Status) {} - instance := instance.NewInstance("test-instance", backendConfig, globalSettings, options, "main", mockOnStatusChange) + instance := instance.New("test-instance", backendConfig, globalSettings, options, "main", mockOnStatusChange) data, err := json.Marshal(instance) if err != nil { @@ -397,7 +398,7 @@ func TestUnmarshalJSON(t *testing.T) { } }` - var inst instance.Process + var inst instance.Instance err := json.Unmarshal([]byte(jsonData), &inst) if err != nil { t.Fatalf("JSON unmarshal failed: %v", err) @@ -434,7 +435,7 @@ func TestUnmarshalJSON(t *testing.T) { } } -func TestCreateInstanceOptionsValidation(t *testing.T) { +func TestCreateOptionsValidation(t *testing.T) { tests := []struct { name string maxRestarts *int @@ -486,7 +487,7 @@ func TestCreateInstanceOptionsValidation(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - options := &instance.CreateInstanceOptions{ + options := &instance.Options{ MaxRestarts: tt.maxRestarts, RestartDelay: tt.restartDelay, BackendType: backends.BackendTypeLlamaCpp, @@ -496,9 +497,9 @@ func TestCreateInstanceOptionsValidation(t *testing.T) { } // Mock onStatusChange function - mockOnStatusChange := func(oldStatus, newStatus instance.InstanceStatus) {} + mockOnStatusChange := func(oldStatus, newStatus instance.Status) {} - instance := instance.NewInstance("test", backendConfig, globalSettings, options, "main", mockOnStatusChange) + instance := instance.New("test", backendConfig, globalSettings, options, "main", mockOnStatusChange) opts := instance.GetOptions() if opts.MaxRestarts == nil { @@ -515,3 +516,303 @@ func TestCreateInstanceOptionsValidation(t *testing.T) { }) } } + +func TestStatusChangeCallback(t *testing.T) { + backendConfig := &config.BackendConfig{ + LlamaCpp: config.BackendSettings{Command: "llama-server"}, + } + globalSettings := &config.InstancesConfig{LogsDir: "/tmp/test"} + options := &instance.Options{ + BackendType: backends.BackendTypeLlamaCpp, + LlamaServerOptions: &llamacpp.LlamaServerOptions{ + Model: "/path/to/model.gguf", + }, + } + + var callbackOldStatus, callbackNewStatus instance.Status + callbackCalled := false + + onStatusChange := func(oldStatus, newStatus instance.Status) { + callbackOldStatus = oldStatus + callbackNewStatus = newStatus + callbackCalled = true + } + + inst := instance.New("test", backendConfig, globalSettings, options, "main", onStatusChange) + + inst.SetStatus(instance.Running) + + if !callbackCalled { + t.Error("Expected status change callback to be called") + } + if callbackOldStatus != instance.Stopped { + t.Errorf("Expected old status Stopped, got %v", callbackOldStatus) + } + if callbackNewStatus != instance.Running { + t.Errorf("Expected new status Running, got %v", callbackNewStatus) + } +} + +func TestSetOptions_NodesPreserved(t *testing.T) { + backendConfig := &config.BackendConfig{ + LlamaCpp: config.BackendSettings{Command: "llama-server"}, + } + globalSettings := &config.InstancesConfig{LogsDir: "/tmp/test"} + + tests := []struct { + name string + initialNodes map[string]struct{} + updateNodes map[string]struct{} + expectedNodes map[string]struct{} + }{ + { + name: "nil nodes preserved as nil", + initialNodes: nil, + updateNodes: map[string]struct{}{"worker1": {}}, + expectedNodes: nil, + }, + { + name: "empty nodes preserved as empty", + initialNodes: map[string]struct{}{}, + updateNodes: map[string]struct{}{"worker1": {}}, + expectedNodes: map[string]struct{}{}, + }, + { + name: "existing nodes preserved", + initialNodes: map[string]struct{}{"worker1": {}, "worker2": {}}, + updateNodes: map[string]struct{}{"worker3": {}}, + expectedNodes: map[string]struct{}{"worker1": {}, "worker2": {}}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + options := &instance.Options{ + BackendType: backends.BackendTypeLlamaCpp, + Nodes: tt.initialNodes, + LlamaServerOptions: &llamacpp.LlamaServerOptions{ + Model: "/path/to/model.gguf", + }, + } + + inst := instance.New("test", backendConfig, globalSettings, options, "main", nil) + + // Attempt to update nodes (should be ignored) + updateOptions := &instance.Options{ + BackendType: backends.BackendTypeLlamaCpp, + Nodes: tt.updateNodes, + LlamaServerOptions: &llamacpp.LlamaServerOptions{ + Model: "/path/to/new-model.gguf", + }, + } + inst.SetOptions(updateOptions) + + opts := inst.GetOptions() + + // Verify nodes are preserved + if len(opts.Nodes) != len(tt.expectedNodes) { + t.Errorf("Expected %d nodes, got %d", len(tt.expectedNodes), len(opts.Nodes)) + } + for node := range tt.expectedNodes { + if _, exists := opts.Nodes[node]; !exists { + t.Errorf("Expected node %s to exist", node) + } + } + + // Verify other options were updated + if opts.LlamaServerOptions.Model != "/path/to/new-model.gguf" { + t.Errorf("Expected model to be updated to '/path/to/new-model.gguf', got %q", opts.LlamaServerOptions.Model) + } + }) + } +} + +func TestProcessErrorCases(t *testing.T) { + backendConfig := &config.BackendConfig{ + LlamaCpp: config.BackendSettings{Command: "llama-server"}, + } + globalSettings := &config.InstancesConfig{LogsDir: "/tmp/test"} + options := &instance.Options{ + BackendType: backends.BackendTypeLlamaCpp, + LlamaServerOptions: &llamacpp.LlamaServerOptions{ + Model: "/path/to/model.gguf", + }, + } + + inst := instance.New("test", backendConfig, globalSettings, options, "main", nil) + + // Stop when not running should return error + err := inst.Stop() + if err == nil { + t.Error("Expected error when stopping non-running instance") + } + + // Simulate running state + inst.SetStatus(instance.Running) + + // Start when already running should return error + err = inst.Start() + if err == nil { + t.Error("Expected error when starting already running instance") + } +} + +func TestRemoteInstanceOperations(t *testing.T) { + backendConfig := &config.BackendConfig{ + LlamaCpp: config.BackendSettings{Command: "llama-server"}, + } + globalSettings := &config.InstancesConfig{LogsDir: "/tmp/test"} + options := &instance.Options{ + BackendType: backends.BackendTypeLlamaCpp, + Nodes: map[string]struct{}{"remote-node": {}}, // Remote instance + LlamaServerOptions: &llamacpp.LlamaServerOptions{ + Model: "/path/to/model.gguf", + }, + } + + inst := instance.New("remote-test", backendConfig, globalSettings, options, "main", nil) + + if !inst.IsRemote() { + t.Error("Expected instance to be remote") + } + + // Start should fail for remote instance + if err := inst.Start(); err == nil { + t.Error("Expected error when starting remote instance") + } + + // Stop should fail for remote instance + if err := inst.Stop(); err == nil { + t.Error("Expected error when stopping remote instance") + } + + // Restart should fail for remote instance + if err := inst.Restart(); err == nil { + t.Error("Expected error when restarting remote instance") + } + + // GetProxy should fail for remote instance + if _, err := inst.GetProxy(); err == nil { + t.Error("Expected error when getting proxy for remote instance") + } + + // GetLogs should fail for remote instance + if _, err := inst.GetLogs(10); err == nil { + t.Error("Expected error when getting logs for remote instance") + } +} + +func TestProxyClearOnOptionsChange(t *testing.T) { + backendConfig := &config.BackendConfig{ + LlamaCpp: config.BackendSettings{Command: "llama-server"}, + } + globalSettings := &config.InstancesConfig{LogsDir: "/tmp/test"} + options := &instance.Options{ + BackendType: backends.BackendTypeLlamaCpp, + LlamaServerOptions: &llamacpp.LlamaServerOptions{ + Host: "localhost", + Port: 8080, + }, + } + + inst := instance.New("test", backendConfig, globalSettings, options, "main", nil) + + // Get initial proxy + proxy1, err := inst.GetProxy() + if err != nil { + t.Fatalf("Failed to get initial proxy: %v", err) + } + + // Update options (should clear proxy) + newOptions := &instance.Options{ + BackendType: backends.BackendTypeLlamaCpp, + LlamaServerOptions: &llamacpp.LlamaServerOptions{ + Host: "localhost", + Port: 8081, // Different port + }, + } + inst.SetOptions(newOptions) + + // Get proxy again - should be recreated with new port + proxy2, err := inst.GetProxy() + if err != nil { + t.Fatalf("Failed to get proxy after options change: %v", err) + } + + // Proxies should be different instances (recreated) + if proxy1 == proxy2 { + t.Error("Expected proxy to be recreated after options change") + } +} + +func TestIdleTimeout(t *testing.T) { + backendConfig := &config.BackendConfig{ + LlamaCpp: config.BackendSettings{Command: "llama-server"}, + } + globalSettings := &config.InstancesConfig{LogsDir: "/tmp/test"} + + t.Run("not running never times out", func(t *testing.T) { + timeout := 1 + inst := instance.New("test", backendConfig, globalSettings, &instance.Options{ + BackendType: backends.BackendTypeLlamaCpp, + IdleTimeout: &timeout, + LlamaServerOptions: &llamacpp.LlamaServerOptions{ + Model: "/path/to/model.gguf", + }, + }, "main", nil) + + if inst.ShouldTimeout() { + t.Error("Non-running instance should never timeout") + } + }) + + t.Run("no timeout configured", func(t *testing.T) { + inst := instance.New("test", backendConfig, globalSettings, &instance.Options{ + BackendType: backends.BackendTypeLlamaCpp, + IdleTimeout: nil, // No timeout + LlamaServerOptions: &llamacpp.LlamaServerOptions{ + Model: "/path/to/model.gguf", + }, + }, "main", nil) + inst.SetStatus(instance.Running) + + if inst.ShouldTimeout() { + t.Error("Instance with no timeout configured should not timeout") + } + }) + + t.Run("timeout exceeded", func(t *testing.T) { + timeout := 1 // 1 minute + inst := instance.New("test", backendConfig, globalSettings, &instance.Options{ + BackendType: backends.BackendTypeLlamaCpp, + IdleTimeout: &timeout, + LlamaServerOptions: &llamacpp.LlamaServerOptions{ + Model: "/path/to/model.gguf", + }, + }, "main", nil) + inst.SetStatus(instance.Running) + + // Use mock time provider + mockTime := &mockTimeProvider{currentTime: time.Now().Unix()} + inst.SetTimeProvider(mockTime) + + // Set last request time to now + inst.UpdateLastRequestTime() + + // Advance time by 2 minutes (exceeds 1 minute timeout) + mockTime.currentTime = time.Now().Add(2 * time.Minute).Unix() + + if !inst.ShouldTimeout() { + t.Error("Instance should timeout when idle time exceeds configured timeout") + } + }) +} + +// mockTimeProvider for timeout testing +type mockTimeProvider struct { + currentTime int64 // Unix timestamp +} + +func (m *mockTimeProvider) Now() time.Time { + return time.Unix(m.currentTime, 0) +} diff --git a/pkg/instance/lifecycle.go b/pkg/instance/lifecycle.go deleted file mode 100644 index fa37dc3..0000000 --- a/pkg/instance/lifecycle.go +++ /dev/null @@ -1,417 +0,0 @@ -package instance - -import ( - "context" - "fmt" - "log" - "net/http" - "os" - "os/exec" - "runtime" - "syscall" - "time" - - "llamactl/pkg/backends" - "llamactl/pkg/config" -) - -// Start starts the llama server instance and returns an error if it fails. -func (i *Process) Start() error { - i.mu.Lock() - defer i.mu.Unlock() - - if i.IsRunning() { - return fmt.Errorf("instance %s is already running", i.Name) - } - - // Safety check: ensure options are valid - if i.options == nil { - return fmt.Errorf("instance %s has no options set", i.Name) - } - - // Reset restart counter when manually starting (not during auto-restart) - // We can detect auto-restart by checking if restartCancel is set - if i.restartCancel == nil { - i.restarts = 0 - } - - // Initialize last request time to current time when starting - i.lastRequestTime.Store(i.timeProvider.Now().Unix()) - - // Create context before building command (needed for CommandContext) - i.ctx, i.cancel = context.WithCancel(context.Background()) - - // Create log files - if err := i.logger.Create(); err != nil { - return fmt.Errorf("failed to create log files: %w", err) - } - - // Build command using backend-specific methods - cmd, cmdErr := i.buildCommand() - if cmdErr != nil { - return fmt.Errorf("failed to build command: %w", cmdErr) - } - i.cmd = cmd - - if runtime.GOOS != "windows" { - setProcAttrs(i.cmd) - } - - var err error - i.stdout, err = i.cmd.StdoutPipe() - if err != nil { - i.logger.Close() - return fmt.Errorf("failed to get stdout pipe: %w", err) - } - i.stderr, err = i.cmd.StderrPipe() - if err != nil { - i.stdout.Close() - i.logger.Close() - return fmt.Errorf("failed to get stderr pipe: %w", err) - } - - if err := i.cmd.Start(); err != nil { - return fmt.Errorf("failed to start instance %s: %w", i.Name, err) - } - - i.SetStatus(Running) - - // Create channel for monitor completion signaling - i.monitorDone = make(chan struct{}) - - go i.logger.readOutput(i.stdout) - go i.logger.readOutput(i.stderr) - - go i.monitorProcess() - - return nil -} - -// Stop terminates the subprocess -func (i *Process) Stop() error { - i.mu.Lock() - - if !i.IsRunning() { - // Even if not running, cancel any pending restart - if i.restartCancel != nil { - i.restartCancel() - i.restartCancel = nil - log.Printf("Cancelled pending restart for instance %s", i.Name) - } - i.mu.Unlock() - return fmt.Errorf("instance %s is not running", i.Name) - } - - // Cancel any pending restart - if i.restartCancel != nil { - i.restartCancel() - i.restartCancel = nil - } - - // Set status to stopped first to signal intentional stop - i.SetStatus(Stopped) - - // Clean up the proxy - i.proxy = nil - - // Get the monitor done channel before releasing the lock - monitorDone := i.monitorDone - - i.mu.Unlock() - - // Stop the process with SIGINT if cmd exists - if i.cmd != nil && i.cmd.Process != nil { - if err := i.cmd.Process.Signal(syscall.SIGINT); err != nil { - log.Printf("Failed to send SIGINT to instance %s: %v", i.Name, err) - } - } - - // If no process exists, we can return immediately - if i.cmd == nil || monitorDone == nil { - i.logger.Close() - return nil - } - - select { - case <-monitorDone: - // Process exited normally - case <-time.After(30 * time.Second): - // Force kill if it doesn't exit within 30 seconds - if i.cmd != nil && i.cmd.Process != nil { - killErr := i.cmd.Process.Kill() - if killErr != nil { - log.Printf("Failed to force kill instance %s: %v", i.Name, killErr) - } - log.Printf("Instance %s did not stop in time, force killed", i.Name) - - // Wait a bit more for the monitor to finish after force kill - select { - case <-monitorDone: - // Monitor completed after force kill - case <-time.After(2 * time.Second): - log.Printf("Warning: Monitor goroutine did not complete after force kill for instance %s", i.Name) - } - } - } - - i.logger.Close() - - return nil -} - -func (i *Process) LastRequestTime() int64 { - return i.lastRequestTime.Load() -} - -func (i *Process) WaitForHealthy(timeout int) error { - if !i.IsRunning() { - return fmt.Errorf("instance %s is not running", i.Name) - } - - if timeout <= 0 { - timeout = 30 // Default to 30 seconds if no timeout is specified - } - - ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeout)*time.Second) - defer cancel() - - // Get instance options to build the health check URL - opts := i.GetOptions() - if opts == nil { - return fmt.Errorf("instance %s has no options set", i.Name) - } - - // Build the health check URL directly - var host string - var port int - switch opts.BackendType { - case backends.BackendTypeLlamaCpp: - if opts.LlamaServerOptions != nil { - host = opts.LlamaServerOptions.Host - port = opts.LlamaServerOptions.Port - } - case backends.BackendTypeMlxLm: - if opts.MlxServerOptions != nil { - host = opts.MlxServerOptions.Host - port = opts.MlxServerOptions.Port - } - case backends.BackendTypeVllm: - if opts.VllmServerOptions != nil { - host = opts.VllmServerOptions.Host - port = opts.VllmServerOptions.Port - } - } - if host == "" { - host = "localhost" - } - healthURL := fmt.Sprintf("http://%s:%d/health", host, port) - - // Create a dedicated HTTP client for health checks - client := &http.Client{ - Timeout: 5 * time.Second, // 5 second timeout per request - } - - // Helper function to check health directly - checkHealth := func() bool { - req, err := http.NewRequestWithContext(ctx, "GET", healthURL, nil) - if err != nil { - return false - } - - resp, err := client.Do(req) - if err != nil { - return false - } - defer resp.Body.Close() - - return resp.StatusCode == http.StatusOK - } - - // Try immediate check first - if checkHealth() { - return nil // Instance is healthy - } - - // If immediate check failed, start polling - ticker := time.NewTicker(1 * time.Second) - defer ticker.Stop() - - for { - select { - case <-ctx.Done(): - return fmt.Errorf("timeout waiting for instance %s to become healthy after %d seconds", i.Name, timeout) - case <-ticker.C: - if checkHealth() { - return nil // Instance is healthy - } - // Continue polling - } - } -} - -func (i *Process) monitorProcess() { - defer func() { - i.mu.Lock() - if i.monitorDone != nil { - close(i.monitorDone) - i.monitorDone = nil - } - i.mu.Unlock() - }() - - err := i.cmd.Wait() - - i.mu.Lock() - - // Check if the instance was intentionally stopped - if !i.IsRunning() { - i.mu.Unlock() - return - } - - i.SetStatus(Stopped) - i.logger.Close() - - // Cancel any existing restart context since we're handling a new exit - if i.restartCancel != nil { - i.restartCancel() - i.restartCancel = nil - } - - // Log the exit - if err != nil { - log.Printf("Instance %s crashed with error: %v", i.Name, err) - // Handle restart while holding the lock, then release it - i.handleRestart() - } else { - log.Printf("Instance %s exited cleanly", i.Name) - i.mu.Unlock() - } -} - -// handleRestart manages the restart process while holding the lock -func (i *Process) handleRestart() { - // Validate restart conditions and get safe parameters - shouldRestart, maxRestarts, restartDelay := i.validateRestartConditions() - if !shouldRestart { - i.SetStatus(Failed) - i.mu.Unlock() - return - } - - i.restarts++ - log.Printf("Auto-restarting instance %s (attempt %d/%d) in %v", - i.Name, i.restarts, maxRestarts, time.Duration(restartDelay)*time.Second) - - // Create a cancellable context for the restart delay - restartCtx, cancel := context.WithCancel(context.Background()) - i.restartCancel = cancel - - // Release the lock before sleeping - i.mu.Unlock() - - // Use context-aware sleep so it can be cancelled - select { - case <-time.After(time.Duration(restartDelay) * time.Second): - // Sleep completed normally, continue with restart - case <-restartCtx.Done(): - // Restart was cancelled - log.Printf("Restart cancelled for instance %s", i.Name) - return - } - - // Restart the instance - if err := i.Start(); err != nil { - log.Printf("Failed to restart instance %s: %v", i.Name, err) - } else { - log.Printf("Successfully restarted instance %s", i.Name) - // Clear the cancel function - i.mu.Lock() - i.restartCancel = nil - i.mu.Unlock() - } -} - -// validateRestartConditions checks if the instance should be restarted and returns the parameters -func (i *Process) validateRestartConditions() (shouldRestart bool, maxRestarts int, restartDelay int) { - if i.options == nil { - log.Printf("Instance %s not restarting: options are nil", i.Name) - return false, 0, 0 - } - - if i.options.AutoRestart == nil || !*i.options.AutoRestart { - log.Printf("Instance %s not restarting: AutoRestart is disabled", i.Name) - return false, 0, 0 - } - - if i.options.MaxRestarts == nil { - log.Printf("Instance %s not restarting: MaxRestarts is nil", i.Name) - return false, 0, 0 - } - - if i.options.RestartDelay == nil { - log.Printf("Instance %s not restarting: RestartDelay is nil", i.Name) - return false, 0, 0 - } - - // Values are already validated during unmarshaling/SetOptions - maxRestarts = *i.options.MaxRestarts - restartDelay = *i.options.RestartDelay - - if i.restarts >= maxRestarts { - log.Printf("Instance %s exceeded max restart attempts (%d)", i.Name, maxRestarts) - return false, 0, 0 - } - - return true, maxRestarts, restartDelay -} - -// buildCommand builds the command to execute using backend-specific logic -func (i *Process) buildCommand() (*exec.Cmd, error) { - // Get backend configuration - backendConfig, err := i.getBackendConfig() - if err != nil { - return nil, err - } - - // Build the environment variables - env := i.options.BuildEnvironment(backendConfig) - - // Get the command to execute - command := i.options.GetCommand(backendConfig) - - // Build command arguments - args := i.options.BuildCommandArgs(backendConfig) - - // Create the exec.Cmd - cmd := exec.CommandContext(i.ctx, command, args...) - - // Start with host environment variables - cmd.Env = os.Environ() - - // Add/override with backend-specific environment variables - for k, v := range env { - cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", k, v)) - } - - return cmd, nil -} - -// getBackendConfig resolves the backend configuration for the current instance -func (i *Process) getBackendConfig() (*config.BackendSettings, error) { - var backendTypeStr string - - switch i.options.BackendType { - case backends.BackendTypeLlamaCpp: - backendTypeStr = "llama-cpp" - case backends.BackendTypeMlxLm: - backendTypeStr = "mlx" - case backends.BackendTypeVllm: - backendTypeStr = "vllm" - default: - return nil, fmt.Errorf("unsupported backend type: %s", i.options.BackendType) - } - - settings := i.globalBackendSettings.GetBackendSettings(backendTypeStr) - return &settings, nil -} diff --git a/pkg/instance/logging.go b/pkg/instance/logger.go similarity index 77% rename from pkg/instance/logging.go rename to pkg/instance/logger.go index 5432556..f836411 100644 --- a/pkg/instance/logging.go +++ b/pkg/instance/logger.go @@ -6,25 +6,30 @@ import ( "io" "os" "strings" + "sync" "time" ) -type InstanceLogger struct { +type logger struct { name string logDir string logFile *os.File logFilePath string + mu sync.RWMutex } -func NewInstanceLogger(name string, logDir string) *InstanceLogger { - return &InstanceLogger{ +func newLogger(name string, logDir string) *logger { + return &logger{ name: name, logDir: logDir, } } -// Create creates and opens the log files for stdout and stderr -func (i *InstanceLogger) Create() error { +// create creates and opens the log files for stdout and stderr +func (i *logger) create() error { + i.mu.Lock() + defer i.mu.Unlock() + if i.logDir == "" { return fmt.Errorf("logDir is empty for instance %s", i.name) } @@ -51,17 +56,16 @@ func (i *InstanceLogger) Create() error { return nil } -// GetLogs retrieves the last n lines of logs from the instance -func (i *Process) GetLogs(num_lines int) (string, error) { +// getLogs retrieves the last n lines of logs from the instance +func (i *logger) getLogs(num_lines int) (string, error) { i.mu.RLock() - logFileName := i.logger.logFilePath - i.mu.RUnlock() + defer i.mu.RUnlock() - if logFileName == "" { - return "", fmt.Errorf("log file not created for instance %s", i.Name) + if i.logFilePath == "" { + return "", fmt.Errorf("log file not created for instance %s", i.name) } - file, err := os.Open(logFileName) + file, err := os.Open(i.logFilePath) if err != nil { return "", fmt.Errorf("failed to open log file: %w", err) } @@ -93,8 +97,11 @@ func (i *Process) GetLogs(num_lines int) (string, error) { return strings.Join(lines[start:], "\n"), nil } -// closeLogFile closes the log files -func (i *InstanceLogger) Close() { +// close closes the log files +func (i *logger) close() { + i.mu.Lock() + defer i.mu.Unlock() + if i.logFile != nil { timestamp := time.Now().Format("2006-01-02 15:04:05") fmt.Fprintf(i.logFile, "=== Instance %s stopped at %s ===\n\n", i.name, timestamp) @@ -104,7 +111,7 @@ func (i *InstanceLogger) Close() { } // readOutput reads from the given reader and writes lines to the log file -func (i *InstanceLogger) readOutput(reader io.ReadCloser) { +func (i *logger) readOutput(reader io.ReadCloser) { defer reader.Close() scanner := bufio.NewScanner(reader) diff --git a/pkg/instance/options.go b/pkg/instance/options.go index 439f426..1dddb15 100644 --- a/pkg/instance/options.go +++ b/pkg/instance/options.go @@ -10,9 +10,12 @@ import ( "llamactl/pkg/config" "log" "maps" + "slices" + "sync" ) -type CreateInstanceOptions struct { +// Options contains the actual configuration (exported - this is the public API). +type Options struct { // Auto restart AutoRestart *bool `json:"auto_restart,omitempty"` MaxRestarts *int `json:"max_restarts,omitempty"` @@ -27,7 +30,7 @@ type CreateInstanceOptions struct { BackendType backends.BackendType `json:"backend_type"` BackendOptions map[string]any `json:"backend_options,omitempty"` - Nodes []string `json:"nodes,omitempty"` + Nodes map[string]struct{} `json:"-"` // Backend-specific options LlamaServerOptions *llamacpp.LlamaServerOptions `json:"-"` @@ -35,11 +38,57 @@ type CreateInstanceOptions struct { VllmServerOptions *vllm.VllmServerOptions `json:"-"` } -// UnmarshalJSON implements custom JSON unmarshaling for CreateInstanceOptions -func (c *CreateInstanceOptions) UnmarshalJSON(data []byte) error { +// options wraps Options with thread-safe access (unexported). +type options struct { + mu sync.RWMutex + opts *Options +} + +// newOptions creates a new options wrapper with the given Options +func newOptions(opts *Options) *options { + return &options{ + opts: opts, + } +} + +// get returns a copy of the current options +func (o *options) get() *Options { + o.mu.RLock() + defer o.mu.RUnlock() + return o.opts +} + +// set updates the options +func (o *options) set(opts *Options) { + o.mu.Lock() + defer o.mu.Unlock() + o.opts = opts +} + +// MarshalJSON implements json.Marshaler for options wrapper +func (o *options) MarshalJSON() ([]byte, error) { + o.mu.RLock() + defer o.mu.RUnlock() + return o.opts.MarshalJSON() +} + +// UnmarshalJSON implements json.Unmarshaler for options wrapper +func (o *options) UnmarshalJSON(data []byte) error { + o.mu.Lock() + defer o.mu.Unlock() + + if o.opts == nil { + o.opts = &Options{} + } + return o.opts.UnmarshalJSON(data) +} + +// UnmarshalJSON implements custom JSON unmarshaling for Options +func (c *Options) UnmarshalJSON(data []byte) error { // Use anonymous struct to avoid recursion - type Alias CreateInstanceOptions + type Alias Options aux := &struct { + Nodes []string `json:"nodes,omitempty"` // Accept JSON array *Alias }{ Alias: (*Alias)(c), @@ -49,6 +98,14 @@ func (c *CreateInstanceOptions) UnmarshalJSON(data []byte) error { return err } + // Convert nodes array to map + if len(aux.Nodes) > 0 { + c.Nodes = make(map[string]struct{}, len(aux.Nodes)) + for _, node := range aux.Nodes { + c.Nodes[node] = struct{}{} + } + } + // Parse backend-specific options switch c.BackendType { case backends.BackendTypeLlamaCpp: @@ -95,16 +152,27 @@ func (c *CreateInstanceOptions) UnmarshalJSON(data []byte) error { return nil } -// MarshalJSON implements custom JSON marshaling for CreateInstanceOptions -func (c *CreateInstanceOptions) MarshalJSON() ([]byte, error) { +// MarshalJSON implements custom JSON marshaling for Options +func (c *Options) MarshalJSON() ([]byte, error) { // Use anonymous struct to avoid recursion - type Alias CreateInstanceOptions + type Alias Options aux := struct { + Nodes []string `json:"nodes,omitempty"` // Output as JSON array *Alias }{ Alias: (*Alias)(c), } + // Convert nodes map to array (sorted for consistency) + if len(c.Nodes) > 0 { + aux.Nodes = make([]string, 0, len(c.Nodes)) + for node := range c.Nodes { + aux.Nodes = append(aux.Nodes, node) + } + // Sort for consistent output + slices.Sort(aux.Nodes) + } + // Convert backend-specific options back to BackendOptions map for JSON switch c.BackendType { case backends.BackendTypeLlamaCpp: @@ -154,8 +222,8 @@ func (c *CreateInstanceOptions) MarshalJSON() ([]byte, error) { return json.Marshal(aux) } -// ValidateAndApplyDefaults validates the instance options and applies constraints -func (c *CreateInstanceOptions) ValidateAndApplyDefaults(name string, globalSettings *config.InstancesConfig) { +// validateAndApplyDefaults validates the instance options and applies constraints +func (c *Options) validateAndApplyDefaults(name string, globalSettings *config.InstancesConfig) { // Validate and apply constraints if c.MaxRestarts != nil && *c.MaxRestarts < 0 { log.Printf("Instance %s MaxRestarts value (%d) cannot be negative, setting to 0", name, *c.MaxRestarts) @@ -193,7 +261,8 @@ func (c *CreateInstanceOptions) ValidateAndApplyDefaults(name string, globalSett } } -func (c *CreateInstanceOptions) GetCommand(backendConfig *config.BackendSettings) string { +// getCommand builds the command to run the backend +func (c *Options) getCommand(backendConfig *config.BackendSettings) string { if backendConfig.Docker != nil && backendConfig.Docker.Enabled && c.BackendType != backends.BackendTypeMlxLm { return "docker" @@ -202,8 +271,8 @@ func (c *CreateInstanceOptions) GetCommand(backendConfig *config.BackendSettings return backendConfig.Command } -// BuildCommandArgs builds command line arguments for the backend -func (c *CreateInstanceOptions) BuildCommandArgs(backendConfig *config.BackendSettings) []string { +// buildCommandArgs builds command line arguments for the backend +func (c *Options) buildCommandArgs(backendConfig *config.BackendSettings) []string { var args []string @@ -246,7 +315,8 @@ func (c *CreateInstanceOptions) BuildCommandArgs(backendConfig *config.BackendSe return args } -func (c *CreateInstanceOptions) BuildEnvironment(backendConfig *config.BackendSettings) map[string]string { +// buildEnvironment builds the environment variables for the backend process +func (c *Options) buildEnvironment(backendConfig *config.BackendSettings) map[string]string { env := map[string]string{} if backendConfig.Environment != nil { diff --git a/pkg/instance/process.go b/pkg/instance/process.go new file mode 100644 index 0000000..9c7cfec --- /dev/null +++ b/pkg/instance/process.go @@ -0,0 +1,446 @@ +package instance + +import ( + "context" + "fmt" + "io" + "log" + "net/http" + "os" + "os/exec" + "runtime" + "sync" + "syscall" + "time" + + "llamactl/pkg/backends" + "llamactl/pkg/config" +) + +// process manages the OS process lifecycle for a local instance. +// process owns its complete lifecycle including auto-restart logic. +type process struct { + instance *Instance // Back-reference for SetStatus, GetOptions + + mu sync.RWMutex + cmd *exec.Cmd + ctx context.Context + cancel context.CancelFunc + stdout io.ReadCloser + stderr io.ReadCloser + restarts int + restartCancel context.CancelFunc + monitorDone chan struct{} +} + +// newProcess creates a new process component for the given instance +func newProcess(instance *Instance) *process { + return &process{ + instance: instance, + } +} + +// start starts the OS process and returns an error if it fails. +func (p *process) start() error { + p.mu.Lock() + defer p.mu.Unlock() + + if p.instance.IsRunning() { + return fmt.Errorf("instance %s is already running", p.instance.Name) + } + + // Safety check: ensure options are valid + if p.instance.options == nil { + return fmt.Errorf("instance %s has no options set", p.instance.Name) + } + + // Reset restart counter when manually starting (not during auto-restart) + // We can detect auto-restart by checking if restartCancel is set + if p.restartCancel == nil { + p.restarts = 0 + } + + // Initialize last request time to current time when starting + if p.instance.proxy != nil { + p.instance.proxy.updateLastRequestTime() + } + + // Create context before building command (needed for CommandContext) + p.ctx, p.cancel = context.WithCancel(context.Background()) + + // Create log files + if err := p.instance.logger.create(); err != nil { + return fmt.Errorf("failed to create log files: %w", err) + } + + // Build command using backend-specific methods + cmd, cmdErr := p.buildCommand() + if cmdErr != nil { + return fmt.Errorf("failed to build command: %w", cmdErr) + } + p.cmd = cmd + + if runtime.GOOS != "windows" { + setProcAttrs(p.cmd) + } + + var err error + p.stdout, err = p.cmd.StdoutPipe() + if err != nil { + p.instance.logger.close() + return fmt.Errorf("failed to get stdout pipe: %w", err) + } + p.stderr, err = p.cmd.StderrPipe() + if err != nil { + p.stdout.Close() + p.instance.logger.close() + return fmt.Errorf("failed to get stderr pipe: %w", err) + } + + if err := p.cmd.Start(); err != nil { + return fmt.Errorf("failed to start instance %s: %w", p.instance.Name, err) + } + + p.instance.SetStatus(Running) + + // Create channel for monitor completion signaling + p.monitorDone = make(chan struct{}) + + go p.instance.logger.readOutput(p.stdout) + go p.instance.logger.readOutput(p.stderr) + + go p.monitorProcess() + + return nil +} + +// stop terminates the subprocess without restarting +func (p *process) stop() error { + p.mu.Lock() + + if !p.instance.IsRunning() { + // Even if not running, cancel any pending restart + if p.restartCancel != nil { + p.restartCancel() + p.restartCancel = nil + log.Printf("Cancelled pending restart for instance %s", p.instance.Name) + } + p.mu.Unlock() + return fmt.Errorf("instance %s is not running", p.instance.Name) + } + + // Cancel any pending restart + if p.restartCancel != nil { + p.restartCancel() + p.restartCancel = nil + } + + // Set status to stopped first to signal intentional stop + p.instance.SetStatus(Stopped) + + // Get the monitor done channel before releasing the lock + monitorDone := p.monitorDone + + p.mu.Unlock() + + // Stop the process with SIGINT if cmd exists + if p.cmd != nil && p.cmd.Process != nil { + if err := p.cmd.Process.Signal(syscall.SIGINT); err != nil { + log.Printf("Failed to send SIGINT to instance %s: %v", p.instance.Name, err) + } + } + + // If no process exists, we can return immediately + if p.cmd == nil || monitorDone == nil { + p.instance.logger.close() + return nil + } + + select { + case <-monitorDone: + // Process exited normally + case <-time.After(30 * time.Second): + // Force kill if it doesn't exit within 30 seconds + if p.cmd != nil && p.cmd.Process != nil { + killErr := p.cmd.Process.Kill() + if killErr != nil { + log.Printf("Failed to force kill instance %s: %v", p.instance.Name, killErr) + } + log.Printf("Instance %s did not stop in time, force killed", p.instance.Name) + + // Wait a bit more for the monitor to finish after force kill + select { + case <-monitorDone: + // Monitor completed after force kill + case <-time.After(2 * time.Second): + log.Printf("Warning: Monitor goroutine did not complete after force kill for instance %s", p.instance.Name) + } + } + } + + p.instance.logger.close() + + return nil +} + +// restart manually restarts the process (resets restart counter) +func (p *process) restart() error { + // Stop the process first + if err := p.stop(); err != nil { + // If it's not running, that's ok - we'll just start it + if err.Error() != fmt.Sprintf("instance %s is not running", p.instance.Name) { + return fmt.Errorf("failed to stop instance during restart: %w", err) + } + } + + // Reset restart counter for manual restart + p.mu.Lock() + p.restarts = 0 + p.mu.Unlock() + + // Start the process + return p.start() +} + +// waitForHealthy waits for the process to become healthy +func (p *process) waitForHealthy(timeout int) error { + if !p.instance.IsRunning() { + return fmt.Errorf("instance %s is not running", p.instance.Name) + } + + if timeout <= 0 { + timeout = 30 // Default to 30 seconds if no timeout is specified + } + + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeout)*time.Second) + defer cancel() + + // Get host/port from instance + host, port := p.instance.getBackendHostPort() + healthURL := fmt.Sprintf("http://%s:%d/health", host, port) + + // Create a dedicated HTTP client for health checks + client := &http.Client{ + Timeout: 5 * time.Second, // 5 second timeout per request + } + + // Helper function to check health directly + checkHealth := func() bool { + req, err := http.NewRequestWithContext(ctx, "GET", healthURL, nil) + if err != nil { + return false + } + + resp, err := client.Do(req) + if err != nil { + return false + } + defer resp.Body.Close() + + return resp.StatusCode == http.StatusOK + } + + // Try immediate check first + if checkHealth() { + return nil // Instance is healthy + } + + // If immediate check failed, start polling + ticker := time.NewTicker(1 * time.Second) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return fmt.Errorf("timeout waiting for instance %s to become healthy after %d seconds", p.instance.Name, timeout) + case <-ticker.C: + if checkHealth() { + return nil // Instance is healthy + } + // Continue polling + } + } +} + +// monitorProcess monitors the OS process and handles crashes/exits +func (p *process) monitorProcess() { + defer func() { + p.mu.Lock() + if p.monitorDone != nil { + close(p.monitorDone) + p.monitorDone = nil + } + p.mu.Unlock() + }() + + err := p.cmd.Wait() + + p.mu.Lock() + + // Check if the instance was intentionally stopped + if !p.instance.IsRunning() { + p.mu.Unlock() + return + } + + p.instance.SetStatus(Stopped) + p.instance.logger.close() + + // Cancel any existing restart context since we're handling a new exit + if p.restartCancel != nil { + p.restartCancel() + p.restartCancel = nil + } + + // Log the exit + if err != nil { + log.Printf("Instance %s crashed with error: %v", p.instance.Name, err) + // Handle auto-restart logic + p.handleAutoRestart(err) + } else { + log.Printf("Instance %s exited cleanly", p.instance.Name) + p.mu.Unlock() + } +} + +// shouldAutoRestart checks if the process should auto-restart +func (p *process) shouldAutoRestart() bool { + opts := p.instance.GetOptions() + if opts == nil { + log.Printf("Instance %s not restarting: options are nil", p.instance.Name) + return false + } + + if opts.AutoRestart == nil || !*opts.AutoRestart { + log.Printf("Instance %s not restarting: AutoRestart is disabled", p.instance.Name) + return false + } + + if opts.MaxRestarts == nil { + log.Printf("Instance %s not restarting: MaxRestarts is nil", p.instance.Name) + return false + } + + maxRestarts := *opts.MaxRestarts + if p.restarts >= maxRestarts { + log.Printf("Instance %s exceeded max restart attempts (%d)", p.instance.Name, maxRestarts) + return false + } + + return true +} + +// handleAutoRestart manages the auto-restart process +func (p *process) handleAutoRestart(err error) { + // Check if should restart + if !p.shouldAutoRestart() { + p.instance.SetStatus(Failed) + p.mu.Unlock() + return + } + + // Get restart parameters + opts := p.instance.GetOptions() + if opts.RestartDelay == nil { + log.Printf("Instance %s not restarting: RestartDelay is nil", p.instance.Name) + p.instance.SetStatus(Failed) + p.mu.Unlock() + return + } + + restartDelay := *opts.RestartDelay + maxRestarts := *opts.MaxRestarts + + p.restarts++ + log.Printf("Auto-restarting instance %s (attempt %d/%d) in %v", + p.instance.Name, p.restarts, maxRestarts, time.Duration(restartDelay)*time.Second) + + // Create a cancellable context for the restart delay + restartCtx, cancel := context.WithCancel(context.Background()) + p.restartCancel = cancel + + // Release the lock before sleeping + p.mu.Unlock() + + // Use context-aware sleep so it can be cancelled + select { + case <-time.After(time.Duration(restartDelay) * time.Second): + // Sleep completed normally, continue with restart + case <-restartCtx.Done(): + // Restart was cancelled + log.Printf("Restart cancelled for instance %s", p.instance.Name) + return + } + + // Restart the instance + if err := p.start(); err != nil { + log.Printf("Failed to restart instance %s: %v", p.instance.Name, err) + } else { + log.Printf("Successfully restarted instance %s", p.instance.Name) + // Clear the cancel function + p.mu.Lock() + p.restartCancel = nil + p.mu.Unlock() + } +} + +// buildCommand builds the command to execute using backend-specific logic +func (p *process) buildCommand() (*exec.Cmd, error) { + // Get options + opts := p.instance.GetOptions() + if opts == nil { + return nil, fmt.Errorf("instance options are nil") + } + + // Get backend configuration + backendConfig, err := p.getBackendConfig() + if err != nil { + return nil, err + } + + // Build the environment variables + env := opts.buildEnvironment(backendConfig) + + // Get the command to execute + command := opts.getCommand(backendConfig) + + // Build command arguments + args := opts.buildCommandArgs(backendConfig) + + // Create the exec.Cmd + cmd := exec.CommandContext(p.ctx, command, args...) + + // Start with host environment variables + cmd.Env = os.Environ() + + // Add/override with backend-specific environment variables + for k, v := range env { + cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", k, v)) + } + + return cmd, nil +} + +// getBackendConfig resolves the backend configuration for the current instance +func (p *process) getBackendConfig() (*config.BackendSettings, error) { + opts := p.instance.GetOptions() + if opts == nil { + return nil, fmt.Errorf("instance options are nil") + } + + var backendTypeStr string + + switch opts.BackendType { + case backends.BackendTypeLlamaCpp: + backendTypeStr = "llama-cpp" + case backends.BackendTypeMlxLm: + backendTypeStr = "mlx" + case backends.BackendTypeVllm: + backendTypeStr = "vllm" + default: + return nil, fmt.Errorf("unsupported backend type: %s", opts.BackendType) + } + + settings := p.instance.globalBackendSettings.GetBackendSettings(backendTypeStr) + return &settings, nil +} diff --git a/pkg/instance/proxy.go b/pkg/instance/proxy.go new file mode 100644 index 0000000..321095b --- /dev/null +++ b/pkg/instance/proxy.go @@ -0,0 +1,155 @@ +package instance + +import ( + "fmt" + "llamactl/pkg/backends" + "net/http" + "net/http/httputil" + "net/url" + "sync" + "sync/atomic" + "time" +) + +// TimeProvider interface allows for testing with mock time +type TimeProvider interface { + Now() time.Time +} + +// realTimeProvider implements TimeProvider using the actual time +type realTimeProvider struct{} + +func (realTimeProvider) Now() time.Time { + return time.Now() +} + +// proxy manages HTTP reverse proxy and request tracking for an instance. +type proxy struct { + instance *Instance + + mu sync.RWMutex + proxy *httputil.ReverseProxy + proxyOnce sync.Once + proxyErr error + lastRequestTime atomic.Int64 + timeProvider TimeProvider +} + +// newProxy creates a new Proxy for the given instance +func newProxy(instance *Instance) *proxy { + return &proxy{ + instance: instance, + timeProvider: realTimeProvider{}, + } +} + +// get returns the reverse proxy for this instance, creating it if needed. +// Uses sync.Once to ensure thread-safe one-time initialization. +func (p *proxy) get() (*httputil.ReverseProxy, error) { + // sync.Once guarantees buildProxy() is called exactly once + // Other callers block until first initialization completes + p.proxyOnce.Do(func() { + p.proxy, p.proxyErr = p.build() + }) + + return p.proxy, p.proxyErr +} + +// build creates the reverse proxy based on instance options +func (p *proxy) build() (*httputil.ReverseProxy, error) { + options := p.instance.GetOptions() + if options == nil { + return nil, fmt.Errorf("instance %s has no options set", p.instance.Name) + } + + // Remote instances should not use local proxy - they are handled by RemoteInstanceProxy + if len(options.Nodes) > 0 { + return nil, fmt.Errorf("instance %s is a remote instance and should not use local proxy", p.instance.Name) + } + + // Get host/port from process + host, port := p.instance.getBackendHostPort() + + targetURL, err := url.Parse(fmt.Sprintf("http://%s:%d", host, port)) + if err != nil { + return nil, fmt.Errorf("failed to parse target URL for instance %s: %w", p.instance.Name, err) + } + + proxy := httputil.NewSingleHostReverseProxy(targetURL) + + // Get response headers from backend config + var responseHeaders map[string]string + switch options.BackendType { + case backends.BackendTypeLlamaCpp: + responseHeaders = p.instance.globalBackendSettings.LlamaCpp.ResponseHeaders + case backends.BackendTypeVllm: + responseHeaders = p.instance.globalBackendSettings.VLLM.ResponseHeaders + case backends.BackendTypeMlxLm: + responseHeaders = p.instance.globalBackendSettings.MLX.ResponseHeaders + } + + proxy.ModifyResponse = func(resp *http.Response) error { + // Remove CORS headers from backend response to avoid conflicts + // llamactl will add its own CORS headers + resp.Header.Del("Access-Control-Allow-Origin") + resp.Header.Del("Access-Control-Allow-Methods") + resp.Header.Del("Access-Control-Allow-Headers") + resp.Header.Del("Access-Control-Allow-Credentials") + resp.Header.Del("Access-Control-Max-Age") + resp.Header.Del("Access-Control-Expose-Headers") + + for key, value := range responseHeaders { + resp.Header.Set(key, value) + } + return nil + } + + return proxy, nil +} + +// clear resets the proxy, allowing it to be recreated when options change. +func (p *proxy) clear() { + p.mu.Lock() + defer p.mu.Unlock() + + p.proxy = nil + p.proxyErr = nil + p.proxyOnce = sync.Once{} // Reset Once for next GetProxy call +} + +// updateLastRequestTime updates the last request access time for the instance +func (p *proxy) updateLastRequestTime() { + lastRequestTime := p.timeProvider.Now().Unix() + p.lastRequestTime.Store(lastRequestTime) +} + +// getLastRequestTime returns the last request time as a Unix timestamp +func (p *proxy) getLastRequestTime() int64 { + return p.lastRequestTime.Load() +} + +// shouldTimeout checks if the instance should timeout based on idle time +func (p *proxy) shouldTimeout() bool { + if !p.instance.IsRunning() { + return false + } + + options := p.instance.GetOptions() + if options == nil || options.IdleTimeout == nil || *options.IdleTimeout <= 0 { + return false + } + + // Check if the last request time exceeds the idle timeout + lastRequest := p.lastRequestTime.Load() + idleTimeoutMinutes := *options.IdleTimeout + + // Convert timeout from minutes to seconds for comparison + idleTimeoutSeconds := int64(idleTimeoutMinutes * 60) + + return (p.timeProvider.Now().Unix() - lastRequest) > idleTimeoutSeconds +} + +// setTimeProvider sets a custom time provider for testing +func (p *proxy) setTimeProvider(tp TimeProvider) { + p.timeProvider = tp +} diff --git a/pkg/instance/status.go b/pkg/instance/status.go index e07fe03..92e8669 100644 --- a/pkg/instance/status.go +++ b/pkg/instance/status.go @@ -3,48 +3,32 @@ package instance import ( "encoding/json" "log" + "sync" ) -// Enum for instance status -type InstanceStatus int +// Status is the enum for status values (exported). +type Status int const ( - Stopped InstanceStatus = iota + Stopped Status = iota Running Failed ) -var nameToStatus = map[string]InstanceStatus{ +var nameToStatus = map[string]Status{ "stopped": Stopped, "running": Running, "failed": Failed, } -var statusToName = map[InstanceStatus]string{ +var statusToName = map[Status]string{ Stopped: "stopped", Running: "running", Failed: "failed", } -func (p *Process) SetStatus(status InstanceStatus) { - oldStatus := p.Status - p.Status = status - - if p.onStatusChange != nil { - p.onStatusChange(oldStatus, status) - } -} - -func (p *Process) GetStatus() InstanceStatus { - return p.Status -} - -// IsRunning returns true if the status is Running -func (p *Process) IsRunning() bool { - return p.Status == Running -} - -func (s InstanceStatus) MarshalJSON() ([]byte, error) { +// Status enum JSON marshaling methods +func (s Status) MarshalJSON() ([]byte, error) { name, ok := statusToName[s] if !ok { name = "stopped" // Default to "stopped" for unknown status @@ -52,8 +36,8 @@ func (s InstanceStatus) MarshalJSON() ([]byte, error) { return json.Marshal(name) } -// UnmarshalJSON implements json.Unmarshaler -func (s *InstanceStatus) UnmarshalJSON(data []byte) error { +// UnmarshalJSON implements json.Unmarshaler for Status enum +func (s *Status) UnmarshalJSON(data []byte) error { var str string if err := json.Unmarshal(data, &str); err != nil { return err @@ -68,3 +52,61 @@ func (s *InstanceStatus) UnmarshalJSON(data []byte) error { *s = status return nil } + +// status represents the instance status with thread-safe access (unexported). +type status struct { + mu sync.RWMutex + s Status + + // Callback for status changes + onStatusChange func(oldStatus, newStatus Status) +} + +// newStatus creates a new status wrapper with the given initial status +func newStatus(initial Status) *status { + return &status{ + s: initial, + } +} + +// get returns the current status +func (st *status) get() Status { + st.mu.RLock() + defer st.mu.RUnlock() + return st.s +} + +// set updates the status and triggers the onStatusChange callback if set +func (st *status) set(newStatus Status) { + st.mu.Lock() + oldStatus := st.s + st.s = newStatus + callback := st.onStatusChange + st.mu.Unlock() + + // Call the callback outside the lock to avoid potential deadlocks + if callback != nil { + callback(oldStatus, newStatus) + } +} + +// isRunning returns true if the status is Running +func (st *status) isRunning() bool { + st.mu.RLock() + defer st.mu.RUnlock() + return st.s == Running +} + +// MarshalJSON implements json.Marshaler for status wrapper +func (st *status) MarshalJSON() ([]byte, error) { + st.mu.RLock() + defer st.mu.RUnlock() + return st.s.MarshalJSON() +} + +// UnmarshalJSON implements json.Unmarshaler for status wrapper +func (st *status) UnmarshalJSON(data []byte) error { + st.mu.Lock() + defer st.mu.Unlock() + return st.s.UnmarshalJSON(data) +} diff --git a/pkg/instance/timeout.go b/pkg/instance/timeout.go deleted file mode 100644 index 94cdc16..0000000 --- a/pkg/instance/timeout.go +++ /dev/null @@ -1,28 +0,0 @@ -package instance - -// UpdateLastRequestTime updates the last request access time for the instance via proxy -func (i *Process) UpdateLastRequestTime() { - i.mu.Lock() - defer i.mu.Unlock() - - lastRequestTime := i.timeProvider.Now().Unix() - i.lastRequestTime.Store(lastRequestTime) -} - -func (i *Process) ShouldTimeout() bool { - i.mu.RLock() - defer i.mu.RUnlock() - - if !i.IsRunning() || i.options.IdleTimeout == nil || *i.options.IdleTimeout <= 0 { - return false - } - - // Check if the last request time exceeds the idle timeout - lastRequest := i.lastRequestTime.Load() - idleTimeoutMinutes := *i.options.IdleTimeout - - // Convert timeout from minutes to seconds for comparison - idleTimeoutSeconds := int64(idleTimeoutMinutes * 60) - - return (i.timeProvider.Now().Unix() - lastRequest) > idleTimeoutSeconds -} diff --git a/pkg/instance/timeout_test.go b/pkg/instance/timeout_test.go deleted file mode 100644 index c4cf6ae..0000000 --- a/pkg/instance/timeout_test.go +++ /dev/null @@ -1,274 +0,0 @@ -package instance_test - -import ( - "llamactl/pkg/backends" - "llamactl/pkg/backends/llamacpp" - "llamactl/pkg/config" - "llamactl/pkg/instance" - "llamactl/pkg/testutil" - "sync/atomic" - "testing" - "time" -) - -// MockTimeProvider implements TimeProvider for testing -type MockTimeProvider struct { - currentTime atomic.Int64 // Unix timestamp -} - -func NewMockTimeProvider(t time.Time) *MockTimeProvider { - m := &MockTimeProvider{} - m.currentTime.Store(t.Unix()) - return m -} - -func (m *MockTimeProvider) Now() time.Time { - return time.Unix(m.currentTime.Load(), 0) -} - -func (m *MockTimeProvider) SetTime(t time.Time) { - m.currentTime.Store(t.Unix()) -} - -// Timeout-related tests - -func TestUpdateLastRequestTime(t *testing.T) { - backendConfig := &config.BackendConfig{ - LlamaCpp: config.BackendSettings{ - Command: "llama-server", - }, - MLX: config.BackendSettings{ - Command: "mlx_lm.server", - }, - } - - globalSettings := &config.InstancesConfig{ - LogsDir: "/tmp/test", - } - - options := &instance.CreateInstanceOptions{ - BackendType: backends.BackendTypeLlamaCpp, - LlamaServerOptions: &llamacpp.LlamaServerOptions{ - Model: "/path/to/model.gguf", - }, - } - - // Mock onStatusChange function - mockOnStatusChange := func(oldStatus, newStatus instance.InstanceStatus) {} - - inst := instance.NewInstance("test-instance", backendConfig, globalSettings, options, "main", mockOnStatusChange) - - // Test that UpdateLastRequestTime doesn't panic - inst.UpdateLastRequestTime() -} - -func TestShouldTimeout_NotRunning(t *testing.T) { - backendConfig := &config.BackendConfig{ - LlamaCpp: config.BackendSettings{ - Command: "llama-server", - }, - MLX: config.BackendSettings{ - Command: "mlx_lm.server", - }, - } - - globalSettings := &config.InstancesConfig{ - LogsDir: "/tmp/test", - } - - idleTimeout := 1 // 1 minute - options := &instance.CreateInstanceOptions{ - IdleTimeout: &idleTimeout, - BackendType: backends.BackendTypeLlamaCpp, - LlamaServerOptions: &llamacpp.LlamaServerOptions{ - Model: "/path/to/model.gguf", - }, - } - - // Mock onStatusChange function - mockOnStatusChange := func(oldStatus, newStatus instance.InstanceStatus) {} - - inst := instance.NewInstance("test-instance", backendConfig, globalSettings, options, "main", mockOnStatusChange) - - // Instance is not running, should not timeout regardless of configuration - if inst.ShouldTimeout() { - t.Error("Non-running instance should never timeout") - } -} - -func TestShouldTimeout_NoTimeoutConfigured(t *testing.T) { - backendConfig := &config.BackendConfig{ - LlamaCpp: config.BackendSettings{ - Command: "llama-server", - }, - MLX: config.BackendSettings{ - Command: "mlx_lm.server", - }, - } - - globalSettings := &config.InstancesConfig{ - LogsDir: "/tmp/test", - } - - tests := []struct { - name string - idleTimeout *int - }{ - {"nil timeout", nil}, - {"zero timeout", testutil.IntPtr(0)}, - {"negative timeout", testutil.IntPtr(-5)}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // Mock onStatusChange function - mockOnStatusChange := func(oldStatus, newStatus instance.InstanceStatus) {} - - options := &instance.CreateInstanceOptions{ - IdleTimeout: tt.idleTimeout, - BackendType: backends.BackendTypeLlamaCpp, - LlamaServerOptions: &llamacpp.LlamaServerOptions{ - Model: "/path/to/model.gguf", - }, - } - - inst := instance.NewInstance("test-instance", backendConfig, globalSettings, options, "main", mockOnStatusChange) - // Simulate running state - inst.SetStatus(instance.Running) - - if inst.ShouldTimeout() { - t.Errorf("Instance with %s should not timeout", tt.name) - } - }) - } -} - -func TestShouldTimeout_WithinTimeLimit(t *testing.T) { - backendConfig := &config.BackendConfig{ - LlamaCpp: config.BackendSettings{ - Command: "llama-server", - }, - MLX: config.BackendSettings{ - Command: "mlx_lm.server", - }, - } - - globalSettings := &config.InstancesConfig{ - LogsDir: "/tmp/test", - } - - idleTimeout := 5 // 5 minutes - options := &instance.CreateInstanceOptions{ - IdleTimeout: &idleTimeout, - BackendType: backends.BackendTypeLlamaCpp, - LlamaServerOptions: &llamacpp.LlamaServerOptions{ - Model: "/path/to/model.gguf", - }, - } - - // Mock onStatusChange function - mockOnStatusChange := func(oldStatus, newStatus instance.InstanceStatus) {} - - inst := instance.NewInstance("test-instance", backendConfig, globalSettings, options, "main", mockOnStatusChange) - inst.SetStatus(instance.Running) - - // Update last request time to now - inst.UpdateLastRequestTime() - - // Should not timeout immediately - if inst.ShouldTimeout() { - t.Error("Instance should not timeout when last request was recent") - } -} - -func TestShouldTimeout_ExceedsTimeLimit(t *testing.T) { - backendConfig := &config.BackendConfig{ - LlamaCpp: config.BackendSettings{ - Command: "llama-server", - }, - MLX: config.BackendSettings{ - Command: "mlx_lm.server", - }, - } - - globalSettings := &config.InstancesConfig{ - LogsDir: "/tmp/test", - } - - idleTimeout := 1 // 1 minute - options := &instance.CreateInstanceOptions{ - IdleTimeout: &idleTimeout, - BackendType: backends.BackendTypeLlamaCpp, - LlamaServerOptions: &llamacpp.LlamaServerOptions{ - Model: "/path/to/model.gguf", - }, - } - - // Mock onStatusChange function - mockOnStatusChange := func(oldStatus, newStatus instance.InstanceStatus) {} - - inst := instance.NewInstance("test-instance", backendConfig, globalSettings, options, "main", mockOnStatusChange) - inst.SetStatus(instance.Running) - - // Use MockTimeProvider to simulate old last request time - mockTime := NewMockTimeProvider(time.Now()) - inst.SetTimeProvider(mockTime) - - // Set last request time to now - inst.UpdateLastRequestTime() - - // Advance time by 2 minutes (exceeds 1 minute timeout) - mockTime.SetTime(time.Now().Add(2 * time.Minute)) - - if !inst.ShouldTimeout() { - t.Error("Instance should timeout when last request exceeds idle timeout") - } -} - -func TestTimeoutConfiguration_Validation(t *testing.T) { - backendConfig := &config.BackendConfig{ - LlamaCpp: config.BackendSettings{ - Command: "llama-server", - }, - MLX: config.BackendSettings{ - Command: "mlx_lm.server", - }, - } - - globalSettings := &config.InstancesConfig{ - LogsDir: "/tmp/test", - } - - tests := []struct { - name string - inputTimeout *int - expectedTimeout int - }{ - {"default value when nil", nil, 0}, - {"positive value", testutil.IntPtr(10), 10}, - {"zero value", testutil.IntPtr(0), 0}, - {"negative value gets corrected", testutil.IntPtr(-5), 0}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - options := &instance.CreateInstanceOptions{ - IdleTimeout: tt.inputTimeout, - BackendType: backends.BackendTypeLlamaCpp, - LlamaServerOptions: &llamacpp.LlamaServerOptions{ - Model: "/path/to/model.gguf", - }, - } - - // Mock onStatusChange function - mockOnStatusChange := func(oldStatus, newStatus instance.InstanceStatus) {} - - inst := instance.NewInstance("test-instance", backendConfig, globalSettings, options, "main", mockOnStatusChange) - opts := inst.GetOptions() - - if opts.IdleTimeout == nil || *opts.IdleTimeout != tt.expectedTimeout { - t.Errorf("Expected IdleTimeout %d, got %v", tt.expectedTimeout, opts.IdleTimeout) - } - }) - } -} diff --git a/pkg/manager/manager.go b/pkg/manager/manager.go index c402659..894b8df 100644 --- a/pkg/manager/manager.go +++ b/pkg/manager/manager.go @@ -16,35 +16,35 @@ import ( // InstanceManager defines the interface for managing instances of the llama server. type InstanceManager interface { - ListInstances() ([]*instance.Process, error) - CreateInstance(name string, options *instance.CreateInstanceOptions) (*instance.Process, error) - GetInstance(name string) (*instance.Process, error) - UpdateInstance(name string, options *instance.CreateInstanceOptions) (*instance.Process, error) + ListInstances() ([]*instance.Instance, error) + CreateInstance(name string, options *instance.Options) (*instance.Instance, error) + GetInstance(name string) (*instance.Instance, error) + UpdateInstance(name string, options *instance.Options) (*instance.Instance, error) DeleteInstance(name string) error - StartInstance(name string) (*instance.Process, error) + StartInstance(name string) (*instance.Instance, error) IsMaxRunningInstancesReached() bool - StopInstance(name string) (*instance.Process, error) + StopInstance(name string) (*instance.Instance, error) EvictLRUInstance() error - RestartInstance(name string) (*instance.Process, error) + RestartInstance(name string) (*instance.Instance, error) GetInstanceLogs(name string, numLines int) (string, error) Shutdown() } type RemoteManager interface { - ListRemoteInstances(node *config.NodeConfig) ([]*instance.Process, error) - CreateRemoteInstance(node *config.NodeConfig, name string, options *instance.CreateInstanceOptions) (*instance.Process, error) - GetRemoteInstance(node *config.NodeConfig, name string) (*instance.Process, error) - UpdateRemoteInstance(node *config.NodeConfig, name string, options *instance.CreateInstanceOptions) (*instance.Process, error) + ListRemoteInstances(node *config.NodeConfig) ([]*instance.Instance, error) + CreateRemoteInstance(node *config.NodeConfig, name string, options *instance.Options) (*instance.Instance, error) + GetRemoteInstance(node *config.NodeConfig, name string) (*instance.Instance, error) + UpdateRemoteInstance(node *config.NodeConfig, name string, options *instance.Options) (*instance.Instance, error) DeleteRemoteInstance(node *config.NodeConfig, name string) error - StartRemoteInstance(node *config.NodeConfig, name string) (*instance.Process, error) - StopRemoteInstance(node *config.NodeConfig, name string) (*instance.Process, error) - RestartRemoteInstance(node *config.NodeConfig, name string) (*instance.Process, error) + StartRemoteInstance(node *config.NodeConfig, name string) (*instance.Instance, error) + StopRemoteInstance(node *config.NodeConfig, name string) (*instance.Instance, error) + RestartRemoteInstance(node *config.NodeConfig, name string) (*instance.Instance, error) GetRemoteInstanceLogs(node *config.NodeConfig, name string, numLines int) (string, error) } type instanceManager struct { mu sync.RWMutex - instances map[string]*instance.Process + instances map[string]*instance.Instance runningInstances map[string]struct{} ports map[int]bool instancesConfig config.InstancesConfig @@ -58,9 +58,9 @@ type instanceManager struct { isShutdown bool // Remote instance management - httpClient *http.Client - instanceNodeMap map[string]*config.NodeConfig // Maps instance name to its node config - nodeConfigMap map[string]*config.NodeConfig // Maps node name to node config for quick lookup + httpClient *http.Client + instanceNodeMap map[string]*config.NodeConfig // Maps instance name to its node config + nodeConfigMap map[string]*config.NodeConfig // Maps node name to node config for quick lookup } // NewInstanceManager creates a new instance of InstanceManager. @@ -77,7 +77,7 @@ func NewInstanceManager(backendsConfig config.BackendConfig, instancesConfig con } im := &instanceManager{ - instances: make(map[string]*instance.Process), + instances: make(map[string]*instance.Instance), runningInstances: make(map[string]struct{}), ports: make(map[int]bool), instancesConfig: instancesConfig, @@ -132,7 +132,7 @@ func (im *instanceManager) getNextAvailablePort() (int, error) { } // persistInstance saves an instance to its JSON file -func (im *instanceManager) persistInstance(instance *instance.Process) error { +func (im *instanceManager) persistInstance(instance *instance.Instance) error { if im.instancesConfig.InstancesDir == "" { return nil // Persistence disabled } @@ -174,7 +174,7 @@ func (im *instanceManager) Shutdown() { close(im.shutdownChan) // Create a list of running instances to stop - var runningInstances []*instance.Process + var runningInstances []*instance.Instance var runningNames []string for name, inst := range im.instances { if inst.IsRunning() { @@ -199,7 +199,7 @@ func (im *instanceManager) Shutdown() { wg.Add(len(runningInstances)) for i, inst := range runningInstances { - go func(name string, inst *instance.Process) { + go func(name string, inst *instance.Instance) { defer wg.Done() fmt.Printf("Stopping instance %s...\n", name) // Attempt to stop the instance gracefully @@ -263,7 +263,7 @@ func (im *instanceManager) loadInstance(name, path string) error { return fmt.Errorf("failed to read instance file: %w", err) } - var persistedInstance instance.Process + var persistedInstance instance.Instance if err := json.Unmarshal(data, &persistedInstance); err != nil { return fmt.Errorf("failed to unmarshal instance: %w", err) } @@ -275,28 +275,37 @@ func (im *instanceManager) loadInstance(name, path string) error { options := persistedInstance.GetOptions() - // Check if this is a remote instance - // An instance is remote if Nodes is specified AND the first node is not the local node - isRemote := options != nil && len(options.Nodes) > 0 && options.Nodes[0] != im.localNodeName + // Check if this is a remote instance (local node not in the Nodes set) + var isRemote bool + var nodeName string + if options != nil { + if _, isLocal := options.Nodes[im.localNodeName]; !isLocal { + // Get the first node from the set + for node := range options.Nodes { + nodeName = node + isRemote = true + break + } + } + } - var statusCallback func(oldStatus, newStatus instance.InstanceStatus) + var statusCallback func(oldStatus, newStatus instance.Status) if !isRemote { // Only set status callback for local instances - statusCallback = func(oldStatus, newStatus instance.InstanceStatus) { + statusCallback = func(oldStatus, newStatus instance.Status) { im.onStatusChange(persistedInstance.Name, oldStatus, newStatus) } } // Create new inst using NewInstance (handles validation, defaults, setup) - inst := instance.NewInstance(name, &im.backendsConfig, &im.instancesConfig, options, im.localNodeName, statusCallback) + inst := instance.New(name, &im.backendsConfig, &im.instancesConfig, options, im.localNodeName, statusCallback) // Restore persisted fields that NewInstance doesn't set inst.Created = persistedInstance.Created - inst.SetStatus(persistedInstance.Status) + inst.SetStatus(persistedInstance.GetStatus()) // Handle remote instance mapping if isRemote { - nodeName := options.Nodes[0] nodeConfig, exists := im.nodeConfigMap[nodeName] if !exists { return fmt.Errorf("node %s not found for remote instance %s", nodeName, name) @@ -321,8 +330,8 @@ func (im *instanceManager) loadInstance(name, path string) error { // For instances with auto-restart disabled, it sets their status to Stopped func (im *instanceManager) autoStartInstances() { im.mu.RLock() - var instancesToStart []*instance.Process - var instancesToStop []*instance.Process + var instancesToStart []*instance.Instance + var instancesToStop []*instance.Instance for _, inst := range im.instances { if inst.IsRunning() && // Was running when persisted inst.GetOptions() != nil && @@ -364,7 +373,7 @@ func (im *instanceManager) autoStartInstances() { } } -func (im *instanceManager) onStatusChange(name string, oldStatus, newStatus instance.InstanceStatus) { +func (im *instanceManager) onStatusChange(name string, oldStatus, newStatus instance.Status) { im.mu.Lock() defer im.mu.Unlock() @@ -377,7 +386,7 @@ func (im *instanceManager) onStatusChange(name string, oldStatus, newStatus inst // getNodeForInstance returns the node configuration for a remote instance // Returns nil if the instance is not remote or the node is not found -func (im *instanceManager) getNodeForInstance(inst *instance.Process) *config.NodeConfig { +func (im *instanceManager) getNodeForInstance(inst *instance.Instance) *config.NodeConfig { if !inst.IsRemote() { return nil } diff --git a/pkg/manager/manager_test.go b/pkg/manager/manager_test.go index e4a6329..531e9e2 100644 --- a/pkg/manager/manager_test.go +++ b/pkg/manager/manager_test.go @@ -70,7 +70,7 @@ func TestPersistence(t *testing.T) { // Test instance persistence on creation manager1 := manager.NewInstanceManager(backendConfig, cfg, map[string]config.NodeConfig{}, "main") - options := &instance.CreateInstanceOptions{ + options := &instance.Options{ BackendType: backends.BackendTypeLlamaCpp, LlamaServerOptions: &llamacpp.LlamaServerOptions{ Model: "/path/to/model.gguf", @@ -132,7 +132,7 @@ func TestConcurrentAccess(t *testing.T) { wg.Add(1) go func(index int) { defer wg.Done() - options := &instance.CreateInstanceOptions{ + options := &instance.Options{ BackendType: backends.BackendTypeLlamaCpp, LlamaServerOptions: &llamacpp.LlamaServerOptions{ Model: "/path/to/model.gguf", @@ -169,7 +169,7 @@ func TestShutdown(t *testing.T) { mgr := createTestManager() // Create test instance - options := &instance.CreateInstanceOptions{ + options := &instance.Options{ BackendType: backends.BackendTypeLlamaCpp, LlamaServerOptions: &llamacpp.LlamaServerOptions{ Model: "/path/to/model.gguf", @@ -230,7 +230,7 @@ func TestAutoRestartDisabledInstanceStatus(t *testing.T) { manager1 := manager.NewInstanceManager(backendConfig, cfg, map[string]config.NodeConfig{}, "main") autoRestart := false - options := &instance.CreateInstanceOptions{ + options := &instance.Options{ BackendType: backends.BackendTypeLlamaCpp, AutoRestart: &autoRestart, LlamaServerOptions: &llamacpp.LlamaServerOptions{ diff --git a/pkg/manager/operations.go b/pkg/manager/operations.go index 0fbbbcb..7129794 100644 --- a/pkg/manager/operations.go +++ b/pkg/manager/operations.go @@ -3,7 +3,6 @@ package manager import ( "fmt" "llamactl/pkg/backends" - "llamactl/pkg/config" "llamactl/pkg/instance" "llamactl/pkg/validation" "os" @@ -14,7 +13,7 @@ type MaxRunningInstancesError error // updateLocalInstanceFromRemote updates the local stub instance with data from the remote instance // while preserving the Nodes field to maintain remote instance tracking -func (im *instanceManager) updateLocalInstanceFromRemote(localInst *instance.Process, remoteInst *instance.Process) { +func (im *instanceManager) updateLocalInstanceFromRemote(localInst *instance.Instance, remoteInst *instance.Instance) { if localInst == nil || remoteInst == nil { return } @@ -27,10 +26,12 @@ func (im *instanceManager) updateLocalInstanceFromRemote(localInst *instance.Pro // Preserve the Nodes field from the local instance localOptions := localInst.GetOptions() - var preservedNodes []string + var preservedNodes map[string]struct{} if localOptions != nil && len(localOptions.Nodes) > 0 { - preservedNodes = make([]string, len(localOptions.Nodes)) - copy(preservedNodes, localOptions.Nodes) + preservedNodes = make(map[string]struct{}, len(localOptions.Nodes)) + for node := range localOptions.Nodes { + preservedNodes[node] = struct{}{} + } } // Create a copy of remote options and restore the Nodes field @@ -39,15 +40,15 @@ func (im *instanceManager) updateLocalInstanceFromRemote(localInst *instance.Pro // Update the local instance with all remote data localInst.SetOptions(&updatedOptions) - localInst.Status = remoteInst.Status + localInst.SetStatus(remoteInst.GetStatus()) localInst.Created = remoteInst.Created } // ListInstances returns a list of all instances managed by the instance manager. // For remote instances, this fetches the live state from remote nodes and updates local stubs. -func (im *instanceManager) ListInstances() ([]*instance.Process, error) { +func (im *instanceManager) ListInstances() ([]*instance.Instance, error) { im.mu.RLock() - localInstances := make([]*instance.Process, 0, len(im.instances)) + localInstances := make([]*instance.Instance, 0, len(im.instances)) for _, inst := range im.instances { localInstances = append(localInstances, inst) } @@ -75,7 +76,7 @@ func (im *instanceManager) ListInstances() ([]*instance.Process, error) { // CreateInstance creates a new instance with the given options and returns it. // The instance is initially in a "stopped" state. -func (im *instanceManager) CreateInstance(name string, options *instance.CreateInstanceOptions) (*instance.Process, error) { +func (im *instanceManager) CreateInstance(name string, options *instance.Options) (*instance.Instance, error) { if options == nil { return nil, fmt.Errorf("instance options cannot be nil") } @@ -98,16 +99,17 @@ func (im *instanceManager) CreateInstance(name string, options *instance.CreateI return nil, fmt.Errorf("instance with name %s already exists", name) } - // Check if this is a remote instance - // An instance is remote if Nodes is specified AND the first node is not the local node - isRemote := len(options.Nodes) > 0 && options.Nodes[0] != im.localNodeName - var nodeConfig *config.NodeConfig + // Check if this is a remote instance (local node not in the Nodes set) + if _, isLocal := options.Nodes[im.localNodeName]; !isLocal && len(options.Nodes) > 0 { + // Get the first node from the set + var nodeName string + for node := range options.Nodes { + nodeName = node + break + } - if isRemote { // Validate that the node exists - nodeName := options.Nodes[0] // Use first node for now - var exists bool - nodeConfig, exists = im.nodeConfigMap[nodeName] + nodeConfig, exists := im.nodeConfigMap[nodeName] if !exists { return nil, fmt.Errorf("node %s not found", nodeName) } @@ -120,7 +122,7 @@ func (im *instanceManager) CreateInstance(name string, options *instance.CreateI // Create a local stub that preserves the Nodes field for tracking // We keep the original options (with Nodes) so IsRemote() works correctly - inst := instance.NewInstance(name, &im.backendsConfig, &im.instancesConfig, options, im.localNodeName, nil) + inst := instance.New(name, &im.backendsConfig, &im.instancesConfig, options, im.localNodeName, nil) // Update the local stub with all remote data (preserving Nodes) im.updateLocalInstanceFromRemote(inst, remoteInst) @@ -149,11 +151,11 @@ func (im *instanceManager) CreateInstance(name string, options *instance.CreateI return nil, err } - statusCallback := func(oldStatus, newStatus instance.InstanceStatus) { + statusCallback := func(oldStatus, newStatus instance.Status) { im.onStatusChange(name, oldStatus, newStatus) } - inst := instance.NewInstance(name, &im.backendsConfig, &im.instancesConfig, options, im.localNodeName, statusCallback) + inst := instance.New(name, &im.backendsConfig, &im.instancesConfig, options, im.localNodeName, statusCallback) im.instances[inst.Name] = inst if err := im.persistInstance(inst); err != nil { @@ -165,7 +167,7 @@ func (im *instanceManager) CreateInstance(name string, options *instance.CreateI // GetInstance retrieves an instance by its name. // For remote instances, this fetches the live state from the remote node and updates the local stub. -func (im *instanceManager) GetInstance(name string) (*instance.Process, error) { +func (im *instanceManager) GetInstance(name string) (*instance.Instance, error) { im.mu.RLock() inst, exists := im.instances[name] im.mu.RUnlock() @@ -195,7 +197,7 @@ func (im *instanceManager) GetInstance(name string) (*instance.Process, error) { // UpdateInstance updates the options of an existing instance and returns it. // If the instance is running, it will be restarted to apply the new options. -func (im *instanceManager) UpdateInstance(name string, options *instance.CreateInstanceOptions) (*instance.Process, error) { +func (im *instanceManager) UpdateInstance(name string, options *instance.Options) (*instance.Instance, error) { im.mu.RLock() inst, exists := im.instances[name] im.mu.RUnlock() @@ -327,7 +329,7 @@ func (im *instanceManager) DeleteInstance(name string) error { // StartInstance starts a stopped instance and returns it. // If the instance is already running, it returns an error. -func (im *instanceManager) StartInstance(name string) (*instance.Process, error) { +func (im *instanceManager) StartInstance(name string) (*instance.Instance, error) { im.mu.RLock() inst, exists := im.instances[name] im.mu.RUnlock() @@ -396,7 +398,7 @@ func (im *instanceManager) IsMaxRunningInstancesReached() bool { } // StopInstance stops a running instance and returns it. -func (im *instanceManager) StopInstance(name string) (*instance.Process, error) { +func (im *instanceManager) StopInstance(name string) (*instance.Instance, error) { im.mu.RLock() inst, exists := im.instances[name] im.mu.RUnlock() @@ -439,7 +441,7 @@ func (im *instanceManager) StopInstance(name string) (*instance.Process, error) } // RestartInstance stops and then starts an instance, returning the updated instance. -func (im *instanceManager) RestartInstance(name string) (*instance.Process, error) { +func (im *instanceManager) RestartInstance(name string) (*instance.Instance, error) { im.mu.RLock() inst, exists := im.instances[name] im.mu.RUnlock() @@ -490,7 +492,7 @@ func (im *instanceManager) GetInstanceLogs(name string, numLines int) (string, e } // getPortFromOptions extracts the port from backend-specific options -func (im *instanceManager) getPortFromOptions(options *instance.CreateInstanceOptions) int { +func (im *instanceManager) getPortFromOptions(options *instance.Options) int { switch options.BackendType { case backends.BackendTypeLlamaCpp: if options.LlamaServerOptions != nil { @@ -509,7 +511,7 @@ func (im *instanceManager) getPortFromOptions(options *instance.CreateInstanceOp } // setPortInOptions sets the port in backend-specific options -func (im *instanceManager) setPortInOptions(options *instance.CreateInstanceOptions, port int) { +func (im *instanceManager) setPortInOptions(options *instance.Options, port int) { switch options.BackendType { case backends.BackendTypeLlamaCpp: if options.LlamaServerOptions != nil { @@ -527,7 +529,7 @@ func (im *instanceManager) setPortInOptions(options *instance.CreateInstanceOpti } // assignAndValidatePort assigns a port if not specified and validates it's not in use -func (im *instanceManager) assignAndValidatePort(options *instance.CreateInstanceOptions) error { +func (im *instanceManager) assignAndValidatePort(options *instance.Options) error { currentPort := im.getPortFromOptions(options) if currentPort == 0 { diff --git a/pkg/manager/operations_test.go b/pkg/manager/operations_test.go index 6743a4b..56b8b3b 100644 --- a/pkg/manager/operations_test.go +++ b/pkg/manager/operations_test.go @@ -13,7 +13,7 @@ import ( func TestCreateInstance_Success(t *testing.T) { manager := createTestManager() - options := &instance.CreateInstanceOptions{ + options := &instance.Options{ BackendType: backends.BackendTypeLlamaCpp, LlamaServerOptions: &llamacpp.LlamaServerOptions{ Model: "/path/to/model.gguf", @@ -40,7 +40,7 @@ func TestCreateInstance_Success(t *testing.T) { func TestCreateInstance_ValidationAndLimits(t *testing.T) { // Test duplicate names mngr := createTestManager() - options := &instance.CreateInstanceOptions{ + options := &instance.Options{ BackendType: backends.BackendTypeLlamaCpp, LlamaServerOptions: &llamacpp.LlamaServerOptions{ Model: "/path/to/model.gguf", @@ -96,7 +96,7 @@ func TestPortManagement(t *testing.T) { manager := createTestManager() // Test auto port assignment - options1 := &instance.CreateInstanceOptions{ + options1 := &instance.Options{ BackendType: backends.BackendTypeLlamaCpp, LlamaServerOptions: &llamacpp.LlamaServerOptions{ Model: "/path/to/model.gguf", @@ -114,7 +114,7 @@ func TestPortManagement(t *testing.T) { } // Test port conflict detection - options2 := &instance.CreateInstanceOptions{ + options2 := &instance.Options{ BackendType: backends.BackendTypeLlamaCpp, LlamaServerOptions: &llamacpp.LlamaServerOptions{ Model: "/path/to/model2.gguf", @@ -132,7 +132,7 @@ func TestPortManagement(t *testing.T) { // Test port release on deletion specificPort := 8080 - options3 := &instance.CreateInstanceOptions{ + options3 := &instance.Options{ BackendType: backends.BackendTypeLlamaCpp, LlamaServerOptions: &llamacpp.LlamaServerOptions{ Model: "/path/to/model.gguf", @@ -160,7 +160,7 @@ func TestPortManagement(t *testing.T) { func TestInstanceOperations(t *testing.T) { manager := createTestManager() - options := &instance.CreateInstanceOptions{ + options := &instance.Options{ BackendType: backends.BackendTypeLlamaCpp, LlamaServerOptions: &llamacpp.LlamaServerOptions{ Model: "/path/to/model.gguf", @@ -183,7 +183,7 @@ func TestInstanceOperations(t *testing.T) { } // Update instance - newOptions := &instance.CreateInstanceOptions{ + newOptions := &instance.Options{ BackendType: backends.BackendTypeLlamaCpp, LlamaServerOptions: &llamacpp.LlamaServerOptions{ Model: "/path/to/new-model.gguf", diff --git a/pkg/manager/remote_ops.go b/pkg/manager/remote_ops.go index e98e396..143c317 100644 --- a/pkg/manager/remote_ops.go +++ b/pkg/manager/remote_ops.go @@ -66,13 +66,13 @@ func parseRemoteResponse(resp *http.Response, result any) error { } // ListRemoteInstances lists all instances on the remote node -func (im *instanceManager) ListRemoteInstances(nodeConfig *config.NodeConfig) ([]*instance.Process, error) { +func (im *instanceManager) ListRemoteInstances(nodeConfig *config.NodeConfig) ([]*instance.Instance, error) { resp, err := im.makeRemoteRequest(nodeConfig, "GET", "/api/v1/instances/", nil) if err != nil { return nil, err } - var instances []*instance.Process + var instances []*instance.Instance if err := parseRemoteResponse(resp, &instances); err != nil { return nil, err } @@ -81,7 +81,7 @@ func (im *instanceManager) ListRemoteInstances(nodeConfig *config.NodeConfig) ([ } // CreateRemoteInstance creates a new instance on the remote node -func (im *instanceManager) CreateRemoteInstance(nodeConfig *config.NodeConfig, name string, options *instance.CreateInstanceOptions) (*instance.Process, error) { +func (im *instanceManager) CreateRemoteInstance(nodeConfig *config.NodeConfig, name string, options *instance.Options) (*instance.Instance, error) { path := fmt.Sprintf("/api/v1/instances/%s/", name) resp, err := im.makeRemoteRequest(nodeConfig, "POST", path, options) @@ -89,7 +89,7 @@ func (im *instanceManager) CreateRemoteInstance(nodeConfig *config.NodeConfig, n return nil, err } - var inst instance.Process + var inst instance.Instance if err := parseRemoteResponse(resp, &inst); err != nil { return nil, err } @@ -98,14 +98,14 @@ func (im *instanceManager) CreateRemoteInstance(nodeConfig *config.NodeConfig, n } // GetRemoteInstance retrieves an instance by name from the remote node -func (im *instanceManager) GetRemoteInstance(nodeConfig *config.NodeConfig, name string) (*instance.Process, error) { +func (im *instanceManager) GetRemoteInstance(nodeConfig *config.NodeConfig, name string) (*instance.Instance, error) { path := fmt.Sprintf("/api/v1/instances/%s/", name) resp, err := im.makeRemoteRequest(nodeConfig, "GET", path, nil) if err != nil { return nil, err } - var inst instance.Process + var inst instance.Instance if err := parseRemoteResponse(resp, &inst); err != nil { return nil, err } @@ -114,7 +114,7 @@ func (im *instanceManager) GetRemoteInstance(nodeConfig *config.NodeConfig, name } // UpdateRemoteInstance updates an existing instance on the remote node -func (im *instanceManager) UpdateRemoteInstance(nodeConfig *config.NodeConfig, name string, options *instance.CreateInstanceOptions) (*instance.Process, error) { +func (im *instanceManager) UpdateRemoteInstance(nodeConfig *config.NodeConfig, name string, options *instance.Options) (*instance.Instance, error) { path := fmt.Sprintf("/api/v1/instances/%s/", name) resp, err := im.makeRemoteRequest(nodeConfig, "PUT", path, options) @@ -122,7 +122,7 @@ func (im *instanceManager) UpdateRemoteInstance(nodeConfig *config.NodeConfig, n return nil, err } - var inst instance.Process + var inst instance.Instance if err := parseRemoteResponse(resp, &inst); err != nil { return nil, err } @@ -142,14 +142,14 @@ func (im *instanceManager) DeleteRemoteInstance(nodeConfig *config.NodeConfig, n } // StartRemoteInstance starts an instance on the remote node -func (im *instanceManager) StartRemoteInstance(nodeConfig *config.NodeConfig, name string) (*instance.Process, error) { +func (im *instanceManager) StartRemoteInstance(nodeConfig *config.NodeConfig, name string) (*instance.Instance, error) { path := fmt.Sprintf("/api/v1/instances/%s/start", name) resp, err := im.makeRemoteRequest(nodeConfig, "POST", path, nil) if err != nil { return nil, err } - var inst instance.Process + var inst instance.Instance if err := parseRemoteResponse(resp, &inst); err != nil { return nil, err } @@ -158,14 +158,14 @@ func (im *instanceManager) StartRemoteInstance(nodeConfig *config.NodeConfig, na } // StopRemoteInstance stops an instance on the remote node -func (im *instanceManager) StopRemoteInstance(nodeConfig *config.NodeConfig, name string) (*instance.Process, error) { +func (im *instanceManager) StopRemoteInstance(nodeConfig *config.NodeConfig, name string) (*instance.Instance, error) { path := fmt.Sprintf("/api/v1/instances/%s/stop", name) resp, err := im.makeRemoteRequest(nodeConfig, "POST", path, nil) if err != nil { return nil, err } - var inst instance.Process + var inst instance.Instance if err := parseRemoteResponse(resp, &inst); err != nil { return nil, err } @@ -174,14 +174,14 @@ func (im *instanceManager) StopRemoteInstance(nodeConfig *config.NodeConfig, nam } // RestartRemoteInstance restarts an instance on the remote node -func (im *instanceManager) RestartRemoteInstance(nodeConfig *config.NodeConfig, name string) (*instance.Process, error) { +func (im *instanceManager) RestartRemoteInstance(nodeConfig *config.NodeConfig, name string) (*instance.Instance, error) { path := fmt.Sprintf("/api/v1/instances/%s/restart", name) resp, err := im.makeRemoteRequest(nodeConfig, "POST", path, nil) if err != nil { return nil, err } - var inst instance.Process + var inst instance.Instance if err := parseRemoteResponse(resp, &inst); err != nil { return nil, err } diff --git a/pkg/manager/timeout.go b/pkg/manager/timeout.go index 50b1c10..2e0314a 100644 --- a/pkg/manager/timeout.go +++ b/pkg/manager/timeout.go @@ -37,7 +37,7 @@ func (im *instanceManager) checkAllTimeouts() { // EvictLRUInstance finds and stops the least recently used running instance. func (im *instanceManager) EvictLRUInstance() error { im.mu.RLock() - var lruInstance *instance.Process + var lruInstance *instance.Instance for name := range im.runningInstances { inst := im.instances[name] diff --git a/pkg/manager/timeout_test.go b/pkg/manager/timeout_test.go index 4992370..8c30d5d 100644 --- a/pkg/manager/timeout_test.go +++ b/pkg/manager/timeout_test.go @@ -34,7 +34,7 @@ func TestTimeoutFunctionality(t *testing.T) { defer testManager.Shutdown() idleTimeout := 1 // 1 minute - options := &instance.CreateInstanceOptions{ + options := &instance.Options{ IdleTimeout: &idleTimeout, BackendType: backends.BackendTypeLlamaCpp, LlamaServerOptions: &llamacpp.LlamaServerOptions{ @@ -84,7 +84,7 @@ func TestTimeoutFunctionality(t *testing.T) { inst.SetStatus(instance.Stopped) // Test that instance without timeout doesn't timeout - noTimeoutOptions := &instance.CreateInstanceOptions{ + noTimeoutOptions := &instance.Options{ BackendType: backends.BackendTypeLlamaCpp, LlamaServerOptions: &llamacpp.LlamaServerOptions{ Model: "/path/to/model.gguf", @@ -115,21 +115,21 @@ func TestEvictLRUInstance_Success(t *testing.T) { // Don't defer manager.Shutdown() - we'll handle cleanup manually // Create 3 instances with idle timeout enabled (value doesn't matter for LRU logic) - options1 := &instance.CreateInstanceOptions{ + options1 := &instance.Options{ BackendType: backends.BackendTypeLlamaCpp, LlamaServerOptions: &llamacpp.LlamaServerOptions{ Model: "/path/to/model1.gguf", }, IdleTimeout: func() *int { timeout := 1; return &timeout }(), // Any value > 0 } - options2 := &instance.CreateInstanceOptions{ + options2 := &instance.Options{ BackendType: backends.BackendTypeLlamaCpp, LlamaServerOptions: &llamacpp.LlamaServerOptions{ Model: "/path/to/model2.gguf", }, IdleTimeout: func() *int { timeout := 1; return &timeout }(), // Any value > 0 } - options3 := &instance.CreateInstanceOptions{ + options3 := &instance.Options{ BackendType: backends.BackendTypeLlamaCpp, LlamaServerOptions: &llamacpp.LlamaServerOptions{ Model: "/path/to/model3.gguf", @@ -196,8 +196,8 @@ func TestEvictLRUInstance_Success(t *testing.T) { func TestEvictLRUInstance_NoEligibleInstances(t *testing.T) { // Helper function to create instances with different timeout configurations - createInstanceWithTimeout := func(manager manager.InstanceManager, name, model string, timeout *int) *instance.Process { - options := &instance.CreateInstanceOptions{ + createInstanceWithTimeout := func(manager manager.InstanceManager, name, model string, timeout *int) *instance.Instance { + options := &instance.Options{ BackendType: backends.BackendTypeLlamaCpp, LlamaServerOptions: &llamacpp.LlamaServerOptions{ Model: model, @@ -236,7 +236,7 @@ func TestEvictLRUInstance_NoEligibleInstances(t *testing.T) { inst3 := createInstanceWithTimeout(manager, "no-timeout-3", "/path/to/model3.gguf", nil) // Set instances to running - instances := []*instance.Process{inst1, inst2, inst3} + instances := []*instance.Instance{inst1, inst2, inst3} for _, inst := range instances { inst.SetStatus(instance.Running) } @@ -276,7 +276,7 @@ func TestEvictLRUInstance_NoEligibleInstances(t *testing.T) { instNoTimeout2 := createInstanceWithTimeout(manager, "no-timeout-2", "/path/to/model-no-timeout2.gguf", nil) // Set all instances to running - instances := []*instance.Process{instWithTimeout, instNoTimeout1, instNoTimeout2} + instances := []*instance.Instance{instWithTimeout, instNoTimeout1, instNoTimeout2} for _, inst := range instances { inst.SetStatus(instance.Running) inst.UpdateLastRequestTime() diff --git a/pkg/server/handlers_backends.go b/pkg/server/handlers_backends.go index 7d6cab0..6fa833c 100644 --- a/pkg/server/handlers_backends.go +++ b/pkg/server/handlers_backends.go @@ -106,7 +106,7 @@ func (h *Handler) LlamaCppProxy(onDemandStart bool) http.HandlerFunc { // @Accept json // @Produce json // @Param request body ParseCommandRequest true "Command to parse" -// @Success 200 {object} instance.CreateInstanceOptions "Parsed options" +// @Success 200 {object} instance.Options "Parsed options" // @Failure 400 {object} map[string]string "Invalid request or command" // @Failure 500 {object} map[string]string "Internal Server Error" // @Router /backends/llama-cpp/parse-command [post] @@ -135,7 +135,7 @@ func (h *Handler) ParseLlamaCommand() http.HandlerFunc { writeError(w, http.StatusBadRequest, "parse_error", err.Error()) return } - options := &instance.CreateInstanceOptions{ + options := &instance.Options{ BackendType: backends.BackendTypeLlamaCpp, LlamaServerOptions: llamaOptions, } @@ -154,7 +154,7 @@ func (h *Handler) ParseLlamaCommand() http.HandlerFunc { // @Accept json // @Produce json // @Param request body ParseCommandRequest true "Command to parse" -// @Success 200 {object} instance.CreateInstanceOptions "Parsed options" +// @Success 200 {object} instance.Options "Parsed options" // @Failure 400 {object} map[string]string "Invalid request or command" // @Router /backends/mlx/parse-command [post] func (h *Handler) ParseMlxCommand() http.HandlerFunc { @@ -188,7 +188,7 @@ func (h *Handler) ParseMlxCommand() http.HandlerFunc { // Currently only support mlx_lm backend type backendType := backends.BackendTypeMlxLm - options := &instance.CreateInstanceOptions{ + options := &instance.Options{ BackendType: backendType, MlxServerOptions: mlxOptions, } @@ -208,7 +208,7 @@ func (h *Handler) ParseMlxCommand() http.HandlerFunc { // @Accept json // @Produce json // @Param request body ParseCommandRequest true "Command to parse" -// @Success 200 {object} instance.CreateInstanceOptions "Parsed options" +// @Success 200 {object} instance.Options "Parsed options" // @Failure 400 {object} map[string]string "Invalid request or command" // @Router /backends/vllm/parse-command [post] func (h *Handler) ParseVllmCommand() http.HandlerFunc { @@ -241,7 +241,7 @@ func (h *Handler) ParseVllmCommand() http.HandlerFunc { backendType := backends.BackendTypeVllm - options := &instance.CreateInstanceOptions{ + options := &instance.Options{ BackendType: backendType, VllmServerOptions: vllmOptions, } diff --git a/pkg/server/handlers_instances.go b/pkg/server/handlers_instances.go index be3cf4a..9e41190 100644 --- a/pkg/server/handlers_instances.go +++ b/pkg/server/handlers_instances.go @@ -47,7 +47,7 @@ func (h *Handler) ListInstances() http.HandlerFunc { // @Accept json // @Produces json // @Param name path string true "Instance Name" -// @Param options body instance.CreateInstanceOptions true "Instance configuration options" +// @Param options body instance.Options true "Instance configuration options" // @Success 201 {object} instance.Process "Created instance details" // @Failure 400 {string} string "Invalid request body" // @Failure 500 {string} string "Internal Server Error" @@ -60,7 +60,7 @@ func (h *Handler) CreateInstance() http.HandlerFunc { return } - var options instance.CreateInstanceOptions + var options instance.Options if err := json.NewDecoder(r.Body).Decode(&options); err != nil { http.Error(w, "Invalid request body", http.StatusBadRequest) return @@ -122,7 +122,7 @@ func (h *Handler) GetInstance() http.HandlerFunc { // @Accept json // @Produces json // @Param name path string true "Instance Name" -// @Param options body instance.CreateInstanceOptions true "Instance configuration options" +// @Param options body instance.Options true "Instance configuration options" // @Success 200 {object} instance.Process "Updated instance details" // @Failure 400 {string} string "Invalid name format" // @Failure 500 {string} string "Internal Server Error" @@ -135,7 +135,7 @@ func (h *Handler) UpdateInstance() http.HandlerFunc { return } - var options instance.CreateInstanceOptions + var options instance.Options if err := json.NewDecoder(r.Body).Decode(&options); err != nil { http.Error(w, "Invalid request body", http.StatusBadRequest) return @@ -391,15 +391,24 @@ func (h *Handler) ProxyToInstance() http.HandlerFunc { } // RemoteInstanceProxy proxies requests to a remote instance -func (h *Handler) RemoteInstanceProxy(w http.ResponseWriter, r *http.Request, name string, inst *instance.Process) { +func (h *Handler) RemoteInstanceProxy(w http.ResponseWriter, r *http.Request, name string, inst *instance.Instance) { // Get the node name from instance options options := inst.GetOptions() - if options == nil || len(options.Nodes) == 0 { - http.Error(w, "Instance has no node configured", http.StatusInternalServerError) + if options == nil { + http.Error(w, "Instance has no options configured", http.StatusInternalServerError) return } - nodeName := options.Nodes[0] + // Get the first node from the set + var nodeName string + for node := range options.Nodes { + nodeName = node + break + } + if nodeName == "" { + http.Error(w, "Instance has no node configured", http.StatusInternalServerError) + return + } // Check if we have a cached proxy for this node h.remoteProxiesMu.RLock() diff --git a/pkg/server/handlers_openai.go b/pkg/server/handlers_openai.go index c6e56e9..075f651 100644 --- a/pkg/server/handlers_openai.go +++ b/pkg/server/handlers_openai.go @@ -152,15 +152,24 @@ func (h *Handler) OpenAIProxy() http.HandlerFunc { } // RemoteOpenAIProxy proxies OpenAI-compatible requests to a remote instance -func (h *Handler) RemoteOpenAIProxy(w http.ResponseWriter, r *http.Request, modelName string, inst *instance.Process) { +func (h *Handler) RemoteOpenAIProxy(w http.ResponseWriter, r *http.Request, modelName string, inst *instance.Instance) { // Get the node name from instance options options := inst.GetOptions() - if options == nil || len(options.Nodes) == 0 { - http.Error(w, "Instance has no node configured", http.StatusInternalServerError) + if options == nil { + http.Error(w, "Instance has no options configured", http.StatusInternalServerError) return } - nodeName := options.Nodes[0] + // Get the first node from the set + var nodeName string + for node := range options.Nodes { + nodeName = node + break + } + if nodeName == "" { + http.Error(w, "Instance has no node configured", http.StatusInternalServerError) + return + } // Check if we have a cached proxy for this node h.remoteProxiesMu.RLock() diff --git a/pkg/validation/validation.go b/pkg/validation/validation.go index 638e5d2..6d6638d 100644 --- a/pkg/validation/validation.go +++ b/pkg/validation/validation.go @@ -35,7 +35,7 @@ func validateStringForInjection(value string) error { } // ValidateInstanceOptions performs validation based on backend type -func ValidateInstanceOptions(options *instance.CreateInstanceOptions) error { +func ValidateInstanceOptions(options *instance.Options) error { if options == nil { return ValidationError(fmt.Errorf("options cannot be nil")) } @@ -54,7 +54,7 @@ func ValidateInstanceOptions(options *instance.CreateInstanceOptions) error { } // validateLlamaCppOptions validates llama.cpp specific options -func validateLlamaCppOptions(options *instance.CreateInstanceOptions) error { +func validateLlamaCppOptions(options *instance.Options) error { if options.LlamaServerOptions == nil { return ValidationError(fmt.Errorf("llama server options cannot be nil for llama.cpp backend")) } @@ -73,7 +73,7 @@ func validateLlamaCppOptions(options *instance.CreateInstanceOptions) error { } // validateMlxOptions validates MLX backend specific options -func validateMlxOptions(options *instance.CreateInstanceOptions) error { +func validateMlxOptions(options *instance.Options) error { if options.MlxServerOptions == nil { return ValidationError(fmt.Errorf("MLX server options cannot be nil for MLX backend")) } @@ -91,7 +91,7 @@ func validateMlxOptions(options *instance.CreateInstanceOptions) error { } // validateVllmOptions validates vLLM backend specific options -func validateVllmOptions(options *instance.CreateInstanceOptions) error { +func validateVllmOptions(options *instance.Options) error { if options.VllmServerOptions == nil { return ValidationError(fmt.Errorf("vLLM server options cannot be nil for vLLM backend")) } diff --git a/pkg/validation/validation_test.go b/pkg/validation/validation_test.go index 8d8c49e..759ebc3 100644 --- a/pkg/validation/validation_test.go +++ b/pkg/validation/validation_test.go @@ -83,7 +83,7 @@ func TestValidateInstanceOptions_PortValidation(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - options := &instance.CreateInstanceOptions{ + options := &instance.Options{ BackendType: backends.BackendTypeLlamaCpp, LlamaServerOptions: &llamacpp.LlamaServerOptions{ Port: tt.port, @@ -137,7 +137,7 @@ func TestValidateInstanceOptions_StringInjection(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { // Test with Model field (string field) - options := &instance.CreateInstanceOptions{ + options := &instance.Options{ BackendType: backends.BackendTypeLlamaCpp, LlamaServerOptions: &llamacpp.LlamaServerOptions{ Model: tt.value, @@ -175,7 +175,7 @@ func TestValidateInstanceOptions_ArrayInjection(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { // Test with Lora field (array field) - options := &instance.CreateInstanceOptions{ + options := &instance.Options{ BackendType: backends.BackendTypeLlamaCpp, LlamaServerOptions: &llamacpp.LlamaServerOptions{ Lora: tt.array, @@ -194,12 +194,12 @@ func TestValidateInstanceOptions_MultipleFieldInjection(t *testing.T) { // Test that injection in any field is caught tests := []struct { name string - options *instance.CreateInstanceOptions + options *instance.Options wantErr bool }{ { name: "injection in model field", - options: &instance.CreateInstanceOptions{ + options: &instance.Options{ BackendType: backends.BackendTypeLlamaCpp, LlamaServerOptions: &llamacpp.LlamaServerOptions{ Model: "safe.gguf", @@ -210,7 +210,7 @@ func TestValidateInstanceOptions_MultipleFieldInjection(t *testing.T) { }, { name: "injection in log file", - options: &instance.CreateInstanceOptions{ + options: &instance.Options{ BackendType: backends.BackendTypeLlamaCpp, LlamaServerOptions: &llamacpp.LlamaServerOptions{ Model: "safe.gguf", @@ -221,7 +221,7 @@ func TestValidateInstanceOptions_MultipleFieldInjection(t *testing.T) { }, { name: "all safe fields", - options: &instance.CreateInstanceOptions{ + options: &instance.Options{ BackendType: backends.BackendTypeLlamaCpp, LlamaServerOptions: &llamacpp.LlamaServerOptions{ Model: "/path/to/model.gguf", @@ -247,7 +247,7 @@ func TestValidateInstanceOptions_MultipleFieldInjection(t *testing.T) { func TestValidateInstanceOptions_NonStringFields(t *testing.T) { // Test that non-string fields don't interfere with validation - options := &instance.CreateInstanceOptions{ + options := &instance.Options{ AutoRestart: testutil.BoolPtr(true), MaxRestarts: testutil.IntPtr(5), RestartDelay: testutil.IntPtr(10),