diff --git a/libcst/_nodes/statement.py b/libcst/_nodes/statement.py index 6cad7684..1cb9221f 100644 --- a/libcst/_nodes/statement.py +++ b/libcst/_nodes/statement.py @@ -2854,17 +2854,16 @@ def _visit_and_replace_children(self, visitor: CSTVisitorT) -> "CSTNode": self, "whitespace_after_case", self.whitespace_after_case, visitor ), pattern=visit_required(self, "pattern", self.pattern, visitor), - # pyre-fixme[6]: Expected `SimpleWhitespace` for 4th param but got - # `Optional[SimpleWhitespace]`. - whitespace_before_if=visit_optional( + whitespace_before_if=visit_required( self, "whitespace_before_if", self.whitespace_before_if, visitor ), - # pyre-fixme[6]: Expected `SimpleWhitespace` for 5th param but got - # `Optional[SimpleWhitespace]`. - whitespace_after_if=visit_optional( + whitespace_after_if=visit_required( self, "whitespace_after_if", self.whitespace_after_if, visitor ), guard=visit_optional(self, "guard", self.guard, visitor), + whitespace_before_colon=visit_required( + self, "whitespace_before_colon", self.whitespace_before_colon, visitor + ), body=visit_required(self, "body", self.body, visitor), ) @@ -3382,6 +3381,7 @@ def _visit_and_replace_children(self, visitor: CSTVisitorT) -> "MatchClass": whitespace_after_kwds=visit_required( self, "whitespace_after_kwds", self.whitespace_after_kwds, visitor ), + rpar=visit_sequence(self, "rpar", self.rpar, visitor), ) def _codegen_impl(self, state: CodegenState) -> None: diff --git a/libcst/tests/test_roundtrip.py b/libcst/tests/test_roundtrip.py index e3a7a35b..d5da81f2 100644 --- a/libcst/tests/test_roundtrip.py +++ b/libcst/tests/test_roundtrip.py @@ -3,25 +3,43 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. + from pathlib import Path from unittest import TestCase -from libcst import parse_module +from libcst import CSTTransformer, parse_module from libcst._parser.entrypoints import is_native fixtures: Path = Path(__file__).parent.parent.parent / "native/libcst/tests/fixtures" +class NOOPTransformer(CSTTransformer): + pass + + class RoundTripTests(TestCase): - def test_clean_roundtrip(self) -> None: + def _get_fixtures(self) -> list[Path]: if not is_native(): self.skipTest("pure python parser doesn't work with this") self.assertTrue(fixtures.exists(), f"{fixtures} should exist") files = list(fixtures.iterdir()) self.assertGreater(len(files), 0) - for file in files: + return files + + def test_clean_roundtrip(self) -> None: + for file in self._get_fixtures(): with self.subTest(file=str(file)): src = file.read_text(encoding="utf-8") mod = parse_module(src) self.maxDiff = None self.assertEqual(mod.code, src) + + def test_transform_roundtrip(self) -> None: + transformer = NOOPTransformer() + self.maxDiff = None + for file in self._get_fixtures(): + with self.subTest(file=str(file)): + src = file.read_text(encoding="utf-8") + mod = parse_module(src) + new_mod = mod.visit(transformer) + self.assertEqual(src, new_mod.code) diff --git a/native/libcst/tests/fixtures/malicious_match.py b/native/libcst/tests/fixtures/malicious_match.py index 8c46571f..54840022 100644 --- a/native/libcst/tests/fixtures/malicious_match.py +++ b/native/libcst/tests/fixtures/malicious_match.py @@ -37,4 +37,6 @@ case x,y , * more :pass case y.z: pass case 1, 2: pass + case ( Foo ( ) ) : pass + case (lol) if ( True , ) :pass