Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enforce constraints for unnamed enums #3884

Open
wants to merge 10 commits into from
18 changes: 18 additions & 0 deletions .changelog/4329788.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
---
applies_to: ["server"]
authors: ["drganjoo"]
references: ["smithy-rs#3880"]
breaking: true
new_feature: false
bug_fix: true
---
Unnamed enums now validate assigned values and will raise a `ConstraintViolation` if an unknown variant is set.
drganjoo marked this conversation as resolved.
Show resolved Hide resolved

The following is an example of an unnamed enum:
```smithy
@enum([
{ value: "MONDAY" },
{ value: "TUESDAY" }
])
string UnnamedDayOfWeek
```
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,37 @@ data class InfallibleEnumType(
)
}

override fun implFromForStrForUnnamedEnum(context: EnumGeneratorContext): Writable =
drganjoo marked this conversation as resolved.
Show resolved Hide resolved
writable {
rustTemplate(
"""
impl<T> #{From}<T> for ${context.enumName} where T: #{AsRef}<str> {
fn from(s: T) -> Self {
${context.enumName}(s.as_ref().to_owned())
drganjoo marked this conversation as resolved.
Show resolved Hide resolved
}
}
""",
*preludeScope,
)
}

override fun implFromStrForUnnamedEnum(context: EnumGeneratorContext): Writable =
writable {
// Add an infallible FromStr implementation for uniformity
rustTemplate(
"""
impl ::std::str::FromStr for ${context.enumName} {
type Err = ::std::convert::Infallible;

fn from_str(s: &str) -> #{Result}<Self, <Self as ::std::str::FromStr>::Err> {
#{Ok}(${context.enumName}::from(s))
}
}
""",
*preludeScope,
)
}

