diff --git a/package.json b/package.json index 4d1ff80..f99773f 100644 --- a/package.json +++ b/package.json @@ -68,7 +68,7 @@ }, "dependencies": { "@libp2p/crypto": "^1.0.0", - "@libp2p/interface-connection-encrypter": "^2.0.1", + "@libp2p/interface-connection-encrypter": "^3.0.0", "@libp2p/interface-keys": "^1.0.2", "@libp2p/interface-peer-id": "^1.0.2", "@libp2p/logger": "^2.0.0", @@ -88,8 +88,8 @@ }, "devDependencies": { "@libp2p/daemon-client": "^3.0.1", - "@libp2p/daemon-server": "^3.0.0", - "@libp2p/interface-connection-encrypter-compliance-tests": "^2.0.1", + "@libp2p/daemon-server": "^3.0.1", + "@libp2p/interface-connection-encrypter-compliance-tests": "^2.0.3", "@libp2p/interop": "^3.0.1", "@libp2p/mplex": "^5.0.0", "@libp2p/peer-id-factory": "^1.0.8", @@ -100,7 +100,7 @@ "execa": "^6.1.0", "go-libp2p": "^0.0.6", "iso-random-stream": "^2.0.2", - "libp2p": "0.39.2", + "libp2p": "0.39.4", "mkdirp": "^1.0.4", "p-defer": "^4.0.0", "protons": "^5.1.0", @@ -111,4 +111,4 @@ "./dist/src/alloc-unsafe.js": "./dist/src/alloc-unsafe-browser.js", "util": false } -} +} \ No newline at end of file diff --git a/src/@types/handshake-interface.ts b/src/@types/handshake-interface.ts index 485a11a..8fe4072 100644 --- a/src/@types/handshake-interface.ts +++ b/src/@types/handshake-interface.ts @@ -1,11 +1,12 @@ import type { PeerId } from '@libp2p/interface-peer-id' import type { bytes } from './basic.js' import type { NoiseSession } from './handshake.js' +import type { NoiseExtensions } from '../proto/payload.js' export interface IHandshake { session: NoiseSession remotePeer: PeerId - remoteEarlyData: bytes + remoteExtensions: NoiseExtensions encrypt: (plaintext: bytes, session: NoiseSession) => bytes - decrypt: (ciphertext: bytes, session: NoiseSession) => {plaintext: bytes, valid: boolean} + decrypt: (ciphertext: bytes, session: NoiseSession) => { plaintext: bytes, valid: boolean } } diff --git a/src/@types/libp2p.ts b/src/@types/libp2p.ts index 4730297..16596cf 100644 --- a/src/@types/libp2p.ts +++ b/src/@types/libp2p.ts @@ -1,11 +1,10 @@ import type { ConnectionEncrypter } from '@libp2p/interface-connection-encrypter' -import type { bytes, bytes32 } from './basic.js' +import type { NoiseExtensions } from '../proto/payload.js' +import type { bytes32 } from './basic.js' export interface KeyPair { publicKey: bytes32 privateKey: bytes32 } -export interface INoiseConnection extends ConnectionEncrypter { - remoteEarlyData?: () => bytes -} +export interface INoiseConnection extends ConnectionEncrypter {} diff --git a/src/handshake-xx.ts b/src/handshake-xx.ts index 9139b12..2f2275b 100644 --- a/src/handshake-xx.ts +++ b/src/handshake-xx.ts @@ -21,12 +21,13 @@ import { getPeerIdFromPayload, verifySignedPayload } from './utils.js' +import type { NoiseExtensions } from './proto/payload.js' export class XXHandshake implements IHandshake { public isInitiator: boolean public session: NoiseSession public remotePeer!: PeerId - public remoteEarlyData: bytes + public remoteExtensions: NoiseExtensions = { webtransportCerthashes: [] } protected payload: bytes protected connection: ProtobufStream @@ -55,7 +56,6 @@ export class XXHandshake implements IHandshake { } this.xx = handshake ?? new XX(crypto) this.session = this.xx.initSession(this.isInitiator, this.prologue, this.staticKeypair) - this.remoteEarlyData = new Uint8Array(0) } // stage 0 @@ -97,7 +97,7 @@ export class XXHandshake implements IHandshake { const decodedPayload = decodePayload(plaintext) this.remotePeer = this.remotePeer || await getPeerIdFromPayload(decodedPayload) await verifySignedPayload(this.session.hs.rs, decodedPayload, this.remotePeer) - this.setRemoteEarlyData(decodedPayload.data) + this.setRemoteNoiseExtension(decodedPayload.extensions) } catch (e) { const err = e as Error throw new UnexpectedPeerError(`Error occurred while verifying signed payload: ${err.message}`) @@ -132,7 +132,7 @@ export class XXHandshake implements IHandshake { const decodedPayload = decodePayload(plaintext) this.remotePeer = this.remotePeer || await getPeerIdFromPayload(decodedPayload) await verifySignedPayload(this.session.hs.rs, decodedPayload, this.remotePeer) - this.setRemoteEarlyData(decodedPayload.data) + this.setRemoteNoiseExtension(decodedPayload.extensions) } catch (e) { const err = e as Error throw new UnexpectedPeerError(`Error occurred while verifying signed payload: ${err.message}`) @@ -147,7 +147,7 @@ export class XXHandshake implements IHandshake { return this.xx.encryptWithAd(cs, new Uint8Array(0), plaintext) } - public decrypt (ciphertext: Uint8Array, session: NoiseSession): {plaintext: bytes, valid: boolean} { + public decrypt (ciphertext: Uint8Array, session: NoiseSession): { plaintext: bytes, valid: boolean } { const cs = this.getCS(session, false) return this.xx.decryptWithAd(cs, new Uint8Array(0), ciphertext) @@ -169,9 +169,9 @@ export class XXHandshake implements IHandshake { } } - protected setRemoteEarlyData (data: Uint8Array|null|undefined): void { - if (data) { - this.remoteEarlyData = data + protected setRemoteNoiseExtension (e: NoiseExtensions | null | undefined): void { + if (e) { + this.remoteExtensions = e } } } diff --git a/src/noise.ts b/src/noise.ts index b2d08c2..e22bb39 100644 --- a/src/noise.ts +++ b/src/noise.ts @@ -15,6 +15,7 @@ import { decryptStream, encryptStream } from './crypto/streaming.js' import { uint16BEDecode, uint16BEEncode } from './encoder.js' import { XXHandshake } from './handshake-xx.js' import { getPayload } from './utils.js' +import type { NoiseExtensions } from './proto/payload.js' interface HandshakeParams { connection: ProtobufStream @@ -29,15 +30,15 @@ export class Noise implements INoiseConnection { private readonly prologue: Uint8Array private readonly staticKeys: KeyPair - private readonly earlyData?: bytes + private readonly extensions?: NoiseExtensions /** * @param {bytes} staticNoiseKey - x25519 private key, reuse for faster handshakes - * @param {bytes} earlyData + * @param {NoiseExtensions} extensions */ - constructor (staticNoiseKey?: bytes, earlyData?: bytes, crypto: ICryptoInterface = stablelib, prologueBytes?: Uint8Array) { - this.earlyData = earlyData ?? new Uint8Array(0) + constructor (staticNoiseKey?: bytes, extensions?: NoiseExtensions, crypto: ICryptoInterface = stablelib, prologueBytes?: Uint8Array) { this.crypto = crypto + this.extensions = extensions if (staticNoiseKey) { // accepts x25519 private key of length 32 @@ -56,7 +57,7 @@ export class Noise implements INoiseConnection { * @param {PeerId} remotePeer - PeerId of the remote peer. Used to validate the integrity of the remote peer. * @returns {Promise} */ - public async secureOutbound (localPeer: PeerId, connection: Duplex, remotePeer?: PeerId): Promise { + public async secureOutbound (localPeer: PeerId, connection: Duplex, remotePeer?: PeerId): Promise> { const wrappedConnection = pbStream( connection, { @@ -75,7 +76,7 @@ export class Noise implements INoiseConnection { return { conn, - remoteEarlyData: handshake.remoteEarlyData, + remoteExtensions: handshake.remoteExtensions, remotePeer: handshake.remotePeer } } @@ -88,7 +89,7 @@ export class Noise implements INoiseConnection { * @param {PeerId} remotePeer - optional PeerId of the initiating peer, if known. This may only exist during transport upgrades. * @returns {Promise} */ - public async secureInbound (localPeer: PeerId, connection: Duplex, remotePeer?: PeerId): Promise { + public async secureInbound (localPeer: PeerId, connection: Duplex, remotePeer?: PeerId): Promise> { const wrappedConnection = pbStream( connection, { @@ -107,8 +108,8 @@ export class Noise implements INoiseConnection { return { conn, - remoteEarlyData: handshake.remoteEarlyData, - remotePeer: handshake.remotePeer + remotePeer: handshake.remotePeer, + remoteExtensions: handshake.remoteExtensions } } @@ -119,7 +120,7 @@ export class Noise implements INoiseConnection { * @param {HandshakeParams} params */ private async performHandshake (params: HandshakeParams): Promise { - const payload = await getPayload(params.localPeer, this.staticKeys.publicKey, this.earlyData) + const payload = await getPayload(params.localPeer, this.staticKeys.publicKey, this.extensions) // run XX handshake return await this.performXXHandshake(params, payload) diff --git a/src/proto/payload.proto b/src/proto/payload.proto index 05a78c6..cdb2383 100644 --- a/src/proto/payload.proto +++ b/src/proto/payload.proto @@ -1,8 +1,11 @@ syntax = "proto3"; -package pb; -message NoiseHandshakePayload { - bytes identity_key = 1; - bytes identity_sig = 2; - bytes data = 3; +message NoiseExtensions { + repeated bytes webtransport_certhashes = 1; } + +message NoiseHandshakePayload { + bytes identity_key = 1; + bytes identity_sig = 2; + optional NoiseExtensions extensions = 4; +} \ No newline at end of file diff --git a/src/proto/payload.ts b/src/proto/payload.ts index 0d1af33..3a28281 100644 --- a/src/proto/payload.ts +++ b/src/proto/payload.ts @@ -5,100 +5,153 @@ import { encodeMessage, decodeMessage, message } from 'protons-runtime' import type { Uint8ArrayList } from 'uint8arraylist' import type { Codec } from 'protons-runtime' -export namespace pb { - export interface NoiseHandshakePayload { - identityKey: Uint8Array - identitySig: Uint8Array - data: Uint8Array - } +export interface NoiseExtensions { + webtransportCerthashes: Uint8Array[] +} - export namespace NoiseHandshakePayload { - let _codec: Codec +export namespace NoiseExtensions { + let _codec: Codec - export const codec = (): Codec => { - if (_codec == null) { - _codec = message((obj, writer, opts = {}) => { - if (opts.lengthDelimited !== false) { - writer.fork() - } + export const codec = (): Codec => { + if (_codec == null) { + _codec = message((obj, writer, opts = {}) => { + if (opts.lengthDelimited !== false) { + writer.fork() + } - if (obj.identityKey != null) { + if (obj.webtransportCerthashes != null) { + for (const value of obj.webtransportCerthashes) { writer.uint32(10) - writer.bytes(obj.identityKey) - } else { - throw new Error('Protocol error: required field "identityKey" was not found in object') + writer.bytes(value) } - - if (obj.identitySig != null) { - writer.uint32(18) - writer.bytes(obj.identitySig) - } else { - throw new Error('Protocol error: required field "identitySig" was not found in object') + } else { + throw new Error('Protocol error: required field "webtransportCerthashes" was not found in object') + } + + if (opts.lengthDelimited !== false) { + writer.ldelim() + } + }, (reader, length) => { + const obj: any = { + webtransportCerthashes: [] + } + + const end = length == null ? reader.len : reader.pos + length + + while (reader.pos < end) { + const tag = reader.uint32() + + switch (tag >>> 3) { + case 1: + obj.webtransportCerthashes.push(reader.bytes()) + break + default: + reader.skipType(tag & 7) + break } + } - if (obj.data != null) { - writer.uint32(26) - writer.bytes(obj.data) - } else { - throw new Error('Protocol error: required field "data" was not found in object') - } + return obj + }) + } - if (opts.lengthDelimited !== false) { - writer.ldelim() - } - }, (reader, length) => { - const obj: any = { - identityKey: new Uint8Array(0), - identitySig: new Uint8Array(0), - data: new Uint8Array(0) - } + return _codec + } - const end = length == null ? reader.len : reader.pos + length - - while (reader.pos < end) { - const tag = reader.uint32() - - switch (tag >>> 3) { - case 1: - obj.identityKey = reader.bytes() - break - case 2: - obj.identitySig = reader.bytes() - break - case 3: - obj.data = reader.bytes() - break - default: - reader.skipType(tag & 7) - break - } - } + export const encode = (obj: NoiseExtensions): Uint8Array => { + return encodeMessage(obj, NoiseExtensions.codec()) + } - if (obj.identityKey == null) { - throw new Error('Protocol error: value for required field "identityKey" was not found in protobuf') - } + export const decode = (buf: Uint8Array | Uint8ArrayList): NoiseExtensions => { + return decodeMessage(buf, NoiseExtensions.codec()) + } +} - if (obj.identitySig == null) { - throw new Error('Protocol error: value for required field "identitySig" was not found in protobuf') - } +export interface NoiseHandshakePayload { + identityKey: Uint8Array + identitySig: Uint8Array + extensions?: NoiseExtensions +} - if (obj.data == null) { - throw new Error('Protocol error: value for required field "data" was not found in protobuf') +export namespace NoiseHandshakePayload { + let _codec: Codec + + export const codec = (): Codec => { + if (_codec == null) { + _codec = message((obj, writer, opts = {}) => { + if (opts.lengthDelimited !== false) { + writer.fork() + } + + if (obj.identityKey != null) { + writer.uint32(10) + writer.bytes(obj.identityKey) + } else { + throw new Error('Protocol error: required field "identityKey" was not found in object') + } + + if (obj.identitySig != null) { + writer.uint32(18) + writer.bytes(obj.identitySig) + } else { + throw new Error('Protocol error: required field "identitySig" was not found in object') + } + + if (obj.extensions != null) { + writer.uint32(34) + NoiseExtensions.codec().encode(obj.extensions, writer) + } + + if (opts.lengthDelimited !== false) { + writer.ldelim() + } + }, (reader, length) => { + const obj: any = { + identityKey: new Uint8Array(0), + identitySig: new Uint8Array(0) + } + + const end = length == null ? reader.len : reader.pos + length + + while (reader.pos < end) { + const tag = reader.uint32() + + switch (tag >>> 3) { + case 1: + obj.identityKey = reader.bytes() + break + case 2: + obj.identitySig = reader.bytes() + break + case 4: + obj.extensions = NoiseExtensions.codec().decode(reader, reader.uint32()) + break + default: + reader.skipType(tag & 7) + break } + } - return obj - }) - } + if (obj.identityKey == null) { + throw new Error('Protocol error: value for required field "identityKey" was not found in protobuf') + } - return _codec - } + if (obj.identitySig == null) { + throw new Error('Protocol error: value for required field "identitySig" was not found in protobuf') + } - export const encode = (obj: NoiseHandshakePayload): Uint8Array => { - return encodeMessage(obj, NoiseHandshakePayload.codec()) + return obj + }) } - export const decode = (buf: Uint8Array | Uint8ArrayList): NoiseHandshakePayload => { - return decodeMessage(buf, NoiseHandshakePayload.codec()) - } + return _codec + } + + export const encode = (obj: NoiseHandshakePayload): Uint8Array => { + return encodeMessage(obj, NoiseHandshakePayload.codec()) + } + + export const decode = (buf: Uint8Array | Uint8ArrayList): NoiseHandshakePayload => { + return decodeMessage(buf, NoiseHandshakePayload.codec()) } } diff --git a/src/utils.ts b/src/utils.ts index 77d5388..6eadfce 100644 --- a/src/utils.ts +++ b/src/utils.ts @@ -4,17 +4,14 @@ import { peerIdFromKeys } from '@libp2p/peer-id' import { concat as uint8ArrayConcat } from 'uint8arrays/concat' import { fromString as uint8ArrayFromString } from 'uint8arrays/from-string' import type { bytes } from './@types/basic.js' -import { pb } from './proto/payload.js' - -const NoiseHandshakePayloadProto = pb.NoiseHandshakePayload +import { NoiseExtensions, NoiseHandshakePayload } from './proto/payload.js' export async function getPayload ( localPeer: PeerId, staticPublicKey: bytes, - earlyData?: bytes + extensions?: NoiseExtensions ): Promise { const signedPayload = await signPayload(localPeer, getHandshakePayload(staticPublicKey)) - const earlyDataPayload = earlyData ?? new Uint8Array(0) if (localPeer.publicKey == null) { throw new Error('PublicKey was missing from local PeerId') @@ -23,19 +20,19 @@ export async function getPayload ( return createHandshakePayload( localPeer.publicKey, signedPayload, - earlyDataPayload + extensions ) } export function createHandshakePayload ( libp2pPublicKey: Uint8Array, signedPayload: Uint8Array, - earlyData?: Uint8Array + extensions?: NoiseExtensions ): bytes { - return NoiseHandshakePayloadProto.encode({ + return NoiseHandshakePayload.encode({ identityKey: libp2pPublicKey, identitySig: signedPayload, - data: earlyData ?? new Uint8Array(0) + extensions: extensions ?? { webtransportCerthashes: [] } }).subarray() } @@ -49,12 +46,12 @@ export async function signPayload (peerId: PeerId, payload: bytes): Promise { +export async function getPeerIdFromPayload (payload: NoiseHandshakePayload): Promise { return await peerIdFromKeys(payload.identityKey) } -export function decodePayload (payload: bytes|Uint8Array): pb.NoiseHandshakePayload { - return NoiseHandshakePayloadProto.decode(payload) +export function decodePayload (payload: bytes | Uint8Array): NoiseHandshakePayload { + return NoiseHandshakePayload.decode(payload) } export function getHandshakePayload (publicKey: bytes): bytes { @@ -72,7 +69,7 @@ export function getHandshakePayload (publicKey: bytes): bytes { */ export async function verifySignedPayload ( noiseStaticKey: bytes, - payload: pb.NoiseHandshakePayload, + payload: NoiseHandshakePayload, remotePeer: PeerId ): Promise { // Unmarshaling from PublicKey protobuf diff --git a/test/interop.ts b/test/interop.ts index 65b23a5..ff3a4d0 100644 --- a/test/interop.ts +++ b/test/interop.ts @@ -23,8 +23,8 @@ async function createGoPeer (options: SpawnOptions): Promise { const log = logger(`go-libp2p:${controlPort}`) const opts = [ - `-listen=${apiAddr.toString()}`, - '-hostAddrs=/ip4/0.0.0.0/tcp/0' + `-listen=${apiAddr.toString()}`, + '-hostAddrs=/ip4/0.0.0.0/tcp/0' ] if (options.noise === true) { @@ -78,6 +78,7 @@ async function createJsPeer (options: SpawnOptions): Promise { }, transports: [new TCP()], streamMuxers: [new Mplex()], + // @ts-expect-error libp2p options is still referencing the old connection encrypter interface https://github.com/libp2p/js-libp2p/pull/1402 connectionEncryption: [new Noise()] } diff --git a/test/noise.spec.ts b/test/noise.spec.ts index b446650..77cfd43 100644 --- a/test/noise.spec.ts +++ b/test/noise.spec.ts @@ -162,13 +162,14 @@ describe('Noise', () => { } }) - it('should accept and return early data from remote peer', async () => { + it('should accept and return Noise extension from remote peer', async () => { try { - const localPeerEarlyData = Buffer.from('early data') + const certhashInit = Buffer.from('certhash data from init') const staticKeysInitiator = stablelib.generateX25519KeyPair() - const noiseInit = new Noise(staticKeysInitiator.privateKey, localPeerEarlyData) + const noiseInit = new Noise(staticKeysInitiator.privateKey, { webtransportCerthashes: [certhashInit] }) const staticKeysResponder = stablelib.generateX25519KeyPair() - const noiseResp = new Noise(staticKeysResponder.privateKey) + const certhashResp = Buffer.from('certhash data from respon') + const noiseResp = new Noise(staticKeysResponder.privateKey, { webtransportCerthashes: [certhashResp] }) const [inboundConnection, outboundConnection] = duplexPair() const [outbound, inbound] = await Promise.all([ @@ -176,8 +177,8 @@ describe('Noise', () => { noiseResp.secureInbound(remotePeer, inboundConnection) ]) - assert(uint8ArrayEquals(inbound.remoteEarlyData, localPeerEarlyData)) - assert(uint8ArrayEquals(outbound.remoteEarlyData, Buffer.alloc(0))) + assert(uint8ArrayEquals(inbound.remoteExtensions?.webtransportCerthashes[0] ?? new Uint8Array(), certhashInit)) + assert(uint8ArrayEquals(outbound.remoteExtensions?.webtransportCerthashes[0] ?? new Uint8Array(), certhashResp)) } catch (e) { const err = e as Error assert(false, err.message)