Skip to content

Commit

Permalink
Make enum forward-compatible
Browse files Browse the repository at this point in the history
This commit implements the suggested approach described in
#627 (comment)

The idea is that once the user writes a match expression against an enum
and assumes that an execution path comes to a particular match arm, we
should guarantee that when the user upgrades a version of SDK, the
execution path should come to the same match arm as before.
  • Loading branch information
Saito committed Nov 2, 2022
1 parent a026f6f commit cd61a8e
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ abstract class SymbolMetadataProvider(private val base: RustSymbolProvider) : Wr
is StringShape -> if (shape.hasTrait<EnumTrait>()) {
enumMeta(shape)
} else null

else -> null
}
return baseSymbol.toBuilder().meta(meta).build()
Expand Down Expand Up @@ -100,11 +101,13 @@ class BaseSymbolMetadataProvider(
)
}
}

container.isUnionShape ||
container.isListShape ||
container.isSetShape ||
container.isMapShape
-> RustMetadata(visibility = Visibility.PUBLIC)

else -> TODO("Unrecognized container type: $container")
}
}
Expand All @@ -120,9 +123,10 @@ class BaseSymbolMetadataProvider(
override fun enumMeta(stringShape: StringShape): RustMetadata {
return containerDefault.withDerives(
RuntimeType.std.member("hash::Hash"),
).withDerives( // enums can be eq because they can only contain strings
).withDerives(
// enums can be eq because the inner data also implements Eq
RuntimeType.std.member("cmp::Eq"),
// enums can be Ord because they can only contain strings
// enums can be Ord because the inner data also implements Ord
RuntimeType.std.member("cmp::PartialOrd"),
RuntimeType.std.member("cmp::Ord"),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,9 @@ open class EnumGenerator(
/** Name of the generated unknown enum member name for enums with named members. */
const val UnknownVariant = "Unknown"

/** Name of the opaque struct that is inner data for the generated [UnknownVariant]. */
const val UnknownVariantValue = "UnknownVariantValue"

/** Name of the function on the enum impl to get a vec of value names */
const val Values = "values"
}
Expand All @@ -108,6 +111,10 @@ open class EnumGenerator(
// pub enum Blah { V1, V2, .. }
renderEnum()
writer.insertTrailingNewline()
if (target == CodegenTarget.CLIENT) {
renderUnknownVariantValue()
}
writer.insertTrailingNewline()
// impl From<str> for Blah { ... }
renderFromForStr()
// impl FromStr for Blah { ... }
Expand Down Expand Up @@ -168,8 +175,8 @@ open class EnumGenerator(
writer.rustBlock("enum $enumName") {
sortedMembers.forEach { member -> member.render(writer) }
if (target == CodegenTarget.CLIENT) {
docs("$UnknownVariant contains new variants that have been added since this code was generated.")
rust("$UnknownVariant(String)")
docs("`$UnknownVariant` contains new variants that have been added since this code was generated.")
rust("$UnknownVariant($UnknownVariantValue)")
}
}
}
Expand All @@ -183,7 +190,7 @@ open class EnumGenerator(
rust("""$enumName::${member.derivedName()} => ${member.value.dq()},""")
}
if (target == CodegenTarget.CLIENT) {
rust("$enumName::$UnknownVariant(s) => s.as_ref()")
rust("$enumName::$UnknownVariant(value) => value.as_str()")
}
}
}
Expand All @@ -198,14 +205,25 @@ open class EnumGenerator(
}
}

private fun renderUnknownVariantValue() {
meta.render(writer)
writer.write("struct $UnknownVariantValue(String);")
writer.rustBlock("impl $UnknownVariantValue") {
// The generated as_str is not pub as we need to prevent users from calling it on this opaque struct.
rustBlock("fn as_str(&self) -> &str") {
rust("&self.0")
}
}
}

protected open fun renderFromForStr() {
writer.rustBlock("impl #T<&str> for $enumName", RuntimeType.From) {
rustBlock("fn from(s: &str) -> Self") {
rustBlock("match s") {
sortedMembers.forEach { member ->
rust("""${member.value.dq()} => $enumName::${member.derivedName()},""")
}
rust("other => $enumName::$UnknownVariant(other.to_owned())")
rust("other => $enumName::$UnknownVariant($UnknownVariantValue(other.to_owned()))")
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ class EnumGeneratorTest {
let instance = InstanceType::T2Micro;
assert_eq!(instance.as_str(), "t2.micro");
assert_eq!(InstanceType::from("t2.nano"), InstanceType::T2Nano);
assert_eq!(InstanceType::from("other"), InstanceType::Unknown("other".to_owned()));
assert_eq!(InstanceType::from("other"), InstanceType::Unknown(UnknownVariantValue("other".to_owned())));
// round trip unknown variants:
assert_eq!(InstanceType::from("other").as_str(), "other");
""",
Expand Down Expand Up @@ -250,7 +250,7 @@ class EnumGeneratorTest {
"""
assert_eq!(SomeEnum::from("Unknown"), SomeEnum::UnknownValue);
assert_eq!(SomeEnum::from("UnknownValue"), SomeEnum::UnknownValue_);
assert_eq!(SomeEnum::from("SomethingNew"), SomeEnum::Unknown("SomethingNew".into()));
assert_eq!(SomeEnum::from("SomethingNew"), SomeEnum::Unknown(UnknownVariantValue("SomethingNew".to_owned())));
""",
)
}
Expand All @@ -271,7 +271,9 @@ class EnumGeneratorTest {
val shape: StringShape = model.lookup("test#SomeEnum")
val trait = shape.expectTrait<EnumTrait>()
val provider = testSymbolProvider(model)
val rendered = RustWriter.forModule("model").also { EnumGenerator(model, provider, it, shape, trait).render() }.toString()
val rendered =
RustWriter.forModule("model").also { EnumGenerator(model, provider, it, shape, trait).render() }
.toString()

rendered shouldContain
"""
Expand All @@ -297,7 +299,9 @@ class EnumGeneratorTest {
val shape: StringShape = model.lookup("test#SomeEnum")
val trait = shape.expectTrait<EnumTrait>()
val provider = testSymbolProvider(model)
val rendered = RustWriter.forModule("model").also { EnumGenerator(model, provider, it, shape, trait).render() }.toString()
val rendered =
RustWriter.forModule("model").also { EnumGenerator(model, provider, it, shape, trait).render() }
.toString()

rendered shouldContain
"""
Expand Down Expand Up @@ -326,7 +330,7 @@ class EnumGeneratorTest {
writer.compileAndTest(
"""
assert_eq!(SomeEnum::from("other"), SomeEnum::SelfValue);
assert_eq!(SomeEnum::from("SomethingNew"), SomeEnum::Unknown("SomethingNew".into()));
assert_eq!(SomeEnum::from("SomethingNew"), SomeEnum::Unknown(UnknownVariantValue("SomethingNew".to_owned())));
""",
)
}
Expand Down

0 comments on commit cd61a8e

Please sign in to comment.