Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(protocol): add get data message #118

Merged
merged 16 commits into from
Oct 4, 2024
Merged
4 changes: 3 additions & 1 deletion src/network/protocol/lib.zig
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
pub const messages = @import("./messages/lib.zig");
pub const NetworkAddress = @import("NetworkAddress.zig");
pub const NetworkAddress = @import("types/NetworkAddress.zig");
pub const InventoryItem = @import("types/InventoryItem.zig");

/// Network services
pub const ServiceFlags = struct {
pub const NODE_NETWORK: u64 = 0x1;
Expand Down
147 changes: 147 additions & 0 deletions src/network/protocol/messages/getdata.zig
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
const std = @import("std");
const CompactSizeUint = @import("bitcoin-primitives").types.CompatSizeUint;
const message = @import("./lib.zig");
const genericChecksum = @import("lib.zig").genericChecksum;

const Sha256 = std.crypto.hash.sha2.Sha256;

const protocol = @import("../lib.zig");

pub const GetdataMessage = struct {
inventory: []const protocol.InventoryItem,

pub inline fn name() *const [12]u8 {
return protocol.CommandNames.GETDATA ++ [_]u8{0} ** 5;
}

/// Returns the message checksum
///
/// Computed as `Sha256(Sha256(self.serialize()))[0..4]`
pub fn checksum(self: *const GetdataMessage) [4]u8 {
return genericChecksum(self);
}

/// Free the `inventory`
pub fn deinit(self: GetdataMessage, allocator: std.mem.Allocator) void {
allocator.free(self.inventory);
}

/// Serialize the message as bytes and write them to the Writer.
///
/// `w` should be a valid `Writer`.
pub fn serializeToWriter(self: *const GetdataMessage, w: anytype) !void {
const count = CompactSizeUint.new(self.inventory.len);
try count.encodeToWriter(w);

for (self.inventory) |item| {
try item.encodeToWriter(w);
}
}

pub fn serialize(self: *const GetdataMessage, allocator: std.mem.Allocator) ![]u8 {
const serialized_len = self.hintSerializedLen();

const ret = try allocator.alloc(u8, serialized_len);
errdefer allocator.free(ret);

try self.serializeToSlice(ret);

return ret;
}

/// Serialize a message as bytes and write them to the buffer.
///
/// buffer.len must be >= than self.hintSerializedLen()
pub fn serializeToSlice(self: *const GetdataMessage, buffer: []u8) !void {
var fbs = std.io.fixedBufferStream(buffer);
const writer = fbs.writer();
try self.serializeToWriter(writer);
}

pub fn hintSerializedLen(self: *const GetdataMessage) usize {
var length: usize = 0;

// Adding the length of CompactSizeUint for the count
const count = CompactSizeUint.new(self.inventory.len);
length += count.hint_encoded_len();

// Adding the length of each inventory item
length += self.inventory.len * (4 + 32); // Type (4 bytes) + Hash (32 bytes)

return length;
}

pub fn deserializeReader(allocator: std.mem.Allocator, r: anytype) !GetdataMessage {

const compact_count = try CompactSizeUint.decodeReader(r);
const count = compact_count.value();
tdelabro marked this conversation as resolved.
Show resolved Hide resolved
if (count == 0) {
return GetdataMessage{
.inventory = &[_]protocol.InventoryItem{},
};
}

const inventory = try allocator.alloc(protocol.InventoryItem, count);
errdefer allocator.free(inventory);

for (inventory) |*item| {
item.* = try protocol.InventoryItem.decodeReader(r);
}

return GetdataMessage{
.inventory = inventory,
};
}

/// Deserialize bytes into a `GetdataMessage`
pub fn deserializeSlice(allocator: std.mem.Allocator, bytes: []const u8) !GetdataMessage {
var fbs = std.io.fixedBufferStream(bytes);
const reader = fbs.reader();
return try GetdataMessage.deserializeReader(allocator, reader);
}


pub fn eql(self: *const GetdataMessage, other: *const GetdataMessage) bool {
if (self.inventory.len != other.inventory.len) return false;

for (0..self.inventory.len) |i| {
const item_self = self.inventory[i];
const item_other = other.inventory[i];
if (!item_self.eql(&item_other)) {
return false;
}
}

return true;
}
};


// TESTS

test "ok_full_flow_GetdataMessage" {
const allocator = std.testing.allocator;

// With some inventory items
{
const inventory_items = [_]protocol.InventoryItem{
.{ .type = 1, .hash = [_]u8{0xab} ** 32 },
.{ .type = 2, .hash = [_]u8{0xcd} ** 32 },
.{ .type = 2, .hash = [_]u8{0xef} ** 32 },
};

const gd = GetdataMessage{
.inventory = inventory_items[0..],
};

const payload = try gd.serialize(allocator);
defer allocator.free(payload);

const deserialized_gd = try GetdataMessage.deserializeSlice(allocator, payload);

try std.testing.expect(gd.eql(&deserialized_gd));

// Free allocated memory for deserialized inventory
defer allocator.free(deserialized_gd.inventory);
}
}
50 changes: 9 additions & 41 deletions src/network/protocol/messages/lib.zig
Original file line number Diff line number Diff line change
Expand Up @@ -11,43 +11,14 @@ pub const MerkleBlockMessage = @import("merkleblock.zig").MerkleBlockMessage;
pub const FeeFilterMessage = @import("feefilter.zig").FeeFilterMessage;
pub const SendCmpctMessage = @import("sendcmpct.zig").SendCmpctMessage;
pub const FilterClearMessage = @import("filterclear.zig").FilterClearMessage;
pub const GetdataMessage = @import("getdata.zig").GetdataMessage;
pub const Block = @import("block.zig").BlockMessage;
pub const FilterAddMessage = @import("filteradd.zig").FilterAddMessage;
const Sha256 = std.crypto.hash.sha2.Sha256;
pub const NotFoundMessage = @import("notfound.zig").NotFoundMessage;
pub const SendHeadersMessage = @import("sendheaders.zig").SendHeadersMessage;
pub const FilterLoadMessage = @import("filterload.zig").FilterLoadMessage;
pub const HeadersMessage = @import("headers.zig").HeadersMessage;

pub const InventoryVector = struct {
type: u32,
hash: [32]u8,

pub fn serializeToWriter(self: InventoryVector, writer: anytype) !void {
comptime {
if (!std.meta.hasFn(@TypeOf(writer), "writeInt")) @compileError("Expects writer to have fn 'writeInt'.");
if (!std.meta.hasFn(@TypeOf(writer), "writeAll")) @compileError("Expects writer to have fn 'writeAll'.");
}
try writer.writeInt(u32, self.type, .little);
try writer.writeAll(&self.hash);
}

pub fn deserializeReader(r: anytype) !InventoryVector {
comptime {
if (!std.meta.hasFn(@TypeOf(r), "readInt")) @compileError("Expects r to have fn 'readInt'.");
if (!std.meta.hasFn(@TypeOf(r), "readBytesNoEof")) @compileError("Expects r to have fn 'readBytesNoEof'.");
}

const type_value = try r.readInt(u32, .little);
var hash: [32]u8 = undefined;
try r.readNoEof(&hash);

return InventoryVector{
.type = type_value,
.hash = hash,
};
}
};
pub const CmpctBlockMessage = @import("cmpctblock.zig").CmpctBlockMessage;

pub const MessageTypes = enum {
Expand All @@ -67,10 +38,12 @@ pub const MessageTypes = enum {
notfound,
sendheaders,
filterload,
getdata,
headers,
cmpctblock,
};


pub const Message = union(MessageTypes) {
version: VersionMessage,
verack: VerackMessage,
Expand All @@ -88,6 +61,7 @@ pub const Message = union(MessageTypes) {
notfound: NotFoundMessage,
sendheaders: SendHeadersMessage,
filterload: FilterLoadMessage,
getdata: GetdataMessage,
headers: HeadersMessage,
cmpctblock: CmpctBlockMessage,

Expand All @@ -109,6 +83,7 @@ pub const Message = union(MessageTypes) {
.notfound => |m| @TypeOf(m).name(),
.sendheaders => |m| @TypeOf(m).name(),
.filterload => |m| @TypeOf(m).name(),
.getdata => |m| @TypeOf(m).name(),
.headers => |m| @TypeOf(m).name(),
.cmpctblock => |m| @TypeOf(m).name(),
};
Expand All @@ -117,23 +92,14 @@ pub const Message = union(MessageTypes) {
pub fn deinit(self: *Message, allocator: std.mem.Allocator) void {
switch (self.*) {
.version => |*m| m.deinit(allocator),
.verack => {},
.mempool => {},
.getaddr => {},
.getblocks => |*m| m.deinit(allocator),
.ping => {},
.pong => {},
.merkleblock => |*m| m.deinit(allocator),
.sendcmpct => {},
.feefilter => {},
.filterclear => {},
.block => |*m| m.deinit(allocator),
.filteradd => |*m| m.deinit(allocator),
.notfound => {},
.getdata => |*m| m.deinit(allocator),
.cmpctblock => |*m| m.deinit(allocator),
.sendheaders => {},
.filterload => {},
.headers => |*m| m.deinit(allocator),
else => {}
}
}

Expand All @@ -155,6 +121,7 @@ pub const Message = union(MessageTypes) {
.notfound => |*m| m.checksum(),
.sendheaders => |*m| m.checksum(),
.filterload => |*m| m.checksum(),
.getdata => |*m| m.checksum(),
.headers => |*m| m.checksum(),
.cmpctblock => |*m| m.checksum(),
};
Expand All @@ -178,6 +145,7 @@ pub const Message = union(MessageTypes) {
.notfound => |m| m.hintSerializedLen(),
.sendheaders => |m| m.hintSerializedLen(),
.filterload => |*m| m.hintSerializedLen(),
.getdata => |m| m.hintSerializedLen(),
.headers => |*m| m.hintSerializedLen(),
.cmpctblock => |*m| m.hintSerializedLen(),
};
Expand Down
13 changes: 6 additions & 7 deletions src/network/protocol/messages/notfound.zig
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
const std = @import("std");
const protocol = @import("../lib.zig");
const Sha256 = std.crypto.hash.sha2.Sha256;
const InventoryVector = @import("lib.zig").InventoryVector;

/// NotFoundMessage represents the "notfound" message
///
/// https://developer.bitcoin.org/reference/p2p_networking.html#notfound
pub const NotFoundMessage = struct {
inventory: []const InventoryVector,
inventory: []const protocol.InventoryItem,

const Self = @This();

Expand Down Expand Up @@ -42,7 +41,7 @@ pub const NotFoundMessage = struct {
pub fn serializeToWriter(self: *const Self, writer: anytype) !void {
try writer.writeInt(u32, @intCast(self.inventory.len), .little);
for (self.inventory) |inv| {
try InventoryVector.serializeToWriter(inv, writer);
try inv.encodeToWriter(writer);
}
}

Expand All @@ -65,11 +64,11 @@ pub const NotFoundMessage = struct {
}

const count = try r.readInt(u32, .little);
const inventory = try allocator.alloc(InventoryVector, count);
const inventory = try allocator.alloc(protocol.InventoryItem, count);
errdefer allocator.free(inventory);

for (inventory) |*inv| {
inv.* = try InventoryVector.deserializeReader(r);
inv.* = try protocol.InventoryItem.decodeReader(r);
}

return Self{
Expand All @@ -93,7 +92,7 @@ pub const NotFoundMessage = struct {
allocator.free(self.inventory);
}

pub fn new(inventory: []const InventoryVector) Self {
pub fn new(inventory: []const protocol.InventoryItem) Self {
return .{
.inventory = inventory,
};
Expand All @@ -106,7 +105,7 @@ test "ok_fullflow_notfound_message" {
const allocator = std.testing.allocator;

{
const inventory = [_]InventoryVector{
const inventory = [_]protocol.InventoryItem{
.{ .type = 1, .hash = [_]u8{0xab} ** 32 },
.{ .type = 2, .hash = [_]u8{0xcd} ** 32 },
};
Expand Down
33 changes: 33 additions & 0 deletions src/network/protocol/types/InventoryItem.zig
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
const std = @import("std");

type: u32,
hash: [32]u8,
tdelabro marked this conversation as resolved.
Show resolved Hide resolved

pub fn encodeToWriter(self: *const @This(), w: anytype) !void {
comptime {
if (!std.meta.hasFn(@TypeOf(w), "writeInt")) @compileError("Expects r to have fn 'writeInt'.");
if (!std.meta.hasFn(@TypeOf(w), "writeAll")) @compileError("Expects r to have fn 'writeAll'.");
}
try w.writeInt(u32, self.type, .little);
try w.writeAll(&self.hash);
}

pub fn decodeReader(r: anytype) !@This() {
comptime {
if (!std.meta.hasFn(@TypeOf(r), "readInt")) @compileError("Expects reader to have fn 'readInt'.");
if (!std.meta.hasFn(@TypeOf(r), "readNoEof")) @compileError("Expects reader to have fn 'readNoEof'.");
}

const item_type = try r.readInt(u32, .little);
var hash: [32]u8 = undefined;
try r.readNoEof(&hash);

return @This(){
.type = item_type,
.hash = hash,
};
}

pub fn eql(self: *const @This(), other: *const @This()) bool {
return self.type == other.type and std.mem.eql(u8, &self.hash, &other.hash);
}
Loading
Loading