diff options
Diffstat (limited to 'internal/handlers/claudomator_proxy.go')
| -rw-r--r-- | internal/handlers/claudomator_proxy.go | 126 |
1 files changed, 126 insertions, 0 deletions
diff --git a/internal/handlers/claudomator_proxy.go b/internal/handlers/claudomator_proxy.go new file mode 100644 index 0000000..bfbbabc --- /dev/null +++ b/internal/handlers/claudomator_proxy.go @@ -0,0 +1,126 @@ +package handlers + +import ( + "io" + "net" + "net/http" + "net/http/httputil" + "net/url" + "strings" +) + +// NewClaudomatorProxy returns an http.Handler that reverse-proxies requests to +// targetURL, stripping the "/claudomator" prefix from the path. WebSocket +// upgrade requests are handled via raw TCP hijacking to support long-lived +// connections. +func NewClaudomatorProxy(targetURL string) http.Handler { + target, err := url.Parse(targetURL) + if err != nil { + panic("claudomator: invalid target URL: " + err.Error()) + } + + rp := &httputil.ReverseProxy{ + Director: func(req *http.Request) { + req.URL.Scheme = target.Scheme + req.URL.Host = target.Host + + // Strip /claudomator prefix + stripped := strings.TrimPrefix(req.URL.Path, "/claudomator") + if stripped == "" { + stripped = "/" + } + req.URL.Path = stripped + + if req.URL.RawPath != "" { + rawStripped := strings.TrimPrefix(req.URL.RawPath, "/claudomator") + if rawStripped == "" { + rawStripped = "/" + } + req.URL.RawPath = rawStripped + } + }, + ModifyResponse: func(resp *http.Response) error { + // Preserve Service-Worker-Allowed header + if swa := resp.Header.Get("Service-Worker-Allowed"); swa != "" { + resp.Header.Set("Service-Worker-Allowed", swa) + } + return nil + }, + } + + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.EqualFold(r.Header.Get("Upgrade"), "websocket") { + proxyWebSocket(w, r, target) + return + } + rp.ServeHTTP(w, r) + }) +} + +// proxyWebSocket handles WebSocket upgrade via raw TCP hijacking. +func proxyWebSocket(w http.ResponseWriter, r *http.Request, target *url.URL) { + // Determine host:port for dialing + host := target.Host + if target.Port() == "" { + switch target.Scheme { + case "https": + host += ":443" + default: + host += ":80" + } + } + + upstream, err := net.Dial("tcp", host) + if err != nil { + http.Error(w, "bad gateway", http.StatusBadGateway) + return + } + defer upstream.Close() + + // Rewrite path on the request before forwarding + r.URL.Scheme = target.Scheme + r.URL.Host = target.Host + stripped := strings.TrimPrefix(r.URL.Path, "/claudomator") + if stripped == "" { + stripped = "/" + } + r.URL.Path = stripped + if r.URL.RawPath != "" { + rawStripped := strings.TrimPrefix(r.URL.RawPath, "/claudomator") + if rawStripped == "" { + rawStripped = "/" + } + r.URL.RawPath = rawStripped + } + r.RequestURI = r.URL.RequestURI() + + // Write the HTTP request to the upstream connection + if err := r.Write(upstream); err != nil { + http.Error(w, "bad gateway", http.StatusBadGateway) + return + } + + // Hijack the client connection + hijacker, ok := w.(http.Hijacker) + if !ok { + http.Error(w, "websocket not supported", http.StatusInternalServerError) + return + } + clientConn, _, err := hijacker.Hijack() + if err != nil { + return + } + defer clientConn.Close() + + // Bidirectional copy — no deadlines so long-lived WS connections survive + done := make(chan struct{}, 2) + go func() { + _, _ = io.Copy(upstream, clientConn) + done <- struct{}{} + }() + go func() { + _, _ = io.Copy(clientConn, upstream) + done <- struct{}{} + }() + <-done +} |
