diff --git a/tests/unit/ast/nodes/test_fold_compare.py b/tests/unit/ast/nodes/test_fold_compare.py index aab8ac0b2d..fd9f65a7d3 100644 --- a/tests/unit/ast/nodes/test_fold_compare.py +++ b/tests/unit/ast/nodes/test_fold_compare.py @@ -110,3 +110,20 @@ def test_compare_type_mismatch(op): old_node = vyper_ast.body[0].value with pytest.raises(UnfoldableNode): old_node.get_folded_value() + + +@pytest.mark.parametrize("op", ["==", "!="]) +def test_compare_eq_bytes(get_contract, op): + left, right = "0xA1AAB33F", "0xa1aab33f" + source = f""" +@external +def foo(a: bytes4, b: bytes4) -> bool: + return a {op} b + """ + contract = get_contract(source) + + vyper_ast = parse_and_fold(f"{left} {op} {right}") + old_node = vyper_ast.body[0].value + new_node = old_node.get_folded_value() + + assert contract.foo(left, right) == new_node.value diff --git a/vyper/semantics/analysis/constant_folding.py b/vyper/semantics/analysis/constant_folding.py index 6e4166dc52..98cab0f8cb 100644 --- a/vyper/semantics/analysis/constant_folding.py +++ b/vyper/semantics/analysis/constant_folding.py @@ -180,8 +180,11 @@ def visit_Compare(self, node): raise UnfoldableNode( f"Invalid literal types for {node.op.description} comparison", node ) - - value = node.op._op(left.value, right.value) + lvalue, rvalue = left.value, right.value + if isinstance(left, vy_ast.Hex): + # Hex values are str, convert to be case-unsensitive. + lvalue, rvalue = lvalue.lower(), rvalue.lower() + value = node.op._op(lvalue, rvalue) return vy_ast.NameConstant.from_node(node, value=value) def visit_List(self, node) -> vy_ast.ExprNode: diff --git a/vyper/utils.py b/vyper/utils.py index f4f14a346e..3f19a9d15c 100644 --- a/vyper/utils.py +++ b/vyper/utils.py @@ -25,9 +25,10 @@ class OrderedSet(Generic[_T]): """ def __init__(self, iterable=None): - self._data = dict() - if iterable is not None: - self.update(iterable) + if iterable is None: + self._data = dict() + else: + self._data = dict.fromkeys(iterable) def __repr__(self): keys = ", ".join(repr(k) for k in self) @@ -57,6 +58,7 @@ def pop(self): def add(self, item: _T) -> None: self._data[item] = None + # NOTE to refactor: duplicate of self.update() def addmany(self, iterable): for item in iterable: self._data[item] = None @@ -109,11 +111,11 @@ def intersection(cls, *sets): if len(sets) == 0: raise ValueError("undefined: intersection of no sets") - ret = sets[0].copy() - for e in sets[0]: - if any(e not in s for s in sets[1:]): - ret.remove(e) - return ret + tmp = sets[0]._data.keys() + for s in sets[1:]: + tmp &= s._data.keys() + + return cls(tmp) class StringEnum(enum.Enum):