diff --git a/build.gradle b/build.gradle index 7034465..912c4be 100644 --- a/build.gradle +++ b/build.gradle @@ -3,7 +3,7 @@ import org.jetbrains.kotlin.gradle.tasks.KotlinCompile buildscript { ext.kotlinVersion = '1.5.20' ext.kotlinCoroutinesVersion = '1.5.0' - ext.ktorVersion = '1.4.1' + ext.ktorVersion = '1.6.1' ext.okhttpVersion = '4.9.1' } @@ -62,7 +62,6 @@ dependencies { testImplementation("org.jetbrains.kotlin:kotlin-test-junit5") testImplementation("com.nhaarman.mockitokotlin2:mockito-kotlin:2.2.0") testImplementation("org.mockito:mockito-inline:3.11.2") - testImplementation("org.jetbrains.kotlinx:kotlinx-coroutines-test:$kotlinCoroutinesVersion") } java { diff --git a/src/main/kotlin/tech/relaycorp/poweb/PoWebClient.kt b/src/main/kotlin/tech/relaycorp/poweb/PoWebClient.kt index e2921c9..c9286a8 100644 --- a/src/main/kotlin/tech/relaycorp/poweb/PoWebClient.kt +++ b/src/main/kotlin/tech/relaycorp/poweb/PoWebClient.kt @@ -3,6 +3,9 @@ package tech.relaycorp.poweb import io.ktor.client.HttpClient import io.ktor.client.engine.HttpClientEngine import io.ktor.client.engine.okhttp.OkHttp +import io.ktor.client.features.ClientRequestException +import io.ktor.client.features.RedirectResponseException +import io.ktor.client.features.ServerResponseException import io.ktor.client.features.websocket.DefaultClientWebSocketSession import io.ktor.client.features.websocket.WebSockets import io.ktor.client.features.websocket.webSocket @@ -19,7 +22,6 @@ import io.ktor.http.content.ByteArrayContent import io.ktor.http.content.OutgoingContent import io.ktor.http.content.TextContent import io.ktor.http.contentType -import io.ktor.util.KtorExperimentalAPI import io.ktor.util.toByteArray import kotlinx.coroutines.channels.ClosedReceiveChannelException import kotlinx.coroutines.flow.Flow @@ -62,7 +64,6 @@ import java.time.Duration * * The underlying connection is created lazily. */ -@OptIn(KtorExperimentalAPI::class) public class PoWebClient internal constructor( internal val hostName: String, internal val port: Int, @@ -81,7 +82,8 @@ public class PoWebClient internal constructor( private val urlScheme = if (useTls) "https" else "http" private val wsScheme = if (useTls) "wss" else "ws" - internal val baseURL: String = "$urlScheme://$hostName:$port/v1" + internal val baseHttpUrl: String = "$urlScheme://$hostName:$port/v1" + internal val baseWsUrl: String = "$wsScheme://$hostName:$port/v1" /** * Close the underlying connection to the server (if any). @@ -250,8 +252,9 @@ public class PoWebClient internal constructor( requestBody: OutgoingContent, authorizationHeader: String? = null ): HttpResponse { - val url = "$baseURL$path" - val response: HttpResponse = try { + val url = "$baseHttpUrl$path" + + return try { ktorClient.post(url) { if (authorizationHeader != null) { header("Authorization", authorizationHeader) @@ -259,20 +262,20 @@ public class PoWebClient internal constructor( body = requestBody } } catch (exc: UnknownHostException) { - throw ServerConnectionException("Failed to resolve DNS for $baseURL", exc) + throw ServerConnectionException("Failed to resolve DNS for $baseHttpUrl", exc) } catch (exc: IOException) { throw ServerConnectionException("Failed to connect to $url", exc) - } - - if (response.status.value in 200..299) { - return response - } - throw when (response.status.value) { - in 400..499 -> PoWebClientException(response.status) - in 500..599 -> ServerConnectionException( - "The server was unable to fulfil the request (${response.status})" + } catch (exc: RedirectResponseException) { + // HTTP 3XX response + throw ServerBindingException("Unexpected redirect (${exc.response.status})") + } catch (exc: ClientRequestException) { + // HTTP 4XX response + throw PoWebClientException(exc.response.status) + } catch (exc: ServerResponseException) { + // HTTP 5XX response + throw ServerConnectionException( + "The server was unable to fulfil the request (${exc.response.status})" ) - else -> ServerBindingException("Received unexpected status (${response.status})") } } @@ -293,7 +296,7 @@ public class PoWebClient internal constructor( block: suspend DefaultClientWebSocketSession.() -> Unit ) = try { ktorClient.webSocket( - "$wsScheme://$hostName:$port$path", + "$baseWsUrl$path", { headers?.forEach { header(it.first, it.second) } }, block ) @@ -304,7 +307,7 @@ public class PoWebClient internal constructor( } public companion object { - internal const val PARCEL_COLLECTION_ENDPOINT_PATH = "/v1/parcel-collection" + internal const val PARCEL_COLLECTION_ENDPOINT_PATH = "/parcel-collection" private const val DEFAULT_LOCAL_PORT = 276 private const val DEFAULT_REMOTE_PORT = 443 diff --git a/src/test/kotlin/tech/relaycorp/poweb/ParcelCollectionTest.kt b/src/test/kotlin/tech/relaycorp/poweb/ParcelCollectionTest.kt index 39f79b0..21a98d6 100644 --- a/src/test/kotlin/tech/relaycorp/poweb/ParcelCollectionTest.kt +++ b/src/test/kotlin/tech/relaycorp/poweb/ParcelCollectionTest.kt @@ -1,7 +1,6 @@ package tech.relaycorp.poweb import io.ktor.http.cio.websocket.CloseReason -import io.ktor.util.KtorExperimentalAPI import kotlinx.coroutines.channels.ClosedReceiveChannelException import kotlinx.coroutines.flow.collect import kotlinx.coroutines.flow.first @@ -15,7 +14,6 @@ import org.junit.jupiter.api.assertThrows import tech.relaycorp.poweb.websocket.ActionSequence import tech.relaycorp.poweb.websocket.ChallengeAction import tech.relaycorp.poweb.websocket.CloseConnectionAction -import tech.relaycorp.poweb.websocket.MockKtorClientManager import tech.relaycorp.poweb.websocket.ParcelDeliveryAction import tech.relaycorp.poweb.websocket.SendTextMessageAction import tech.relaycorp.poweb.websocket.WebSocketTestCase @@ -36,7 +34,6 @@ import kotlin.test.assertEquals import kotlin.test.assertFalse import kotlin.test.assertTrue -@KtorExperimentalAPI class ParcelCollectionTest : WebSocketTestCase() { private val nonce = "nonce".toByteArray() @@ -53,18 +50,19 @@ class ParcelCollectionTest : WebSocketTestCase() { fun closeClient() = client.ktorClient.close() @Test - fun `Request should be made to the parcel collection endpoint`() = runBlocking { - val mockClient = PoWebClient.initLocal() - val ktorClientManager = MockKtorClientManager() - mockClient.ktorClient = ktorClientManager.ktorClient + fun `Request should be made to the parcel collection endpoint`() { + val client = PoWebClient.initLocal(mockWebServer.port) + client.ktorClient = ktorWSClient + setListenerActions(CloseConnectionAction()) - ktorClientManager.useClient { - mockClient.collectParcels(arrayOf(signer)).toList() + assertThrows { + runBlocking { client.collectParcels(arrayOf(signer)).toList() } } + val request = mockWebServer.takeRequest() assertEquals( - PoWebClient.PARCEL_COLLECTION_ENDPOINT_PATH, - ktorClientManager.request.url.encodedPath + "/v1${PoWebClient.PARCEL_COLLECTION_ENDPOINT_PATH}", + request.path ) } diff --git a/src/test/kotlin/tech/relaycorp/poweb/ParcelDeliveryTest.kt b/src/test/kotlin/tech/relaycorp/poweb/ParcelDeliveryTest.kt index 03158c0..90fd7e8 100644 --- a/src/test/kotlin/tech/relaycorp/poweb/ParcelDeliveryTest.kt +++ b/src/test/kotlin/tech/relaycorp/poweb/ParcelDeliveryTest.kt @@ -7,9 +7,7 @@ import io.ktor.client.request.HttpRequestData import io.ktor.http.HttpMethod import io.ktor.http.HttpStatusCode import io.ktor.http.content.OutgoingContent -import io.ktor.util.KtorExperimentalAPI -import kotlinx.coroutines.ExperimentalCoroutinesApi -import kotlinx.coroutines.test.runBlockingTest +import kotlinx.coroutines.runBlocking import org.bouncycastle.util.encoders.Base64 import org.junit.jupiter.api.Test import org.junit.jupiter.api.assertThrows @@ -24,15 +22,13 @@ import kotlin.test.assertEquals import kotlin.test.assertNotNull import kotlin.test.assertTrue -@ExperimentalCoroutinesApi -@KtorExperimentalAPI class ParcelDeliveryTest { private val parcelSerialized = "Let's say I'm the serialization of a parcel".toByteArray() private val signer = Signer(PDACertPath.PRIVATE_ENDPOINT, KeyPairSet.PRIVATE_ENDPOINT.private) @Test - fun `Request should be made with HTTP POST`() = runBlockingTest { + fun `Request should be made with HTTP POST`() = runBlocking { var method: HttpMethod? = null val client = makeTestClient { request: HttpRequestData -> method = request.method @@ -45,7 +41,7 @@ class ParcelDeliveryTest { } @Test - fun `Endpoint should be the one for parcels`() = runBlockingTest { + fun `Endpoint should be the one for parcels`() = runBlocking { var endpointURL: String? = null val client = makeTestClient { request: HttpRequestData -> endpointURL = request.url.toString() @@ -54,11 +50,11 @@ class ParcelDeliveryTest { client.use { client.deliverParcel(parcelSerialized, signer) } - assertEquals("${client.baseURL}/parcels", endpointURL) + assertEquals("${client.baseHttpUrl}/parcels", endpointURL) } @Test - fun `Request content type should be the appropriate value`() = runBlockingTest { + fun `Request content type should be the appropriate value`() = runBlocking { var contentType: String? = null val client = makeTestClient { request: HttpRequestData -> contentType = request.body.contentType.toString() @@ -71,7 +67,7 @@ class ParcelDeliveryTest { } @Test - fun `Request body should be the parcel serialized`() = runBlockingTest { + fun `Request body should be the parcel serialized`() = runBlocking { var requestBody: ByteArray? = null val client = makeTestClient { request: HttpRequestData -> assertTrue(request.body is OutgoingContent.ByteArrayContent) @@ -85,7 +81,7 @@ class ParcelDeliveryTest { } @Test - fun `Delivery signature should be in the request headers`() = runBlockingTest { + fun `Delivery signature should be in the request headers`(): Unit = runBlocking { var authorizationHeader: String? = null val client = makeTestClient { request: HttpRequestData -> authorizationHeader = request.headers["Authorization"] @@ -106,7 +102,7 @@ class ParcelDeliveryTest { } @Test - fun `HTTP 20X should be regarded a successful delivery`() = runBlockingTest { + fun `HTTP 20X should be regarded a successful delivery`() = runBlocking { val client = makeTestClient { respond("", HttpStatusCode.Accepted) } client.use { client.deliverParcel(parcelSerialized, signer) } @@ -118,7 +114,7 @@ class ParcelDeliveryTest { client.use { val exception = assertThrows { - runBlockingTest { client.deliverParcel(parcelSerialized, signer) } + runBlocking { client.deliverParcel(parcelSerialized, signer) } } assertEquals("The server rejected the parcel", exception.message) @@ -132,7 +128,7 @@ class ParcelDeliveryTest { client.use { val exception = assertThrows { - runBlockingTest { client.deliverParcel(parcelSerialized, signer) } + runBlocking { client.deliverParcel(parcelSerialized, signer) } } assertEquals("The server returned a $status response", exception.message) diff --git a/src/test/kotlin/tech/relaycorp/poweb/PoWebClientTest.kt b/src/test/kotlin/tech/relaycorp/poweb/PoWebClientTest.kt index 6a69469..b90c7b7 100644 --- a/src/test/kotlin/tech/relaycorp/poweb/PoWebClientTest.kt +++ b/src/test/kotlin/tech/relaycorp/poweb/PoWebClientTest.kt @@ -11,19 +11,14 @@ import io.ktor.client.request.HttpRequestData import io.ktor.http.ContentType import io.ktor.http.HttpMethod import io.ktor.http.HttpStatusCode -import io.ktor.http.URLProtocol import io.ktor.http.content.ByteArrayContent import io.ktor.http.content.OutgoingContent import io.ktor.util.InternalAPI -import io.ktor.util.KtorExperimentalAPI -import kotlinx.coroutines.ExperimentalCoroutinesApi import kotlinx.coroutines.runBlocking -import kotlinx.coroutines.test.runBlockingTest import org.junit.jupiter.api.Nested import org.junit.jupiter.api.Test import org.junit.jupiter.api.assertThrows import tech.relaycorp.poweb.websocket.CloseConnectionAction -import tech.relaycorp.poweb.websocket.MockKtorClientManager import tech.relaycorp.poweb.websocket.ServerShutdownAction import tech.relaycorp.poweb.websocket.WebSocketTestCase import tech.relaycorp.relaynet.bindings.pdc.ServerBindingException @@ -38,8 +33,6 @@ import kotlin.test.assertFalse import kotlin.test.assertNull import kotlin.test.assertTrue -@ExperimentalCoroutinesApi -@KtorExperimentalAPI @Suppress("RedundantInnerClassModifier") class PoWebClientTest { @Nested @@ -79,7 +72,7 @@ class PoWebClientTest { fun `Correct HTTP URL should be set when not using TLS`() { val client = PoWebClient.initLocal() - assertEquals("http://127.0.0.1:276/v1", client.baseURL) + assertEquals("http://127.0.0.1:276/v1", client.baseHttpUrl) } @InternalAPI @@ -128,7 +121,7 @@ class PoWebClientTest { fun `Correct HTTPS URL should be set when using TLS`() { val client = PoWebClient.initRemote(hostName) - assertEquals("https://$hostName:443/v1", client.baseURL) + assertEquals("https://$hostName:443/v1", client.baseHttpUrl) } @InternalAPI @@ -167,7 +160,7 @@ class PoWebClientTest { @Nested inner class Request { @Test - fun `Request should be made with HTTP POST`() = runBlockingTest { + fun `Request should be made with HTTP POST`() = runBlocking { var method: HttpMethod? = null val client = makeTestClient { request: HttpRequestData -> method = request.method @@ -180,7 +173,7 @@ class PoWebClientTest { } @Test - fun `Specified path should be honored`() = runBlockingTest { + fun `Specified path should be honored`() = runBlocking { var endpointURL: String? = null val client = makeTestClient { request: HttpRequestData -> endpointURL = request.url.toString() @@ -189,11 +182,11 @@ class PoWebClientTest { client.use { client.post(path, body) } - assertEquals("${client.baseURL}$path", endpointURL) + assertEquals("${client.baseHttpUrl}$path", endpointURL) } @Test - fun `Specified Content-Type should be honored`() = runBlockingTest { + fun `Specified Content-Type should be honored`() = runBlocking { var contentType: String? = null val client = makeTestClient { request: HttpRequestData -> contentType = request.body.contentType.toString() @@ -206,7 +199,7 @@ class PoWebClientTest { } @Test - fun `Request body should be the parcel serialized`() = runBlockingTest { + fun `Request body should be the parcel serialized`() = runBlocking { var requestBody: ByteArray? = null val client = makeTestClient { request: HttpRequestData -> assertTrue(request.body is OutgoingContent.ByteArrayContent) @@ -220,7 +213,7 @@ class PoWebClientTest { } @Test - fun `No Authorization header should be set by default`() = runBlockingTest { + fun `No Authorization header should be set by default`() = runBlocking { var authorizationHeader: String? = null val client = makeTestClient { request: HttpRequestData -> authorizationHeader = request.headers["Authorization"] @@ -233,7 +226,7 @@ class PoWebClientTest { } @Test - fun `Authorization should be set if requested`() = runBlockingTest { + fun `Authorization should be set if requested`() = runBlocking { val expectedAuthorizationHeader = "Foo bar" var actualAuthorizationHeader: String? = null val client = makeTestClient { request: HttpRequestData -> @@ -250,7 +243,7 @@ class PoWebClientTest { @Nested inner class Response { @Test - fun `HTTP 20X should be regarded a successful delivery`() = runBlockingTest { + fun `HTTP 20X should be regarded a successful delivery`(): Unit = runBlocking { val client = makeTestClient { respond("", HttpStatusCode.Accepted) } client.use { client.post(path, body) } @@ -262,11 +255,11 @@ class PoWebClientTest { client.use { val exception = assertThrows { - runBlockingTest { client.post(path, body) } + runBlocking { client.post(path, body) } } assertEquals( - "Received unexpected status (${HttpStatusCode.Found})", + "Unexpected redirect (${HttpStatusCode.Found})", exception.message ) } @@ -279,7 +272,7 @@ class PoWebClientTest { client.use { val exception = assertThrows { - runBlockingTest { client.post(path, body) } + runBlocking { client.post(path, body) } } assertEquals(status, exception.responseStatus) @@ -292,7 +285,7 @@ class PoWebClientTest { client.use { val exception = assertThrows { - runBlockingTest { client.post(path, body) } + runBlocking { client.post(path, body) } } assertEquals( @@ -313,7 +306,7 @@ class PoWebClientTest { runBlocking { client.post(path, body) } } - assertEquals("Failed to resolve DNS for ${client.baseURL}", exception.message) + assertEquals("Failed to resolve DNS for ${client.baseHttpUrl}", exception.message) assertTrue(exception.cause is UnknownHostException) } } @@ -327,7 +320,7 @@ class PoWebClientTest { runBlocking { client.post(path, body) } } - assertEquals("Failed to connect to ${client.baseURL}$path", exception.message) + assertEquals("Failed to connect to ${client.baseHttpUrl}$path", exception.message) assertTrue(exception.cause is SocketException) } } @@ -336,8 +329,7 @@ class PoWebClientTest { @Nested inner class WebSocketConnection : WebSocketTestCase(false) { private val hostName = "127.0.0.1" - private val port = 13276 - private val path = "/v1/the-endpoint" + private val path = "/the-endpoint" @Test fun `Failing to connect to the server should throw an exception`() { @@ -370,88 +362,75 @@ class PoWebClientTest { @Test fun `Losing the connection abruptly should throw an exception`(): Unit = runBlocking { - val client = PoWebClient.initLocal(mockWebServer.port) setListenerActions(ServerShutdownAction()) - client.use { - val exception = assertThrows { - runBlocking { - client.wsConnect(path) { - incoming.receive() - } - } + val exception = assertThrows { + mockWSConnect { + incoming.receive() } - - assertEquals("Connection was closed abruptly", exception.message) - assertTrue(exception.cause is EOFException) } - } - - @Test - fun `Client should use WS if TLS is not required`() = runBlockingTest { - val wsRequest = mockWSConnect(false) {} - assertEquals(URLProtocol.WS, wsRequest.url.protocol) + assertEquals("Connection was closed abruptly", exception.message) + assertTrue(exception.cause is EOFException) } @Test - fun `Client should use WSS if TLS is required`() = runBlockingTest { - val wsRequest = mockWSConnect(true) {} + fun `Client should use WS if TLS is not required`() = runBlocking { + val client = PoWebClient(hostName, mockWebServer.port, false) - assertEquals(URLProtocol.WSS, wsRequest.url.protocol) + assertTrue(client.baseWsUrl.startsWith("ws:"), "Actual URL: ${client.baseWsUrl}") } @Test - fun `Client should connect to specified host and port`(): Unit = runBlockingTest { - val wsRequest = mockWSConnect(true) {} + fun `Client should use WSS if TLS is required`() = runBlocking { + val client = PoWebClient(hostName, mockWebServer.port, true) - assertEquals(hostName, wsRequest.url.host) - assertEquals(port, wsRequest.url.port) + assertTrue(client.baseWsUrl.startsWith("wss:")) } @Test fun `Client should connect to specified path`() = runBlocking { - val wsRequest = mockWSConnect {} + setListenerActions(CloseConnectionAction()) + + mockWSConnect {} - assertEquals(path, wsRequest.url.encodedPath) + val request = mockWebServer.takeRequest() + assertEquals("/v1$path", request.path) } @Test fun `Request headers should be honored`() = runBlocking { val header1 = Pair("x-h1", "value1") val header2 = Pair("x-h2", "value2") + setListenerActions(CloseConnectionAction()) - val wsRequest = mockWSConnect(headers = listOf(header1, header2)) {} + mockWSConnect(listOf(header1, header2)) {} - assertEquals(header1.second, wsRequest.headers[header1.first]) - assertEquals(header2.second, wsRequest.headers[header2.first]) + val request = mockWebServer.takeRequest() + assertEquals(header1.second, request.headers[header1.first]) + assertEquals(header2.second, request.headers[header2.first]) } @Test fun `Specified block should be called`(): Unit = runBlocking { setListenerActions(CloseConnectionAction()) - val client = PoWebClient.initLocal(mockWebServer.port) var wasBlockRun = false - client.wsConnect(path) { wasBlockRun = true } + mockWSConnect { wasBlockRun = true } assertTrue(wasBlockRun) } private suspend fun mockWSConnect( - useTls: Boolean = false, headers: List>? = null, block: suspend DefaultClientWebSocketSession.() -> Unit - ): HttpRequestData { - val client = PoWebClient(hostName, port, useTls) - val ktorClientManager = MockKtorClientManager() - client.ktorClient = ktorClientManager.ktorClient + ) { + val client = PoWebClient(hostName, mockWebServer.port, false) + client.ktorClient = ktorWSClient - ktorClientManager.useClient { + client.use { client.wsConnect(path, headers, block) } - - return ktorClientManager.request } } } diff --git a/src/test/kotlin/tech/relaycorp/poweb/RegistrationTest.kt b/src/test/kotlin/tech/relaycorp/poweb/RegistrationTest.kt index 6a7a304..8d5f63b 100644 --- a/src/test/kotlin/tech/relaycorp/poweb/RegistrationTest.kt +++ b/src/test/kotlin/tech/relaycorp/poweb/RegistrationTest.kt @@ -7,9 +7,7 @@ import io.ktor.http.HttpMethod import io.ktor.http.HttpStatusCode import io.ktor.http.content.OutgoingContent import io.ktor.http.headersOf -import io.ktor.util.KtorExperimentalAPI -import kotlinx.coroutines.ExperimentalCoroutinesApi -import kotlinx.coroutines.test.runBlockingTest +import kotlinx.coroutines.runBlocking import org.junit.jupiter.api.Nested import org.junit.jupiter.api.Test import org.junit.jupiter.api.assertThrows @@ -25,8 +23,6 @@ import kotlin.test.assertEquals import kotlin.test.assertTrue @Suppress("RedundantInnerClassModifier") -@ExperimentalCoroutinesApi -@KtorExperimentalAPI class RegistrationTest { @Nested inner class PreRegistration { @@ -35,7 +31,7 @@ class RegistrationTest { headersOf("Content-Type", ContentTypes.NODE_REGISTRATION_AUTHORIZATION.value) @Test - fun `Request method should be POST`() = runBlockingTest { + fun `Request method should be POST`() = runBlocking { var method: HttpMethod? = null val client = makeTestClient { request: HttpRequestData -> method = request.method @@ -48,7 +44,7 @@ class RegistrationTest { } @Test - fun `Request should be made to the appropriate endpoint`() = runBlockingTest { + fun `Request should be made to the appropriate endpoint`() = runBlocking { var endpointURL: String? = null val client = makeTestClient { request: HttpRequestData -> endpointURL = request.url.toString() @@ -57,11 +53,11 @@ class RegistrationTest { client.use { client.preRegisterNode(publicKey) } - assertEquals("${client.baseURL}/pre-registrations", endpointURL) + assertEquals("${client.baseHttpUrl}/pre-registrations", endpointURL) } @Test - fun `Request Content-Type should be plain text`() = runBlockingTest { + fun `Request Content-Type should be plain text`() = runBlocking { var contentType: ContentType? = null val client = makeTestClient { request: HttpRequestData -> contentType = request.body.contentType @@ -74,7 +70,7 @@ class RegistrationTest { } @Test - fun `Request body should be SHA-256 digest of the node public key`() = runBlockingTest { + fun `Request body should be SHA-256 digest of the node public key`() = runBlocking { var requestBody: ByteArray? = null val client = makeTestClient { request: HttpRequestData -> assertTrue(request.body is OutgoingContent.ByteArrayContent) @@ -101,7 +97,7 @@ class RegistrationTest { } val exception = assertThrows { - runBlockingTest { + runBlocking { client.use { client.preRegisterNode(publicKey) } } } @@ -120,7 +116,7 @@ class RegistrationTest { } val exception = assertThrows { - runBlockingTest { + runBlocking { client.use { client.preRegisterNode(publicKey) } } } @@ -133,7 +129,7 @@ class RegistrationTest { @Test fun `Registration request should be output if pre-registration succeeds`() = - runBlockingTest { + runBlocking { val authorizationSerialized = "This is the PNRA".toByteArray() val client = makeTestClient { respond(authorizationSerialized, headers = responseHeaders) @@ -161,7 +157,7 @@ class RegistrationTest { private val registrationSerialized = registration.serialize() @Test - fun `Request method should be POST`() = runBlockingTest { + fun `Request method should be POST`() = runBlocking { var method: HttpMethod? = null val client = makeTestClient { request: HttpRequestData -> method = request.method @@ -174,7 +170,7 @@ class RegistrationTest { } @Test - fun `Request should be made to the appropriate endpoint`() = runBlockingTest { + fun `Request should be made to the appropriate endpoint`() = runBlocking { var endpointURL: String? = null val client = makeTestClient { request: HttpRequestData -> endpointURL = request.url.toString() @@ -183,11 +179,11 @@ class RegistrationTest { client.use { client.registerNode(pnrrSerialized) } - assertEquals("${client.baseURL}/nodes", endpointURL) + assertEquals("${client.baseHttpUrl}/nodes", endpointURL) } @Test - fun `Request Content-Type should be a PNRR`() = runBlockingTest { + fun `Request Content-Type should be a PNRR`() = runBlocking { var contentType: ContentType? = null val client = makeTestClient { request: HttpRequestData -> contentType = request.body.contentType @@ -200,7 +196,7 @@ class RegistrationTest { } @Test - fun `Request body should be the PNRR serialized`() = runBlockingTest { + fun `Request body should be the PNRR serialized`() = runBlocking { var requestBody: ByteArray? = null val client = makeTestClient { request: HttpRequestData -> assertTrue(request.body is OutgoingContent.ByteArrayContent) @@ -224,7 +220,7 @@ class RegistrationTest { } val exception = assertThrows { - runBlockingTest { + runBlocking { client.use { client.registerNode(pnrrSerialized) } } } @@ -242,7 +238,7 @@ class RegistrationTest { } val exception = assertThrows { - runBlockingTest { + runBlocking { client.use { client.registerNode(pnrrSerialized) } } } @@ -252,13 +248,13 @@ class RegistrationTest { } @Test - fun `Exception should be thrown if server reports we violated binding`() = runBlockingTest { + fun `Exception should be thrown if server reports we violated binding`() = runBlocking { val client = makeTestClient { respond("{}", status = HttpStatusCode.Forbidden) } val exception = assertThrows { - runBlockingTest { + runBlocking { client.use { client.registerNode(pnrrSerialized) } } } @@ -270,7 +266,7 @@ class RegistrationTest { } @Test - fun `Registration should be output if request succeeds`() = runBlockingTest { + fun `Registration should be output if request succeeds`() = runBlocking { val client = makeTestClient { respond(registrationSerialized, headers = responseHeaders) } diff --git a/src/test/kotlin/tech/relaycorp/poweb/Utils.kt b/src/test/kotlin/tech/relaycorp/poweb/Utils.kt index 4a1aba2..8e51f17 100644 --- a/src/test/kotlin/tech/relaycorp/poweb/Utils.kt +++ b/src/test/kotlin/tech/relaycorp/poweb/Utils.kt @@ -2,12 +2,8 @@ package tech.relaycorp.poweb import io.ktor.client.engine.mock.MockEngine import io.ktor.client.engine.mock.MockRequestHandler -import io.ktor.util.KtorExperimentalAPI import java.security.MessageDigest -internal const val NON_ROUTABLE_IP_ADDRESS = "192.0.2.1" - -@KtorExperimentalAPI internal fun makeTestClient(handler: MockRequestHandler): PoWebClient { val ktorEngine = MockEngine.create { addHandler(handler) diff --git a/src/test/kotlin/tech/relaycorp/poweb/websocket/MockKtorClientManager.kt b/src/test/kotlin/tech/relaycorp/poweb/websocket/MockKtorClientManager.kt deleted file mode 100644 index a1fd4e7..0000000 --- a/src/test/kotlin/tech/relaycorp/poweb/websocket/MockKtorClientManager.kt +++ /dev/null @@ -1,48 +0,0 @@ -package tech.relaycorp.poweb.websocket - -import io.ktor.client.HttpClient -import io.ktor.client.engine.mock.MockEngine -import io.ktor.client.features.websocket.WebSockets -import io.ktor.client.request.HttpRequestData -import io.ktor.util.KtorExperimentalAPI -import kotlin.test.junit5.JUnit5Asserter.fail - -/** - * Workaround to use Ktor's MockEngine with a WebSocket connection, which is currently unsupported. - * - * We're not actually implementing a mock WebSocket server with this: We're just recording the - * requests made. - */ -@KtorExperimentalAPI -class MockKtorClientManager { - private val requests = mutableListOf() - - val ktorClient = mockClient(requests) - - val request: HttpRequestData - get() = requests.single() - - suspend fun useClient(block: suspend () -> Unit) { - try { - block() - } catch (_: SkipHandlerException) { - return - } - fail("Mock handler was not reached") - } - - companion object { - private fun mockClient(requests: MutableList) = HttpClient(MockEngine) { - install(WebSockets) - - engine { - addHandler { request -> - requests.add(request) - throw SkipHandlerException() - } - } - } - - private class SkipHandlerException : Exception() - } -} diff --git a/src/test/kotlin/tech/relaycorp/poweb/websocket/WebSocketTestCase.kt b/src/test/kotlin/tech/relaycorp/poweb/websocket/WebSocketTestCase.kt index b9836e0..1c4ea09 100644 --- a/src/test/kotlin/tech/relaycorp/poweb/websocket/WebSocketTestCase.kt +++ b/src/test/kotlin/tech/relaycorp/poweb/websocket/WebSocketTestCase.kt @@ -1,5 +1,10 @@ package tech.relaycorp.poweb.websocket +import io.ktor.client.HttpClient +import io.ktor.client.engine.HttpClientEngine +import io.ktor.client.engine.okhttp.OkHttp +import io.ktor.client.features.websocket.WebSockets +import okhttp3.OkHttpClient import okhttp3.mockwebserver.MockResponse import okhttp3.mockwebserver.MockWebServer import org.awaitility.Awaitility.await @@ -13,6 +18,13 @@ open class WebSocketTestCase(private val autoStartServer: Boolean = true) { protected var listener: MockWebSocketListener? = null + private val okhttpEngine: HttpClientEngine = OkHttp.create { + preconfigured = OkHttpClient.Builder().build() + } + protected val ktorWSClient = HttpClient(okhttpEngine) { + install(WebSockets) + } + @BeforeEach fun startServer() { if (autoStartServer) {