Skip to content

Commit

Permalink
Improve union of multi markers
Browse files Browse the repository at this point in the history
  • Loading branch information
radoering committed Apr 1, 2022
1 parent d6d33de commit f8ae8bd
Show file tree
Hide file tree
Showing 3 changed files with 199 additions and 50 deletions.
127 changes: 95 additions & 32 deletions src/poetry/core/version/markers.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,14 +460,51 @@ def intersect(self, other: BaseMarker) -> BaseMarker:
return MultiMarker.of(*new_markers)

def union(self, other: BaseMarker) -> BaseMarker:
if other in self._markers:
return other

if isinstance(other, (SingleMarker, MultiMarker)):
return MarkerUnion.of(self, other)

return other.union(self)

def union_simplify(self, other: BaseMarker) -> BaseMarker | None:
"""
In contrast to the standard union method, which prefers to return
a MarkerUnion of MultiMarkers, this version prefers to return
a MultiMarker of MarkerUnions.
The rationale behind this approach is to find additional simplifications.
In order to avoid endless recursions, this method returns None
if it cannot find a simplification.
"""
if isinstance(other, SingleMarker):
new_markers = []
for marker in self._markers:
union = marker.union(other)
if not union.is_any():
new_markers.append(union)

if len(new_markers) == 1:
return new_markers[0]
if other in new_markers and all(
other == m or isinstance(m, MarkerUnion) and other in m.markers
for m in new_markers
):
return other

elif isinstance(other, MultiMarker):
markers = set(self._markers)
other_markers = set(other.markers)
common_markers = markers & other_markers
unique_markers = markers - common_markers
other_unique_markers = other_markers - common_markers
if common_markers:
unique_union = self.of(*unique_markers).union(
self.of(*other_unique_markers)
)
if not isinstance(unique_union, MarkerUnion):
return self.of(*common_markers).intersect(unique_union)

return None

def validate(self, environment: dict[str, Any]) -> bool:
return all(m.validate(environment) for m in self._markers)

Expand Down Expand Up @@ -543,42 +580,68 @@ def markers(self) -> list[BaseMarker]:

@classmethod
def of(cls, *markers: BaseMarker) -> BaseMarker:
flattened_markers = _flatten_markers(markers, MarkerUnion)
new_markers = _flatten_markers(markers, MarkerUnion)
old_markers: list[BaseMarker] = []

new_markers: list[BaseMarker] = []
for marker in flattened_markers:
if marker in new_markers:
continue
while old_markers != new_markers:
old_markers = new_markers
new_markers = []
for marker in old_markers:
if marker in new_markers or marker.is_empty():
continue

if isinstance(marker, SingleMarker):
included = False
for i, mark in enumerate(new_markers):
if isinstance(mark, SingleMarker) and (
mark.name == marker.name
or (
mark.name in PYTHON_VERSION_MARKERS
and marker.name in PYTHON_VERSION_MARKERS
)
):
union = mark.constraint.union(marker.constraint)
if union == mark.constraint:
included = True
break
elif union == marker.constraint:
new_markers[i] = marker
included = True
break
elif union.is_any():
return AnyMarker()
elif isinstance(union, VersionConstraint) and union.is_simple():
new_markers[i] = SingleMarker(mark.name, union)

if isinstance(marker, SingleMarker):
for i, mark in enumerate(new_markers):
if isinstance(mark, SingleMarker) and (
mark.name == marker.name
or (
mark.name in PYTHON_VERSION_MARKERS
and marker.name in PYTHON_VERSION_MARKERS
)
):
constraint_union = mark.constraint.union(marker.constraint)
if constraint_union == mark.constraint:
included = True
break
elif constraint_union == marker.constraint:
new_markers[i] = marker
included = True
break
elif constraint_union.is_any():
return AnyMarker()
elif (
isinstance(constraint_union, VersionConstraint)
and constraint_union.is_simple()
):
new_markers[i] = SingleMarker(
mark.name, constraint_union
)
included = True
break

elif isinstance(mark, MultiMarker):
union = mark.union_simplify(marker)
if union is not None:
new_markers[i] = union
included = True
break

elif isinstance(marker, MultiMarker):
included = False
for i, mark in enumerate(new_markers):
union = marker.union_simplify(mark)
if union is not None:
new_markers[i] = union
included = True
break

if included:
continue

new_markers.append(marker)
# flatten again because union_simplify may return a union
new_markers = _flatten_markers(new_markers, MarkerUnion)
else:
new_markers.append(marker)

if any(m.is_any() for m in new_markers):
return AnyMarker()
Expand Down
15 changes: 10 additions & 5 deletions tests/packages/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,27 @@

def test_convert_markers():
marker = parse_marker(
'sys_platform == "win32" and python_version < "3.6" or sys_platform == "win32"'
'sys_platform == "win32" and python_version < "3.6" or sys_platform == "linux"'
' and python_version < "3.6" and python_version >= "3.3" or sys_platform =='
' "win32" and python_version < "3.3"'
' "darwin" and python_version < "3.3"'
)

converted = convert_markers(marker)

assert converted["python_version"] == [
[("<", "3.6")],
[("<", "3.6"), (">=", "3.3")],
[("<", "3.3")],
]

marker = parse_marker('python_version == "2.7" or python_version == "2.6"')
marker = parse_marker(
'sys_platform == "win32" and python_version < "3.6" or sys_platform == "win32"'
' and python_version < "3.6" and python_version >= "3.3" or sys_platform =='
' "win32" and python_version < "3.3"'
)
converted = convert_markers(marker)
assert converted["python_version"] == [[("<", "3.6")]]

