diff --git a/build.gradle b/build.gradle index 7887d5d..6df00dc 100644 --- a/build.gradle +++ b/build.gradle @@ -1,11 +1,12 @@ import org.jetbrains.kotlin.gradle.tasks.KotlinCompile buildscript { - ext { - kotlinVersion = '1.3.72' - protobufVersion = '3.12.2' - protobufGradleVersion = '0.8.12' - } + ext.kotlinVersion = '1.3.72' + ext.protobufVersion = '3.12.2' + ext.protobufGradleVersion = '0.8.12' + ext.kotlinCoroutinesVersion = '1.3.7' + ext.ktorVersion = '1.3.2' + ext.okhttpVersion = '4.8.0' } plugins { @@ -32,17 +33,29 @@ dependencies { implementation("org.jetbrains.kotlin:kotlin-stdlib-jdk8:$kotlinVersion") + api('tech.relaycorp:relaynet:1.16.1') + + // Handshake nonce signatures + implementation("org.bouncycastle:bcpkix-jdk15on:1.64") + + implementation("io.ktor:ktor-client-okhttp:$ktorVersion") + testImplementation("io.ktor:ktor-client-mock:$ktorVersion") + testImplementation("io.ktor:ktor-client-mock-jvm:$ktorVersion") + testImplementation("com.squareup.okhttp3:okhttp:$okhttpVersion") + testImplementation("com.squareup.okhttp3:mockwebserver:$okhttpVersion") + testImplementation("com.squareup.okio:okio:2.7.0") + // Protobuf - implementation "com.google.protobuf:protobuf-gradle-plugin:$protobufGradleVersion" - implementation "com.google.protobuf:protobuf-java:$protobufVersion" - implementation "com.google.protobuf:protobuf-java-util:$protobufVersion" + implementation("com.google.protobuf:protobuf-gradle-plugin:$protobufGradleVersion") + implementation("com.google.protobuf:protobuf-java:$protobufVersion") + implementation("com.google.protobuf:protobuf-java-util:$protobufVersion") testImplementation("org.jetbrains.kotlin:kotlin-test") - - // Use the Kotlin JUnit5 integration. testImplementation("org.junit.jupiter:junit-jupiter:5.6.2") testImplementation("org.junit.jupiter:junit-jupiter-params:5.6.2") testImplementation("org.jetbrains.kotlin:kotlin-test-junit5") + testImplementation("com.nhaarman.mockitokotlin2:mockito-kotlin:2.2.0") + testImplementation("org.mockito:mockito-inline:3.4.0") } java { @@ -50,18 +63,16 @@ java { withSourcesJar() } -tasks.withType(KotlinCompile).configureEach { - kotlinOptions.jvmTarget = "1.8" +tasks.withType(KotlinCompile).all { + kotlinOptions { + jvmTarget = JavaVersion.VERSION_1_8 + } } protobuf { protoc { artifact = "com.google.protobuf:protoc:$protobufVersion" } } -tasks.withType(KotlinCompile).configureEach { - kotlinOptions.jvmTarget = "1.8" -} - tasks.dokka { outputFormat = "html" outputDirectory = "$buildDir/docs/api" @@ -124,3 +135,17 @@ spotless { ktlint().userData(ktlintUserData) } } + +// Workaround for https://github.com/google/protobuf-gradle-plugin/issues/391 +configurations { + "compileProtoPath" { + attributes { + attribute(Usage.USAGE_ATTRIBUTE, objects.named(Usage, "java-runtime")) + } + } + "testCompileProtoPath" { + attributes { + attribute(Usage.USAGE_ATTRIBUTE, objects.named(Usage, "java-runtime")) + } + } +} diff --git a/jacoco.gradle b/jacoco.gradle index df8bc56..84a1d7c 100644 --- a/jacoco.gradle +++ b/jacoco.gradle @@ -41,7 +41,8 @@ jacocoTestCoverageVerification { limit { counter = "BRANCH" value = "MISSEDCOUNT" - maximum = "0".toBigDecimal() + // Workaround for https://github.com/jacoco/jacoco/issues/1036 + maximum = "1".toBigDecimal() } } } diff --git a/src/main/kotlin/tech/relaycorp/poweb/PoWebClient.kt b/src/main/kotlin/tech/relaycorp/poweb/PoWebClient.kt new file mode 100644 index 0000000..a57be0e --- /dev/null +++ b/src/main/kotlin/tech/relaycorp/poweb/PoWebClient.kt @@ -0,0 +1,67 @@ +package tech.relaycorp.poweb + +import io.ktor.client.HttpClient +import io.ktor.client.engine.okhttp.OkHttp +import io.ktor.client.features.websocket.DefaultClientWebSocketSession +import io.ktor.client.features.websocket.WebSockets +import io.ktor.client.features.websocket.webSocket +import io.ktor.http.cio.websocket.CloseReason +import io.ktor.http.cio.websocket.Frame +import io.ktor.http.cio.websocket.close +import io.ktor.http.cio.websocket.readBytes +import io.ktor.util.KtorExperimentalAPI +import tech.relaycorp.poweb.handshake.Challenge +import tech.relaycorp.poweb.handshake.InvalidMessageException +import tech.relaycorp.poweb.handshake.NonceSigner +import tech.relaycorp.poweb.handshake.Response +import java.io.Closeable + +class PoWebClient internal constructor( + internal val hostName: String, + internal val port: Int, + internal val useTls: Boolean +) : Closeable { + @KtorExperimentalAPI + internal var ktorClient = HttpClient(OkHttp) { + install(WebSockets) + } + + @KtorExperimentalAPI + override fun close() = ktorClient.close() + + private val wsUrl = "ws${if (useTls) "s" else ""}://$hostName:$port" + + @KtorExperimentalAPI + internal suspend fun wsConnect( + path: String, + block: suspend DefaultClientWebSocketSession.() -> Unit + ) = ktorClient.webSocket("${wsUrl}$path", block = block) + + companion object { + private const val defaultLocalPort = 276 + private const val defaultRemotePort = 443 + + fun initLocal(port: Int = defaultLocalPort) = + PoWebClient("127.0.0.1", port, false) + + fun initRemote(hostName: String, port: Int = defaultRemotePort) = + PoWebClient(hostName, port, true) + } +} + +@Throws(PoWebException::class) +internal suspend fun DefaultClientWebSocketSession.handshake(nonceSigners: Array) { + if (nonceSigners.isEmpty()) { + throw PoWebException("At least one nonce signer must be specified") + } + val challengeRaw = incoming.receive() + val challenge = try { + Challenge.deserialize(challengeRaw.readBytes()) + } catch (exc: InvalidMessageException) { + close(CloseReason(CloseReason.Codes.VIOLATED_POLICY, "")) + throw PoWebException("Server sent an invalid handshake challenge", exc) + } + val nonceSignatures = nonceSigners.map { it.sign(challenge.nonce) }.toTypedArray() + val response = Response(nonceSignatures) + outgoing.send(Frame.Binary(true, response.serialize())) +} diff --git a/src/main/kotlin/tech/relaycorp/poweb/PoWebException.kt b/src/main/kotlin/tech/relaycorp/poweb/PoWebException.kt new file mode 100644 index 0000000..9dc76dc --- /dev/null +++ b/src/main/kotlin/tech/relaycorp/poweb/PoWebException.kt @@ -0,0 +1,3 @@ +package tech.relaycorp.poweb + +class PoWebException(message: String, cause: Throwable? = null) : Exception(message, cause) diff --git a/src/main/kotlin/tech/relaycorp/poweb/handshake/NonceSigner.kt b/src/main/kotlin/tech/relaycorp/poweb/handshake/NonceSigner.kt new file mode 100644 index 0000000..a249311 --- /dev/null +++ b/src/main/kotlin/tech/relaycorp/poweb/handshake/NonceSigner.kt @@ -0,0 +1,33 @@ +package tech.relaycorp.poweb.handshake + +import org.bouncycastle.cms.CMSProcessableByteArray +import org.bouncycastle.cms.CMSSignedDataGenerator +import org.bouncycastle.cms.CMSTypedData +import org.bouncycastle.cms.jcajce.JcaSignerInfoGeneratorBuilder +import org.bouncycastle.operator.ContentSigner +import org.bouncycastle.operator.jcajce.JcaContentSignerBuilder +import org.bouncycastle.operator.jcajce.JcaDigestCalculatorProviderBuilder +import tech.relaycorp.relaynet.wrappers.x509.Certificate +import java.security.PrivateKey + +class NonceSigner(internal val certificate: Certificate, private val privateKey: PrivateKey) { + fun sign(nonce: ByteArray): ByteArray { + val signedDataGenerator = CMSSignedDataGenerator() + + val signerBuilder = JcaContentSignerBuilder("SHA256withRSA") + val contentSigner: ContentSigner = signerBuilder.build(privateKey) + val signerInfoGenerator = JcaSignerInfoGeneratorBuilder( + JcaDigestCalculatorProviderBuilder() + .build() + ).build(contentSigner, certificate.certificateHolder) + signedDataGenerator.addSignerInfoGenerator( + signerInfoGenerator + ) + + signedDataGenerator.addCertificate(certificate.certificateHolder) + + val plaintextCms: CMSTypedData = CMSProcessableByteArray(nonce) + val cmsSignedData = signedDataGenerator.generate(plaintextCms, false) + return cmsSignedData.encoded + } +} diff --git a/src/test/kotlin/tech/relaycorp/poweb/PoWebClientTest.kt b/src/test/kotlin/tech/relaycorp/poweb/PoWebClientTest.kt new file mode 100644 index 0000000..665e5e7 --- /dev/null +++ b/src/test/kotlin/tech/relaycorp/poweb/PoWebClientTest.kt @@ -0,0 +1,334 @@ +package tech.relaycorp.poweb + +import com.nhaarman.mockitokotlin2.spy +import com.nhaarman.mockitokotlin2.verify +import io.ktor.client.HttpClient +import io.ktor.client.engine.mock.MockEngine +import io.ktor.client.engine.okhttp.OkHttpEngine +import io.ktor.client.features.websocket.DefaultClientWebSocketSession +import io.ktor.client.features.websocket.WebSockets +import io.ktor.client.request.HttpRequestData +import io.ktor.http.URLProtocol +import io.ktor.http.cio.websocket.CloseReason +import io.ktor.http.fullPath +import io.ktor.util.InternalAPI +import io.ktor.util.KtorExperimentalAPI +import kotlinx.coroutines.runBlocking +import okhttp3.Response +import okhttp3.WebSocket +import okhttp3.WebSocketListener +import okio.ByteString +import okio.ByteString.Companion.toByteString +import org.junit.jupiter.api.AfterEach +import org.junit.jupiter.api.Disabled +import org.junit.jupiter.api.Nested +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.assertThrows +import tech.relaycorp.poweb.handshake.CMSUtils +import tech.relaycorp.poweb.handshake.Challenge +import tech.relaycorp.poweb.handshake.InvalidMessageException +import tech.relaycorp.poweb.handshake.NonceSigner +import tech.relaycorp.relaynet.issueEndpointCertificate +import tech.relaycorp.relaynet.wrappers.generateRSAKeyPair +import java.time.ZonedDateTime +import kotlin.test.assertEquals +import kotlin.test.assertFalse +import kotlin.test.assertTrue + +@KtorExperimentalAPI +class PoWebClientTest { + @Nested + inner class Constructor { + @Nested + inner class InitLocal { + @Test + fun `Host name should be the localhost IP address`() { + val client = PoWebClient.initLocal() + + assertEquals("127.0.0.1", client.hostName) + } + + @Test + fun `TLS should not be used`() { + val client = PoWebClient.initLocal() + + assertFalse(client.useTls) + } + + @Test + fun `Port should default to 276`() { + val client = PoWebClient.initLocal() + + assertEquals(276, client.port) + } + + @Test + fun `Port should be overridable`() { + val customPort = 13276 + val client = PoWebClient.initLocal(customPort) + + assertEquals(customPort, client.port) + } + } + + @Nested + inner class InitRemote { + private val hostName = "gb.relaycorp.tech" + + @Test + fun `Specified host name should be honored`() { + val client = PoWebClient.initRemote(hostName) + + assertEquals(hostName, client.hostName) + } + + @Test + fun `TLS should be used`() { + val client = PoWebClient.initRemote(hostName) + + assertTrue(client.useTls) + } + + @Test + fun `Port should default to 443`() { + val client = PoWebClient.initRemote(hostName) + + assertEquals(443, client.port) + } + + @Test + fun `Port should be overridable`() { + val customPort = 1234 + val client = PoWebClient.initRemote(hostName, customPort) + + assertEquals(customPort, client.port) + } + } + } + + @InternalAPI + @Test + fun `OkHTTP should be the client engine`() { + val client = PoWebClient.initLocal() + + assertTrue(client.ktorClient.engine is OkHttpEngine) + } + + @Test + fun `Close method should close underlying Ktor client`() { + val client = PoWebClient.initLocal() + client.ktorClient = spy(client.ktorClient) + + client.close() + + verify(client.ktorClient).close() + } + + @Nested + inner class WebSocketConnection : WebSocketTestCase(false) { + private val hostName = "127.0.0.1" + private val port = 13276 + private val path = "/v1/the-endpoint" + + @Test + fun `Client should use WS if TLS is not required`() { + val wsRequest = wsConnect(false) {} + + assertEquals(URLProtocol.WS, wsRequest.url.protocol) + } + + @Test + fun `Client should use WSS if TLS is required`() { + val wsRequest = wsConnect(true) {} + + assertEquals(URLProtocol.WSS, wsRequest.url.protocol) + } + + @Test + fun `Client should connect to specified host`() { + val wsRequest = wsConnect {} + + assertEquals(hostName, wsRequest.url.host) + } + + @Test + fun `Client should connect to specified port`() { + val wsRequest = wsConnect {} + + assertEquals(port, wsRequest.url.port) + } + + @Test + fun `Client should connect to specified path`() { + val wsRequest = wsConnect {} + + assertEquals(path, wsRequest.url.fullPath) + } + + @Test + fun `Specified block should be called`(): Unit = runBlocking { + setWebSocketListener(object : WebSocketListener() { + override fun onOpen(webSocket: WebSocket, response: Response) { + webSocket.close(1000, "No-op") + } + }) + val client = PoWebClient.initLocal(mockWebServer.port) + + var wasBlockRun = false + client.wsConnect(path) { wasBlockRun = true } + + assertTrue(wasBlockRun) + } + + private fun wsConnect( + useTls: Boolean = false, + block: suspend DefaultClientWebSocketSession.() -> Unit + ): HttpRequestData { + val client = PoWebClient(hostName, port, useTls) + var connectionRequest: HttpRequestData? = null + client.ktorClient = HttpClient(MockEngine) { + install(WebSockets) + + engine { + addHandler { request -> + connectionRequest = request + error("Nothing to see here") + } + } + } + + assertThrows { runBlocking { client.wsConnect(path, block) } } + assertTrue(connectionRequest is HttpRequestData) + + return connectionRequest as HttpRequestData + } + } + + @Nested + inner class Handshake : WebSocketTestCase() { + private val nonce = "nonce".toByteArray() + private val challengeSerialized = Challenge(nonce).serialize().toByteString() + + // Compute client on demand because getting the server port will start the server + private val client by lazy { PoWebClient.initLocal(mockWebServer.port) } + + private val signer = generateDummySigner() + + @AfterEach + fun closeClient() = client.ktorClient.close() + + @Test + fun `Getting an invalid challenge should result in an exception`() { + var closeCode: Int? = null + + setWebSocketListener(object : WebSocketListener() { + override fun onOpen(webSocket: WebSocket, response: Response) { + webSocket.send("Not a valid challenge") + } + + override fun onClosing(webSocket: WebSocket, code: Int, reason: String) { + closeCode = code + super.onClosing(webSocket, code, reason) + } + }) + + client.use { + val exception = assertThrows { + runBlocking { client.wsConnect("/") { handshake(arrayOf(signer)) } } + } + + assertEquals("Server sent an invalid handshake challenge", exception.message) + assertTrue(exception.cause is InvalidMessageException) + } + assertEquals(CloseReason.Codes.VIOLATED_POLICY.code.toInt(), closeCode) + } + + @Test + fun `At least one nonce signer should be required`() { + var closeCode: Int? = null + + setWebSocketListener(object : WebSocketListener() { + override fun onOpen(webSocket: WebSocket, response: Response) { + webSocket.send(challengeSerialized) + } + + override fun onClosing(webSocket: WebSocket, code: Int, reason: String) { + closeCode = code + } + }) + + client.use { + val exception = assertThrows { + runBlocking { + client.wsConnect("/") { handshake(emptyArray()) } + } + } + + assertEquals("At least one nonce signer must be specified", exception.message) + } + assertEquals(CloseReason.Codes.NORMAL.code.toInt(), closeCode) + } + + @Test + fun `Challenge nonce should be signed with each signer`() { + var response: tech.relaycorp.poweb.handshake.Response? = null + + setWebSocketListener(object : WebSocketListener() { + override fun onOpen(webSocket: WebSocket, response: Response) { + webSocket.send(challengeSerialized) + } + + override fun onMessage(webSocket: WebSocket, bytes: ByteString) { + response = tech.relaycorp.poweb.handshake.Response.deserialize( + bytes.toByteArray() + ) + webSocket.close(1000, "") + } + }) + + val signer2 = generateDummySigner() + + client.use { + runBlocking { client.wsConnect("/") { handshake(arrayOf(signer, signer2)) } } + + assertTrue(response is tech.relaycorp.poweb.handshake.Response) + val nonceSignatures = response!!.nonceSignatures + assertEquals( + signer.certificate, + CMSUtils.verifySignature(nonceSignatures[0], nonce) + ) + assertEquals( + signer2.certificate, + CMSUtils.verifySignature(nonceSignatures[1], nonce) + ) + } + } + + private fun generateDummySigner(): NonceSigner { + val keyPair = generateRSAKeyPair() + val certificate = issueEndpointCertificate( + keyPair.public, + keyPair.private, + ZonedDateTime.now().plusDays(1)) + return NonceSigner(certificate, keyPair.private) + } + } + + @Nested + inner class ParcelCollection { + @Test + @Disabled + fun `Request should be made to the parcel collection endpoint`() { + } + + @Test + @Disabled + fun `Call should return if server closed connection normally`() { + } + + @Test + @Disabled + fun `An exception should be thrown if the server closes the connection with an error`() { + } + } +} diff --git a/src/test/kotlin/tech/relaycorp/poweb/WebSocketTestCase.kt b/src/test/kotlin/tech/relaycorp/poweb/WebSocketTestCase.kt new file mode 100644 index 0000000..052b687 --- /dev/null +++ b/src/test/kotlin/tech/relaycorp/poweb/WebSocketTestCase.kt @@ -0,0 +1,34 @@ +package tech.relaycorp.poweb + +import okhttp3.WebSocketListener +import okhttp3.mockwebserver.MockResponse +import okhttp3.mockwebserver.MockWebServer +import org.junit.jupiter.api.AfterEach +import org.junit.jupiter.api.BeforeEach +import java.io.IOException + +open class WebSocketTestCase(private val autoStartServer: Boolean = true) { + protected val mockWebServer = MockWebServer() + + @BeforeEach + fun startServer() { + if (autoStartServer) { + mockWebServer.start() + } + } + + @AfterEach + fun stopServer() { + try { + mockWebServer.shutdown() + } catch (exc: IOException) { + // Ignore the weird "Gave up waiting for queue to shut down" exception in + // MockWebServer when the code under test closes the connection explicitly + // TODO: Raise issue in OkHTTP repo + } + } + + protected fun setWebSocketListener(listener: WebSocketListener) { + mockWebServer.enqueue(MockResponse().withWebSocketUpgrade(listener)) + } +} diff --git a/src/test/kotlin/tech/relaycorp/poweb/handshake/CMSUtils.kt b/src/test/kotlin/tech/relaycorp/poweb/handshake/CMSUtils.kt new file mode 100644 index 0000000..b946df2 --- /dev/null +++ b/src/test/kotlin/tech/relaycorp/poweb/handshake/CMSUtils.kt @@ -0,0 +1,56 @@ +package tech.relaycorp.poweb.handshake + +import org.bouncycastle.asn1.ASN1InputStream +import org.bouncycastle.asn1.cms.ContentInfo +import org.bouncycastle.cert.X509CertificateHolder +import org.bouncycastle.cert.selector.X509CertificateHolderSelector +import org.bouncycastle.cms.CMSProcessableByteArray +import org.bouncycastle.cms.CMSSignedData +import org.bouncycastle.cms.SignerInformation +import org.bouncycastle.cms.jcajce.JcaSimpleSignerInfoVerifierBuilder +import org.bouncycastle.util.Selector +import tech.relaycorp.relaynet.wrappers.x509.Certificate + +object CMSUtils { + internal fun verifySignature( + cmsSignedData: ByteArray, + plaintext: ByteArray + ): Certificate { + val signedData = parseCmsSignedData(cmsSignedData, plaintext) + + val signerInfo = getSignerInfoFromSignedData(signedData) + + // We shouldn't have to force this type cast but this is the only way I could get the code to work and, based on + // what I found online, that's what others have had to do as well + @Suppress("UNCHECKED_CAST") val signerCertSelector = X509CertificateHolderSelector( + signerInfo.sid.issuer, + signerInfo.sid.serialNumber + ) as Selector + + val signerCertMatches = signedData.certificates.getMatches(signerCertSelector) + val signerCertificateHolder = signerCertMatches.first() + val verifier = JcaSimpleSignerInfoVerifierBuilder().build(signerCertificateHolder) + + signerInfo.verify(verifier) + + return Certificate(signerCertificateHolder) + } + + private fun parseCmsSignedData( + cmsSignedDataSerialized: ByteArray, + expectedPlaintext: ByteArray + ): CMSSignedData { + val asn1Stream = ASN1InputStream(cmsSignedDataSerialized) + val asn1Sequence = asn1Stream.readObject() + val contentInfo = ContentInfo.getInstance(asn1Sequence) + return CMSSignedData(CMSProcessableByteArray(expectedPlaintext), contentInfo) + } + + private fun getSignerInfoFromSignedData(signedData: CMSSignedData): SignerInformation { + val signersCount = signedData.signerInfos.size() + if (signersCount != 1) { + throw Exception("SignedData should contain exactly one SignerInfo (got $signersCount)") + } + return signedData.signerInfos.first() + } +} diff --git a/src/test/kotlin/tech/relaycorp/poweb/handshake/NonceSignerTest.kt b/src/test/kotlin/tech/relaycorp/poweb/handshake/NonceSignerTest.kt new file mode 100644 index 0000000..e37b65d --- /dev/null +++ b/src/test/kotlin/tech/relaycorp/poweb/handshake/NonceSignerTest.kt @@ -0,0 +1,197 @@ +package tech.relaycorp.poweb.handshake + +import org.bouncycastle.asn1.ASN1InputStream +import org.bouncycastle.asn1.ASN1ObjectIdentifier +import org.bouncycastle.asn1.ASN1Primitive +import org.bouncycastle.asn1.DEROctetString +import org.bouncycastle.asn1.cms.Attribute +import org.bouncycastle.asn1.cms.ContentInfo +import org.bouncycastle.cms.CMSSignedData +import org.bouncycastle.util.CollectionStore +import org.junit.jupiter.api.Nested +import org.junit.jupiter.api.Test +import tech.relaycorp.relaynet.issueEndpointCertificate +import tech.relaycorp.relaynet.wrappers.generateRSAKeyPair +import java.security.MessageDigest +import java.time.ZonedDateTime +import kotlin.test.assertEquals + +class NonceSignerTest { + private val nonce = "The nonce".toByteArray() + private val keyPair = generateRSAKeyPair() + private val certificate = issueEndpointCertificate( + keyPair.public, + keyPair.private, + ZonedDateTime.now().plusDays(1) + ) + private val signer = NonceSigner(certificate, keyPair.private) + + @Test + fun `Serialization should be DER-encoded`() { + val serialization = signer.sign(nonce) + + parseDer(serialization) + } + + @Test + fun `SignedData value should be wrapped in a ContentInfo value`() { + val serialization = signer.sign(nonce) + + ContentInfo.getInstance(parseDer(serialization)) + } + + @Test + fun `SignedData version should be set to 1`() { + val serialization = signer.sign(nonce) + + val cmsSignedData = parseCmsSignedData(serialization) + + assertEquals(1, cmsSignedData.version) + } + + @Test + fun `Plaintext should not be embedded`() { + val serialization = signer.sign(nonce) + + val cmsSignedData = parseCmsSignedData(serialization) + + assertEquals(null, cmsSignedData.signedContent) + } + + @Nested + inner class SignerInfo { + @Test + fun `There should only be one SignerInfo`() { + val serialization = signer.sign(nonce) + + val cmsSignedData = parseCmsSignedData(serialization) + + assertEquals(1, cmsSignedData.signerInfos.size()) + } + + @Test + fun `SignerInfo version should be set to 1`() { + val serialization = signer.sign(nonce) + + val cmsSignedData = parseCmsSignedData(serialization) + + val signerInfo = cmsSignedData.signerInfos.first() + assertEquals(1, signerInfo.version) + } + + @Test + fun `SignerIdentifier should be IssuerAndSerialNumber`() { + val serialization = signer.sign(nonce) + + val cmsSignedData = parseCmsSignedData(serialization) + + val signerInfo = cmsSignedData.signerInfos.first() + assertEquals(certificate.certificateHolder.issuer, signerInfo.sid.issuer) + assertEquals( + certificate.certificateHolder.serialNumber, + signerInfo.sid.serialNumber + ) + } + + @Nested + inner class SignedAttributes { + @Test + fun `Signed attributes should be present`() { + val serialization = signer.sign(nonce) + + val cmsSignedData = parseCmsSignedData(serialization) + + val signerInfo = cmsSignedData.signerInfos.first() + + assert(0 < signerInfo.signedAttributes.size()) + } + + @Test + fun `Content type attribute should be set to CMS Data`() { + val serialization = signer.sign(nonce) + + val cmsSignedData = parseCmsSignedData(serialization) + + val signerInfo = cmsSignedData.signerInfos.first() + + val cmsContentTypeAttrOid = ASN1ObjectIdentifier("1.2.840.113549.1.9.3") + val contentTypeAttrs = + signerInfo.signedAttributes.getAll(cmsContentTypeAttrOid) + assertEquals(1, contentTypeAttrs.size()) + val contentTypeAttr = contentTypeAttrs.get(0) as Attribute + assertEquals(1, contentTypeAttr.attributeValues.size) + val cmsDataOid = "1.2.840.113549.1.7.1" + assertEquals(cmsDataOid, contentTypeAttr.attributeValues[0].toString()) + } + + @Test + fun `Plaintext digest should be present`() { + val serialization = signer.sign(nonce) + + val cmsSignedData = parseCmsSignedData(serialization) + + val signerInfo = cmsSignedData.signerInfos.first() + + val cmsDigestAttributeOid = ASN1ObjectIdentifier("1.2.840.113549.1.9.4") + val digestAttrs = + signerInfo.signedAttributes.getAll(cmsDigestAttributeOid) + assertEquals(1, digestAttrs.size()) + val digestAttr = digestAttrs.get(0) as Attribute + assertEquals(1, digestAttr.attributeValues.size) + val digest = MessageDigest.getInstance("SHA-256").digest(nonce) + assertEquals( + digest.asList(), + (digestAttr.attributeValues[0] as DEROctetString).octets.asList() + ) + } + } + } + + @Test + fun `Signer certificate should be attached`() { + val serialization = signer.sign(nonce) + + val cmsSignedData = parseCmsSignedData(serialization) + + val attachedCerts = + (cmsSignedData.certificates as CollectionStore).asSequence().toList() + assertEquals(1, attachedCerts.size) + assertEquals(certificate.certificateHolder, attachedCerts[0]) + } + + @Test + fun `SHA-256 should be used`() { + val serialization = signer.sign(nonce) + + val cmsSignedData = parseCmsSignedData(serialization) + + assertEquals(1, cmsSignedData.digestAlgorithmIDs.size) + val sha256Oid = ASN1ObjectIdentifier("2.16.840.1.101.3.4.2.1") + assertEquals(sha256Oid, cmsSignedData.digestAlgorithmIDs.first().algorithm) + + val signerInfo = cmsSignedData.signerInfos.first() + + assertEquals(sha256Oid, signerInfo.digestAlgorithmID.algorithm) + } + + @Test + fun `Signature should verify`() { + val serialization = signer.sign(nonce) + + val signerCertificate = CMSUtils.verifySignature(serialization, nonce) + + assertEquals(certificate, signerCertificate) + } + + private fun parseDer(derSerialization: ByteArray): ASN1Primitive { + val asn1Stream = ASN1InputStream(derSerialization) + return asn1Stream.readObject() + } + + private fun parseCmsSignedData(serialization: ByteArray): CMSSignedData { + val contentInfo = ContentInfo.getInstance( + parseDer(serialization) + ) + return CMSSignedData(contentInfo) + } +}