Refactor backend options handling and validation

This commit is contained in:
2025-10-19 17:41:08 +02:00
parent 2a7010d0e1
commit 55f671c354
10 changed files with 480 additions and 425 deletions

View File

@@ -1,5 +1,13 @@
package backends package backends
import (
"encoding/json"
"fmt"
"llamactl/pkg/config"
"llamactl/pkg/validation"
"maps"
)
type BackendType string type BackendType string
const ( const (
@@ -13,10 +21,301 @@ type Options struct {
BackendType BackendType `json:"backend_type"` BackendType BackendType `json:"backend_type"`
BackendOptions map[string]any `json:"backend_options,omitempty"` BackendOptions map[string]any `json:"backend_options,omitempty"`
Nodes map[string]struct{} `json:"-"`
// Backend-specific options // Backend-specific options
LlamaServerOptions *LlamaServerOptions `json:"-"` LlamaServerOptions *LlamaServerOptions `json:"-"`
MlxServerOptions *MlxServerOptions `json:"-"` MlxServerOptions *MlxServerOptions `json:"-"`
VllmServerOptions *VllmServerOptions `json:"-"` VllmServerOptions *VllmServerOptions `json:"-"`
} }
func (o *Options) UnmarshalJSON(data []byte) error {
// Use anonymous struct to avoid recursion
type Alias Options
aux := &struct {
*Alias
}{
Alias: (*Alias)(o),
}
if err := json.Unmarshal(data, aux); err != nil {
return err
}
// Parse backend-specific options
switch o.BackendType {
case BackendTypeLlamaCpp:
if o.BackendOptions != nil {
// Convert map to JSON and then unmarshal to LlamaServerOptions
optionsData, err := json.Marshal(o.BackendOptions)
if err != nil {
return fmt.Errorf("failed to marshal backend options: %w", err)
}
o.LlamaServerOptions = &LlamaServerOptions{}
if err := json.Unmarshal(optionsData, o.LlamaServerOptions); err != nil {
return fmt.Errorf("failed to unmarshal llama.cpp options: %w", err)
}
}
case BackendTypeMlxLm:
if o.BackendOptions != nil {
optionsData, err := json.Marshal(o.BackendOptions)
if err != nil {
return fmt.Errorf("failed to marshal backend options: %w", err)
}
o.MlxServerOptions = &MlxServerOptions{}
if err := json.Unmarshal(optionsData, o.MlxServerOptions); err != nil {
return fmt.Errorf("failed to unmarshal MLX options: %w", err)
}
}
case BackendTypeVllm:
if o.BackendOptions != nil {
optionsData, err := json.Marshal(o.BackendOptions)
if err != nil {
return fmt.Errorf("failed to marshal backend options: %w", err)
}
o.VllmServerOptions = &VllmServerOptions{}
if err := json.Unmarshal(optionsData, o.VllmServerOptions); err != nil {
return fmt.Errorf("failed to unmarshal vLLM options: %w", err)
}
}
}
return nil
}
func (o *Options) MarshalJSON() ([]byte, error) {
// Use anonymous struct to avoid recursion
type Alias Options
aux := &struct {
*Alias
}{
Alias: (*Alias)(o),
}
// Prepare BackendOptions map
if o.BackendOptions == nil {
o.BackendOptions = make(map[string]any)
}
// Populate BackendOptions based on backend-specific options
switch o.BackendType {
case BackendTypeLlamaCpp:
if o.LlamaServerOptions != nil {
optionsData, err := json.Marshal(o.LlamaServerOptions)
if err != nil {
return nil, fmt.Errorf("failed to marshal llama.cpp options: %w", err)
}
if err := json.Unmarshal(optionsData, &o.BackendOptions); err != nil {
return nil, fmt.Errorf("failed to unmarshal llama.cpp options to map: %w", err)
}
}
case BackendTypeMlxLm:
if o.MlxServerOptions != nil {
optionsData, err := json.Marshal(o.MlxServerOptions)
if err != nil {
return nil, fmt.Errorf("failed to marshal MLX options: %w", err)
}
if err := json.Unmarshal(optionsData, &o.BackendOptions); err != nil {
return nil, fmt.Errorf("failed to unmarshal MLX options to map: %w", err)
}
}
case BackendTypeVllm:
if o.VllmServerOptions != nil {
optionsData, err := json.Marshal(o.VllmServerOptions)
if err != nil {
return nil, fmt.Errorf("failed to marshal vLLM options: %w", err)
}
if err := json.Unmarshal(optionsData, &o.BackendOptions); err != nil {
return nil, fmt.Errorf("failed to unmarshal vLLM options to map: %w", err)
}
}
}
return json.Marshal(aux)
}
func getBackendSettings(o *Options, backendConfig *config.BackendConfig) *config.BackendSettings {
switch o.BackendType {
case BackendTypeLlamaCpp:
return &backendConfig.LlamaCpp
case BackendTypeMlxLm:
return &backendConfig.MLX
case BackendTypeVllm:
return &backendConfig.VLLM
default:
return nil
}
}
func (o *Options) isDockerEnabled(backend *config.BackendSettings) bool {
if backend.Docker != nil && backend.Docker.Enabled && o.BackendType != BackendTypeMlxLm {
return true
}
return false
}
func (o *Options) IsDockerEnabled(backendConfig *config.BackendConfig) bool {
backendSettings := getBackendSettings(o, backendConfig)
return o.isDockerEnabled(backendSettings)
}
// GetCommand builds the command to run the backend
func (o *Options) GetCommand(backendConfig *config.BackendConfig) string {
backendSettings := getBackendSettings(o, backendConfig)
if o.isDockerEnabled(backendSettings) {
return "docker"
}
return backendSettings.Command
}
// buildCommandArgs builds command line arguments for the backend
func (o *Options) BuildCommandArgs(backendConfig *config.BackendConfig) []string {
var args []string
backendSettings := getBackendSettings(o, backendConfig)
if o.isDockerEnabled(backendSettings) {
// For Docker, start with Docker args
args = append(args, backendSettings.Docker.Args...)
args = append(args, backendSettings.Docker.Image)
switch o.BackendType {
case BackendTypeLlamaCpp:
if o.LlamaServerOptions != nil {
args = append(args, o.LlamaServerOptions.BuildDockerArgs()...)
}
case BackendTypeVllm:
if o.VllmServerOptions != nil {
args = append(args, o.VllmServerOptions.BuildDockerArgs()...)
}
}
} else {
// For native execution, start with backend args
args = append(args, backendSettings.Args...)
switch o.BackendType {
case BackendTypeLlamaCpp:
if o.LlamaServerOptions != nil {
args = append(args, o.LlamaServerOptions.BuildCommandArgs()...)
}
case BackendTypeMlxLm:
if o.MlxServerOptions != nil {
args = append(args, o.MlxServerOptions.BuildCommandArgs()...)
}
case BackendTypeVllm:
if o.VllmServerOptions != nil {
args = append(args, o.VllmServerOptions.BuildCommandArgs()...)
}
}
}
return args
}
// BuildEnvironment builds the environment variables for the backend process
func (o *Options) BuildEnvironment(backendConfig *config.BackendConfig, environment map[string]string) map[string]string {
backendSettings := getBackendSettings(o, backendConfig)
env := map[string]string{}
if backendSettings.Environment != nil {
maps.Copy(env, backendSettings.Environment)
}
if o.isDockerEnabled(backendSettings) {
if backendSettings.Docker.Environment != nil {
maps.Copy(env, backendSettings.Docker.Environment)
}
}
if environment != nil {
maps.Copy(env, environment)
}
return env
}
func (o *Options) GetPort() int {
if o != nil {
switch o.BackendType {
case BackendTypeLlamaCpp:
if o.LlamaServerOptions != nil {
return o.LlamaServerOptions.Port
}
case BackendTypeMlxLm:
if o.MlxServerOptions != nil {
return o.MlxServerOptions.Port
}
case BackendTypeVllm:
if o.VllmServerOptions != nil {
return o.VllmServerOptions.Port
}
}
}
return 0
}
func (o *Options) SetPort(port int) {
if o != nil {
switch o.BackendType {
case BackendTypeLlamaCpp:
if o.LlamaServerOptions != nil {
o.LlamaServerOptions.Port = port
}
case BackendTypeMlxLm:
if o.MlxServerOptions != nil {
o.MlxServerOptions.Port = port
}
case BackendTypeVllm:
if o.VllmServerOptions != nil {
o.VllmServerOptions.Port = port
}
}
}
}
func (o *Options) GetHost() string {
if o != nil {
switch o.BackendType {
case BackendTypeLlamaCpp:
if o.LlamaServerOptions != nil {
return o.LlamaServerOptions.Host
}
case BackendTypeMlxLm:
if o.MlxServerOptions != nil {
return o.MlxServerOptions.Host
}
case BackendTypeVllm:
if o.VllmServerOptions != nil {
return o.VllmServerOptions.Host
}
}
}
return "localhost"
}
func (o *Options) GetResponseHeaders(backendConfig *config.BackendConfig) map[string]string {
backendSettings := getBackendSettings(o, backendConfig)
return backendSettings.ResponseHeaders
}
// ValidateInstanceOptions performs validation based on backend type
func (o *Options) ValidateInstanceOptions() error {
// Validate based on backend type
switch o.BackendType {
case BackendTypeLlamaCpp:
return validateLlamaCppOptions(o.LlamaServerOptions)
case BackendTypeMlxLm:
return validateMlxOptions(o.MlxServerOptions)
case BackendTypeVllm:
return validateVllmOptions(o.VllmServerOptions)
default:
return validation.ValidationError(fmt.Errorf("unsupported backend type: %s", o.BackendType))
}
}

View File

@@ -2,6 +2,8 @@ package backends
import ( import (
"encoding/json" "encoding/json"
"fmt"
"llamactl/pkg/validation"
"reflect" "reflect"
"strconv" "strconv"
) )
@@ -364,3 +366,22 @@ func ParseLlamaCommand(command string) (*LlamaServerOptions, error) {
return &llamaOptions, nil return &llamaOptions, nil
} }
// validateLlamaCppOptions validates llama.cpp specific options
func validateLlamaCppOptions(options *LlamaServerOptions) error {
if options == nil {
return validation.ValidationError(fmt.Errorf("llama server options cannot be nil for llama.cpp backend"))
}
// Use reflection to check all string fields for injection patterns
if err := validation.ValidateStructStrings(options, ""); err != nil {
return err
}
// Basic network validation for port
if options.Port < 0 || options.Port > 65535 {
return validation.ValidationError(fmt.Errorf("invalid port range: %d", options.Port))
}
return nil
}