marker = parse_marker('python_version == "2.7" or python_version == "2.6"')
converted = convert_markers(marker)
assert converted["python_version"] == [[("==", "2.7")], [("==", "2.6")]]


Expand Down
107 changes: 94 additions & 13 deletions tests/version/test_markers.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,60 @@ def test_single_marker_union_with_multi_duplicate():
assert str(union) == 'sys_platform == "darwin" and python_version >= "3.6"'


@pytest.mark.parametrize(
("single_marker", "multi_marker", "expected"),
[
(
'python_version >= "3.6"',
'python_version >= "3.7" and sys_platform == "win32"',
'python_version >= "3.6"',
),
(
'sys_platform == "linux"',
'sys_platform != "linux" and sys_platform != "win32"',
'sys_platform != "win32"',
),
],
)
def test_single_marker_union_with_multi_is_single_marker(
single_marker: str, multi_marker: str, expected: str
):
m = parse_marker(single_marker)
union = m.union(parse_marker(multi_marker))
assert str(union) == expected


def test_single_marker_union_with_multi_cannot_be_simplified():
m = parse_marker('python_version >= "3.7"')
union = m.union(parse_marker('python_version >= "3.6" and sys_platform == "win32"'))
assert (
str(union)
== 'python_version >= "3.6" and sys_platform == "win32" or python_version >='
' "3.7"'
)


def test_single_marker_union_with_multi_is_union_of_single_markers():
m = parse_marker('python_version >= "3.6"')
union = m.union(parse_marker('python_version < "3.6" and sys_platform == "win32"'))
assert str(union) == 'sys_platform == "win32" or python_version >= "3.6"'


def test_single_marker_union_with_multi_union_is_union_of_single_markers():
m = parse_marker('python_version >= "3.6"')
union = m.union(
parse_marker(
'python_version < "3.6" and sys_platform == "win32" or python_version <'
' "3.6" and sys_platform == "linux"'
)
)
assert (
str(union)
== 'sys_platform == "win32" or sys_platform == "linux" or python_version >='
' "3.6"'
)


def test_single_marker_union_with_union():
m = parse_marker('sys_platform == "darwin"')

Expand Down Expand Up @@ -367,29 +421,60 @@ def test_multi_marker_intersect_with_multi_union_leads_to_empty_in_two_steps():
def test_multi_marker_union_multi():
m = parse_marker('sys_platform == "darwin" and implementation_name == "cpython"')

intersection = m.union(
parse_marker('python_version >= "3.6" and os_name == "Windows"')
)
union = m.union(parse_marker('python_version >= "3.6" and os_name == "Windows"'))
assert (
str(intersection)
str(union)
== 'sys_platform == "darwin" and implementation_name == "cpython" '
'or python_version >= "3.6" and os_name == "Windows"'
)


def test_multi_marker_union_multi_is_single_marker():
m = parse_marker('python_version >= "3" and sys_platform == "win32"')
m2 = parse_marker('sys_platform != "win32" and python_version >= "3"')
assert str(m.union(m2)) == 'python_version >= "3"'
assert str(m2.union(m)) == 'python_version >= "3"'


def test_multi_marker_union_multi_is_multi():
m = parse_marker('python_version >= "3" and sys_platform == "win32"')
m2 = parse_marker(
'python_version >= "3" and sys_platform != "win32" and sys_platform != "linux"'
)
assert str(m.union(m2)) == 'python_version >= "3" and sys_platform != "linux"'
assert str(m2.union(m)) == 'python_version >= "3" and sys_platform != "linux"'


def test_multi_marker_union_with_union():
m = parse_marker('sys_platform == "darwin" and implementation_name == "cpython"')

intersection = m.union(
parse_marker('python_version >= "3.6" or os_name == "Windows"')
)
union = m.union(parse_marker('python_version >= "3.6" or os_name == "Windows"'))
assert (
str(intersection)
str(union)
== 'python_version >= "3.6" or os_name == "Windows"'
' or sys_platform == "darwin" and implementation_name == "cpython"'
)


def test_multi_marker_union_with_multi_union_is_single_marker():
m = parse_marker('sys_platform == "darwin" and python_version == "3"')
m2 = parse_marker(
'sys_platform == "darwin" and python_version < "3" or sys_platform == "darwin"'
' and python_version > "3"'
)
assert str(m.union(m2)) == 'sys_platform == "darwin"'
assert str(m2.union(m)) == 'sys_platform == "darwin"'


def test_multi_marker_union_with_union_multi_is_single_marker():
m = parse_marker('sys_platform == "darwin" and python_version == "3"')
m2 = parse_marker(
'sys_platform == "darwin" and (python_version < "3" or python_version > "3")'
)
assert str(m.union(m2)) == 'sys_platform == "darwin"'
assert str(m2.union(m)) == 'sys_platform == "darwin"'


def test_marker_union():
m = parse_marker('sys_platform == "darwin" or implementation_name == "cpython"')

Expand Down Expand Up @@ -440,11 +525,7 @@ def test_marker_union_intersect_single_with_overlapping_constraints():

m = parse_marker('sys_platform == "darwin" or python_version < "3.4"')
intersection = m.intersect(parse_marker('sys_platform == "darwin"'))
assert (
str(intersection)
== 'sys_platform == "darwin" or python_version < "3.4" and sys_platform =='
' "darwin"'
)
assert str(intersection) == 'sys_platform == "darwin"'


def test_marker_union_intersect_marker_union():
Expand Down

0 comments on commit f8ae8bd

Please sign in to comment.