Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update serde implementation to support out of range floats #3825

Merged
merged 2 commits into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import software.amazon.smithy.model.shapes.BlobShape
import software.amazon.smithy.model.shapes.BooleanShape
import software.amazon.smithy.model.shapes.CollectionShape
import software.amazon.smithy.model.shapes.DocumentShape
import software.amazon.smithy.model.shapes.DoubleShape
import software.amazon.smithy.model.shapes.FloatShape
import software.amazon.smithy.model.shapes.MapShape
import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.NumberShape
Expand Down Expand Up @@ -208,13 +210,45 @@ class SerializeImplGenerator(private val codegenContext: CodegenContext) {
* For enums, it adds `as_str()` to convert it into a string directly.
*/
private fun serializeNumber(shape: NumberShape): RuntimeType {
val numericType = SimpleShapes.getValue(shape::class)
return when (shape) {
is FloatShape, is DoubleShape -> serializeFloat(shape)
else ->
RuntimeType.forInlineFun(
numericType.toString(),
PrimitiveShapesModule,
) {
implSerializeConfigured(symbolBuilder(shape, numericType).build()) {
rustTemplate("self.value.serialize(serializer)")
}
}
}
}

private fun serializeFloat(shape: NumberShape): RuntimeType {
val numericType = SimpleShapes.getValue(shape::class)
return RuntimeType.forInlineFun(
numericType.toString(),
PrimitiveShapesModule,
) {
implSerializeConfigured(symbolBuilder(shape, numericType).build()) {
rustTemplate("self.value.serialize(serializer)")
rustTemplate(
"""
if !self.settings.out_of_range_floats_as_strings {
return self.value.serialize(serializer)
}
if self.value.is_nan() {
serializer.serialize_str("NaN")
} else if *self.value == #{ty}::INFINITY {
serializer.serialize_str("Infinity")
} else if *self.value == #{ty}::NEG_INFINITY {
serializer.serialize_str("-Infinity")
} else {
self.value.serialize(serializer)
}
""",
"ty" to numericType,
)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ object SupportStructures {
{
use #{serde}::Serialize;
value
.serialize_ref(&#{SerializationSettings} { redact_sensitive_fields: true })
.serialize_ref(&#{SerializationSettings}::redact_sensitive_fields())
.serialize(serializer)
}
""",
Expand All @@ -70,7 +70,7 @@ object SupportStructures {
{
use #{serde}::Serialize;
value
.serialize_ref(&#{SerializationSettings} { redact_sensitive_fields: false })
.serialize_ref(&#{SerializationSettings}::leak_sensitive_fields())
.serialize(serializer)
}
""",
Expand Down Expand Up @@ -211,7 +211,6 @@ object SupportStructures {

private fun serializationSettings() =
RuntimeType.forInlineFun("SerializationSettings", supportModule) {
// TODO(serde): Consider removing `derive(Default)`
rustTemplate(
"""
/// Settings for use when serializing structures
Expand All @@ -220,17 +219,23 @@ object SupportStructures {
pub struct SerializationSettings {
/// Replace all sensitive fields with `<redacted>` during serialization
pub redact_sensitive_fields: bool,

/// Serialize Nan, infinity and negative infinity as strings.
///
/// For protocols like JSON, this avoids the loss-of-information that occurs when these out-of-range values
/// are serialized as null.
pub out_of_range_floats_as_strings: bool,
}

impl SerializationSettings {
/// Replace all `@sensitive` fields with `<redacted>` when serializing.
///
/// Note: This may alter the type of the serialized output and make it impossible to deserialize as
/// numerical fields will be replaced with strings.
pub const fn redact_sensitive_fields() -> Self { Self { redact_sensitive_fields: true } }
pub const fn redact_sensitive_fields() -> Self { Self { redact_sensitive_fields: true, out_of_range_floats_as_strings: false } }

/// Preserve the contents of sensitive fields during serializing
pub const fn leak_sensitive_fields() -> Self { Self { redact_sensitive_fields: false } }
pub const fn leak_sensitive_fields() -> Self { Self { redact_sensitive_fields: false, out_of_range_floats_as_strings: false } }
}
""",
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,9 @@ class SerdeDecoratorTest {
blob: SensitiveBlob,
constrained: Constrained,
recursive: Recursive,
map: EnumKeyedMap
map: EnumKeyedMap,
float: Float,
double: Double
}

structure Constrained {
Expand Down Expand Up @@ -134,6 +136,8 @@ class SerdeDecoratorTest {
structure Nested {
@required
int: Integer,
float: Float,
double: Double,
sensitive: Timestamps,
notSensitive: AlsoTimestamps,
manyEnums: TestEnumList,
Expand Down Expand Up @@ -202,8 +206,12 @@ class SerdeDecoratorTest {
.e(Some(TestEnum::A))
.document(Some(Document::String("hello!".into())))
.blob(Some(Blob::new("hello")))
.float(Some(f32::INFINITY))
.double(Some(f64::NAN))
.nested(Some(Nested::builder()
.int(5)
.float(Some(f32::NEG_INFINITY))
.double(Some(f64::NEG_INFINITY))
.sensitive(Some(sensitive_map.clone()))
.not_sensitive(Some(sensitive_map))
.many_enums(Some(vec![TestEnum::A]))
Expand Down Expand Up @@ -274,6 +282,8 @@ class SerdeDecoratorTest {
"e": "A",
"nested": {
"int": 5,
"float": "-Infinity",
"double": "-Infinity",
"sensitive": {
"a": "1970-01-01T00:00:00Z"
},
Expand All @@ -289,7 +299,9 @@ class SerdeDecoratorTest {
"enum": "B"
},
"document": "hello!",
"blob": "aGVsbG8="
"blob": "aGVsbG8=",
"float": "Infinity",
"double": "NaN"
}""".replace("\\s".toRegex(), "")

private val expectedRedacted =
Expand All @@ -298,6 +310,8 @@ class SerdeDecoratorTest {
"e": "<redacted>",
"nested": {
"int": 5,
"float": "-Infinity",
"double": "-Infinity",
"sensitive": {
"a": "<redacted>"
},
Expand All @@ -311,7 +325,9 @@ class SerdeDecoratorTest {
},
"union": "<redacted>",
"document": "hello!",
"blob": "<redacted>"
"blob": "<redacted>",
"float": "Infinity",
"double": "NaN"
}
""".replace("\\s".toRegex(), "")

Expand Down Expand Up @@ -343,8 +359,12 @@ class SerdeDecoratorTest {
.e("A".into())
.document(Document::String("hello!".into()))
.blob(Blob::new("hello"))
.float(f32::INFINITY)
.double(f64::NAN)
.nested(Nested::builder()
.int(5)
.float(f32::NEG_INFINITY)
.double(f64::NEG_INFINITY)
.sensitive("a", DateTime::from(UNIX_EPOCH))
.not_sensitive("a", DateTime::from(UNIX_EPOCH))
.many_enums("A".into())
Expand All @@ -355,11 +375,15 @@ class SerdeDecoratorTest {
.build()
.unwrap();
let mut settings = #{crate}::serde::SerializationSettings::default();
settings.out_of_range_floats_as_strings = true;
let serialized = #{serde_json}::to_string(&input.serialize_ref(&settings)).expect("failed to serialize");
assert_eq!(serialized, ${expectedNoRedactions.dq()});
settings.redact_sensitive_fields = true;
let serialized = #{serde_json}::to_string(&input.serialize_ref(&settings)).expect("failed to serialize");
assert_eq!(serialized, ${expectedRedacted.dq()});
settings.out_of_range_floats_as_strings = false;
let serialized = #{serde_json}::to_string(&input.serialize_ref(&settings)).expect("failed to serialize");
assert_ne!(serialized, ${expectedRedacted.dq()});
""",
*codegenScope,
)
Expand Down
Loading