Skip to content

Commit

Permalink
KAFKA-4090: Validate SSL connection in client
Browse files Browse the repository at this point in the history
  • Loading branch information
gurinderu committed Jan 11, 2022
1 parent 2e3396e commit 68cfa4e
Show file tree
Hide file tree
Showing 9 changed files with 502 additions and 109 deletions.
3 changes: 2 additions & 1 deletion checkstyle/checkstyle.xml
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,10 @@
<module name="Header">
<property name="headerFile" value="${headerFile}" />
</module>
<module name="SuppressWarningsFilter" />

<module name="TreeWalker">

<module name="SuppressWarningsHolder" />
<!-- code cleanup -->
<module name="UnusedImports">
<property name="processJavadoc" value="true" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,13 @@
import java.io.EOFException;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.ReadableByteChannel;
import java.nio.channels.ScatteringByteChannel;
import java.util.concurrent.atomic.AtomicInteger;

/**
* A size delimited Receive that consists of a 4 byte network-ordered size N followed by N bytes of content
* A size delimited Receive that consists of a 4 byte network-ordered size N
* followed by N bytes of content.
*/
public class NetworkReceive implements Receive {

Expand All @@ -36,129 +39,189 @@ public class NetworkReceive implements Receive {
private static final ByteBuffer EMPTY_BUFFER = ByteBuffer.allocate(0);

private final String source;
private final ByteBuffer size;
private final ByteBuffer sizeBuf;
private final ByteBuffer minBuf;
private final int maxSize;
private final MemoryPool memoryPool;
private final AtomicInteger byteCount;
private int requestedBufferSize = -1;
private ByteBuffer buffer;
private ByteBuffer payloadBuffer = null;
private volatile ReadState readState = ReadState.READ_SIZE;

enum ReadState {
READ_SIZE, VALIDATE_SIZE, ALLOCATE_BUFFER, READ_PAYLOAD, COMPLETE
}

public NetworkReceive(String source, ByteBuffer buffer) {
this.source = source;
this.buffer = buffer;
this.size = null;
this.maxSize = UNLIMITED;
this.memoryPool = MemoryPool.NONE;
public NetworkReceive() {
this(UNKNOWN_SOURCE);
}

public NetworkReceive(String source) {
this.source = source;
this.size = ByteBuffer.allocate(4);
this.buffer = null;
this.maxSize = UNLIMITED;
this.memoryPool = MemoryPool.NONE;
this(UNLIMITED, source);
}

public NetworkReceive(String source, ByteBuffer buffer) {
this(source);
this.payloadBuffer = buffer;
}

public NetworkReceive(int maxSize, String source) {
this.source = source;
this.size = ByteBuffer.allocate(4);
this.buffer = null;
this.maxSize = maxSize;
this.memoryPool = MemoryPool.NONE;
this(maxSize, source, MemoryPool.NONE);
}

public NetworkReceive(int maxSize, String source, MemoryPool memoryPool) {
this.source = source;
this.size = ByteBuffer.allocate(4);
this.buffer = null;
this.maxSize = maxSize;
this.memoryPool = memoryPool;
}

public NetworkReceive() {
this(UNKNOWN_SOURCE);
}

@Override
public String source() {
return source;
}

@Override
public boolean complete() {
return !size.hasRemaining() && buffer != null && !buffer.hasRemaining();
this.minBuf = (ByteBuffer) ByteBuffer.allocate(SslUtils.SSL_RECORD_HEADER_LENGTH).position(4);
this.sizeBuf = (ByteBuffer) this.minBuf.duplicate().position(0).limit(4);
this.byteCount = new AtomicInteger(0);
}

@SuppressWarnings("fallthrough")
public long readFrom(ScatteringByteChannel channel) throws IOException {
int read = 0;
if (size.hasRemaining()) {
int bytesRead = channel.read(size);
if (bytesRead < 0)
throw new EOFException();
read += bytesRead;
if (!size.hasRemaining()) {
size.rewind();
int receiveSize = size.getInt();
if (receiveSize < 0)
throw new InvalidReceiveException("Invalid receive (size = " + receiveSize + ")");
if (maxSize != UNLIMITED && receiveSize > maxSize)
throw new InvalidReceiveException("Invalid receive (size = " + receiveSize + " larger than " + maxSize + ")");
requestedBufferSize = receiveSize; //may be 0 for some payloads (SASL)
if (receiveSize == 0) {
buffer = EMPTY_BUFFER;

switch (readState) {
case READ_SIZE:
read += readRequestedBufferSize(channel);
if (this.sizeBuf.hasRemaining()) {
break;
}
}
this.readState = ReadState.VALIDATE_SIZE;
/** FALLTHROUGH TO NEXT STATE */
case VALIDATE_SIZE:
if (this.requestedBufferSize != 0) {
read += validateRequestedBufferSize(channel);
if (this.minBuf.hasRemaining()) {
break;
}
}
this.readState = ReadState.ALLOCATE_BUFFER;
/** FALLTHROUGH */
case ALLOCATE_BUFFER:
if (this.requestedBufferSize == 0) {
this.payloadBuffer = EMPTY_BUFFER;
} else {
this.payloadBuffer = tryAllocateBuffer(this.requestedBufferSize);
if (this.payloadBuffer == null) {
break;
} else {
// Copy any bytes that were already consumed
this.minBuf.position(this.sizeBuf.limit());
this.payloadBuffer.put(this.minBuf);
}
}
this.readState = ReadState.READ_PAYLOAD;
/** FALLTHROUGH TO NEXT STATE */
case READ_PAYLOAD:
final int payloadRead = channel.read(payloadBuffer);
if (payloadRead < 0)
throw new EOFException();
read += payloadRead;
if (!this.payloadBuffer.hasRemaining()) {
this.readState = ReadState.COMPLETE;
}
break;
case COMPLETE:
break;
}
if (buffer == null && requestedBufferSize != -1) { //we know the size we want but havent been able to allocate it yet
buffer = memoryPool.tryAllocate(requestedBufferSize);
if (buffer == null)
log.trace("Broker low on memory - could not allocate buffer of size {} for source {}", requestedBufferSize, source);

this.byteCount.addAndGet(read);

return read;
}

private int validateRequestedBufferSize(final ScatteringByteChannel channel)
throws IOException {
int minRead = channel.read(this.minBuf);
if (minRead < 0) {
throw new EOFException();
}
if (buffer != null) {
int bytesRead = channel.read(buffer);
if (bytesRead < 0)
throw new EOFException();
read += bytesRead;
if (!this.minBuf.hasRemaining()) {
final boolean isEncrypted =
SslUtils.isEncrypted((ByteBuffer) this.minBuf.duplicate().rewind());
if (isEncrypted) {
throw new InvalidReceiveException(
"Recieved an unexpected SSL packet from the server. "
+ "Please ensure the client is properly configured with SSL enabled.");
}
if (this.requestedBufferSize < 0)
throw new InvalidReceiveException(
"Invalid receive (size = " + this.requestedBufferSize + ")");
if (maxSize != UNLIMITED && this.requestedBufferSize > maxSize)
throw new InvalidReceiveException("Invalid receive (size = "
+ this.requestedBufferSize + " larger than " + maxSize + ")");
}

return read;
return minRead;
}

private ByteBuffer tryAllocateBuffer(final int bufSize) {
final ByteBuffer bb = memoryPool.tryAllocate(bufSize);
if (bb == null) {
log.trace("Broker low on memory - could not allocate buffer of size {} for source {}",
requestedBufferSize, source);
}
return bb;
}

private int readRequestedBufferSize(final ReadableByteChannel channel) throws IOException {
final int sizeRead = channel.read(sizeBuf);
if (sizeRead < 0) {
throw new EOFException();
}
if (sizeBuf.hasRemaining()) {
return sizeRead;
}
sizeBuf.rewind();
this.requestedBufferSize = sizeBuf.getInt();
return sizeRead;
}

@Override
public boolean requiredMemoryAmountKnown() {
return requestedBufferSize != -1;
return this.readState.ordinal() > ReadState.VALIDATE_SIZE.ordinal();
}

@Override
public boolean memoryAllocated() {
return buffer != null;
return this.readState.ordinal() >= ReadState.READ_PAYLOAD.ordinal();
}

@Override
public boolean complete() {
return this.readState == ReadState.COMPLETE;
}

@Override
public void close() throws IOException {
if (buffer != null && buffer != EMPTY_BUFFER) {
memoryPool.release(buffer);
buffer = null;
if (payloadBuffer != null && payloadBuffer != EMPTY_BUFFER) {
memoryPool.release(payloadBuffer);
payloadBuffer = null;
}
}

@Override
public String source() {
return source;
}

public ByteBuffer payload() {
return this.buffer;
return this.payloadBuffer;
}

public int bytesRead() {
if (buffer == null)
return size.position();
return buffer.position() + size.position();
return this.byteCount.get();
}

/**
* Returns the total size of the receive including payload and size buffer
* for use in metrics. This is consistent with {@link NetworkSend#size()}
*/
public int size() {
return payload().limit() + size.limit();
return payload().limit() + sizeBuf.limit();
}

}
Loading

0 comments on commit 68cfa4e

Please sign in to comment.