Skip to content

Commit

Permalink
Switch colcon_core.extension_point to importlib.metadata
Browse files Browse the repository at this point in the history
  • Loading branch information
cottsay committed Jul 11, 2023
1 parent e2a864c commit bc06873
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 50 deletions.
71 changes: 33 additions & 38 deletions colcon_core/extension_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,17 @@
import os
import traceback

try:
from importlib.metadata import distributions
from importlib.metadata import EntryPoint
from importlib.metadata import entry_points
except ImportError:
from importlib_metadata import distributions
from importlib_metadata import EntryPoint
from importlib_metadata import entry_points

from colcon_core.environment_variable import EnvironmentVariable
from colcon_core.logging import colcon_logger
from pkg_resources import EntryPoint
from pkg_resources import iter_entry_points
from pkg_resources import WorkingSet

"""Environment variable to block extensions"""
EXTENSION_BLOCKLIST_ENVIRONMENT_VARIABLE = EnvironmentVariable(
Expand Down Expand Up @@ -44,27 +50,25 @@ def get_all_extension_points():
colcon_extension_points.setdefault(EXTENSION_POINT_GROUP_NAME, None)

entry_points = defaultdict(dict)
working_set = WorkingSet()
for dist in sorted(working_set):
entry_map = dist.get_entry_map()
for group_name in entry_map.keys():
seen = set()
for dist in distributions():
if dist.name in seen:
continue
seen.add(dist.name)
for entry_point in dist.entry_points:
# skip groups which are not registered as extension points
if group_name not in colcon_extension_points:
if entry_point.group not in colcon_extension_points:
continue

group = entry_map[group_name]
for entry_point_name, entry_point in group.items():
if entry_point_name in entry_points[group_name]:
previous = entry_points[group_name][entry_point_name]
logger.error(
f"Entry point '{group_name}.{entry_point_name}' is "
f"declared multiple times, '{entry_point}' "
f"overwriting '{previous}'")
value = entry_point.module_name
if entry_point.attrs:
value += f":{'.'.join(entry_point.attrs)}"
entry_points[group_name][entry_point_name] = (
value, dist.project_name, getattr(dist, 'version', None))
if entry_point.name in entry_points[entry_point.group]:
previous = entry_points[entry_point.group][entry_point.name]
logger.error(
f"Entry point '{entry_point.group}.{entry_point.name}' is "
f"declared multiple times, '{entry_point.value}' "
f"from '{dist._path}' "
f"overwriting '{previous}'")
entry_points[entry_point.group][entry_point.name] = \
(entry_point.value, dist.name, dist.version)
return entry_points


Expand All @@ -76,19 +80,16 @@ def get_extension_points(group):
:returns: mapping of extension point names to extension point values
:rtype: dict
"""
entry_points = {}
for entry_point in iter_entry_points(group=group):
if entry_point.name in entry_points:
previous_entry_point = entry_points[entry_point.name]
extension_points = {}
for entry_point in entry_points(group=group):
if entry_point.name in extension_points:
previous_entry_point = extension_points[entry_point.name]
logger.error(
f"Entry point '{group}.{entry_point.name}' is declared "
f"multiple times, '{entry_point}' overwriting "
f"multiple times, '{entry_point.value}' overwriting "
f"'{previous_entry_point}'")
value = entry_point.module_name
if entry_point.attrs:
value += f":{'.'.join(entry_point.attrs)}"
entry_points[entry_point.name] = value
return entry_points
extension_points[entry_point.name] = entry_point.value
return extension_points


def load_extension_points(group, *, excludes=None):
Expand Down Expand Up @@ -146,10 +147,4 @@ def load_extension_point(name, value, group):
raise RuntimeError(
'The entry point name is listed in the environment variable '
f"'{EXTENSION_BLOCKLIST_ENVIRONMENT_VARIABLE.name}'")
if ':' in value:
module_name, attr = value.split(':', 1)
attrs = attr.split('.')
else:
module_name = value
attrs = ()
return EntryPoint(name, module_name, attrs).resolve()
return EntryPoint(name, value, group).load()
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ install_requires =
coloredlogs; sys_platform == 'win32'
distlib
EmPy
importlib-metadata; python_version < "3.8"
# the pytest dependency and its extensions are provided for convenience
# even though they are only conditional
pytest
Expand Down
2 changes: 1 addition & 1 deletion stdeb.cfg
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[colcon-core]
No-Python2:
Depends3: python3-distlib, python3-empy, python3-pytest, python3-setuptools
Depends3: python3-distlib, python3-empy, python3-pytest, python3-setuptools, python3 (>= 3.7) | python3-importlib-metadata
Recommends3: python3-pytest-cov
Suggests3: python3-pytest-repeat, python3-pytest-rerunfailures
Suite: bionic focal jammy stretch buster bullseye
Expand Down
2 changes: 1 addition & 1 deletion test/spell_check.words
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ apache
argcomplete
argparse
asyncio
attrs
autouse
basepath
bazqux
Expand Down Expand Up @@ -48,6 +47,7 @@ hardcodes
hookimpl
hookwrapper
https
importlib
isatty
iterdir
junit
Expand Down
24 changes: 14 additions & 10 deletions test/test_extension_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
from .environment_context import EnvironmentContext


Group1 = EntryPoint('group1', 'g1')
Group2 = EntryPoint('group2', 'g2')
Group1 = EntryPoint('group1', 'g1', EXTENSION_POINT_GROUP_NAME)
Group2 = EntryPoint('group2', 'g2', EXTENSION_POINT_GROUP_NAME)


class Dist():
Expand All @@ -40,8 +40,8 @@ def iter_entry_points(*, group):
if group == EXTENSION_POINT_GROUP_NAME:
return [Group1, Group2]
assert group == Group1.name
ep1 = EntryPoint('extA', 'eA')
ep2 = EntryPoint('extB', 'eB')
ep1 = EntryPoint('extA', 'eA', Group1.name)
ep2 = EntryPoint('extB', 'eB', Group1.name)
return [ep1, ep2]


Expand All @@ -50,14 +50,18 @@ def working_set():
Dist('group1', {
'group1': {ep.name: ep for ep in iter_entry_points(group='group1')}
}),
Dist('group2', {'group2': {'extC': EntryPoint('extC', 'eC')}}),
Dist('groupX', {'groupX': {'extD': EntryPoint('extD', 'eD')}}),
Dist('group2', {
'group2': {'extC': EntryPoint('extC', 'eC', Group2.name)}
}),
Dist('groupX', {
'groupX': {'extD': EntryPoint('extD', 'eD', 'groupX')}
}),
]


def test_all_extension_points():
with patch(
'colcon_core.extension_point.iter_entry_points',
'colcon_core.extension_point.entry_points',
side_effect=iter_entry_points
):
with patch(
Expand All @@ -75,7 +79,7 @@ def test_all_extension_points():
def test_extension_point_blocklist():
# successful loading of extension point without a blocklist
with patch(
'colcon_core.extension_point.iter_entry_points',
'colcon_core.extension_point.entry_points',
side_effect=iter_entry_points
):
with patch(
Expand Down Expand Up @@ -119,7 +123,7 @@ def test_extension_point_blocklist():
assert resolve.call_count == 0


def entry_point_resolve(self, *args, **kwargs):
def entry_point_load(self, *args, **kwargs):
if self.name == 'exception':
raise Exception('entry point raising exception')
if self.name == 'runtime_error':
Expand All @@ -129,7 +133,7 @@ def entry_point_resolve(self, *args, **kwargs):
return DEFAULT


@patch.object(EntryPoint, 'resolve', entry_point_resolve)
@patch.object(EntryPoint, 'load', entry_point_load)
@patch(
'colcon_core.extension_point.get_extension_points',
return_value={'exception': 'a', 'runtime_error': 'b', 'success': 'c'}
Expand Down

0 comments on commit bc06873

Please sign in to comment.