Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use RedisCredentials with HELLO command #3507

Merged
merged 4 commits into from
Aug 21, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 57 additions & 67 deletions src/main/java/redis/clients/jedis/Connection.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package redis.clients.jedis;

import static redis.clients.jedis.util.SafeEncoder.encode;

import java.io.Closeable;
import java.io.IOException;
import java.net.Socket;
Expand All @@ -9,7 +11,6 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.function.Supplier;

import redis.clients.jedis.Protocol.Command;
Expand All @@ -25,7 +26,6 @@
import redis.clients.jedis.util.JedisMetaInfo;
import redis.clients.jedis.util.RedisInputStream;
import redis.clients.jedis.util.RedisOutputStream;
import redis.clients.jedis.util.SafeEncoder;

public class Connection implements Closeable {

Expand Down Expand Up @@ -270,14 +270,14 @@ public String getStatusCodeReply() {
if (null == resp) {
return null;
} else {
return SafeEncoder.encode(resp);
return encode(resp);
}
}

public String getBulkReply() {
final byte[] result = getBinaryBulkReply();
if (null != result) {
return SafeEncoder.encode(result);
return encode(result);
} else {
return null;
}
Expand Down Expand Up @@ -391,44 +391,27 @@ private static boolean validateClientInfo(String info) {
private void initializeFromClientConfig(JedisClientConfig config) {
try {
connect();
protocol = config.getRedisProtocol();

boolean doClientName = true;

/// HELLO and AUTH -->
if (protocol == RedisProtocol.RESP3 && config.getUser() != null) {

hello(protocol, config.getUser(), config.getPassword(), config.getClientName());
doClientName = false;

} else {

Supplier<RedisCredentials> credentialsProvider = config.getCredentialsProvider();
if (credentialsProvider instanceof RedisCredentialsProvider) {
try {
((RedisCredentialsProvider) credentialsProvider).prepare();
auth(credentialsProvider);
} finally {
((RedisCredentialsProvider) credentialsProvider).cleanUp();
}
} else {
auth(credentialsProvider);
}
protocol = config.getRedisProtocol();

if (protocol != null) {
hello(protocol);
final Supplier<RedisCredentials> credentialsProvider = config.getCredentialsProvider();
if (credentialsProvider instanceof RedisCredentialsProvider) {
final RedisCredentialsProvider redisCredentialsProvider = (RedisCredentialsProvider) credentialsProvider;
try {
redisCredentialsProvider.prepare();
helloOrAuth(protocol, redisCredentialsProvider.get());
} finally {
redisCredentialsProvider.cleanUp();
}
}

int dbIndex = config.getDatabase();
if (dbIndex > 0) {
select(dbIndex);
} else {
helloOrAuth(protocol, credentialsProvider != null ? credentialsProvider.get()
: new DefaultRedisCredentials(config.getUser(), config.getPassword()));
}

List<CommandArguments> fireAndForgetMsg = new ArrayList<>();

String clientName = config.getClientName();
if (doClientName && clientName != null && validateClientInfo(clientName)) {
if (clientName != null && validateClientInfo(clientName)) {
fireAndForgetMsg.add(new CommandArguments(Command.CLIENT).add(Keyword.SETNAME).add(clientName));
}

Expand All @@ -448,6 +431,12 @@ private void initializeFromClientConfig(JedisClientConfig config) {
sendCommand(arg);
}
getMany(fireAndForgetMsg.size());

int dbIndex = config.getDatabase();
if (dbIndex > 0) {
select(dbIndex);
}

} catch (JedisException je) {
try {
disconnect();
Expand All @@ -458,58 +447,59 @@ private void initializeFromClientConfig(JedisClientConfig config) {
}
}

private Map hello(final RedisProtocol protocol) {
sendCommand(Protocol.Command.HELLO, String.valueOf(protocol.version()));
Map reply = BuilderFactory.ENCODED_OBJECT_MAP.build(getOne());
// LoggerFactory.getLogger(Connection.class).info("HELLO reply: {}", reply);
return reply;
}
private void helloOrAuth(final RedisProtocol protocol, final RedisCredentials credentials) {

private Map hello(final RedisProtocol protocol, final String user, final String password,
final String clientName) {
if (clientName == null) {
sendCommand(Protocol.Command.HELLO, String.valueOf(protocol.version()),
Protocol.Keyword.AUTH.name(), user, password);
} else {
sendCommand(Protocol.Command.HELLO, String.valueOf(protocol.version()),
Protocol.Keyword.AUTH.name(), user, password,
Protocol.Keyword.SETNAME.name(), clientName);
if (credentials == null || credentials.getPassword() == null) {
if (protocol != null) {
sendCommand(Command.HELLO, encode(protocol.version()));
getOne();
}
return;
}
Map reply = BuilderFactory.ENCODED_OBJECT_MAP.build(getOne());
// LoggerFactory.getLogger(Connection.class).info("HELLO reply: {}", reply);
return reply;
}

private void auth(final Supplier<RedisCredentials> credentialsProvider) {
RedisCredentials credentials = credentialsProvider.get();
if (credentials == null || credentials.getPassword() == null) return;

// Source: https://stackoverflow.com/a/9670279/4021802
ByteBuffer passBuf = Protocol.CHARSET.encode(CharBuffer.wrap(credentials.getPassword()));
byte[] rawPass = Arrays.copyOfRange(passBuf.array(), passBuf.position(), passBuf.limit());
Arrays.fill(passBuf.array(), (byte) 0); // clear sensitive data

if (credentials.getUser() != null) {
sendCommand(Protocol.Command.AUTH, SafeEncoder.encode(credentials.getUser()), rawPass);
} else {
sendCommand(Protocol.Command.AUTH, rawPass);
}
try {
/// actual HELLO or AUTH -->
if (protocol != null) {
if (credentials.getUser() != null) {
sendCommand(Command.HELLO, encode(protocol.version()),
Keyword.AUTH.getRaw(), encode(credentials.getUser()), rawPass);
getOne(); // Map
} else {
sendCommand(Command.AUTH, rawPass);
getStatusCodeReply(); // OK
sendCommand(Command.HELLO, encode(protocol.version()));
getOne(); // Map
}
} else { // protocol == null
if (credentials.getUser() != null) {
sendCommand(Command.AUTH, encode(credentials.getUser()), rawPass);
} else {
sendCommand(Command.AUTH, rawPass);
}
getStatusCodeReply(); // OK
}
/// <-- actual HELLO or AUTH
} finally {

Arrays.fill(rawPass, (byte) 0); // clear sensitive data
Arrays.fill(rawPass, (byte) 0); // clear sensitive data
}

// clearing 'char[] credentials.getPassword()' should be
// handled in RedisCredentialsProvider.cleanUp()

getStatusCodeReply(); // OK
}

public String select(final int index) {
sendCommand(Protocol.Command.SELECT, Protocol.toByteArray(index));
sendCommand(Command.SELECT, Protocol.toByteArray(index));
return getStatusCodeReply();
}

public boolean ping() {
sendCommand(Protocol.Command.PING);
sendCommand(Command.PING);
String status = getStatusCodeReply();
if (!"PONG".equals(status)) {
throw new JedisException(status);
Expand Down