mirror of
https://github.com/lordmathis/llamactl.git
synced 2025-11-05 16:44:22 +00:00
Refactor backend options handling and validation
This commit is contained in:
@@ -1,5 +1,13 @@
|
|||||||
package backends
|
package backends
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"llamactl/pkg/config"
|
||||||
|
"llamactl/pkg/validation"
|
||||||
|
"maps"
|
||||||
|
)
|
||||||
|
|
||||||
type BackendType string
|
type BackendType string
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -13,10 +21,301 @@ type Options struct {
|
|||||||
BackendType BackendType `json:"backend_type"`
|
BackendType BackendType `json:"backend_type"`
|
||||||
BackendOptions map[string]any `json:"backend_options,omitempty"`
|
BackendOptions map[string]any `json:"backend_options,omitempty"`
|
||||||
|
|
||||||
Nodes map[string]struct{} `json:"-"`
|
|
||||||
|
|
||||||
// Backend-specific options
|
// Backend-specific options
|
||||||
LlamaServerOptions *LlamaServerOptions `json:"-"`
|
LlamaServerOptions *LlamaServerOptions `json:"-"`
|
||||||
MlxServerOptions *MlxServerOptions `json:"-"`
|
MlxServerOptions *MlxServerOptions `json:"-"`
|
||||||
VllmServerOptions *VllmServerOptions `json:"-"`
|
VllmServerOptions *VllmServerOptions `json:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (o *Options) UnmarshalJSON(data []byte) error {
|
||||||
|
// Use anonymous struct to avoid recursion
|
||||||
|
type Alias Options
|
||||||
|
aux := &struct {
|
||||||
|
*Alias
|
||||||
|
}{
|
||||||
|
Alias: (*Alias)(o),
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.Unmarshal(data, aux); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse backend-specific options
|
||||||
|
switch o.BackendType {
|
||||||
|
case BackendTypeLlamaCpp:
|
||||||
|
if o.BackendOptions != nil {
|
||||||
|
// Convert map to JSON and then unmarshal to LlamaServerOptions
|
||||||
|
optionsData, err := json.Marshal(o.BackendOptions)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to marshal backend options: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
o.LlamaServerOptions = &LlamaServerOptions{}
|
||||||
|
if err := json.Unmarshal(optionsData, o.LlamaServerOptions); err != nil {
|
||||||
|
return fmt.Errorf("failed to unmarshal llama.cpp options: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case BackendTypeMlxLm:
|
||||||
|
if o.BackendOptions != nil {
|
||||||
|
optionsData, err := json.Marshal(o.BackendOptions)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to marshal backend options: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
o.MlxServerOptions = &MlxServerOptions{}
|
||||||
|
if err := json.Unmarshal(optionsData, o.MlxServerOptions); err != nil {
|
||||||
|
return fmt.Errorf("failed to unmarshal MLX options: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case BackendTypeVllm:
|
||||||
|
if o.BackendOptions != nil {
|
||||||
|
optionsData, err := json.Marshal(o.BackendOptions)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to marshal backend options: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
o.VllmServerOptions = &VllmServerOptions{}
|
||||||
|
if err := json.Unmarshal(optionsData, o.VllmServerOptions); err != nil {
|
||||||
|
return fmt.Errorf("failed to unmarshal vLLM options: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *Options) MarshalJSON() ([]byte, error) {
|
||||||
|
// Use anonymous struct to avoid recursion
|
||||||
|
type Alias Options
|
||||||
|
aux := &struct {
|
||||||
|
*Alias
|
||||||
|
}{
|
||||||
|
Alias: (*Alias)(o),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepare BackendOptions map
|
||||||
|
if o.BackendOptions == nil {
|
||||||
|
o.BackendOptions = make(map[string]any)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Populate BackendOptions based on backend-specific options
|
||||||
|
switch o.BackendType {
|
||||||
|
case BackendTypeLlamaCpp:
|
||||||
|
if o.LlamaServerOptions != nil {
|
||||||
|
optionsData, err := json.Marshal(o.LlamaServerOptions)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to marshal llama.cpp options: %w", err)
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(optionsData, &o.BackendOptions); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to unmarshal llama.cpp options to map: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case BackendTypeMlxLm:
|
||||||
|
if o.MlxServerOptions != nil {
|
||||||
|
optionsData, err := json.Marshal(o.MlxServerOptions)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to marshal MLX options: %w", err)
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(optionsData, &o.BackendOptions); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to unmarshal MLX options to map: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case BackendTypeVllm:
|
||||||
|
if o.VllmServerOptions != nil {
|
||||||
|
optionsData, err := json.Marshal(o.VllmServerOptions)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to marshal vLLM options: %w", err)
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(optionsData, &o.BackendOptions); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to unmarshal vLLM options to map: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return json.Marshal(aux)
|
||||||
|
}
|
||||||
|
|
||||||
|
func getBackendSettings(o *Options, backendConfig *config.BackendConfig) *config.BackendSettings {
|
||||||
|
switch o.BackendType {
|
||||||
|
case BackendTypeLlamaCpp:
|
||||||
|
return &backendConfig.LlamaCpp
|
||||||
|
case BackendTypeMlxLm:
|
||||||
|
return &backendConfig.MLX
|
||||||
|
case BackendTypeVllm:
|
||||||
|
return &backendConfig.VLLM
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *Options) isDockerEnabled(backend *config.BackendSettings) bool {
|
||||||
|
if backend.Docker != nil && backend.Docker.Enabled && o.BackendType != BackendTypeMlxLm {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *Options) IsDockerEnabled(backendConfig *config.BackendConfig) bool {
|
||||||
|
backendSettings := getBackendSettings(o, backendConfig)
|
||||||
|
return o.isDockerEnabled(backendSettings)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCommand builds the command to run the backend
|
||||||
|
func (o *Options) GetCommand(backendConfig *config.BackendConfig) string {
|
||||||
|
|
||||||
|
backendSettings := getBackendSettings(o, backendConfig)
|
||||||
|
|
||||||
|
if o.isDockerEnabled(backendSettings) {
|
||||||
|
return "docker"
|
||||||
|
}
|
||||||
|
|
||||||
|
return backendSettings.Command
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildCommandArgs builds command line arguments for the backend
|
||||||
|
func (o *Options) BuildCommandArgs(backendConfig *config.BackendConfig) []string {
|
||||||
|
|
||||||
|
var args []string
|
||||||
|
|
||||||
|
backendSettings := getBackendSettings(o, backendConfig)
|
||||||
|
|
||||||
|
if o.isDockerEnabled(backendSettings) {
|
||||||
|
// For Docker, start with Docker args
|
||||||
|
args = append(args, backendSettings.Docker.Args...)
|
||||||
|
args = append(args, backendSettings.Docker.Image)
|
||||||
|
|
||||||
|
switch o.BackendType {
|
||||||
|
case BackendTypeLlamaCpp:
|
||||||
|
if o.LlamaServerOptions != nil {
|
||||||
|
args = append(args, o.LlamaServerOptions.BuildDockerArgs()...)
|
||||||
|
}
|
||||||
|
case BackendTypeVllm:
|
||||||
|
if o.VllmServerOptions != nil {
|
||||||
|
args = append(args, o.VllmServerOptions.BuildDockerArgs()...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} else {
|
||||||
|
// For native execution, start with backend args
|
||||||
|
args = append(args, backendSettings.Args...)
|
||||||
|
|
||||||
|
switch o.BackendType {
|
||||||
|
case BackendTypeLlamaCpp:
|
||||||
|
if o.LlamaServerOptions != nil {
|
||||||
|
args = append(args, o.LlamaServerOptions.BuildCommandArgs()...)
|
||||||
|
}
|
||||||
|
case BackendTypeMlxLm:
|
||||||
|
if o.MlxServerOptions != nil {
|
||||||
|
args = append(args, o.MlxServerOptions.BuildCommandArgs()...)
|
||||||
|
}
|
||||||
|
case BackendTypeVllm:
|
||||||
|
if o.VllmServerOptions != nil {
|
||||||
|
args = append(args, o.VllmServerOptions.BuildCommandArgs()...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return args
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildEnvironment builds the environment variables for the backend process
|
||||||
|
func (o *Options) BuildEnvironment(backendConfig *config.BackendConfig, environment map[string]string) map[string]string {
|
||||||
|
|
||||||
|
backendSettings := getBackendSettings(o, backendConfig)
|
||||||
|
env := map[string]string{}
|
||||||
|
|
||||||
|
if backendSettings.Environment != nil {
|
||||||
|
maps.Copy(env, backendSettings.Environment)
|
||||||
|
}
|
||||||
|
|
||||||
|
if o.isDockerEnabled(backendSettings) {
|
||||||
|
if backendSettings.Docker.Environment != nil {
|
||||||
|
maps.Copy(env, backendSettings.Docker.Environment)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if environment != nil {
|
||||||
|
maps.Copy(env, environment)
|
||||||
|
}
|
||||||
|
|
||||||
|
return env
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *Options) GetPort() int {
|
||||||
|
if o != nil {
|
||||||
|
switch o.BackendType {
|
||||||
|
case BackendTypeLlamaCpp:
|
||||||
|
if o.LlamaServerOptions != nil {
|
||||||
|
return o.LlamaServerOptions.Port
|
||||||
|
}
|
||||||
|
case BackendTypeMlxLm:
|
||||||
|
if o.MlxServerOptions != nil {
|
||||||
|
return o.MlxServerOptions.Port
|
||||||
|
}
|
||||||
|
case BackendTypeVllm:
|
||||||
|
if o.VllmServerOptions != nil {
|
||||||
|
return o.VllmServerOptions.Port
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *Options) SetPort(port int) {
|
||||||
|
if o != nil {
|
||||||
|
switch o.BackendType {
|
||||||
|
case BackendTypeLlamaCpp:
|
||||||
|
if o.LlamaServerOptions != nil {
|
||||||
|
o.LlamaServerOptions.Port = port
|
||||||
|
}
|
||||||
|
case BackendTypeMlxLm:
|
||||||
|
if o.MlxServerOptions != nil {
|
||||||
|
o.MlxServerOptions.Port = port
|
||||||
|
}
|
||||||
|
case BackendTypeVllm:
|
||||||
|
if o.VllmServerOptions != nil {
|
||||||
|
o.VllmServerOptions.Port = port
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *Options) GetHost() string {
|
||||||
|
if o != nil {
|
||||||
|
switch o.BackendType {
|
||||||
|
case BackendTypeLlamaCpp:
|
||||||
|
if o.LlamaServerOptions != nil {
|
||||||
|
return o.LlamaServerOptions.Host
|
||||||
|
}
|
||||||
|
case BackendTypeMlxLm:
|
||||||
|
if o.MlxServerOptions != nil {
|
||||||
|
return o.MlxServerOptions.Host
|
||||||
|
}
|
||||||
|
case BackendTypeVllm:
|
||||||
|
if o.VllmServerOptions != nil {
|
||||||
|
return o.VllmServerOptions.Host
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "localhost"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *Options) GetResponseHeaders(backendConfig *config.BackendConfig) map[string]string {
|
||||||
|
backendSettings := getBackendSettings(o, backendConfig)
|
||||||
|
return backendSettings.ResponseHeaders
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateInstanceOptions performs validation based on backend type
|
||||||
|
func (o *Options) ValidateInstanceOptions() error {
|
||||||
|
// Validate based on backend type
|
||||||
|
switch o.BackendType {
|
||||||
|
case BackendTypeLlamaCpp:
|
||||||
|
return validateLlamaCppOptions(o.LlamaServerOptions)
|
||||||
|
case BackendTypeMlxLm:
|
||||||
|
return validateMlxOptions(o.MlxServerOptions)
|
||||||
|
case BackendTypeVllm:
|
||||||
|
return validateVllmOptions(o.VllmServerOptions)
|
||||||
|
default:
|
||||||
|
return validation.ValidationError(fmt.Errorf("unsupported backend type: %s", o.BackendType))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -2,6 +2,8 @@ package backends
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"llamactl/pkg/validation"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strconv"
|
"strconv"
|
||||||
)
|
)
|
||||||
@@ -364,3 +366,22 @@ func ParseLlamaCommand(command string) (*LlamaServerOptions, error) {
|
|||||||
|
|
||||||
return &llamaOptions, nil
|
return &llamaOptions, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// validateLlamaCppOptions validates llama.cpp specific options
|
||||||
|
func validateLlamaCppOptions(options *LlamaServerOptions) error {
|
||||||
|
if options == nil {
|
||||||
|
return validation.ValidationError(fmt.Errorf("llama server options cannot be nil for llama.cpp backend"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use reflection to check all string fields for injection patterns
|
||||||
|
if err := validation.ValidateStructStrings(options, ""); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Basic network validation for port
|
||||||
|
if options.Port < 0 || options.Port > 65535 {
|
||||||
|
return validation.ValidationError(fmt.Errorf("invalid port range: %d", options.Port))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,5 +1,10 @@
|
|||||||
package backends
|
package backends
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"llamactl/pkg/validation"
|
||||||
|
)
|
||||||
|
|
||||||
type MlxServerOptions struct {
|
type MlxServerOptions struct {
|
||||||
// Basic connection options
|
// Basic connection options
|
||||||
Model string `json:"model,omitempty"`
|
Model string `json:"model,omitempty"`
|
||||||
@@ -50,3 +55,21 @@ func ParseMlxCommand(command string) (*MlxServerOptions, error) {
|
|||||||
|
|
||||||
return &mlxOptions, nil
|
return &mlxOptions, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// validateMlxOptions validates MLX backend specific options
|
||||||
|
func validateMlxOptions(options *MlxServerOptions) error {
|
||||||
|
if options == nil {
|
||||||
|
return validation.ValidationError(fmt.Errorf("MLX server options cannot be nil for MLX backend"))
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := validation.ValidateStructStrings(options, ""); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Basic network validation for port
|
||||||
|
if options.Port < 0 || options.Port > 65535 {
|
||||||
|
return validation.ValidationError(fmt.Errorf("invalid port range: %d", options.Port))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,5 +1,10 @@
|
|||||||
package backends
|
package backends
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"llamactl/pkg/validation"
|
||||||
|
)
|
||||||
|
|
||||||
// vllmMultiValuedFlags defines flags that should be repeated for each value rather than comma-separated
|
// vllmMultiValuedFlags defines flags that should be repeated for each value rather than comma-separated
|
||||||
var vllmMultiValuedFlags = map[string]bool{
|
var vllmMultiValuedFlags = map[string]bool{
|
||||||
"api-key": true,
|
"api-key": true,
|
||||||
@@ -194,3 +199,22 @@ func ParseVllmCommand(command string) (*VllmServerOptions, error) {
|
|||||||
|
|
||||||
return &vllmOptions, nil
|
return &vllmOptions, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// validateVllmOptions validates vLLM backend specific options
|
||||||
|
func validateVllmOptions(options *VllmServerOptions) error {
|
||||||
|
if options == nil {
|
||||||
|
return validation.ValidationError(fmt.Errorf("vLLM server options cannot be nil for vLLM backend"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use reflection to check all string fields for injection patterns
|
||||||
|
if err := validation.ValidateStructStrings(options, ""); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Basic network validation for port
|
||||||
|
if options.Port < 0 || options.Port > 65535 {
|
||||||
|
return validation.ValidationError(fmt.Errorf("invalid port range: %d", options.Port))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ package instance
|
|||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"llamactl/pkg/backends"
|
|
||||||
"llamactl/pkg/config"
|
"llamactl/pkg/config"
|
||||||
"log"
|
"log"
|
||||||
"net/http/httputil"
|
"net/http/httputil"
|
||||||
@@ -124,48 +123,6 @@ func (i *Instance) IsRunning() bool {
|
|||||||
return i.status.isRunning()
|
return i.status.isRunning()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *Instance) GetPort() int {
|
|
||||||
opts := i.GetOptions()
|
|
||||||
if opts != nil {
|
|
||||||
switch opts.BackendType {
|
|
||||||
case backends.BackendTypeLlamaCpp:
|
|
||||||
if opts.LlamaServerOptions != nil {
|
|
||||||
return opts.LlamaServerOptions.Port
|
|
||||||
}
|
|
||||||
case backends.BackendTypeMlxLm:
|
|
||||||
if opts.MlxServerOptions != nil {
|
|
||||||
return opts.MlxServerOptions.Port
|
|
||||||
}
|
|
||||||
case backends.BackendTypeVllm:
|
|
||||||
if opts.VllmServerOptions != nil {
|
|
||||||
return opts.VllmServerOptions.Port
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i *Instance) GetHost() string {
|
|
||||||
opts := i.GetOptions()
|
|
||||||
if opts != nil {
|
|
||||||
switch opts.BackendType {
|
|
||||||
case backends.BackendTypeLlamaCpp:
|
|
||||||
if opts.LlamaServerOptions != nil {
|
|
||||||
return opts.LlamaServerOptions.Host
|
|
||||||
}
|
|
||||||
case backends.BackendTypeMlxLm:
|
|
||||||
if opts.MlxServerOptions != nil {
|
|
||||||
return opts.MlxServerOptions.Host
|
|
||||||
}
|
|
||||||
case backends.BackendTypeVllm:
|
|
||||||
if opts.VllmServerOptions != nil {
|
|
||||||
return opts.VllmServerOptions.Host
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetOptions sets the options
|
// SetOptions sets the options
|
||||||
func (i *Instance) SetOptions(opts *Options) {
|
func (i *Instance) SetOptions(opts *Options) {
|
||||||
if opts == nil {
|
if opts == nil {
|
||||||
@@ -198,6 +155,20 @@ func (i *Instance) SetTimeProvider(tp TimeProvider) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (i *Instance) GetHost() string {
|
||||||
|
if i.options == nil {
|
||||||
|
return "localhost"
|
||||||
|
}
|
||||||
|
return i.options.GetHost()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *Instance) GetPort() int {
|
||||||
|
if i.options == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return i.options.GetPort()
|
||||||
|
}
|
||||||
|
|
||||||
// GetProxy returns the reverse proxy for this instance
|
// GetProxy returns the reverse proxy for this instance
|
||||||
func (i *Instance) GetProxy() (*httputil.ReverseProxy, error) {
|
func (i *Instance) GetProxy() (*httputil.ReverseProxy, error) {
|
||||||
if i.proxy == nil {
|
if i.proxy == nil {
|
||||||
@@ -266,39 +237,31 @@ func (i *Instance) ShouldTimeout() bool {
|
|||||||
return i.proxy.shouldTimeout()
|
return i.proxy.shouldTimeout()
|
||||||
}
|
}
|
||||||
|
|
||||||
// getBackendHostPort extracts the host and port from instance options
|
func (i *Instance) getCommand() string {
|
||||||
// Returns the configured host and port for the backend
|
|
||||||
func (i *Instance) getBackendHostPort() (string, int) {
|
|
||||||
opts := i.GetOptions()
|
opts := i.GetOptions()
|
||||||
if opts == nil {
|
if opts == nil {
|
||||||
return "localhost", 0
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
var host string
|
return opts.BackendOptions.GetCommand(i.globalBackendSettings)
|
||||||
var port int
|
}
|
||||||
switch opts.BackendType {
|
|
||||||
case backends.BackendTypeLlamaCpp:
|
func (i *Instance) buildCommandArgs() []string {
|
||||||
if opts.LlamaServerOptions != nil {
|
opts := i.GetOptions()
|
||||||
host = opts.LlamaServerOptions.Host
|
if opts == nil {
|
||||||
port = opts.LlamaServerOptions.Port
|
return nil
|
||||||
}
|
|
||||||
case backends.BackendTypeMlxLm:
|
|
||||||
if opts.MlxServerOptions != nil {
|
|
||||||
host = opts.MlxServerOptions.Host
|
|
||||||
port = opts.MlxServerOptions.Port
|
|
||||||
}
|
|
||||||
case backends.BackendTypeVllm:
|
|
||||||
if opts.VllmServerOptions != nil {
|
|
||||||
host = opts.VllmServerOptions.Host
|
|
||||||
port = opts.VllmServerOptions.Port
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if host == "" {
|
return opts.BackendOptions.BuildCommandArgs(i.globalBackendSettings)
|
||||||
host = "localhost"
|
}
|
||||||
|
|
||||||
|
func (i *Instance) buildEnvironment() map[string]string {
|
||||||
|
opts := i.GetOptions()
|
||||||
|
if opts == nil {
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return host, port
|
return opts.BackendOptions.BuildEnvironment(i.globalBackendSettings, opts.Environment)
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarshalJSON implements json.Marshaler for Instance
|
// MarshalJSON implements json.Marshaler for Instance
|
||||||
@@ -307,21 +270,7 @@ func (i *Instance) MarshalJSON() ([]byte, error) {
|
|||||||
opts := i.GetOptions()
|
opts := i.GetOptions()
|
||||||
|
|
||||||
// Determine if docker is enabled for this instance's backend
|
// Determine if docker is enabled for this instance's backend
|
||||||
var dockerEnabled bool
|
dockerEnabled := opts.BackendOptions.IsDockerEnabled(i.globalBackendSettings)
|
||||||
if opts != nil {
|
|
||||||
switch opts.BackendType {
|
|
||||||
case backends.BackendTypeLlamaCpp:
|
|
||||||
if i.globalBackendSettings != nil && i.globalBackendSettings.LlamaCpp.Docker != nil && i.globalBackendSettings.LlamaCpp.Docker.Enabled {
|
|
||||||
dockerEnabled = true
|
|
||||||
}
|
|
||||||
case backends.BackendTypeVllm:
|
|
||||||
if i.globalBackendSettings != nil && i.globalBackendSettings.VLLM.Docker != nil && i.globalBackendSettings.VLLM.Docker.Enabled {
|
|
||||||
dockerEnabled = true
|
|
||||||
}
|
|
||||||
case backends.BackendTypeMlxLm:
|
|
||||||
// MLX does not support docker currently
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return json.Marshal(&struct {
|
return json.Marshal(&struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ import (
|
|||||||
"llamactl/pkg/backends"
|
"llamactl/pkg/backends"
|
||||||
"llamactl/pkg/config"
|
"llamactl/pkg/config"
|
||||||
"log"
|
"log"
|
||||||
"maps"
|
|
||||||
"slices"
|
"slices"
|
||||||
"sync"
|
"sync"
|
||||||
)
|
)
|
||||||
@@ -21,18 +20,12 @@ type Options struct {
|
|||||||
OnDemandStart *bool `json:"on_demand_start,omitempty"`
|
OnDemandStart *bool `json:"on_demand_start,omitempty"`
|
||||||
// Idle timeout
|
// Idle timeout
|
||||||
IdleTimeout *int `json:"idle_timeout,omitempty"` // minutes
|
IdleTimeout *int `json:"idle_timeout,omitempty"` // minutes
|
||||||
//Environment variables
|
// Environment variables
|
||||||
Environment map[string]string `json:"environment,omitempty"`
|
Environment map[string]string `json:"environment,omitempty"`
|
||||||
|
// Assigned nodes
|
||||||
BackendType backends.BackendType `json:"backend_type"`
|
|
||||||
BackendOptions map[string]any `json:"backend_options,omitempty"`
|
|
||||||
|
|
||||||
Nodes map[string]struct{} `json:"-"`
|
Nodes map[string]struct{} `json:"-"`
|
||||||
|
// Backend options
|
||||||
// Backend-specific options
|
BackendOptions backends.Options `json:"-"`
|
||||||
LlamaServerOptions *backends.LlamaServerOptions `json:"-"`
|
|
||||||
MlxServerOptions *backends.MlxServerOptions `json:"-"`
|
|
||||||
VllmServerOptions *backends.VllmServerOptions `json:"-"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// options wraps Options with thread-safe access (unexported).
|
// options wraps Options with thread-safe access (unexported).
|
||||||
@@ -62,6 +55,18 @@ func (o *options) set(opts *Options) {
|
|||||||
o.opts = opts
|
o.opts = opts
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (o *options) GetHost() string {
|
||||||
|
o.mu.RLock()
|
||||||
|
defer o.mu.RUnlock()
|
||||||
|
return o.opts.BackendOptions.GetHost()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *options) GetPort() int {
|
||||||
|
o.mu.RLock()
|
||||||
|
defer o.mu.RUnlock()
|
||||||
|
return o.opts.BackendOptions.GetPort()
|
||||||
|
}
|
||||||
|
|
||||||
// MarshalJSON implements json.Marshaler for options wrapper
|
// MarshalJSON implements json.Marshaler for options wrapper
|
||||||
func (o *options) MarshalJSON() ([]byte, error) {
|
func (o *options) MarshalJSON() ([]byte, error) {
|
||||||
o.mu.RLock()
|
o.mu.RLock()
|
||||||
@@ -85,7 +90,9 @@ func (c *Options) UnmarshalJSON(data []byte) error {
|
|||||||
// Use anonymous struct to avoid recursion
|
// Use anonymous struct to avoid recursion
|
||||||
type Alias Options
|
type Alias Options
|
||||||
aux := &struct {
|
aux := &struct {
|
||||||
Nodes []string `json:"nodes,omitempty"` // Accept JSON array
|
Nodes []string `json:"nodes,omitempty"`
|
||||||
|
BackendType backends.BackendType `json:"backend_type"`
|
||||||
|
BackendOptions map[string]any `json:"backend_options,omitempty"`
|
||||||
*Alias
|
*Alias
|
||||||
}{
|
}{
|
||||||
Alias: (*Alias)(c),
|
Alias: (*Alias)(c),
|
||||||
@@ -103,47 +110,27 @@ func (c *Options) UnmarshalJSON(data []byte) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parse backend-specific options
|
// Create backend options struct and unmarshal
|
||||||
switch c.BackendType {
|
c.BackendOptions = backends.Options{
|
||||||
case backends.BackendTypeLlamaCpp:
|
BackendType: aux.BackendType,
|
||||||
if c.BackendOptions != nil {
|
BackendOptions: aux.BackendOptions,
|
||||||
// Convert map to JSON and then unmarshal to LlamaServerOptions
|
}
|
||||||
optionsData, err := json.Marshal(c.BackendOptions)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to marshal backend options: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
c.LlamaServerOptions = &backends.LlamaServerOptions{}
|
// Marshal the backend options to JSON for proper unmarshaling
|
||||||
if err := json.Unmarshal(optionsData, c.LlamaServerOptions); err != nil {
|
backendJson, err := json.Marshal(struct {
|
||||||
return fmt.Errorf("failed to unmarshal llama.cpp options: %w", err)
|
BackendType backends.BackendType `json:"backend_type"`
|
||||||
}
|
BackendOptions map[string]any `json:"backend_options,omitempty"`
|
||||||
}
|
}{
|
||||||
case backends.BackendTypeMlxLm:
|
BackendType: aux.BackendType,
|
||||||
if c.BackendOptions != nil {
|
BackendOptions: aux.BackendOptions,
|
||||||
optionsData, err := json.Marshal(c.BackendOptions)
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to marshal backend options: %w", err)
|
return fmt.Errorf("failed to marshal backend options: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
c.MlxServerOptions = &backends.MlxServerOptions{}
|
// Unmarshal into the backends.Options struct to trigger its custom unmarshaling
|
||||||
if err := json.Unmarshal(optionsData, c.MlxServerOptions); err != nil {
|
if err := json.Unmarshal(backendJson, &c.BackendOptions); err != nil {
|
||||||
return fmt.Errorf("failed to unmarshal MLX options: %w", err)
|
return fmt.Errorf("failed to unmarshal backend options: %w", err)
|
||||||
}
|
|
||||||
}
|
|
||||||
case backends.BackendTypeVllm:
|
|
||||||
if c.BackendOptions != nil {
|
|
||||||
optionsData, err := json.Marshal(c.BackendOptions)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to marshal backend options: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
c.VllmServerOptions = &backends.VllmServerOptions{}
|
|
||||||
if err := json.Unmarshal(optionsData, c.VllmServerOptions); err != nil {
|
|
||||||
return fmt.Errorf("failed to unmarshal vLLM options: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
return fmt.Errorf("unknown backend type: %s", c.BackendType)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -154,7 +141,9 @@ func (c *Options) MarshalJSON() ([]byte, error) {
|
|||||||
// Use anonymous struct to avoid recursion
|
// Use anonymous struct to avoid recursion
|
||||||
type Alias Options
|
type Alias Options
|
||||||
aux := struct {
|
aux := struct {
|
||||||
Nodes []string `json:"nodes,omitempty"` // Output as JSON array
|
Nodes []string `json:"nodes,omitempty"` // Output as JSON array
|
||||||
|
BackendType backends.BackendType `json:"backend_type"`
|
||||||
|
BackendOptions map[string]any `json:"backend_options,omitempty"`
|
||||||
*Alias
|
*Alias
|
||||||
}{
|
}{
|
||||||
Alias: (*Alias)(c),
|
Alias: (*Alias)(c),
|
||||||
@@ -170,52 +159,25 @@ func (c *Options) MarshalJSON() ([]byte, error) {
|
|||||||
slices.Sort(aux.Nodes)
|
slices.Sort(aux.Nodes)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convert backend-specific options back to BackendOptions map for JSON
|
// Set backend type
|
||||||
switch c.BackendType {
|
aux.BackendType = c.BackendOptions.BackendType
|
||||||
case backends.BackendTypeLlamaCpp:
|
|
||||||
if c.LlamaServerOptions != nil {
|
|
||||||
data, err := json.Marshal(c.LlamaServerOptions)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to marshal llama server options: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var backendOpts map[string]any
|
// Marshal the backends.Options struct to get the properly formatted backend options
|
||||||
if err := json.Unmarshal(data, &backendOpts); err != nil {
|
backendData, err := json.Marshal(c.BackendOptions)
|
||||||
return nil, fmt.Errorf("failed to unmarshal to map: %w", err)
|
if err != nil {
|
||||||
}
|
return nil, fmt.Errorf("failed to marshal backend options: %w", err)
|
||||||
|
|
||||||
aux.BackendOptions = backendOpts
|
|
||||||
}
|
|
||||||
case backends.BackendTypeMlxLm:
|
|
||||||
if c.MlxServerOptions != nil {
|
|
||||||
data, err := json.Marshal(c.MlxServerOptions)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to marshal MLX server options: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var backendOpts map[string]any
|
|
||||||
if err := json.Unmarshal(data, &backendOpts); err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to unmarshal to map: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
aux.BackendOptions = backendOpts
|
|
||||||
}
|
|
||||||
case backends.BackendTypeVllm:
|
|
||||||
if c.VllmServerOptions != nil {
|
|
||||||
data, err := json.Marshal(c.VllmServerOptions)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to marshal vLLM server options: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var backendOpts map[string]any
|
|
||||||
if err := json.Unmarshal(data, &backendOpts); err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to unmarshal to map: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
aux.BackendOptions = backendOpts
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Unmarshal into a temporary struct to extract the backend_options map
|
||||||
|
var tempBackend struct {
|
||||||
|
BackendOptions map[string]any `json:"backend_options,omitempty"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(backendData, &tempBackend); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to unmarshal backend data: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
aux.BackendOptions = tempBackend.BackendOptions
|
||||||
|
|
||||||
return json.Marshal(aux)
|
return json.Marshal(aux)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -257,78 +219,3 @@ func (c *Options) validateAndApplyDefaults(name string, globalSettings *config.I
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// getCommand builds the command to run the backend
|
|
||||||
func (c *Options) getCommand(backendConfig *config.BackendSettings) string {
|
|
||||||
|
|
||||||
if backendConfig.Docker != nil && backendConfig.Docker.Enabled && c.BackendType != backends.BackendTypeMlxLm {
|
|
||||||
return "docker"
|
|
||||||
}
|
|
||||||
|
|
||||||
return backendConfig.Command
|
|
||||||
}
|
|
||||||
|
|
||||||
// buildCommandArgs builds command line arguments for the backend
|
|
||||||
func (c *Options) buildCommandArgs(backendConfig *config.BackendSettings) []string {
|
|
||||||
|
|
||||||
var args []string
|
|
||||||
|
|
||||||
if backendConfig.Docker != nil && backendConfig.Docker.Enabled && c.BackendType != backends.BackendTypeMlxLm {
|
|
||||||
// For Docker, start with Docker args
|
|
||||||
args = append(args, backendConfig.Docker.Args...)
|
|
||||||
args = append(args, backendConfig.Docker.Image)
|
|
||||||
|
|
||||||
switch c.BackendType {
|
|
||||||
case backends.BackendTypeLlamaCpp:
|
|
||||||
if c.LlamaServerOptions != nil {
|
|
||||||
args = append(args, c.LlamaServerOptions.BuildDockerArgs()...)
|
|
||||||
}
|
|
||||||
case backends.BackendTypeVllm:
|
|
||||||
if c.VllmServerOptions != nil {
|
|
||||||
args = append(args, c.VllmServerOptions.BuildDockerArgs()...)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} else {
|
|
||||||
// For native execution, start with backend args
|
|
||||||
args = append(args, backendConfig.Args...)
|
|
||||||
|
|
||||||
switch c.BackendType {
|
|
||||||
case backends.BackendTypeLlamaCpp:
|
|
||||||
if c.LlamaServerOptions != nil {
|
|
||||||
args = append(args, c.LlamaServerOptions.BuildCommandArgs()...)
|
|
||||||
}
|
|
||||||
case backends.BackendTypeMlxLm:
|
|
||||||
if c.MlxServerOptions != nil {
|
|
||||||
args = append(args, c.MlxServerOptions.BuildCommandArgs()...)
|
|
||||||
}
|
|
||||||
case backends.BackendTypeVllm:
|
|
||||||
if c.VllmServerOptions != nil {
|
|
||||||
args = append(args, c.VllmServerOptions.BuildCommandArgs()...)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return args
|
|
||||||
}
|
|
||||||
|
|
||||||
// buildEnvironment builds the environment variables for the backend process
|
|
||||||
func (c *Options) buildEnvironment(backendConfig *config.BackendSettings) map[string]string {
|
|
||||||
env := map[string]string{}
|
|
||||||
|
|
||||||
if backendConfig.Environment != nil {
|
|
||||||
maps.Copy(env, backendConfig.Environment)
|
|
||||||
}
|
|
||||||
|
|
||||||
if backendConfig.Docker != nil && backendConfig.Docker.Enabled && c.BackendType != backends.BackendTypeMlxLm {
|
|
||||||
if backendConfig.Docker.Environment != nil {
|
|
||||||
maps.Copy(env, backendConfig.Docker.Environment)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if c.Environment != nil {
|
|
||||||
maps.Copy(env, c.Environment)
|
|
||||||
}
|
|
||||||
|
|
||||||
return env
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -12,9 +12,6 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"llamactl/pkg/backends"
|
|
||||||
"llamactl/pkg/config"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// process manages the OS process lifecycle for a local instance.
|
// process manages the OS process lifecycle for a local instance.
|
||||||
@@ -216,7 +213,8 @@ func (p *process) waitForHealthy(timeout int) error {
|
|||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
// Get host/port from instance
|
// Get host/port from instance
|
||||||
host, port := p.instance.getBackendHostPort()
|
host := p.instance.options.GetHost()
|
||||||
|
port := p.instance.options.GetPort()
|
||||||
healthURL := fmt.Sprintf("http://%s:%d/health", host, port)
|
healthURL := fmt.Sprintf("http://%s:%d/health", host, port)
|
||||||
|
|
||||||
// Create a dedicated HTTP client for health checks
|
// Create a dedicated HTTP client for health checks
|
||||||
@@ -386,26 +384,15 @@ func (p *process) handleAutoRestart(err error) {
|
|||||||
|
|
||||||
// buildCommand builds the command to execute using backend-specific logic
|
// buildCommand builds the command to execute using backend-specific logic
|
||||||
func (p *process) buildCommand() (*exec.Cmd, error) {
|
func (p *process) buildCommand() (*exec.Cmd, error) {
|
||||||
// Get options
|
|
||||||
opts := p.instance.GetOptions()
|
|
||||||
if opts == nil {
|
|
||||||
return nil, fmt.Errorf("instance options are nil")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get backend configuration
|
|
||||||
backendConfig, err := p.getBackendConfig()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build the environment variables
|
// Build the environment variables
|
||||||
env := opts.buildEnvironment(backendConfig)
|
env := p.instance.buildEnvironment()
|
||||||
|
|
||||||
// Get the command to execute
|
// Get the command to execute
|
||||||
command := opts.getCommand(backendConfig)
|
command := p.instance.getCommand()
|
||||||
|
|
||||||
// Build command arguments
|
// Build command arguments
|
||||||
args := opts.buildCommandArgs(backendConfig)
|
args := p.instance.buildCommandArgs()
|
||||||
|
|
||||||
// Create the exec.Cmd
|
// Create the exec.Cmd
|
||||||
cmd := exec.CommandContext(p.ctx, command, args...)
|
cmd := exec.CommandContext(p.ctx, command, args...)
|
||||||
@@ -420,27 +407,3 @@ func (p *process) buildCommand() (*exec.Cmd, error) {
|
|||||||
|
|
||||||
return cmd, nil
|
return cmd, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// getBackendConfig resolves the backend configuration for the current instance
|
|
||||||
func (p *process) getBackendConfig() (*config.BackendSettings, error) {
|
|
||||||
opts := p.instance.GetOptions()
|
|
||||||
if opts == nil {
|
|
||||||
return nil, fmt.Errorf("instance options are nil")
|
|
||||||
}
|
|
||||||
|
|
||||||
var backendTypeStr string
|
|
||||||
|
|
||||||
switch opts.BackendType {
|
|
||||||
case backends.BackendTypeLlamaCpp:
|
|
||||||
backendTypeStr = "llama-cpp"
|
|
||||||
case backends.BackendTypeMlxLm:
|
|
||||||
backendTypeStr = "mlx"
|
|
||||||
case backends.BackendTypeVllm:
|
|
||||||
backendTypeStr = "vllm"
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("unsupported backend type: %s", opts.BackendType)
|
|
||||||
}
|
|
||||||
|
|
||||||
settings := p.instance.globalBackendSettings.GetBackendSettings(backendTypeStr)
|
|
||||||
return &settings, nil
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ package instance
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"llamactl/pkg/backends"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httputil"
|
"net/http/httputil"
|
||||||
"net/url"
|
"net/url"
|
||||||
@@ -68,8 +67,11 @@ func (p *proxy) build() (*httputil.ReverseProxy, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Get host/port from process
|
// Get host/port from process
|
||||||
host, port := p.instance.getBackendHostPort()
|
host := p.instance.options.GetHost()
|
||||||
|
port := p.instance.options.GetPort()
|
||||||
|
if port == 0 {
|
||||||
|
return nil, fmt.Errorf("instance %s has no port assigned", p.instance.Name)
|
||||||
|
}
|
||||||
targetURL, err := url.Parse(fmt.Sprintf("http://%s:%d", host, port))
|
targetURL, err := url.Parse(fmt.Sprintf("http://%s:%d", host, port))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to parse target URL for instance %s: %w", p.instance.Name, err)
|
return nil, fmt.Errorf("failed to parse target URL for instance %s: %w", p.instance.Name, err)
|
||||||
@@ -78,15 +80,7 @@ func (p *proxy) build() (*httputil.ReverseProxy, error) {
|
|||||||
proxy := httputil.NewSingleHostReverseProxy(targetURL)
|
proxy := httputil.NewSingleHostReverseProxy(targetURL)
|
||||||
|
|
||||||
// Get response headers from backend config
|
// Get response headers from backend config
|
||||||
var responseHeaders map[string]string
|
responseHeaders := options.BackendOptions.GetResponseHeaders(p.instance.globalBackendSettings)
|
||||||
switch options.BackendType {
|
|
||||||
case backends.BackendTypeLlamaCpp:
|
|
||||||
responseHeaders = p.instance.globalBackendSettings.LlamaCpp.ResponseHeaders
|
|
||||||
case backends.BackendTypeVllm:
|
|
||||||
responseHeaders = p.instance.globalBackendSettings.VLLM.ResponseHeaders
|
|
||||||
case backends.BackendTypeMlxLm:
|
|
||||||
responseHeaders = p.instance.globalBackendSettings.MLX.ResponseHeaders
|
|
||||||
}
|
|
||||||
|
|
||||||
proxy.ModifyResponse = func(resp *http.Response) error {
|
proxy.ModifyResponse = func(resp *http.Response) error {
|
||||||
// Remove CORS headers from backend response to avoid conflicts
|
// Remove CORS headers from backend response to avoid conflicts
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ package manager
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"llamactl/pkg/backends"
|
|
||||||
"llamactl/pkg/instance"
|
"llamactl/pkg/instance"
|
||||||
"llamactl/pkg/validation"
|
"llamactl/pkg/validation"
|
||||||
"os"
|
"os"
|
||||||
@@ -86,7 +85,7 @@ func (im *instanceManager) CreateInstance(name string, options *instance.Options
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = validation.ValidateInstanceOptions(options)
|
err = options.BackendOptions.ValidateInstanceOptions()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -232,7 +231,7 @@ func (im *instanceManager) UpdateInstance(name string, options *instance.Options
|
|||||||
return nil, fmt.Errorf("instance options cannot be nil")
|
return nil, fmt.Errorf("instance options cannot be nil")
|
||||||
}
|
}
|
||||||
|
|
||||||
err := validation.ValidateInstanceOptions(options)
|
err := options.BackendOptions.ValidateInstanceOptions()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -493,39 +492,12 @@ func (im *instanceManager) GetInstanceLogs(name string, numLines int) (string, e
|
|||||||
|
|
||||||
// getPortFromOptions extracts the port from backend-specific options
|
// getPortFromOptions extracts the port from backend-specific options
|
||||||
func (im *instanceManager) getPortFromOptions(options *instance.Options) int {
|
func (im *instanceManager) getPortFromOptions(options *instance.Options) int {
|
||||||
switch options.BackendType {
|
return options.BackendOptions.GetPort()
|
||||||
case backends.BackendTypeLlamaCpp:
|
|
||||||
if options.LlamaServerOptions != nil {
|
|
||||||
return options.LlamaServerOptions.Port
|
|
||||||
}
|
|
||||||
case backends.BackendTypeMlxLm:
|
|
||||||
if options.MlxServerOptions != nil {
|
|
||||||
return options.MlxServerOptions.Port
|
|
||||||
}
|
|
||||||
case backends.BackendTypeVllm:
|
|
||||||
if options.VllmServerOptions != nil {
|
|
||||||
return options.VllmServerOptions.Port
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return 0
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// setPortInOptions sets the port in backend-specific options
|
// setPortInOptions sets the port in backend-specific options
|
||||||
func (im *instanceManager) setPortInOptions(options *instance.Options, port int) {
|
func (im *instanceManager) setPortInOptions(options *instance.Options, port int) {
|
||||||
switch options.BackendType {
|
options.BackendOptions.SetPort(port)
|
||||||
case backends.BackendTypeLlamaCpp:
|
|
||||||
if options.LlamaServerOptions != nil {
|
|
||||||
options.LlamaServerOptions.Port = port
|
|
||||||
}
|
|
||||||
case backends.BackendTypeMlxLm:
|
|
||||||
if options.MlxServerOptions != nil {
|
|
||||||
options.MlxServerOptions.Port = port
|
|
||||||
}
|
|
||||||
case backends.BackendTypeVllm:
|
|
||||||
if options.VllmServerOptions != nil {
|
|
||||||
options.VllmServerOptions.Port = port
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// assignAndValidatePort assigns a port if not specified and validates it's not in use
|
// assignAndValidatePort assigns a port if not specified and validates it's not in use
|
||||||
|
|||||||
@@ -2,8 +2,6 @@ package validation
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"llamactl/pkg/backends"
|
|
||||||
"llamactl/pkg/instance"
|
|
||||||
"reflect"
|
"reflect"
|
||||||
"regexp"
|
"regexp"
|
||||||
)
|
)
|
||||||
@@ -24,8 +22,8 @@ var (
|
|||||||
|
|
||||||
type ValidationError error
|
type ValidationError error
|
||||||
|
|
||||||
// validateStringForInjection checks if a string contains dangerous patterns
|
// ValidateStringForInjection checks if a string contains dangerous patterns
|
||||||
func validateStringForInjection(value string) error {
|
func ValidateStringForInjection(value string) error {
|
||||||
for _, pattern := range dangerousPatterns {
|
for _, pattern := range dangerousPatterns {
|
||||||
if pattern.MatchString(value) {
|
if pattern.MatchString(value) {
|
||||||
return ValidationError(fmt.Errorf("value contains potentially dangerous characters: %s", value))
|
return ValidationError(fmt.Errorf("value contains potentially dangerous characters: %s", value))
|
||||||
@@ -34,83 +32,8 @@ func validateStringForInjection(value string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ValidateInstanceOptions performs validation based on backend type
|
// ValidateStructStrings recursively validates all string fields in a struct
|
||||||
func ValidateInstanceOptions(options *instance.Options) error {
|
func ValidateStructStrings(v any, fieldPath string) error {
|
||||||
if options == nil {
|
|
||||||
return ValidationError(fmt.Errorf("options cannot be nil"))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate based on backend type
|
|
||||||
switch options.BackendType {
|
|
||||||
case backends.BackendTypeLlamaCpp:
|
|
||||||
return validateLlamaCppOptions(options)
|
|
||||||
case backends.BackendTypeMlxLm:
|
|
||||||
return validateMlxOptions(options)
|
|
||||||
case backends.BackendTypeVllm:
|
|
||||||
return validateVllmOptions(options)
|
|
||||||
default:
|
|
||||||
return ValidationError(fmt.Errorf("unsupported backend type: %s", options.BackendType))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// validateLlamaCppOptions validates llama.cpp specific options
|
|
||||||
func validateLlamaCppOptions(options *instance.Options) error {
|
|
||||||
if options.LlamaServerOptions == nil {
|
|
||||||
return ValidationError(fmt.Errorf("llama server options cannot be nil for llama.cpp backend"))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Use reflection to check all string fields for injection patterns
|
|
||||||
if err := validateStructStrings(options.LlamaServerOptions, ""); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Basic network validation for port
|
|
||||||
if options.LlamaServerOptions.Port < 0 || options.LlamaServerOptions.Port > 65535 {
|
|
||||||
return ValidationError(fmt.Errorf("invalid port range: %d", options.LlamaServerOptions.Port))
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// validateMlxOptions validates MLX backend specific options
|
|
||||||
func validateMlxOptions(options *instance.Options) error {
|
|
||||||
if options.MlxServerOptions == nil {
|
|
||||||
return ValidationError(fmt.Errorf("MLX server options cannot be nil for MLX backend"))
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := validateStructStrings(options.MlxServerOptions, ""); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Basic network validation for port
|
|
||||||
if options.MlxServerOptions.Port < 0 || options.MlxServerOptions.Port > 65535 {
|
|
||||||
return ValidationError(fmt.Errorf("invalid port range: %d", options.MlxServerOptions.Port))
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// validateVllmOptions validates vLLM backend specific options
|
|
||||||
func validateVllmOptions(options *instance.Options) error {
|
|
||||||
if options.VllmServerOptions == nil {
|
|
||||||
return ValidationError(fmt.Errorf("vLLM server options cannot be nil for vLLM backend"))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Use reflection to check all string fields for injection patterns
|
|
||||||
if err := validateStructStrings(options.VllmServerOptions, ""); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Basic network validation for port
|
|
||||||
if options.VllmServerOptions.Port < 0 || options.VllmServerOptions.Port > 65535 {
|
|
||||||
return ValidationError(fmt.Errorf("invalid port range: %d", options.VllmServerOptions.Port))
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// validateStructStrings recursively validates all string fields in a struct
|
|
||||||
func validateStructStrings(v any, fieldPath string) error {
|
|
||||||
val := reflect.ValueOf(v)
|
val := reflect.ValueOf(v)
|
||||||
if val.Kind() == reflect.Ptr {
|
if val.Kind() == reflect.Ptr {
|
||||||
val = val.Elem()
|
val = val.Elem()
|
||||||
@@ -136,21 +59,21 @@ func validateStructStrings(v any, fieldPath string) error {
|
|||||||
|
|
||||||
switch field.Kind() {
|
switch field.Kind() {
|
||||||
case reflect.String:
|
case reflect.String:
|
||||||
if err := validateStringForInjection(field.String()); err != nil {
|
if err := ValidateStringForInjection(field.String()); err != nil {
|
||||||
return ValidationError(fmt.Errorf("field %s: %w", fieldName, err))
|
return ValidationError(fmt.Errorf("field %s: %w", fieldName, err))
|
||||||
}
|
}
|
||||||
|
|
||||||
case reflect.Slice:
|
case reflect.Slice:
|
||||||
if field.Type().Elem().Kind() == reflect.String {
|
if field.Type().Elem().Kind() == reflect.String {
|
||||||
for j := 0; j < field.Len(); j++ {
|
for j := 0; j < field.Len(); j++ {
|
||||||
if err := validateStringForInjection(field.Index(j).String()); err != nil {
|
if err := ValidateStringForInjection(field.Index(j).String()); err != nil {
|
||||||
return ValidationError(fmt.Errorf("field %s[%d]: %w", fieldName, j, err))
|
return ValidationError(fmt.Errorf("field %s[%d]: %w", fieldName, j, err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
if err := validateStructStrings(field.Interface(), fieldName); err != nil {
|
if err := ValidateStructStrings(field.Interface(), fieldName); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user