Skip to content

Commit

Permalink
Fix recursive constraint violations with paths over list and map shapes
Browse files Browse the repository at this point in the history
There is a widespread assumption throughout the generation of constraint
violations that does not hold true all the time, namely, that a
recursive constraint violation graph has the same requirements with
regards to boxing as the regular shape graph.

Some types corresponding to recursive shapes are boxed to introduce
indirection and thus not generate an infinitely recursive type. The
algorithm however does not superfluously introduce boxes when the cycle
goes through a list shape or a map shape. Why list shapes and map
shapes? List shapes and map shapes get rendered in Rust as `Vec<T>` and
`HashMap<K, V>`, respectively, they're the only Smithy shapes that
"organically" introduce indirection (via a pointer to the heap) in the
recursive path. For other recursive paths, we thus have to introduce the
indirection artificially ourselves using `Box`. This is done in the
`RecursiveShapeBoxer` model transform.

However, the constraint violation graph needs to box types in recursive
paths more often. Since we don't collect constraint violations
(yet, see #2040), the constraint violation graph never holds
`Vec<T>`s or `HashMap<K, V>`s, only simple types. Indeed, the following simple
recursive model:

```smithy
union Recursive {
    list: List
}

@Length(min: 69)
list List {
    member: Recursive
}
```

has a cycle that goes through a list shape, so no shapes in it need
boxing in the regular shape graph. However, the constraint violation
graph is infinitely recursive if we don't introduce boxing somewhere:

```rust
pub mod model {
    pub mod list {
        pub enum ConstraintViolation {
            Length(usize),
            Member(
                usize,
                crate::model::recursive::ConstraintViolation,
            ),
        }
    }

    pub mod recursive {
        pub enum ConstraintViolation {
            List(crate::model::list::ConstraintViolation),
        }
    }
}
```

This commit fixes things by making the `RecursiveShapeBoxer` model
transform configurable so that the "cycles through lists and maps
introduce indirection" assumption can be lifted. This allows a server
model transform, `RecursiveConstraintViolationBoxer`, to tag member
shapes along recursive paths with a new trait,
`ConstraintViolationRustBoxTrait`, that the constraint violation type
generation then utilizes to ensure that no infinitely recursive
constraint violation types get generated.

For example, for the above model, the generated Rust code would now look
like:

```rust
pub mod model {
    pub mod list {
        pub enum ConstraintViolation {
            Length(usize),
            Member(
                usize,
                std::boxed::Box(crate::model::recursive::ConstraintViolation),
            ),
        }
    }

    pub mod recursive {
        pub enum ConstraintViolation {
            List(crate::model::list::ConstraintViolation),
        }
    }
}
```

Likewise, places where constraint violations are handled (like where
unconstrained types are converted to constrained types) have been
updated to account for the scenario where they now are or need to be
boxed.

Parametrized tests have been added to exhaustively test combinations of
models exercising recursive paths going through (sparse and non-sparse)
list and map shapes, as well as union and structure shapes
(`RecursiveConstraintViolationsTest`). These tests even assert that the
specific member shapes along the cycles are tagged as expected
(`RecursiveConstraintViolationBoxerTest`).
  • Loading branch information
david-perez committed Feb 15, 2023
1 parent d7f8130 commit 8d79352
Show file tree
Hide file tree
Showing 29 changed files with 451 additions and 75 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -105,14 +105,14 @@ class ClientCodegenVisitor(
// Add errors attached at the service level to the models
.let { ModelTransformer.create().copyServiceErrorsToOperations(it, settings.getService(it)) }
// Add `Box<T>` to recursive shapes as necessary
.let(RecursiveShapeBoxer::transform)
.let(RecursiveShapeBoxer()::transform)
// Normalize the `message` field on errors when enabled in settings (default: true)
.letIf(settings.codegenConfig.addMessageToErrors, AddErrorMessage::transform)
// NormalizeOperations by ensuring every operation has an input & output shape
.let(OperationNormalizer::transform)
// Drop unsupported event stream operations from the model
.let { RemoveEventStreamOperations.transform(it, settings) }
// - Normalize event stream operations
// Normalize event stream operations
.let(EventStreamNormalizer::transform)

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ internal class ResiliencyConfigCustomizationTest {

@Test
fun `generates a valid config`() {
val model = RecursiveShapeBoxer.transform(OperationNormalizer.transform(baseModel))
val model = RecursiveShapeBoxer().transform(OperationNormalizer.transform(baseModel))
val project = TestWorkspace.testProject()
val codegenContext = testCodegenContext(model, settings = project.rustSettings())

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ import software.amazon.smithy.model.traits.Trait
/**
* Trait indicating that this shape should be represented with `Box<T>` when converted into Rust
*
* This is used to handle recursive shapes. See RecursiveShapeBoxer.
* This is used to handle recursive shapes.
* See [software.amazon.smithy.rust.codegen.core.smithy.transformers.RecursiveShapeBoxer].
*
* This trait is synthetic, applied during code generation, and never used in actual models.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,29 @@ package software.amazon.smithy.rust.codegen.core.smithy.transformers

import software.amazon.smithy.codegen.core.TopologicalIndex
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.ListShape
import software.amazon.smithy.model.shapes.CollectionShape
import software.amazon.smithy.model.shapes.MapShape
import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.SetShape
import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.model.transform.ModelTransformer
import software.amazon.smithy.rust.codegen.core.smithy.traits.RustBoxTrait
import software.amazon.smithy.rust.codegen.core.util.hasTrait

object RecursiveShapeBoxer {
class RecursiveShapeBoxer(
private val containsIndirectionPredicate: (Collection<Shape>) -> Boolean = ::containsIndirection,
private val boxShapeFn: (MemberShape) -> MemberShape = ::addRustBoxTrait,
) {
/**
* Transform a model which may contain recursive shapes into a model annotated with [RustBoxTrait]
* Transform a model which may contain recursive shapes into a model annotated with [RustBoxTrait].
*
* When recursive shapes do NOT go through a List, Map, or Set, they must be boxed in Rust. This function will
* iteratively find loops & add the `RustBox` trait in a deterministic way until it reaches a fixed point.
* When recursive shapes do NOT go through a `CollectionShape` or a `MapShape` shape, they must be boxed in Rust.
* This function will iteratively find loops and add the [RustBoxTrait] trait in a deterministic way until it
* reaches a fixed point.
*
* Why `CollectionShape`s and `MapShape`s? Note that `CollectionShape`s get rendered in Rust as `Vec<T>`, and
* `MapShape`s as `HashMap<String, T>`; they're the only Smithy shapes that "organically" introduce indirection
* (via a pointer to the heap) in the recursive path. For other recursive paths, we thus have to introduce the
* indirection artificially ourselves using `Box`.
*
* This function MUST be deterministic (always choose the same shapes to `Box`). If it is not, that is a bug. Even so
* this function may cause backward compatibility issues in certain pathological cases where a changes to recursive
Expand All @@ -41,12 +49,12 @@ object RecursiveShapeBoxer {
* If [model] contains no loops, return null.
*/
private fun transformInner(model: Model): Model? {
// Execute 1-step of the boxing algorithm in the path to reaching a fixed point
// 1. Find all the shapes that are part of a cycle
// 2. Find all the loops that those shapes are part of
// 3. Filter out the loops that go through a layer of indirection
// 3. Pick _just one_ of the remaining loops to fix
// 4. Select the member shape in that loop with the earliest shape id
// Execute 1 step of the boxing algorithm in the path to reaching a fixed point:
// 1. Find all the shapes that are part of a cycle.
// 2. Find all the loops that those shapes are part of.
// 3. Filter out the loops that go through a layer of indirection.
// 3. Pick _just one_ of the remaining loops to fix.
// 4. Select the member shape in that loop with the earliest shape id.
// 5. Box it.
// (External to this function) Go back to 1.
val index = TopologicalIndex.of(model)
Expand All @@ -58,34 +66,32 @@ object RecursiveShapeBoxer {
// Flatten the connections into shapes.
loops.map { it.shapes }
}
val loopToFix = loops.firstOrNull { !containsIndirection(it) }
val loopToFix = loops.firstOrNull { !containsIndirectionPredicate(it) }

return loopToFix?.let { loop: List<Shape> ->
check(loop.isNotEmpty())
// pick the shape to box in a deterministic way
// Pick the shape to box in a deterministic way.
val shapeToBox = loop.filterIsInstance<MemberShape>().minByOrNull { it.id }!!
ModelTransformer.create().mapShapes(model) { shape ->
if (shape == shapeToBox) {
shape.asMemberShape().get().toBuilder().addTrait(RustBoxTrait()).build()
boxShapeFn(shape.asMemberShape().get())
} else {
shape
}
}
}
}
}

/**
* Check if a List<Shape> contains a shape which will use a pointer when represented in Rust, avoiding the
* need to add more Boxes
*/
private fun containsIndirection(loop: List<Shape>): Boolean {
return loop.find {
when (it) {
is ListShape,
is MapShape,
is SetShape, -> true
else -> it.hasTrait<RustBoxTrait>()
}
} != null
/**
* Check if a `List<Shape>` contains a shape which will use a pointer when represented in Rust, avoiding the
* need to add more `Box`es.
*/
private fun containsIndirection(loop: Collection<Shape>): Boolean = loop.find {
when (it) {
is CollectionShape, is MapShape -> true
else -> it.hasTrait<RustBoxTrait>()
}
}
} != null

private fun addRustBoxTrait(shape: MemberShape): MemberShape = shape.toBuilder().addTrait(RustBoxTrait()).build()
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ class InstantiatorTest {
@required
num: Integer
}
""".asSmithyModel().let { RecursiveShapeBoxer.transform(it) }
""".asSmithyModel().let { RecursiveShapeBoxer().transform(it) }

private val codegenContext = testCodegenContext(model)
private val symbolProvider = codegenContext.symbolProvider
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ class StructureGeneratorTest {
@Test
fun `it generates accessor methods`() {
val testModel =
RecursiveShapeBoxer.transform(
RecursiveShapeBoxer().transform(
"""
namespace test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class AwsQueryParserGeneratorTest {

@Test
fun `it modifies operation parsing to include Response and Result tags`() {
val model = RecursiveShapeBoxer.transform(OperationNormalizer.transform(baseModel))
val model = RecursiveShapeBoxer().transform(OperationNormalizer.transform(baseModel))
val codegenContext = testCodegenContext(model)
val symbolProvider = codegenContext.symbolProvider
val parserGenerator = AwsQueryParserGenerator(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class Ec2QueryParserGeneratorTest {

@Test
fun `it modifies operation parsing to include Response and Result tags`() {
val model = RecursiveShapeBoxer.transform(OperationNormalizer.transform(baseModel))
val model = RecursiveShapeBoxer().transform(OperationNormalizer.transform(baseModel))
val codegenContext = testCodegenContext(model)
val symbolProvider = codegenContext.symbolProvider
val parserGenerator = Ec2QueryParserGenerator(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ class JsonParserGeneratorTest {

@Test
fun `generates valid deserializers`() {
val model = RecursiveShapeBoxer.transform(OperationNormalizer.transform(baseModel))
val model = RecursiveShapeBoxer().transform(OperationNormalizer.transform(baseModel))
val codegenContext = testCodegenContext(model)
val symbolProvider = codegenContext.symbolProvider
fun builderSymbol(shape: StructureShape): Symbol =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ internal class XmlBindingTraitParserGeneratorTest {

@Test
fun `generates valid parsers`() {
val model = RecursiveShapeBoxer.transform(OperationNormalizer.transform(baseModel))
val model = RecursiveShapeBoxer().transform(OperationNormalizer.transform(baseModel))
val codegenContext = testCodegenContext(model)
val symbolProvider = codegenContext.symbolProvider
val parserGenerator = XmlBindingTraitParserGenerator(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ class AwsQuerySerializerGeneratorTest {
true -> CodegenTarget.CLIENT
false -> CodegenTarget.SERVER
}
val model = RecursiveShapeBoxer.transform(OperationNormalizer.transform(baseModel))
val model = RecursiveShapeBoxer().transform(OperationNormalizer.transform(baseModel))
val codegenContext = testCodegenContext(model, codegenTarget = codegenTarget)
val symbolProvider = codegenContext.symbolProvider
val parserGenerator = AwsQuerySerializerGenerator(testCodegenContext(model, codegenTarget = codegenTarget))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ class Ec2QuerySerializerGeneratorTest {

@Test
fun `generates valid serializers`() {
val model = RecursiveShapeBoxer.transform(OperationNormalizer.transform(baseModel))
val model = RecursiveShapeBoxer().transform(OperationNormalizer.transform(baseModel))
val codegenContext = testCodegenContext(model)
val symbolProvider = codegenContext.symbolProvider
val parserGenerator = Ec2QuerySerializerGenerator(codegenContext)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ class JsonSerializerGeneratorTest {

@Test
fun `generates valid serializers`() {
val model = RecursiveShapeBoxer.transform(OperationNormalizer.transform(baseModel))
val model = RecursiveShapeBoxer().transform(OperationNormalizer.transform(baseModel))
val codegenContext = testCodegenContext(model)
val symbolProvider = codegenContext.symbolProvider
val parserSerializer = JsonSerializerGenerator(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ internal class XmlBindingTraitSerializerGeneratorTest {

@Test
fun `generates valid serializers`() {
val model = RecursiveShapeBoxer.transform(OperationNormalizer.transform(baseModel))
val model = RecursiveShapeBoxer().transform(OperationNormalizer.transform(baseModel))
val codegenContext = testCodegenContext(model)
val symbolProvider = codegenContext.symbolProvider
val parserGenerator = XmlBindingTraitSerializerGenerator(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ internal class RecursiveShapeBoxerTest {
hello: Hello
}
""".asSmithyModel()
RecursiveShapeBoxer.transform(model) shouldBe model
RecursiveShapeBoxer().transform(model) shouldBe model
}

@Test
Expand All @@ -43,7 +43,7 @@ internal class RecursiveShapeBoxerTest {
anotherField: Boolean
}
""".asSmithyModel()
val transformed = RecursiveShapeBoxer.transform(model)
val transformed = RecursiveShapeBoxer().transform(model)
val member: MemberShape = transformed.lookup("com.example#Recursive\$RecursiveStruct")
member.expectTrait<RustBoxTrait>()
}
Expand All @@ -70,7 +70,7 @@ internal class RecursiveShapeBoxerTest {
third: SecondTree
}
""".asSmithyModel()
val transformed = RecursiveShapeBoxer.transform(model)
val transformed = RecursiveShapeBoxer().transform(model)
val boxed = transformed.shapes().filter { it.hasTrait<RustBoxTrait>() }.toList()
boxed.map { it.id.toString().removePrefix("com.example#") }.toSet() shouldBe setOf(
"Atom\$add",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class RecursiveShapesIntegrationTest {
}
output.message shouldContain "has infinite size"

val fixedProject = check(RecursiveShapeBoxer.transform(model))
val fixedProject = check(RecursiveShapeBoxer().transform(model))
fixedProject.compileAndTest()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.Ser
import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerProtocolLoader
import software.amazon.smithy.rust.codegen.server.smithy.traits.isReachableFromOperationInput
import software.amazon.smithy.rust.codegen.server.smithy.transformers.AttachValidationExceptionToConstrainedOperationInputsInAllowList
import software.amazon.smithy.rust.codegen.server.smithy.transformers.RecursiveConstraintViolationBoxer
import software.amazon.smithy.rust.codegen.server.smithy.transformers.RemoveEbsModelValidationException
import software.amazon.smithy.rust.codegen.server.smithy.transformers.ShapesReachableFromOperationInputTagger
import java.util.logging.Logger
Expand Down Expand Up @@ -162,7 +163,9 @@ open class ServerCodegenVisitor(
// Add errors attached at the service level to the models
.let { ModelTransformer.create().copyServiceErrorsToOperations(it, settings.getService(it)) }
// Add `Box<T>` to recursive shapes as necessary
.let(RecursiveShapeBoxer::transform)
.let(RecursiveShapeBoxer()::transform)
// Add `Box<T>` to recursive constraint violations as necessary
.let(RecursiveConstraintViolationBoxer::transform)
// Normalize operations by adding synthetic input and output shapes to every operation
.let(OperationNormalizer::transform)
// Remove the EBS model's own `ValidationException`, which collides with `smithy.framework#ValidationException`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,15 @@ import software.amazon.smithy.model.shapes.CollectionShape
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.rustlang.Visibility
import software.amazon.smithy.rust.codegen.core.rustlang.join
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.smithy.makeRustBoxed
import software.amazon.smithy.rust.codegen.core.smithy.module
import software.amazon.smithy.rust.codegen.core.util.hasTrait
import software.amazon.smithy.rust.codegen.core.util.letIf
import software.amazon.smithy.rust.codegen.server.smithy.PubCrateConstraintViolationSymbolProvider
import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext
import software.amazon.smithy.rust.codegen.server.smithy.canReachConstrainedShape
import software.amazon.smithy.rust.codegen.server.smithy.traits.ConstraintViolationRustBoxTrait
import software.amazon.smithy.rust.codegen.server.smithy.traits.isReachableFromOperationInput

class CollectionConstraintViolationGenerator(
Expand All @@ -38,16 +41,22 @@ class CollectionConstraintViolationGenerator(
private val constraintsInfo: List<TraitInfo> = collectionConstraintsInfo.map { it.toTraitInfo() }

fun render() {
val memberShape = model.expectShape(shape.member.target)
val targetShape = model.expectShape(shape.member.target)
val constraintViolationSymbol = constraintViolationSymbolProvider.toSymbol(shape)
val constraintViolationName = constraintViolationSymbol.name
val isMemberConstrained = memberShape.canReachConstrainedShape(model, symbolProvider)
val isMemberConstrained = targetShape.canReachConstrainedShape(model, symbolProvider)
val constraintViolationVisibility = Visibility.publicIf(publicConstrainedTypes, Visibility.PUBCRATE)

modelsModuleWriter.withInlineModule(constraintViolationSymbol.module()) {
val constraintViolationVariants = constraintsInfo.map { it.constraintViolationVariant }.toMutableList()
if (isMemberConstrained) {
constraintViolationVariants += {
val memberConstraintViolationSymbol =
constraintViolationSymbolProvider.toSymbol(targetShape).letIf(
shape.member.hasTrait<ConstraintViolationRustBoxTrait>(),
) {
it.makeRustBoxed()
}
rustTemplate(
"""
/// Constraint violation error when an element doesn't satisfy its own constraints.
Expand All @@ -56,7 +65,7 @@ class CollectionConstraintViolationGenerator(
##[doc(hidden)]
Member(usize, #{MemberConstraintViolationSymbol})
""",
"MemberConstraintViolationSymbol" to constraintViolationSymbolProvider.toSymbol(memberShape),
"MemberConstraintViolationSymbol" to memberConstraintViolationSymbol,
)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,13 @@ import software.amazon.smithy.model.traits.LengthTrait
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.rustlang.Visibility
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.smithy.makeRustBoxed
import software.amazon.smithy.rust.codegen.core.smithy.module
import software.amazon.smithy.rust.codegen.core.util.hasTrait
import software.amazon.smithy.rust.codegen.core.util.letIf
import software.amazon.smithy.rust.codegen.server.smithy.PubCrateConstraintViolationSymbolProvider
import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext
import software.amazon.smithy.rust.codegen.server.smithy.traits.ConstraintViolationRustBoxTrait
import software.amazon.smithy.rust.codegen.server.smithy.traits.isReachableFromOperationInput

class MapConstraintViolationGenerator(
Expand Down Expand Up @@ -47,7 +50,14 @@ class MapConstraintViolationGenerator(
constraintViolationCodegenScopeMutableList.add("KeyConstraintViolationSymbol" to constraintViolationSymbolProvider.toSymbol(keyShape))
}
if (isValueConstrained(valueShape, model, symbolProvider)) {
constraintViolationCodegenScopeMutableList.add("ValueConstraintViolationSymbol" to constraintViolationSymbolProvider.toSymbol(valueShape))
constraintViolationCodegenScopeMutableList.add(
"ValueConstraintViolationSymbol" to
constraintViolationSymbolProvider.toSymbol(valueShape).letIf(
shape.value.hasTrait<ConstraintViolationRustBoxTrait>(),
) {
it.makeRustBoxed()
},
)
constraintViolationCodegenScopeMutableList.add("KeySymbol" to constrainedShapeSymbolProvider.toSymbol(keyShape))
}
val constraintViolationCodegenScope = constraintViolationCodegenScopeMutableList.toTypedArray()
Expand Down
Loading

0 comments on commit 8d79352

Please sign in to comment.