From 44d34a6ab8529301712d4707c9a27a1a310971af Mon Sep 17 00:00:00 2001 From: M Sazzadul Hoque <7600764+sazzad16@users.noreply.github.com> Date: Tue, 9 Apr 2024 19:03:49 +0600 Subject: [PATCH] PubSub handle array of messages for RESP2 (#3811) * PubSub handle array of messages for RESP2 only * Modify test and add binary mode test * Edit --- .../redis/clients/jedis/JedisPubSubBase.java | 9 ++- .../jedis/PublishSubscribeCommandsTest.java | 73 +++++++++++++++++++ 2 files changed, 79 insertions(+), 3 deletions(-) diff --git a/src/main/java/redis/clients/jedis/JedisPubSubBase.java b/src/main/java/redis/clients/jedis/JedisPubSubBase.java index 552310e4de..bf9d0a32c5 100644 --- a/src/main/java/redis/clients/jedis/JedisPubSubBase.java +++ b/src/main/java/redis/clients/jedis/JedisPubSubBase.java @@ -133,10 +133,13 @@ private void process() { onUnsubscribe(enchannel, subscribedChannels); } else if (Arrays.equals(MESSAGE.getRaw(), resp)) { final byte[] bchannel = (byte[]) listReply.get(1); - final byte[] bmesg = (byte[]) listReply.get(2); + final Object mesg = listReply.get(2); final T enchannel = (bchannel == null) ? null : encode(bchannel); - final T enmesg = (bmesg == null) ? null : encode(bmesg); - onMessage(enchannel, enmesg); + if (mesg instanceof List) { + ((List) mesg).forEach(bmesg -> onMessage(enchannel, encode(bmesg))); + } else { + onMessage(enchannel, (mesg == null) ? null : encode((byte[]) mesg)); + } } else if (Arrays.equals(PMESSAGE.getRaw(), resp)) { final byte[] bpattern = (byte[]) listReply.get(1); final byte[] bchannel = (byte[]) listReply.get(2); diff --git a/src/test/java/redis/clients/jedis/commands/jedis/PublishSubscribeCommandsTest.java b/src/test/java/redis/clients/jedis/commands/jedis/PublishSubscribeCommandsTest.java index 75968fdc04..240c198716 100644 --- a/src/test/java/redis/clients/jedis/commands/jedis/PublishSubscribeCommandsTest.java +++ b/src/test/java/redis/clients/jedis/commands/jedis/PublishSubscribeCommandsTest.java @@ -4,7 +4,9 @@ import static org.hamcrest.Matchers.hasItems; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; +import static redis.clients.jedis.Protocol.Command.CLIENT; import java.io.IOException; import java.net.UnknownHostException; @@ -15,7 +17,9 @@ import java.util.Map; import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Consumer; +import org.hamcrest.Matchers; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; @@ -537,4 +541,73 @@ private String makeLargeString(int size) { return sb.toString(); } + + @Test(timeout = 5000) + public void subscribeCacheInvalidateChannel() { + org.junit.Assume.assumeThat(protocol, Matchers.not(RedisProtocol.RESP3)); + + final String cacheInvalidate = "__redis__:invalidate"; + final AtomicBoolean onMessage = new AtomicBoolean(false); + final JedisPubSub pubsub = new JedisPubSub() { + @Override public void onMessage(String channel, String message) { + onMessage.set(true); + assertEquals(cacheInvalidate, channel); + if (message != null) { + assertEquals("foo", message); + consumeJedis(j -> j.flushAll()); + } else { + unsubscribe(channel); + } + } + + @Override public void onSubscribe(String channel, int subscribedChannels) { + assertEquals(cacheInvalidate, channel); + consumeJedis(j -> j.set("foo", "bar")); + } + }; + + try (Jedis subscriber = createJedis()) { + long clientId = subscriber.clientId(); + subscriber.sendCommand(CLIENT, "TRACKING", "ON", "REDIRECT", Long.toString(clientId), "BCAST"); + subscriber.subscribe(pubsub, cacheInvalidate); + assertTrue("Subscriber didn't get any message.", onMessage.get()); + } + } + + @Test(timeout = 5000) + public void subscribeCacheInvalidateChannelBinary() { + org.junit.Assume.assumeThat(protocol, Matchers.not(RedisProtocol.RESP3)); + + final byte[] cacheInvalidate = "__redis__:invalidate".getBytes(); + final AtomicBoolean onMessage = new AtomicBoolean(false); + final BinaryJedisPubSub pubsub = new BinaryJedisPubSub() { + @Override public void onMessage(byte[] channel, byte[] message) { + onMessage.set(true); + assertArrayEquals(cacheInvalidate, channel); + if (message != null) { + assertArrayEquals("foo".getBytes(), message); + consumeJedis(j -> j.flushAll()); + } else { + unsubscribe(channel); + } + } + + @Override public void onSubscribe(byte[] channel, int subscribedChannels) { + assertArrayEquals(cacheInvalidate, channel); + consumeJedis(j -> j.set("foo".getBytes(), "bar".getBytes())); + } + }; + + try (Jedis subscriber = createJedis()) { + long clientId = subscriber.clientId(); + subscriber.sendCommand(CLIENT, "TRACKING", "ON", "REDIRECT", Long.toString(clientId), "BCAST"); + subscriber.subscribe(pubsub, cacheInvalidate); + assertTrue("Subscriber didn't get any message.", onMessage.get()); + } + } + + private void consumeJedis(Consumer consumer) { + Thread t = new Thread(() -> consumer.accept(jedis)); + t.start(); + } }