Refactor backend options marshaling/unmarshaling

This commit is contained in:
2025-10-19 20:48:05 +02:00
parent d8e0da9cf8
commit b25ad48605

View File

@@ -25,6 +25,12 @@ type backend interface {
Validate() error Validate() error
} }
var backendConstructors = map[BackendType]func() backend{
BackendTypeLlamaCpp: func() backend { return &LlamaServerOptions{} },
BackendTypeMlxLm: func() backend { return &MlxServerOptions{} },
BackendTypeVllm: func() backend { return &VllmServerOptions{} },
}
type Options struct { type Options struct {
BackendType BackendType `json:"backend_type"` BackendType BackendType `json:"backend_type"`
BackendOptions map[string]any `json:"backend_options,omitempty"` BackendOptions map[string]any `json:"backend_options,omitempty"`
@@ -36,7 +42,6 @@ type Options struct {
} }
func (o *Options) UnmarshalJSON(data []byte) error { func (o *Options) UnmarshalJSON(data []byte) error {
// Use anonymous struct to avoid recursion
type Alias Options type Alias Options
aux := &struct { aux := &struct {
*Alias *Alias
@@ -48,52 +53,31 @@ func (o *Options) UnmarshalJSON(data []byte) error {
return err return err
} }
// Parse backend-specific options // Create backend from constructor map
switch o.BackendType {
case BackendTypeLlamaCpp:
if o.BackendOptions != nil { if o.BackendOptions != nil {
// Convert map to JSON and then unmarshal to LlamaServerOptions constructor, exists := backendConstructors[o.BackendType]
if !exists {
return fmt.Errorf("unsupported backend type: %s", o.BackendType)
}
backend := constructor()
optionsData, err := json.Marshal(o.BackendOptions) optionsData, err := json.Marshal(o.BackendOptions)
if err != nil { if err != nil {
return fmt.Errorf("failed to marshal backend options: %w", err) return fmt.Errorf("failed to marshal backend options: %w", err)
} }
o.LlamaServerOptions = &LlamaServerOptions{} if err := json.Unmarshal(optionsData, backend); err != nil {
if err := json.Unmarshal(optionsData, o.LlamaServerOptions); err != nil { return fmt.Errorf("failed to unmarshal backend options: %w", err)
return fmt.Errorf("failed to unmarshal llama.cpp options: %w", err)
}
}
case BackendTypeMlxLm:
if o.BackendOptions != nil {
optionsData, err := json.Marshal(o.BackendOptions)
if err != nil {
return fmt.Errorf("failed to marshal backend options: %w", err)
} }
o.MlxServerOptions = &MlxServerOptions{} // Store in the appropriate typed field for backward compatibility
if err := json.Unmarshal(optionsData, o.MlxServerOptions); err != nil { o.setBackendOptions(backend)
return fmt.Errorf("failed to unmarshal MLX options: %w", err)
}
}
case BackendTypeVllm:
if o.BackendOptions != nil {
optionsData, err := json.Marshal(o.BackendOptions)
if err != nil {
return fmt.Errorf("failed to marshal backend options: %w", err)
}
o.VllmServerOptions = &VllmServerOptions{}
if err := json.Unmarshal(optionsData, o.VllmServerOptions); err != nil {
return fmt.Errorf("failed to unmarshal vLLM options: %w", err)
}
}
} }
return nil return nil
} }
func (o *Options) MarshalJSON() ([]byte, error) { func (o *Options) MarshalJSON() ([]byte, error) {
// Use anonymous struct to avoid recursion
type Alias Options type Alias Options
aux := &struct { aux := &struct {
*Alias *Alias
@@ -101,48 +85,33 @@ func (o *Options) MarshalJSON() ([]byte, error) {
Alias: (*Alias)(o), Alias: (*Alias)(o),
} }
// Prepare BackendOptions map // Get backend and marshal it
if o.BackendOptions == nil { backend := o.getBackend()
o.BackendOptions = make(map[string]any) if backend != nil {
} optionsData, err := json.Marshal(backend)
// Populate BackendOptions based on backend-specific options
switch o.BackendType {
case BackendTypeLlamaCpp:
if o.LlamaServerOptions != nil {
optionsData, err := json.Marshal(o.LlamaServerOptions)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to marshal llama.cpp options: %w", err) return nil, fmt.Errorf("failed to marshal backend options: %w", err)
}
if err := json.Unmarshal(optionsData, &o.BackendOptions); err != nil {
return nil, fmt.Errorf("failed to unmarshal llama.cpp options to map: %w", err)
}
}
case BackendTypeMlxLm:
if o.MlxServerOptions != nil {
optionsData, err := json.Marshal(o.MlxServerOptions)
if err != nil {
return nil, fmt.Errorf("failed to marshal MLX options: %w", err)
}
if err := json.Unmarshal(optionsData, &o.BackendOptions); err != nil {
return nil, fmt.Errorf("failed to unmarshal MLX options to map: %w", err)
}
}
case BackendTypeVllm:
if o.VllmServerOptions != nil {
optionsData, err := json.Marshal(o.VllmServerOptions)
if err != nil {
return nil, fmt.Errorf("failed to marshal vLLM options: %w", err)
}
if err := json.Unmarshal(optionsData, &o.BackendOptions); err != nil {
return nil, fmt.Errorf("failed to unmarshal vLLM options to map: %w", err)
} }
if err := json.Unmarshal(optionsData, &aux.BackendOptions); err != nil {
return nil, fmt.Errorf("failed to unmarshal backend options to map: %w", err)
} }
} }
return json.Marshal(aux) return json.Marshal(aux)
} }
// setBackendOptions stores the backend in the appropriate typed field
func (o *Options) setBackendOptions(bcknd backend) {
switch v := bcknd.(type) {
case *LlamaServerOptions:
o.LlamaServerOptions = v
case *MlxServerOptions:
o.MlxServerOptions = v
case *VllmServerOptions:
o.VllmServerOptions = v
}
}
func (o *Options) getBackendSettings(backendConfig *config.BackendConfig) *config.BackendSettings { func (o *Options) getBackendSettings(backendConfig *config.BackendConfig) *config.BackendSettings {
switch o.BackendType { switch o.BackendType {
case BackendTypeLlamaCpp: case BackendTypeLlamaCpp: