diff --git a/include/net/mptcp.h b/include/net/mptcp.h index 222b551bb58b5..818a0fa3a2a96 100644 --- a/include/net/mptcp.h +++ b/include/net/mptcp.h @@ -834,6 +834,7 @@ void mptcp_sub_close_wq(struct work_struct *work); void mptcp_sub_close(struct sock *sk, unsigned long delay); struct sock *mptcp_select_ack_sock(const struct sock *meta_sk); void mptcp_fallback_meta_sk(struct sock *meta_sk); +void mptcp_prepare_for_backlog(struct sock *sk, struct sk_buff *skb); int mptcp_backlog_rcv(struct sock *meta_sk, struct sk_buff *skb); void mptcp_ack_handler(unsigned long); bool mptcp_check_rtt(const struct tcp_sock *tp, int time); @@ -1438,6 +1439,7 @@ static inline bool mptcp_fallback_infinite(const struct sock *sk, int flag) return false; } static inline void mptcp_init_mp_opt(const struct mptcp_options_received *mopt) {} +static inline void mptcp_prepare_for_backlog(struct sock *sk, struct sk_buff *skb) {} static inline bool mptcp_check_rtt(const struct tcp_sock *tp, int time) { return false; diff --git a/net/ipv4/tcp_ipv4.c b/net/ipv4/tcp_ipv4.c index 830849e683e62..15b90061a38f4 100644 --- a/net/ipv4/tcp_ipv4.c +++ b/net/ipv4/tcp_ipv4.c @@ -1742,7 +1742,7 @@ int tcp_v4_rcv(struct sk_buff *skb) } if (sock_owned_by_user(sk)) { - skb->sk = sk; + mptcp_prepare_for_backlog(sk, skb); if (unlikely(sk_add_backlog(sk, skb, sk->sk_rcvbuf + sk->sk_sndbuf))) { reqsk_put(req); @@ -1819,7 +1819,7 @@ int tcp_v4_rcv(struct sk_buff *skb) bh_lock_sock_nested(meta_sk); if (sock_owned_by_user(meta_sk)) - skb->sk = sk; + mptcp_prepare_for_backlog(sk, skb); } else { meta_sk = sk; bh_lock_sock_nested(sk); diff --git a/net/ipv6/tcp_ipv6.c b/net/ipv6/tcp_ipv6.c index dd86c448c5afb..a2c0187d83662 100644 --- a/net/ipv6/tcp_ipv6.c +++ b/net/ipv6/tcp_ipv6.c @@ -1530,7 +1530,7 @@ static int tcp_v6_rcv(struct sk_buff *skb) } if (sock_owned_by_user(sk)) { - skb->sk = sk; + mptcp_prepare_for_backlog(sk, skb); if (unlikely(sk_add_backlog(sk, skb, sk->sk_rcvbuf + sk->sk_sndbuf))) { reqsk_put(req); @@ -1606,7 +1606,7 @@ static int tcp_v6_rcv(struct sk_buff *skb) bh_lock_sock_nested(meta_sk); if (sock_owned_by_user(meta_sk)) - skb->sk = sk; + mptcp_prepare_for_backlog(sk, skb); } else { meta_sk = sk; bh_lock_sock_nested(sk); diff --git a/net/mptcp/mptcp_ctrl.c b/net/mptcp/mptcp_ctrl.c index faf8a6cdbcc71..b46682eee745f 100644 --- a/net/mptcp/mptcp_ctrl.c +++ b/net/mptcp/mptcp_ctrl.c @@ -984,6 +984,16 @@ static void mptcp_sub_inherit_sockopts(const struct sock *meta_sk, struct sock * inet_sk(sub_sk)->recverr = 0; } +void mptcp_prepare_for_backlog(struct sock *sk, struct sk_buff *skb) +{ + /* In case of success (in mptcp_backlog_rcv) and error (in kfree_skb) of + * sk_add_backlog, we will decrement the sk refcount. + */ + sock_hold(sk); + skb->sk = sk; + skb->destructor = sock_efree; +} + int mptcp_backlog_rcv(struct sock *meta_sk, struct sk_buff *skb) { /* skb-sk may be NULL if we receive a packet immediatly after the @@ -992,13 +1002,17 @@ int mptcp_backlog_rcv(struct sock *meta_sk, struct sk_buff *skb) struct sock *sk = skb->sk ? skb->sk : meta_sk; int ret = 0; - skb->sk = NULL; - if (unlikely(!refcount_inc_not_zero(&sk->sk_refcnt))) { kfree_skb(skb); return 0; } + /* Decrement sk refcnt when calling the skb destructor. + * Refcnt is incremented and skb destructor is set in tcp_v{4,6}_rcv via + * mptcp_prepare_for_backlog() here above. + */ + skb_orphan(skb); + if (sk->sk_family == AF_INET) ret = tcp_v4_do_rcv(sk, skb); #if IS_ENABLED(CONFIG_IPV6) diff --git a/net/mptcp/mptcp_input.c b/net/mptcp/mptcp_input.c index 1e73b4e857f8c..b5344c176f753 100644 --- a/net/mptcp/mptcp_input.c +++ b/net/mptcp/mptcp_input.c @@ -1184,7 +1184,6 @@ int mptcp_lookup_join(struct sk_buff *skb, struct inet_timewait_sock *tw) */ bh_lock_sock_nested(meta_sk); if (sock_owned_by_user(meta_sk)) { - skb->sk = meta_sk; if (unlikely(sk_add_backlog(meta_sk, skb, meta_sk->sk_rcvbuf + meta_sk->sk_sndbuf))) { bh_unlock_sock(meta_sk); @@ -1257,7 +1256,6 @@ int mptcp_do_join_short(struct sk_buff *skb, } if (sock_owned_by_user(meta_sk)) { - skb->sk = meta_sk; if (unlikely(sk_add_backlog(meta_sk, skb, meta_sk->sk_rcvbuf + meta_sk->sk_sndbuf))) __NET_INC_STATS(net, LINUX_MIB_TCPBACKLOGDROP);