diff --git a/channel.go b/channel.go index 1b76277d..cca873de 100644 --- a/channel.go +++ b/channel.go @@ -94,6 +94,13 @@ type ChannelOptions struct { // This is an unstable API - breaking changes are likely. RelayTimerVerification bool + // EnableMPTCP enables MPTCP for TCP network connection to increase reliability. + // It requires underlying operating system support MPTCP. + // If EnableMPTCP is false or no MPTCP support, the connection will use normal TCP. + // It's set to false by default. + // If a Dialer is passed as option, this value will be ignored. + EnableMPTCP bool + // The reporter to use for reporting stats for this channel. StatsReporter StatsReporter @@ -184,6 +191,7 @@ type Channel struct { relayMaxConnTimeout time.Duration relayMaxTombs uint64 relayTimerVerify bool + enableMPTCP bool internalHandlers *handlerMap handler Handler onPeerStatusChanged func(*Peer) @@ -275,8 +283,12 @@ func NewChannel(serviceName string, opts *ChannelOptions) (*Channel, error) { return nil, err } - // Default to dialContext if dialer is not passed in as an option + // Default to dialContext or dialMPTCPContex + // if dialer is not passed in as an option dialCtx := dialContext + if opts.EnableMPTCP { + dialCtx = dialMPTCPContext + } if opts.Dialer != nil { dialCtx = func(ctx context.Context, hostPort string) (net.Conn, error) { return opts.Dialer(ctx, "tcp", hostPort) @@ -306,6 +318,7 @@ func NewChannel(serviceName string, opts *ChannelOptions) (*Channel, error) { relayMaxConnTimeout: opts.RelayMaxConnectionTimeout, relayMaxTombs: opts.RelayMaxTombs, relayTimerVerify: opts.RelayTimerVerification, + enableMPTCP: opts.EnableMPTCP, dialer: dialCtx, connContext: opts.ConnContext, closed: make(chan struct{}), @@ -402,7 +415,9 @@ func (ch *Channel) ListenAndServe(hostPort string) error { return errAlreadyListening } - l, err := net.Listen("tcp", hostPort) + lc := net.ListenConfig{} + lc.SetMultipathTCP(ch.enableMPTCP) + l, err := lc.Listen(context.Background(), "tcp", hostPort) if err != nil { mutable.RUnlock() return err diff --git a/channel_test.go b/channel_test.go index 3757a922..0c09d94c 100644 --- a/channel_test.go +++ b/channel_test.go @@ -44,24 +44,27 @@ func toMap(fields LogFields) map[string]interface{} { } func TestNewChannel(t *testing.T) { - ch, err := NewChannel("svc", &ChannelOptions{ - ProcessName: "pname", - }) - require.NoError(t, err, "NewChannel failed") - - assert.Equal(t, LocalPeerInfo{ - ServiceName: "svc", - PeerInfo: PeerInfo{ + for _, mptcp := range []bool{true, false} { + ch, err := NewChannel("svc", &ChannelOptions{ ProcessName: "pname", - HostPort: ephemeralHostPort, - IsEphemeral: true, - Version: PeerVersion{ - Language: "go", - LanguageVersion: strings.TrimPrefix(runtime.Version(), "go"), - TChannelVersion: VersionInfo, + EnableMPTCP: mptcp, + }) + require.NoError(t, err, "NewChannel failed") + + assert.Equal(t, LocalPeerInfo{ + ServiceName: "svc", + PeerInfo: PeerInfo{ + ProcessName: "pname", + HostPort: ephemeralHostPort, + IsEphemeral: true, + Version: PeerVersion{ + Language: "go", + LanguageVersion: strings.TrimPrefix(runtime.Version(), "go"), + TChannelVersion: VersionInfo, + }, }, - }, - }, ch.PeerInfo(), "Wrong local peer info") + }, ch.PeerInfo(), "Wrong local peer info") + } } func TestLoggers(t *testing.T) { diff --git a/dial_17.go b/dial_17.go index 313a754a..b1d72c20 100644 --- a/dial_17.go +++ b/dial_17.go @@ -32,3 +32,9 @@ func dialContext(ctx context.Context, hostPort string) (net.Conn, error) { d := net.Dialer{} return d.DialContext(ctx, "tcp", hostPort) } + +func dialMPTCPContext(ctx context.Context, hostPort string) (net.Conn, error) { + d := net.Dialer{} + d.SetMultipathTCP(true) + return d.DialContext(ctx, "tcp", hostPort) +}