diff --git a/pkg/server/handlers.go b/pkg/server/handlers.go index 4ddbfea..9e31df9 100644 --- a/pkg/server/handlers.go +++ b/pkg/server/handlers.go @@ -4,6 +4,8 @@ import ( "llamactl/pkg/config" "llamactl/pkg/manager" "net/http" + "net/http/httputil" + "sync" "time" ) @@ -11,6 +13,8 @@ type Handler struct { InstanceManager manager.InstanceManager cfg config.AppConfig httpClient *http.Client + remoteProxies map[string]*httputil.ReverseProxy // Cache of remote proxies by instance name + remoteProxiesMu sync.RWMutex } func NewHandler(im manager.InstanceManager, cfg config.AppConfig) *Handler { @@ -20,5 +24,6 @@ func NewHandler(im manager.InstanceManager, cfg config.AppConfig) *Handler { httpClient: &http.Client{ Timeout: 30 * time.Second, }, + remoteProxies: make(map[string]*httputil.ReverseProxy), } } diff --git a/pkg/server/handlers_instances.go b/pkg/server/handlers_instances.go index 6a90cba..90a00f9 100644 --- a/pkg/server/handlers_instances.go +++ b/pkg/server/handlers_instances.go @@ -3,11 +3,12 @@ package server import ( "encoding/json" "fmt" - "io" "llamactl/pkg/config" "llamactl/pkg/instance" "llamactl/pkg/manager" "net/http" + "net/http/httputil" + "net/url" "strconv" "strings" @@ -361,6 +362,12 @@ func (h *Handler) ProxyToInstance() http.HandlerFunc { return } + // Check if this is a remote instance + if inst.IsRemote() { + h.RemoteInstanceProxy(w, r, name, inst) + return + } + if !inst.IsRunning() { http.Error(w, "Instance is not running", http.StatusServiceUnavailable) return @@ -399,59 +406,53 @@ func (h *Handler) RemoteInstanceProxy(w http.ResponseWriter, r *http.Request, na } nodeName := options.Nodes[0] - var nodeConfig *config.NodeConfig - for i := range h.cfg.Nodes { - if h.cfg.Nodes[i].Name == nodeName { - nodeConfig = &h.cfg.Nodes[i] - break + + // Check if we have a cached proxy for this instance + h.remoteProxiesMu.RLock() + proxy, exists := h.remoteProxies[name] + h.remoteProxiesMu.RUnlock() + + if !exists { + // Find node configuration + var nodeConfig *config.NodeConfig + for i := range h.cfg.Nodes { + if h.cfg.Nodes[i].Name == nodeName { + nodeConfig = &h.cfg.Nodes[i] + break + } } - } - if nodeConfig == nil { - http.Error(w, fmt.Sprintf("Node %s not found", nodeName), http.StatusInternalServerError) - return - } - - // Strip the "/api/v1/instances//proxy" prefix from the request URL - prefix := fmt.Sprintf("/api/v1/instances/%s/proxy", name) - proxyPath := r.URL.Path[len(prefix):] - - // Build the remote URL - remoteURL := fmt.Sprintf("%s/api/v1/instances/%s/proxy%s", nodeConfig.Address, name, proxyPath) - - // Create a new request to the remote node - req, err := http.NewRequest(r.Method, remoteURL, r.Body) - if err != nil { - http.Error(w, "Failed to create remote request: "+err.Error(), http.StatusInternalServerError) - return - } - - // Copy headers - req.Header = r.Header.Clone() - - // Add API key if configured - if nodeConfig.APIKey != "" { - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", nodeConfig.APIKey)) - } - - // Forward the request - resp, err := h.httpClient.Do(req) - if err != nil { - http.Error(w, "Failed to proxy to remote instance: "+err.Error(), http.StatusBadGateway) - return - } - defer resp.Body.Close() - - // Copy response headers - for key, values := range resp.Header { - for _, value := range values { - w.Header().Add(key, value) + if nodeConfig == nil { + http.Error(w, fmt.Sprintf("Node %s not found", nodeName), http.StatusInternalServerError) + return } + + // Create reverse proxy to remote node + targetURL, err := url.Parse(nodeConfig.Address) + if err != nil { + http.Error(w, "Failed to parse node address: "+err.Error(), http.StatusInternalServerError) + return + } + + proxy = httputil.NewSingleHostReverseProxy(targetURL) + + // Modify request before forwarding + originalDirector := proxy.Director + apiKey := nodeConfig.APIKey // Capture for closure + proxy.Director = func(req *http.Request) { + originalDirector(req) + // Add API key if configured + if apiKey != "" { + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey)) + } + } + + // Cache the proxy + h.remoteProxiesMu.Lock() + h.remoteProxies[name] = proxy + h.remoteProxiesMu.Unlock() } - // Copy status code - w.WriteHeader(resp.StatusCode) - - // Copy response body - io.Copy(w, resp.Body) + // Forward the request using the cached proxy + proxy.ServeHTTP(w, r) }