Skip to content

Commit

Permalink
updated types
Browse files Browse the repository at this point in the history
Signed-off-by: Ketan Umare <ketan.umare@gmail.com>
  • Loading branch information
kumare3 committed May 3, 2023
1 parent 312a436 commit a68f8c4
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions plugins/flytekit-kf-pytorch/tests/test_elastic_task.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import typing
from dataclasses import dataclass

import pytest
Expand Down Expand Up @@ -28,7 +29,7 @@ def dist_communicate() -> int:
return tensor.item()


def train(config: Config) -> tuple[str, Config, torch.nn.Module, int]:
def train(config: Config) -> typing.Tuple[str, Config, torch.nn.Module, int]:
"""Mock training a model using torch-elastic for test purposes."""
dist.init_process_group(backend="gloo")

Expand All @@ -50,7 +51,7 @@ def test_end_to_end(start_method: str) -> None:
train_task = task(train, task_config=Elastic(nnodes=1, nproc_per_node=world_size, start_method=start_method))

@workflow
def wf(config: Config = Config()) -> tuple[str, Config, torch.nn.Module, int]:
def wf(config: Config = Config()) -> typing.Tuple[str, Config, torch.nn.Module, int]:
return train_task(config=config)

r, cfg, m, distributed_result = wf()
Expand Down

0 comments on commit a68f8c4

Please sign in to comment.