diff --git a/include/net/mptcp.h b/include/net/mptcp.h index 80ca6a7cec779..2455b8138de86 100644 --- a/include/net/mptcp.h +++ b/include/net/mptcp.h @@ -825,6 +825,7 @@ void mptcp_sub_close_wq(struct work_struct *work); void mptcp_sub_close(struct sock *sk, unsigned long delay); struct sock *mptcp_select_ack_sock(const struct sock *meta_sk); void mptcp_fallback_meta_sk(struct sock *meta_sk); +void mptcp_prepare_for_backlog(struct sock *sk, struct sk_buff *skb); int mptcp_backlog_rcv(struct sock *meta_sk, struct sk_buff *skb); void mptcp_ack_handler(struct timer_list *t); bool mptcp_check_rtt(const struct tcp_sock *tp, int time); @@ -1395,6 +1396,7 @@ static inline bool mptcp_fallback_infinite(const struct sock *sk, int flag) return false; } static inline void mptcp_init_mp_opt(const struct mptcp_options_received *mopt) {} +static inline void mptcp_prepare_for_backlog(struct sock *sk, struct sk_buff *skb) {} static inline bool mptcp_check_rtt(const struct tcp_sock *tp, int time) { return false; diff --git a/net/ipv4/tcp_ipv4.c b/net/ipv4/tcp_ipv4.c index 3bec4f0c2c322..a9e072a6913ea 100644 --- a/net/ipv4/tcp_ipv4.c +++ b/net/ipv4/tcp_ipv4.c @@ -1814,7 +1814,7 @@ int tcp_v4_rcv(struct sk_buff *skb) bh_lock_sock_nested(meta_sk); if (sock_owned_by_user(meta_sk)) - skb->sk = sk; + mptcp_prepare_for_backlog(sk, skb); } else { meta_sk = sk; bh_lock_sock_nested(sk); diff --git a/net/ipv6/tcp_ipv6.c b/net/ipv6/tcp_ipv6.c index eea44f1c9084a..3d31592ad4dd5 100644 --- a/net/ipv6/tcp_ipv6.c +++ b/net/ipv6/tcp_ipv6.c @@ -1608,7 +1608,7 @@ static int tcp_v6_rcv(struct sk_buff *skb) bh_lock_sock_nested(meta_sk); if (sock_owned_by_user(meta_sk)) - skb->sk = sk; + mptcp_prepare_for_backlog(sk, skb); } else { meta_sk = sk; bh_lock_sock_nested(sk); diff --git a/net/mptcp/mptcp_ctrl.c b/net/mptcp/mptcp_ctrl.c index c9d8b5c115ebe..1c6b5d5e1bead 100644 --- a/net/mptcp/mptcp_ctrl.c +++ b/net/mptcp/mptcp_ctrl.c @@ -988,6 +988,16 @@ static void mptcp_sub_inherit_sockopts(const struct sock *meta_sk, struct sock * inet_sk(sub_sk)->recverr = 0; } +void mptcp_prepare_for_backlog(struct sock *sk, struct sk_buff *skb) +{ + /* In case of success (in mptcp_backlog_rcv) and error (in kfree_skb) of + * sk_add_backlog, we will decrement the sk refcount. + */ + sock_hold(sk); + skb->sk = sk; + skb->destructor = sock_efree; +} + int mptcp_backlog_rcv(struct sock *meta_sk, struct sk_buff *skb) { /* skb-sk may be NULL if we receive a packet immediatly after the @@ -996,13 +1006,17 @@ int mptcp_backlog_rcv(struct sock *meta_sk, struct sk_buff *skb) struct sock *sk = skb->sk ? skb->sk : meta_sk; int ret = 0; - skb->sk = NULL; - if (unlikely(!refcount_inc_not_zero(&sk->sk_refcnt))) { kfree_skb(skb); return 0; } + /* Decrement sk refcnt when calling the skb destructor. + * Refcnt is incremented and skb destructor is set in tcp_v{4,6}_rcv via + * mptcp_prepare_for_backlog() here above. + */ + skb_orphan(skb); + if (sk->sk_family == AF_INET) ret = tcp_v4_do_rcv(sk, skb); #if IS_ENABLED(CONFIG_IPV6)