diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/protocol/ProtocolTestGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/protocol/ProtocolTestGenerator.kt index 45fdbc1ac9..20121535a8 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/protocol/ProtocolTestGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/protocol/ProtocolTestGenerator.kt @@ -85,9 +85,10 @@ abstract class ProtocolTestGenerator { /** The entry point to render the protocol tests, invoked by the code generators. */ fun render(writer: RustWriter) { - val allTests = allMatchingTestCases().flatMap { - fixBrokenTestCase(it) - } + val allTests = + allMatchingTestCases().flatMap { + fixBrokenTestCase(it) + } if (allTests.isEmpty()) { return } @@ -104,50 +105,55 @@ abstract class ProtocolTestGenerator { * If the test is broken, we synthesize it in two versions: the original broken test with a `#[should_panic]` * attribute, so get alerted if the test now passes, and the fixed version, which should pass. */ - private fun fixBrokenTestCase(it: TestCase): List = if (!it.isBroken()) { - listOf(it) - } else { - assert(it.expectFail()) - - val brokenTest = it.findInBroken()!! - var fixed = brokenTest.fixIt(it) - - val intro = "The hotfix function for broken test case ${it.kind} ${it.id}" - val moreInfo = - """This test case was identified to be broken in at least these Smithy versions: [${brokenTest.inAtLeast.joinToString()}]. - |We are tracking things here: [${brokenTest.trackedIn.joinToString()}].""".trimMargin() - - // Something must change... - if (it == fixed) { - PANIC( - """$intro did not make any modifications. It is likely that the test case was + private fun fixBrokenTestCase(it: TestCase): List = + if (!it.isBroken()) { + listOf(it) + } else { + assert(it.expectFail()) + + val brokenTest = it.findInBroken()!! + var fixed = brokenTest.fixIt(it) + + val intro = "The hotfix function for broken test case ${it.kind} ${it.id}" + val moreInfo = + """This test case was identified to be broken in at least these Smithy versions: [${brokenTest.inAtLeast.joinToString()}]. + |We are tracking things here: [${brokenTest.trackedIn.joinToString()}]. + """.trimMargin() + + // Something must change... + if (it == fixed) { + PANIC( + """$intro did not make any modifications. It is likely that the test case was |fixed upstream, and you're now updating the Smithy version; in this case, remove the hotfix |function, as the test is no longer broken. - |$moreInfo""".trimMargin(), - ) - } + |$moreInfo + """.trimMargin(), + ) + } - // ... but the hotfix function is not allowed to change the test case kind... - if (it.kind != fixed.kind) { - PANIC( - """$intro changed the test case kind. This is not allowed. - |$moreInfo""".trimMargin(), - ) - } + // ... but the hotfix function is not allowed to change the test case kind... + if (it.kind != fixed.kind) { + PANIC( + """$intro changed the test case kind. This is not allowed. + |$moreInfo + """.trimMargin(), + ) + } - // ... nor its id. - if (it.id != fixed.id) { - PANIC( - """$intro changed the test case id. This is not allowed. - |$moreInfo""".trimMargin(), - ) - } + // ... nor its id. + if (it.id != fixed.id) { + PANIC( + """$intro changed the test case id. This is not allowed. + |$moreInfo + """.trimMargin(), + ) + } - // The latter is because we're going to generate the fixed version with an identifiable suffix. - fixed = fixed.suffixIdWith("_hotfixed") + // The latter is because we're going to generate the fixed version with an identifiable suffix. + fixed = fixed.suffixIdWith("_hotfixed") - listOf(it, fixed) - } + listOf(it, fixed) + } /** Implementors should describe how to render the test cases. **/ abstract fun RustWriter.renderAllTestCases(allTests: List) @@ -161,11 +167,12 @@ abstract class ProtocolTestGenerator { this.filter { testCase -> runOnly.contains(testCase.id) } } - private fun TestCase.toFailingTest(): FailingTest = when (this) { - is TestCase.MalformedRequestTest -> FailingTest.MalformedRequestTest(serviceShapeId.toString(), this.id) - is TestCase.RequestTest -> FailingTest.RequestTest(serviceShapeId.toString(), this.id) - is TestCase.ResponseTest -> FailingTest.ResponseTest(serviceShapeId.toString(), this.id) - } + private fun TestCase.toFailingTest(): FailingTest = + when (this) { + is TestCase.MalformedRequestTest -> FailingTest.MalformedRequestTest(serviceShapeId.toString(), this.id) + is TestCase.RequestTest -> FailingTest.RequestTest(serviceShapeId.toString(), this.id) + is TestCase.ResponseTest -> FailingTest.ResponseTest(serviceShapeId.toString(), this.id) + } /** Do we expect this test case to fail? */ private fun TestCase.expectFail(): Boolean = this.isBroken() || expectFail.contains(this.toFailingTest()) @@ -173,11 +180,12 @@ abstract class ProtocolTestGenerator { /** Is this test case broken? */ private fun TestCase.isBroken(): Boolean = this.findInBroken() != null - private fun TestCase.findInBroken(): BrokenTest? = brokenTests.find { brokenTest -> - (this is TestCase.RequestTest && brokenTest is BrokenTest.RequestTest && this.id == brokenTest.id) || - (this is TestCase.ResponseTest && brokenTest is BrokenTest.ResponseTest && this.id == brokenTest.id) || - (this is TestCase.MalformedRequestTest && brokenTest is BrokenTest.MalformedRequestTest && this.id == brokenTest.id) - } + private fun TestCase.findInBroken(): BrokenTest? = + brokenTests.find { brokenTest -> + (this is TestCase.RequestTest && brokenTest is BrokenTest.RequestTest && this.id == brokenTest.id) || + (this is TestCase.ResponseTest && brokenTest is BrokenTest.ResponseTest && this.id == brokenTest.id) || + (this is TestCase.MalformedRequestTest && brokenTest is BrokenTest.MalformedRequestTest && this.id == brokenTest.id) + } fun requestTestCases(): List { val requestTests = @@ -355,14 +363,13 @@ abstract class ProtocolTestGenerator { sealed class BrokenTest( open val serviceShapeId: String, open val id: String, - /** A non-exhaustive set of Smithy versions where the test was found to be broken. */ open val inAtLeast: Set, /** * GitHub URLs related to the test brokenness, like a GitHub issue in Smithy where we reported the test was broken, * or a PR where we fixed it. **/ - open val trackedIn: Set + open val trackedIn: Set, ) { data class RequestTest( override val serviceShapeId: String, @@ -411,8 +418,10 @@ object ServiceShapeId { sealed class FailingTest(open val serviceShapeId: String, open val id: String) { data class RequestTest(override val serviceShapeId: String, override val id: String) : FailingTest(serviceShapeId, id) + data class ResponseTest(override val serviceShapeId: String, override val id: String) : FailingTest(serviceShapeId, id) + data class MalformedRequestTest(override val serviceShapeId: String, override val id: String) : FailingTest(serviceShapeId, id) } @@ -455,19 +464,22 @@ sealed class TestCase { override fun equals(other: Any?): Boolean { if (this === other) return true if (other !is MalformedRequestTest) return false - return this.protocol == other.protocol && this.id == other.id && this.documentation == other.documentation && this.testCase.request.toNode() - .equals(other.testCase.request.toNode()) && this.testCase.response.toNode() - .equals(other.testCase.response.toNode()) + return this.protocol == other.protocol && this.id == other.id && this.documentation == other.documentation && + this.testCase.request.toNode() + .equals(other.testCase.request.toNode()) && + this.testCase.response.toNode() + .equals(other.testCase.response.toNode()) } override fun hashCode(): Int = testCase.hashCode() } - fun suffixIdWith(suffix: String): TestCase = when (this) { - is RequestTest -> RequestTest(this.testCase.suffixIdWith(suffix)) - is MalformedRequestTest -> MalformedRequestTest(this.testCase.suffixIdWith(suffix)) - is ResponseTest -> ResponseTest(this.testCase.suffixIdWith(suffix), this.targetShape) - } + fun suffixIdWith(suffix: String): TestCase = + when (this) { + is RequestTest -> RequestTest(this.testCase.suffixIdWith(suffix)) + is MalformedRequestTest -> MalformedRequestTest(this.testCase.suffixIdWith(suffix)) + is ResponseTest -> ResponseTest(this.testCase.suffixIdWith(suffix), this.targetShape) + } private fun HttpRequestTestCase.suffixIdWith(suffix: String): HttpRequestTestCase = this.toBuilder().id(this.id + suffix).build() diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt index 453a8605af..cb03181d7c 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt @@ -157,12 +157,13 @@ class ServerProtocolTestGenerator( "RestJsonMalformedPatternReDOSString", howToFixItFn = ::fixRestJsonMalformedPatternReDOSString, inAtLeast = setOf("1.26.2", "1.49.0"), - trackedIn = setOf( - // TODO(https://github.com/awslabs/smithy/issues/1506) - "https://github.com/awslabs/smithy/issues/1506", - // TODO(https://github.com/smithy-lang/smithy/pull/2340) - "https://github.com/smithy-lang/smithy/pull/2340", - ), + trackedIn = + setOf( + // TODO(https://github.com/awslabs/smithy/issues/1506) + "https://github.com/awslabs/smithy/issues/1506", + // TODO(https://github.com/smithy-lang/smithy/pull/2340) + "https://github.com/smithy-lang/smithy/pull/2340", + ), ), ) @@ -188,7 +189,9 @@ class ServerProtocolTestGenerator( "S3PreservesEmbeddedDotSegmentInUriLabel", ) - private fun fixRestJsonMalformedPatternReDOSString(testCase: TestCase.MalformedRequestTest): TestCase.MalformedRequestTest { + private fun fixRestJsonMalformedPatternReDOSString( + testCase: TestCase.MalformedRequestTest, + ): TestCase.MalformedRequestTest { val brokenResponse = testCase.testCase.response val brokenBody = brokenResponse.body.get() val fixedBody = @@ -321,7 +324,7 @@ class ServerProtocolTestGenerator( if (!protocolSupport.responseSerialization || ( !protocolSupport.errorSerialization && shape.hasTrait() - ) + ) ) { rust("/* test case disabled for this protocol (not yet supported) */") return