Skip to content

Commit

Permalink
[ISSUE apache#4701]fix use tcp protocol client send message, it throw…
Browse files Browse the repository at this point in the history
… a DecoderException
  • Loading branch information
mxsm committed Jan 1, 2024
1 parent bd2aeb9 commit 95ac2e0
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 120 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.ByteToMessageCodec;
import io.netty.handler.codec.MessageToByteEncoder;
import io.netty.handler.codec.ReplayingDecoder;


import com.fasterxml.jackson.core.JsonProcessingException;
Expand All @@ -48,139 +46,122 @@
@Slf4j
public class Codec extends ByteToMessageCodec<Package> {

private static final int FRAME_MAX_LENGTH = 1024 * 1024 * 4;
private static final int FRAME_MAX_LENGTH = 1024 * 1024 * 4; // 4M

private static final byte[] CONSTANT_MAGIC_FLAG = serializeBytes("EventMesh");
private static final byte[] VERSION = serializeBytes("0000");

private Encoder encoder = new Encoder();
private Decoder decoder = new Decoder();

@Override
protected void encode(ChannelHandlerContext ctx, Package pkg, ByteBuf out) throws Exception {
encoder.encode(ctx, pkg, out);
}
Preconditions.checkNotNull(pkg, "TcpPackage cannot be null");
final Header header = pkg.getHeader();
Preconditions.checkNotNull(header, "TcpPackage header cannot be null", header);
if (log.isDebugEnabled()) {
log.debug("Encoder pkg={}", JsonUtils.toJSONString(pkg));
}

@Override
protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
decoder.decode(ctx, in, out);
}
final byte[] headerData = JsonUtils.toJSONBytes(header);
final byte[] bodyData;

public static class Encoder extends MessageToByteEncoder<Package> {

@Override
public void encode(ChannelHandlerContext ctx, Package pkg, ByteBuf out) throws Exception {
Preconditions.checkNotNull(pkg, "TcpPackage cannot be null");
final Header header = pkg.getHeader();
Preconditions.checkNotNull(header, "TcpPackage header cannot be null", header);
if (log.isDebugEnabled()) {
log.debug("Encoder pkg={}", JsonUtils.toJSONString(pkg));
}

final byte[] headerData = JsonUtils.toJSONBytes(header);
final byte[] bodyData;

if (StringUtils.equals(Constants.CLOUD_EVENTS_PROTOCOL_NAME, header.getStringProperty(Constants.PROTOCOL_TYPE))) {
bodyData = (byte[]) pkg.getBody();
} else {
bodyData = JsonUtils.toJSONBytes(pkg.getBody());
}

int headerLength = ArrayUtils.getLength(headerData);
int bodyLength = ArrayUtils.getLength(bodyData);

final int length = CONSTANT_MAGIC_FLAG.length + VERSION.length + headerLength + bodyLength;

if (length > FRAME_MAX_LENGTH) {
throw new IllegalArgumentException("message size is exceed limit!");
}
/**
* Header + Body, Format:
* <pre>
* ┌───────────────┬─────────────┬──────────────────┬──────────────────┬──────────────────┬─────────────────┐
* │ MAGIC_FLAG │ VERSION │ package length │ Header length │ Header │ body │
* │ (9bytes) │ (4bytes) │ (4bytes) │ (4bytes) │ (header bytes) │ (body bytes) │
* └───────────────┴─────────────┴──────────────────┴──────────────────┴──────────────────┴─────────────────┘
* </pre>
*/
out.writeBytes(CONSTANT_MAGIC_FLAG);
out.writeBytes(VERSION);
out.writeInt(length);
out.writeInt(headerLength);
if (headerData != null) {
out.writeBytes(headerData);
}
if (bodyData != null) {
out.writeBytes(bodyData);
}
if (StringUtils.equals(Constants.CLOUD_EVENTS_PROTOCOL_NAME, header.getStringProperty(Constants.PROTOCOL_TYPE))) {
bodyData = (byte[]) pkg.getBody();
} else {
bodyData = JsonUtils.toJSONBytes(pkg.getBody());
}
}

public static class Decoder extends ReplayingDecoder<Package> {

@Override
public void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
try {
if (null == in) {
return;
}

byte[] flagBytes = parseFlag(in);
byte[] versionBytes = parseVersion(in);
validateFlag(flagBytes, versionBytes, ctx);

final int length = in.readInt();
final int headerLength = in.readInt();
final int bodyLength = length - CONSTANT_MAGIC_FLAG.length - VERSION.length - headerLength;
Header header = parseHeader(in, headerLength);
Object body = parseBody(in, header, bodyLength);

Package pkg = new Package(header, body);
out.add(pkg);
} catch (Exception e) {
log.error("decode error| received data: {}.", deserializeBytes(in.array()), e);
throw e;
}
}
int headerLength = ArrayUtils.getLength(headerData);
int bodyLength = ArrayUtils.getLength(bodyData);

final int length = CONSTANT_MAGIC_FLAG.length + VERSION.length + headerLength + bodyLength;

private byte[] parseFlag(ByteBuf in) {
final byte[] flagBytes = new byte[CONSTANT_MAGIC_FLAG.length];
in.readBytes(flagBytes);
return flagBytes;
if (length > FRAME_MAX_LENGTH) {
throw new IllegalArgumentException("message size is exceed limit!");
}
/**
* Header + Body, Format:
* <pre>
* ┌───────────────┬─────────────┬──────────────────┬──────────────────┬──────────────────┬─────────────────┐
* │ MAGIC_FLAG │ VERSION │ package length │ Header length │ Header │ body │
* │ (9bytes) │ (4bytes) │ (4bytes) │ (4bytes) │ (header bytes) │ (body bytes) │
* └───────────────┴─────────────┴──────────────────┴──────────────────┴──────────────────┴─────────────────┘
* </pre>
*/
out.writeBytes(CONSTANT_MAGIC_FLAG);
out.writeBytes(VERSION);
out.writeInt(length);
out.writeInt(headerLength);
if (headerData != null) {
out.writeBytes(headerData);
}
if (bodyData != null) {
out.writeBytes(bodyData);
}
}

private byte[] parseVersion(ByteBuf in) {
final byte[] versionBytes = new byte[VERSION.length];
in.readBytes(versionBytes);
return versionBytes;
@Override
protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
if (in == null) {
return;
}
if (in.readableBytes() < CONSTANT_MAGIC_FLAG.length + VERSION.length + 4 + 4) {
// Not enough data to read the package length and header length
return;
}
byte[] flagBytes = parseFlag(in);
byte[] versionBytes = parseVersion(in);
validateFlagAndVersion(flagBytes, versionBytes, ctx);
final int packageLength = in.readInt();
final int headerLength = in.readInt();
if (in.readableBytes() < packageLength - 13) {
// Not enough data yet, reset the reader index and wait for more data
in.resetReaderIndex();
return;
}
final int bodyLength = packageLength - CONSTANT_MAGIC_FLAG.length - VERSION.length - headerLength;
Header header = parseHeader(in, headerLength);
Object body = parseBody(in, header, bodyLength);

private Header parseHeader(ByteBuf in, int headerLength) throws JsonProcessingException {
if (headerLength <= 0) {
return null;
}
final byte[] headerData = new byte[headerLength];
in.readBytes(headerData);
LogUtils.debug(log, "Decode headerJson={}", deserializeBytes(headerData));
return JsonUtils.parseObject(headerData, Header.class);
Package pkg = new Package(header, body);
out.add(pkg);
}

private byte[] parseFlag(ByteBuf in) {
final byte[] flagBytes = new byte[CONSTANT_MAGIC_FLAG.length];
in.readBytes(flagBytes);
return flagBytes;
}

private byte[] parseVersion(ByteBuf in) {
final byte[] versionBytes = new byte[VERSION.length];
in.readBytes(versionBytes);
return versionBytes;
}

private Header parseHeader(ByteBuf in, int headerLength) throws JsonProcessingException {
if (headerLength <= 0) {
return null;
}
final byte[] headerData = new byte[headerLength];
in.readBytes(headerData);
LogUtils.debug(log, "Decode headerJson={}", deserializeBytes(headerData));
return JsonUtils.parseObject(headerData, Header.class);
}

private Object parseBody(ByteBuf in, Header header, int bodyLength) throws JsonProcessingException {
if (bodyLength <= 0 || header == null) {
return null;
}
final byte[] bodyData = new byte[bodyLength];
in.readBytes(bodyData);
LogUtils.debug(log, "Decode bodyJson={}", deserializeBytes(bodyData));
return deserializeBody(deserializeBytes(bodyData), header);
private Object parseBody(ByteBuf in, Header header, int bodyLength) throws JsonProcessingException {
if (bodyLength <= 0 || header == null) {
return null;
}
final byte[] bodyData = new byte[bodyLength];
in.readBytes(bodyData);
LogUtils.debug(log, "Decode bodyJson={}", deserializeBytes(bodyData));
return deserializeBody(deserializeBytes(bodyData), header);
}

private void validateFlag(byte[] flagBytes, byte[] versionBytes, ChannelHandlerContext ctx) {
if (!Arrays.equals(flagBytes, CONSTANT_MAGIC_FLAG) || !Arrays.equals(versionBytes, VERSION)) {
String errorMsg = String.format("invalid magic flag or version|flag=%s|version=%s|remoteAddress=%s",
deserializeBytes(flagBytes), deserializeBytes(versionBytes), ctx.channel().remoteAddress());
throw new IllegalArgumentException(errorMsg);
}
private void validateFlagAndVersion(byte[] flagBytes, byte[] versionBytes, ChannelHandlerContext ctx) {
if (!Arrays.equals(flagBytes, CONSTANT_MAGIC_FLAG) || !Arrays.equals(versionBytes, VERSION)) {
String errorMsg = String.format("invalid magic flag or version|flag=%s|version=%s|remoteAddress=%s",
deserializeBytes(flagBytes), deserializeBytes(versionBytes), ctx.channel().remoteAddress());
throw new IllegalArgumentException(errorMsg);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@ public void testCodec() throws Exception {
header.setCmd(Command.HELLO_REQUEST);
Package testP = new Package(header);
testP.setBody(new Object());
Codec.Encoder ce = new Codec.Encoder();
Codec ce = new Codec();
ByteBuf buf = PooledByteBufAllocator.DEFAULT.buffer();
ce.encode(null, testP, buf);
Codec.Decoder cd = new Codec.Decoder();
Codec cd = new Codec();
ArrayList<Object> result = new ArrayList<>();
cd.decode(null, buf, result);
Assertions.assertNotNull(result.get(0));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,8 +187,7 @@ private class TcpServerInitializer extends ChannelInitializer<SocketChannel> {
protected void initChannel(SocketChannel ch) {
globalTrafficShapingHandler = newGTSHandler(tcpThreadPoolGroup.getScheduler(), eventMeshTCPConfiguration.getCtc().getReadLimit());
ch.pipeline()
.addLast(getWorkerGroup(), new Codec.Encoder())
.addLast(getWorkerGroup(), new Codec.Decoder())
.addLast(getWorkerGroup(), new Codec())
.addLast(getWorkerGroup(), "global-traffic-shaping", globalTrafficShapingHandler)
.addLast(getWorkerGroup(), "channel-traffic-shaping", newCTSHandler(eventMeshTCPConfiguration.getCtc().getReadLimit()))
.addLast(getWorkerGroup(), tcpConnectionHandler)
Expand Down

0 comments on commit 95ac2e0

Please sign in to comment.