Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add session context outbounds as slice #3356

Merged
merged 4 commits into from
May 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 13 additions & 9 deletions app/dispatcher/default.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,11 +218,12 @@ func (d *DefaultDispatcher) Dispatch(ctx context.Context, destination net.Destin
if !destination.IsValid() {
panic("Dispatcher: Invalid destination.")
}
ob := session.OutboundFromContext(ctx)
if ob == nil {
ob = &session.Outbound{}
ctx = session.ContextWithOutbound(ctx, ob)
outbounds := session.OutboundsFromContext(ctx)
if len(outbounds) == 0 {
outbounds = []*session.Outbound{{}}
ctx = session.ContextWithOutbounds(ctx, outbounds)
}
ob := outbounds[len(outbounds) - 1]
ob.OriginalTarget = destination
ob.Target = destination
content := session.ContentFromContext(ctx)
Expand Down Expand Up @@ -274,11 +275,12 @@ func (d *DefaultDispatcher) DispatchLink(ctx context.Context, destination net.De
if !destination.IsValid() {
return newError("Dispatcher: Invalid destination.")
}
ob := session.OutboundFromContext(ctx)
if ob == nil {
ob = &session.Outbound{}
ctx = session.ContextWithOutbound(ctx, ob)
outbounds := session.OutboundsFromContext(ctx)
if len(outbounds) == 0 {
outbounds = []*session.Outbound{{}}
ctx = session.ContextWithOutbounds(ctx, outbounds)
}
ob := outbounds[len(outbounds) - 1]
ob.OriginalTarget = destination
ob.Target = destination
content := session.ContentFromContext(ctx)
Expand Down Expand Up @@ -368,7 +370,8 @@ func sniffer(ctx context.Context, cReader *cachedReader, metadataOnly bool, netw
return contentResult, contentErr
}
func (d *DefaultDispatcher) routedDispatch(ctx context.Context, link *transport.Link, destination net.Destination) {
ob := session.OutboundFromContext(ctx)
outbounds := session.OutboundsFromContext(ctx)
ob := outbounds[len(outbounds) - 1]
if hosts, ok := d.dns.(dns.HostsLookup); ok && destination.Address.Family().IsDomain() {
proxied := hosts.LookupHosts(ob.Target.String())
if proxied != nil {
Expand Down Expand Up @@ -425,6 +428,7 @@ func (d *DefaultDispatcher) routedDispatch(ctx context.Context, link *transport.
return
}

ob.Tag = handler.Tag()
if accessMessage := log.AccessMessageFromContext(ctx); accessMessage != nil {
if tag := handler.Tag(); tag != "" {
if inTag == "" {
Expand Down
11 changes: 6 additions & 5 deletions app/dispatcher/fakednssniffer.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,20 @@ func newFakeDNSSniffer(ctx context.Context) (protocolSnifferWithMetadata, error)
return protocolSnifferWithMetadata{}, errNotInit
}
return protocolSnifferWithMetadata{protocolSniffer: func(ctx context.Context, bytes []byte) (SniffResult, error) {
Target := session.OutboundFromContext(ctx).Target
if Target.Network == net.Network_TCP || Target.Network == net.Network_UDP {
domainFromFakeDNS := fakeDNSEngine.GetDomainFromFakeDNS(Target.Address)
outbounds := session.OutboundsFromContext(ctx)
ob := outbounds[len(outbounds) - 1]
if ob.Target.Network == net.Network_TCP || ob.Target.Network == net.Network_UDP {
domainFromFakeDNS := fakeDNSEngine.GetDomainFromFakeDNS(ob.Target.Address)
if domainFromFakeDNS != "" {
newError("fake dns got domain: ", domainFromFakeDNS, " for ip: ", Target.Address.String()).WriteToLog(session.ExportIDToError(ctx))
newError("fake dns got domain: ", domainFromFakeDNS, " for ip: ", ob.Target.Address.String()).WriteToLog(session.ExportIDToError(ctx))
return &fakeDNSSniffResult{domainName: domainFromFakeDNS}, nil
}
}

if ipAddressInRangeValueI := ctx.Value(ipAddressInRange); ipAddressInRangeValueI != nil {
ipAddressInRangeValue := ipAddressInRangeValueI.(*ipAddressInRangeOpt)
if fkr0, ok := fakeDNSEngine.(dns.FakeDNSEngineRev0); ok {
inPool := fkr0.IsIPInIPPool(Target.Address)
inPool := fkr0.IsIPInIPPool(ob.Target.Address)
ipAddressInRangeValue.addressInRange = &inPool
}
}
Expand Down
11 changes: 6 additions & 5 deletions app/proxyman/inbound/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func (w *tcpWorker) callback(conn stat.Connection) {
sid := session.NewID()
ctx = session.ContextWithID(ctx, sid)

var outbound = &session.Outbound{}
outbounds := []*session.Outbound{{}}
if w.recvOrigDest {
var dest net.Destination
switch getTProxyType(w.stream) {
Expand All @@ -75,10 +75,10 @@ func (w *tcpWorker) callback(conn stat.Connection) {
dest = net.DestinationFromAddr(conn.LocalAddr())
}
if dest.IsValid() {
outbound.Target = dest
outbounds[0].Target = dest
}
}
ctx = session.ContextWithOutbound(ctx, outbound)
ctx = session.ContextWithOutbounds(ctx, outbounds)

if w.uplinkCounter != nil || w.downlinkCounter != nil {
conn = &stat.CounterConnection{
Expand Down Expand Up @@ -309,9 +309,10 @@ func (w *udpWorker) callback(b *buf.Buffer, source net.Destination, originalDest
ctx = session.ContextWithID(ctx, sid)

if originalDest.IsValid() {
ctx = session.ContextWithOutbound(ctx, &session.Outbound{
outbounds := []*session.Outbound{{
Target: originalDest,
})
}}
ctx = session.ContextWithOutbounds(ctx, outbounds)
}
ctx = session.ContextWithInbound(ctx, &session.Inbound{
Source: source,
Expand Down
38 changes: 18 additions & 20 deletions app/proxyman/outbound/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,10 +169,11 @@ func (h *Handler) Tag() string {

// Dispatch implements proxy.Outbound.Dispatch.
func (h *Handler) Dispatch(ctx context.Context, link *transport.Link) {
outbound := session.OutboundFromContext(ctx)
if outbound.Target.Network == net.Network_UDP && outbound.OriginalTarget.Address != nil && outbound.OriginalTarget.Address != outbound.Target.Address {
link.Reader = &buf.EndpointOverrideReader{Reader: link.Reader, Dest: outbound.Target.Address, OriginalDest: outbound.OriginalTarget.Address}
link.Writer = &buf.EndpointOverrideWriter{Writer: link.Writer, Dest: outbound.Target.Address, OriginalDest: outbound.OriginalTarget.Address}
outbounds := session.OutboundsFromContext(ctx)
ob := outbounds[len(outbounds) - 1]
if ob.Target.Network == net.Network_UDP && ob.OriginalTarget.Address != nil && ob.OriginalTarget.Address != ob.Target.Address {
link.Reader = &buf.EndpointOverrideReader{Reader: link.Reader, Dest: ob.Target.Address, OriginalDest: ob.OriginalTarget.Address}
link.Writer = &buf.EndpointOverrideWriter{Writer: link.Writer, Dest: ob.Target.Address, OriginalDest: ob.OriginalTarget.Address}
}
if h.mux != nil {
test := func(err error) {
Expand All @@ -183,7 +184,7 @@ func (h *Handler) Dispatch(ctx context.Context, link *transport.Link) {
common.Interrupt(link.Writer)
}
}
if outbound.Target.Network == net.Network_UDP && outbound.Target.Port == 443 {
if ob.Target.Network == net.Network_UDP && ob.Target.Port == 443 {
switch h.udp443 {
case "reject":
test(newError("XUDP rejected UDP/443 traffic").AtInfo())
Expand All @@ -192,7 +193,7 @@ func (h *Handler) Dispatch(ctx context.Context, link *transport.Link) {
goto out
}
}
if h.xudp != nil && outbound.Target.Network == net.Network_UDP {
if h.xudp != nil && ob.Target.Network == net.Network_UDP {
if !h.xudp.Enabled {
goto out
}
Expand Down Expand Up @@ -243,10 +244,11 @@ func (h *Handler) Dial(ctx context.Context, dest net.Destination) (stat.Connecti
handler := h.outboundManager.GetHandler(tag)
if handler != nil {
newError("proxying to ", tag, " for dest ", dest).AtDebug().WriteToLog(session.ExportIDToError(ctx))
ctx = session.ContextWithOutbound(ctx, &session.Outbound{
outbounds := session.OutboundsFromContext(ctx)
ctx = session.ContextWithOutbounds(ctx, append(outbounds, &session.Outbound{
Target: dest,
})

Tag: tag,
})) // add another outbound in session ctx
opts := pipe.OptionsFromContext(ctx)
uplinkReader, uplinkWriter := pipe.New(opts...)
downlinkReader, downlinkWriter := pipe.New(opts...)
Expand All @@ -266,15 +268,12 @@ func (h *Handler) Dial(ctx context.Context, dest net.Destination) (stat.Connecti
}

if h.senderSettings.Via != nil {
outbound := session.OutboundFromContext(ctx)
if outbound == nil {
outbound = new(session.Outbound)
ctx = session.ContextWithOutbound(ctx, outbound)
}
outbounds := session.OutboundsFromContext(ctx)
ob := outbounds[len(outbounds) - 1]
if h.senderSettings.ViaCidr == "" {
outbound.Gateway = h.senderSettings.Via.AsAddress()
ob.Gateway = h.senderSettings.Via.AsAddress()
} else { //Get a random address.
outbound.Gateway = ParseRandomIPv6(h.senderSettings.Via.AsAddress(), h.senderSettings.ViaCidr)
ob.Gateway = ParseRandomIPv6(h.senderSettings.Via.AsAddress(), h.senderSettings.ViaCidr)
}
}
}
Expand All @@ -285,10 +284,9 @@ func (h *Handler) Dial(ctx context.Context, dest net.Destination) (stat.Connecti

conn, err := internet.Dial(ctx, dest, h.streamSettings)
conn = h.getStatCouterConnection(conn)
outbound := session.OutboundFromContext(ctx)
if outbound != nil {
outbound.Conn = conn
}
outbounds := session.OutboundsFromContext(ctx)
ob := outbounds[len(outbounds) - 1]
ob.Conn = conn
return conn, err
}

Expand Down
3 changes: 3 additions & 0 deletions app/proxyman/outbound/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/xtls/xray-core/app/stats"
"github.com/xtls/xray-core/common/net"
"github.com/xtls/xray-core/common/serial"
"github.com/xtls/xray-core/common/session"
core "github.com/xtls/xray-core/core"
"github.com/xtls/xray-core/features/outbound"
"github.com/xtls/xray-core/proxy/freedom"
Expand Down Expand Up @@ -44,6 +45,7 @@ func TestOutboundWithoutStatCounter(t *testing.T) {
v, _ := core.New(config)
v.AddFeature((outbound.Manager)(new(Manager)))
ctx := context.WithValue(context.Background(), xrayKey, v)
ctx = session.ContextWithOutbounds(ctx, []*session.Outbound{{}})
h, _ := NewHandler(ctx, &core.OutboundHandlerConfig{
Tag: "tag",
ProxySettings: serial.ToTypedMessage(&freedom.Config{}),
Expand Down Expand Up @@ -73,6 +75,7 @@ func TestOutboundWithStatCounter(t *testing.T) {
v, _ := core.New(config)
v.AddFeature((outbound.Manager)(new(Manager)))
ctx := context.WithValue(context.Background(), xrayKey, v)
ctx = session.ContextWithOutbounds(ctx, []*session.Outbound{{}})
h, _ := NewHandler(ctx, &core.OutboundHandlerConfig{
Tag: "tag",
ProxySettings: serial.ToTypedMessage(&freedom.Config{}),
Expand Down
12 changes: 7 additions & 5 deletions app/reverse/portal.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,13 @@ func (p *Portal) Close() error {
}

func (p *Portal) HandleConnection(ctx context.Context, link *transport.Link) error {
outboundMeta := session.OutboundFromContext(ctx)
if outboundMeta == nil {
outbounds := session.OutboundsFromContext(ctx)
ob := outbounds[len(outbounds) - 1]
if ob == nil {
return newError("outbound metadata not found").AtError()
}

if isDomain(outboundMeta.Target, p.domain) {
if isDomain(ob.Target, p.domain) {
muxClient, err := mux.NewClientWorker(*link, mux.ClientStrategy{})
if err != nil {
return newError("failed to create mux client worker").Base(err).AtWarning()
Expand Down Expand Up @@ -206,9 +207,10 @@ func NewPortalWorker(client *mux.ClientWorker) (*PortalWorker, error) {
downlinkReader, downlinkWriter := pipe.New(opt...)

ctx := context.Background()
ctx = session.ContextWithOutbound(ctx, &session.Outbound{
outbounds := []*session.Outbound{{
Target: net.UDPDestination(net.DomainAddress(internalDomain), 0),
})
}}
ctx = session.ContextWithOutbounds(ctx, outbounds)
f := client.Dispatch(ctx, &transport.Link{
Reader: uplinkReader,
Writer: downlinkWriter,
Expand Down
20 changes: 15 additions & 5 deletions app/router/router_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ func TestSimpleRouter(t *testing.T) {
HandlerSelector: mockHs,
}, nil))

ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{Target: net.TCPDestination(net.DomainAddress("example.com"), 80)})
ctx := session.ContextWithOutbounds(context.Background(), []*session.Outbound{{
Target: net.TCPDestination(net.DomainAddress("example.com"), 80),
}})
route, err := r.PickRoute(routing_session.AsRoutingContext(ctx))
common.Must(err)
if tag := route.GetOutboundTag(); tag != "test" {
Expand Down Expand Up @@ -86,7 +88,9 @@ func TestSimpleBalancer(t *testing.T) {
HandlerSelector: mockHs,
}, nil))

ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{Target: net.TCPDestination(net.DomainAddress("example.com"), 80)})
ctx := session.ContextWithOutbounds(context.Background(), []*session.Outbound{{
Target: net.TCPDestination(net.DomainAddress("example.com"), 80),
}})
route, err := r.PickRoute(routing_session.AsRoutingContext(ctx))
common.Must(err)
if tag := route.GetOutboundTag(); tag != "test" {
Expand Down Expand Up @@ -174,7 +178,9 @@ func TestIPOnDemand(t *testing.T) {
r := new(Router)
common.Must(r.Init(context.TODO(), config, mockDNS, nil, nil))

ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{Target: net.TCPDestination(net.DomainAddress("example.com"), 80)})
ctx := session.ContextWithOutbounds(context.Background(), []*session.Outbound{{
Target: net.TCPDestination(net.DomainAddress("example.com"), 80),
}})
route, err := r.PickRoute(routing_session.AsRoutingContext(ctx))
common.Must(err)
if tag := route.GetOutboundTag(); tag != "test" {
Expand Down Expand Up @@ -213,7 +219,9 @@ func TestIPIfNonMatchDomain(t *testing.T) {
r := new(Router)
common.Must(r.Init(context.TODO(), config, mockDNS, nil, nil))

ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{Target: net.TCPDestination(net.DomainAddress("example.com"), 80)})
ctx := session.ContextWithOutbounds(context.Background(), []*session.Outbound{{
Target: net.TCPDestination(net.DomainAddress("example.com"), 80),
}})
route, err := r.PickRoute(routing_session.AsRoutingContext(ctx))
common.Must(err)
if tag := route.GetOutboundTag(); tag != "test" {
Expand Down Expand Up @@ -247,7 +255,9 @@ func TestIPIfNonMatchIP(t *testing.T) {
r := new(Router)
common.Must(r.Init(context.TODO(), config, mockDNS, nil, nil))

ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{Target: net.TCPDestination(net.LocalHostIP, 80)})
ctx := session.ContextWithOutbounds(context.Background(), []*session.Outbound{{
Target: net.TCPDestination(net.LocalHostIP, 80),
}})
route, err := r.PickRoute(routing_session.AsRoutingContext(ctx))
common.Must(err)
if tag := route.GetOutboundTag(); tag != "test" {
Expand Down
14 changes: 8 additions & 6 deletions common/mux/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,10 @@ func (f *DialingWorkerFactory) Create() (*ClientWorker, error) {
}

go func(p proxy.Outbound, d internet.Dialer, c common.Closable) {
ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{
outbounds := []*session.Outbound{{
Target: net.TCPDestination(muxCoolAddress, muxCoolPort),
})
}}
ctx := session.ContextWithOutbounds(context.Background(), outbounds)
ctx, cancel := context.WithCancel(ctx)

if err := p.Process(ctx, &transport.Link{Reader: uplinkReader, Writer: downlinkWriter}, d); err != nil {
Expand Down Expand Up @@ -242,17 +243,18 @@ func writeFirstPayload(reader buf.Reader, writer *Writer) error {
}

func fetchInput(ctx context.Context, s *Session, output buf.Writer) {
dest := session.OutboundFromContext(ctx).Target
outbounds := session.OutboundsFromContext(ctx)
ob := outbounds[len(outbounds) - 1]
transferType := protocol.TransferTypeStream
if dest.Network == net.Network_UDP {
if ob.Target.Network == net.Network_UDP {
transferType = protocol.TransferTypePacket
}
s.transferType = transferType
writer := NewWriter(s.ID, dest, output, transferType, xudp.GetGlobalID(ctx))
writer := NewWriter(s.ID, ob.Target, output, transferType, xudp.GetGlobalID(ctx))
defer s.Close(false)
defer writer.Close()

newError("dispatching request to ", dest).WriteToLog(session.ExportIDToError(ctx))
newError("dispatching request to ", ob.Target).WriteToLog(session.ExportIDToError(ctx))
if err := writeFirstPayload(s.input, writer); err != nil {
newError("failed to write first payload").Base(err).WriteToLog(session.ExportIDToError(ctx))
writer.hasError = true
Expand Down
8 changes: 4 additions & 4 deletions common/mux/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,9 @@ func TestClientWorkerClose(t *testing.T) {
}

tr1, tw1 := pipe.New(pipe.WithoutSizeLimit())
ctx1 := session.ContextWithOutbound(context.Background(), &session.Outbound{
ctx1 := session.ContextWithOutbounds(context.Background(), []*session.Outbound{{
Target: net.TCPDestination(net.DomainAddress("www.example.com"), 80),
})
}})
common.Must(manager.Dispatch(ctx1, &transport.Link{
Reader: tr1,
Writer: tw1,
Expand All @@ -103,9 +103,9 @@ func TestClientWorkerClose(t *testing.T) {
}

tr2, tw2 := pipe.New(pipe.WithoutSizeLimit())
ctx2 := session.ContextWithOutbound(context.Background(), &session.Outbound{
ctx2 := session.ContextWithOutbounds(context.Background(), []*session.Outbound{{
Target: net.TCPDestination(net.DomainAddress("www.example.com"), 80),
})
}})
common.Must(manager.Dispatch(ctx2, &transport.Link{
Reader: tr2,
Writer: tw2,
Expand Down
10 changes: 5 additions & 5 deletions common/session/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,13 @@ func InboundFromContext(ctx context.Context) *Inbound {
return nil
}

func ContextWithOutbound(ctx context.Context, outbound *Outbound) context.Context {
return context.WithValue(ctx, outboundSessionKey, outbound)
func ContextWithOutbounds(ctx context.Context, outbounds []*Outbound) context.Context {
return context.WithValue(ctx, outboundSessionKey, outbounds)
}

func OutboundFromContext(ctx context.Context) *Outbound {
if outbound, ok := ctx.Value(outboundSessionKey).(*Outbound); ok {
return outbound
func OutboundsFromContext(ctx context.Context) []*Outbound {
if outbounds, ok := ctx.Value(outboundSessionKey).([]*Outbound); ok {
return outbounds
}
return nil
}
Expand Down
Loading