diff --git a/Networking/Sources/MsQuicSwift/QuicEventHandler.swift b/Networking/Sources/MsQuicSwift/QuicEventHandler.swift index 6927c7ec..66a5927d 100644 --- a/Networking/Sources/MsQuicSwift/QuicEventHandler.swift +++ b/Networking/Sources/MsQuicSwift/QuicEventHandler.swift @@ -65,73 +65,3 @@ extension QuicEventHandler { public func closed(_: QuicStream, status _: QuicStatus, code _: QuicErrorCode) {} } - -public final class MockQuicEventHandler: QuicEventHandler { - public enum EventType { - case newConnection(listener: QuicListener, connection: QuicConnection, info: ConnectionInfo) - case shouldOpen(connection: QuicConnection, certificate: Data?) - case connected(connection: QuicConnection) - case shutdownInitiated(connection: QuicConnection, reason: ConnectionCloseReason) - case shutdownComplete(connection: QuicConnection) - case streamStarted(connection: QuicConnection, stream: QuicStream) - case dataReceived(stream: QuicStream, data: Data?) - case closed(stream: QuicStream, status: QuicStatus, code: QuicErrorCode) - } - - public let events: ThreadSafeContainer<[EventType]> = .init([]) - - public init() {} - - public func newConnection( - _ listener: QuicListener, connection: QuicConnection, info: ConnectionInfo - ) -> QuicStatus { - events.write { events in - events.append(.newConnection(listener: listener, connection: connection, info: info)) - } - - return .code(.success) - } - - public func shouldOpen(_ connection: QuicConnection, certificate: Data?) -> QuicStatus { - events.write { events in - events.append(.shouldOpen(connection: connection, certificate: certificate)) - } - return .code(.success) - } - - public func connected(_ connection: QuicConnection) { - events.write { events in - events.append(.connected(connection: connection)) - } - } - - public func shutdownInitiated(_ connection: QuicConnection, reason: ConnectionCloseReason) { - events.write { events in - events.append(.shutdownInitiated(connection: connection, reason: reason)) - } - } - - public func shutdownComplete(_ connection: QuicConnection) { - events.write { events in - events.append(.shutdownComplete(connection: connection)) - } - } - - public func streamStarted(_ connect: QuicConnection, stream: QuicStream) { - events.write { events in - events.append(.streamStarted(connection: connect, stream: stream)) - } - } - - public func dataReceived(_ stream: QuicStream, data: Data?) { - events.write { events in - events.append(.dataReceived(stream: stream, data: data)) - } - } - - public func closed(_ stream: QuicStream, status: QuicStatus, code: QuicErrorCode) { - events.write { events in - events.append(.closed(stream: stream, status: status, code: code)) - } - } -} diff --git a/Networking/Sources/Networking/Connection.swift b/Networking/Sources/Networking/Connection.swift index 998b2116..b02380ac 100644 --- a/Networking/Sources/Networking/Connection.swift +++ b/Networking/Sources/Networking/Connection.swift @@ -119,16 +119,10 @@ public final class Connection: Sendable, ConnectionInfoP public var isClosed: Bool { state.read { - switch $0 { - case .connecting: - false - case .connected: - false - case .closed: - true - case .reconnect: - false + if case .closed = $0 { + return true } + return false } } diff --git a/Networking/Sources/Networking/Peer.swift b/Networking/Sources/Networking/Peer.swift index eef08566..5ca00014 100644 --- a/Networking/Sources/Networking/Peer.swift +++ b/Networking/Sources/Networking/Peer.swift @@ -20,12 +20,7 @@ struct BackoffState { var attempt: Int var delay: TimeInterval - init() { - attempt = 0 - delay = 1 - } - - init(attempt: Int = 0, delay: TimeInterval = 1) { + init(_ attempt: Int = 0, _ delay: TimeInterval = 1) { self.attempt = attempt self.delay = delay } @@ -479,12 +474,10 @@ private struct PeerEventHandler: QuicEventHandler { connections.byId[connection.id] } guard let conn else { - logger.warning( - "Connected but connection is gone?", metadata: ["connectionId": "\(connection.id)"] - ) + logger.warning("Connected but connection is gone?", metadata: ["connectionId": "\(connection.id)"]) return } - // Check if the connection is already reconnected + impl.reconnectStates.write { reconnectStates in reconnectStates[conn.remoteAddress] = nil } @@ -511,17 +504,17 @@ private struct PeerEventHandler: QuicEventHandler { connections.byId[connection.id] } let needReconnect = impl.connections.write { connections in + var needReconnect = false if let conn = connections.byId[connection.id] { - let needReconnect = conn.needReconnect + needReconnect = conn.needReconnect if let publicKey = conn.publicKey { connections.byPublicKey.removeValue(forKey: publicKey) } connections.byId.removeValue(forKey: connection.id) connections.byAddr.removeValue(forKey: conn.remoteAddress) conn.closed() - return needReconnect } - return false + return needReconnect } if needReconnect, let address = conn?.remoteAddress, let role = conn?.role { do { @@ -533,10 +526,7 @@ private struct PeerEventHandler: QuicEventHandler { } func shutdownInitiated(_ connection: QuicConnection, reason: ConnectionCloseReason) { - logger.debug( - "Shutdown initiated", - metadata: ["connectionId": "\(connection.id)", "reason": "\(reason)"] - ) + logger.debug("Shutdown initiated", metadata: ["connectionId": "\(connection.id)", "reason": "\(reason)"]) if shouldReconnect(basedOn: reason) { impl.connections.write { connections in if let conn = connections.byId[connection.id] { @@ -609,15 +599,8 @@ private struct PeerEventHandler: QuicEventHandler { if let connection { connection.streamClosed(stream: stream, abort: !status.isSucceeded) if shouldReopenStream(connection: connection, stream: stream, status: status) { - do { - if let kind = stream.kind { - // impl.reopenUpStream(connection: connection, kind: kind); - do { - try connection.createPreistentStream(kind: kind) - } catch { - logger.error("Attempt to recreate the persistent stream failed: \(error)") - } - } + if let kind = stream.kind { + impl.reopenUpStream(connection: connection, kind: kind) } } } else { @@ -642,7 +625,7 @@ private struct PeerEventHandler: QuicEventHandler { case .connectionIdle, .badCert: return false default: - return !status.isSucceeded + return status.isSucceeded } } } diff --git a/Networking/Tests/MsQuicSwiftTests/QuicListenerTests.swift b/Networking/Tests/MsQuicSwiftTests/QuicListenerTests.swift index 13e28a33..34029236 100644 --- a/Networking/Tests/MsQuicSwiftTests/QuicListenerTests.swift +++ b/Networking/Tests/MsQuicSwiftTests/QuicListenerTests.swift @@ -40,6 +40,152 @@ struct QuicListenerTests { registration = try QuicRegistration() } + final class MockQuicEventHandler: QuicEventHandler { + enum EventType { + case newConnection(listener: QuicListener, connection: QuicConnection, info: ConnectionInfo) + case shouldOpen(connection: QuicConnection, certificate: Data?) + case connected(connection: QuicConnection) + case shutdownInitiated(connection: QuicConnection, reason: ConnectionCloseReason) + case shutdownComplete(connection: QuicConnection) + case streamStarted(connection: QuicConnection, stream: QuicStream) + case dataReceived(stream: QuicStream, data: Data?) + case closed(stream: QuicStream, status: QuicStatus, code: QuicErrorCode) + } + + let events: ThreadSafeContainer<[EventType]> = .init([]) + + init() {} + + func newConnection( + _ listener: QuicListener, connection: QuicConnection, info: ConnectionInfo + ) -> QuicStatus { + events.write { events in + events.append(.newConnection(listener: listener, connection: connection, info: info)) + } + return .code(.success) + } + + func shouldOpen(_ connection: QuicConnection, certificate: Data?) -> QuicStatus { + events.write { events in + events.append(.shouldOpen(connection: connection, certificate: certificate)) + } + return .code(.success) + } + + func connected(_ connection: QuicConnection) { + events.write { events in + events.append(.connected(connection: connection)) + } + } + + func streamStarted(_ connect: QuicConnection, stream: QuicStream) { + events.write { events in + events.append(.streamStarted(connection: connect, stream: stream)) + } + } + + func dataReceived(_ stream: QuicStream, data: Data?) { + events.write { events in + events.append(.dataReceived(stream: stream, data: data)) + } + } + } + + final class EmptyQuicEventHandler: QuicEventHandler {} + + @Test + func emptyQuicEventHandler() async throws { + let serverHandler = MockQuicEventHandler() + let clientHandler = EmptyQuicEventHandler() + + // create listener + + let quicSettings = QuicSettings.defaultSettings + let serverConfiguration = try QuicConfiguration( + registration: registration, + pkcs12: pkcs12Data, + alpns: [Data("testalpn".utf8)], + client: false, + settings: quicSettings + ) + + let listener = try QuicListener( + handler: serverHandler, + registration: registration, + configuration: serverConfiguration, + listenAddress: NetAddr(ipAddress: "127.0.0.1", port: 0)!, + alpns: [Data("testalpn".utf8)] + ) + + let listenAddress = try listener.listenAddress() + let (ipAddress, port) = listenAddress.getAddressAndPort() + #expect(ipAddress == "127.0.0.1") + #expect(port != 0) + + // create connection to listener + + let clientConfiguration = try QuicConfiguration( + registration: registration, + pkcs12: pkcs12Data, + alpns: [Data("testalpn".utf8)], + client: true, + settings: quicSettings + ) + + let clientConnection = try QuicConnection( + handler: clientHandler, + registration: registration, + configuration: clientConfiguration + ) + + try clientConnection.connect(to: listenAddress) + + let stream1 = try clientConnection.createStream() + + try stream1.send(data: Data("test data 1".utf8)) + + try? await Task.sleep(for: .milliseconds(100)) + let (serverConnection, info) = serverHandler.events.value.compactMap { + switch $0 { + case let .newConnection(_, connection, info): + (connection, info) as (QuicConnection, ConnectionInfo)? + default: + nil + } + }.first! + + let (ipAddress2, _) = info.remoteAddress.getAddressAndPort() + + #expect(info.negotiatedAlpn == Data("testalpn".utf8)) + #expect(info.serverName == "127.0.0.1") + #expect(info.localAddress == listenAddress) + #expect(ipAddress2 == "127.0.0.1") + + let stream2 = try serverConnection.createStream() + try stream2.send(data: Data("other test data 2".utf8)) + + try? await Task.sleep(for: .milliseconds(100)) + let receivedData = serverHandler.events.value.compactMap { + switch $0 { + case let .dataReceived(_, data): + data + default: + nil + } + } + + #expect(receivedData.count == 1) + #expect(receivedData[0] == Data("test data 1".utf8)) + try clientConnection.shutdown() + try? await Task.sleep(for: .milliseconds(1000)) + #expect(throws: Error.self) { + try serverConnection.connect(to: info.remoteAddress) + } + #expect(throws: Error.self) { + _ = try clientConnection.getRemoteAddress() + } + } + @Test func connectAndSendReceive() async throws { let serverHandler = MockQuicEventHandler() diff --git a/Networking/Tests/NetworkingTests/MockPeerEventTests.swift b/Networking/Tests/NetworkingTests/MockPeerEventTests.swift index f837c576..54106e72 100644 --- a/Networking/Tests/NetworkingTests/MockPeerEventTests.swift +++ b/Networking/Tests/NetworkingTests/MockPeerEventTests.swift @@ -19,7 +19,6 @@ final class MockPeerEventTests { case shutdownInitiated(connection: QuicConnection, reason: ConnectionCloseReason) case streamStarted(connection: QuicConnection, stream: QuicStream) case dataReceived(stream: QuicStream, data: Data?) - case closed(stream: QuicStream, status: QuicStatus, code: QuicErrorCode) } let events: ThreadSafeContainer<[EventType]> = .init([]) @@ -80,12 +79,6 @@ final class MockPeerEventTests { events.append(.dataReceived(stream: stream, data: data)) } } - - func closed(_ stream: QuicStream, status: QuicStatus, code: QuicErrorCode) { - events.write { events in - events.append(.closed(stream: stream, status: status, code: code)) - } - } } let registration: QuicRegistration @@ -161,12 +154,10 @@ final class MockPeerEventTests { try clientConnection.connect(to: listenAddress) try await Task.sleep(for: .milliseconds(100)) let (_, reason) = clientHandler.events.value.compactMap { - switch $0 { - case let .shutdownInitiated(connection, reason): - (connection, reason) as (QuicConnection, ConnectionCloseReason)? - default: - nil + if case let .shutdownInitiated(connection, reason) = $0 { + return (connection, reason) } + return nil }.first! #expect( reason @@ -176,6 +167,8 @@ final class MockPeerEventTests { ) } + final class MockQuicEventHandler: QuicEventHandler {} + @Test func connected() async throws { let serverHandler = MockPeerEventHandler() @@ -216,6 +209,8 @@ final class MockPeerEventTests { // Attempt to connect try clientConnection.connect(to: listenAddress) + let removeAddress = try clientConnection.getRemoteAddress() + #expect(removeAddress != nil) let stream1 = try clientConnection.createStream() try stream1.send(data: Data("test data 1".utf8)) @@ -236,6 +231,82 @@ final class MockPeerEventTests { #expect(ipAddress2 == "127.0.0.1") } + @Test + func mockTestCert() async throws { + let serverHandler = MockPeerEventHandler() + let clientHandler = MockQuicEventHandler() + + let serverConfiguration = try QuicConfiguration( + registration: registration, + pkcs12: badCertData, + alpns: [Data("testalpn".utf8)], + client: false, + settings: QuicSettings.defaultSettings + ) + + let listener = try QuicListener( + handler: serverHandler, + registration: registration, + configuration: serverConfiguration, + listenAddress: NetAddr(ipAddress: "127.0.0.1", port: 0)!, + alpns: [Data("testalpn".utf8)] + ) + + let listenAddress = try listener.listenAddress() + + // Client setup with bad certificate + let clientConfiguration = try QuicConfiguration( + registration: registration, + pkcs12: certData, + alpns: [Data("testalpn".utf8)], + client: true, + settings: QuicSettings.defaultSettings + ) + + let clientConnection = try QuicConnection( + handler: clientHandler, + registration: registration, + configuration: clientConfiguration + ) + try clientConnection.connect(to: listenAddress) + try await Task.sleep(for: .milliseconds(100)) + let stream1 = try clientConnection.createStream() + + try stream1.send(data: Data("test data 1".utf8)) + + try? await Task.sleep(for: .milliseconds(100)) + let (serverConnection, info) = serverHandler.events.value.compactMap { + switch $0 { + case let .newConnection(_, connection, info): + (connection, info) as (QuicConnection, ConnectionInfo)? + default: + nil + } + }.first! + let (ipAddress2, _) = info.remoteAddress.getAddressAndPort() + + #expect(info.negotiatedAlpn == Data("testalpn".utf8)) + #expect(info.serverName == "127.0.0.1") + #expect(info.localAddress == listenAddress) + #expect(ipAddress2 == "127.0.0.1") + + let stream2 = try serverConnection.createStream() + try stream2.send(data: Data("other test data 2".utf8)) + + try? await Task.sleep(for: .milliseconds(100)) + let receivedData = serverHandler.events.value.compactMap { + switch $0 { + case let .dataReceived(_, data): + data + default: + nil + } + } + + #expect(receivedData.count == 1) + #expect(receivedData[0] == Data("test data 1".utf8)) + } + @Test func rejectsConDueToBadClientCert() async throws { let serverHandler = MockPeerEventHandler() diff --git a/Networking/Tests/NetworkingTests/PKCS12Tests.swift b/Networking/Tests/NetworkingTests/PKCS12Tests.swift index 19b5801d..8523a531 100644 --- a/Networking/Tests/NetworkingTests/PKCS12Tests.swift +++ b/Networking/Tests/NetworkingTests/PKCS12Tests.swift @@ -27,8 +27,8 @@ struct PKCS12Tests { let registration = try QuicRegistration() - let serverHandler = MockQuicEventHandler() - let clientHandler = MockQuicEventHandler() + let serverHandler = MockPeerEventTests.MockPeerEventHandler() + let clientHandler = MockPeerEventTests.MockPeerEventHandler() // create listener @@ -71,26 +71,20 @@ struct PKCS12Tests { try? await Task.sleep(for: .milliseconds(50)) - let clientData = clientHandler.events.value.compactMap { - switch $0 { - case let .shouldOpen(_, certificate): - certificate as Data? - default: - nil + let clientConn = clientHandler.events.value.compactMap { + if case let .connected(connection: connection) = $0 { + return connection } - } - - #expect(clientData.first!.count > 0) + return nil + }.first! + #expect(clientConn != nil) - let serverData = serverHandler.events.value.compactMap { - switch $0 { - case let .shouldOpen(_, certificate): - certificate as Data? - default: - nil + let serverConn = serverHandler.events.value.compactMap { + if case let .connected(connection: connection) = $0 { + return connection } - } - - #expect(serverData.first!.count > 0) + return nil + }.first! + #expect(serverConn != nil) } } diff --git a/Networking/Tests/NetworkingTests/PeerTests.swift b/Networking/Tests/NetworkingTests/PeerTests.swift index 07079bcd..96bcaceb 100644 --- a/Networking/Tests/NetworkingTests/PeerTests.swift +++ b/Networking/Tests/NetworkingTests/PeerTests.swift @@ -7,13 +7,6 @@ import Utils @testable import Networking struct PeerTests { - struct MockMessage: MessageProtocol { - let data: Data - func encode() throws -> Data { - data - } - } - struct MockRequest: RequestProtocol { var kind: Kind var data: Data @@ -50,10 +43,6 @@ struct PeerTests { self.data = data return MockRequest(kind: kind, data: data) } - - func finish() -> Data? { - data - } } struct MockUniqueMessageDecoder: MessageDecoder { @@ -70,10 +59,6 @@ struct PeerTests { self.data = data return MockRequest(kind: kind, data: data) } - - func finish() -> Data? { - data - } } actor DataStorage { @@ -89,10 +74,6 @@ struct PeerTests { typealias Request = MockRequest private let dataStorage: PeerTests.DataStorage = DataStorage() - var lastReceivedData: Data? { - get async { await dataStorage.data.last } - } - func createDecoder(kind: StreamKind) -> any MessageDecoder { MockEphemeralMessageDecoder(kind: kind) } @@ -184,8 +165,8 @@ struct PeerTests { let con = try peer.connect(to: centerPeer.listenAddress(), role: .builder) try await con.ready() } - // Simulate close connections 3~5s - try? await Task.sleep(for: .milliseconds(5000)) + // Simulate close connections 5~8s + try? await Task.sleep(for: .milliseconds(8000)) centerPeer.broadcast(kind: .uniqueA, message: .init(kind: .uniqueA, data: Data("connection rotation strategy".utf8))) try? await Task.sleep(for: .milliseconds(1000)) var receivedCount = 0 @@ -239,6 +220,12 @@ struct PeerTests { let connection1 = try peer1.connect(to: listenAddress, role: .validator) try? await Task.sleep(for: .milliseconds(3000)) + #expect(throws: Error.self) { + _ = try connection1.createStream(kind: .typeA) + } + #expect(throws: Error.self) { + _ = try connection1.createStream(kind: .uniqueA) + } #expect(connection1.isClosed == true) } @@ -335,8 +322,8 @@ struct PeerTests { presistentStreams[.uniqueA] } stream!.close(abort: true) - // Wait to simulate downtime & reopen up stream 3~5s - try? await Task.sleep(for: .milliseconds(3000)) + // Wait to simulate downtime & reopen up stream 8s + try? await Task.sleep(for: .milliseconds(8000)) messageData = Data("reopen up stream data".utf8) peer1.broadcast( kind: .uniqueA, message: .init(kind: .uniqueA, data: messageData) @@ -435,13 +422,9 @@ struct PeerTests { let connection1 = try peer1.connect(to: peer2.listenAddress(), role: .validator) let connection2 = try peer2.connect(to: peer1.listenAddress(), role: .validator) try? await Task.sleep(for: .milliseconds(1000)) - if !connection1.isClosed { - let data = try await connection1.request(MockRequest(kind: .typeA, data: Data("hello world".utf8))) - try? await Task.sleep(for: .milliseconds(500)) - #expect(data == Data("hello world response".utf8)) - } - if !connection2.isClosed { - let data = try await connection2.request(MockRequest(kind: .typeA, data: Data("hello world".utf8))) + let connections = [connection1, connection2] + for connection in connections where !connection.isClosed { + let data = try await connection.request(MockRequest(kind: .typeA, data: Data("hello world".utf8))) try? await Task.sleep(for: .milliseconds(500)) #expect(data == Data("hello world response".utf8)) } @@ -452,7 +435,7 @@ struct PeerTests { let handler1 = MockPresentStreamHandler() let handler2 = MockPresentStreamHandler() // Define the data size, 5MB - let dataSize = 5 * 1024 * 1024 + let dataSize = 10 * 1024 * 1024 var largeData = Data(capacity: dataSize) // Generate random data @@ -494,25 +477,29 @@ struct PeerTests { try? await Task.sleep(for: .milliseconds(50)) let receivedData1 = try await connection1.request( - MockRequest(kind: .typeA, data: largeData) + MockRequest(kind: .typeA, data: largeData.prefix(dataSize / 2)) ) try? await Task.sleep(for: .milliseconds(100)) // Verify that the received data matches the original large data - #expect(receivedData1 == largeData + Data(" response".utf8)) - + #expect(receivedData1 == largeData.prefix(dataSize / 2) + Data(" response".utf8)) peer1.broadcast( - kind: .uniqueA, message: .init(kind: .uniqueA, data: largeData) + kind: .uniqueA, message: .init(kind: .uniqueA, data: largeData.prefix(dataSize / 2)) ) try? await Task.sleep(for: .milliseconds(100)) peer2.broadcast( - kind: .uniqueB, message: .init(kind: .uniqueB, data: largeData) + kind: .uniqueB, message: .init(kind: .uniqueB, data: largeData.prefix(dataSize / 2)) ) // Verify last received data try? await Task.sleep(for: .milliseconds(2000)) - await #expect(handler2.lastReceivedData == largeData) - await #expect(handler1.lastReceivedData == largeData) + await #expect(handler2.lastReceivedData == largeData.prefix(dataSize / 2)) + await #expect(handler1.lastReceivedData == largeData.prefix(dataSize / 2)) + await #expect(throws: Error.self) { + _ = try await connection1.request( + MockRequest(kind: .typeC, data: largeData) + ) + } } @Test