diff --git a/handler.go b/handler.go index 495d1ec..153fe23 100644 --- a/handler.go +++ b/handler.go @@ -2,36 +2,51 @@ package main import ( "bufio" + "context" "crypto/tls" + "fmt" "io" + "net" "net/http" "net/http/httputil" - "net/url" "strings" + "time" ) type AuthProvider func() string type ProxyHandler struct { auth AuthProvider - upstream string + upstreamAddr string + tlsName string logger *CondLogger + dialer *net.Dialer httptransport http.RoundTripper resolver *Resolver } -func NewProxyHandler(upstream string, auth AuthProvider, resolver *Resolver, logger *CondLogger) *ProxyHandler { - proxyurl, err := url.Parse("https://" + upstream) - if err != nil { - panic(err) +func NewProxyHandler(upstream *Endpoint, auth AuthProvider, resolver *Resolver, logger *CondLogger) *ProxyHandler { + dialer := &net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, } + netaddr := net.JoinHostPort(upstream.Host, fmt.Sprintf("%d", upstream.Port)) httptransport := &http.Transport{ - Proxy: http.ProxyURL(proxyurl), + MaxIdleConns: 100, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + Proxy: http.ProxyURL(upstream.URL()), + DialContext: func(ctx context.Context, _, _ string) (net.Conn, error) { + return dialer.DialContext(ctx, "tcp", netaddr) + }, } return &ProxyHandler{ auth: auth, - upstream: upstream, + upstreamAddr: netaddr, + tlsName: upstream.TLSName, logger: logger, + dialer: dialer, httptransport: httptransport, resolver: resolver, } @@ -48,17 +63,25 @@ func (s *ProxyHandler) ServeHTTP(wr http.ResponseWriter, req *http.Request) { return } - conn, err := tls.Dial("tcp", s.upstream, nil) + conn, err := s.dialer.DialContext(req.Context(), "tcp", s.upstreamAddr) if err != nil { - s.logger.Error("Can't dial tls upstream: %v", err) - http.Error(wr, "Can't dial tls upstream", http.StatusBadGateway) + s.logger.Error("Can't dial upstream: %v", err) + http.Error(wr, "Can't dial upstream", http.StatusBadGateway) return } + defer conn.Close() + + if s.tlsName != "" { + conn = tls.Client(conn, &tls.Config{ + ServerName: s.tlsName, + }) + defer conn.Close() + } _, err = conn.Write(rawreq) if err != nil { - s.logger.Error("Can't write tls upstream: %v", err) - http.Error(wr, "Can't write tls upstream", http.StatusBadGateway) + s.logger.Error("Can't write upstream: %v", err) + http.Error(wr, "Can't write upstream", http.StatusBadGateway) return } bufrd := bufio.NewReader(conn) @@ -74,14 +97,22 @@ func (s *ProxyHandler) ServeHTTP(wr http.ResponseWriter, req *http.Request) { proxyResp.Header.Get("X-Hola-Error") == "Forbidden Host" { s.logger.Info("Request %s denied by upstream. Rescuing it with resolve&rewrite workaround.", req.URL.String()) - conn.Close() - conn, err = tls.Dial("tcp", s.upstream, nil) + + conn, err = s.dialer.DialContext(req.Context(), "tcp", s.upstreamAddr) if err != nil { - s.logger.Error("Can't dial tls upstream: %v", err) - http.Error(wr, "Can't dial tls upstream", http.StatusBadGateway) + s.logger.Error("Can't dial upstream: %v", err) + http.Error(wr, "Can't dial upstream", http.StatusBadGateway) return } defer conn.Close() + + if s.tlsName != "" { + conn = tls.Client(conn, &tls.Config{ + ServerName: s.tlsName, + }) + defer conn.Close() + } + err = rewriteConnectReq(req, s.resolver) if err != nil { s.logger.Error("Can't rewrite request: %v", err) @@ -101,7 +132,6 @@ func (s *ProxyHandler) ServeHTTP(wr http.ResponseWriter, req *http.Request) { return } } else { - defer conn.Close() responseBytes, err = httputil.DumpResponse(proxyResp, false) if err != nil { s.logger.Error("Can't dump response: %v", err) @@ -160,8 +190,8 @@ func (s *ProxyHandler) ServeHTTP(wr http.ResponseWriter, req *http.Request) { proxyReq.Header.Set("Proxy-Authorization", s.auth()) rawreq, _ := httputil.DumpRequest(proxyReq, false) - // Prepare upstream TLS conn - conn, err := tls.Dial("tcp", s.upstream, nil) + // Prepare upstream conn + conn, err := s.dialer.DialContext(req.Context(), "tcp", s.upstreamAddr) if err != nil { s.logger.Error("Can't dial tls upstream: %v", err) http.Error(wr, "Can't dial tls upstream", http.StatusBadGateway) @@ -169,6 +199,13 @@ func (s *ProxyHandler) ServeHTTP(wr http.ResponseWriter, req *http.Request) { } defer conn.Close() + if s.tlsName != "" { + conn = tls.Client(conn, &tls.Config{ + ServerName: s.tlsName, + }) + defer conn.Close() + } + // Send proxy request _, err = conn.Write(rawreq) if err != nil { diff --git a/holaapi.go b/holaapi.go index 8ade758..3d145fc 100644 --- a/holaapi.go +++ b/holaapi.go @@ -106,9 +106,9 @@ func (c *FallbackConfig) ShuffleAgents() { func (c *FallbackConfig) Clone() *FallbackConfig { return &FallbackConfig{ - Agents: append([]FallbackAgent(nil), c.Agents...), + Agents: append([]FallbackAgent(nil), c.Agents...), UpdatedAt: c.UpdatedAt, - TTL: c.TTL, + TTL: c.TTL, } } @@ -338,7 +338,7 @@ func httpClientWithProxy(agent *FallbackAgent) *http.Client { t.Proxy = http.ProxyURL(agent.ToProxy()) addr := net.JoinHostPort(agent.IP, fmt.Sprintf("%d", agent.Port)) t.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) { - return dialer.DialContext(ctx, "tcp4", addr) + return dialer.DialContext(ctx, "tcp", addr) } } return &http.Client{ diff --git a/main.go b/main.go index 7a06b57..c58989e 100644 --- a/main.go +++ b/main.go @@ -126,7 +126,7 @@ func run() int { logWriter.Close() return 5 } - mainLogger.Info("Endpoint: %s", endpoint) + mainLogger.Info("Endpoint: %s", endpoint.URL().String()) mainLogger.Info("Starting proxy server...") handler := NewProxyHandler(endpoint, auth, resolver, proxyLogger) mainLogger.Info("Init complete.") diff --git a/utils.go b/utils.go index 826636b..8ea99e1 100644 --- a/utils.go +++ b/utils.go @@ -18,6 +18,26 @@ import ( "time" ) +type Endpoint struct { + Host string + Port uint16 + TLSName string +} + +func (e *Endpoint) URL() *url.URL { + if e.TLSName == "" { + return &url.URL{ + Scheme: "http", + Host: net.JoinHostPort(e.Host, fmt.Sprintf("%d", e.Port)), + } + } else { + return &url.URL{ + Scheme: "https", + Host: net.JoinHostPort(e.TLSName, fmt.Sprintf("%d", e.Port)), + } + } +} + func basic_auth_header(login, password string) string { return "basic " + base64.StdEncoding.EncodeToString( []byte(login+":"+password)) @@ -123,14 +143,15 @@ func print_proxies(country string, proxy_type string, limit uint, timeout time.D return 0 } -func get_endpoint(tunnels *ZGetTunnelsResponse, typ string, trial bool, force_port_field string) (string, error) { - var hostname string - for k := range tunnels.IPList { +func get_endpoint(tunnels *ZGetTunnelsResponse, typ string, trial bool, force_port_field string) (*Endpoint, error) { + var hostname, ip string + for k, v := range tunnels.IPList { hostname = k + ip = v break } - if hostname == "" { - return "", errors.New("No tunnels found in API response") + if hostname == "" || ip == "" { + return nil, errors.New("No tunnels found in API response") } var port uint16 @@ -157,10 +178,14 @@ func get_endpoint(tunnels *ZGetTunnelsResponse, typ string, trial bool, force_po port = tunnels.Port.Peer } } else { - return "", errors.New("Unsupported port type") + return nil, errors.New("Unsupported port type") } } - return net.JoinHostPort(hostname, strconv.FormatUint(uint64(port), 10)), nil + return &Endpoint{ + Host: ip, + Port: port, + TLSName: hostname, + }, nil } // Hop-by-hop headers. These are removed when sent to the backend.