diff --git a/pkg/instance/proxy.go b/pkg/instance/proxy.go index d261ddf..a61dfb6 100644 --- a/pkg/instance/proxy.go +++ b/pkg/instance/proxy.go @@ -37,8 +37,9 @@ type proxy struct { proxyOnce sync.Once proxyErr error - lastRequestTime atomic.Int64 - timeProvider TimeProvider + lastRequestTime atomic.Int64 + inflightRequests atomic.Int32 + timeProvider TimeProvider } // newProxy creates a new Proxy for the given instance @@ -153,6 +154,31 @@ func (p *proxy) build() (*httputil.ReverseProxy, error) { return proxy, nil } +// serveHTTP handles HTTP requests with inflight tracking and shutting down state checks +func (p *proxy) serveHTTP(w http.ResponseWriter, r *http.Request) error { + // Check if instance is shutting down + status := p.instance.GetStatus() + if status == ShuttingDown { + w.WriteHeader(http.StatusServiceUnavailable) + w.Write([]byte("Instance is shutting down")) + return fmt.Errorf("instance is shutting down") + } + + // Get the reverse proxy + reverseProxy, err := p.get() + if err != nil { + return err + } + + // Track inflight requests + p.incInflightRequests() + defer p.decInflightRequests() + + // Serve the request + reverseProxy.ServeHTTP(w, r) + return nil +} + // clear resets the proxy, allowing it to be recreated when options change. func (p *proxy) clear() { p.mu.Lock() @@ -160,7 +186,7 @@ func (p *proxy) clear() { p.proxy = nil p.proxyErr = nil - p.proxyOnce = sync.Once{} // Reset Once for next GetProxy call + p.proxyOnce = sync.Once{} } // updateLastRequestTime updates the last request access time for the instance @@ -199,3 +225,18 @@ func (p *proxy) shouldTimeout() bool { func (p *proxy) setTimeProvider(tp TimeProvider) { p.timeProvider = tp } + +// incInflightRequests increments the inflight request counter +func (p *proxy) incInflightRequests() { + p.inflightRequests.Add(1) +} + +// decInflightRequests decrements the inflight request counter +func (p *proxy) decInflightRequests() { + p.inflightRequests.Add(-1) +} + +// getInflightRequests returns the current number of inflight requests +func (p *proxy) getInflightRequests() int32 { + return p.inflightRequests.Load() +}