Add MLX backend support with configuration and parsing enhancements

This commit is contained in:
2025-09-16 22:38:39 +02:00
parent cc5d8acd92
commit 587be68077
7 changed files with 214 additions and 37 deletions

View File

@@ -1,4 +1,15 @@
import { type CreateInstanceOptions, type BackendOptions, getAllFieldKeys, getAllBackendFieldKeys } from '@/schemas/instanceOptions'
import {
type CreateInstanceOptions,
type LlamaCppBackendOptions,
type MlxBackendOptions,
LlamaCppBackendOptionsSchema,
MlxBackendOptionsSchema,
getAllFieldKeys,
getAllLlamaCppFieldKeys,
getAllMlxFieldKeys,
getLlamaCppFieldType,
getMlxFieldType
} from '@/schemas/instanceOptions'
// Instance-level basic fields (not backend-specific)
export const basicFieldsConfig: Record<string, {
@@ -36,8 +47,8 @@ export const basicFieldsConfig: Record<string, {
}
}
// Backend-specific basic fields (these go in backend_options)
export const basicBackendFieldsConfig: Record<string, {
// LlamaCpp backend-specific basic fields
const basicLlamaCppFieldsConfig: Record<string, {
label: string
description?: string
placeholder?: string
@@ -46,7 +57,8 @@ export const basicBackendFieldsConfig: Record<string, {
model: {
label: 'Model Path',
placeholder: '/path/to/model.gguf',
description: 'Path to the model file'
description: 'Path to the model file',
required: true
},
hf_repo: {
label: 'Hugging Face Repository',
@@ -65,13 +77,55 @@ export const basicBackendFieldsConfig: Record<string, {
}
}
export function isBasicField(key: keyof CreateInstanceOptions): boolean {
// MLX backend-specific basic fields
const basicMlxFieldsConfig: Record<string, {
label: string
description?: string
placeholder?: string
required?: boolean
}> = {
model: {
label: 'Model',
placeholder: 'mlx-community/Mistral-7B-Instruct-v0.3-4bit',
description: 'The path to the MLX model weights, tokenizer, and config',
required: true
},
python_path: {
label: 'Python Virtual Environment Path',
placeholder: '/path/to/venv',
description: 'Path to Python virtual environment (optional)'
},
temp: {
label: 'Temperature',
placeholder: '0.0',
description: 'Default sampling temperature (default: 0.0)'
},
top_p: {
label: 'Top-P',
placeholder: '1.0',
description: 'Default nucleus sampling top-p (default: 1.0)'
},
top_k: {
label: 'Top-K',
placeholder: '0',
description: 'Default top-k sampling (default: 0, disables top-k)'
},
min_p: {
label: 'Min-P',
placeholder: '0.0',
description: 'Default min-p sampling (default: 0.0, disables min-p)'
},
max_tokens: {
label: 'Max Tokens',
placeholder: '512',
description: 'Default maximum number of tokens to generate (default: 512)'
}
}
function isBasicField(key: keyof CreateInstanceOptions): boolean {
return key in basicFieldsConfig
}
export function isBasicBackendField(key: keyof BackendOptions): boolean {
return key in basicBackendFieldsConfig
}
export function getBasicFields(): (keyof CreateInstanceOptions)[] {
return Object.keys(basicFieldsConfig) as (keyof CreateInstanceOptions)[]
@@ -81,13 +135,61 @@ export function getAdvancedFields(): (keyof CreateInstanceOptions)[] {
return getAllFieldKeys().filter(key => !isBasicField(key))
}
export function getBasicBackendFields(): (keyof BackendOptions)[] {
return Object.keys(basicBackendFieldsConfig) as (keyof BackendOptions)[]
export function getBasicBackendFields(backendType?: string): string[] {
if (backendType === 'mlx_lm') {
return Object.keys(basicMlxFieldsConfig)
} else if (backendType === 'llama_cpp') {
return Object.keys(basicLlamaCppFieldsConfig)
}
// Default to LlamaCpp for backward compatibility
return Object.keys(basicLlamaCppFieldsConfig)
}
export function getAdvancedBackendFields(): (keyof BackendOptions)[] {
return getAllBackendFieldKeys().filter(key => !isBasicBackendField(key))
export function getAdvancedBackendFields(backendType?: string): string[] {
if (backendType === 'mlx_lm') {
return getAllMlxFieldKeys().filter(key => !(key in basicMlxFieldsConfig))
} else if (backendType === 'llama_cpp') {
return getAllLlamaCppFieldKeys().filter(key => !(key in basicLlamaCppFieldsConfig))
}
// Default to LlamaCpp for backward compatibility
return getAllLlamaCppFieldKeys().filter(key => !(key in basicLlamaCppFieldsConfig))
}
// Combined backend fields config for use in BackendFormField
export const basicBackendFieldsConfig: Record<string, {
label: string
description?: string
placeholder?: string
required?: boolean
}> = {
...basicLlamaCppFieldsConfig,
...basicMlxFieldsConfig
}
// Get field type for any backend option (union type)
export function getBackendFieldType(key: string): 'text' | 'number' | 'boolean' | 'array' {
// Try to get type from LlamaCpp schema first
try {
if (LlamaCppBackendOptionsSchema.shape && key in LlamaCppBackendOptionsSchema.shape) {
return getLlamaCppFieldType(key as keyof LlamaCppBackendOptions)
}
} catch {
// Schema might not be available
}
// Try MLX schema
try {
if (MlxBackendOptionsSchema.shape && key in MlxBackendOptionsSchema.shape) {
return getMlxFieldType(key as keyof MlxBackendOptions)
}
} catch {
// Schema might not be available
}
// Default fallback
return 'text'
}
// Re-export the Zod-based functions
export { getFieldType, getBackendFieldType } from '@/schemas/instanceOptions'
export { getFieldType } from '@/schemas/instanceOptions'