Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test: fix test_force_editable_mode #161

Merged
merged 1 commit into from
Sep 4, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 49 additions & 50 deletions tests/dependency/test_dependency.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import sys
import tempfile
import uuid
from pathlib import Path

Expand All @@ -15,7 +14,7 @@
from substrafl.nodes.node import InputIdentifiers
from substrafl.nodes.node import OutputIdentifiers
from substrafl.remote import remote_data
from substrafl.remote.register.register import _create_substra_function_files
from substrafl.remote.register.register import register_function

CURRENT_FILE = Path(__file__)

Expand Down Expand Up @@ -51,54 +50,53 @@ class TestLocalDependency:
def _register_function(self, my_algo, algo_deps, client, session_dir):
"""Register a train function"""
data_op = my_algo.train(data_samples=list(), shared_state=None)
operation_dir = Path(tempfile.mkdtemp(dir=session_dir))
archive_path, description_path = _create_substra_function_files(
data_op.remote_struct,

inputs = [
substra.schemas.FunctionInputSpec(
identifier=InputIdentifiers.datasamples,
kind=substra.schemas.AssetKind.data_sample.value,
optional=False,
multiple=True,
),
substra.schemas.FunctionInputSpec(
identifier=InputIdentifiers.opener,
kind=substra.schemas.AssetKind.data_manager.value,
optional=False,
multiple=False,
),
substra.schemas.FunctionInputSpec(
identifier=InputIdentifiers.local,
kind=substra.schemas.AssetKind.model.value,
optional=True,
multiple=False,
),
substra.schemas.FunctionInputSpec(
identifier=InputIdentifiers.shared,
kind=substra.schemas.AssetKind.model.value,
optional=True,
multiple=False,
),
]

outputs = [
substra.schemas.FunctionOutputSpec(
identifier=OutputIdentifiers.local, kind=substra.schemas.AssetKind.model.value, multiple=False
),
substra.schemas.FunctionOutputSpec(
identifier=OutputIdentifiers.shared, kind=substra.schemas.AssetKind.model.value, multiple=False
),
]
permissions = substra.schemas.Permissions(public=True, authorized_ids=list())

function_key = register_function(
client=client,
remote_struct=data_op.remote_struct,
permissions=permissions,
inputs=inputs,
outputs=outputs,
dependencies=algo_deps,
install_libraries=client.backend_mode != substra.BackendType.LOCAL_SUBPROCESS,
operation_dir=operation_dir,
)
algo_query = substra.schemas.FunctionSpec(
name="algo_test_deps",
inputs=[
substra.schemas.FunctionInputSpec(
identifier=InputIdentifiers.datasamples,
kind=substra.schemas.AssetKind.data_sample.value,
optional=False,
multiple=True,
),
substra.schemas.FunctionInputSpec(
identifier=InputIdentifiers.opener,
kind=substra.schemas.AssetKind.data_manager.value,
optional=False,
multiple=False,
),
substra.schemas.FunctionInputSpec(
identifier=InputIdentifiers.local,
kind=substra.schemas.AssetKind.model.value,
optional=True,
multiple=False,
),
substra.schemas.FunctionInputSpec(
identifier=InputIdentifiers.shared,
kind=substra.schemas.AssetKind.model.value,
optional=True,
multiple=False,
),
],
outputs=[
substra.schemas.FunctionOutputSpec(
identifier=OutputIdentifiers.local, kind=substra.schemas.AssetKind.model.value, multiple=False
),
substra.schemas.FunctionOutputSpec(
identifier=OutputIdentifiers.shared, kind=substra.schemas.AssetKind.model.value, multiple=False
),
],
description=description_path,
file=archive_path,
permissions=substra.schemas.Permissions(public=True, authorized_ids=list()),
)
function_key = client.add_function(algo_query)

return function_key

def _register_train_task(self, function_key, dataset_key, data_sample_key, client):
Expand Down Expand Up @@ -311,17 +309,18 @@ def test_force_editable_mode(
mocker.patch("substrafl.dependency.manage_dependencies.compile_requirements")

client = network.clients[0]
monkeypatch.setenv("SUBSTRA_FORCE_EDITABLE_MODE", str(True))

algo_deps = Dependency(pypi_dependencies=["pytest"], editable_mode=False)

monkeypatch.setenv("SUBSTRA_FORCE_EDITABLE_MODE", str(True))
self._register_function(dummy_algo_class(), algo_deps, client, session_dir)
assert substrafl.dependency.manage_dependencies.local_lib_wheels.call_count == 1

substrafl.dependency.manage_dependencies.local_lib_wheels.reset_mock()

monkeypatch.setenv("SUBSTRA_FORCE_EDITABLE_MODE", str(False))
self._register_function(dummy_algo_class(), algo_deps, client, session_dir)
assert substrafl.remote.register.register.local_lib_wheels.call_count == 0
assert substrafl.dependency.manage_dependencies.local_lib_wheels.call_count == 0


def test_get_compute():
Expand Down
Loading