View File

@@ -1,5 +1,10 @@
package backends package backends
import (
"fmt"
"llamactl/pkg/validation"
)
type MlxServerOptions struct { type MlxServerOptions struct {
// Basic connection options // Basic connection options
Model string `json:"model,omitempty"` Model string `json:"model,omitempty"`
@@ -50,3 +55,21 @@ func ParseMlxCommand(command string) (*MlxServerOptions, error) {
return &mlxOptions, nil return &mlxOptions, nil
} }
// validateMlxOptions validates MLX backend specific options
func validateMlxOptions(options *MlxServerOptions) error {
if options == nil {
return validation.ValidationError(fmt.Errorf("MLX server options cannot be nil for MLX backend"))
}
if err := validation.ValidateStructStrings(options, ""); err != nil {
return err
}
// Basic network validation for port
if options.Port < 0 || options.Port > 65535 {
return validation.ValidationError(fmt.Errorf("invalid port range: %d", options.Port))
}
return nil
}

View File

@@ -1,5 +1,10 @@
package backends package backends
import (
"fmt"
"llamactl/pkg/validation"
)
// vllmMultiValuedFlags defines flags that should be repeated for each value rather than comma-separated // vllmMultiValuedFlags defines flags that should be repeated for each value rather than comma-separated
var vllmMultiValuedFlags = map[string]bool{ var vllmMultiValuedFlags = map[string]bool{
"api-key": true, "api-key": true,
@@ -194,3 +199,22 @@ func ParseVllmCommand(command string) (*VllmServerOptions, error) {
return &vllmOptions, nil return &vllmOptions, nil
} }
// validateVllmOptions validates vLLM backend specific options
func validateVllmOptions(options *VllmServerOptions) error {
if options == nil {
return validation.ValidationError(fmt.Errorf("vLLM server options cannot be nil for vLLM backend"))
}
// Use reflection to check all string fields for injection patterns
if err := validation.ValidateStructStrings(options, ""); err != nil {
return err
}
// Basic network validation for port
if options.Port < 0 || options.Port > 65535 {
return validation.ValidationError(fmt.Errorf("invalid port range: %d", options.Port))
}
return nil
}

View File

@@ -3,7 +3,6 @@ package instance
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"llamactl/pkg/backends"
"llamactl/pkg/config" "llamactl/pkg/config"
"log" "log"
"net/http/httputil" "net/http/httputil"
@@ -124,48 +123,6 @@ func (i *Instance) IsRunning() bool {
return i.status.isRunning() return i.status.isRunning()
} }
func (i *Instance) GetPort() int {
opts := i.GetOptions()
if opts != nil {
switch opts.BackendType {
case backends.BackendTypeLlamaCpp:
if opts.LlamaServerOptions != nil {
return opts.LlamaServerOptions.Port
}
case backends.BackendTypeMlxLm:
if opts.MlxServerOptions != nil {
return opts.MlxServerOptions.Port
}
case backends.BackendTypeVllm:
if opts.VllmServerOptions != nil {
return opts.VllmServerOptions.Port
}
}
}
return 0
}
func (i *Instance) GetHost() string {
opts := i.GetOptions()
if opts != nil {
switch opts.BackendType {
case backends.BackendTypeLlamaCpp:
if opts.LlamaServerOptions != nil {
return opts.LlamaServerOptions.Host
}
case backends.BackendTypeMlxLm:
if opts.MlxServerOptions != nil {
return opts.MlxServerOptions.Host
}
case backends.BackendTypeVllm:
if opts.VllmServerOptions != nil {
return opts.VllmServerOptions.Host
}
}
}
return ""
}
// SetOptions sets the options // SetOptions sets the options
func (i *Instance) SetOptions(opts *Options) { func (i *Instance) SetOptions(opts *Options) {
if opts == nil { if opts == nil {
@@ -198,6 +155,20 @@ func (i *Instance) SetTimeProvider(tp TimeProvider) {
} }
} }
func (i *Instance) GetHost() string {
if i.options == nil {
return "localhost"
}
return i.options.GetHost()
}
func (i *Instance) GetPort() int {
if i.options == nil {
return 0
}
return i.options.GetPort()
}
// GetProxy returns the reverse proxy for this instance // GetProxy returns the reverse proxy for this instance
func (i *Instance) GetProxy() (*httputil.ReverseProxy, error) { func (i *Instance) GetProxy() (*httputil.ReverseProxy, error) {
if i.proxy == nil { if i.proxy == nil {
@@ -266,39 +237,31 @@ func (i *Instance) ShouldTimeout() bool {
return i.proxy.shouldTimeout() return i.proxy.shouldTimeout()
} }
// getBackendHostPort extracts the host and port from instance options func (i *Instance) getCommand() string {
// Returns the configured host and port for the backend
func (i *Instance) getBackendHostPort() (string, int) {
opts := i.GetOptions() opts := i.GetOptions()
if opts == nil { if opts == nil {
return "localhost", 0 return ""
} }
var host string return opts.BackendOptions.GetCommand(i.globalBackendSettings)
var port int
switch opts.BackendType {
case backends.BackendTypeLlamaCpp:
if opts.LlamaServerOptions != nil {
host = opts.LlamaServerOptions.Host
port = opts.LlamaServerOptions.Port
}
case backends.BackendTypeMlxLm:
if opts.MlxServerOptions != nil {
host = opts.MlxServerOptions.Host
port = opts.MlxServerOptions.Port
}
case backends.BackendTypeVllm:
if opts.VllmServerOptions != nil {
host = opts.VllmServerOptions.Host
port = opts.VllmServerOptions.Port
}
} }
if host == "" { func (i *Instance) buildCommandArgs() []string {
host = "localhost" opts := i.GetOptions()
if opts == nil {
return nil
} }
return host, port return opts.BackendOptions.BuildCommandArgs(i.globalBackendSettings)
}
func (i *Instance) buildEnvironment() map[string]string {
opts := i.GetOptions()
if opts == nil {
return nil
}
return opts.BackendOptions.BuildEnvironment(i.globalBackendSettings, opts.Environment)
} }
// MarshalJSON implements json.Marshaler for Instance // MarshalJSON implements json.Marshaler for Instance
@@ -307,21 +270,7 @@ func (i *Instance) MarshalJSON() ([]byte, error) {
opts := i.GetOptions() opts := i.GetOptions()
// Determine if docker is enabled for this instance's backend // Determine if docker is enabled for this instance's backend
var dockerEnabled bool dockerEnabled := opts.BackendOptions.IsDockerEnabled(i.globalBackendSettings)
if opts != nil {
switch opts.BackendType {
case backends.BackendTypeLlamaCpp:
if i.globalBackendSettings != nil && i.globalBackendSettings.LlamaCpp.Docker != nil && i.globalBackendSettings.LlamaCpp.Docker.Enabled {
dockerEnabled = true
}
case backends.BackendTypeVllm:
if i.globalBackendSettings != nil && i.globalBackendSettings.VLLM.Docker != nil && i.globalBackendSettings.VLLM.Docker.Enabled {
dockerEnabled = true
}
case backends.BackendTypeMlxLm:
// MLX does not support docker currently
}
}
return json.Marshal(&struct { return json.Marshal(&struct {
Name string `json:"name"` Name string `json:"name"`

View File

@@ -6,7 +6,6 @@ import (
"llamactl/pkg/backends" "llamactl/pkg/backends"
"llamactl/pkg/config" "llamactl/pkg/config"
"log" "log"
"maps"
"slices" "slices"
"sync" "sync"
) )
@@ -23,16 +22,10 @@ type Options struct {
IdleTimeout *int `json:"idle_timeout,omitempty"` // minutes IdleTimeout *int `json:"idle_timeout,omitempty"` // minutes
// Environment variables // Environment variables
Environment map[string]string `json:"environment,omitempty"` Environment map[string]string `json:"environment,omitempty"`
// Assigned nodes
BackendType backends.BackendType `json:"backend_type"`
BackendOptions map[string]any `json:"backend_options,omitempty"`
Nodes map[string]struct{} `json:"-"` Nodes map[string]struct{} `json:"-"`
// Backend options
// Backend-specific options BackendOptions backends.Options `json:"-"`
LlamaServerOptions *backends.LlamaServerOptions `json:"-"`
MlxServerOptions *backends.MlxServerOptions `json:"-"`
VllmServerOptions *backends.VllmServerOptions `json:"-"`
} }
// options wraps Options with thread-safe access (unexported). // options wraps Options with thread-safe access (unexported).
@@ -62,6 +55,18 @@ func (o *options) set(opts *Options) {
o.opts = opts o.opts = opts
} }
func (o *options) GetHost() string {
o.mu.RLock()
defer o.mu.RUnlock()
return o.opts.BackendOptions.GetHost()
}
func (o *options) GetPort() int {
o.mu.RLock()
defer o.mu.RUnlock()
return o.opts.BackendOptions.GetPort()
}
// MarshalJSON implements json.Marshaler for options wrapper // MarshalJSON implements json.Marshaler for options wrapper
func (o *options) MarshalJSON() ([]byte, error) { func (o *options) MarshalJSON() ([]byte, error) {
o.mu.RLock() o.mu.RLock()
@@ -85,7 +90,9 @@ func (c *Options) UnmarshalJSON(data []byte) error {
// Use anonymous struct to avoid recursion // Use anonymous struct to avoid recursion
type Alias Options type Alias Options
aux := &struct { aux := &struct {
Nodes []string `json:"nodes,omitempty"` // Accept JSON array Nodes []string `json:"nodes,omitempty"`
BackendType backends.BackendType `json:"backend_type"`
BackendOptions map[string]any `json:"backend_options,omitempty"`
*Alias *Alias
}{ }{
Alias: (*Alias)(c), Alias: (*Alias)(c),
@@ -103,47 +110,27 @@ func (c *Options) UnmarshalJSON(data []byte) error {
} }
} }
// Parse backend-specific options // Create backend options struct and unmarshal
switch c.BackendType { c.BackendOptions = backends.Options{
case backends.BackendTypeLlamaCpp: BackendType: aux.BackendType,
if c.BackendOptions != nil { BackendOptions: aux.BackendOptions,
// Convert map to JSON and then unmarshal to LlamaServerOptions }
optionsData, err := json.Marshal(c.BackendOptions)
// Marshal the backend options to JSON for proper unmarshaling
backendJson, err := json.Marshal(struct {
BackendType backends.BackendType `json:"backend_type"`
BackendOptions map[string]any `json:"backend_options,omitempty"`
}{
BackendType: aux.BackendType,
BackendOptions: aux.BackendOptions,
})
if err != nil { if err != nil {
return fmt.Errorf("failed to marshal backend options: %w", err) return fmt.Errorf("failed to marshal backend options: %w", err)
} }
c.LlamaServerOptions = &backends.LlamaServerOptions{} // Unmarshal into the backends.Options struct to trigger its custom unmarshaling
if err := json.Unmarshal(optionsData, c.LlamaServerOptions); err != nil { if err := json.Unmarshal(backendJson, &c.BackendOptions); err != nil {
return fmt.Errorf("failed to unmarshal llama.cpp options: %w", err) return fmt.Errorf("failed to unmarshal backend options: %w", err)
}
}
case backends.BackendTypeMlxLm:
if c.BackendOptions != nil {
optionsData, err := json.Marshal(c.BackendOptions)
if err != nil {
return fmt.Errorf("failed to marshal backend options: %w", err)
}
c.MlxServerOptions = &backends.MlxServerOptions{}
if err := json.Unmarshal(optionsData, c.MlxServerOptions); err != nil {
return fmt.Errorf("failed to unmarshal MLX options: %w", err)
}
}
case backends.BackendTypeVllm:
if c.BackendOptions != nil {
optionsData, err := json.Marshal(c.BackendOptions)
if err != nil {
return fmt.Errorf("failed to marshal backend options: %w", err)
}
c.VllmServerOptions = &backends.VllmServerOptions{}
if err := json.Unmarshal(optionsData, c.VllmServerOptions); err != nil {
return fmt.Errorf("failed to unmarshal vLLM options: %w", err)
}
}
default:
return fmt.Errorf("unknown backend type: %s", c.BackendType)
} }
return nil return nil
@@ -155,6 +142,8 @@ func (c *Options) MarshalJSON() ([]byte, error) {
type Alias Options type Alias Options
aux := struct { aux := struct {
Nodes []string `json:"nodes,omitempty"` // Output as JSON array Nodes []string `json:"nodes,omitempty"` // Output as JSON array
BackendType backends.BackendType `json:"backend_type"`
BackendOptions map[string]any `json:"backend_options,omitempty"`
*Alias *Alias
}{ }{
Alias: (*Alias)(c), Alias: (*Alias)(c),
@@ -170,51 +159,24 @@ func (c *Options) MarshalJSON() ([]byte, error) {
slices.Sort(aux.Nodes) slices.Sort(aux.Nodes)
} }
// Convert backend-specific options back to BackendOptions map for JSON // Set backend type
switch c.BackendType { aux.BackendType = c.BackendOptions.BackendType
case backends.BackendTypeLlamaCpp:
if c.LlamaServerOptions != nil { // Marshal the backends.Options struct to get the properly formatted backend options
data, err := json.Marshal(c.LlamaServerOptions) backendData, err := json.Marshal(c.BackendOptions)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to marshal llama server options: %w", err) return nil, fmt.Errorf("failed to marshal backend options: %w", err)
} }
var backendOpts map[string]any // Unmarshal into a temporary struct to extract the backend_options map
if err := json.Unmarshal(data, &backendOpts); err != nil { var tempBackend struct {
return nil, fmt.Errorf("failed to unmarshal to map: %w", err) BackendOptions map[string]any `json:"backend_options,omitempty"`
}
if err := json.Unmarshal(backendData, &tempBackend); err != nil {
return nil, fmt.Errorf("failed to unmarshal backend data: %w", err)
} }
aux.BackendOptions = backendOpts aux.BackendOptions = tempBackend.BackendOptions
}
case backends.BackendTypeMlxLm:
if c.MlxServerOptions != nil {
data, err := json.Marshal(c.MlxServerOptions)
if err != nil {
return nil, fmt.Errorf("failed to marshal MLX server options: %w", err)
}
var backendOpts map[string]any
if err := json.Unmarshal(data, &backendOpts); err != nil {
return nil, fmt.Errorf("failed to unmarshal to map: %w", err)
}
aux.BackendOptions = backendOpts
}
case backends.BackendTypeVllm:
if c.VllmServerOptions != nil {
data, err := json.Marshal(c.VllmServerOptions)
if err != nil {
return nil, fmt.Errorf("failed to marshal vLLM server options: %w", err)
}
var backendOpts map[string]any
if err := json.Unmarshal(data, &backendOpts); err != nil {
return nil, fmt.Errorf("failed to unmarshal to map: %w", err)
}
aux.BackendOptions = backendOpts
}
}
return json.Marshal(aux) return json.Marshal(aux)
} }
@@ -257,78 +219,3 @@ func (c *Options) validateAndApplyDefaults(name string, globalSettings *config.I
} }
} }
} }
// getCommand builds the command to run the backend
func (c *Options) getCommand(backendConfig *config.BackendSettings) string {
if backendConfig.Docker != nil && backendConfig.Docker.Enabled && c.BackendType != backends.BackendTypeMlxLm {
return "docker"
}
return backendConfig.Command
}
// buildCommandArgs builds command line arguments for the backend
func (c *Options) buildCommandArgs(backendConfig *config.BackendSettings) []string {
var args []string
if backendConfig.Docker != nil && backendConfig.Docker.Enabled && c.BackendType != backends.BackendTypeMlxLm {
// For Docker, start with Docker args
args = append(args, backendConfig.Docker.Args...)
args = append(args, backendConfig.Docker.Image)
switch c.BackendType {
case backends.BackendTypeLlamaCpp:
if c.LlamaServerOptions != nil {
args = append(args, c.LlamaServerOptions.BuildDockerArgs()...)
}
case backends.BackendTypeVllm:
if c.VllmServerOptions != nil {
args = append(args, c.VllmServerOptions.BuildDockerArgs()...)
}
}
} else {
// For native execution, start with backend args
args = append(args, backendConfig.Args...)
switch c.BackendType {
case backends.BackendTypeLlamaCpp:
if c.LlamaServerOptions != nil {
args = append(args, c.LlamaServerOptions.BuildCommandArgs()...)
}
case backends.BackendTypeMlxLm:
if c.MlxServerOptions != nil {
args = append(args, c.MlxServerOptions.BuildCommandArgs()...)
}
case backends.BackendTypeVllm:
if c.VllmServerOptions != nil {
args = append(args, c.VllmServerOptions.BuildCommandArgs()...)
}
}
}
return args
}
// buildEnvironment builds the environment variables for the backend process
func (c *Options) buildEnvironment(backendConfig *config.BackendSettings) map[string]string {
env := map[string]string{}
if backendConfig.Environment != nil {
maps.Copy(env, backendConfig.Environment)
}
if backendConfig.Docker != nil && backendConfig.Docker.Enabled && c.BackendType != backends.BackendTypeMlxLm {
if backendConfig.Docker.Environment != nil {
maps.Copy(env, backendConfig.Docker.Environment)
}
}
if c.Environment != nil {
maps.Copy(env, c.Environment)
}
return env
}

View File

@@ -12,9 +12,6 @@ import (
"sync" "sync"
"syscall" "syscall"
"time" "time"
"llamactl/pkg/backends"
"llamactl/pkg/config"
) )
// process manages the OS process lifecycle for a local instance. // process manages the OS process lifecycle for a local instance.
@@ -216,7 +213,8 @@ func (p *process) waitForHealthy(timeout int) error {
defer cancel() defer cancel()
// Get host/port from instance // Get host/port from instance
host, port := p.instance.getBackendHostPort() host := p.instance.options.GetHost()
port := p.instance.options.GetPort()
healthURL := fmt.Sprintf("http://%s:%d/health", host, port) healthURL := fmt.Sprintf("http://%s:%d/health", host, port)
// Create a dedicated HTTP client for health checks // Create a dedicated HTTP client for health checks
@@ -386,26 +384,15 @@ func (p *process) handleAutoRestart(err error) {
// buildCommand builds the command to execute using backend-specific logic // buildCommand builds the command to execute using backend-specific logic
func (p *process) buildCommand() (*exec.Cmd, error) { func (p *process) buildCommand() (*exec.Cmd, error) {
// Get options
opts := p.instance.GetOptions()
if opts == nil {
return nil, fmt.Errorf("instance options are nil")
}
// Get backend configuration
backendConfig, err := p.getBackendConfig()
if err != nil {
return nil, err
}
// Build the environment variables // Build the environment variables
env := opts.buildEnvironment(backendConfig) env := p.instance.buildEnvironment()
// Get the command to execute // Get the command to execute
command := opts.getCommand(backendConfig) command := p.instance.getCommand()
// Build command arguments // Build command arguments
args := opts.buildCommandArgs(backendConfig) args := p.instance.buildCommandArgs()
// Create the exec.Cmd // Create the exec.Cmd
cmd := exec.CommandContext(p.ctx, command, args...) cmd := exec.CommandContext(p.ctx, command, args...)
@@ -420,27 +407,3 @@ func (p *process) buildCommand() (*exec.Cmd, error) {
return cmd, nil return cmd, nil
} }
// getBackendConfig resolves the backend configuration for the current instance
func (p *process) getBackendConfig() (*config.BackendSettings, error) {
opts := p.instance.GetOptions()
if opts == nil {
return nil, fmt.Errorf("instance options are nil")
}
var backendTypeStr string
switch opts.BackendType {
case backends.BackendTypeLlamaCpp:
backendTypeStr = "llama-cpp"
case backends.BackendTypeMlxLm:
backendTypeStr = "mlx"
case backends.BackendTypeVllm:
backendTypeStr = "vllm"
default:
return nil, fmt.Errorf("unsupported backend type: %s", opts.BackendType)
}
settings := p.instance.globalBackendSettings.GetBackendSettings(backendTypeStr)
return &settings, nil
}

View File

@@ -2,7 +2,6 @@ package instance
import ( import (
"fmt" "fmt"
"llamactl/pkg/backends"
"net/http" "net/http"
"net/http/httputil" "net/http/httputil"
"net/url" "net/url"
@@ -68,8 +67,11 @@ func (p *proxy) build() (*httputil.ReverseProxy, error) {
} }
// Get host/port from process // Get host/port from process
host, port := p.instance.getBackendHostPort() host := p.instance.options.GetHost()
port := p.instance.options.GetPort()
if port == 0 {
return nil, fmt.Errorf("instance %s has no port assigned", p.instance.Name)
}
targetURL, err := url.Parse(fmt.Sprintf("http://%s:%d", host, port)) targetURL, err := url.Parse(fmt.Sprintf("http://%s:%d", host, port))
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to parse target URL for instance %s: %w", p.instance.Name, err) return nil, fmt.Errorf("failed to parse target URL for instance %s: %w", p.instance.Name, err)
@@ -78,15 +80,7 @@ func (p *proxy) build() (*httputil.ReverseProxy, error) {
proxy := httputil.NewSingleHostReverseProxy(targetURL) proxy := httputil.NewSingleHostReverseProxy(targetURL)
// Get response headers from backend config // Get response headers from backend config
var responseHeaders map[string]string responseHeaders := options.BackendOptions.GetResponseHeaders(p.instance.globalBackendSettings)
switch options.BackendType {
case backends.BackendTypeLlamaCpp:
responseHeaders = p.instance.globalBackendSettings.LlamaCpp.ResponseHeaders
case backends.BackendTypeVllm:
responseHeaders = p.instance.globalBackendSettings.VLLM.ResponseHeaders
case backends.BackendTypeMlxLm:
responseHeaders = p.instance.globalBackendSettings.MLX.ResponseHeaders
}
proxy.ModifyResponse = func(resp *http.Response) error { proxy.ModifyResponse = func(resp *http.Response) error {
// Remove CORS headers from backend response to avoid conflicts // Remove CORS headers from backend response to avoid conflicts

View File

@@ -2,7 +2,6 @@ package manager
import ( import (
"fmt" "fmt"
"llamactl/pkg/backends"
"llamactl/pkg/instance" "llamactl/pkg/instance"
"llamactl/pkg/validation" "llamactl/pkg/validation"
"os" "os"
@@ -86,7 +85,7 @@ func (im *instanceManager) CreateInstance(name string, options *instance.Options
return nil, err return nil, err
} }
err = validation.ValidateInstanceOptions(options) err = options.BackendOptions.ValidateInstanceOptions()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -232,7 +231,7 @@ func (im *instanceManager) UpdateInstance(name string, options *instance.Options
return nil, fmt.Errorf("instance options cannot be nil") return nil, fmt.Errorf("instance options cannot be nil")
} }
err := validation.ValidateInstanceOptions(options) err := options.BackendOptions.ValidateInstanceOptions()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -493,39 +492,12 @@ func (im *instanceManager) GetInstanceLogs(name string, numLines int) (string, e
// getPortFromOptions extracts the port from backend-specific options // getPortFromOptions extracts the port from backend-specific options
func (im *instanceManager) getPortFromOptions(options *instance.Options) int { func (im *instanceManager) getPortFromOptions(options *instance.Options) int {
switch options.BackendType { return options.BackendOptions.GetPort()
case backends.BackendTypeLlamaCpp:
if options.LlamaServerOptions != nil {
return options.LlamaServerOptions.Port
}
case backends.BackendTypeMlxLm:
if options.MlxServerOptions != nil {
return options.MlxServerOptions.Port
}
case backends.BackendTypeVllm:
if options.VllmServerOptions != nil {
return options.VllmServerOptions.Port
}
}
return 0
} }
// setPortInOptions sets the port in backend-specific options // setPortInOptions sets the port in backend-specific options
func (im *instanceManager) setPortInOptions(options *instance.Options, port int) { func (im *instanceManager) setPortInOptions(options *instance.Options, port int) {
switch options.BackendType { options.BackendOptions.SetPort(port)
case backends.BackendTypeLlamaCpp:
if options.LlamaServerOptions != nil {
options.LlamaServerOptions.Port = port
}
case backends.BackendTypeMlxLm:
if options.MlxServerOptions != nil {
options.MlxServerOptions.Port = port
}
case backends.BackendTypeVllm:
if options.VllmServerOptions != nil {
options.VllmServerOptions.Port = port
}
}
} }
// assignAndValidatePort assigns a port if not specified and validates it's not in use // assignAndValidatePort assigns a port if not specified and validates it's not in use

View File

@@ -2,8 +2,6 @@ package validation
import ( import (
"fmt" "fmt"
"llamactl/pkg/backends"
"llamactl/pkg/instance"
"reflect" "reflect"
"regexp" "regexp"
) )
@@ -24,8 +22,8 @@ var (
type ValidationError error type ValidationError error
// validateStringForInjection checks if a string contains dangerous patterns // ValidateStringForInjection checks if a string contains dangerous patterns
func validateStringForInjection(value string) error { func ValidateStringForInjection(value string) error {
for _, pattern := range dangerousPatterns { for _, pattern := range dangerousPatterns {
if pattern.MatchString(value) { if pattern.MatchString(value) {
return ValidationError(fmt.Errorf("value contains potentially dangerous characters: %s", value)) return ValidationError(fmt.Errorf("value contains potentially dangerous characters: %s", value))
@@ -34,83 +32,8 @@ func validateStringForInjection(value string) error {
return nil return nil
} }
// ValidateInstanceOptions performs validation based on backend type // ValidateStructStrings recursively validates all string fields in a struct
func ValidateInstanceOptions(options *instance.Options) error { func ValidateStructStrings(v any, fieldPath string) error {
if options == nil {
return ValidationError(fmt.Errorf("options cannot be nil"))
}
// Validate based on backend type
switch options.BackendType {
case backends.BackendTypeLlamaCpp:
return validateLlamaCppOptions(options)
case backends.BackendTypeMlxLm:
return validateMlxOptions(options)
case backends.BackendTypeVllm:
return validateVllmOptions(options)
default:
return ValidationError(fmt.Errorf("unsupported backend type: %s", options.BackendType))
}
}
// validateLlamaCppOptions validates llama.cpp specific options
func validateLlamaCppOptions(options *instance.Options) error {
if options.LlamaServerOptions == nil {
return ValidationError(fmt.Errorf("llama server options cannot be nil for llama.cpp backend"))
}
// Use reflection to check all string fields for injection patterns
if err := validateStructStrings(options.LlamaServerOptions, ""); err != nil {
return err
}
// Basic network validation for port
if options.LlamaServerOptions.Port < 0 || options.LlamaServerOptions.Port > 65535 {
return ValidationError(fmt.Errorf("invalid port range: %d", options.LlamaServerOptions.Port))
}
return nil
}
// validateMlxOptions validates MLX backend specific options
func validateMlxOptions(options *instance.Options) error {
if options.MlxServerOptions == nil {
return ValidationError(fmt.Errorf("MLX server options cannot be nil for MLX backend"))
}
if err := validateStructStrings(options.MlxServerOptions, ""); err != nil {
return err
}
// Basic network validation for port
if options.MlxServerOptions.Port < 0 || options.MlxServerOptions.Port > 65535 {
return ValidationError(fmt.Errorf("invalid port range: %d", options.MlxServerOptions.Port))
}
return nil
}
// validateVllmOptions validates vLLM backend specific options
func validateVllmOptions(options *instance.Options) error {
if options.VllmServerOptions == nil {
return ValidationError(fmt.Errorf("vLLM server options cannot be nil for vLLM backend"))
}
// Use reflection to check all string fields for injection patterns
if err := validateStructStrings(options.VllmServerOptions, ""); err != nil {
return err
}
// Basic network validation for port
if options.VllmServerOptions.Port < 0 || options.VllmServerOptions.Port > 65535 {
return ValidationError(fmt.Errorf("invalid port range: %d", options.VllmServerOptions.Port))
}
return nil
}
// validateStructStrings recursively validates all string fields in a struct
func validateStructStrings(v any, fieldPath string) error {
val := reflect.ValueOf(v) val := reflect.ValueOf(v)
if val.Kind() == reflect.Ptr { if val.Kind() == reflect.Ptr {
val = val.Elem() val = val.Elem()
@@ -136,21 +59,21 @@ func validateStructStrings(v any, fieldPath string) error {
switch field.Kind() { switch field.Kind() {
case reflect.String: case reflect.String:
if err := validateStringForInjection(field.String()); err != nil { if err := ValidateStringForInjection(field.String()); err != nil {
return ValidationError(fmt.Errorf("field %s: %w", fieldName, err)) return ValidationError(fmt.Errorf("field %s: %w", fieldName, err))
} }
case reflect.Slice: case reflect.Slice:
if field.Type().Elem().Kind() == reflect.String { if field.Type().Elem().Kind() == reflect.String {
for j := 0; j < field.Len(); j++ { for j := 0; j < field.Len(); j++ {
if err := validateStringForInjection(field.Index(j).String()); err != nil { if err := ValidateStringForInjection(field.Index(j).String()); err != nil {
return ValidationError(fmt.Errorf("field %s[%d]: %w", fieldName, j, err)) return ValidationError(fmt.Errorf("field %s[%d]: %w", fieldName, j, err))
} }
} }
} }
case reflect.Struct: case reflect.Struct:
if err := validateStructStrings(field.Interface(), fieldName); err != nil { if err := ValidateStructStrings(field.Interface(), fieldName); err != nil {
return err return err
} }
} }