From 6298b03636fa407fa69f0c947a10980990b1595d Mon Sep 17 00:00:00 2001 From: LordMathis Date: Tue, 7 Oct 2025 18:57:08 +0200 Subject: [PATCH] Refactor RemoteOpenAIProxy to use cached proxies and restore request body handling --- pkg/server/handlers_instances.go | 8 +-- pkg/server/handlers_openai.go | 101 +++++++++++++++---------------- 2 files changed, 54 insertions(+), 55 deletions(-) diff --git a/pkg/server/handlers_instances.go b/pkg/server/handlers_instances.go index 90a00f9..8c325ae 100644 --- a/pkg/server/handlers_instances.go +++ b/pkg/server/handlers_instances.go @@ -407,9 +407,9 @@ func (h *Handler) RemoteInstanceProxy(w http.ResponseWriter, r *http.Request, na nodeName := options.Nodes[0] - // Check if we have a cached proxy for this instance + // Check if we have a cached proxy for this node h.remoteProxiesMu.RLock() - proxy, exists := h.remoteProxies[name] + proxy, exists := h.remoteProxies[nodeName] h.remoteProxiesMu.RUnlock() if !exists { @@ -447,9 +447,9 @@ func (h *Handler) RemoteInstanceProxy(w http.ResponseWriter, r *http.Request, na } } - // Cache the proxy + // Cache the proxy by node name h.remoteProxiesMu.Lock() - h.remoteProxies[name] = proxy + h.remoteProxies[nodeName] = proxy h.remoteProxiesMu.Unlock() } diff --git a/pkg/server/handlers_openai.go b/pkg/server/handlers_openai.go index 07196f0..eea8440 100644 --- a/pkg/server/handlers_openai.go +++ b/pkg/server/handlers_openai.go @@ -8,6 +8,8 @@ import ( "llamactl/pkg/config" "llamactl/pkg/instance" "net/http" + "net/http/httputil" + "net/url" ) // OpenAIListInstances godoc @@ -93,7 +95,9 @@ func (h *Handler) OpenAIProxy() http.HandlerFunc { // Check if this is a remote instance if inst.IsRemote() { - h.RemoteOpenAIProxy(w, r, modelName, inst, bodyBytes) + // Restore the body for the remote proxy + r.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + h.RemoteOpenAIProxy(w, r, modelName, inst) return } @@ -149,7 +153,7 @@ func (h *Handler) OpenAIProxy() http.HandlerFunc { } // RemoteOpenAIProxy proxies OpenAI-compatible requests to a remote instance -func (h *Handler) RemoteOpenAIProxy(w http.ResponseWriter, r *http.Request, modelName string, inst *instance.Process, bodyBytes []byte) { +func (h *Handler) RemoteOpenAIProxy(w http.ResponseWriter, r *http.Request, modelName string, inst *instance.Process) { // Get the node name from instance options options := inst.GetOptions() if options == nil || len(options.Nodes) == 0 { @@ -158,58 +162,53 @@ func (h *Handler) RemoteOpenAIProxy(w http.ResponseWriter, r *http.Request, mode } nodeName := options.Nodes[0] - var nodeConfig *config.NodeConfig - for i := range h.cfg.Nodes { - if h.cfg.Nodes[i].Name == nodeName { - nodeConfig = &h.cfg.Nodes[i] - break + + // Check if we have a cached proxy for this node + h.remoteProxiesMu.RLock() + proxy, exists := h.remoteProxies[nodeName] + h.remoteProxiesMu.RUnlock() + + if !exists { + // Find node configuration + var nodeConfig *config.NodeConfig + for i := range h.cfg.Nodes { + if h.cfg.Nodes[i].Name == nodeName { + nodeConfig = &h.cfg.Nodes[i] + break + } } - } - if nodeConfig == nil { - http.Error(w, fmt.Sprintf("Node %s not found", nodeName), http.StatusInternalServerError) - return - } - - // Build the remote URL - forward to the same OpenAI endpoint on the remote node - remoteURL := fmt.Sprintf("%s%s", nodeConfig.Address, r.URL.Path) - if r.URL.RawQuery != "" { - remoteURL += "?" + r.URL.RawQuery - } - - // Create a new request to the remote node - req, err := http.NewRequest(r.Method, remoteURL, bytes.NewReader(bodyBytes)) - if err != nil { - http.Error(w, "Failed to create remote request: "+err.Error(), http.StatusInternalServerError) - return - } - - // Copy headers - req.Header = r.Header.Clone() - - // Add API key if configured - if nodeConfig.APIKey != "" { - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", nodeConfig.APIKey)) - } - - // Forward the request - resp, err := h.httpClient.Do(req) - if err != nil { - http.Error(w, "Failed to proxy to remote instance: "+err.Error(), http.StatusBadGateway) - return - } - defer resp.Body.Close() - - // Copy response headers - for key, values := range resp.Header { - for _, value := range values { - w.Header().Add(key, value) + if nodeConfig == nil { + http.Error(w, fmt.Sprintf("Node %s not found", nodeName), http.StatusInternalServerError) + return } + + // Create reverse proxy to remote node + targetURL, err := url.Parse(nodeConfig.Address) + if err != nil { + http.Error(w, "Failed to parse node address: "+err.Error(), http.StatusInternalServerError) + return + } + + proxy = httputil.NewSingleHostReverseProxy(targetURL) + + // Modify request before forwarding + originalDirector := proxy.Director + apiKey := nodeConfig.APIKey // Capture for closure + proxy.Director = func(req *http.Request) { + originalDirector(req) + // Add API key if configured + if apiKey != "" { + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey)) + } + } + + // Cache the proxy + h.remoteProxiesMu.Lock() + h.remoteProxies[nodeName] = proxy + h.remoteProxiesMu.Unlock() } - // Copy status code - w.WriteHeader(resp.StatusCode) - - // Copy response body - io.Copy(w, resp.Body) + // Forward the request using the cached proxy + proxy.ServeHTTP(w, r) }