From 92285e4cb996b787b1321e0567d1507dda1d07ef Mon Sep 17 00:00:00 2001 From: Dennis Brakhane Date: Tue, 15 Feb 2022 21:26:47 +0100 Subject: [PATCH] Fix handling of Enums in Literal types --- src/cattr/converters.py | 14 +++++++++--- tests/test_structure_attrs.py | 41 ++++++++++++++++++++++++++++++++++- 2 files changed, 51 insertions(+), 4 deletions(-) diff --git a/src/cattr/converters.py b/src/cattr/converters.py index aa08ab35..18d9eb89 100644 --- a/src/cattr/converters.py +++ b/src/cattr/converters.py @@ -405,9 +405,17 @@ def _structure_call(obj, cl): @staticmethod def _structure_literal(val, type): - if val not in type.__args__: - raise Exception(f"{val} not in literal {type}") - return val + vals = set(type.__args__) + enums = {x for x in vals if isinstance(x, Enum)} + literal_vals = vals.difference(enums) + if val not in literal_vals: + enum_vals = {x.value: x for x in enums} + if val not in enum_vals: + raise Exception(f"{val} not in literal {type}") + else: + return enum_vals[val] + else: + return val # Attrs classes. diff --git a/tests/test_structure_attrs.py b/tests/test_structure_attrs.py index dbd5a6c1..82116108 100644 --- a/tests/test_structure_attrs.py +++ b/tests/test_structure_attrs.py @@ -1,4 +1,5 @@ """Loading of attrs classes.""" +from enum import Enum from ipaddress import IPv4Address, IPv6Address, ip_address from typing import Union from unittest.mock import Mock @@ -164,6 +165,27 @@ class ClassWithLiteral: ) == ClassWithLiteral(4) +@pytest.mark.skipif(is_py37, reason="Not supported on 3.7") +@pytest.mark.parametrize("converter_cls", [Converter, GenConverter]) +def test_structure_literal_enum(converter_cls): + """Structuring a class with a literal field works.""" + from typing import Literal + + converter = converter_cls() + + class Foo(Enum): + FOO = 1 + BAR = 2 + + @define + class ClassWithLiteral: + literal_field: Literal[Foo.FOO] = Foo.FOO + + assert converter.structure( + {"literal_field": 1}, ClassWithLiteral + ) == ClassWithLiteral(Foo.FOO) + + @pytest.mark.skipif(is_py37, reason="Not supported on 3.7") @pytest.mark.parametrize("converter_cls", [Converter, GenConverter]) def test_structure_literal_multiple(converter_cls): @@ -172,9 +194,17 @@ def test_structure_literal_multiple(converter_cls): converter = converter_cls() + class Foo(Enum): + FOO = 1 + FOOFOO = 2 + + class Bar(Enum): + BAR = 8 + BARBAR = 9 + @define class ClassWithLiteral: - literal_field: Literal[4, 5] = 4 + literal_field: Literal[4, 5, Literal[6], Foo.FOO, Bar.BARBAR] = 4 assert converter.structure( {"literal_field": 4}, ClassWithLiteral @@ -182,6 +212,15 @@ class ClassWithLiteral: assert converter.structure( {"literal_field": 5}, ClassWithLiteral ) == ClassWithLiteral(5) + assert converter.structure( + {"literal_field": 6}, ClassWithLiteral + ) == ClassWithLiteral(6) + assert converter.structure( + {"literal_field": 1}, ClassWithLiteral + ) == ClassWithLiteral(Foo.FOO) + assert converter.structure( + {"literal_field": 9}, ClassWithLiteral + ) == ClassWithLiteral(Bar.BARBAR) @pytest.mark.skipif(is_py37, reason="Not supported on 3.7")