mirror of
https://github.com/lordmathis/llamactl.git
synced 2025-11-06 00:54:23 +00:00
Refactor backend options to implement common interface and streamline validation
This commit is contained in:
@@ -14,9 +14,17 @@ const (
|
|||||||
BackendTypeLlamaCpp BackendType = "llama_cpp"
|
BackendTypeLlamaCpp BackendType = "llama_cpp"
|
||||||
BackendTypeMlxLm BackendType = "mlx_lm"
|
BackendTypeMlxLm BackendType = "mlx_lm"
|
||||||
BackendTypeVllm BackendType = "vllm"
|
BackendTypeVllm BackendType = "vllm"
|
||||||
// BackendTypeMlxVlm BackendType = "mlx_vlm" // Future expansion
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type backend interface {
|
||||||
|
BuildCommandArgs() []string
|
||||||
|
BuildDockerArgs() []string
|
||||||
|
GetPort() int
|
||||||
|
SetPort(int)
|
||||||
|
GetHost() string
|
||||||
|
Validate() error
|
||||||
|
}
|
||||||
|
|
||||||
type Options struct {
|
type Options struct {
|
||||||
BackendType BackendType `json:"backend_type"`
|
BackendType BackendType `json:"backend_type"`
|
||||||
BackendOptions map[string]any `json:"backend_options,omitempty"`
|
BackendOptions map[string]any `json:"backend_options,omitempty"`
|
||||||
@@ -135,7 +143,7 @@ func (o *Options) MarshalJSON() ([]byte, error) {
|
|||||||
return json.Marshal(aux)
|
return json.Marshal(aux)
|
||||||
}
|
}
|
||||||
|
|
||||||
func getBackendSettings(o *Options, backendConfig *config.BackendConfig) *config.BackendSettings {
|
func (o *Options) getBackendSettings(backendConfig *config.BackendConfig) *config.BackendSettings {
|
||||||
switch o.BackendType {
|
switch o.BackendType {
|
||||||
case BackendTypeLlamaCpp:
|
case BackendTypeLlamaCpp:
|
||||||
return &backendConfig.LlamaCpp
|
return &backendConfig.LlamaCpp
|
||||||
@@ -148,6 +156,20 @@ func getBackendSettings(o *Options, backendConfig *config.BackendConfig) *config
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getBackend returns the actual backend implementation
|
||||||
|
func (o *Options) getBackend() backend {
|
||||||
|
switch o.BackendType {
|
||||||
|
case BackendTypeLlamaCpp:
|
||||||
|
return o.LlamaServerOptions
|
||||||
|
case BackendTypeMlxLm:
|
||||||
|
return o.MlxServerOptions
|
||||||
|
case BackendTypeVllm:
|
||||||
|
return o.VllmServerOptions
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (o *Options) isDockerEnabled(backend *config.BackendSettings) bool {
|
func (o *Options) isDockerEnabled(backend *config.BackendSettings) bool {
|
||||||
if backend.Docker != nil && backend.Docker.Enabled && o.BackendType != BackendTypeMlxLm {
|
if backend.Docker != nil && backend.Docker.Enabled && o.BackendType != BackendTypeMlxLm {
|
||||||
return true
|
return true
|
||||||
@@ -156,14 +178,14 @@ func (o *Options) isDockerEnabled(backend *config.BackendSettings) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (o *Options) IsDockerEnabled(backendConfig *config.BackendConfig) bool {
|
func (o *Options) IsDockerEnabled(backendConfig *config.BackendConfig) bool {
|
||||||
backendSettings := getBackendSettings(o, backendConfig)
|
backendSettings := o.getBackendSettings(backendConfig)
|
||||||
return o.isDockerEnabled(backendSettings)
|
return o.isDockerEnabled(backendSettings)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetCommand builds the command to run the backend
|
// GetCommand builds the command to run the backend
|
||||||
func (o *Options) GetCommand(backendConfig *config.BackendConfig) string {
|
func (o *Options) GetCommand(backendConfig *config.BackendConfig) string {
|
||||||
|
|
||||||
backendSettings := getBackendSettings(o, backendConfig)
|
backendSettings := o.getBackendSettings(backendConfig)
|
||||||
|
|
||||||
if o.isDockerEnabled(backendSettings) {
|
if o.isDockerEnabled(backendSettings) {
|
||||||
return "docker"
|
return "docker"
|
||||||
@@ -177,42 +199,22 @@ func (o *Options) BuildCommandArgs(backendConfig *config.BackendConfig) []string
|
|||||||
|
|
||||||
var args []string
|
var args []string
|
||||||
|
|
||||||
backendSettings := getBackendSettings(o, backendConfig)
|
backendSettings := o.getBackendSettings(backendConfig)
|
||||||
|
backend := o.getBackend()
|
||||||
|
if backend == nil {
|
||||||
|
return args
|
||||||
|
}
|
||||||
|
|
||||||
if o.isDockerEnabled(backendSettings) {
|
if o.isDockerEnabled(backendSettings) {
|
||||||
// For Docker, start with Docker args
|
// For Docker, start with Docker args
|
||||||
args = append(args, backendSettings.Docker.Args...)
|
args = append(args, backendSettings.Docker.Args...)
|
||||||
args = append(args, backendSettings.Docker.Image)
|
args = append(args, backendSettings.Docker.Image)
|
||||||
|
args = append(args, backend.BuildDockerArgs()...)
|
||||||
switch o.BackendType {
|
|
||||||
case BackendTypeLlamaCpp:
|
|
||||||
if o.LlamaServerOptions != nil {
|
|
||||||
args = append(args, o.LlamaServerOptions.BuildDockerArgs()...)
|
|
||||||
}
|
|
||||||
case BackendTypeVllm:
|
|
||||||
if o.VllmServerOptions != nil {
|
|
||||||
args = append(args, o.VllmServerOptions.BuildDockerArgs()...)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
// For native execution, start with backend args
|
// For native execution, start with backend args
|
||||||
args = append(args, backendSettings.Args...)
|
args = append(args, backendSettings.Args...)
|
||||||
|
args = append(args, backend.BuildCommandArgs()...)
|
||||||
switch o.BackendType {
|
|
||||||
case BackendTypeLlamaCpp:
|
|
||||||
if o.LlamaServerOptions != nil {
|
|
||||||
args = append(args, o.LlamaServerOptions.BuildCommandArgs()...)
|
|
||||||
}
|
|
||||||
case BackendTypeMlxLm:
|
|
||||||
if o.MlxServerOptions != nil {
|
|
||||||
args = append(args, o.MlxServerOptions.BuildCommandArgs()...)
|
|
||||||
}
|
|
||||||
case BackendTypeVllm:
|
|
||||||
if o.VllmServerOptions != nil {
|
|
||||||
args = append(args, o.VllmServerOptions.BuildCommandArgs()...)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return args
|
return args
|
||||||
@@ -221,7 +223,7 @@ func (o *Options) BuildCommandArgs(backendConfig *config.BackendConfig) []string
|
|||||||
// BuildEnvironment builds the environment variables for the backend process
|
// BuildEnvironment builds the environment variables for the backend process
|
||||||
func (o *Options) BuildEnvironment(backendConfig *config.BackendConfig, environment map[string]string) map[string]string {
|
func (o *Options) BuildEnvironment(backendConfig *config.BackendConfig, environment map[string]string) map[string]string {
|
||||||
|
|
||||||
backendSettings := getBackendSettings(o, backendConfig)
|
backendSettings := o.getBackendSettings(backendConfig)
|
||||||
env := map[string]string{}
|
env := map[string]string{}
|
||||||
|
|
||||||
if backendSettings.Environment != nil {
|
if backendSettings.Environment != nil {
|
||||||
@@ -242,80 +244,39 @@ func (o *Options) BuildEnvironment(backendConfig *config.BackendConfig, environm
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (o *Options) GetPort() int {
|
func (o *Options) GetPort() int {
|
||||||
if o != nil {
|
backend := o.getBackend()
|
||||||
switch o.BackendType {
|
if backend != nil {
|
||||||
case BackendTypeLlamaCpp:
|
return backend.GetPort()
|
||||||
if o.LlamaServerOptions != nil {
|
|
||||||
return o.LlamaServerOptions.Port
|
|
||||||
}
|
|
||||||
case BackendTypeMlxLm:
|
|
||||||
if o.MlxServerOptions != nil {
|
|
||||||
return o.MlxServerOptions.Port
|
|
||||||
}
|
|
||||||
case BackendTypeVllm:
|
|
||||||
if o.VllmServerOptions != nil {
|
|
||||||
return o.VllmServerOptions.Port
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *Options) SetPort(port int) {
|
func (o *Options) SetPort(port int) {
|
||||||
if o != nil {
|
backend := o.getBackend()
|
||||||
switch o.BackendType {
|
if backend != nil {
|
||||||
case BackendTypeLlamaCpp:
|
backend.SetPort(port)
|
||||||
if o.LlamaServerOptions != nil {
|
|
||||||
o.LlamaServerOptions.Port = port
|
|
||||||
}
|
|
||||||
case BackendTypeMlxLm:
|
|
||||||
if o.MlxServerOptions != nil {
|
|
||||||
o.MlxServerOptions.Port = port
|
|
||||||
}
|
|
||||||
case BackendTypeVllm:
|
|
||||||
if o.VllmServerOptions != nil {
|
|
||||||
o.VllmServerOptions.Port = port
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *Options) GetHost() string {
|
func (o *Options) GetHost() string {
|
||||||
if o != nil {
|
backend := o.getBackend()
|
||||||
switch o.BackendType {
|
if backend != nil {
|
||||||
case BackendTypeLlamaCpp:
|
return backend.GetHost()
|
||||||
if o.LlamaServerOptions != nil {
|
|
||||||
return o.LlamaServerOptions.Host
|
|
||||||
}
|
|
||||||
case BackendTypeMlxLm:
|
|
||||||
if o.MlxServerOptions != nil {
|
|
||||||
return o.MlxServerOptions.Host
|
|
||||||
}
|
|
||||||
case BackendTypeVllm:
|
|
||||||
if o.VllmServerOptions != nil {
|
|
||||||
return o.VllmServerOptions.Host
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return "localhost"
|
return "localhost"
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *Options) GetResponseHeaders(backendConfig *config.BackendConfig) map[string]string {
|
func (o *Options) GetResponseHeaders(backendConfig *config.BackendConfig) map[string]string {
|
||||||
backendSettings := getBackendSettings(o, backendConfig)
|
backendSettings := o.getBackendSettings(backendConfig)
|
||||||
return backendSettings.ResponseHeaders
|
return backendSettings.ResponseHeaders
|
||||||
}
|
}
|
||||||
|
|
||||||
// ValidateInstanceOptions performs validation based on backend type
|
// ValidateInstanceOptions performs validation based on backend type
|
||||||
func (o *Options) ValidateInstanceOptions() error {
|
func (o *Options) ValidateInstanceOptions() error {
|
||||||
// Validate based on backend type
|
backend := o.getBackend()
|
||||||
switch o.BackendType {
|
if backend == nil {
|
||||||
case BackendTypeLlamaCpp:
|
return validation.ValidationError(fmt.Errorf("backend options cannot be nil for backend type %s", o.BackendType))
|
||||||
return validateLlamaCppOptions(o.LlamaServerOptions)
|
|
||||||
case BackendTypeMlxLm:
|
|
||||||
return validateMlxOptions(o.MlxServerOptions)
|
|
||||||
case BackendTypeVllm:
|
|
||||||
return validateVllmOptions(o.VllmServerOptions)
|
|
||||||
default:
|
|
||||||
return validation.ValidationError(fmt.Errorf("unsupported backend type: %s", o.BackendType))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return backend.Validate()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -336,6 +336,36 @@ func (o *LlamaServerOptions) UnmarshalJSON(data []byte) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (o *LlamaServerOptions) GetPort() int {
|
||||||
|
return o.Port
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *LlamaServerOptions) SetPort(port int) {
|
||||||
|
o.Port = port
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *LlamaServerOptions) GetHost() string {
|
||||||
|
return o.Host
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *LlamaServerOptions) Validate() error {
|
||||||
|
if o == nil {
|
||||||
|
return validation.ValidationError(fmt.Errorf("llama server options cannot be nil for llama.cpp backend"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use reflection to check all string fields for injection patterns
|
||||||
|
if err := validation.ValidateStructStrings(o, ""); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Basic network validation for port
|
||||||
|
if o.Port < 0 || o.Port > 65535 {
|
||||||
|
return validation.ValidationError(fmt.Errorf("invalid port range: %d", o.Port))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// BuildCommandArgs converts InstanceOptions to command line arguments
|
// BuildCommandArgs converts InstanceOptions to command line arguments
|
||||||
func (o *LlamaServerOptions) BuildCommandArgs() []string {
|
func (o *LlamaServerOptions) BuildCommandArgs() []string {
|
||||||
// Llama uses multiple flags for arrays by default (not comma-separated)
|
// Llama uses multiple flags for arrays by default (not comma-separated)
|
||||||
@@ -366,22 +396,3 @@ func ParseLlamaCommand(command string) (*LlamaServerOptions, error) {
|
|||||||
|
|
||||||
return &llamaOptions, nil
|
return &llamaOptions, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// validateLlamaCppOptions validates llama.cpp specific options
|
|
||||||
func validateLlamaCppOptions(options *LlamaServerOptions) error {
|
|
||||||
if options == nil {
|
|
||||||
return validation.ValidationError(fmt.Errorf("llama server options cannot be nil for llama.cpp backend"))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Use reflection to check all string fields for injection patterns
|
|
||||||
if err := validation.ValidateStructStrings(options, ""); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Basic network validation for port
|
|
||||||
if options.Port < 0 || options.Port > 65535 {
|
|
||||||
return validation.ValidationError(fmt.Errorf("invalid port range: %d", options.Port))
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -31,12 +31,45 @@ type MlxServerOptions struct {
|
|||||||
MaxTokens int `json:"max_tokens,omitempty"`
|
MaxTokens int `json:"max_tokens,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (o *MlxServerOptions) GetPort() int {
|
||||||
|
return o.Port
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *MlxServerOptions) SetPort(port int) {
|
||||||
|
o.Port = port
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *MlxServerOptions) GetHost() string {
|
||||||
|
return o.Host
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *MlxServerOptions) Validate() error {
|
||||||
|
if o == nil {
|
||||||
|
return validation.ValidationError(fmt.Errorf("MLX server options cannot be nil for MLX backend"))
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := validation.ValidateStructStrings(o, ""); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Basic network validation for port
|
||||||
|
if o.Port < 0 || o.Port > 65535 {
|
||||||
|
return validation.ValidationError(fmt.Errorf("invalid port range: %d", o.Port))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// BuildCommandArgs converts to command line arguments
|
// BuildCommandArgs converts to command line arguments
|
||||||
func (o *MlxServerOptions) BuildCommandArgs() []string {
|
func (o *MlxServerOptions) BuildCommandArgs() []string {
|
||||||
multipleFlags := map[string]bool{} // MLX doesn't currently have []string fields
|
multipleFlags := map[string]bool{} // MLX doesn't currently have []string fields
|
||||||
return BuildCommandArgs(o, multipleFlags)
|
return BuildCommandArgs(o, multipleFlags)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (o *MlxServerOptions) BuildDockerArgs() []string {
|
||||||
|
return []string{}
|
||||||
|
}
|
||||||
|
|
||||||
// ParseMlxCommand parses a mlx_lm.server command string into MlxServerOptions
|
// ParseMlxCommand parses a mlx_lm.server command string into MlxServerOptions
|
||||||
// Supports multiple formats:
|
// Supports multiple formats:
|
||||||
// 1. Full command: "mlx_lm.server --model model/path"
|
// 1. Full command: "mlx_lm.server --model model/path"
|
||||||
@@ -55,21 +88,3 @@ func ParseMlxCommand(command string) (*MlxServerOptions, error) {
|
|||||||
|
|
||||||
return &mlxOptions, nil
|
return &mlxOptions, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// validateMlxOptions validates MLX backend specific options
|
|
||||||
func validateMlxOptions(options *MlxServerOptions) error {
|
|
||||||
if options == nil {
|
|
||||||
return validation.ValidationError(fmt.Errorf("MLX server options cannot be nil for MLX backend"))
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := validation.ValidateStructStrings(options, ""); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Basic network validation for port
|
|
||||||
if options.Port < 0 || options.Port > 65535 {
|
|
||||||
return validation.ValidationError(fmt.Errorf("invalid port range: %d", options.Port))
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -140,6 +140,36 @@ type VllmServerOptions struct {
|
|||||||
OverrideKVCacheALIGNSize int `json:"override_kv_cache_align_size,omitempty"`
|
OverrideKVCacheALIGNSize int `json:"override_kv_cache_align_size,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (o *VllmServerOptions) GetPort() int {
|
||||||
|
return o.Port
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *VllmServerOptions) SetPort(port int) {
|
||||||
|
o.Port = port
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *VllmServerOptions) GetHost() string {
|
||||||
|
return o.Host
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *VllmServerOptions) Validate() error {
|
||||||
|
if o == nil {
|
||||||
|
return validation.ValidationError(fmt.Errorf("vLLM server options cannot be nil for vLLM backend"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use reflection to check all string fields for injection patterns
|
||||||
|
if err := validation.ValidateStructStrings(o, ""); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Basic network validation for port
|
||||||
|
if o.Port < 0 || o.Port > 65535 {
|
||||||
|
return validation.ValidationError(fmt.Errorf("invalid port range: %d", o.Port))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// BuildCommandArgs converts VllmServerOptions to command line arguments
|
// BuildCommandArgs converts VllmServerOptions to command line arguments
|
||||||
// For vLLM native, model is a positional argument after "serve"
|
// For vLLM native, model is a positional argument after "serve"
|
||||||
func (o *VllmServerOptions) BuildCommandArgs() []string {
|
func (o *VllmServerOptions) BuildCommandArgs() []string {
|
||||||
@@ -199,22 +229,3 @@ func ParseVllmCommand(command string) (*VllmServerOptions, error) {
|
|||||||
|
|
||||||
return &vllmOptions, nil
|
return &vllmOptions, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// validateVllmOptions validates vLLM backend specific options
|
|
||||||
func validateVllmOptions(options *VllmServerOptions) error {
|
|
||||||
if options == nil {
|
|
||||||
return validation.ValidationError(fmt.Errorf("vLLM server options cannot be nil for vLLM backend"))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Use reflection to check all string fields for injection patterns
|
|
||||||
if err := validation.ValidateStructStrings(options, ""); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Basic network validation for port
|
|
||||||
if options.Port < 0 || options.Port > 65535 {
|
|
||||||
return validation.ValidationError(fmt.Errorf("invalid port range: %d", options.Port))
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|||||||
Reference in New Issue
Block a user