Skip to content

Commit

Permalink
Handle ResponseWriteUtils Fail Correctly (microsoft#433)
Browse files Browse the repository at this point in the history
* RespWriteUtils code cleanup

* SendAndReset fixes

* validate write response does not require response buffer draining

* bump release version

* addressing comments
  • Loading branch information
vazois authored May 31, 2024
1 parent 36f25c4 commit 3cc65bd
Show file tree
Hide file tree
Showing 8 changed files with 135 additions and 132 deletions.
2 changes: 1 addition & 1 deletion .azure/pipelines/azure-pipelines-external-release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# 1) update the name: string below (line 6) -- this is the version for the nuget package (e.g. 1.0.0)
# 2) update \libs\host\GarnetServer.cs readonly string version (~line 45) -- NOTE - these two values need to be the same
######################################
name: 1.0.11
name: 1.0.12
trigger:
branches:
include:
Expand Down
89 changes: 44 additions & 45 deletions libs/common/RespWriteUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ public static unsafe class RespWriteUtils
/// </summary>
public static bool WriteMapLength(int len, ref byte* curr, byte* end)
{
int numDigits = NumUtils.NumDigits(len);
int totalLen = 1 + numDigits + 2;
var numDigits = NumUtils.NumDigits(len);
var totalLen = 1 + numDigits + 2;
if (totalLen > (int)(end - curr))
return false;
*curr++ = (byte)'%';
Expand All @@ -34,8 +34,8 @@ public static bool WriteMapLength(int len, ref byte* curr, byte* end)
/// </summary>
public static bool WritePushLength(int len, ref byte* curr, byte* end)
{
int numDigits = NumUtils.NumDigits(len);
int totalLen = 1 + numDigits + 2;
var numDigits = NumUtils.NumDigits(len);
var totalLen = 1 + numDigits + 2;
if (totalLen > (int)(end - curr))
return false;
*curr++ = (byte)'>';
Expand All @@ -49,8 +49,8 @@ public static bool WritePushLength(int len, ref byte* curr, byte* end)
/// </summary>
public static bool WriteArrayLength(int len, ref byte* curr, byte* end)
{
int numDigits = NumUtils.NumDigits(len);
int totalLen = 1 + numDigits + 2;
var numDigits = NumUtils.NumDigits(len);
var totalLen = 1 + numDigits + 2;
if (totalLen > (int)(end - curr))
return false;
*curr++ = (byte)'*';
Expand All @@ -64,12 +64,12 @@ public static bool WriteArrayLength(int len, ref byte* curr, byte* end)
/// </summary>
public static bool WriteArrayItem(long integer, ref byte* curr, byte* end)
{
int integerLen = NumUtils.NumDigitsInLong(integer);
byte sign = (byte)(integer < 0 ? 1 : 0);
int integerLenLen = NumUtils.NumDigits(sign + integerLen);
var integerLen = NumUtils.NumDigitsInLong(integer);
var sign = (byte)(integer < 0 ? 1 : 0);
var integerLenLen = NumUtils.NumDigits(sign + integerLen);

//$[integerLen]\r\n[integer]\r\n
int totalLen = 1 + integerLenLen + 2 + sign + integerLen + 2;
// $[integerLen]\r\n[integer]\r\n
var totalLen = 1 + integerLenLen + 2 + sign + integerLen + 2;
if (totalLen > (int)(end - curr))
return false;

Expand Down Expand Up @@ -115,7 +115,7 @@ public static bool WriteNullArray(ref byte* curr, byte* end)
public static bool WriteSimpleString(ReadOnlySpan<byte> simpleString, ref byte* curr, byte* end)
{
// Simple strings are of the form "+OK\r\n"
int totalLen = 1 + simpleString.Length + 2;
var totalLen = 1 + simpleString.Length + 2;
if (totalLen > (int)(end - curr))
return false;

Expand Down Expand Up @@ -144,15 +144,14 @@ public static bool WriteSimpleString(ReadOnlySpan<char> simpleString, ref byte*
return true;
}


/// <summary>
/// Write a long as a simple string
/// </summary>
public static bool WriteLongAsSimpleString(long value, ref byte* curr, byte* end)
{
// Simple strings are of the form "+cc\r\n"
int longLength = NumUtils.NumDigitsInLong(value);
int totalLen = 1 + longLength + 2;
var longLength = NumUtils.NumDigitsInLong(value);
var totalLen = 1 + longLength + 2;
if (totalLen > (int)(end - curr))
return false;

Expand All @@ -168,7 +167,7 @@ public static bool WriteLongAsSimpleString(long value, ref byte* curr, byte* end
/// <param name="errorString">An ASCII encoded error string. The string mustn't contain a CR (\r) or LF (\n) bytes.</param>
public static bool WriteError(ReadOnlySpan<byte> errorString, ref byte* curr, byte* end)
{
int totalLen = 1 + errorString.Length + 2;
var totalLen = 1 + errorString.Length + 2;
if (totalLen > (int)(end - curr))
return false;

Expand All @@ -185,12 +184,12 @@ public static bool WriteError(ReadOnlySpan<byte> errorString, ref byte* curr, by
/// <param name="errorString">An ASCII error string. The string mustn't contain a CR (\r) or LF (\n) characters.</param>
public static bool WriteError(ReadOnlySpan<char> errorString, ref byte* curr, byte* end)
{
int totalLen = 1 + errorString.Length + 2;
var totalLen = 1 + errorString.Length + 2;
if (totalLen > (int)(end - curr))
return false;

*curr++ = (byte)'-';
int bytesWritten = Encoding.ASCII.GetBytes(errorString, new Span<byte>(curr, errorString.Length));
var bytesWritten = Encoding.ASCII.GetBytes(errorString, new Span<byte>(curr, errorString.Length));
curr += bytesWritten;
WriteNewline(ref curr);
return true;
Expand Down Expand Up @@ -219,7 +218,7 @@ public static bool WriteAsciiDirect(ReadOnlySpan<char> span, ref byte* curr, byt
if (span.Length > (int)(end - curr))
return false;

int bytesWritten = Encoding.ASCII.GetBytes(span, new Span<byte>(curr, span.Length));
var bytesWritten = Encoding.ASCII.GetBytes(span, new Span<byte>(curr, span.Length));
curr += bytesWritten;
return true;
}
Expand Down Expand Up @@ -262,14 +261,14 @@ public static bool WriteBulkString(ReadOnlySpan<byte> item, ref byte* curr, byte
public static bool WriteAsciiBulkString(ReadOnlySpan<char> chars, ref byte* curr, byte* end)
{
var itemDigits = NumUtils.NumDigits(chars.Length);
int totalLen = 1 + itemDigits + 2 + chars.Length + 2;
var totalLen = 1 + itemDigits + 2 + chars.Length + 2;
if (totalLen > (int)(end - curr))
return false;

*curr++ = (byte)'$';
NumUtils.IntToBytes(chars.Length, itemDigits, ref curr);
WriteNewline(ref curr);
int bytesWritten = Encoding.ASCII.GetBytes(chars, new Span<byte>(curr, chars.Length));
var bytesWritten = Encoding.ASCII.GetBytes(chars, new Span<byte>(curr, chars.Length));
curr += bytesWritten;
WriteNewline(ref curr);
return true;
Expand All @@ -281,17 +280,17 @@ public static bool WriteAsciiBulkString(ReadOnlySpan<char> chars, ref byte* curr
public static bool WriteUtf8BulkString(ReadOnlySpan<char> chars, ref byte* curr, byte* end)
{
// Calculate the amount of bytes it takes to encoded the UTF16 string as UTF8
int encodedByteCount = Encoding.UTF8.GetByteCount(chars);
var encodedByteCount = Encoding.UTF8.GetByteCount(chars);

var itemDigits = NumUtils.NumDigits(encodedByteCount);
int totalLen = 1 + itemDigits + 2 + encodedByteCount + 2;
var totalLen = 1 + itemDigits + 2 + encodedByteCount + 2;
if (totalLen > (int)(end - curr))
return false;

*curr++ = (byte)'$';
NumUtils.IntToBytes(encodedByteCount, itemDigits, ref curr);
WriteNewline(ref curr);
int bytesWritten = Encoding.UTF8.GetBytes(chars, new Span<byte>(curr, encodedByteCount));
var bytesWritten = Encoding.UTF8.GetBytes(chars, new Span<byte>(curr, encodedByteCount));
curr += bytesWritten;
WriteNewline(ref curr);
return true;
Expand All @@ -308,11 +307,11 @@ public static int GetBulkStringLength(int length)
/// </summary>
public static bool WriteInteger(int integer, ref byte* curr, byte* end)
{
int integerLen = NumUtils.NumDigitsInLong(integer);
byte sign = (byte)(integer < 0 ? 1 : 0);
var integerLen = NumUtils.NumDigitsInLong(integer);
var sign = (byte)(integer < 0 ? 1 : 0);

//:integer\r\n
int totalLen = 1 + sign + integerLen + 2;
var totalLen = 1 + sign + integerLen + 2;
if (totalLen > (int)(end - curr))
return false;

Expand All @@ -327,11 +326,11 @@ public static bool WriteInteger(int integer, ref byte* curr, byte* end)
/// </summary>
public static bool WriteInteger(long integer, ref byte* curr, byte* end)
{
int integerLen = NumUtils.NumDigitsInLong(integer);
byte sign = (byte)(integer < 0 ? 1 : 0);
var integerLen = NumUtils.NumDigitsInLong(integer);
var sign = (byte)(integer < 0 ? 1 : 0);

//:integer\r\n
int totalLen = 1 + sign + integerLen + 2;
var totalLen = 1 + sign + integerLen + 2;
if (totalLen > (int)(end - curr))
return false;

Expand Down Expand Up @@ -381,13 +380,13 @@ public static bool WriteIntegerFromBytes(ReadOnlySpan<byte> integerBytes, ref by
/// </summary>
public static bool WriteIntegerAsBulkString(int integer, ref byte* curr, byte* end)
{
int integerLen = NumUtils.NumDigitsInLong(integer);
byte sign = (byte)(integer < 0 ? 1 : 0);
var integerLen = NumUtils.NumDigitsInLong(integer);
var sign = (byte)(integer < 0 ? 1 : 0);

int integerLenSize = NumUtils.NumDigits(integerLen + sign);
var integerLenSize = NumUtils.NumDigits(integerLen + sign);

//$size\r\ninteger\r\n
int totalLen = 1 + integerLenSize + 2 + sign + integerLen + 2;
var totalLen = 1 + integerLenSize + 2 + sign + integerLen + 2;
if (totalLen > (int)(end - curr))
return false;

Expand All @@ -404,10 +403,10 @@ public static bool WriteIntegerAsBulkString(int integer, ref byte* curr, byte* e
/// </summary>
public static bool WriteIntegerAsBulkString(long integer, ref byte* curr, byte* end, out int totalLen)
{
int integerLen = NumUtils.NumDigitsInLong(integer);
byte sign = (byte)(integer < 0 ? 1 : 0);
var integerLen = NumUtils.NumDigitsInLong(integer);
var sign = (byte)(integer < 0 ? 1 : 0);

int integerLenSize = NumUtils.NumDigits(integerLen + sign);
var integerLenSize = NumUtils.NumDigits(integerLen + sign);

//$size\r\ninteger\r\n
totalLen = 1 + integerLenSize + 2 + sign + integerLen + 2;
Expand All @@ -427,16 +426,15 @@ public static bool WriteIntegerAsBulkString(long integer, ref byte* curr, byte*
/// </summary>
public static int GetIntegerAsBulkStringLength(int integer)
{
int integerLen = NumUtils.NumDigitsInLong(integer);
byte sign = (byte)(integer < 0 ? 1 : 0);
var integerLen = NumUtils.NumDigitsInLong(integer);
var sign = (byte)(integer < 0 ? 1 : 0);

int integerLenSize = NumUtils.NumDigits(integerLen + sign);
var integerLenSize = NumUtils.NumDigits(integerLen + sign);

//$size\r\ninteger\r\n
return 1 + integerLenSize + 2 + sign + integerLen + 2;
}


/// <summary>
/// Create header for *Scan output
/// *scan commands have an array of two elements
Expand Down Expand Up @@ -475,8 +473,8 @@ public static bool WriteEmptyArray(ref byte* curr, byte* end)
/// </summary>
public static bool WriteArrayWithNullElements(int len, ref byte* curr, byte* end)
{
int numDigits = NumUtils.NumDigits(len);
int totalLen = 1 + numDigits + 2;
var numDigits = NumUtils.NumDigits(len);
var totalLen = 1 + numDigits + 2;
totalLen += len * 5; // 5 is the length of $-1\r\n

if (totalLen > (int)(end - curr))
Expand All @@ -485,9 +483,10 @@ public static bool WriteArrayWithNullElements(int len, ref byte* curr, byte* end
*curr++ = (byte)'*';
NumUtils.IntToBytes(len, numDigits, ref curr);
WriteNewline(ref curr);
for (int i = 0; i < len; i++)
for (var i = 0; i < len; i++)
{
WriteNull(ref curr, end);
if (!WriteNull(ref curr, end))
return false;
}
return true;
}
Expand Down
2 changes: 1 addition & 1 deletion libs/host/GarnetServer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ public class GarnetServer : IDisposable
protected StoreWrapper storeWrapper;

// IMPORTANT: Keep the version in sync with .azure\pipelines\azure-pipelines-external-release.yml line ~6.
readonly string version = "1.0.11";
readonly string version = "1.0.12";

/// <summary>
/// Resp protocol version
Expand Down
Loading

0 comments on commit 3cc65bd

Please sign in to comment.