diff --git a/build.gradle b/build.gradle index 0c89ece..798b34a 100644 --- a/build.gradle +++ b/build.gradle @@ -4,7 +4,7 @@ buildscript { ext.kotlinVersion = '1.4.0' ext.protobufVersion = '3.13.0' ext.protobufGradleVersion = '0.8.12' - ext.kotlinCoroutinesVersion = '1.3.7' + ext.kotlinCoroutinesVersion = '1.3.8' ext.ktorVersion = '1.4.0' ext.okhttpVersion = '4.8.1' } @@ -33,7 +33,7 @@ dependencies { implementation("org.jetbrains.kotlin:kotlin-stdlib-jdk8:$kotlinVersion") - api('tech.relaycorp:relaynet:1.34.1') + api('tech.relaycorp:relaynet:1.35.0') // Handshake nonce signatures implementation("org.bouncycastle:bcpkix-jdk15on:1.66") @@ -57,6 +57,7 @@ dependencies { testImplementation("org.jetbrains.kotlin:kotlin-test-junit5") testImplementation("com.nhaarman.mockitokotlin2:mockito-kotlin:2.2.0") testImplementation("org.mockito:mockito-inline:3.5.2") + testImplementation("org.jetbrains.kotlinx:kotlinx-coroutines-test:$kotlinCoroutinesVersion") } java { diff --git a/src/main/kotlin/tech/relaycorp/poweb/PoWebClient.kt b/src/main/kotlin/tech/relaycorp/poweb/PoWebClient.kt index 9b2e986..bf59992 100644 --- a/src/main/kotlin/tech/relaycorp/poweb/PoWebClient.kt +++ b/src/main/kotlin/tech/relaycorp/poweb/PoWebClient.kt @@ -5,61 +5,175 @@ import io.ktor.client.engine.okhttp.OkHttp import io.ktor.client.features.websocket.DefaultClientWebSocketSession import io.ktor.client.features.websocket.WebSockets import io.ktor.client.features.websocket.webSocket +import io.ktor.client.request.header import io.ktor.http.cio.websocket.CloseReason import io.ktor.http.cio.websocket.Frame import io.ktor.http.cio.websocket.close import io.ktor.http.cio.websocket.readBytes import io.ktor.util.KtorExperimentalAPI +import kotlinx.coroutines.channels.ClosedReceiveChannelException +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.FlowCollector +import kotlinx.coroutines.flow.flow import tech.relaycorp.poweb.handshake.Challenge -import tech.relaycorp.poweb.handshake.InvalidMessageException -import tech.relaycorp.poweb.handshake.NonceSigner +import tech.relaycorp.poweb.handshake.InvalidChallengeException import tech.relaycorp.poweb.handshake.Response +import tech.relaycorp.relaynet.bindings.pdc.NonceSigner +import tech.relaycorp.relaynet.bindings.pdc.ParcelCollection +import tech.relaycorp.relaynet.bindings.pdc.StreamingMode +import tech.relaycorp.relaynet.messages.InvalidMessageException +import tech.relaycorp.relaynet.messages.control.ParcelDelivery +import tech.relaycorp.relaynet.wrappers.x509.Certificate import java.io.Closeable +import java.io.EOFException +import java.net.ConnectException +/** + * PoWeb client. + * + * @param hostName The IP address or domain for the PoWeb server + * @param port The port for the PoWeb server + * @param useTls Whether the PoWeb server uses TLS + * + * The underlying connection is created lazily. + */ +@KtorExperimentalAPI public class PoWebClient internal constructor( internal val hostName: String, internal val port: Int, internal val useTls: Boolean ) : Closeable { - @KtorExperimentalAPI internal var ktorClient = HttpClient(OkHttp) { install(WebSockets) } - @KtorExperimentalAPI + private val wsScheme = if (useTls) "wss" else "ws" + + /** + * Close the underlying connection to the server (if any). + */ override fun close(): Unit = ktorClient.close() - private val wsUrl = "ws${if (useTls) "s" else ""}://$hostName:$port" + /** + * Collect parcels on behalf of the specified nodes. + * + * @param nonceSigners The nonce signers for each node whose parcels should be collected + * @param streamingMode Which streaming mode to ask the server to use + */ + @Throws( + ServerConnectionException::class, + InvalidServerMessageException::class, + NonceSignerException::class + ) + public suspend fun collectParcels( + nonceSigners: Array, + streamingMode: StreamingMode = StreamingMode.KeepAlive + ): Flow = flow { + if (nonceSigners.isEmpty()) { + throw NonceSignerException("At least one nonce signer must be specified") + } + + val trustedCertificates = nonceSigners.map { it.certificate } + val streamingModeHeader = Pair(StreamingMode.HEADER_NAME, streamingMode.headerValue) + wsConnect(PARCEL_COLLECTION_ENDPOINT_PATH, listOf(streamingModeHeader)) { + try { + handshake(nonceSigners) + } catch (exc: ClosedReceiveChannelException) { + // Alert the client to the fact that the server closed the connection before + // completing the handshake. Otherwise, the client will assume that the operation + // succeeded and there were no parcels to collect. + throw ServerConnectionException( + "Server closed the connection during the handshake", + exc + ) + } + collectAndAckParcels(this, this@flow, trustedCertificates) + + // The server must've closed the connection for us to get here, since we're consuming + // all incoming messages indefinitely. + val reason = closeReason.await()!! + if (reason.code != CloseReason.Codes.NORMAL.code) { + throw ServerConnectionException( + "Server closed the connection unexpectedly " + + "(code: ${reason.code}, reason: ${reason.message})" + ) + } + } + } + + @Throws(PoWebException::class) + private suspend fun collectAndAckParcels( + webSocketSession: DefaultClientWebSocketSession, + flowCollector: FlowCollector, + trustedCertificates: List + ) { + for (frame in webSocketSession.incoming) { + val delivery = try { + ParcelDelivery.deserialize(frame.readBytes()) + } catch (exc: InvalidMessageException) { + webSocketSession.close( + CloseReason(CloseReason.Codes.VIOLATED_POLICY, "Invalid parcel delivery") + ) + throw InvalidServerMessageException("Received invalid message from server", exc) + } + val collector = ParcelCollection(delivery.parcelSerialized, trustedCertificates) { + webSocketSession.outgoing.send(Frame.Text(delivery.deliveryId)) + } + flowCollector.emit(collector) + } + } - @KtorExperimentalAPI internal suspend fun wsConnect( path: String, + headers: List>? = null, block: suspend DefaultClientWebSocketSession.() -> Unit - ) = ktorClient.webSocket("${wsUrl}$path", block = block) + ) = try { + ktorClient.webSocket( + "$wsScheme://$hostName:$port$path", + { headers?.forEach { header(it.first, it.second) } }, + block + ) + } catch (exc: ConnectException) { + throw ServerConnectionException("Server is unreachable", exc) + } catch (exc: EOFException) { + throw ServerConnectionException("Connection was closed abruptly", exc) + } public companion object { - private const val defaultLocalPort = 276 - private const val defaultRemotePort = 443 + internal const val PARCEL_COLLECTION_ENDPOINT_PATH = "/v1/parcel-collection" - public fun initLocal(port: Int = defaultLocalPort): PoWebClient = - PoWebClient("127.0.0.1", port, false) + private const val DEFAULT_LOCAL_PORT = 276 + private const val DEFAULT_REMOTE_PORT = 443 - public fun initRemote(hostName: String, port: Int = defaultRemotePort): PoWebClient = - PoWebClient(hostName, port, true) + /** + * Connect to a private gateway from a private endpoint. + * + * @param port The port for the PoWeb server + * + * TLS won't be used. + */ + public fun initLocal(port: Int = DEFAULT_LOCAL_PORT): PoWebClient = + PoWebClient("127.0.0.1", port, false) + + /** + * Connect to a public gateway from a private gateway via TLS. + * + * @param hostName The IP address or domain for the PoWeb server + * @param port The port for the PoWeb server + */ + public fun initRemote(hostName: String, port: Int = DEFAULT_REMOTE_PORT): PoWebClient = + PoWebClient(hostName, port, true) } } @Throws(PoWebException::class) -internal suspend fun DefaultClientWebSocketSession.handshake(nonceSigners: Array) { - if (nonceSigners.isEmpty()) { - throw PoWebException("At least one nonce signer must be specified") - } +private suspend fun DefaultClientWebSocketSession.handshake(nonceSigners: Array) { val challengeRaw = incoming.receive() val challenge = try { Challenge.deserialize(challengeRaw.readBytes()) - } catch (exc: InvalidMessageException) { + } catch (exc: InvalidChallengeException) { close(CloseReason(CloseReason.Codes.VIOLATED_POLICY, "")) - throw PoWebException("Server sent an invalid handshake challenge", exc) + throw InvalidServerMessageException("Server sent an invalid handshake challenge", exc) } val nonceSignatures = nonceSigners.map { it.sign(challenge.nonce) }.toTypedArray() val response = Response(nonceSignatures) diff --git a/src/main/kotlin/tech/relaycorp/poweb/PoWebException.kt b/src/main/kotlin/tech/relaycorp/poweb/PoWebException.kt index e43712c..865012b 100644 --- a/src/main/kotlin/tech/relaycorp/poweb/PoWebException.kt +++ b/src/main/kotlin/tech/relaycorp/poweb/PoWebException.kt @@ -1,3 +1,33 @@ package tech.relaycorp.poweb -public class PoWebException(message: String, cause: Throwable? = null) : Exception(message, cause) +public abstract class PoWebException internal constructor( + message: String, + cause: Throwable? = null +) : Exception(message, cause) + +/** + * Base class for connectivity errors and errors caused by the server. + */ +public sealed class ServerException(message: String, cause: Throwable?) : + PoWebException(message, cause) + +/** + * Error before or while connected to the server. + * + * The client should retry later. + */ +public class ServerConnectionException(message: String, cause: Throwable? = null) : + ServerException(message, cause) + +/** + * The server sent an invalid message. + * + * The server didn't adhere to the protocol. Retrying later is unlikely to make a difference. + */ +public class InvalidServerMessageException(message: String, cause: Throwable) : + ServerException(message, cause) + +/** + * The client made a mistake while specifying the nonce signer(s). + */ +public class NonceSignerException(message: String) : PoWebException(message) diff --git a/src/main/kotlin/tech/relaycorp/poweb/handshake/Challenge.kt b/src/main/kotlin/tech/relaycorp/poweb/handshake/Challenge.kt index d2ae248..759b2f6 100644 --- a/src/main/kotlin/tech/relaycorp/poweb/handshake/Challenge.kt +++ b/src/main/kotlin/tech/relaycorp/poweb/handshake/Challenge.kt @@ -16,7 +16,7 @@ public class Challenge(public val nonce: ByteArray) { val pbChallenge = try { PBChallenge.parseFrom(serialization) } catch (_: InvalidProtocolBufferException) { - throw InvalidMessageException("Message is not a valid challenge") + throw InvalidChallengeException("Message is not a valid challenge") } return Challenge(pbChallenge.gatewayNonce.toByteArray()) } diff --git a/src/main/kotlin/tech/relaycorp/poweb/handshake/InvalidChallengeException.kt b/src/main/kotlin/tech/relaycorp/poweb/handshake/InvalidChallengeException.kt new file mode 100644 index 0000000..249d63a --- /dev/null +++ b/src/main/kotlin/tech/relaycorp/poweb/handshake/InvalidChallengeException.kt @@ -0,0 +1,3 @@ +package tech.relaycorp.poweb.handshake + +public class InvalidChallengeException(message: String) : Exception(message) diff --git a/src/main/kotlin/tech/relaycorp/poweb/handshake/InvalidMessageException.kt b/src/main/kotlin/tech/relaycorp/poweb/handshake/InvalidMessageException.kt deleted file mode 100644 index b2829ce..0000000 --- a/src/main/kotlin/tech/relaycorp/poweb/handshake/InvalidMessageException.kt +++ /dev/null @@ -1,3 +0,0 @@ -package tech.relaycorp.poweb.handshake - -public class InvalidMessageException(message: String) : Exception(message) diff --git a/src/main/kotlin/tech/relaycorp/poweb/handshake/InvalidResponseException.kt b/src/main/kotlin/tech/relaycorp/poweb/handshake/InvalidResponseException.kt new file mode 100644 index 0000000..a3e7e4c --- /dev/null +++ b/src/main/kotlin/tech/relaycorp/poweb/handshake/InvalidResponseException.kt @@ -0,0 +1,3 @@ +package tech.relaycorp.poweb.handshake + +public class InvalidResponseException(message: String) : Exception(message) diff --git a/src/main/kotlin/tech/relaycorp/poweb/handshake/NonceSigner.kt b/src/main/kotlin/tech/relaycorp/poweb/handshake/NonceSigner.kt deleted file mode 100644 index e589cbb..0000000 --- a/src/main/kotlin/tech/relaycorp/poweb/handshake/NonceSigner.kt +++ /dev/null @@ -1,15 +0,0 @@ -package tech.relaycorp.poweb.handshake - -import tech.relaycorp.relaynet.messages.control.NonceSignature -import tech.relaycorp.relaynet.wrappers.x509.Certificate -import java.security.PrivateKey - -public class NonceSigner( - internal val certificate: Certificate, - private val privateKey: PrivateKey -) { - public fun sign(nonce: ByteArray): ByteArray { - val signature = NonceSignature(nonce, certificate) - return signature.serialize(privateKey) - } -} diff --git a/src/main/kotlin/tech/relaycorp/poweb/handshake/Response.kt b/src/main/kotlin/tech/relaycorp/poweb/handshake/Response.kt index 4d0362d..ee9e33e 100644 --- a/src/main/kotlin/tech/relaycorp/poweb/handshake/Response.kt +++ b/src/main/kotlin/tech/relaycorp/poweb/handshake/Response.kt @@ -7,8 +7,8 @@ import tech.relaycorp.poweb.internal.protobuf_messages.handshake.Response as PBR public class Response(public val nonceSignatures: Array) { public fun serialize(): ByteArray { val pbResponse = PBResponse.newBuilder() - .addAllGatewayNonceSignatures(nonceSignatures.map { ByteString.copyFrom(it) }) - .build() + .addAllGatewayNonceSignatures(nonceSignatures.map { ByteString.copyFrom(it) }) + .build() return pbResponse.toByteArray() } @@ -17,7 +17,7 @@ public class Response(public val nonceSignatures: Array) { val pbResponse = try { PBResponse.parseFrom(serialization) } catch (_: InvalidProtocolBufferException) { - throw InvalidMessageException("Message is not a valid response") + throw InvalidResponseException("Message is not a valid response") } val nonceSignatures = pbResponse.gatewayNonceSignaturesList.map { it.toByteArray() } return Response(nonceSignatures.toTypedArray()) diff --git a/src/test/kotlin/tech/relaycorp/poweb/ParcelCollectionTest.kt b/src/test/kotlin/tech/relaycorp/poweb/ParcelCollectionTest.kt new file mode 100644 index 0000000..c4f1c86 --- /dev/null +++ b/src/test/kotlin/tech/relaycorp/poweb/ParcelCollectionTest.kt @@ -0,0 +1,407 @@ +package tech.relaycorp.poweb + +import io.ktor.http.cio.websocket.CloseReason +import io.ktor.util.KtorExperimentalAPI +import kotlinx.coroutines.channels.ClosedReceiveChannelException +import kotlinx.coroutines.flow.collect +import kotlinx.coroutines.flow.first +import kotlinx.coroutines.flow.take +import kotlinx.coroutines.flow.toList +import kotlinx.coroutines.runBlocking +import org.junit.jupiter.api.AfterEach +import org.junit.jupiter.api.Nested +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.assertThrows +import tech.relaycorp.poweb.handshake.InvalidChallengeException +import tech.relaycorp.poweb.websocket.ActionSequence +import tech.relaycorp.poweb.websocket.ChallengeAction +import tech.relaycorp.poweb.websocket.CloseConnectionAction +import tech.relaycorp.poweb.websocket.MockKtorClientManager +import tech.relaycorp.poweb.websocket.ParcelDeliveryAction +import tech.relaycorp.poweb.websocket.SendTextMessageAction +import tech.relaycorp.poweb.websocket.WebSocketTestCase +import tech.relaycorp.relaynet.bindings.pdc.NonceSigner +import tech.relaycorp.relaynet.bindings.pdc.StreamingMode +import tech.relaycorp.relaynet.issueEndpointCertificate +import tech.relaycorp.relaynet.messages.InvalidMessageException +import tech.relaycorp.relaynet.messages.control.NonceSignature +import tech.relaycorp.relaynet.wrappers.generateRSAKeyPair +import java.nio.charset.Charset +import java.time.ZonedDateTime +import kotlin.test.assertEquals +import kotlin.test.assertFalse +import kotlin.test.assertTrue + +@KtorExperimentalAPI +class ParcelCollectionTest : WebSocketTestCase() { + private val nonce = "nonce".toByteArray() + + // Compute client on demand because getting the server port will start the server + private val client by lazy { PoWebClient.initLocal(mockWebServer.port) } + + private val signer = generateDummySigner() + + private val deliveryId = "the delivery id" + private val parcelSerialized = "the parcel serialized".toByteArray() + + @AfterEach + fun closeClient() = client.ktorClient.close() + + @Test + fun `Request should be made to the parcel collection endpoint`() = runBlocking { + val mockClient = PoWebClient.initLocal() + val ktorClientManager = MockKtorClientManager() + mockClient.ktorClient = ktorClientManager.ktorClient + + ktorClientManager.useClient { + mockClient.collectParcels(arrayOf(signer)).toList() + } + + assertEquals( + PoWebClient.PARCEL_COLLECTION_ENDPOINT_PATH, + ktorClientManager.request.url.encodedPath + ) + } + + @Nested + inner class Handshake { + @Test + fun `Server closing connection during handshake should throw exception`() { + setListenerActions(CloseConnectionAction()) + + client.use { + val exception = assertThrows { + runBlocking { client.collectParcels(arrayOf(signer)).first() } + } + + assertEquals( + "Server closed the connection during the handshake", + exception.message + ) + assertTrue(exception.cause is ClosedReceiveChannelException) + } + + awaitForConnectionClosure() + assertEquals(CloseReason.Codes.NORMAL, listener!!.closingCode) + } + + @Test + fun `Getting an invalid challenge should throw an exception`() { + setListenerActions(SendTextMessageAction("Not a valid challenge")) + + client.use { + val exception = assertThrows { + runBlocking { client.collectParcels(arrayOf(signer)).first() } + } + + assertEquals("Server sent an invalid handshake challenge", exception.message) + assertTrue(exception.cause is InvalidChallengeException) + } + + awaitForConnectionClosure() + assertEquals(CloseReason.Codes.VIOLATED_POLICY, listener!!.closingCode) + } + + @Test + fun `At least one nonce signer should be required`() { + setListenerActions() + + client.use { + val exception = assertThrows { + runBlocking { client.collectParcels(emptyArray()).first() } + } + + assertEquals("At least one nonce signer must be specified", exception.message) + } + + assertFalse(listener!!.connected) + } + + @Test + fun `Challenge nonce should be signed with each signer`() { + setListenerActions(ChallengeAction(nonce), CloseConnectionAction()) + + val signer2 = generateDummySigner() + + client.use { + runBlocking { client.collectParcels(arrayOf(signer, signer2)).toList() } + } + + awaitForConnectionClosure() + + assertEquals(1, listener!!.receivedMessages.size) + val response = tech.relaycorp.poweb.handshake.Response.deserialize( + listener!!.receivedMessages.first() + ) + val nonceSignatures = response.nonceSignatures + val signature1 = NonceSignature.deserialize(nonceSignatures[0]) + assertEquals(nonce.asList(), signature1.nonce.asList()) + assertEquals(signer.certificate, signature1.signerCertificate) + val signature2 = NonceSignature.deserialize(nonceSignatures[1]) + assertEquals(nonce.asList(), signature2.nonce.asList()) + assertEquals(signer2.certificate, signature2.signerCertificate) + } + } + + @Test + fun `Call should return if server closed connection normally after the handshake`(): Unit = + runBlocking { + setListenerActions(ChallengeAction(nonce), CloseConnectionAction()) + + client.use { + client.collectParcels(arrayOf(signer)).collect { } + } + + awaitForConnectionClosure() + assertEquals(CloseReason.Codes.NORMAL, listener!!.closingCode) + } + + @Test + fun `Exception should be thrown if server closes connection with error`(): Unit = + runBlocking { + val code = CloseReason.Codes.VIOLATED_POLICY + val reason = "Whoops" + setListenerActions(ChallengeAction(nonce), CloseConnectionAction(code, reason)) + + client.use { + val exception = assertThrows { + runBlocking { client.collectParcels(arrayOf(signer)).toList() } + } + + assertEquals( + "Server closed the connection unexpectedly " + + "(code: ${code.code}, reason: $reason)", + exception.message + ) + } + } + + @Test + fun `Cancelling the flow should close the connection normally`(): Unit = runBlocking { + val undeliveredAction = + ParcelDeliveryAction("second delivery id", "second parcel".toByteArray()) + setListenerActions( + ChallengeAction(nonce), + ParcelDeliveryAction(deliveryId, parcelSerialized), + undeliveredAction + ) + + client.use { + val deliveries = client.collectParcels(arrayOf(signer)).take(1).toList() + + assertEquals(1, deliveries.size) + } + + awaitForConnectionClosure() + assertEquals(CloseReason.Codes.NORMAL, listener!!.closingCode) + assertFalse(undeliveredAction.wasRun) + } + + @Test + fun `Malformed deliveries should be refused`(): Unit = runBlocking { + setListenerActions(ChallengeAction(nonce), SendTextMessageAction("invalid")) + + client.use { + val exception = assertThrows { + runBlocking { client.collectParcels(arrayOf(signer)).toList() } + } + + assertEquals("Received invalid message from server", exception.message) + assertTrue(exception.cause is InvalidMessageException) + } + + awaitForConnectionClosure() + assertEquals(CloseReason.Codes.VIOLATED_POLICY, listener!!.closingCode!!) + assertEquals("Invalid parcel delivery", listener!!.closingReason!!) + } + + @Nested + inner class StreamingModeHeader { + @Test + fun `Streaming mode should be Keep-Alive by default`(): Unit = runBlocking { + setListenerActions(ChallengeAction(nonce), CloseConnectionAction()) + + client.use { + client.collectParcels(arrayOf(signer)).toList() + } + + awaitForConnectionClosure() + assertEquals( + StreamingMode.KeepAlive.headerValue, + listener!!.request!!.header(StreamingMode.HEADER_NAME) + ) + } + + @Test + fun `Streaming mode can be changed on request`(): Unit = runBlocking { + setListenerActions(ChallengeAction(nonce), CloseConnectionAction()) + + client.use { + client.collectParcels(arrayOf(signer), StreamingMode.CloseUponCompletion).toList() + } + + awaitForConnectionClosure() + assertEquals( + StreamingMode.CloseUponCompletion.headerValue, + listener!!.request!!.header(StreamingMode.HEADER_NAME) + ) + } + } + + @Nested + inner class Collector { + @Test + fun `No collectors should be output if the server doesn't deliver anything`(): Unit = + runBlocking { + setListenerActions(ChallengeAction(nonce), CloseConnectionAction()) + + client.use { + val deliveries = client.collectParcels(arrayOf(signer)).toList() + + assertEquals(0, deliveries.size) + } + } + + @Test + fun `One collector should be output if there is one delivery`(): Unit = + runBlocking { + setListenerActions( + ChallengeAction(nonce), + ActionSequence( + ParcelDeliveryAction(deliveryId, parcelSerialized), + CloseConnectionAction() + ) + ) + + client.use { + val deliveries = client.collectParcels(arrayOf(signer)).toList() + + assertEquals(1, deliveries.size) + assertEquals( + parcelSerialized.asList(), + deliveries.first().parcelSerialized.asList() + ) + } + } + + @Test + fun `Multiple collectors should be output if there are multiple deliveries`(): Unit = + runBlocking { + val parcelSerialized2 = "second parcel".toByteArray() + setListenerActions( + ChallengeAction(nonce), + ActionSequence( + ParcelDeliveryAction(deliveryId, parcelSerialized), + ParcelDeliveryAction("second delivery id", parcelSerialized2), + CloseConnectionAction() + ) + ) + + client.use { + val deliveries = client.collectParcels(arrayOf(signer)).toList() + + assertEquals(2, deliveries.size) + assertEquals( + parcelSerialized.asList(), + deliveries.first().parcelSerialized.asList() + ) + assertEquals( + parcelSerialized2.asList(), + deliveries[1].parcelSerialized.asList() + ) + } + } + } + + @Nested + inner class CollectorTrustedCerts { + @Test + fun `Collector should use trusted certificates from nonce signers`() = runBlocking { + setListenerActions( + ChallengeAction(nonce), + ActionSequence( + ParcelDeliveryAction(deliveryId, parcelSerialized), + CloseConnectionAction() + ) + ) + + client.use { + val deliveries = client.collectParcels(arrayOf(signer)).toList() + + assertEquals(1, deliveries.size) + assertEquals( + listOf(signer.certificate), + deliveries.first().trustedCertificates.toList() + ) + } + } + } + + @Nested + inner class CollectorACK { + @Test + fun `Each ACK should be passed on to the server`(): Unit = runBlocking { + setListenerActions( + ChallengeAction(nonce), + ParcelDeliveryAction(deliveryId, parcelSerialized), + CloseConnectionAction() + ) + + client.use { + client.collectParcels(arrayOf(signer)).collect { it.ack() } + } + + awaitForConnectionClosure() + // The server should've got two messages: The handshake response and the ACK + assertEquals(2, listener!!.receivedMessages.size) + assertEquals( + deliveryId, + listener!!.receivedMessages[1].toString(Charset.defaultCharset()) + ) + } + + @Test + fun `Missing ACKs should be honored`(): Unit = runBlocking { + // The server will deliver 2 parcels but the client will only ACK the first one + val additionalParcelDelivery = + ParcelDeliveryAction("second delivery id", "parcel".toByteArray()) + setListenerActions( + ChallengeAction(nonce), + ParcelDeliveryAction(deliveryId, parcelSerialized), + ActionSequence( + additionalParcelDelivery, + CloseConnectionAction() + ) + ) + + client.use { + var wasFirstCollectionAcknowledged = false + client.collectParcels(arrayOf(signer)).collect { + // Only acknowledge the first collection + if (!wasFirstCollectionAcknowledged) { + it.ack() + wasFirstCollectionAcknowledged = true + } + } + } + + awaitForConnectionClosure() + // The server should've got two messages: The handshake response and the first ACK + assertEquals(2, listener!!.receivedMessages.size) + assertEquals( + deliveryId, + listener!!.receivedMessages[1].toString(Charset.defaultCharset()) + ) + assertTrue(additionalParcelDelivery.wasRun) + } + } + + private fun generateDummySigner(): NonceSigner { + val keyPair = generateRSAKeyPair() + val certificate = issueEndpointCertificate( + keyPair.public, + keyPair.private, + ZonedDateTime.now().plusDays(1)) + return NonceSigner(certificate, keyPair.private) + } +} diff --git a/src/test/kotlin/tech/relaycorp/poweb/PoWebClientTest.kt b/src/test/kotlin/tech/relaycorp/poweb/PoWebClientTest.kt index bd1dd7a..6e85f3b 100644 --- a/src/test/kotlin/tech/relaycorp/poweb/PoWebClientTest.kt +++ b/src/test/kotlin/tech/relaycorp/poweb/PoWebClientTest.kt @@ -2,36 +2,24 @@ package tech.relaycorp.poweb import com.nhaarman.mockitokotlin2.spy import com.nhaarman.mockitokotlin2.verify -import io.ktor.client.HttpClient -import io.ktor.client.engine.mock.MockEngine import io.ktor.client.engine.okhttp.OkHttpEngine import io.ktor.client.features.websocket.DefaultClientWebSocketSession -import io.ktor.client.features.websocket.WebSockets import io.ktor.client.request.HttpRequestData import io.ktor.http.URLProtocol -import io.ktor.http.cio.websocket.CloseReason -import io.ktor.http.fullPath import io.ktor.util.InternalAPI import io.ktor.util.KtorExperimentalAPI +import kotlinx.coroutines.ExperimentalCoroutinesApi import kotlinx.coroutines.runBlocking -import okhttp3.Response -import okhttp3.WebSocket -import okhttp3.WebSocketListener -import okio.ByteString -import okio.ByteString.Companion.toByteString -import org.awaitility.Awaitility.await -import org.junit.jupiter.api.AfterEach -import org.junit.jupiter.api.Disabled +import kotlinx.coroutines.test.runBlockingTest import org.junit.jupiter.api.Nested import org.junit.jupiter.api.Test import org.junit.jupiter.api.assertThrows -import tech.relaycorp.poweb.handshake.Challenge -import tech.relaycorp.poweb.handshake.InvalidMessageException -import tech.relaycorp.poweb.handshake.NonceSigner -import tech.relaycorp.relaynet.issueEndpointCertificate -import tech.relaycorp.relaynet.messages.control.NonceSignature -import tech.relaycorp.relaynet.wrappers.generateRSAKeyPair -import java.time.ZonedDateTime +import tech.relaycorp.poweb.websocket.CloseConnectionAction +import tech.relaycorp.poweb.websocket.MockKtorClientManager +import tech.relaycorp.poweb.websocket.ServerShutdownAction +import tech.relaycorp.poweb.websocket.WebSocketTestCase +import java.io.EOFException +import java.net.ConnectException import kotlin.test.assertEquals import kotlin.test.assertFalse import kotlin.test.assertTrue @@ -39,6 +27,7 @@ import kotlin.test.assertTrue @KtorExperimentalAPI class PoWebClientTest { @Nested + @Suppress("RedundantInnerClassModifier") inner class Constructor { @Nested inner class InitLocal { @@ -126,53 +115,90 @@ class PoWebClientTest { } @Nested + @Suppress("RedundantInnerClassModifier") + @ExperimentalCoroutinesApi inner class WebSocketConnection : WebSocketTestCase(false) { private val hostName = "127.0.0.1" private val port = 13276 private val path = "/v1/the-endpoint" @Test - fun `Client should use WS if TLS is not required`() { - val wsRequest = wsConnect(false) {} + fun `Failing to connect to the server should throw an exception`() { + // Connect to an invalid port + val client = PoWebClient.initLocal(mockWebServer.port - 1) + + client.use { + val exception = assertThrows { + runBlocking { client.wsConnect(path) {} } + } + + assertEquals("Server is unreachable", exception.message) + assertTrue(exception.cause is ConnectException) + } + } + + @Test + fun `Losing the connection abruptly should throw an exception`(): Unit = runBlocking { + val client = PoWebClient.initLocal(mockWebServer.port) + setListenerActions(ServerShutdownAction()) + + client.use { + val exception = assertThrows { + runBlocking { + client.wsConnect(path) { + incoming.receive() + } + } + } + + assertEquals("Connection was closed abruptly", exception.message) + assertTrue(exception.cause is EOFException) + } + } + + @Test + fun `Client should use WS if TLS is not required`() = runBlockingTest { + val wsRequest = mockWSConnect(false) {} assertEquals(URLProtocol.WS, wsRequest.url.protocol) } @Test - fun `Client should use WSS if TLS is required`() { - val wsRequest = wsConnect(true) {} + fun `Client should use WSS if TLS is required`() = runBlockingTest { + val wsRequest = mockWSConnect(true) {} assertEquals(URLProtocol.WSS, wsRequest.url.protocol) } @Test - fun `Client should connect to specified host`() { - val wsRequest = wsConnect {} + fun `Client should connect to specified host and port`(): Unit = runBlockingTest { + val wsRequest = mockWSConnect(true) {} assertEquals(hostName, wsRequest.url.host) + assertEquals(port, wsRequest.url.port) } @Test - fun `Client should connect to specified port`() { - val wsRequest = wsConnect {} + fun `Client should connect to specified path`() = runBlocking { + val wsRequest = mockWSConnect {} - assertEquals(port, wsRequest.url.port) + assertEquals(path, wsRequest.url.encodedPath) } @Test - fun `Client should connect to specified path`() { - val wsRequest = wsConnect {} + fun `Request headers should be honored`() = runBlocking { + val header1 = Pair("x-h1", "value1") + val header2 = Pair("x-h2", "value2") - assertEquals(path, wsRequest.url.fullPath) + val wsRequest = mockWSConnect(headers = listOf(header1, header2)) {} + + assertEquals(header1.second, wsRequest.headers[header1.first]) + assertEquals(header2.second, wsRequest.headers[header2.first]) } @Test fun `Specified block should be called`(): Unit = runBlocking { - setWebSocketListener(object : WebSocketListener() { - override fun onOpen(webSocket: WebSocket, response: Response) { - webSocket.close(1000, "No-op") - } - }) + setListenerActions(CloseConnectionAction()) val client = PoWebClient.initLocal(mockWebServer.port) var wasBlockRun = false @@ -181,156 +207,20 @@ class PoWebClientTest { assertTrue(wasBlockRun) } - private fun wsConnect( + private suspend fun mockWSConnect( useTls: Boolean = false, + headers: List>? = null, block: suspend DefaultClientWebSocketSession.() -> Unit ): HttpRequestData { val client = PoWebClient(hostName, port, useTls) - var connectionRequest: HttpRequestData? = null - client.ktorClient = HttpClient(MockEngine) { - install(WebSockets) - - engine { - addHandler { request -> - connectionRequest = request - error("Nothing to see here") - } - } - } - - assertThrows { runBlocking { client.wsConnect(path, block) } } - assertTrue(connectionRequest is HttpRequestData) - - return connectionRequest as HttpRequestData - } - } - - @Nested - inner class Handshake : WebSocketTestCase() { - private val nonce = "nonce".toByteArray() - private val challengeSerialized = Challenge(nonce).serialize().toByteString() + val ktorClientManager = MockKtorClientManager() + client.ktorClient = ktorClientManager.ktorClient - // Compute client on demand because getting the server port will start the server - private val client by lazy { PoWebClient.initLocal(mockWebServer.port) } - - private val signer = generateDummySigner() - - @AfterEach - fun closeClient() = client.ktorClient.close() - - @Test - fun `Getting an invalid challenge should result in an exception`() { - var closeCode: Int? = null - - setWebSocketListener(object : WebSocketListener() { - override fun onOpen(webSocket: WebSocket, response: Response) { - webSocket.send("Not a valid challenge") - } - - override fun onClosing(webSocket: WebSocket, code: Int, reason: String) { - closeCode = code - super.onClosing(webSocket, code, reason) - } - }) - - client.use { - val exception = assertThrows { - runBlocking { client.wsConnect("/") { handshake(arrayOf(signer)) } } - } - - assertEquals("Server sent an invalid handshake challenge", exception.message) - assertTrue(exception.cause is InvalidMessageException) + ktorClientManager.useClient { + client.wsConnect(path, headers, block) } - await().until { closeCode is Int } - assertEquals(CloseReason.Codes.VIOLATED_POLICY.code.toInt(), closeCode) - } - - @Test - fun `At least one nonce signer should be required`() { - var closeCode: Int? = null - - setWebSocketListener(object : WebSocketListener() { - override fun onOpen(webSocket: WebSocket, response: Response) { - webSocket.send(challengeSerialized) - } - - override fun onClosing(webSocket: WebSocket, code: Int, reason: String) { - closeCode = code - } - }) - - client.use { - val exception = assertThrows { - runBlocking { - client.wsConnect("/") { handshake(emptyArray()) } - } - } - - assertEquals("At least one nonce signer must be specified", exception.message) - } - await().until { closeCode is Int } - assertEquals(CloseReason.Codes.NORMAL.code.toInt(), closeCode) - } - - @Test - fun `Challenge nonce should be signed with each signer`() { - var response: tech.relaycorp.poweb.handshake.Response? = null - - setWebSocketListener(object : WebSocketListener() { - override fun onOpen(webSocket: WebSocket, response: Response) { - webSocket.send(challengeSerialized) - } - override fun onMessage(webSocket: WebSocket, bytes: ByteString) { - response = tech.relaycorp.poweb.handshake.Response.deserialize( - bytes.toByteArray() - ) - webSocket.close(1000, "") - } - }) - - val signer2 = generateDummySigner() - - client.use { - runBlocking { client.wsConnect("/") { handshake(arrayOf(signer, signer2)) } } - - await().until { response is tech.relaycorp.poweb.handshake.Response } - - val nonceSignatures = response!!.nonceSignatures - val signature1 = NonceSignature.deserialize(nonceSignatures[0]) - assertEquals(nonce.asList(), signature1.nonce.asList()) - assertEquals(signer.certificate, signature1.signerCertificate) - val signature2 = NonceSignature.deserialize(nonceSignatures[1]) - assertEquals(nonce.asList(), signature2.nonce.asList()) - assertEquals(signer2.certificate, signature2.signerCertificate) - } - } - - private fun generateDummySigner(): NonceSigner { - val keyPair = generateRSAKeyPair() - val certificate = issueEndpointCertificate( - keyPair.public, - keyPair.private, - ZonedDateTime.now().plusDays(1)) - return NonceSigner(certificate, keyPair.private) - } - } - - @Nested - inner class ParcelCollection { - @Test - @Disabled - fun `Request should be made to the parcel collection endpoint`() { - } - - @Test - @Disabled - fun `Call should return if server closed connection normally`() { - } - - @Test - @Disabled - fun `An exception should be thrown if the server closes the connection with an error`() { + return ktorClientManager.request } } } diff --git a/src/test/kotlin/tech/relaycorp/poweb/handshake/ChallengeTest.kt b/src/test/kotlin/tech/relaycorp/poweb/handshake/ChallengeTest.kt index 9c1a7ee..fc8e333 100644 --- a/src/test/kotlin/tech/relaycorp/poweb/handshake/ChallengeTest.kt +++ b/src/test/kotlin/tech/relaycorp/poweb/handshake/ChallengeTest.kt @@ -29,7 +29,7 @@ class ChallengeTest { fun `Invalid serialization should be refused`() { val serialization = "This is invalid".toByteArray() - val exception = assertThrows { + val exception = assertThrows { Challenge.deserialize(serialization) } diff --git a/src/test/kotlin/tech/relaycorp/poweb/handshake/NonceSignerTest.kt b/src/test/kotlin/tech/relaycorp/poweb/handshake/NonceSignerTest.kt deleted file mode 100644 index 7c1938d..0000000 --- a/src/test/kotlin/tech/relaycorp/poweb/handshake/NonceSignerTest.kt +++ /dev/null @@ -1,35 +0,0 @@ -package tech.relaycorp.poweb.handshake - -import org.junit.jupiter.api.Test -import tech.relaycorp.relaynet.issueEndpointCertificate -import tech.relaycorp.relaynet.messages.control.NonceSignature -import tech.relaycorp.relaynet.wrappers.generateRSAKeyPair -import java.time.ZonedDateTime -import kotlin.test.assertEquals - -class NonceSignerTest { - private val nonce = "The nonce".toByteArray() - private val keyPair = generateRSAKeyPair() - private val certificate = issueEndpointCertificate( - keyPair.public, - keyPair.private, - ZonedDateTime.now().plusDays(1) - ) - private val signer = NonceSigner(certificate, keyPair.private) - - @Test - fun `Nonce should be honored`() { - val serialization = signer.sign(nonce) - - val signature = NonceSignature.deserialize(serialization) - assertEquals(nonce.asList(), signature.nonce.asList()) - } - - @Test - fun `Signer certificate should be honored`() { - val serialization = signer.sign(nonce) - - val signature = NonceSignature.deserialize(serialization) - assertEquals(certificate, signature.signerCertificate) - } -} diff --git a/src/test/kotlin/tech/relaycorp/poweb/handshake/ResponseTest.kt b/src/test/kotlin/tech/relaycorp/poweb/handshake/ResponseTest.kt index dddb1d9..e369fe6 100644 --- a/src/test/kotlin/tech/relaycorp/poweb/handshake/ResponseTest.kt +++ b/src/test/kotlin/tech/relaycorp/poweb/handshake/ResponseTest.kt @@ -58,7 +58,7 @@ class ResponseTest { fun `Invalid serialization should be refused`() { val serialization = "This is invalid".toByteArray() - val exception = assertThrows { + val exception = assertThrows { Response.deserialize(serialization) } diff --git a/src/test/kotlin/tech/relaycorp/poweb/websocket/MockKtorClientManager.kt b/src/test/kotlin/tech/relaycorp/poweb/websocket/MockKtorClientManager.kt new file mode 100644 index 0000000..a1fd4e7 --- /dev/null +++ b/src/test/kotlin/tech/relaycorp/poweb/websocket/MockKtorClientManager.kt @@ -0,0 +1,48 @@ +package tech.relaycorp.poweb.websocket + +import io.ktor.client.HttpClient +import io.ktor.client.engine.mock.MockEngine +import io.ktor.client.features.websocket.WebSockets +import io.ktor.client.request.HttpRequestData +import io.ktor.util.KtorExperimentalAPI +import kotlin.test.junit5.JUnit5Asserter.fail + +/** + * Workaround to use Ktor's MockEngine with a WebSocket connection, which is currently unsupported. + * + * We're not actually implementing a mock WebSocket server with this: We're just recording the + * requests made. + */ +@KtorExperimentalAPI +class MockKtorClientManager { + private val requests = mutableListOf() + + val ktorClient = mockClient(requests) + + val request: HttpRequestData + get() = requests.single() + + suspend fun useClient(block: suspend () -> Unit) { + try { + block() + } catch (_: SkipHandlerException) { + return + } + fail("Mock handler was not reached") + } + + companion object { + private fun mockClient(requests: MutableList) = HttpClient(MockEngine) { + install(WebSockets) + + engine { + addHandler { request -> + requests.add(request) + throw SkipHandlerException() + } + } + } + + private class SkipHandlerException : Exception() + } +} diff --git a/src/test/kotlin/tech/relaycorp/poweb/websocket/MockWebSocketAction.kt b/src/test/kotlin/tech/relaycorp/poweb/websocket/MockWebSocketAction.kt new file mode 100644 index 0000000..a94df19 --- /dev/null +++ b/src/test/kotlin/tech/relaycorp/poweb/websocket/MockWebSocketAction.kt @@ -0,0 +1,59 @@ +package tech.relaycorp.poweb.websocket + +import io.ktor.http.cio.websocket.CloseReason +import okhttp3.WebSocket +import okhttp3.mockwebserver.MockWebServer +import okio.ByteString.Companion.toByteString +import tech.relaycorp.poweb.handshake.Challenge +import tech.relaycorp.relaynet.messages.control.ParcelDelivery + +sealed class MockWebSocketAction { + var wasRun = false + + open fun run(webSocket: WebSocket, mockWebServer: MockWebServer) { + wasRun = true + } +} + +open class SendBinaryMessageAction(private val message: ByteArray) : MockWebSocketAction() { + override fun run(webSocket: WebSocket, mockWebServer: MockWebServer) { + webSocket.send(message.toByteString()) + super.run(webSocket, mockWebServer) + } +} + +class SendTextMessageAction(private val message: String) : MockWebSocketAction() { + override fun run(webSocket: WebSocket, mockWebServer: MockWebServer) { + webSocket.send(message) + super.run(webSocket, mockWebServer) + } +} + +class ChallengeAction(nonce: ByteArray) : SendBinaryMessageAction(Challenge(nonce).serialize()) + +class ParcelDeliveryAction(deliveryId: String, parcelSerialized: ByteArray) : + SendBinaryMessageAction(ParcelDelivery(deliveryId, parcelSerialized).serialize()) + +class CloseConnectionAction( + private val code: CloseReason.Codes = CloseReason.Codes.NORMAL, + private val reason: String? = null +) : MockWebSocketAction() { + override fun run(webSocket: WebSocket, mockWebServer: MockWebServer) { + webSocket.close(code.code.toInt(), reason) + super.run(webSocket, mockWebServer) + } +} + +class ServerShutdownAction : MockWebSocketAction() { + override fun run(webSocket: WebSocket, mockWebServer: MockWebServer) { + mockWebServer.shutdown() + super.run(webSocket, mockWebServer) + } +} + +class ActionSequence(private vararg val actions: MockWebSocketAction) : MockWebSocketAction() { + override fun run(webSocket: WebSocket, mockWebServer: MockWebServer) { + actions.forEach { it.run(webSocket, mockWebServer) } + super.run(webSocket, mockWebServer) + } +} diff --git a/src/test/kotlin/tech/relaycorp/poweb/websocket/MockWebSocketListener.kt b/src/test/kotlin/tech/relaycorp/poweb/websocket/MockWebSocketListener.kt new file mode 100644 index 0000000..b7ffcda --- /dev/null +++ b/src/test/kotlin/tech/relaycorp/poweb/websocket/MockWebSocketListener.kt @@ -0,0 +1,57 @@ +package tech.relaycorp.poweb.websocket + +import io.ktor.http.cio.websocket.CloseReason +import okhttp3.Request +import okhttp3.Response +import okhttp3.WebSocket +import okhttp3.WebSocketListener +import okhttp3.mockwebserver.MockWebServer +import okio.ByteString +import kotlin.test.assertFalse + +class MockWebSocketListener( + private val actions: MutableList, + private val mockWebServer: MockWebServer +) : WebSocketListener() { + var request: Request? = null + var connectionOpen = false + var connected = false + + val receivedMessages = mutableListOf() + + internal var closingCode: CloseReason.Codes? = null + internal var closingReason: String? = null + + override fun onOpen(webSocket: WebSocket, response: Response) { + assertFalse(connected, "Listener cannot be reused") + request = webSocket.request() + connectionOpen = true + connected = true + + runNextAction(webSocket) + } + + override fun onMessage(webSocket: WebSocket, bytes: ByteString) { + receivedMessages.add(bytes.toByteArray()) + + runNextAction(webSocket) + } + + override fun onMessage(webSocket: WebSocket, text: String) { + receivedMessages.add(text.toByteArray()) + + runNextAction(webSocket) + } + + private fun runNextAction(webSocket: WebSocket) { + val action = actions.removeFirst() + action.run(webSocket, mockWebServer) + } + + override fun onClosing(webSocket: WebSocket, code: Int, reason: String) { + closingCode = CloseReason.Codes.byCode(code.toShort()) + closingReason = reason + + connectionOpen = false + } +} diff --git a/src/test/kotlin/tech/relaycorp/poweb/WebSocketTestCase.kt b/src/test/kotlin/tech/relaycorp/poweb/websocket/WebSocketTestCase.kt similarity index 57% rename from src/test/kotlin/tech/relaycorp/poweb/WebSocketTestCase.kt rename to src/test/kotlin/tech/relaycorp/poweb/websocket/WebSocketTestCase.kt index 052b687..b9836e0 100644 --- a/src/test/kotlin/tech/relaycorp/poweb/WebSocketTestCase.kt +++ b/src/test/kotlin/tech/relaycorp/poweb/websocket/WebSocketTestCase.kt @@ -1,20 +1,25 @@ -package tech.relaycorp.poweb +package tech.relaycorp.poweb.websocket -import okhttp3.WebSocketListener import okhttp3.mockwebserver.MockResponse import okhttp3.mockwebserver.MockWebServer +import org.awaitility.Awaitility.await import org.junit.jupiter.api.AfterEach import org.junit.jupiter.api.BeforeEach import java.io.IOException +import kotlin.test.assertTrue open class WebSocketTestCase(private val autoStartServer: Boolean = true) { protected val mockWebServer = MockWebServer() + protected var listener: MockWebSocketListener? = null + @BeforeEach fun startServer() { if (autoStartServer) { mockWebServer.start() } + + listener = null } @AfterEach @@ -28,7 +33,16 @@ open class WebSocketTestCase(private val autoStartServer: Boolean = true) { } } - protected fun setWebSocketListener(listener: WebSocketListener) { - mockWebServer.enqueue(MockResponse().withWebSocketUpgrade(listener)) + protected fun setListenerActions(vararg actions: MockWebSocketAction) { + listener = MockWebSocketListener(actions.toMutableList(), mockWebServer) + mockWebServer.enqueue(MockResponse().withWebSocketUpgrade(listener!!)) + } + + /** + * Wait until the connection to the server has been closed. + */ + protected fun awaitForConnectionClosure() { + assertTrue(listener!!.connected, "The server must've got at least one connection") + await().until { !listener!!.connectionOpen } } }