mirror of
https://github.com/lordmathis/llamactl.git
synced 2025-11-06 09:04:27 +00:00
Split large package into subpackages
This commit is contained in:
365
pkg/backends/llamacpp/llama.go
Normal file
365
pkg/backends/llamacpp/llama.go
Normal file
@@ -0,0 +1,365 @@
|
||||
package llamacpp
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type LlamaServerOptions struct {
|
||||
// Common params
|
||||
VerbosePrompt bool `json:"verbose_prompt,omitempty"`
|
||||
Threads int `json:"threads,omitempty"`
|
||||
ThreadsBatch int `json:"threads_batch,omitempty"`
|
||||
CPUMask string `json:"cpu_mask,omitempty"`
|
||||
CPURange string `json:"cpu_range,omitempty"`
|
||||
CPUStrict int `json:"cpu_strict,omitempty"`
|
||||
Priority int `json:"priority,omitempty"`
|
||||
Poll int `json:"poll,omitempty"`
|
||||
CPUMaskBatch string `json:"cpu_mask_batch,omitempty"`
|
||||
CPURangeBatch string `json:"cpu_range_batch,omitempty"`
|
||||
CPUStrictBatch int `json:"cpu_strict_batch,omitempty"`
|
||||
PriorityBatch int `json:"priority_batch,omitempty"`
|
||||
PollBatch int `json:"poll_batch,omitempty"`
|
||||
CtxSize int `json:"ctx_size,omitempty"`
|
||||
Predict int `json:"predict,omitempty"`
|
||||
BatchSize int `json:"batch_size,omitempty"`
|
||||
UBatchSize int `json:"ubatch_size,omitempty"`
|
||||
Keep int `json:"keep,omitempty"`
|
||||
FlashAttn bool `json:"flash_attn,omitempty"`
|
||||
NoPerf bool `json:"no_perf,omitempty"`
|
||||
Escape bool `json:"escape,omitempty"`
|
||||
NoEscape bool `json:"no_escape,omitempty"`
|
||||
RopeScaling string `json:"rope_scaling,omitempty"`
|
||||
RopeScale float64 `json:"rope_scale,omitempty"`
|
||||
RopeFreqBase float64 `json:"rope_freq_base,omitempty"`
|
||||
RopeFreqScale float64 `json:"rope_freq_scale,omitempty"`
|
||||
YarnOrigCtx int `json:"yarn_orig_ctx,omitempty"`
|
||||
YarnExtFactor float64 `json:"yarn_ext_factor,omitempty"`
|
||||
YarnAttnFactor float64 `json:"yarn_attn_factor,omitempty"`
|
||||
YarnBetaSlow float64 `json:"yarn_beta_slow,omitempty"`
|
||||
YarnBetaFast float64 `json:"yarn_beta_fast,omitempty"`
|
||||
DumpKVCache bool `json:"dump_kv_cache,omitempty"`
|
||||
NoKVOffload bool `json:"no_kv_offload,omitempty"`
|
||||
CacheTypeK string `json:"cache_type_k,omitempty"`
|
||||
CacheTypeV string `json:"cache_type_v,omitempty"`
|
||||
DefragThold float64 `json:"defrag_thold,omitempty"`
|
||||
Parallel int `json:"parallel,omitempty"`
|
||||
Mlock bool `json:"mlock,omitempty"`
|
||||
NoMmap bool `json:"no_mmap,omitempty"`
|
||||
Numa string `json:"numa,omitempty"`
|
||||
Device string `json:"device,omitempty"`
|
||||
OverrideTensor []string `json:"override_tensor,omitempty"`
|
||||
GPULayers int `json:"gpu_layers,omitempty"`
|
||||
SplitMode string `json:"split_mode,omitempty"`
|
||||
TensorSplit string `json:"tensor_split,omitempty"`
|
||||
MainGPU int `json:"main_gpu,omitempty"`
|
||||
CheckTensors bool `json:"check_tensors,omitempty"`
|
||||
OverrideKV []string `json:"override_kv,omitempty"`
|
||||
Lora []string `json:"lora,omitempty"`
|
||||
LoraScaled []string `json:"lora_scaled,omitempty"`
|
||||
ControlVector []string `json:"control_vector,omitempty"`
|
||||
ControlVectorScaled []string `json:"control_vector_scaled,omitempty"`
|
||||
ControlVectorLayerRange string `json:"control_vector_layer_range,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
ModelURL string `json:"model_url,omitempty"`
|
||||
HFRepo string `json:"hf_repo,omitempty"`
|
||||
HFRepoDraft string `json:"hf_repo_draft,omitempty"`
|
||||
HFFile string `json:"hf_file,omitempty"`
|
||||
HFRepoV string `json:"hf_repo_v,omitempty"`
|
||||
HFFileV string `json:"hf_file_v,omitempty"`
|
||||
HFToken string `json:"hf_token,omitempty"`
|
||||
LogDisable bool `json:"log_disable,omitempty"`
|
||||
LogFile string `json:"log_file,omitempty"`
|
||||
LogColors bool `json:"log_colors,omitempty"`
|
||||
Verbose bool `json:"verbose,omitempty"`
|
||||
Verbosity int `json:"verbosity,omitempty"`
|
||||
LogPrefix bool `json:"log_prefix,omitempty"`
|
||||
LogTimestamps bool `json:"log_timestamps,omitempty"`
|
||||
|
||||
// Sampling params
|
||||
Samplers string `json:"samplers,omitempty"`
|
||||
Seed int `json:"seed,omitempty"`
|
||||
SamplingSeq string `json:"sampling_seq,omitempty"`
|
||||
IgnoreEOS bool `json:"ignore_eos,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
MinP float64 `json:"min_p,omitempty"`
|
||||
XTCProbability float64 `json:"xtc_probability,omitempty"`
|
||||
XTCThreshold float64 `json:"xtc_threshold,omitempty"`
|
||||
Typical float64 `json:"typical,omitempty"`
|
||||
RepeatLastN int `json:"repeat_last_n,omitempty"`
|
||||
RepeatPenalty float64 `json:"repeat_penalty,omitempty"`
|
||||
PresencePenalty float64 `json:"presence_penalty,omitempty"`
|
||||
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
|
||||
DryMultiplier float64 `json:"dry_multiplier,omitempty"`
|
||||
DryBase float64 `json:"dry_base,omitempty"`
|
||||
DryAllowedLength int `json:"dry_allowed_length,omitempty"`
|
||||
DryPenaltyLastN int `json:"dry_penalty_last_n,omitempty"`
|
||||
DrySequenceBreaker []string `json:"dry_sequence_breaker,omitempty"`
|
||||
DynatempRange float64 `json:"dynatemp_range,omitempty"`
|
||||
DynatempExp float64 `json:"dynatemp_exp,omitempty"`
|
||||
Mirostat int `json:"mirostat,omitempty"`
|
||||
MirostatLR float64 `json:"mirostat_lr,omitempty"`
|
||||
MirostatEnt float64 `json:"mirostat_ent,omitempty"`
|
||||
LogitBias []string `json:"logit_bias,omitempty"`
|
||||
Grammar string `json:"grammar,omitempty"`
|
||||
GrammarFile string `json:"grammar_file,omitempty"`
|
||||
JSONSchema string `json:"json_schema,omitempty"`
|
||||
JSONSchemaFile string `json:"json_schema_file,omitempty"`
|
||||
|
||||
// Server/Example-specific params
|
||||
NoContextShift bool `json:"no_context_shift,omitempty"`
|
||||
Special bool `json:"special,omitempty"`
|
||||
NoWarmup bool `json:"no_warmup,omitempty"`
|
||||
SPMInfill bool `json:"spm_infill,omitempty"`
|
||||
Pooling string `json:"pooling,omitempty"`
|
||||
ContBatching bool `json:"cont_batching,omitempty"`
|
||||
NoContBatching bool `json:"no_cont_batching,omitempty"`
|
||||
MMProj string `json:"mmproj,omitempty"`
|
||||
MMProjURL string `json:"mmproj_url,omitempty"`
|
||||
NoMMProj bool `json:"no_mmproj,omitempty"`
|
||||
NoMMProjOffload bool `json:"no_mmproj_offload,omitempty"`
|
||||
Alias string `json:"alias,omitempty"`
|
||||
Host string `json:"host,omitempty"`
|
||||
Port int `json:"port,omitempty"`
|
||||
Path string `json:"path,omitempty"`
|
||||
NoWebUI bool `json:"no_webui,omitempty"`
|
||||
Embedding bool `json:"embedding,omitempty"`
|
||||
Reranking bool `json:"reranking,omitempty"`
|
||||
APIKey string `json:"api_key,omitempty"`
|
||||
APIKeyFile string `json:"api_key_file,omitempty"`
|
||||
SSLKeyFile string `json:"ssl_key_file,omitempty"`
|
||||
SSLCertFile string `json:"ssl_cert_file,omitempty"`
|
||||
ChatTemplateKwargs string `json:"chat_template_kwargs,omitempty"`
|
||||
Timeout int `json:"timeout,omitempty"`
|
||||
ThreadsHTTP int `json:"threads_http,omitempty"`
|
||||
CacheReuse int `json:"cache_reuse,omitempty"`
|
||||
Metrics bool `json:"metrics,omitempty"`
|
||||
Slots bool `json:"slots,omitempty"`
|
||||
Props bool `json:"props,omitempty"`
|
||||
NoSlots bool `json:"no_slots,omitempty"`
|
||||
SlotSavePath string `json:"slot_save_path,omitempty"`
|
||||
Jinja bool `json:"jinja,omitempty"`
|
||||
ReasoningFormat string `json:"reasoning_format,omitempty"`
|
||||
ReasoningBudget int `json:"reasoning_budget,omitempty"`
|
||||
ChatTemplate string `json:"chat_template,omitempty"`
|
||||
ChatTemplateFile string `json:"chat_template_file,omitempty"`
|
||||
NoPrefillAssistant bool `json:"no_prefill_assistant,omitempty"`
|
||||
SlotPromptSimilarity float64 `json:"slot_prompt_similarity,omitempty"`
|
||||
LoraInitWithoutApply bool `json:"lora_init_without_apply,omitempty"`
|
||||
|
||||
// Speculative decoding params
|
||||
DraftMax int `json:"draft_max,omitempty"`
|
||||
DraftMin int `json:"draft_min,omitempty"`
|
||||
DraftPMin float64 `json:"draft_p_min,omitempty"`
|
||||
CtxSizeDraft int `json:"ctx_size_draft,omitempty"`
|
||||
DeviceDraft string `json:"device_draft,omitempty"`
|
||||
GPULayersDraft int `json:"gpu_layers_draft,omitempty"`
|
||||
ModelDraft string `json:"model_draft,omitempty"`
|
||||
CacheTypeKDraft string `json:"cache_type_k_draft,omitempty"`
|
||||
CacheTypeVDraft string `json:"cache_type_v_draft,omitempty"`
|
||||
|
||||
// Audio/TTS params
|
||||
ModelVocoder string `json:"model_vocoder,omitempty"`
|
||||
TTSUseGuideTokens bool `json:"tts_use_guide_tokens,omitempty"`
|
||||
|
||||
// Default model params
|
||||
EmbdBGESmallEnDefault bool `json:"embd_bge_small_en_default,omitempty"`
|
||||
EmbdE5SmallEnDefault bool `json:"embd_e5_small_en_default,omitempty"`
|
||||
EmbdGTESmallDefault bool `json:"embd_gte_small_default,omitempty"`
|
||||
FIMQwen1_5BDefault bool `json:"fim_qwen_1_5b_default,omitempty"`
|
||||
FIMQwen3BDefault bool `json:"fim_qwen_3b_default,omitempty"`
|
||||
FIMQwen7BDefault bool `json:"fim_qwen_7b_default,omitempty"`
|
||||
FIMQwen7BSpec bool `json:"fim_qwen_7b_spec,omitempty"`
|
||||
FIMQwen14BSpec bool `json:"fim_qwen_14b_spec,omitempty"`
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements custom JSON unmarshaling to support multiple field names
|
||||
func (o *LlamaServerOptions) UnmarshalJSON(data []byte) error {
|
||||
// First unmarshal into a map to handle multiple field names
|
||||
var raw map[string]any
|
||||
if err := json.Unmarshal(data, &raw); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Create a temporary struct for standard unmarshaling
|
||||
type tempOptions LlamaServerOptions
|
||||
temp := tempOptions{}
|
||||
|
||||
// Standard unmarshal first
|
||||
if err := json.Unmarshal(data, &temp); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Copy to our struct
|
||||
*o = LlamaServerOptions(temp)
|
||||
|
||||
// Handle alternative field names
|
||||
fieldMappings := map[string]string{
|
||||
// Official llama-server short forms from the documentation
|
||||
"t": "threads", // -t, --threads N
|
||||
"tb": "threads_batch", // -tb, --threads-batch N
|
||||
"C": "cpu_mask", // -C, --cpu-mask M
|
||||
"Cr": "cpu_range", // -Cr, --cpu-range lo-hi
|
||||
"Cb": "cpu_mask_batch", // -Cb, --cpu-mask-batch M
|
||||
"Crb": "cpu_range_batch", // -Crb, --cpu-range-batch lo-hi
|
||||
"c": "ctx_size", // -c, --ctx-size N
|
||||
"n": "predict", // -n, --predict, --n-predict N
|
||||
"b": "batch_size", // -b, --batch-size N
|
||||
"ub": "ubatch_size", // -ub, --ubatch-size N
|
||||
"fa": "flash_attn", // -fa, --flash-attn
|
||||
"e": "escape", // -e, --escape
|
||||
"dkvc": "dump_kv_cache", // -dkvc, --dump-kv-cache
|
||||
"nkvo": "no_kv_offload", // -nkvo, --no-kv-offload
|
||||
"ctk": "cache_type_k", // -ctk, --cache-type-k TYPE
|
||||
"ctv": "cache_type_v", // -ctv, --cache-type-v TYPE
|
||||
"dt": "defrag_thold", // -dt, --defrag-thold N
|
||||
"np": "parallel", // -np, --parallel N
|
||||
"dev": "device", // -dev, --device <dev1,dev2,..>
|
||||
"ot": "override_tensor", // --override-tensor, -ot
|
||||
"ngl": "gpu_layers", // -ngl, --gpu-layers, --n-gpu-layers N
|
||||
"sm": "split_mode", // -sm, --split-mode
|
||||
"ts": "tensor_split", // -ts, --tensor-split N0,N1,N2,...
|
||||
"mg": "main_gpu", // -mg, --main-gpu INDEX
|
||||
"m": "model", // -m, --model FNAME
|
||||
"mu": "model_url", // -mu, --model-url MODEL_URL
|
||||
"hf": "hf_repo", // -hf, -hfr, --hf-repo
|
||||
"hfr": "hf_repo", // -hf, -hfr, --hf-repo
|
||||
"hfd": "hf_repo_draft", // -hfd, -hfrd, --hf-repo-draft
|
||||
"hfrd": "hf_repo_draft", // -hfd, -hfrd, --hf-repo-draft
|
||||
"hff": "hf_file", // -hff, --hf-file FILE
|
||||
"hfv": "hf_repo_v", // -hfv, -hfrv, --hf-repo-v
|
||||
"hfrv": "hf_repo_v", // -hfv, -hfrv, --hf-repo-v
|
||||
"hffv": "hf_file_v", // -hffv, --hf-file-v FILE
|
||||
"hft": "hf_token", // -hft, --hf-token TOKEN
|
||||
"v": "verbose", // -v, --verbose, --log-verbose
|
||||
"lv": "verbosity", // -lv, --verbosity, --log-verbosity N
|
||||
"s": "seed", // -s, --seed SEED
|
||||
"temp": "temperature", // --temp N
|
||||
"l": "logit_bias", // -l, --logit-bias
|
||||
"j": "json_schema", // -j, --json-schema SCHEMA
|
||||
"jf": "json_schema_file", // -jf, --json-schema-file FILE
|
||||
"sp": "special", // -sp, --special
|
||||
"cb": "cont_batching", // -cb, --cont-batching
|
||||
"nocb": "no_cont_batching", // -nocb, --no-cont-batching
|
||||
"a": "alias", // -a, --alias STRING
|
||||
"to": "timeout", // -to, --timeout N
|
||||
"sps": "slot_prompt_similarity", // -sps, --slot-prompt-similarity
|
||||
"cd": "ctx_size_draft", // -cd, --ctx-size-draft N
|
||||
"devd": "device_draft", // -devd, --device-draft
|
||||
"ngld": "gpu_layers_draft", // -ngld, --gpu-layers-draft
|
||||
"md": "model_draft", // -md, --model-draft FNAME
|
||||
"ctkd": "cache_type_k_draft", // -ctkd, --cache-type-k-draft TYPE
|
||||
"ctvd": "cache_type_v_draft", // -ctvd, --cache-type-v-draft TYPE
|
||||
"mv": "model_vocoder", // -mv, --model-vocoder FNAME
|
||||
}
|
||||
|
||||
// Process alternative field names
|
||||
for altName, canonicalName := range fieldMappings {
|
||||
if value, exists := raw[altName]; exists {
|
||||
// Use reflection to set the field value
|
||||
v := reflect.ValueOf(o).Elem()
|
||||
field := v.FieldByNameFunc(func(fieldName string) bool {
|
||||
field, _ := v.Type().FieldByName(fieldName)
|
||||
jsonTag := field.Tag.Get("json")
|
||||
return jsonTag == canonicalName+",omitempty" || jsonTag == canonicalName
|
||||
})
|
||||
|
||||
if field.IsValid() && field.CanSet() {
|
||||
switch field.Kind() {
|
||||
case reflect.Int:
|
||||
if intVal, ok := value.(float64); ok {
|
||||
field.SetInt(int64(intVal))
|
||||
} else if strVal, ok := value.(string); ok {
|
||||
if intVal, err := strconv.Atoi(strVal); err == nil {
|
||||
field.SetInt(int64(intVal))
|
||||
}
|
||||
}
|
||||
case reflect.Float64:
|
||||
if floatVal, ok := value.(float64); ok {
|
||||
field.SetFloat(floatVal)
|
||||
} else if strVal, ok := value.(string); ok {
|
||||
if floatVal, err := strconv.ParseFloat(strVal, 64); err == nil {
|
||||
field.SetFloat(floatVal)
|
||||
}
|
||||
}
|
||||
case reflect.String:
|
||||
if strVal, ok := value.(string); ok {
|
||||
field.SetString(strVal)
|
||||
}
|
||||
case reflect.Bool:
|
||||
if boolVal, ok := value.(bool); ok {
|
||||
field.SetBool(boolVal)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// BuildCommandArgs converts InstanceOptions to command line arguments
|
||||
func (o *LlamaServerOptions) BuildCommandArgs() []string {
|
||||
var args []string
|
||||
|
||||
v := reflect.ValueOf(o).Elem()
|
||||
t := v.Type()
|
||||
|
||||
for i := 0; i < v.NumField(); i++ {
|
||||
field := v.Field(i)
|
||||
fieldType := t.Field(i)
|
||||
|
||||
// Skip unexported fields
|
||||
if !field.CanInterface() {
|
||||
continue
|
||||
}
|
||||
|
||||
// Get the JSON tag to determine the flag name
|
||||
jsonTag := fieldType.Tag.Get("json")
|
||||
if jsonTag == "" || jsonTag == "-" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Remove ",omitempty" from the tag
|
||||
flagName := jsonTag
|
||||
if commaIndex := strings.Index(jsonTag, ","); commaIndex != -1 {
|
||||
flagName = jsonTag[:commaIndex]
|
||||
}
|
||||
|
||||
// Convert snake_case to kebab-case for CLI flags
|
||||
flagName = strings.ReplaceAll(flagName, "_", "-")
|
||||
|
||||
// Add the appropriate arguments based on field type and value
|
||||
switch field.Kind() {
|
||||
case reflect.Bool:
|
||||
if field.Bool() {
|
||||
args = append(args, "--"+flagName)
|
||||
}
|
||||
case reflect.Int:
|
||||
if field.Int() != 0 {
|
||||
args = append(args, "--"+flagName, strconv.FormatInt(field.Int(), 10))
|
||||
}
|
||||
case reflect.Float64:
|
||||
if field.Float() != 0 {
|
||||
args = append(args, "--"+flagName, strconv.FormatFloat(field.Float(), 'f', -1, 64))
|
||||
}
|
||||
case reflect.String:
|
||||
if field.String() != "" {
|
||||
args = append(args, "--"+flagName, field.String())
|
||||
}
|
||||
case reflect.Slice:
|
||||
if field.Type().Elem().Kind() == reflect.String {
|
||||
// Handle []string fields
|
||||
for j := 0; j < field.Len(); j++ {
|
||||
args = append(args, "--"+flagName, field.Index(j).String())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return args
|
||||
}
|
||||
396
pkg/backends/llamacpp/llama_test.go
Normal file
396
pkg/backends/llamacpp/llama_test.go
Normal file
@@ -0,0 +1,396 @@
|
||||
package llamacpp_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"llamactl/pkg/backends/llamacpp"
|
||||
"reflect"
|
||||
"slices"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestBuildCommandArgs_BasicFields(t *testing.T) {
|
||||
options := llamacpp.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
Port: 8080,
|
||||
Host: "localhost",
|
||||
Verbose: true,
|
||||
CtxSize: 4096,
|
||||
GPULayers: 32,
|
||||
}
|
||||
|
||||
args := options.BuildCommandArgs()
|
||||
|
||||
// Check individual arguments
|
||||
expectedPairs := map[string]string{
|
||||
"--model": "/path/to/model.gguf",
|
||||
"--port": "8080",
|
||||
"--host": "localhost",
|
||||
"--ctx-size": "4096",
|
||||
"--gpu-layers": "32",
|
||||
}
|
||||
|
||||
for flag, expectedValue := range expectedPairs {
|
||||
if !containsFlagWithValue(args, flag, expectedValue) {
|
||||
t.Errorf("Expected %s %s, not found in %v", flag, expectedValue, args)
|
||||
}
|
||||
}
|
||||
|
||||
// Check standalone boolean flag
|
||||
if !contains(args, "--verbose") {
|
||||
t.Errorf("Expected --verbose flag not found in %v", args)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildCommandArgs_BooleanFields(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
options llamacpp.LlamaServerOptions
|
||||
expected []string
|
||||
excluded []string
|
||||
}{
|
||||
{
|
||||
name: "verbose true",
|
||||
options: llamacpp.LlamaServerOptions{
|
||||
Verbose: true,
|
||||
},
|
||||
expected: []string{"--verbose"},
|
||||
},
|
||||
{
|
||||
name: "verbose false",
|
||||
options: llamacpp.LlamaServerOptions{
|
||||
Verbose: false,
|
||||
},
|
||||
excluded: []string{"--verbose"},
|
||||
},
|
||||
{
|
||||
name: "multiple booleans",
|
||||
options: llamacpp.LlamaServerOptions{
|
||||
Verbose: true,
|
||||
FlashAttn: true,
|
||||
Mlock: false,
|
||||
NoMmap: true,
|
||||
},
|
||||
expected: []string{"--verbose", "--flash-attn", "--no-mmap"},
|
||||
excluded: []string{"--mlock"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
args := tt.options.BuildCommandArgs()
|
||||
|
||||
for _, expectedArg := range tt.expected {
|
||||
if !contains(args, expectedArg) {
|
||||
t.Errorf("Expected argument %q not found in %v", expectedArg, args)
|
||||
}
|
||||
}
|
||||
|
||||
for _, excludedArg := range tt.excluded {
|
||||
if contains(args, excludedArg) {
|
||||
t.Errorf("Excluded argument %q found in %v", excludedArg, args)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildCommandArgs_NumericFields(t *testing.T) {
|
||||
options := llamacpp.LlamaServerOptions{
|
||||
Port: 8080,
|
||||
Threads: 4,
|
||||
CtxSize: 2048,
|
||||
GPULayers: 16,
|
||||
Temperature: 0.7,
|
||||
TopK: 40,
|
||||
TopP: 0.9,
|
||||
}
|
||||
|
||||
args := options.BuildCommandArgs()
|
||||
|
||||
expectedPairs := map[string]string{
|
||||
"--port": "8080",
|
||||
"--threads": "4",
|
||||
"--ctx-size": "2048",
|
||||
"--gpu-layers": "16",
|
||||
"--temperature": "0.7",
|
||||
"--top-k": "40",
|
||||
"--top-p": "0.9",
|
||||
}
|
||||
|
||||
for flag, expectedValue := range expectedPairs {
|
||||
if !containsFlagWithValue(args, flag, expectedValue) {
|
||||
t.Errorf("Expected %s %s, not found in %v", flag, expectedValue, args)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildCommandArgs_ZeroValues(t *testing.T) {
|
||||
options := llamacpp.LlamaServerOptions{
|
||||
Port: 0, // Should be excluded
|
||||
Threads: 0, // Should be excluded
|
||||
Temperature: 0, // Should be excluded
|
||||
Model: "", // Should be excluded
|
||||
Verbose: false, // Should be excluded
|
||||
}
|
||||
|
||||
args := options.BuildCommandArgs()
|
||||
|
||||
// Zero values should not appear in arguments
|
||||
excludedArgs := []string{
|
||||
"--port", "0",
|
||||
"--threads", "0",
|
||||
"--temperature", "0",
|
||||
"--model", "",
|
||||
"--verbose",
|
||||
}
|
||||
|
||||
for _, excludedArg := range excludedArgs {
|
||||
if contains(args, excludedArg) {
|
||||
t.Errorf("Zero value argument %q should not be present in %v", excludedArg, args)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildCommandArgs_ArrayFields(t *testing.T) {
|
||||
options := llamacpp.LlamaServerOptions{
|
||||
Lora: []string{"adapter1.bin", "adapter2.bin"},
|
||||
OverrideTensor: []string{"tensor1", "tensor2", "tensor3"},
|
||||
DrySequenceBreaker: []string{".", "!", "?"},
|
||||
}
|
||||
|
||||
args := options.BuildCommandArgs()
|
||||
|
||||
// Check that each array value appears with its flag
|
||||
expectedOccurrences := map[string][]string{
|
||||
"--lora": {"adapter1.bin", "adapter2.bin"},
|
||||
"--override-tensor": {"tensor1", "tensor2", "tensor3"},
|
||||
"--dry-sequence-breaker": {".", "!", "?"},
|
||||
}
|
||||
|
||||
for flag, values := range expectedOccurrences {
|
||||
for _, value := range values {
|
||||
if !containsFlagWithValue(args, flag, value) {
|
||||
t.Errorf("Expected %s %s, not found in %v", flag, value, args)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildCommandArgs_EmptyArrays(t *testing.T) {
|
||||
options := llamacpp.LlamaServerOptions{
|
||||
Lora: []string{}, // Empty array should not generate args
|
||||
OverrideTensor: []string{}, // Empty array should not generate args
|
||||
}
|
||||
|
||||
args := options.BuildCommandArgs()
|
||||
|
||||
excludedArgs := []string{"--lora", "--override-tensor"}
|
||||
for _, excludedArg := range excludedArgs {
|
||||
if contains(args, excludedArg) {
|
||||
t.Errorf("Empty array should not generate argument %q in %v", excludedArg, args)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildCommandArgs_FieldNameConversion(t *testing.T) {
|
||||
// Test snake_case to kebab-case conversion
|
||||
options := llamacpp.LlamaServerOptions{
|
||||
CtxSize: 4096,
|
||||
GPULayers: 32,
|
||||
ThreadsBatch: 2,
|
||||
FlashAttn: true,
|
||||
TopK: 40,
|
||||
TopP: 0.9,
|
||||
}
|
||||
|
||||
args := options.BuildCommandArgs()
|
||||
|
||||
// Check that field names are properly converted
|
||||
expectedFlags := []string{
|
||||
"--ctx-size", // ctx_size -> ctx-size
|
||||
"--gpu-layers", // gpu_layers -> gpu-layers
|
||||
"--threads-batch", // threads_batch -> threads-batch
|
||||
"--flash-attn", // flash_attn -> flash-attn
|
||||
"--top-k", // top_k -> top-k
|
||||
"--top-p", // top_p -> top-p
|
||||
}
|
||||
|
||||
for _, flag := range expectedFlags {
|
||||
if !contains(args, flag) {
|
||||
t.Errorf("Expected flag %q not found in %v", flag, args)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnmarshalJSON_StandardFields(t *testing.T) {
|
||||
jsonData := `{
|
||||
"model": "/path/to/model.gguf",
|
||||
"port": 8080,
|
||||
"host": "localhost",
|
||||
"verbose": true,
|
||||
"ctx_size": 4096,
|
||||
"gpu_layers": 32,
|
||||
"temperature": 0.7
|
||||
}`
|
||||
|
||||
var options llamacpp.LlamaServerOptions
|
||||
err := json.Unmarshal([]byte(jsonData), &options)
|
||||
if err != nil {
|
||||
t.Fatalf("Unmarshal failed: %v", err)
|
||||
}
|
||||
|
||||
if options.Model != "/path/to/model.gguf" {
|
||||
t.Errorf("Expected model '/path/to/model.gguf', got %q", options.Model)
|
||||
}
|
||||
if options.Port != 8080 {
|
||||
t.Errorf("Expected port 8080, got %d", options.Port)
|
||||
}
|
||||
if options.Host != "localhost" {
|
||||
t.Errorf("Expected host 'localhost', got %q", options.Host)
|
||||
}
|
||||
if !options.Verbose {
|
||||
t.Error("Expected verbose to be true")
|
||||
}
|
||||
if options.CtxSize != 4096 {
|
||||
t.Errorf("Expected ctx_size 4096, got %d", options.CtxSize)
|
||||
}
|
||||
if options.GPULayers != 32 {
|
||||
t.Errorf("Expected gpu_layers 32, got %d", options.GPULayers)
|
||||
}
|
||||
if options.Temperature != 0.7 {
|
||||
t.Errorf("Expected temperature 0.7, got %f", options.Temperature)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnmarshalJSON_AlternativeFieldNames(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
jsonData string
|
||||
checkFn func(llamacpp.LlamaServerOptions) error
|
||||
}{
|
||||
{
|
||||
name: "threads alternatives",
|
||||
jsonData: `{"t": 4, "tb": 2}`,
|
||||
checkFn: func(opts llamacpp.LlamaServerOptions) error {
|
||||
if opts.Threads != 4 {
|
||||
return fmt.Errorf("expected threads 4, got %d", opts.Threads)
|
||||
}
|
||||
if opts.ThreadsBatch != 2 {
|
||||
return fmt.Errorf("expected threads_batch 2, got %d", opts.ThreadsBatch)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "context size alternatives",
|
||||
jsonData: `{"c": 2048}`,
|
||||
checkFn: func(opts llamacpp.LlamaServerOptions) error {
|
||||
if opts.CtxSize != 2048 {
|
||||
return fmt.Errorf("expected ctx_size 4096, got %d", opts.CtxSize)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "gpu layers alternatives",
|
||||
jsonData: `{"ngl": 16}`,
|
||||
checkFn: func(opts llamacpp.LlamaServerOptions) error {
|
||||
if opts.GPULayers != 16 {
|
||||
return fmt.Errorf("expected gpu_layers 32, got %d", opts.GPULayers)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "model alternatives",
|
||||
jsonData: `{"m": "/path/model.gguf"}`,
|
||||
checkFn: func(opts llamacpp.LlamaServerOptions) error {
|
||||
if opts.Model != "/path/model.gguf" {
|
||||
return fmt.Errorf("expected model '/path/model.gguf', got %q", opts.Model)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "temperature alternatives",
|
||||
jsonData: `{"temp": 0.8}`,
|
||||
checkFn: func(opts llamacpp.LlamaServerOptions) error {
|
||||
if opts.Temperature != 0.8 {
|
||||
return fmt.Errorf("expected temperature 0.8, got %f", opts.Temperature)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var options llamacpp.LlamaServerOptions
|
||||
err := json.Unmarshal([]byte(tt.jsonData), &options)
|
||||
if err != nil {
|
||||
t.Fatalf("Unmarshal failed: %v", err)
|
||||
}
|
||||
|
||||
if err := tt.checkFn(options); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnmarshalJSON_InvalidJSON(t *testing.T) {
|
||||
invalidJSON := `{"port": "not-a-number", "invalid": syntax}`
|
||||
|
||||
var options llamacpp.LlamaServerOptions
|
||||
err := json.Unmarshal([]byte(invalidJSON), &options)
|
||||
if err == nil {
|
||||
t.Error("Expected error for invalid JSON")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnmarshalJSON_ArrayFields(t *testing.T) {
|
||||
jsonData := `{
|
||||
"lora": ["adapter1.bin", "adapter2.bin"],
|
||||
"override_tensor": ["tensor1", "tensor2"],
|
||||
"dry_sequence_breaker": [".", "!", "?"]
|
||||
}`
|
||||
|
||||
var options llamacpp.LlamaServerOptions
|
||||
err := json.Unmarshal([]byte(jsonData), &options)
|
||||
if err != nil {
|
||||
t.Fatalf("Unmarshal failed: %v", err)
|
||||
}
|
||||
|
||||
expectedLora := []string{"adapter1.bin", "adapter2.bin"}
|
||||
if !reflect.DeepEqual(options.Lora, expectedLora) {
|
||||
t.Errorf("Expected lora %v, got %v", expectedLora, options.Lora)
|
||||
}
|
||||
|
||||
expectedTensors := []string{"tensor1", "tensor2"}
|
||||
if !reflect.DeepEqual(options.OverrideTensor, expectedTensors) {
|
||||
t.Errorf("Expected override_tensor %v, got %v", expectedTensors, options.OverrideTensor)
|
||||
}
|
||||
|
||||
expectedBreakers := []string{".", "!", "?"}
|
||||
if !reflect.DeepEqual(options.DrySequenceBreaker, expectedBreakers) {
|
||||
t.Errorf("Expected dry_sequence_breaker %v, got %v", expectedBreakers, options.DrySequenceBreaker)
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
func contains(slice []string, item string) bool {
|
||||
return slices.Contains(slice, item)
|
||||
}
|
||||
|
||||
func containsFlagWithValue(args []string, flag, value string) bool {
|
||||
for i, arg := range args {
|
||||
if arg == flag {
|
||||
// Check if there's a next argument and it matches the expected value
|
||||
if i+1 < len(args) && args[i+1] == value {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
Reference in New Issue
Block a user