Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/addr #135

Merged
merged 22 commits into from
Oct 7, 2024
15 changes: 12 additions & 3 deletions src/network/peer.zig
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ const protocol = @import("./protocol/lib.zig");
const wire = @import("./wire/lib.zig");
const Config = @import("../config/config.zig").Config;
const MessageUtils = @import("./message/utils.zig");
const NetworkAddress = @import("./protocol/types/NetworkAddress.zig").NetworkAddress;

pub const Boundness = enum {
inbound,
Expand Down Expand Up @@ -80,7 +81,7 @@ pub const Peer = struct {
switch (received_message) {
.version => |vm| {
self.protocol_version = @min(self.config.protocol_version, vm.version);
self.services = vm.trans_services;
self.services = vm.addr_from.services;
},

.verack => return,
Expand All @@ -95,8 +96,16 @@ pub const Peer = struct {

const message = protocol.messages.VersionMessage.new(
self.config.protocol_version,
.{ .ip = std.mem.zeroes([16]u8), .port = 0, .services = self.config.services },
.{ .ip = address.in6.sa.addr, .port = address.in6.getPort(), .services = 0 },
NetworkAddress{
.services = self.config.services,
.ip = std.mem.zeroes([16]u8),
.port = 0,
},
NetworkAddress{
.services = 0,
.ip = address.in6.sa.addr,
.port = address.in6.getPort(),
},
std.crypto.random.int(u64),
self.config.bestBlock(),
);
Expand Down
143 changes: 143 additions & 0 deletions src/network/protocol/messages/addr.zig
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
const std = @import("std");
const protocol = @import("../lib.zig");
const genericChecksum = @import("lib.zig").genericChecksum;
const NetworkAddress = @import("../types/NetworkAddress.zig").NetworkAddress;
const genericDeserializeSlice = @import("lib.zig").genericDeserializeSlice;
const genericSerialize = @import("lib.zig").genericSerialize;

const Endian = std.builtin.Endian;
const Sha256 = std.crypto.hash.sha2.Sha256;

const CompactSizeUint = @import("bitcoin-primitives").types.CompatSizeUint;

pub const NetworkIPAddr = struct {
oxlime marked this conversation as resolved.
Show resolved Hide resolved
time: u32, // Unix epoch time
address: NetworkAddress,

// NetworkIPAddr eql
pub fn eql(self: *const NetworkIPAddr, other: *const NetworkIPAddr) bool {
return self.time == other.time and self.address.eql(&other.address);
}

pub fn serializeToWriter(self: *const NetworkIPAddr, writer: anytype) !void {
try writer.writeInt(u32, self.time, .little);
try self.address.serializeToWriter(writer);
}

pub fn deserializeReader(reader: anytype) !NetworkIPAddr {
return NetworkIPAddr{
.time = try reader.readInt(u32, .little),
.address = try NetworkAddress.deserializeReader(reader),
};
}
};

/// AddrMessage represents the "addr" message
///
/// https://developer.bitcoin.org/reference/p2p_networking.html#addr
pub const AddrMessage = struct {
ip_addresses: []NetworkIPAddr,

const Self = @This();

pub inline fn name() *const [12]u8 {
return protocol.CommandNames.ADDR ++ [_]u8{0} ** 8;
}

/// Returns the message checksum
///
/// Computed as `Sha256(Sha256(self.serialize()))[0..4]`
pub fn checksum(self: *const AddrMessage) [4]u8 {
return genericChecksum(self);
}

/// Free the `user_agent` if there is one
pub fn deinit(self: AddrMessage, allocator: std.mem.Allocator) void {
allocator.free(self.ip_addresses);
}

/// Serialize the message as bytes and write them to the Writer.
///
/// `w` should be a valid `Writer`.
pub fn serializeToWriter(self: *const AddrMessage, w: anytype) !void {
try CompactSizeUint.new(self.ip_addresses.len).encodeToWriter(w);
for (self.ip_addresses) |*addr| {
try addr.serializeToWriter(w);
}
}

/// Serialize a message as bytes and return them.
pub fn serialize(self: *const AddrMessage, allocator: std.mem.Allocator) ![]u8 {
oxlime marked this conversation as resolved.
Show resolved Hide resolved
return genericSerialize(self, allocator);
}

/// Deserialize a Reader bytes as a `AddrMessage`
pub fn deserializeReader(allocator: std.mem.Allocator, r: anytype) !AddrMessage {
const ip_address_count = try CompactSizeUint.decodeReader(r);

// Allocate space for IP addresses
const ip_addresses = try allocator.alloc(NetworkIPAddr, ip_address_count.value());
errdefer allocator.free(ip_addresses);

for (ip_addresses) |*ip_address| {
ip_address.* = try NetworkIPAddr.deserializeReader(r);
}

return AddrMessage{
.ip_addresses = ip_addresses,
};
}

/// Deserialize bytes into a `AddrMessage`
pub fn deserializeSlice(allocator: std.mem.Allocator, bytes: []const u8) !AddrMessage {
oxlime marked this conversation as resolved.
Show resolved Hide resolved
return genericDeserializeSlice(Self, allocator, bytes);
}

pub fn hintSerializedLen(self: AddrMessage) usize {
// 4 + 8 + 16 + 2
const fixed_length_per_ip = 30;
const count = CompactSizeUint.new(self.ip_addresses.len).hint_encoded_len();
return count + self.ip_addresses.len * fixed_length_per_ip;
}

pub fn eql(self: *const AddrMessage, other: *const AddrMessage) bool {
if (self.ip_addresses.len != other.ip_addresses.len) return false;

const count = @as(usize, self.ip_addresses.len);
for (0..count) |i| {
if (!self.ip_addresses[i].eql(&other.ip_addresses[i])) return false;
}

return true;
}
};

// TESTS
test "ok_full_flow_AddrMessage" {
const test_allocator = std.testing.allocator;
{
const ip_addresses = try test_allocator.alloc(NetworkIPAddr, 1);
defer test_allocator.free(ip_addresses);

ip_addresses[0] = NetworkIPAddr{ .time = 1414012889, .address = NetworkAddress{
.services = 1,
.ip = [16]u8{ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 192, 0, 2, 51 },
.port = 8080,
} };
const am = AddrMessage{
.ip_addresses = ip_addresses[0..],
};

// Serialize
const payload = try am.serialize(test_allocator);
defer test_allocator.free(payload);

// Deserialize
const deserialized_am = try AddrMessage.deserializeSlice(test_allocator, payload);

// Test equality
try std.testing.expect(am.eql(&deserialized_am));

defer test_allocator.free(deserialized_am.ip_addresses);
}
}
4 changes: 2 additions & 2 deletions src/network/protocol/messages/filterload.zig
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ pub const FilterLoadMessage = struct {
const filter = try allocator.alloc(u8, filter_len);
errdefer allocator.free(filter);
try r.readNoEof(filter);

const hash_func = try r.readInt(u32, .little);
const tweak = try r.readInt(u32, .little);
const flags = try r.readInt(u8, .little);
Expand All @@ -78,7 +78,7 @@ pub const FilterLoadMessage = struct {
pub fn hintSerializedLen(self: *const Self) usize {
const fixed_length = 4 + 4 + 1; // hash_func (4 bytes) + tweak (4 bytes) + flags (1 byte)
const compact_filter_len = CompactSizeUint.new(self.filter.len).hint_encoded_len();
return compact_filter_len + self.filter.len + fixed_length;
return compact_filter_len + self.filter.len + fixed_length;
}

pub fn deinit(self: *Self, allocator: std.mem.Allocator) void {
Expand Down
11 changes: 10 additions & 1 deletion src/network/protocol/messages/lib.zig
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ pub const BlockMessage = @import("block.zig").BlockMessage;
pub const GetblocksMessage = @import("getblocks.zig").GetblocksMessage;
pub const PingMessage = @import("ping.zig").PingMessage;
pub const PongMessage = @import("pong.zig").PongMessage;
pub const AddrMessage = @import("addr.zig").AddrMessage;
pub const MerkleBlockMessage = @import("merkleblock.zig").MerkleBlockMessage;
pub const FeeFilterMessage = @import("feefilter.zig").FeeFilterMessage;
pub const SendCmpctMessage = @import("sendcmpct.zig").SendCmpctMessage;
Expand All @@ -29,6 +30,7 @@ pub const MessageTypes = enum {
getblocks,
ping,
pong,
addr,
merkleblock,
sendcmpct,
feefilter,
Expand All @@ -52,6 +54,7 @@ pub const Message = union(MessageTypes) {
getblocks: GetblocksMessage,
ping: PingMessage,
pong: PongMessage,
addr: AddrMessage,
merkleblock: MerkleBlockMessage,
sendcmpct: SendCmpctMessage,
feefilter: FeeFilterMessage,
Expand All @@ -74,6 +77,7 @@ pub const Message = union(MessageTypes) {
.getblocks => |m| @TypeOf(m).name(),
.ping => |m| @TypeOf(m).name(),
.pong => |m| @TypeOf(m).name(),
.addr => |m| @TypeOf(m).name(),
.merkleblock => |m| @TypeOf(m).name(),
.sendcmpct => |m| @TypeOf(m).name(),
.feefilter => |m| @TypeOf(m).name(),
Expand All @@ -93,6 +97,9 @@ pub const Message = union(MessageTypes) {
switch (self.*) {
.version => |*m| m.deinit(allocator),
.getblocks => |*m| m.deinit(allocator),
.ping => {},
.pong => {},
.addr => |m| m.deinit(allocator),
.merkleblock => |*m| m.deinit(allocator),
.block => |*m| m.deinit(allocator),
.filteradd => |*m| m.deinit(allocator),
Expand Down Expand Up @@ -121,6 +128,7 @@ pub const Message = union(MessageTypes) {
.notfound => |*m| m.checksum(),
.sendheaders => |*m| m.checksum(),
.filterload => |*m| m.checksum(),
.addr => |*m| m.checksum(),
.getdata => |*m| m.checksum(),
.headers => |*m| m.checksum(),
.cmpctblock => |*m| m.checksum(),
Expand All @@ -145,6 +153,7 @@ pub const Message = union(MessageTypes) {
.notfound => |m| m.hintSerializedLen(),
.sendheaders => |m| m.hintSerializedLen(),
.filterload => |*m| m.hintSerializedLen(),
.addr => |*m| m.hintSerializedLen(),
.getdata => |m| m.hintSerializedLen(),
.headers => |*m| m.hintSerializedLen(),
.cmpctblock => |*m| m.hintSerializedLen(),
Expand Down Expand Up @@ -194,4 +203,4 @@ pub fn genericDeserializeSlice(comptime T: type, allocator: std.mem.Allocator, b
const reader = fbs.reader();

return try T.deserializeReader(allocator, reader);
}
}
Loading
Loading