From c8eb2c90b979391d069ebde7021ba9380377ac34 Mon Sep 17 00:00:00 2001 From: yuhan6665 <1588741+yuhan6665@users.noreply.github.com> Date: Sun, 12 May 2024 15:57:18 -0400 Subject: [PATCH 1/4] Add session context outbounds as slice slice is needed for dialer proxy where two outbounds work on top of each other There are two sets of target addr for example It also enable Xtls to correctly do splice copy by checking both outbounds are ready to do direct copy --- app/dispatcher/default.go | 21 ++++++++------ app/dispatcher/fakednssniffer.go | 11 ++++---- app/proxyman/inbound/worker.go | 11 ++++---- app/proxyman/outbound/handler.go | 37 +++++++++++-------------- app/reverse/portal.go | 12 ++++---- app/router/router_test.go | 20 +++++++++---- common/mux/client.go | 14 ++++++---- common/mux/client_test.go | 8 +++--- common/session/context.go | 10 +++---- common/session/session.go | 14 ++++------ common/singbridge/dialer.go | 11 ++++++-- features/routing/session/context.go | 4 ++- proxy/blackhole/blackhole.go | 7 ++--- proxy/dns/dns.go | 11 ++++---- proxy/dokodemo/dokodemo.go | 15 ++++++---- proxy/freedom/freedom.go | 14 +++++----- proxy/http/client.go | 21 +++++++------- proxy/http/server.go | 2 +- proxy/loopback/loopback.go | 9 +++--- proxy/shadowsocks/client.go | 14 ++++------ proxy/shadowsocks/server.go | 2 +- proxy/shadowsocks_2022/inbound.go | 2 +- proxy/shadowsocks_2022/inbound_multi.go | 2 +- proxy/shadowsocks_2022/inbound_relay.go | 2 +- proxy/shadowsocks_2022/outbound.go | 11 ++++---- proxy/socks/client.go | 14 ++++------ proxy/socks/server.go | 2 +- proxy/trojan/client.go | 14 ++++------ proxy/trojan/server.go | 2 +- proxy/vless/inbound/inbound.go | 4 +-- proxy/vless/outbound/outbound.go | 20 ++++++------- proxy/vmess/inbound/inbound.go | 2 +- proxy/vmess/outbound/outbound.go | 14 ++++------ proxy/wireguard/client.go | 14 ++++------ proxy/wireguard/server.go | 8 ++++-- transport/internet/dialer.go | 9 ++++-- transport/internet/grpc/dial.go | 2 +- transport/internet/http/dialer.go | 2 +- 38 files changed, 204 insertions(+), 188 deletions(-) diff --git a/app/dispatcher/default.go b/app/dispatcher/default.go index b8131b8f89de..fa8c5d006f4b 100644 --- a/app/dispatcher/default.go +++ b/app/dispatcher/default.go @@ -218,11 +218,12 @@ func (d *DefaultDispatcher) Dispatch(ctx context.Context, destination net.Destin if !destination.IsValid() { panic("Dispatcher: Invalid destination.") } - ob := session.OutboundFromContext(ctx) - if ob == nil { - ob = &session.Outbound{} - ctx = session.ContextWithOutbound(ctx, ob) + outbounds := session.OutboundsFromContext(ctx) + if len(outbounds) == 0 { + outbounds = []*session.Outbound{{}} + ctx = session.ContextWithOutbounds(ctx, outbounds) } + ob := outbounds[len(outbounds) - 1] ob.OriginalTarget = destination ob.Target = destination content := session.ContentFromContext(ctx) @@ -274,11 +275,12 @@ func (d *DefaultDispatcher) DispatchLink(ctx context.Context, destination net.De if !destination.IsValid() { return newError("Dispatcher: Invalid destination.") } - ob := session.OutboundFromContext(ctx) - if ob == nil { - ob = &session.Outbound{} - ctx = session.ContextWithOutbound(ctx, ob) + outbounds := session.OutboundsFromContext(ctx) + if len(outbounds) == 0 { + outbounds = []*session.Outbound{{}} + ctx = session.ContextWithOutbounds(ctx, outbounds) } + ob := outbounds[len(outbounds) - 1] ob.OriginalTarget = destination ob.Target = destination content := session.ContentFromContext(ctx) @@ -368,7 +370,8 @@ func sniffer(ctx context.Context, cReader *cachedReader, metadataOnly bool, netw return contentResult, contentErr } func (d *DefaultDispatcher) routedDispatch(ctx context.Context, link *transport.Link, destination net.Destination) { - ob := session.OutboundFromContext(ctx) + outbounds := session.OutboundsFromContext(ctx) + ob := outbounds[len(outbounds) - 1] if hosts, ok := d.dns.(dns.HostsLookup); ok && destination.Address.Family().IsDomain() { proxied := hosts.LookupHosts(ob.Target.String()) if proxied != nil { diff --git a/app/dispatcher/fakednssniffer.go b/app/dispatcher/fakednssniffer.go index ad879daf5d10..8d0804de1421 100644 --- a/app/dispatcher/fakednssniffer.go +++ b/app/dispatcher/fakednssniffer.go @@ -26,11 +26,12 @@ func newFakeDNSSniffer(ctx context.Context) (protocolSnifferWithMetadata, error) return protocolSnifferWithMetadata{}, errNotInit } return protocolSnifferWithMetadata{protocolSniffer: func(ctx context.Context, bytes []byte) (SniffResult, error) { - Target := session.OutboundFromContext(ctx).Target - if Target.Network == net.Network_TCP || Target.Network == net.Network_UDP { - domainFromFakeDNS := fakeDNSEngine.GetDomainFromFakeDNS(Target.Address) + outbounds := session.OutboundsFromContext(ctx) + ob := outbounds[len(outbounds) - 1] + if ob.Target.Network == net.Network_TCP || ob.Target.Network == net.Network_UDP { + domainFromFakeDNS := fakeDNSEngine.GetDomainFromFakeDNS(ob.Target.Address) if domainFromFakeDNS != "" { - newError("fake dns got domain: ", domainFromFakeDNS, " for ip: ", Target.Address.String()).WriteToLog(session.ExportIDToError(ctx)) + newError("fake dns got domain: ", domainFromFakeDNS, " for ip: ", ob.Target.Address.String()).WriteToLog(session.ExportIDToError(ctx)) return &fakeDNSSniffResult{domainName: domainFromFakeDNS}, nil } } @@ -38,7 +39,7 @@ func newFakeDNSSniffer(ctx context.Context) (protocolSnifferWithMetadata, error) if ipAddressInRangeValueI := ctx.Value(ipAddressInRange); ipAddressInRangeValueI != nil { ipAddressInRangeValue := ipAddressInRangeValueI.(*ipAddressInRangeOpt) if fkr0, ok := fakeDNSEngine.(dns.FakeDNSEngineRev0); ok { - inPool := fkr0.IsIPInIPPool(Target.Address) + inPool := fkr0.IsIPInIPPool(ob.Target.Address) ipAddressInRangeValue.addressInRange = &inPool } } diff --git a/app/proxyman/inbound/worker.go b/app/proxyman/inbound/worker.go index 1fe866552908..9a6499f1290e 100644 --- a/app/proxyman/inbound/worker.go +++ b/app/proxyman/inbound/worker.go @@ -60,7 +60,7 @@ func (w *tcpWorker) callback(conn stat.Connection) { sid := session.NewID() ctx = session.ContextWithID(ctx, sid) - var outbound = &session.Outbound{} + outbounds := []*session.Outbound{{}} if w.recvOrigDest { var dest net.Destination switch getTProxyType(w.stream) { @@ -75,10 +75,10 @@ func (w *tcpWorker) callback(conn stat.Connection) { dest = net.DestinationFromAddr(conn.LocalAddr()) } if dest.IsValid() { - outbound.Target = dest + outbounds[0].Target = dest } } - ctx = session.ContextWithOutbound(ctx, outbound) + ctx = session.ContextWithOutbounds(ctx, outbounds) if w.uplinkCounter != nil || w.downlinkCounter != nil { conn = &stat.CounterConnection{ @@ -309,9 +309,10 @@ func (w *udpWorker) callback(b *buf.Buffer, source net.Destination, originalDest ctx = session.ContextWithID(ctx, sid) if originalDest.IsValid() { - ctx = session.ContextWithOutbound(ctx, &session.Outbound{ + outbounds := []*session.Outbound{{ Target: originalDest, - }) + }} + ctx = session.ContextWithOutbounds(ctx, outbounds) } ctx = session.ContextWithInbound(ctx, &session.Inbound{ Source: source, diff --git a/app/proxyman/outbound/handler.go b/app/proxyman/outbound/handler.go index 792ac24971a6..8150c7fa9aad 100644 --- a/app/proxyman/outbound/handler.go +++ b/app/proxyman/outbound/handler.go @@ -169,10 +169,11 @@ func (h *Handler) Tag() string { // Dispatch implements proxy.Outbound.Dispatch. func (h *Handler) Dispatch(ctx context.Context, link *transport.Link) { - outbound := session.OutboundFromContext(ctx) - if outbound.Target.Network == net.Network_UDP && outbound.OriginalTarget.Address != nil && outbound.OriginalTarget.Address != outbound.Target.Address { - link.Reader = &buf.EndpointOverrideReader{Reader: link.Reader, Dest: outbound.Target.Address, OriginalDest: outbound.OriginalTarget.Address} - link.Writer = &buf.EndpointOverrideWriter{Writer: link.Writer, Dest: outbound.Target.Address, OriginalDest: outbound.OriginalTarget.Address} + outbounds := session.OutboundsFromContext(ctx) + ob := outbounds[len(outbounds) - 1] + if ob.Target.Network == net.Network_UDP && ob.OriginalTarget.Address != nil && ob.OriginalTarget.Address != ob.Target.Address { + link.Reader = &buf.EndpointOverrideReader{Reader: link.Reader, Dest: ob.Target.Address, OriginalDest: ob.OriginalTarget.Address} + link.Writer = &buf.EndpointOverrideWriter{Writer: link.Writer, Dest: ob.Target.Address, OriginalDest: ob.OriginalTarget.Address} } if h.mux != nil { test := func(err error) { @@ -183,7 +184,7 @@ func (h *Handler) Dispatch(ctx context.Context, link *transport.Link) { common.Interrupt(link.Writer) } } - if outbound.Target.Network == net.Network_UDP && outbound.Target.Port == 443 { + if ob.Target.Network == net.Network_UDP && ob.Target.Port == 443 { switch h.udp443 { case "reject": test(newError("XUDP rejected UDP/443 traffic").AtInfo()) @@ -192,7 +193,7 @@ func (h *Handler) Dispatch(ctx context.Context, link *transport.Link) { goto out } } - if h.xudp != nil && outbound.Target.Network == net.Network_UDP { + if h.xudp != nil && ob.Target.Network == net.Network_UDP { if !h.xudp.Enabled { goto out } @@ -243,10 +244,8 @@ func (h *Handler) Dial(ctx context.Context, dest net.Destination) (stat.Connecti handler := h.outboundManager.GetHandler(tag) if handler != nil { newError("proxying to ", tag, " for dest ", dest).AtDebug().WriteToLog(session.ExportIDToError(ctx)) - ctx = session.ContextWithOutbound(ctx, &session.Outbound{ - Target: dest, - }) - + outbounds := session.OutboundsFromContext(ctx) + ctx = session.ContextWithOutbounds(ctx, append(outbounds, &session.Outbound{Target: dest})) // add another outbound in session ctx opts := pipe.OptionsFromContext(ctx) uplinkReader, uplinkWriter := pipe.New(opts...) downlinkReader, downlinkWriter := pipe.New(opts...) @@ -266,15 +265,12 @@ func (h *Handler) Dial(ctx context.Context, dest net.Destination) (stat.Connecti } if h.senderSettings.Via != nil { - outbound := session.OutboundFromContext(ctx) - if outbound == nil { - outbound = new(session.Outbound) - ctx = session.ContextWithOutbound(ctx, outbound) - } + outbounds := session.OutboundsFromContext(ctx) + ob := outbounds[len(outbounds) - 1] if h.senderSettings.ViaCidr == "" { - outbound.Gateway = h.senderSettings.Via.AsAddress() + ob.Gateway = h.senderSettings.Via.AsAddress() } else { //Get a random address. - outbound.Gateway = ParseRandomIPv6(h.senderSettings.Via.AsAddress(), h.senderSettings.ViaCidr) + ob.Gateway = ParseRandomIPv6(h.senderSettings.Via.AsAddress(), h.senderSettings.ViaCidr) } } } @@ -285,10 +281,9 @@ func (h *Handler) Dial(ctx context.Context, dest net.Destination) (stat.Connecti conn, err := internet.Dial(ctx, dest, h.streamSettings) conn = h.getStatCouterConnection(conn) - outbound := session.OutboundFromContext(ctx) - if outbound != nil { - outbound.Conn = conn - } + outbounds := session.OutboundsFromContext(ctx) + ob := outbounds[len(outbounds) - 1] + ob.Conn = conn return conn, err } diff --git a/app/reverse/portal.go b/app/reverse/portal.go index fb0b693002aa..456de550db24 100644 --- a/app/reverse/portal.go +++ b/app/reverse/portal.go @@ -62,12 +62,13 @@ func (p *Portal) Close() error { } func (p *Portal) HandleConnection(ctx context.Context, link *transport.Link) error { - outboundMeta := session.OutboundFromContext(ctx) - if outboundMeta == nil { + outbounds := session.OutboundsFromContext(ctx) + ob := outbounds[len(outbounds) - 1] + if ob == nil { return newError("outbound metadata not found").AtError() } - if isDomain(outboundMeta.Target, p.domain) { + if isDomain(ob.Target, p.domain) { muxClient, err := mux.NewClientWorker(*link, mux.ClientStrategy{}) if err != nil { return newError("failed to create mux client worker").Base(err).AtWarning() @@ -206,9 +207,10 @@ func NewPortalWorker(client *mux.ClientWorker) (*PortalWorker, error) { downlinkReader, downlinkWriter := pipe.New(opt...) ctx := context.Background() - ctx = session.ContextWithOutbound(ctx, &session.Outbound{ + outbounds := []*session.Outbound{{ Target: net.UDPDestination(net.DomainAddress(internalDomain), 0), - }) + }} + ctx = session.ContextWithOutbounds(ctx, outbounds) f := client.Dispatch(ctx, &transport.Link{ Reader: uplinkReader, Writer: downlinkWriter, diff --git a/app/router/router_test.go b/app/router/router_test.go index 4c6bfc63f2f3..2c33aae1c6ea 100644 --- a/app/router/router_test.go +++ b/app/router/router_test.go @@ -45,7 +45,9 @@ func TestSimpleRouter(t *testing.T) { HandlerSelector: mockHs, }, nil)) - ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{Target: net.TCPDestination(net.DomainAddress("example.com"), 80)}) + ctx := session.ContextWithOutbounds(context.Background(), []*session.Outbound{{ + Target: net.TCPDestination(net.DomainAddress("example.com"), 80), + }}) route, err := r.PickRoute(routing_session.AsRoutingContext(ctx)) common.Must(err) if tag := route.GetOutboundTag(); tag != "test" { @@ -86,7 +88,9 @@ func TestSimpleBalancer(t *testing.T) { HandlerSelector: mockHs, }, nil)) - ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{Target: net.TCPDestination(net.DomainAddress("example.com"), 80)}) + ctx := session.ContextWithOutbounds(context.Background(), []*session.Outbound{{ + Target: net.TCPDestination(net.DomainAddress("example.com"), 80), + }}) route, err := r.PickRoute(routing_session.AsRoutingContext(ctx)) common.Must(err) if tag := route.GetOutboundTag(); tag != "test" { @@ -174,7 +178,9 @@ func TestIPOnDemand(t *testing.T) { r := new(Router) common.Must(r.Init(context.TODO(), config, mockDNS, nil, nil)) - ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{Target: net.TCPDestination(net.DomainAddress("example.com"), 80)}) + ctx := session.ContextWithOutbounds(context.Background(), []*session.Outbound{{ + Target: net.TCPDestination(net.DomainAddress("example.com"), 80), + }}) route, err := r.PickRoute(routing_session.AsRoutingContext(ctx)) common.Must(err) if tag := route.GetOutboundTag(); tag != "test" { @@ -213,7 +219,9 @@ func TestIPIfNonMatchDomain(t *testing.T) { r := new(Router) common.Must(r.Init(context.TODO(), config, mockDNS, nil, nil)) - ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{Target: net.TCPDestination(net.DomainAddress("example.com"), 80)}) + ctx := session.ContextWithOutbounds(context.Background(), []*session.Outbound{{ + Target: net.TCPDestination(net.DomainAddress("example.com"), 80), + }}) route, err := r.PickRoute(routing_session.AsRoutingContext(ctx)) common.Must(err) if tag := route.GetOutboundTag(); tag != "test" { @@ -247,7 +255,9 @@ func TestIPIfNonMatchIP(t *testing.T) { r := new(Router) common.Must(r.Init(context.TODO(), config, mockDNS, nil, nil)) - ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{Target: net.TCPDestination(net.LocalHostIP, 80)}) + ctx := session.ContextWithOutbounds(context.Background(), []*session.Outbound{{ + Target: net.TCPDestination(net.LocalHostIP, 80), + }}) route, err := r.PickRoute(routing_session.AsRoutingContext(ctx)) common.Must(err) if tag := route.GetOutboundTag(); tag != "test" { diff --git a/common/mux/client.go b/common/mux/client.go index 88621be0f0ef..2537f02b6d3c 100644 --- a/common/mux/client.go +++ b/common/mux/client.go @@ -148,9 +148,10 @@ func (f *DialingWorkerFactory) Create() (*ClientWorker, error) { } go func(p proxy.Outbound, d internet.Dialer, c common.Closable) { - ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{ + outbounds := []*session.Outbound{{ Target: net.TCPDestination(muxCoolAddress, muxCoolPort), - }) + }} + ctx := session.ContextWithOutbounds(context.Background(), outbounds) ctx, cancel := context.WithCancel(ctx) if err := p.Process(ctx, &transport.Link{Reader: uplinkReader, Writer: downlinkWriter}, d); err != nil { @@ -242,17 +243,18 @@ func writeFirstPayload(reader buf.Reader, writer *Writer) error { } func fetchInput(ctx context.Context, s *Session, output buf.Writer) { - dest := session.OutboundFromContext(ctx).Target + outbounds := session.OutboundsFromContext(ctx) + ob := outbounds[len(outbounds) - 1] transferType := protocol.TransferTypeStream - if dest.Network == net.Network_UDP { + if ob.Target.Network == net.Network_UDP { transferType = protocol.TransferTypePacket } s.transferType = transferType - writer := NewWriter(s.ID, dest, output, transferType, xudp.GetGlobalID(ctx)) + writer := NewWriter(s.ID, ob.Target, output, transferType, xudp.GetGlobalID(ctx)) defer s.Close(false) defer writer.Close() - newError("dispatching request to ", dest).WriteToLog(session.ExportIDToError(ctx)) + newError("dispatching request to ", ob.Target).WriteToLog(session.ExportIDToError(ctx)) if err := writeFirstPayload(s.input, writer); err != nil { newError("failed to write first payload").Base(err).WriteToLog(session.ExportIDToError(ctx)) writer.hasError = true diff --git a/common/mux/client_test.go b/common/mux/client_test.go index 7837a86e7024..9626e2a276af 100644 --- a/common/mux/client_test.go +++ b/common/mux/client_test.go @@ -86,9 +86,9 @@ func TestClientWorkerClose(t *testing.T) { } tr1, tw1 := pipe.New(pipe.WithoutSizeLimit()) - ctx1 := session.ContextWithOutbound(context.Background(), &session.Outbound{ + ctx1 := session.ContextWithOutbounds(context.Background(), []*session.Outbound{{ Target: net.TCPDestination(net.DomainAddress("www.example.com"), 80), - }) + }}) common.Must(manager.Dispatch(ctx1, &transport.Link{ Reader: tr1, Writer: tw1, @@ -103,9 +103,9 @@ func TestClientWorkerClose(t *testing.T) { } tr2, tw2 := pipe.New(pipe.WithoutSizeLimit()) - ctx2 := session.ContextWithOutbound(context.Background(), &session.Outbound{ + ctx2 := session.ContextWithOutbounds(context.Background(), []*session.Outbound{{ Target: net.TCPDestination(net.DomainAddress("www.example.com"), 80), - }) + }}) common.Must(manager.Dispatch(ctx2, &transport.Link{ Reader: tr2, Writer: tw2, diff --git a/common/session/context.go b/common/session/context.go index 87586169e666..fc37bd72e04d 100644 --- a/common/session/context.go +++ b/common/session/context.go @@ -51,13 +51,13 @@ func InboundFromContext(ctx context.Context) *Inbound { return nil } -func ContextWithOutbound(ctx context.Context, outbound *Outbound) context.Context { - return context.WithValue(ctx, outboundSessionKey, outbound) +func ContextWithOutbounds(ctx context.Context, outbounds []*Outbound) context.Context { + return context.WithValue(ctx, outboundSessionKey, outbounds) } -func OutboundFromContext(ctx context.Context) *Outbound { - if outbound, ok := ctx.Value(outboundSessionKey).(*Outbound); ok { - return outbound +func OutboundsFromContext(ctx context.Context) []*Outbound { + if outbounds, ok := ctx.Value(outboundSessionKey).([]*Outbound); ok { + return outbounds } return nil } diff --git a/common/session/session.go b/common/session/session.go index 38ffa7bda80d..d8ab1ec423c6 100644 --- a/common/session/session.go +++ b/common/session/session.go @@ -50,18 +50,11 @@ type Inbound struct { Conn net.Conn // Timer of the inbound buf copier. May be nil. Timer *signal.ActivityTimer - // CanSpliceCopy is a property for this connection, set by both inbound and outbound + // CanSpliceCopy is a property for this connection // 1 = can, 2 = after processing protocol info should be able to, 3 = cannot CanSpliceCopy int } -func(i *Inbound) SetCanSpliceCopy(canSpliceCopy int) int { - if canSpliceCopy > i.CanSpliceCopy { - i.CanSpliceCopy = canSpliceCopy - } - return i.CanSpliceCopy -} - // Outbound is the metadata of an outbound connection. type Outbound struct { // Target address of the outbound connection. @@ -70,10 +63,15 @@ type Outbound struct { RouteTarget net.Destination // Gateway address Gateway net.Address + // Tag of the outbound proxy that handles the connection. + Tag string // Name of the outbound proxy that handles the connection. Name string // Conn is actually internet.Connection. May be nil. It is currently nil for outbound with proxySettings Conn net.Conn + // CanSpliceCopy is a property for this connection + // 1 = can, 2 = after processing protocol info should be able to, 3 = cannot + CanSpliceCopy int } // SniffingRequest controls the behavior of content sniffing. diff --git a/common/singbridge/dialer.go b/common/singbridge/dialer.go index 896c97fee532..6be83036e85b 100644 --- a/common/singbridge/dialer.go +++ b/common/singbridge/dialer.go @@ -43,9 +43,14 @@ func NewOutboundDialer(outbound proxy.Outbound, dialer internet.Dialer) *XrayOut } func (d *XrayOutboundDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { - ctx = session.ContextWithOutbound(ctx, &session.Outbound{ - Target: ToDestination(destination, ToNetwork(network)), - }) + outbounds := session.OutboundsFromContext(ctx) + if len(outbounds) == 0 { + outbounds = []*session.Outbound{{}} + ctx = session.ContextWithOutbounds(ctx, outbounds) + } + ob := outbounds[len(outbounds) - 1] + ob.Target = ToDestination(destination, ToNetwork(network)) + opts := []pipe.Option{pipe.WithSizeLimit(64 * 1024)} uplinkReader, uplinkWriter := pipe.New(opts...) downlinkReader, downlinkWriter := pipe.New(opts...) diff --git a/features/routing/session/context.go b/features/routing/session/context.go index c900219dc472..3c9764b3f917 100644 --- a/features/routing/session/context.go +++ b/features/routing/session/context.go @@ -124,9 +124,11 @@ func (ctx *Context) GetSkipDNSResolve() bool { // AsRoutingContext creates a context from context.context with session info. func AsRoutingContext(ctx context.Context) routing.Context { + outbounds := session.OutboundsFromContext(ctx) + ob := outbounds[len(outbounds) - 1] return &Context{ Inbound: session.InboundFromContext(ctx), - Outbound: session.OutboundFromContext(ctx), + Outbound: ob, Content: session.ContentFromContext(ctx), } } diff --git a/proxy/blackhole/blackhole.go b/proxy/blackhole/blackhole.go index 4b8194172e70..23c9c2919d85 100644 --- a/proxy/blackhole/blackhole.go +++ b/proxy/blackhole/blackhole.go @@ -31,10 +31,9 @@ func New(ctx context.Context, config *Config) (*Handler, error) { // Process implements OutboundHandler.Dispatch(). func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error { - outbound := session.OutboundFromContext(ctx) - if outbound != nil { - outbound.Name = "blackhole" - } + outbounds := session.OutboundsFromContext(ctx) + ob := outbounds[len(outbounds) - 1] + ob.Name = "blackhole" nBytes := h.response.WriteTo(link.Writer) if nBytes > 0 { diff --git a/proxy/dns/dns.go b/proxy/dns/dns.go index 2cf21a429fa6..86790f7674de 100644 --- a/proxy/dns/dns.go +++ b/proxy/dns/dns.go @@ -96,15 +96,16 @@ func parseIPQuery(b []byte) (r bool, domain string, id uint16, qType dnsmessage. // Process implements proxy.Outbound. func (h *Handler) Process(ctx context.Context, link *transport.Link, d internet.Dialer) error { - outbound := session.OutboundFromContext(ctx) - if outbound == nil || !outbound.Target.IsValid() { + outbounds := session.OutboundsFromContext(ctx) + ob := outbounds[len(outbounds) - 1] + if !ob.Target.IsValid() { return newError("invalid outbound") } - outbound.Name = "dns" + ob.Name = "dns" - srcNetwork := outbound.Target.Network + srcNetwork := ob.Target.Network - dest := outbound.Target + dest := ob.Target if h.server.Network != net.Network_Unknown { dest.Network = h.server.Network } diff --git a/proxy/dokodemo/dokodemo.go b/proxy/dokodemo/dokodemo.go index 1c59fe6231ce..5a07df5cbee1 100644 --- a/proxy/dokodemo/dokodemo.go +++ b/proxy/dokodemo/dokodemo.go @@ -86,10 +86,15 @@ func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn st destinationOverridden := false if d.config.FollowRedirect { - if outbound := session.OutboundFromContext(ctx); outbound != nil && outbound.Target.IsValid() { - dest = outbound.Target - destinationOverridden = true - } else if handshake, ok := conn.(hasHandshakeAddressContext); ok { + outbounds := session.OutboundsFromContext(ctx) + if len(outbounds) > 0 { + ob := outbounds[len(outbounds) - 1] + if ob.Target.IsValid() { + dest = ob.Target + destinationOverridden = true + } + } + if handshake, ok := conn.(hasHandshakeAddressContext); ok && !destinationOverridden { addr := handshake.HandshakeAddressContext(ctx) if addr != nil { dest.Address = addr @@ -103,7 +108,7 @@ func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn st inbound := session.InboundFromContext(ctx) inbound.Name = "dokodemo-door" - inbound.SetCanSpliceCopy(1) + inbound.CanSpliceCopy = 1 inbound.User = &protocol.MemoryUser{ Level: d.config.UserLevel, } diff --git a/proxy/freedom/freedom.go b/proxy/freedom/freedom.go index 0176929cc453..9e6afc9d1a73 100644 --- a/proxy/freedom/freedom.go +++ b/proxy/freedom/freedom.go @@ -106,16 +106,16 @@ func isValidAddress(addr *net.IPOrDomain) bool { // Process implements proxy.Outbound. func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error { - outbound := session.OutboundFromContext(ctx) - if outbound == nil || !outbound.Target.IsValid() { + outbounds := session.OutboundsFromContext(ctx) + ob := outbounds[len(outbounds) - 1] + if !ob.Target.IsValid() { return newError("target not specified.") } - outbound.Name = "freedom" + ob.Name = "freedom" + ob.CanSpliceCopy = 1 inbound := session.InboundFromContext(ctx) - if inbound != nil { - inbound.SetCanSpliceCopy(1) - } - destination := outbound.Target + + destination := ob.Target UDPOverride := net.UDPDestination(nil, 0) if h.config.DestinationOverride != nil { server := h.config.DestinationOverride.Server diff --git a/proxy/http/client.go b/proxy/http/client.go index 72060c4d22fd..80a0328a76fc 100644 --- a/proxy/http/client.go +++ b/proxy/http/client.go @@ -69,16 +69,14 @@ func NewClient(ctx context.Context, config *ClientConfig) (*Client, error) { // Process implements proxy.Outbound.Process. We first create a socket tunnel via HTTP CONNECT method, then redirect all inbound traffic to that tunnel. func (c *Client) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error { - outbound := session.OutboundFromContext(ctx) - if outbound == nil || !outbound.Target.IsValid() { + outbounds := session.OutboundsFromContext(ctx) + ob := outbounds[len(outbounds) - 1] + if !ob.Target.IsValid() { return newError("target not specified.") } - outbound.Name = "http" - inbound := session.InboundFromContext(ctx) - if inbound != nil { - inbound.SetCanSpliceCopy(2) - } - target := outbound.Target + ob.Name = "http" + ob.CanSpliceCopy = 2 + target := ob.Target targetAddr := target.NetAddr() if target.Network == net.Network_UDP { @@ -175,9 +173,10 @@ func fillRequestHeader(ctx context.Context, header []*Header) ([]*Header, error) } inbound := session.InboundFromContext(ctx) - outbound := session.OutboundFromContext(ctx) + outbounds := session.OutboundsFromContext(ctx) + ob := outbounds[len(outbounds) - 1] - if inbound == nil || outbound == nil { + if inbound == nil || ob == nil { return nil, newError("missing inbound or outbound metadata from context") } @@ -186,7 +185,7 @@ func fillRequestHeader(ctx context.Context, header []*Header) ([]*Header, error) Target net.Destination }{ Source: inbound.Source, - Target: outbound.Target, + Target: ob.Target, } filled := make([]*Header, len(header)) diff --git a/proxy/http/server.go b/proxy/http/server.go index 511d9b08c3ae..a7df317dc05d 100644 --- a/proxy/http/server.go +++ b/proxy/http/server.go @@ -85,7 +85,7 @@ type readerOnly struct { func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Connection, dispatcher routing.Dispatcher) error { inbound := session.InboundFromContext(ctx) inbound.Name = "http" - inbound.SetCanSpliceCopy(2) + inbound.CanSpliceCopy = 2 inbound.User = &protocol.MemoryUser{ Level: s.config.UserLevel, } diff --git a/proxy/loopback/loopback.go b/proxy/loopback/loopback.go index 30c39bd96409..f3be5a95d78a 100644 --- a/proxy/loopback/loopback.go +++ b/proxy/loopback/loopback.go @@ -22,12 +22,13 @@ type Loopback struct { } func (l *Loopback) Process(ctx context.Context, link *transport.Link, _ internet.Dialer) error { - outbound := session.OutboundFromContext(ctx) - if outbound == nil || !outbound.Target.IsValid() { + outbounds := session.OutboundsFromContext(ctx) + ob := outbounds[len(outbounds) - 1] + if !ob.Target.IsValid() { return newError("target not specified.") } - outbound.Name = "loopback" - destination := outbound.Target + ob.Name = "loopback" + destination := ob.Target newError("opening connection to ", destination).WriteToLog(session.ExportIDToError(ctx)) diff --git a/proxy/shadowsocks/client.go b/proxy/shadowsocks/client.go index 57d8f81c7c6c..8ebe7631b70c 100644 --- a/proxy/shadowsocks/client.go +++ b/proxy/shadowsocks/client.go @@ -49,16 +49,14 @@ func NewClient(ctx context.Context, config *ClientConfig) (*Client, error) { // Process implements OutboundHandler.Process(). func (c *Client) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error { - outbound := session.OutboundFromContext(ctx) - if outbound == nil || !outbound.Target.IsValid() { + outbounds := session.OutboundsFromContext(ctx) + ob := outbounds[len(outbounds) - 1] + if !ob.Target.IsValid() { return newError("target not specified") } - outbound.Name = "shadowsocks" - inbound := session.InboundFromContext(ctx) - if inbound != nil { - inbound.SetCanSpliceCopy(3) - } - destination := outbound.Target + ob.Name = "shadowsocks" + ob.CanSpliceCopy = 3 + destination := ob.Target network := destination.Network var server *protocol.ServerSpec diff --git a/proxy/shadowsocks/server.go b/proxy/shadowsocks/server.go index 2975ba70aec0..8253506a6fff 100644 --- a/proxy/shadowsocks/server.go +++ b/proxy/shadowsocks/server.go @@ -73,7 +73,7 @@ func (s *Server) Network() []net.Network { func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Connection, dispatcher routing.Dispatcher) error { inbound := session.InboundFromContext(ctx) inbound.Name = "shadowsocks" - inbound.SetCanSpliceCopy(3) + inbound.CanSpliceCopy = 3 switch network { case net.Network_TCP: diff --git a/proxy/shadowsocks_2022/inbound.go b/proxy/shadowsocks_2022/inbound.go index 00314c90ec97..f1eb76a5a2e7 100644 --- a/proxy/shadowsocks_2022/inbound.go +++ b/proxy/shadowsocks_2022/inbound.go @@ -66,7 +66,7 @@ func (i *Inbound) Network() []net.Network { func (i *Inbound) Process(ctx context.Context, network net.Network, connection stat.Connection, dispatcher routing.Dispatcher) error { inbound := session.InboundFromContext(ctx) inbound.Name = "shadowsocks-2022" - inbound.SetCanSpliceCopy(3) + inbound.CanSpliceCopy = 3 var metadata M.Metadata if inbound.Source.IsValid() { diff --git a/proxy/shadowsocks_2022/inbound_multi.go b/proxy/shadowsocks_2022/inbound_multi.go index df837894d8ce..f80ec6d11b6f 100644 --- a/proxy/shadowsocks_2022/inbound_multi.go +++ b/proxy/shadowsocks_2022/inbound_multi.go @@ -155,7 +155,7 @@ func (i *MultiUserInbound) Network() []net.Network { func (i *MultiUserInbound) Process(ctx context.Context, network net.Network, connection stat.Connection, dispatcher routing.Dispatcher) error { inbound := session.InboundFromContext(ctx) inbound.Name = "shadowsocks-2022-multi" - inbound.SetCanSpliceCopy(3) + inbound.CanSpliceCopy = 3 var metadata M.Metadata if inbound.Source.IsValid() { diff --git a/proxy/shadowsocks_2022/inbound_relay.go b/proxy/shadowsocks_2022/inbound_relay.go index 7317f8dd27c5..1c4b824870c7 100644 --- a/proxy/shadowsocks_2022/inbound_relay.go +++ b/proxy/shadowsocks_2022/inbound_relay.go @@ -87,7 +87,7 @@ func (i *RelayInbound) Network() []net.Network { func (i *RelayInbound) Process(ctx context.Context, network net.Network, connection stat.Connection, dispatcher routing.Dispatcher) error { inbound := session.InboundFromContext(ctx) inbound.Name = "shadowsocks-2022-relay" - inbound.SetCanSpliceCopy(3) + inbound.CanSpliceCopy = 3 var metadata M.Metadata if inbound.Source.IsValid() { diff --git a/proxy/shadowsocks_2022/outbound.go b/proxy/shadowsocks_2022/outbound.go index bc1eb556f069..cac9a91bbb78 100644 --- a/proxy/shadowsocks_2022/outbound.go +++ b/proxy/shadowsocks_2022/outbound.go @@ -65,15 +65,16 @@ func (o *Outbound) Process(ctx context.Context, link *transport.Link, dialer int inbound := session.InboundFromContext(ctx) if inbound != nil { inboundConn = inbound.Conn - inbound.SetCanSpliceCopy(3) } - outbound := session.OutboundFromContext(ctx) - if outbound == nil || !outbound.Target.IsValid() { + outbounds := session.OutboundsFromContext(ctx) + ob := outbounds[len(outbounds) - 1] + if !ob.Target.IsValid() { return newError("target not specified") } - outbound.Name = "shadowsocks-2022" - destination := outbound.Target + ob.Name = "shadowsocks-2022" + ob.CanSpliceCopy = 3 + destination := ob.Target network := destination.Network newError("tunneling request to ", destination, " via ", o.server.NetAddr()).WriteToLog(session.ExportIDToError(ctx)) diff --git a/proxy/socks/client.go b/proxy/socks/client.go index 82591be4321f..b283eb6506b4 100644 --- a/proxy/socks/client.go +++ b/proxy/socks/client.go @@ -57,17 +57,15 @@ func NewClient(ctx context.Context, config *ClientConfig) (*Client, error) { // Process implements proxy.Outbound.Process. func (c *Client) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error { - outbound := session.OutboundFromContext(ctx) - if outbound == nil || !outbound.Target.IsValid() { + outbounds := session.OutboundsFromContext(ctx) + ob := outbounds[len(outbounds) - 1] + if !ob.Target.IsValid() { return newError("target not specified.") } - outbound.Name = "socks" - inbound := session.InboundFromContext(ctx) - if inbound != nil { - inbound.SetCanSpliceCopy(2) - } + ob.Name = "socks" + ob.CanSpliceCopy = 2 // Destination of the inner request. - destination := outbound.Target + destination := ob.Target // Outbound server. var server *protocol.ServerSpec diff --git a/proxy/socks/server.go b/proxy/socks/server.go index 2f7897572b0e..0109d5b4225b 100644 --- a/proxy/socks/server.go +++ b/proxy/socks/server.go @@ -65,7 +65,7 @@ func (s *Server) Network() []net.Network { func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Connection, dispatcher routing.Dispatcher) error { inbound := session.InboundFromContext(ctx) inbound.Name = "socks" - inbound.SetCanSpliceCopy(2) + inbound.CanSpliceCopy = 2 inbound.User = &protocol.MemoryUser{ Level: s.config.UserLevel, } diff --git a/proxy/trojan/client.go b/proxy/trojan/client.go index d6b95fc0b154..3a4d838ae575 100644 --- a/proxy/trojan/client.go +++ b/proxy/trojan/client.go @@ -50,16 +50,14 @@ func NewClient(ctx context.Context, config *ClientConfig) (*Client, error) { // Process implements OutboundHandler.Process(). func (c *Client) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error { - outbound := session.OutboundFromContext(ctx) - if outbound == nil || !outbound.Target.IsValid() { + outbounds := session.OutboundsFromContext(ctx) + ob := outbounds[len(outbounds) - 1] + if !ob.Target.IsValid() { return newError("target not specified") } - outbound.Name = "trojan" - inbound := session.InboundFromContext(ctx) - if inbound != nil { - inbound.SetCanSpliceCopy(3) - } - destination := outbound.Target + ob.Name = "trojan" + ob.CanSpliceCopy = 3 + destination := ob.Target network := destination.Network var server *protocol.ServerSpec diff --git a/proxy/trojan/server.go b/proxy/trojan/server.go index 5c3fcd9113f5..bc52c2b14682 100644 --- a/proxy/trojan/server.go +++ b/proxy/trojan/server.go @@ -215,7 +215,7 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Con inbound := session.InboundFromContext(ctx) inbound.Name = "trojan" - inbound.SetCanSpliceCopy(3) + inbound.CanSpliceCopy = 3 inbound.User = user sessionPolicy = s.policyManager.ForLevel(user.Level) diff --git a/proxy/vless/inbound/inbound.go b/proxy/vless/inbound/inbound.go index 0ffa61d2ae87..9f7096e16d41 100644 --- a/proxy/vless/inbound/inbound.go +++ b/proxy/vless/inbound/inbound.go @@ -449,7 +449,7 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s switch requestAddons.Flow { case vless.XRV: if account.Flow == requestAddons.Flow { - inbound.SetCanSpliceCopy(2) + inbound.CanSpliceCopy = 2 switch request.Command { case protocol.RequestCommandUDP: return newError(requestAddons.Flow + " doesn't support UDP").AtWarning() @@ -479,7 +479,7 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s return newError(account.ID.String() + " is not able to use " + requestAddons.Flow).AtWarning() } case "": - inbound.SetCanSpliceCopy(3) + inbound.CanSpliceCopy = 3 if account.Flow == vless.XRV && (request.Command == protocol.RequestCommandTCP || isMuxAndNotXUDP(request, first)) { return newError(account.ID.String() + " is not able to use \"\". Note that the pure TLS proxy has certain TLS in TLS characters.").AtWarning() } diff --git a/proxy/vless/outbound/outbound.go b/proxy/vless/outbound/outbound.go index a93688139b2e..495dd74b42aa 100644 --- a/proxy/vless/outbound/outbound.go +++ b/proxy/vless/outbound/outbound.go @@ -70,12 +70,12 @@ func New(ctx context.Context, config *Config) (*Handler, error) { // Process implements proxy.Outbound.Process(). func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error { - outbound := session.OutboundFromContext(ctx) - if outbound == nil || !outbound.Target.IsValid() { + outbounds := session.OutboundsFromContext(ctx) + ob := outbounds[len(outbounds) - 1] + if !ob.Target.IsValid() { return newError("target not specified").AtError() } - outbound.Name = "vless" - inbound := session.InboundFromContext(ctx) + ob.Name = "vless" var rec *protocol.ServerSpec var conn stat.Connection @@ -96,7 +96,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte if statConn, ok := iConn.(*stat.CounterConnection); ok { iConn = statConn.Connection } - target := outbound.Target + target := ob.Target newError("tunneling request to ", target, " via ", rec.Destination().NetAddr()).AtInfo().WriteToLog(session.ExportIDToError(ctx)) command := protocol.RequestCommandTCP @@ -130,9 +130,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte requestAddons.Flow = requestAddons.Flow[:16] fallthrough case vless.XRV: - if inbound != nil { - inbound.SetCanSpliceCopy(2) - } + ob.CanSpliceCopy = 2 switch request.Command { case protocol.RequestCommandUDP: if !allowUDP443 && request.Port == 443 { @@ -161,9 +159,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte rawInput = (*bytes.Buffer)(unsafe.Pointer(p + r.Offset)) } default: - if inbound != nil { - inbound.SetCanSpliceCopy(3) - } + ob.CanSpliceCopy = 3 } var newCtx context.Context @@ -238,7 +234,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte return newError(`failed to use `+requestAddons.Flow+`, found outer tls version `, utlsConn.ConnectionState().Version).AtWarning() } } - ctx1 := session.ContextWithOutbound(ctx, nil) // TODO enable splice + ctx1 := session.ContextWithInbound(ctx, nil) // TODO enable splice err = encoding.XtlsWrite(clientReader, serverWriter, timer, conn, trafficState, ctx1) } else { // from clientReader.ReadMultiBuffer to serverWriter.WriteMultiBufer diff --git a/proxy/vmess/inbound/inbound.go b/proxy/vmess/inbound/inbound.go index 679ea5da688e..f5340f20c418 100644 --- a/proxy/vmess/inbound/inbound.go +++ b/proxy/vmess/inbound/inbound.go @@ -257,7 +257,7 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s inbound := session.InboundFromContext(ctx) inbound.Name = "vmess" - inbound.SetCanSpliceCopy(3) + inbound.CanSpliceCopy = 3 inbound.User = request.User sessionPolicy = h.policyManager.ForLevel(request.User.Level) diff --git a/proxy/vmess/outbound/outbound.go b/proxy/vmess/outbound/outbound.go index c3c55d956fcd..8f102dbb1129 100644 --- a/proxy/vmess/outbound/outbound.go +++ b/proxy/vmess/outbound/outbound.go @@ -60,15 +60,13 @@ func New(ctx context.Context, config *Config) (*Handler, error) { // Process implements proxy.Outbound.Process(). func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error { - outbound := session.OutboundFromContext(ctx) - if outbound == nil || !outbound.Target.IsValid() { + outbounds := session.OutboundsFromContext(ctx) + ob := outbounds[len(outbounds) - 1] + if !ob.Target.IsValid() { return newError("target not specified").AtError() } - outbound.Name = "vmess" - inbound := session.InboundFromContext(ctx) - if inbound != nil { - inbound.SetCanSpliceCopy(3) - } + ob.Name = "vmess" + ob.CanSpliceCopy = 3 var rec *protocol.ServerSpec var conn stat.Connection @@ -87,7 +85,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte } defer conn.Close() - target := outbound.Target + target := ob.Target newError("tunneling request to ", target, " via ", rec.Destination().NetAddr()).WriteToLog(session.ExportIDToError(ctx)) command := protocol.RequestCommandTCP diff --git a/proxy/wireguard/client.go b/proxy/wireguard/client.go index 4136525e69ac..00a6fa51a76b 100644 --- a/proxy/wireguard/client.go +++ b/proxy/wireguard/client.go @@ -127,22 +127,20 @@ func (h *Handler) processWireGuard(dialer internet.Dialer) (err error) { // Process implements OutboundHandler.Dispatch(). func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error { - outbound := session.OutboundFromContext(ctx) - if outbound == nil || !outbound.Target.IsValid() { + outbounds := session.OutboundsFromContext(ctx) + ob := outbounds[len(outbounds) - 1] + if !ob.Target.IsValid() { return newError("target not specified") } - outbound.Name = "wireguard" - inbound := session.InboundFromContext(ctx) - if inbound != nil { - inbound.SetCanSpliceCopy(3) - } + ob.Name = "wireguard" + ob.CanSpliceCopy = 3 if err := h.processWireGuard(dialer); err != nil { return err } // Destination of the inner request. - destination := outbound.Target + destination := ob.Target command := protocol.RequestCommandTCP if destination.Network == net.Network_UDP { command = protocol.RequestCommandUDP diff --git a/proxy/wireguard/server.go b/proxy/wireguard/server.go index bdb4e8018c6b..3d3b584cecef 100644 --- a/proxy/wireguard/server.go +++ b/proxy/wireguard/server.go @@ -79,13 +79,15 @@ func (*Server) Network() []net.Network { func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Connection, dispatcher routing.Dispatcher) error { inbound := session.InboundFromContext(ctx) inbound.Name = "wireguard" - inbound.SetCanSpliceCopy(3) + inbound.CanSpliceCopy = 3 + outbounds := session.OutboundsFromContext(ctx) + ob := outbounds[len(outbounds) - 1] s.info = routingInfo{ ctx: core.ToBackgroundDetachedContext(ctx), dispatcher: dispatcher, inboundTag: session.InboundFromContext(ctx), - outboundTag: session.OutboundFromContext(ctx), + outboundTag: ob, contentTag: session.ContentFromContext(ctx), } @@ -145,7 +147,7 @@ func (s *Server) forwardConnection(dest net.Destination, conn net.Conn) { ctx = session.ContextWithInbound(ctx, s.info.inboundTag) } if s.info.outboundTag != nil { - ctx = session.ContextWithOutbound(ctx, s.info.outboundTag) + ctx = session.ContextWithOutbounds(ctx, []*session.Outbound{s.info.outboundTag}) } if s.info.contentTag != nil { ctx = session.ContextWithContent(ctx, s.info.contentTag) diff --git a/transport/internet/dialer.go b/transport/internet/dialer.go index 3d5d046f7ac9..8231f03263e7 100644 --- a/transport/internet/dialer.go +++ b/transport/internet/dialer.go @@ -112,7 +112,8 @@ func canLookupIP(ctx context.Context, dst net.Destination, sockopt *SocketConfig func redirect(ctx context.Context, dst net.Destination, obt string) net.Conn { newError("redirecting request " + dst.String() + " to " + obt).WriteToLog(session.ExportIDToError(ctx)) h := obm.GetHandler(obt) - ctx = session.ContextWithOutbound(ctx, &session.Outbound{Target: dst, Gateway: nil}) + outbounds := session.OutboundsFromContext(ctx) + ctx = session.ContextWithOutbounds(ctx, append(outbounds, &session.Outbound{Target: dst, Gateway: nil})) // add another outbound in session ctx if h != nil { ur, uw := pipe.New(pipe.OptionsFromContext(ctx)...) dr, dw := pipe.New(pipe.OptionsFromContext(ctx)...) @@ -131,8 +132,10 @@ func redirect(ctx context.Context, dst net.Destination, obt string) net.Conn { // DialSystem calls system dialer to create a network connection. func DialSystem(ctx context.Context, dest net.Destination, sockopt *SocketConfig) (net.Conn, error) { var src net.Address - if outbound := session.OutboundFromContext(ctx); outbound != nil { - src = outbound.Gateway + outbounds := session.OutboundsFromContext(ctx) + if len(outbounds) > 0 { + ob := outbounds[len(outbounds) - 1] + src = ob.Gateway } if sockopt == nil { return effectiveSystemDialer.Dial(ctx, src, dest, sockopt) diff --git a/transport/internet/grpc/dial.go b/transport/internet/grpc/dial.go index 5d5789b4773a..a4b03cedc8ee 100644 --- a/transport/internet/grpc/dial.go +++ b/transport/internet/grpc/dial.go @@ -118,7 +118,7 @@ func getGrpcClient(ctx context.Context, dest net.Destination, streamSettings *in address := net.ParseAddress(rawHost) gctx = session.ContextWithID(gctx, session.IDFromContext(ctx)) - gctx = session.ContextWithOutbound(gctx, session.OutboundFromContext(ctx)) + gctx = session.ContextWithOutbounds(gctx, session.OutboundsFromContext(ctx)) gctx = session.ContextWithTimeoutOnly(gctx, true) c, err := internet.DialSystem(gctx, net.TCPDestination(address, port), sockopt) diff --git a/transport/internet/http/dialer.go b/transport/internet/http/dialer.go index acccd0b72ad6..0148658c5b1d 100644 --- a/transport/internet/http/dialer.go +++ b/transport/internet/http/dialer.go @@ -68,7 +68,7 @@ func getHTTPClient(ctx context.Context, dest net.Destination, streamSettings *in address := net.ParseAddress(rawHost) hctx = session.ContextWithID(hctx, session.IDFromContext(ctx)) - hctx = session.ContextWithOutbound(hctx, session.OutboundFromContext(ctx)) + hctx = session.ContextWithOutbounds(hctx, session.OutboundsFromContext(ctx)) hctx = session.ContextWithTimeoutOnly(hctx, true) pconn, err := internet.DialSystem(hctx, net.TCPDestination(address, port), sockopt) From 3f5aea2d0584e0383ed97f443af9fc256ae1f856 Mon Sep 17 00:00:00 2001 From: yuhan6665 <1588741+yuhan6665@users.noreply.github.com> Date: Sun, 12 May 2024 16:05:50 -0400 Subject: [PATCH 2/4] Fill outbound tag info --- app/dispatcher/default.go | 1 + app/proxyman/outbound/handler.go | 5 ++++- transport/internet/dialer.go | 6 +++++- 3 files changed, 10 insertions(+), 2 deletions(-) diff --git a/app/dispatcher/default.go b/app/dispatcher/default.go index fa8c5d006f4b..26019bbe9e0e 100644 --- a/app/dispatcher/default.go +++ b/app/dispatcher/default.go @@ -428,6 +428,7 @@ func (d *DefaultDispatcher) routedDispatch(ctx context.Context, link *transport. return } + ob.Tag = handler.Tag() if accessMessage := log.AccessMessageFromContext(ctx); accessMessage != nil { if tag := handler.Tag(); tag != "" { if inTag == "" { diff --git a/app/proxyman/outbound/handler.go b/app/proxyman/outbound/handler.go index 8150c7fa9aad..4262c76a5eb7 100644 --- a/app/proxyman/outbound/handler.go +++ b/app/proxyman/outbound/handler.go @@ -245,7 +245,10 @@ func (h *Handler) Dial(ctx context.Context, dest net.Destination) (stat.Connecti if handler != nil { newError("proxying to ", tag, " for dest ", dest).AtDebug().WriteToLog(session.ExportIDToError(ctx)) outbounds := session.OutboundsFromContext(ctx) - ctx = session.ContextWithOutbounds(ctx, append(outbounds, &session.Outbound{Target: dest})) // add another outbound in session ctx + ctx = session.ContextWithOutbounds(ctx, append(outbounds, &session.Outbound{ + Target: dest, + Tag: tag, + })) // add another outbound in session ctx opts := pipe.OptionsFromContext(ctx) uplinkReader, uplinkWriter := pipe.New(opts...) downlinkReader, downlinkWriter := pipe.New(opts...) diff --git a/transport/internet/dialer.go b/transport/internet/dialer.go index 8231f03263e7..ffa868a30314 100644 --- a/transport/internet/dialer.go +++ b/transport/internet/dialer.go @@ -113,7 +113,11 @@ func redirect(ctx context.Context, dst net.Destination, obt string) net.Conn { newError("redirecting request " + dst.String() + " to " + obt).WriteToLog(session.ExportIDToError(ctx)) h := obm.GetHandler(obt) outbounds := session.OutboundsFromContext(ctx) - ctx = session.ContextWithOutbounds(ctx, append(outbounds, &session.Outbound{Target: dst, Gateway: nil})) // add another outbound in session ctx + ctx = session.ContextWithOutbounds(ctx, append(outbounds, &session.Outbound{ + Target: dst, + Gateway: nil, + Tag: obt, + })) // add another outbound in session ctx if h != nil { ur, uw := pipe.New(pipe.OptionsFromContext(ctx)...) dr, dw := pipe.New(pipe.OptionsFromContext(ctx)...) From 75891918f5e041ed04bfe8736a089f2a736f0328 Mon Sep 17 00:00:00 2001 From: yuhan6665 <1588741+yuhan6665@users.noreply.github.com> Date: Sun, 12 May 2024 17:20:57 -0400 Subject: [PATCH 3/4] Splice now checks capalibility from all outbounds --- proxy/proxy.go | 100 ++++++++++++++++++++----------- proxy/vless/encoding/encoding.go | 20 +++++-- proxy/vless/inbound/inbound.go | 6 +- proxy/vless/outbound/outbound.go | 4 +- 4 files changed, 84 insertions(+), 46 deletions(-) diff --git a/proxy/proxy.go b/proxy/proxy.go index 6a5a1798a350..2507d0296282 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -474,45 +474,73 @@ func CopyRawConnIfExist(ctx context.Context, readerConn net.Conn, writerConn net readerConn, readCounter, _ := UnwrapRawConn(readerConn) writerConn, _, writeCounter := UnwrapRawConn(writerConn) reader := buf.NewReader(readerConn) - if inbound := session.InboundFromContext(ctx); inbound != nil { - if tc, ok := writerConn.(*net.TCPConn); ok && readerConn != nil && writerConn != nil && (runtime.GOOS == "linux" || runtime.GOOS == "android") { - for inbound.CanSpliceCopy != 3 { - if inbound.CanSpliceCopy == 1 { - newError("CopyRawConn splice").WriteToLog(session.ExportIDToError(ctx)) - statWriter, _ := writer.(*dispatcher.SizeStatWriter) - //runtime.Gosched() // necessary - time.Sleep(time.Millisecond) // without this, there will be a rare ssl error for freedom splice - w, err := tc.ReadFrom(readerConn) - if readCounter != nil { - readCounter.Add(w) // outbound stats - } - if writeCounter != nil { - writeCounter.Add(w) // inbound stats - } - if statWriter != nil { - statWriter.Counter.Add(w) // user stats - } - if err != nil && errors.Cause(err) != io.EOF { - return err - } - return nil - } - buffer, err := reader.ReadMultiBuffer() - if !buffer.IsEmpty() { - if readCounter != nil { - readCounter.Add(int64(buffer.Len())) - } - timer.Update() - if werr := writer.WriteMultiBuffer(buffer); werr != nil { - return werr - } - } - if err != nil { - return err - } + if runtime.GOOS != "linux" && runtime.GOOS != "android" { + return readV(ctx, reader, writer, timer, readCounter) + } + tc, ok := writerConn.(*net.TCPConn) + if !ok || readerConn == nil || writerConn == nil { + return readV(ctx, reader, writer, timer, readCounter) + } + inbound := session.InboundFromContext(ctx) + if inbound == nil || inbound.CanSpliceCopy == 3 { + return readV(ctx, reader, writer, timer, readCounter) + } + outbounds := session.OutboundsFromContext(ctx) + if len(outbounds) == 0 { + return readV(ctx, reader, writer, timer, readCounter) + } + for _, ob := range outbounds { + if ob.CanSpliceCopy == 3 { + return readV(ctx, reader, writer, timer, readCounter) + } + } + + for { + inbound := session.InboundFromContext(ctx) + outbounds := session.OutboundsFromContext(ctx) + var splice = inbound.CanSpliceCopy == 1 + for _, ob := range outbounds { + if ob.CanSpliceCopy != 1 { + splice = false + } + } + if splice { + newError("CopyRawConn splice").WriteToLog(session.ExportIDToError(ctx)) + statWriter, _ := writer.(*dispatcher.SizeStatWriter) + //runtime.Gosched() // necessary + time.Sleep(time.Millisecond) // without this, there will be a rare ssl error for freedom splice + w, err := tc.ReadFrom(readerConn) + if readCounter != nil { + readCounter.Add(w) // outbound stats + } + if writeCounter != nil { + writeCounter.Add(w) // inbound stats + } + if statWriter != nil { + statWriter.Counter.Add(w) // user stats + } + if err != nil && errors.Cause(err) != io.EOF { + return err } + return nil + } + buffer, err := reader.ReadMultiBuffer() + if !buffer.IsEmpty() { + if readCounter != nil { + readCounter.Add(int64(buffer.Len())) + } + timer.Update() + if werr := writer.WriteMultiBuffer(buffer); werr != nil { + return werr + } + } + if err != nil { + return err } } +} + +func readV(ctx context.Context, reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater, readCounter stats.Counter) error { newError("CopyRawConn readv").WriteToLog(session.ExportIDToError(ctx)) if err := buf.Copy(reader, writer, buf.UpdateActivity(timer), buf.AddToStatCounter(readCounter)); err != nil { return newError("failed to process response").Base(err) diff --git a/proxy/vless/encoding/encoding.go b/proxy/vless/encoding/encoding.go index 5956389af799..2976be749a23 100644 --- a/proxy/vless/encoding/encoding.go +++ b/proxy/vless/encoding/encoding.go @@ -174,15 +174,18 @@ func DecodeResponseHeader(reader io.Reader, request *protocol.RequestHeader) (*A } // XtlsRead filter and read xtls protocol -func XtlsRead(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater, conn net.Conn, input *bytes.Reader, rawInput *bytes.Buffer, trafficState *proxy.TrafficState, ctx context.Context) error { +func XtlsRead(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater, conn net.Conn, input *bytes.Reader, rawInput *bytes.Buffer, trafficState *proxy.TrafficState, ob *session.Outbound, ctx context.Context) error { err := func() error { for { if trafficState.ReaderSwitchToDirectCopy { var writerConn net.Conn - if inbound := session.InboundFromContext(ctx); inbound != nil && inbound.Conn != nil { + if inbound := session.InboundFromContext(ctx); inbound != nil && inbound.Conn != nil && ob != nil { writerConn = inbound.Conn if inbound.CanSpliceCopy == 2 { - inbound.CanSpliceCopy = 1 // force the value to 1, don't use setter + inbound.CanSpliceCopy = 1 + } + if ob.CanSpliceCopy == 2 { // ob need to be passed in due to context can change + ob.CanSpliceCopy = 1 } } return proxy.CopyRawConnIfExist(ctx, conn, writerConn, writer, timer) @@ -219,14 +222,19 @@ func XtlsRead(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater } // XtlsWrite filter and write xtls protocol -func XtlsWrite(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater, conn net.Conn, trafficState *proxy.TrafficState, ctx context.Context) error { +func XtlsWrite(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater, conn net.Conn, trafficState *proxy.TrafficState, ob *session.Outbound, ctx context.Context) error { err := func() error { var ct stats.Counter for { buffer, err := reader.ReadMultiBuffer() if trafficState.WriterSwitchToDirectCopy { - if inbound := session.InboundFromContext(ctx); inbound != nil && inbound.CanSpliceCopy == 2 { - inbound.CanSpliceCopy = 1 // force the value to 1, don't use setter + if inbound := session.InboundFromContext(ctx); inbound != nil && ob != nil { + if inbound.CanSpliceCopy == 2 { + inbound.CanSpliceCopy = 1 + } + if ob.CanSpliceCopy == 2 { + ob.CanSpliceCopy = 1 + } } rawConn, _, writerCounter := proxy.UnwrapRawConn(conn) writer = buf.NewWriter(rawConn) diff --git a/proxy/vless/inbound/inbound.go b/proxy/vless/inbound/inbound.go index 9f7096e16d41..7d2dd5071819 100644 --- a/proxy/vless/inbound/inbound.go +++ b/proxy/vless/inbound/inbound.go @@ -523,7 +523,7 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s if requestAddons.Flow == vless.XRV { ctx1 := session.ContextWithInbound(ctx, nil) // TODO enable splice clientReader = proxy.NewVisionReader(clientReader, trafficState, ctx1) - err = encoding.XtlsRead(clientReader, serverWriter, timer, connection, input, rawInput, trafficState, ctx1) + err = encoding.XtlsRead(clientReader, serverWriter, timer, connection, input, rawInput, trafficState, nil, ctx1) } else { // from clientReader.ReadMultiBuffer to serverWriter.WriteMultiBufer err = buf.Copy(clientReader, serverWriter, buf.UpdateActivity(timer)) @@ -560,7 +560,9 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s var err error if requestAddons.Flow == vless.XRV { - err = encoding.XtlsWrite(serverReader, clientWriter, timer, connection, trafficState, ctx) + outbounds := session.OutboundsFromContext(ctx) + ob := outbounds[len(outbounds) - 1] + err = encoding.XtlsWrite(serverReader, clientWriter, timer, connection, trafficState, ob, ctx) } else { // from serverReader.ReadMultiBuffer to clientWriter.WriteMultiBufer err = buf.Copy(serverReader, clientWriter, buf.UpdateActivity(timer)) diff --git a/proxy/vless/outbound/outbound.go b/proxy/vless/outbound/outbound.go index 495dd74b42aa..bf98253b4547 100644 --- a/proxy/vless/outbound/outbound.go +++ b/proxy/vless/outbound/outbound.go @@ -235,7 +235,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte } } ctx1 := session.ContextWithInbound(ctx, nil) // TODO enable splice - err = encoding.XtlsWrite(clientReader, serverWriter, timer, conn, trafficState, ctx1) + err = encoding.XtlsWrite(clientReader, serverWriter, timer, conn, trafficState, ob, ctx1) } else { // from clientReader.ReadMultiBuffer to serverWriter.WriteMultiBufer err = buf.Copy(clientReader, serverWriter, buf.UpdateActivity(timer)) @@ -273,7 +273,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte } if requestAddons.Flow == vless.XRV { - err = encoding.XtlsRead(serverReader, clientWriter, timer, conn, input, rawInput, trafficState, ctx) + err = encoding.XtlsRead(serverReader, clientWriter, timer, conn, input, rawInput, trafficState, ob, ctx) } else { // from serverReader.ReadMultiBuffer to clientWriter.WriteMultiBufer err = buf.Copy(serverReader, clientWriter, buf.UpdateActivity(timer)) From faede2bff2cbee2fceab6a3c61f66de2c5a14908 Mon Sep 17 00:00:00 2001 From: yuhan6665 <1588741+yuhan6665@users.noreply.github.com> Date: Mon, 13 May 2024 21:33:49 -0400 Subject: [PATCH 4/4] Fix unit tests --- app/proxyman/outbound/handler_test.go | 3 +++ proxy/blackhole/blackhole_test.go | 6 ++++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/app/proxyman/outbound/handler_test.go b/app/proxyman/outbound/handler_test.go index e5b67308ccfc..3f7ef28e0710 100644 --- a/app/proxyman/outbound/handler_test.go +++ b/app/proxyman/outbound/handler_test.go @@ -14,6 +14,7 @@ import ( "github.com/xtls/xray-core/app/stats" "github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/common/serial" + "github.com/xtls/xray-core/common/session" core "github.com/xtls/xray-core/core" "github.com/xtls/xray-core/features/outbound" "github.com/xtls/xray-core/proxy/freedom" @@ -44,6 +45,7 @@ func TestOutboundWithoutStatCounter(t *testing.T) { v, _ := core.New(config) v.AddFeature((outbound.Manager)(new(Manager))) ctx := context.WithValue(context.Background(), xrayKey, v) + ctx = session.ContextWithOutbounds(ctx, []*session.Outbound{{}}) h, _ := NewHandler(ctx, &core.OutboundHandlerConfig{ Tag: "tag", ProxySettings: serial.ToTypedMessage(&freedom.Config{}), @@ -73,6 +75,7 @@ func TestOutboundWithStatCounter(t *testing.T) { v, _ := core.New(config) v.AddFeature((outbound.Manager)(new(Manager))) ctx := context.WithValue(context.Background(), xrayKey, v) + ctx = session.ContextWithOutbounds(ctx, []*session.Outbound{{}}) h, _ := NewHandler(ctx, &core.OutboundHandlerConfig{ Tag: "tag", ProxySettings: serial.ToTypedMessage(&freedom.Config{}), diff --git a/proxy/blackhole/blackhole_test.go b/proxy/blackhole/blackhole_test.go index 8e487e0c10fe..6a9cb8e84e80 100644 --- a/proxy/blackhole/blackhole_test.go +++ b/proxy/blackhole/blackhole_test.go @@ -7,13 +7,15 @@ import ( "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/buf" "github.com/xtls/xray-core/common/serial" + "github.com/xtls/xray-core/common/session" "github.com/xtls/xray-core/proxy/blackhole" "github.com/xtls/xray-core/transport" "github.com/xtls/xray-core/transport/pipe" ) func TestBlackholeHTTPResponse(t *testing.T) { - handler, err := blackhole.New(context.Background(), &blackhole.Config{ + ctx := session.ContextWithOutbounds(context.Background(), []*session.Outbound{{}}) + handler, err := blackhole.New(ctx, &blackhole.Config{ Response: serial.ToTypedMessage(&blackhole.HTTPResponse{}), }) common.Must(err) @@ -32,7 +34,7 @@ func TestBlackholeHTTPResponse(t *testing.T) { Reader: reader, Writer: writer, } - common.Must(handler.Process(context.Background(), &link, nil)) + common.Must(handler.Process(ctx, &link, nil)) common.Must(rerr) if mb.IsEmpty() { t.Error("expect http response, but nothing")