Skip to content

Commit

Permalink
Merge pull request #35 from Snawoot/skip_agent_resolve
Browse files Browse the repository at this point in the history
Skip upstream agent resolve
  • Loading branch information
Snawoot authored Mar 15, 2021
2 parents 3b09f31 + edd7230 commit 4faf6aa
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 31 deletions.
77 changes: 57 additions & 20 deletions handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,36 +2,51 @@ package main

import (
"bufio"
"context"
"crypto/tls"
"fmt"
"io"
"net"
"net/http"
"net/http/httputil"
"net/url"
"strings"
"time"
)

type AuthProvider func() string

type ProxyHandler struct {
auth AuthProvider
upstream string
upstreamAddr string
tlsName string
logger *CondLogger
dialer *net.Dialer
httptransport http.RoundTripper
resolver *Resolver
}

func NewProxyHandler(upstream string, auth AuthProvider, resolver *Resolver, logger *CondLogger) *ProxyHandler {
proxyurl, err := url.Parse("https://" + upstream)
if err != nil {
panic(err)
func NewProxyHandler(upstream *Endpoint, auth AuthProvider, resolver *Resolver, logger *CondLogger) *ProxyHandler {
dialer := &net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}
netaddr := net.JoinHostPort(upstream.Host, fmt.Sprintf("%d", upstream.Port))
httptransport := &http.Transport{
Proxy: http.ProxyURL(proxyurl),
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
Proxy: http.ProxyURL(upstream.URL()),
DialContext: func(ctx context.Context, _, _ string) (net.Conn, error) {
return dialer.DialContext(ctx, "tcp", netaddr)
},
}
return &ProxyHandler{
auth: auth,
upstream: upstream,
upstreamAddr: netaddr,
tlsName: upstream.TLSName,
logger: logger,
dialer: dialer,
httptransport: httptransport,
resolver: resolver,
}
Expand All @@ -48,17 +63,25 @@ func (s *ProxyHandler) ServeHTTP(wr http.ResponseWriter, req *http.Request) {
return
}

conn, err := tls.Dial("tcp", s.upstream, nil)
conn, err := s.dialer.DialContext(req.Context(), "tcp", s.upstreamAddr)
if err != nil {
s.logger.Error("Can't dial tls upstream: %v", err)
http.Error(wr, "Can't dial tls upstream", http.StatusBadGateway)
s.logger.Error("Can't dial upstream: %v", err)
http.Error(wr, "Can't dial upstream", http.StatusBadGateway)
return
}
defer conn.Close()

if s.tlsName != "" {
conn = tls.Client(conn, &tls.Config{
ServerName: s.tlsName,
})
defer conn.Close()
}

_, err = conn.Write(rawreq)
if err != nil {
s.logger.Error("Can't write tls upstream: %v", err)
http.Error(wr, "Can't write tls upstream", http.StatusBadGateway)
s.logger.Error("Can't write upstream: %v", err)
http.Error(wr, "Can't write upstream", http.StatusBadGateway)
return
}
bufrd := bufio.NewReader(conn)
Expand All @@ -74,14 +97,22 @@ func (s *ProxyHandler) ServeHTTP(wr http.ResponseWriter, req *http.Request) {
proxyResp.Header.Get("X-Hola-Error") == "Forbidden Host" {
s.logger.Info("Request %s denied by upstream. Rescuing it with resolve&rewrite workaround.",
req.URL.String())
conn.Close()
conn, err = tls.Dial("tcp", s.upstream, nil)

conn, err = s.dialer.DialContext(req.Context(), "tcp", s.upstreamAddr)
if err != nil {
s.logger.Error("Can't dial tls upstream: %v", err)
http.Error(wr, "Can't dial tls upstream", http.StatusBadGateway)
s.logger.Error("Can't dial upstream: %v", err)
http.Error(wr, "Can't dial upstream", http.StatusBadGateway)
return
}
defer conn.Close()

if s.tlsName != "" {
conn = tls.Client(conn, &tls.Config{
ServerName: s.tlsName,
})
defer conn.Close()
}

err = rewriteConnectReq(req, s.resolver)
if err != nil {
s.logger.Error("Can't rewrite request: %v", err)
Expand All @@ -101,7 +132,6 @@ func (s *ProxyHandler) ServeHTTP(wr http.ResponseWriter, req *http.Request) {
return
}
} else {
defer conn.Close()
responseBytes, err = httputil.DumpResponse(proxyResp, false)
if err != nil {
s.logger.Error("Can't dump response: %v", err)
Expand Down Expand Up @@ -160,15 +190,22 @@ func (s *ProxyHandler) ServeHTTP(wr http.ResponseWriter, req *http.Request) {
proxyReq.Header.Set("Proxy-Authorization", s.auth())
rawreq, _ := httputil.DumpRequest(proxyReq, false)

// Prepare upstream TLS conn
conn, err := tls.Dial("tcp", s.upstream, nil)
// Prepare upstream conn
conn, err := s.dialer.DialContext(req.Context(), "tcp", s.upstreamAddr)
if err != nil {
s.logger.Error("Can't dial tls upstream: %v", err)
http.Error(wr, "Can't dial tls upstream", http.StatusBadGateway)
return
}
defer conn.Close()

if s.tlsName != "" {
conn = tls.Client(conn, &tls.Config{
ServerName: s.tlsName,
})
defer conn.Close()
}

// Send proxy request
_, err = conn.Write(rawreq)
if err != nil {
Expand Down
6 changes: 3 additions & 3 deletions holaapi.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,9 @@ func (c *FallbackConfig) ShuffleAgents() {

func (c *FallbackConfig) Clone() *FallbackConfig {
return &FallbackConfig{
Agents: append([]FallbackAgent(nil), c.Agents...),
Agents: append([]FallbackAgent(nil), c.Agents...),
UpdatedAt: c.UpdatedAt,
TTL: c.TTL,
TTL: c.TTL,
}
}

Expand Down Expand Up @@ -338,7 +338,7 @@ func httpClientWithProxy(agent *FallbackAgent) *http.Client {
t.Proxy = http.ProxyURL(agent.ToProxy())
addr := net.JoinHostPort(agent.IP, fmt.Sprintf("%d", agent.Port))
t.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) {
return dialer.DialContext(ctx, "tcp4", addr)
return dialer.DialContext(ctx, "tcp", addr)
}
}
return &http.Client{
Expand Down
2 changes: 1 addition & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ func run() int {
logWriter.Close()
return 5
}
mainLogger.Info("Endpoint: %s", endpoint)
mainLogger.Info("Endpoint: %s", endpoint.URL().String())
mainLogger.Info("Starting proxy server...")
handler := NewProxyHandler(endpoint, auth, resolver, proxyLogger)
mainLogger.Info("Init complete.")
Expand Down
39 changes: 32 additions & 7 deletions utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,26 @@ import (
"time"
)

type Endpoint struct {
Host string
Port uint16
TLSName string
}

func (e *Endpoint) URL() *url.URL {
if e.TLSName == "" {
return &url.URL{
Scheme: "http",
Host: net.JoinHostPort(e.Host, fmt.Sprintf("%d", e.Port)),
}
} else {
return &url.URL{
Scheme: "https",
Host: net.JoinHostPort(e.TLSName, fmt.Sprintf("%d", e.Port)),
}
}
}

func basic_auth_header(login, password string) string {
return "basic " + base64.StdEncoding.EncodeToString(
[]byte(login+":"+password))
Expand Down Expand Up @@ -123,14 +143,15 @@ func print_proxies(country string, proxy_type string, limit uint, timeout time.D
return 0
}

func get_endpoint(tunnels *ZGetTunnelsResponse, typ string, trial bool, force_port_field string) (string, error) {
var hostname string
for k := range tunnels.IPList {
func get_endpoint(tunnels *ZGetTunnelsResponse, typ string, trial bool, force_port_field string) (*Endpoint, error) {
var hostname, ip string
for k, v := range tunnels.IPList {
hostname = k
ip = v
break
}
if hostname == "" {
return "", errors.New("No tunnels found in API response")
if hostname == "" || ip == "" {
return nil, errors.New("No tunnels found in API response")
}

var port uint16
Expand All @@ -157,10 +178,14 @@ func get_endpoint(tunnels *ZGetTunnelsResponse, typ string, trial bool, force_po
port = tunnels.Port.Peer
}
} else {
return "", errors.New("Unsupported port type")
return nil, errors.New("Unsupported port type")
}
}
return net.JoinHostPort(hostname, strconv.FormatUint(uint64(port), 10)), nil
return &Endpoint{
Host: ip,
Port: port,
TLSName: hostname,
}, nil
}

// Hop-by-hop headers. These are removed when sent to the backend.
Expand Down

0 comments on commit 4faf6aa

Please sign in to comment.