Skip to content

Commit

Permalink
Fix StackOverflowError in Transaction (#2827)
Browse files Browse the repository at this point in the history
* Fix StackOverflowError in Transaction

* format import

* Modify exec() and discard() in Transaction

* add mockito-inline in pom
  • Loading branch information
sazzad16 authored Jan 19, 2022
1 parent dc6a891 commit 0f4bed9
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 30 deletions.
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-core</artifactId>
<artifactId>mockito-inline</artifactId>
<version>3.12.4</version>
<scope>test</scope>
</dependency>
Expand Down
78 changes: 49 additions & 29 deletions src/main/java/redis/clients/jedis/TransactionBase.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import redis.clients.jedis.commands.PipelineCommands;
import redis.clients.jedis.commands.ProtocolCommand;
import redis.clients.jedis.commands.RedisModulePipelineCommands;
import redis.clients.jedis.exceptions.JedisConnectionException;
import redis.clients.jedis.exceptions.JedisDataException;
import redis.clients.jedis.json.JsonSetParams;
import redis.clients.jedis.json.Path;
Expand All @@ -34,6 +35,7 @@ public abstract class TransactionBase extends Queable implements PipelineCommand
protected final Connection connection;
private final CommandObjects commandObjects;

private boolean broken = false;
private boolean inWatch = false;
private boolean inMulti = false;

Expand Down Expand Up @@ -93,6 +95,9 @@ public final void close() {
}

public final void clear() {
if (broken) {
return;
}
if (inMulti) {
discard();
} else if (inWatch) {
Expand All @@ -103,41 +108,56 @@ public final void clear() {
protected abstract void processPipelinedResponses();

public List<Object> exec() {
if (!inMulti) throw new IllegalStateException("EXEC without MULTI");
// ignore QUEUED or ERROR
// connection.getMany(1 + getPipelinedResponseLength());
processPipelinedResponses();
connection.sendCommand(EXEC);
inMulti = false;
inWatch = false;

List<Object> unformatted = connection.getObjectMultiBulkReply();
if (unformatted == null) {
clean();
return null;
if (!inMulti) {
throw new IllegalStateException("EXEC without MULTI");
}
List<Object> formatted = new ArrayList<>(unformatted.size());
for (Object o : unformatted) {
try {
formatted.add(generateResponse(o).get());
} catch (JedisDataException e) {
formatted.add(e);

try {
processPipelinedResponses();
connection.sendCommand(EXEC);

List<Object> unformatted = connection.getObjectMultiBulkReply();
if (unformatted == null) {
clean();
return null;
}

List<Object> formatted = new ArrayList<>(unformatted.size());
for (Object o : unformatted) {
try {
formatted.add(generateResponse(o).get());
} catch (JedisDataException e) {
formatted.add(e);
}
}
return formatted;
} catch (JedisConnectionException jce) {
broken = true;
throw jce;
} finally {
inMulti = false;
inWatch = false;
clean();
}
return formatted;
}

public String discard() {
if (!inMulti) throw new IllegalStateException("DISCARD without MULTI");
// ignore QUEUED or ERROR
// connection.getMany(1 + getPipelinedResponseLength());
processPipelinedResponses();
connection.sendCommand(DISCARD);
String status = connection.getStatusCodeReply(); // OK
inMulti = false;
inWatch = false;
clean();
return status;
if (!inMulti) {
throw new IllegalStateException("DISCARD without MULTI");
}

try {
processPipelinedResponses();
connection.sendCommand(DISCARD);
return connection.getStatusCodeReply();
} catch (JedisConnectionException jce) {
broken = true;
throw jce;
} finally {
inMulti = false;
inWatch = false;
clean();
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import static org.junit.Assert.*;

import static org.mockito.ArgumentMatchers.any;
import static redis.clients.jedis.Protocol.Command.INCR;
import static redis.clients.jedis.Protocol.Command.GET;
import static redis.clients.jedis.Protocol.Command.SET;
Expand All @@ -14,11 +15,15 @@
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.mockito.MockedStatic;
import org.mockito.Mockito;

import redis.clients.jedis.DefaultJedisClientConfig;
import redis.clients.jedis.Jedis;
import redis.clients.jedis.Protocol;
import redis.clients.jedis.Response;
import redis.clients.jedis.Transaction;
import redis.clients.jedis.exceptions.JedisConnectionException;
import redis.clients.jedis.exceptions.JedisDataException;
import redis.clients.jedis.util.SafeEncoder;

Expand Down Expand Up @@ -147,6 +152,46 @@ public void discard() {
assertEquals("OK", status);
}

@Test
public void discardFail() {
Transaction trans = jedis.multi();
trans.set("a", "a");
trans.set("b", "b");

try (MockedStatic<Protocol> protocol = Mockito.mockStatic(Protocol.class)) {
protocol.when(() -> Protocol.read(any())).thenThrow(JedisConnectionException.class);

trans.discard();
fail("Should get mocked JedisConnectionException.");
} catch (JedisConnectionException jce) {
// should be here
} finally {
// close() should pass
trans.close();
}
assertTrue(jedis.isBroken());
}

@Test
public void execFail() {
Transaction trans = jedis.multi();
trans.set("a", "a");
trans.set("b", "b");

try (MockedStatic<Protocol> protocol = Mockito.mockStatic(Protocol.class)) {
protocol.when(() -> Protocol.read(any())).thenThrow(JedisConnectionException.class);

trans.exec();
fail("Should get mocked JedisConnectionException.");
} catch (JedisConnectionException jce) {
// should be here
} finally {
// close() should pass
trans.close();
}
assertTrue(jedis.isBroken());
}

@Test
public void transactionResponse() {
jedis.set("string", "foo");
Expand Down

0 comments on commit 0f4bed9

Please sign in to comment.