diff --git a/src/main/java/redis/clients/jedis/Transaction.java b/src/main/java/redis/clients/jedis/Transaction.java index ce8269d0af..278e5fd4bd 100644 --- a/src/main/java/redis/clients/jedis/Transaction.java +++ b/src/main/java/redis/clients/jedis/Transaction.java @@ -42,18 +42,26 @@ protected final void processPipelinedResponses() { @Override public final List exec() { - List ret = super.exec(); - if (jedis != null) { - jedis.resetState(); + List ret; + try { + ret = super.exec(); + } finally { + if (jedis != null) { + jedis.resetState(); + } } return ret; } @Override public final String discard() { - String ret = super.discard(); - if (jedis != null) { - jedis.resetState(); + String ret; + try { + ret = super.discard(); + } finally { + if (jedis != null) { + jedis.resetState(); + } } return ret; } diff --git a/src/main/java/redis/clients/jedis/TransactionBase.java b/src/main/java/redis/clients/jedis/TransactionBase.java index 6054def615..e17885e48d 100644 --- a/src/main/java/redis/clients/jedis/TransactionBase.java +++ b/src/main/java/redis/clients/jedis/TransactionBase.java @@ -18,6 +18,7 @@ import redis.clients.jedis.commands.PipelineCommands; import redis.clients.jedis.commands.ProtocolCommand; import redis.clients.jedis.commands.RedisModulePipelineCommands; +import redis.clients.jedis.exceptions.JedisConnectionException; import redis.clients.jedis.exceptions.JedisDataException; import redis.clients.jedis.json.JsonSetParams; import redis.clients.jedis.json.Path; @@ -34,6 +35,7 @@ public abstract class TransactionBase extends Queable implements PipelineCommand protected final Connection connection; private final CommandObjects commandObjects; + private boolean broken = false; private boolean inWatch = false; private boolean inMulti = false; @@ -93,6 +95,9 @@ public final void close() { } public final void clear() { + if (broken) { + return; + } if (inMulti) { discard(); } else if (inWatch) { @@ -103,41 +108,56 @@ public final void clear() { protected abstract void processPipelinedResponses(); public List exec() { - if (!inMulti) throw new IllegalStateException("EXEC without MULTI"); - // ignore QUEUED or ERROR -// connection.getMany(1 + getPipelinedResponseLength()); - processPipelinedResponses(); - connection.sendCommand(EXEC); - inMulti = false; - inWatch = false; - - List unformatted = connection.getObjectMultiBulkReply(); - if (unformatted == null) { - clean(); - return null; + if (!inMulti) { + throw new IllegalStateException("EXEC without MULTI"); } - List formatted = new ArrayList<>(unformatted.size()); - for (Object o : unformatted) { - try { - formatted.add(generateResponse(o).get()); - } catch (JedisDataException e) { - formatted.add(e); + + try { + processPipelinedResponses(); + connection.sendCommand(EXEC); + + List unformatted = connection.getObjectMultiBulkReply(); + if (unformatted == null) { + clean(); + return null; + } + + List formatted = new ArrayList<>(unformatted.size()); + for (Object o : unformatted) { + try { + formatted.add(generateResponse(o).get()); + } catch (JedisDataException e) { + formatted.add(e); + } } + return formatted; + } catch (JedisConnectionException jce) { + broken = true; + throw jce; + } finally { + inMulti = false; + inWatch = false; + clean(); } - return formatted; } public String discard() { - if (!inMulti) throw new IllegalStateException("DISCARD without MULTI"); - // ignore QUEUED or ERROR -// connection.getMany(1 + getPipelinedResponseLength()); - processPipelinedResponses(); - connection.sendCommand(DISCARD); - String status = connection.getStatusCodeReply(); // OK - inMulti = false; - inWatch = false; - clean(); - return status; + if (!inMulti) { + throw new IllegalStateException("DISCARD without MULTI"); + } + + try { + processPipelinedResponses(); + connection.sendCommand(DISCARD); + return connection.getStatusCodeReply(); + } catch (JedisConnectionException jce) { + broken = true; + throw jce; + } finally { + inMulti = false; + inWatch = false; + clean(); + } } @Override diff --git a/src/test/java/redis/clients/jedis/commands/jedis/TransactionCommandsTest.java b/src/test/java/redis/clients/jedis/commands/jedis/TransactionCommandsTest.java index c9a0b08181..a2153e6862 100644 --- a/src/test/java/redis/clients/jedis/commands/jedis/TransactionCommandsTest.java +++ b/src/test/java/redis/clients/jedis/commands/jedis/TransactionCommandsTest.java @@ -2,7 +2,7 @@ import static org.junit.Assert.*; -import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.any; import static redis.clients.jedis.Protocol.Command.INCR; import static redis.clients.jedis.Protocol.Command.GET; import static redis.clients.jedis.Protocol.Command.SET; @@ -15,11 +15,12 @@ import org.junit.After; import org.junit.Before; import org.junit.Test; +import org.mockito.MockedStatic; import org.mockito.Mockito; -import redis.clients.jedis.Connection; import redis.clients.jedis.DefaultJedisClientConfig; import redis.clients.jedis.Jedis; +import redis.clients.jedis.Protocol; import redis.clients.jedis.Response; import redis.clients.jedis.Transaction; import redis.clients.jedis.exceptions.JedisConnectionException; @@ -151,14 +152,44 @@ public void discard() { assertEquals("OK", status); } - @Test(expected = JedisConnectionException.class) - public void discardMock() { - Connection mock = Mockito.spy(jedis.getConnection()); - Mockito.doThrow(new JedisConnectionException("mock")).when(mock).getMany(anyInt()); - Transaction trans = new Jedis(mock).multi(); + @Test + public void discardFail() { + Transaction trans = jedis.multi(); + trans.set("a", "a"); + trans.set("b", "b"); + + try (MockedStatic protocol = Mockito.mockStatic(Protocol.class)) { + protocol.when(() -> Protocol.read(any())).thenThrow(JedisConnectionException.class); + + trans.discard(); + fail("Should get mocked JedisConnectionException."); + } catch (JedisConnectionException jce) { + // should be here + } finally { + // close() should pass + trans.close(); + } + assertTrue(jedis.isBroken()); + } + + @Test + public void execFail() { + Transaction trans = jedis.multi(); trans.set("a", "a"); trans.set("b", "b"); - trans.discard(); + + try (MockedStatic protocol = Mockito.mockStatic(Protocol.class)) { + protocol.when(() -> Protocol.read(any())).thenThrow(JedisConnectionException.class); + + trans.exec(); + fail("Should get mocked JedisConnectionException."); + } catch (JedisConnectionException jce) { + // should be here + } finally { + // close() should pass + trans.close(); + } + assertTrue(jedis.isBroken()); } @Test