summaryrefslogtreecommitdiff
path: root/internal/handlers/claudomator_proxy.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/handlers/claudomator_proxy.go')
-rw-r--r--internal/handlers/claudomator_proxy.go126
1 files changed, 126 insertions, 0 deletions
diff --git a/internal/handlers/claudomator_proxy.go b/internal/handlers/claudomator_proxy.go
new file mode 100644
index 0000000..bfbbabc
--- /dev/null
+++ b/internal/handlers/claudomator_proxy.go
@@ -0,0 +1,126 @@
+package handlers
+
+import (
+ "io"
+ "net"
+ "net/http"
+ "net/http/httputil"
+ "net/url"
+ "strings"
+)
+
+// NewClaudomatorProxy returns an http.Handler that reverse-proxies requests to
+// targetURL, stripping the "/claudomator" prefix from the path. WebSocket
+// upgrade requests are handled via raw TCP hijacking to support long-lived
+// connections.
+func NewClaudomatorProxy(targetURL string) http.Handler {
+ target, err := url.Parse(targetURL)
+ if err != nil {
+ panic("claudomator: invalid target URL: " + err.Error())
+ }
+
+ rp := &httputil.ReverseProxy{
+ Director: func(req *http.Request) {
+ req.URL.Scheme = target.Scheme
+ req.URL.Host = target.Host
+
+ // Strip /claudomator prefix
+ stripped := strings.TrimPrefix(req.URL.Path, "/claudomator")
+ if stripped == "" {
+ stripped = "/"
+ }
+ req.URL.Path = stripped
+
+ if req.URL.RawPath != "" {
+ rawStripped := strings.TrimPrefix(req.URL.RawPath, "/claudomator")
+ if rawStripped == "" {
+ rawStripped = "/"
+ }
+ req.URL.RawPath = rawStripped
+ }
+ },
+ ModifyResponse: func(resp *http.Response) error {
+ // Preserve Service-Worker-Allowed header
+ if swa := resp.Header.Get("Service-Worker-Allowed"); swa != "" {
+ resp.Header.Set("Service-Worker-Allowed", swa)
+ }
+ return nil
+ },
+ }
+
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if strings.EqualFold(r.Header.Get("Upgrade"), "websocket") {
+ proxyWebSocket(w, r, target)
+ return
+ }
+ rp.ServeHTTP(w, r)
+ })
+}
+
+// proxyWebSocket handles WebSocket upgrade via raw TCP hijacking.
+func proxyWebSocket(w http.ResponseWriter, r *http.Request, target *url.URL) {
+ // Determine host:port for dialing
+ host := target.Host
+ if target.Port() == "" {
+ switch target.Scheme {
+ case "https":
+ host += ":443"
+ default:
+ host += ":80"
+ }
+ }
+
+ upstream, err := net.Dial("tcp", host)
+ if err != nil {
+ http.Error(w, "bad gateway", http.StatusBadGateway)
+ return
+ }
+ defer upstream.Close()
+
+ // Rewrite path on the request before forwarding
+ r.URL.Scheme = target.Scheme
+ r.URL.Host = target.Host
+ stripped := strings.TrimPrefix(r.URL.Path, "/claudomator")
+ if stripped == "" {
+ stripped = "/"
+ }
+ r.URL.Path = stripped
+ if r.URL.RawPath != "" {
+ rawStripped := strings.TrimPrefix(r.URL.RawPath, "/claudomator")
+ if rawStripped == "" {
+ rawStripped = "/"
+ }
+ r.URL.RawPath = rawStripped
+ }
+ r.RequestURI = r.URL.RequestURI()
+
+ // Write the HTTP request to the upstream connection
+ if err := r.Write(upstream); err != nil {
+ http.Error(w, "bad gateway", http.StatusBadGateway)
+ return
+ }
+
+ // Hijack the client connection
+ hijacker, ok := w.(http.Hijacker)
+ if !ok {
+ http.Error(w, "websocket not supported", http.StatusInternalServerError)
+ return
+ }
+ clientConn, _, err := hijacker.Hijack()
+ if err != nil {
+ return
+ }
+ defer clientConn.Close()
+
+ // Bidirectional copy — no deadlines so long-lived WS connections survive
+ done := make(chan struct{}, 2)
+ go func() {
+ _, _ = io.Copy(upstream, clientConn)
+ done <- struct{}{}
+ }()
+ go func() {
+ _, _ = io.Copy(clientConn, upstream)
+ done <- struct{}{}
+ }()
+ <-done
+}