diff --git a/plugins/flytekit-kf-pytorch/tests/test_elastic_task.py b/plugins/flytekit-kf-pytorch/tests/test_elastic_task.py index bf3c6d1794..2ca6c9cc65 100644 --- a/plugins/flytekit-kf-pytorch/tests/test_elastic_task.py +++ b/plugins/flytekit-kf-pytorch/tests/test_elastic_task.py @@ -1,4 +1,5 @@ import os +import typing from dataclasses import dataclass import pytest @@ -28,7 +29,7 @@ def dist_communicate() -> int: return tensor.item() -def train(config: Config) -> tuple[str, Config, torch.nn.Module, int]: +def train(config: Config) -> typing.Tuple[str, Config, torch.nn.Module, int]: """Mock training a model using torch-elastic for test purposes.""" dist.init_process_group(backend="gloo") @@ -50,7 +51,7 @@ def test_end_to_end(start_method: str) -> None: train_task = task(train, task_config=Elastic(nnodes=1, nproc_per_node=world_size, start_method=start_method)) @workflow - def wf(config: Config = Config()) -> tuple[str, Config, torch.nn.Module, int]: + def wf(config: Config = Config()) -> typing.Tuple[str, Config, torch.nn.Module, int]: return train_task(config=config) r, cfg, m, distributed_result = wf()