override fun additionalEnumImpls(context: EnumGeneratorContext): Writable =
writable {
// `try_parse` isn't needed for unnamed enums
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ internal class ClientInstantiatorTest {
val shape = model.lookup<StringShape>("com.test#UnnamedEnum")
val sut = ClientInstantiator(codegenContext)
val data = Node.parse("t2.nano".dq())
// The client SDK should accept unknown variants as valid.
val notValidVariant = Node.parse("not-a-valid-variant".dq())

val project = TestWorkspace.testProject(symbolProvider)
project.moduleFor(shape) {
Expand All @@ -77,7 +79,11 @@ internal class ClientInstantiatorTest {
withBlock("let result = ", ";") {
sut.render(this, shape, data)
}
rust("""assert_eq!(result, UnnamedEnum("t2.nano".to_owned()));""")
rust("""assert_eq!(result, UnnamedEnum("$data".to_owned()));""")
withBlock("let result = ", ";") {
sut.render(this, shape, notValidVariant)
}
rust("""assert_eq!(result, UnnamedEnum("$notValidVariant".to_owned()));""")
}
}
project.compileAndTest()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,12 @@ abstract class EnumType {
/** Returns a writable that implements `FromStr` for the enum */
abstract fun implFromStr(context: EnumGeneratorContext): Writable

/** Returns a writable that implements `From<&str>` and/or `TryFrom<&str>` for the unnamed enum */
abstract fun implFromForStrForUnnamedEnum(context: EnumGeneratorContext): Writable

/** Returns a writable that implements `FromStr` for the unnamed enum */
abstract fun implFromStrForUnnamedEnum(context: EnumGeneratorContext): Writable

/** Optionally adds additional documentation to the `enum` docs */
open fun additionalDocs(context: EnumGeneratorContext): Writable = writable {}

Expand Down Expand Up @@ -237,32 +243,10 @@ open class EnumGenerator(
rust("&self.0")
},
)

// Add an infallible FromStr implementation for uniformity
rustTemplate(
"""
impl ::std::str::FromStr for ${context.enumName} {
type Err = ::std::convert::Infallible;

fn from_str(s: &str) -> #{Result}<Self, <Self as ::std::str::FromStr>::Err> {
#{Ok}(${context.enumName}::from(s))
}
}
""",
*preludeScope,
)

rustTemplate(
"""
impl<T> #{From}<T> for ${context.enumName} where T: #{AsRef}<str> {
fn from(s: T) -> Self {
${context.enumName}(s.as_ref().to_owned())
}
}

""",
*preludeScope,
)
// impl From<str> for Blah { ... }
enumType.implFromForStrForUnnamedEnum(context)(this)
// impl FromStr for Blah { ... }
enumType.implFromStrForUnnamedEnum(context)(this)
}

private fun RustWriter.renderEnum() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,16 @@ class EnumGeneratorTest {
// intentional no-op
}

override fun implFromForStrForUnnamedEnum(context: EnumGeneratorContext): Writable =
writable {
// intentional no-op
}

override fun implFromStrForUnnamedEnum(context: EnumGeneratorContext): Writable =
writable {
// intentional no-op
}

override fun additionalEnumMembers(context: EnumGeneratorContext): Writable =
writable {
rust("// additional enum members")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.preludeScope
import software.amazon.smithy.rust.codegen.core.util.dq

object TestEnumType : EnumType() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Huh, I wonder why we have this class at all i.e. why we don't test directly against InfallibleEnumType. It feels wrong to copy over the implementations from the "real" classes to this class.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with you on this. I'll see if it can be easily fixed in this PR, otherwise will raise another one for this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

InfallibleEnumType allows UnknownVariants, whereas TestEnumType panics when an unknown variant is encountered. I tried composing TestEnumType to use InfallibleEnumType internally, but InfallibleEnumType is defined in codegen-client, while TestEnumType is in codegen-core.

Expand Down Expand Up @@ -49,4 +50,35 @@ object TestEnumType : EnumType() {
""",
)
}

override fun implFromForStrForUnnamedEnum(context: EnumGeneratorContext): Writable =
writable {
rustTemplate(
"""
impl<T> #{From}<T> for ${context.enumName} where T: #{AsRef}<str> {
fn from(s: T) -> Self {
${context.enumName}(s.as_ref().to_owned())
}
}
""",
*preludeScope,
)
}

override fun implFromStrForUnnamedEnum(context: EnumGeneratorContext): Writable =
writable {
// Add an infallible FromStr implementation for uniformity
rustTemplate(
"""
impl ::std::str::FromStr for ${context.enumName} {
type Err = ::std::convert::Infallible;

fn from_str(s: &str) -> #{Result}<Self, <Self as ::std::str::FromStr>::Err> {
#{Ok}(${context.enumName}::from(s))
}
}
""",
*preludeScope,
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@
package software.amazon.smithy.rust.codegen.server.smithy.generators

import software.amazon.smithy.model.shapes.StringShape
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
Expand Down Expand Up @@ -39,16 +38,14 @@ open class ConstrainedEnum(
}
private val constraintViolationSymbol = constraintViolationSymbolProvider.toSymbol(shape)
private val constraintViolationName = constraintViolationSymbol.name
private val codegenScope =
arrayOf(
"String" to RuntimeType.String,
)

override fun implFromForStr(context: EnumGeneratorContext): Writable =
writable {
withInlineModule(constraintViolationSymbol.module(), codegenContext.moduleDocProvider) {
rustTemplate(
"""
private fun generateConstraintViolation(
context: EnumGeneratorContext,
generateTryFromStrAndString: RustWriter.(EnumGeneratorContext) -> Unit,
) = writable {
withInlineModule(constraintViolationSymbol.module(), codegenContext.moduleDocProvider) {
rustTemplate(
"""
##[derive(Debug, PartialEq)]
pub struct $constraintViolationName(pub(crate) #{String});
Expand All @@ -60,47 +57,86 @@ open class ConstrainedEnum(
impl #{Error} for $constraintViolationName {}
""",
*codegenScope,
"Error" to RuntimeType.StdError,
"Display" to RuntimeType.Display,
)
*preludeScope,
"Error" to RuntimeType.StdError,
"Display" to RuntimeType.Display,
)

if (shape.isReachableFromOperationInput()) {
rustTemplate(
"""
if (shape.isReachableFromOperationInput()) {
rustTemplate(
"""
impl $constraintViolationName {
#{EnumShapeConstraintViolationImplBlock:W}
}
""",
"EnumShapeConstraintViolationImplBlock" to
validationExceptionConversionGenerator.enumShapeConstraintViolationImplBlock(
context.enumTrait,
),
)
}
"EnumShapeConstraintViolationImplBlock" to
validationExceptionConversionGenerator.enumShapeConstraintViolationImplBlock(
context.enumTrait,
),
)
}
rustBlock("impl #T<&str> for ${context.enumName}", RuntimeType.TryFrom) {
rust("type Error = #T;", constraintViolationSymbol)
rustBlockTemplate("fn try_from(s: &str) -> #{Result}<Self, <Self as #{TryFrom}<&str>>::Error>", *preludeScope) {
rustBlock("match s") {
context.sortedMembers.forEach { member ->
rust("${member.value.dq()} => Ok(${context.enumName}::${member.derivedName()}),")
}

generateTryFromStrAndString(context)
}

override fun implFromForStr(context: EnumGeneratorContext): Writable =
generateConstraintViolation(context) {
rustTemplate(
"""
impl #{TryFrom}<&str> for ${context.enumName} {
type Error = #{ConstraintViolation};
fn try_from(s: &str) -> #{Result}<Self, <Self as #{TryFrom}<&str>>::Error> {
match s {
#{MatchArms}
_ => Err(#{ConstraintViolation}(s.to_owned()))
}
rust("_ => Err(#T(s.to_owned()))", constraintViolationSymbol)
}
}
}
impl #{TryFrom}<#{String}> for ${context.enumName} {
type Error = #{ConstraintViolation};
fn try_from(s: #{String}) -> #{Result}<Self, <Self as #{TryFrom}<#{String}>>::Error> {
s.as_str().try_into()
}
}
""",
*preludeScope,
"ConstraintViolation" to constraintViolationSymbol,
"MatchArms" to
writable {
context.sortedMembers.forEach { member ->
rust("${member.value.dq()} => Ok(${context.enumName}::${member.derivedName()}),")
}
},
)
}

override fun implFromForStrForUnnamedEnum(context: EnumGeneratorContext): Writable =
generateConstraintViolation(context) {
rustTemplate(
"""
impl #{TryFrom}<&str> for ${context.enumName} {
type Error = #{ConstraintViolation};
fn try_from(s: &str) -> #{Result}<Self, <Self as #{TryFrom}<&str>>::Error> {
s.to_owned().try_into()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is first converting to a heap-allocated String and only then matching on the enum values. So invalid enum values would unnecessarily be heap-allocated. We should make TryFrom<String> and TryFrom<&str> instead delegate to FromStr, which should only heap-allocate when the enum value is valid.

Copy link
Contributor Author

@drganjoo drganjoo Oct 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unlike named enums, which can use &str for comparison and directly return an enum variant, unnamed enums need to store an owned, heap-allocated String. The TryFrom<String> implementation already receives an owned String as a parameter, so delegating to FromStr would result in an unnecessary heap allocation. Additionally, ConstraintViolation takes ownership of a heap-allocated String. Therefore, both code paths (valid and invalid enum values) require a heap-allocated String. Calling to_owned within TryFrom<&str> only shifts the allocation slightly earlier without adding any additional allocation.

}
}
impl #{TryFrom}<#{String}> for ${context.enumName} {
type Error = #{ConstraintViolation};
fn try_from(s: #{String}) -> #{Result}<Self, <Self as #{TryFrom}<#{String}>>::Error> {
s.as_str().try_into()
match s.as_str() {
#{Values} => Ok(Self(s)),
_ => Err(#{ConstraintViolation}(s))
}
}
}
""",
*preludeScope,
"ConstraintViolation" to constraintViolationSymbol,
"Values" to
writable {
rust(context.sortedMembers.joinToString(" | ") { it.value.dq() })
},
)
}

Expand All @@ -118,6 +154,8 @@ open class ConstrainedEnum(
"ConstraintViolation" to constraintViolationSymbol,
)
}

override fun implFromStrForUnnamedEnum(context: EnumGeneratorContext) = implFromStr(context)
}

class ServerEnumGenerator(
Expand Down
Loading
Loading