diff --git a/pkg/backends/backend.go b/pkg/backends/backend.go index fc32549..ae22b1c 100644 --- a/pkg/backends/backend.go +++ b/pkg/backends/backend.go @@ -25,6 +25,12 @@ type backend interface { Validate() error } +var backendConstructors = map[BackendType]func() backend{ + BackendTypeLlamaCpp: func() backend { return &LlamaServerOptions{} }, + BackendTypeMlxLm: func() backend { return &MlxServerOptions{} }, + BackendTypeVllm: func() backend { return &VllmServerOptions{} }, +} + type Options struct { BackendType BackendType `json:"backend_type"` BackendOptions map[string]any `json:"backend_options,omitempty"` @@ -36,7 +42,6 @@ type Options struct { } func (o *Options) UnmarshalJSON(data []byte) error { - // Use anonymous struct to avoid recursion type Alias Options aux := &struct { *Alias @@ -48,52 +53,31 @@ func (o *Options) UnmarshalJSON(data []byte) error { return err } - // Parse backend-specific options - switch o.BackendType { - case BackendTypeLlamaCpp: - if o.BackendOptions != nil { - // Convert map to JSON and then unmarshal to LlamaServerOptions - optionsData, err := json.Marshal(o.BackendOptions) - if err != nil { - return fmt.Errorf("failed to marshal backend options: %w", err) - } - - o.LlamaServerOptions = &LlamaServerOptions{} - if err := json.Unmarshal(optionsData, o.LlamaServerOptions); err != nil { - return fmt.Errorf("failed to unmarshal llama.cpp options: %w", err) - } + // Create backend from constructor map + if o.BackendOptions != nil { + constructor, exists := backendConstructors[o.BackendType] + if !exists { + return fmt.Errorf("unsupported backend type: %s", o.BackendType) } - case BackendTypeMlxLm: - if o.BackendOptions != nil { - optionsData, err := json.Marshal(o.BackendOptions) - if err != nil { - return fmt.Errorf("failed to marshal backend options: %w", err) - } - o.MlxServerOptions = &MlxServerOptions{} - if err := json.Unmarshal(optionsData, o.MlxServerOptions); err != nil { - return fmt.Errorf("failed to unmarshal MLX options: %w", err) - } + backend := constructor() + optionsData, err := json.Marshal(o.BackendOptions) + if err != nil { + return fmt.Errorf("failed to marshal backend options: %w", err) } - case BackendTypeVllm: - if o.BackendOptions != nil { - optionsData, err := json.Marshal(o.BackendOptions) - if err != nil { - return fmt.Errorf("failed to marshal backend options: %w", err) - } - o.VllmServerOptions = &VllmServerOptions{} - if err := json.Unmarshal(optionsData, o.VllmServerOptions); err != nil { - return fmt.Errorf("failed to unmarshal vLLM options: %w", err) - } + if err := json.Unmarshal(optionsData, backend); err != nil { + return fmt.Errorf("failed to unmarshal backend options: %w", err) } + + // Store in the appropriate typed field for backward compatibility + o.setBackendOptions(backend) } return nil } func (o *Options) MarshalJSON() ([]byte, error) { - // Use anonymous struct to avoid recursion type Alias Options aux := &struct { *Alias @@ -101,48 +85,33 @@ func (o *Options) MarshalJSON() ([]byte, error) { Alias: (*Alias)(o), } - // Prepare BackendOptions map - if o.BackendOptions == nil { - o.BackendOptions = make(map[string]any) - } - - // Populate BackendOptions based on backend-specific options - switch o.BackendType { - case BackendTypeLlamaCpp: - if o.LlamaServerOptions != nil { - optionsData, err := json.Marshal(o.LlamaServerOptions) - if err != nil { - return nil, fmt.Errorf("failed to marshal llama.cpp options: %w", err) - } - if err := json.Unmarshal(optionsData, &o.BackendOptions); err != nil { - return nil, fmt.Errorf("failed to unmarshal llama.cpp options to map: %w", err) - } + // Get backend and marshal it + backend := o.getBackend() + if backend != nil { + optionsData, err := json.Marshal(backend) + if err != nil { + return nil, fmt.Errorf("failed to marshal backend options: %w", err) } - case BackendTypeMlxLm: - if o.MlxServerOptions != nil { - optionsData, err := json.Marshal(o.MlxServerOptions) - if err != nil { - return nil, fmt.Errorf("failed to marshal MLX options: %w", err) - } - if err := json.Unmarshal(optionsData, &o.BackendOptions); err != nil { - return nil, fmt.Errorf("failed to unmarshal MLX options to map: %w", err) - } - } - case BackendTypeVllm: - if o.VllmServerOptions != nil { - optionsData, err := json.Marshal(o.VllmServerOptions) - if err != nil { - return nil, fmt.Errorf("failed to marshal vLLM options: %w", err) - } - if err := json.Unmarshal(optionsData, &o.BackendOptions); err != nil { - return nil, fmt.Errorf("failed to unmarshal vLLM options to map: %w", err) - } + if err := json.Unmarshal(optionsData, &aux.BackendOptions); err != nil { + return nil, fmt.Errorf("failed to unmarshal backend options to map: %w", err) } } return json.Marshal(aux) } +// setBackendOptions stores the backend in the appropriate typed field +func (o *Options) setBackendOptions(bcknd backend) { + switch v := bcknd.(type) { + case *LlamaServerOptions: + o.LlamaServerOptions = v + case *MlxServerOptions: + o.MlxServerOptions = v + case *VllmServerOptions: + o.VllmServerOptions = v + } +} + func (o *Options) getBackendSettings(backendConfig *config.BackendConfig) *config.BackendSettings { switch o.BackendType { case BackendTypeLlamaCpp: