Skip to content

Commit

Permalink
Merge branch 'fix sockmap + stream af_unix memleak'
Browse files Browse the repository at this point in the history
John Fastabend says:

====================
There was a memleak when streaming af_unix sockets were inserted into
multiple sockmap slots and/or maps. This is because each insert would
call a proto update operatino and these must be allowed to be called
multiple times. The streaming af_unix implementation recently added
a refcnt to handle a use after free issue, however it introduced a
memleak when inserted into multiple maps.

This series fixes the memleak, adds a note in the code so we remember
that proto updates need to support this. And then we add three tests
for each of the slightly different iterations of adding sockets into
multiple maps. I kept them as 3 independent test cases here. I have
some slight preference for this they could however be a single test,
but then you don't get to run them independently which was sort of
useful while debugging.
====================

Signed-off-by: Martin KaFai Lau <martin.lau@kernel.org>
  • Loading branch information
Martin KaFai Lau committed Jan 4, 2024
2 parents b456005 + bdbca46 commit 417fa6d
Show file tree
Hide file tree
Showing 3 changed files with 236 additions and 4 deletions.
5 changes: 5 additions & 0 deletions include/linux/skmsg.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,11 @@ struct sk_psock {
void (*saved_close)(struct sock *sk, long timeout);
void (*saved_write_space)(struct sock *sk);
void (*saved_data_ready)(struct sock *sk);
/* psock_update_sk_prot may be called with restore=false many times
* so the handler must be safe for this case. It will be called
* exactly once with restore=true when the psock is being destroyed
* and psock refcnt is zero, but before an RCU grace period.
*/
int (*psock_update_sk_prot)(struct sock *sk, struct sk_psock *psock,
bool restore);
struct proto *sk_proto;
Expand Down
21 changes: 18 additions & 3 deletions net/unix/unix_bpf.c
Original file line number Diff line number Diff line change
Expand Up @@ -161,15 +161,30 @@ int unix_stream_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool r
{
struct sock *sk_pair;

/* Restore does not decrement the sk_pair reference yet because we must
* keep the a reference to the socket until after an RCU grace period
* and any pending sends have completed.
*/
if (restore) {
sk->sk_write_space = psock->saved_write_space;
sock_replace_proto(sk, psock->sk_proto);
return 0;
}

sk_pair = unix_peer(sk);
sock_hold(sk_pair);
psock->sk_pair = sk_pair;
/* psock_update_sk_prot can be called multiple times if psock is
* added to multiple maps and/or slots in the same map. There is
* also an edge case where replacing a psock with itself can trigger
* an extra psock_update_sk_prot during the insert process. So it
* must be safe to do multiple calls. Here we need to ensure we don't
* increment the refcnt through sock_hold many times. There will only
* be a single matching destroy operation.
*/
if (!psock->sk_pair) {
sk_pair = unix_peer(sk);
sock_hold(sk_pair);
psock->sk_pair = sk_pair;
}

unix_stream_bpf_check_needs_rebuild(psock->sk_proto);
sock_replace_proto(sk, &unix_stream_bpf_prot);
return 0;
Expand Down
214 changes: 213 additions & 1 deletion tools/testing/selftests/bpf/prog_tests/sockmap_basic.c
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,213 @@ static void test_sockmap_unconnected_unix(void)
close(dgram);
}

static void test_sockmap_many_socket(void)
{
struct test_sockmap_pass_prog *skel;
int stream[2], dgram, udp, tcp;
int i, err, map, entry = 0;

skel = test_sockmap_pass_prog__open_and_load();
if (!ASSERT_OK_PTR(skel, "open_and_load"))
return;

map = bpf_map__fd(skel->maps.sock_map_rx);

dgram = xsocket(AF_UNIX, SOCK_DGRAM, 0);
if (dgram < 0) {
test_sockmap_pass_prog__destroy(skel);
return;
}

tcp = connected_socket_v4();
if (!ASSERT_GE(tcp, 0, "connected_socket_v4")) {
close(dgram);
test_sockmap_pass_prog__destroy(skel);
return;
}

udp = xsocket(AF_INET, SOCK_DGRAM | SOCK_NONBLOCK, 0);
if (udp < 0) {
close(dgram);
close(tcp);
test_sockmap_pass_prog__destroy(skel);
return;
}

err = socketpair(AF_UNIX, SOCK_STREAM, 0, stream);
ASSERT_OK(err, "socketpair(af_unix, sock_stream)");
if (err)
goto out;

for (i = 0; i < 2; i++, entry++) {
err = bpf_map_update_elem(map, &entry, &stream[0], BPF_ANY);
ASSERT_OK(err, "bpf_map_update_elem(stream)");
}
for (i = 0; i < 2; i++, entry++) {
err = bpf_map_update_elem(map, &entry, &dgram, BPF_ANY);
ASSERT_OK(err, "bpf_map_update_elem(dgram)");
}
for (i = 0; i < 2; i++, entry++) {
err = bpf_map_update_elem(map, &entry, &udp, BPF_ANY);
ASSERT_OK(err, "bpf_map_update_elem(udp)");
}
for (i = 0; i < 2; i++, entry++) {
err = bpf_map_update_elem(map, &entry, &tcp, BPF_ANY);
ASSERT_OK(err, "bpf_map_update_elem(tcp)");
}
for (entry--; entry >= 0; entry--) {
err = bpf_map_delete_elem(map, &entry);
ASSERT_OK(err, "bpf_map_delete_elem(entry)");
}

close(stream[0]);
close(stream[1]);
out:
close(dgram);
close(tcp);
close(udp);
test_sockmap_pass_prog__destroy(skel);
}

static void test_sockmap_many_maps(void)
{
struct test_sockmap_pass_prog *skel;
int stream[2], dgram, udp, tcp;
int i, err, map[2], entry = 0;

skel = test_sockmap_pass_prog__open_and_load();
if (!ASSERT_OK_PTR(skel, "open_and_load"))
return;

map[0] = bpf_map__fd(skel->maps.sock_map_rx);
map[1] = bpf_map__fd(skel->maps.sock_map_tx);

dgram = xsocket(AF_UNIX, SOCK_DGRAM, 0);
if (dgram < 0) {
test_sockmap_pass_prog__destroy(skel);
return;
}

tcp = connected_socket_v4();
if (!ASSERT_GE(tcp, 0, "connected_socket_v4")) {
close(dgram);
test_sockmap_pass_prog__destroy(skel);
return;
}

udp = xsocket(AF_INET, SOCK_DGRAM | SOCK_NONBLOCK, 0);
if (udp < 0) {
close(dgram);
close(tcp);
test_sockmap_pass_prog__destroy(skel);
return;
}

err = socketpair(AF_UNIX, SOCK_STREAM, 0, stream);
ASSERT_OK(err, "socketpair(af_unix, sock_stream)");
if (err)
goto out;

for (i = 0; i < 2; i++, entry++) {
err = bpf_map_update_elem(map[i], &entry, &stream[0], BPF_ANY);
ASSERT_OK(err, "bpf_map_update_elem(stream)");
}
for (i = 0; i < 2; i++, entry++) {
err = bpf_map_update_elem(map[i], &entry, &dgram, BPF_ANY);
ASSERT_OK(err, "bpf_map_update_elem(dgram)");
}
for (i = 0; i < 2; i++, entry++) {
err = bpf_map_update_elem(map[i], &entry, &udp, BPF_ANY);
ASSERT_OK(err, "bpf_map_update_elem(udp)");
}
for (i = 0; i < 2; i++, entry++) {
err = bpf_map_update_elem(map[i], &entry, &tcp, BPF_ANY);
ASSERT_OK(err, "bpf_map_update_elem(tcp)");
}
for (entry--; entry >= 0; entry--) {
err = bpf_map_delete_elem(map[1], &entry);
entry--;
ASSERT_OK(err, "bpf_map_delete_elem(entry)");
err = bpf_map_delete_elem(map[0], &entry);
ASSERT_OK(err, "bpf_map_delete_elem(entry)");
}

close(stream[0]);
close(stream[1]);
out:
close(dgram);
close(tcp);
close(udp);
test_sockmap_pass_prog__destroy(skel);
}

static void test_sockmap_same_sock(void)
{
struct test_sockmap_pass_prog *skel;
int stream[2], dgram, udp, tcp;
int i, err, map, zero = 0;

skel = test_sockmap_pass_prog__open_and_load();
if (!ASSERT_OK_PTR(skel, "open_and_load"))
return;

map = bpf_map__fd(skel->maps.sock_map_rx);

dgram = xsocket(AF_UNIX, SOCK_DGRAM, 0);
if (dgram < 0) {
test_sockmap_pass_prog__destroy(skel);
return;
}

tcp = connected_socket_v4();
if (!ASSERT_GE(tcp, 0, "connected_socket_v4")) {
close(dgram);
test_sockmap_pass_prog__destroy(skel);
return;
}

udp = xsocket(AF_INET, SOCK_DGRAM | SOCK_NONBLOCK, 0);
if (udp < 0) {
close(dgram);
close(tcp);
test_sockmap_pass_prog__destroy(skel);
return;
}

err = socketpair(AF_UNIX, SOCK_STREAM, 0, stream);
ASSERT_OK(err, "socketpair(af_unix, sock_stream)");
if (err)
goto out;

for (i = 0; i < 2; i++) {
err = bpf_map_update_elem(map, &zero, &stream[0], BPF_ANY);
ASSERT_OK(err, "bpf_map_update_elem(stream)");
}
for (i = 0; i < 2; i++) {
err = bpf_map_update_elem(map, &zero, &dgram, BPF_ANY);
ASSERT_OK(err, "bpf_map_update_elem(dgram)");
}
for (i = 0; i < 2; i++) {
err = bpf_map_update_elem(map, &zero, &udp, BPF_ANY);
ASSERT_OK(err, "bpf_map_update_elem(udp)");
}
for (i = 0; i < 2; i++) {
err = bpf_map_update_elem(map, &zero, &tcp, BPF_ANY);
ASSERT_OK(err, "bpf_map_update_elem(tcp)");
}

err = bpf_map_delete_elem(map, &zero);
ASSERT_OK(err, "bpf_map_delete_elem(entry)");

close(stream[0]);
close(stream[1]);
out:
close(dgram);
close(tcp);
close(udp);
test_sockmap_pass_prog__destroy(skel);
}

void test_sockmap_basic(void)
{
if (test__start_subtest("sockmap create_update_free"))
Expand Down Expand Up @@ -597,7 +804,12 @@ void test_sockmap_basic(void)
test_sockmap_skb_verdict_fionread(false);
if (test__start_subtest("sockmap skb_verdict msg_f_peek"))
test_sockmap_skb_verdict_peek();

if (test__start_subtest("sockmap unconnected af_unix"))
test_sockmap_unconnected_unix();
if (test__start_subtest("sockmap one socket to many map entries"))
test_sockmap_many_socket();
if (test__start_subtest("sockmap one socket to many maps"))
test_sockmap_many_maps();
if (test__start_subtest("sockmap same socket replace"))
test_sockmap_same_sock();
}

0 comments on commit 417fa6d

Please sign in to comment.