Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing checkpoint downloader #174

Merged
merged 12 commits into from
Dec 21, 2021
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ the section headers (Added/Changed/...) and incrementing the package version.
- ([#161](https://github.com/microsoft/hi-ml/pull/161)) Empty string as target folder for a dataset creates an invalid mounting path for the dataset in AzureML (fixes #160)
- ([#167](https://github.com/microsoft/hi-ml/pull/167)) Fix bugs in logging hyperparameters: logging as name/value
table, rather than one column per hyperparameter. Use string logging for all hyperparameters
- ([#174](https://github.com/microsoft/hi-ml/pull/174)) Fix bugs in returned local_checkpoint_path when downloading checkpoints from AML run

### Removed

Expand Down
28 changes: 18 additions & 10 deletions hi-ml-azure/src/health_azure/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ def from_string(self, x: str) -> List[str]:
class CheckpointDownloader:
def __init__(self, run_id: str, checkpoint_filename: str, azure_config_json_path: Path = None,
aml_workspace: Workspace = None, download_dir: PathOrString = "checkpoints",
remote_checkpoint_folder: PathOrString = "checkpoints") -> None:
remote_checkpoint_dir: PathOrString = "checkpoints") -> None:
vale-salvatelli marked this conversation as resolved.
Show resolved Hide resolved
"""
Utility class for downloading checkpoint files from an Azure ML run

Expand All @@ -480,25 +480,29 @@ def __init__(self, run_id: str, checkpoint_filename: str, azure_config_json_path
:param aml_workspace: An optional Azure ML Workspace object. If not running inside an AML Run, and no
azure_config_json_path is provided, this is required.
:param download_dir: The local directory in which to save the downloaded checkpoint files.
:param remote_checkpoint_folder: The remote folder from which to download the checkpoint file
:param remote_checkpoint_dir: The remote folder from which to download the checkpoint file
"""
self.azure_config_json_path = azure_config_json_path
self.aml_workspace = aml_workspace
self.run_id = run_id
self.checkpoint_filename = checkpoint_filename
self.download_dir = Path(download_dir)
self.remote_checkpoint_folder = Path(remote_checkpoint_folder)
self.remote_checkpoint_dir = Path(remote_checkpoint_dir)

@property
def local_checkpoint_path(self) -> Path:
def local_checkpoint_dir(self) -> Path:
# in case we run_id is a run recovery id, extract the run id
run_id_parts = self.run_id.split(":")
run_id = run_id_parts[-1]
return self.download_dir / run_id / self.checkpoint_filename
return self.download_dir / run_id

@property
def remote_checkpoint_path(self) -> Path:
return self.remote_checkpoint_folder / self.checkpoint_filename
return self.remote_checkpoint_dir / self.checkpoint_filename

@property
def local_checkpoint_path(self) -> Path:
return self.local_checkpoint_dir / self.remote_checkpoint_path
vale-salvatelli marked this conversation as resolved.
Show resolved Hide resolved

def download_checkpoint_if_necessary(self) -> Path:
"""Downloads the specified checkpoint if it does not already exist.
Expand All @@ -509,9 +513,9 @@ def download_checkpoint_if_necessary(self) -> Path:
workspace_config_path=self.azure_config_json_path)

if not self.local_checkpoint_path.exists():
local_checkpoint_dir = self.local_checkpoint_path.parent
local_checkpoint_dir.mkdir(exist_ok=True, parents=True)
download_checkpoints_from_run_id(self.run_id, str(self.remote_checkpoint_path), local_checkpoint_dir,
self.local_checkpoint_dir.mkdir(exist_ok=True, parents=True)
download_checkpoints_from_run_id(self.run_id, str(self.remote_checkpoint_path),
self.local_checkpoint_dir,
aml_workspace=workspace)
assert self.local_checkpoint_path.exists()

Expand Down Expand Up @@ -608,9 +612,12 @@ def get_workspace(aml_workspace: Optional[Workspace] = None, workspace_config_pa
else:
raise ValueError("No workspace config file given, nor can we find one.")

if workspace_config_path.is_file():
if not isinstance(workspace_config_path, Path):
raise ValueError("Workspace config path is not a path, check your input.")
elif workspace_config_path.is_file():
auth = get_authentication()
return Workspace.from_config(path=str(workspace_config_path), auth=auth)

raise ValueError("Workspace config file does not exist or cannot be read.")


Expand Down Expand Up @@ -1046,6 +1053,7 @@ def get_run_file_names(run: Run, prefix: str = "") -> List[str]:
:return: A list of paths within the Run's container
"""
all_files = run.get_file_names()
print(f"Selecting files with prefix {prefix}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you leave this print statement in by accident?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I made it a bit nicer and left it on purpose - I found it's useful to know which prefix is used to select among run files

return [f for f in all_files if f.startswith(prefix)] if prefix else all_files


Expand Down
6 changes: 6 additions & 0 deletions hi-ml-azure/testazure/testazure/test_azure_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,11 +116,17 @@ def test_get_workspace(
with pytest.raises(ValueError) as ex:
util.get_workspace(None, None)
assert "No workspace config file given" in str(ex)

# Workspace config file is set to a file that does not exist
with pytest.raises(ValueError) as ex:
util.get_workspace(None, workspace_config_path=tmp_path / "does_not_exist")
assert "Workspace config file does not exist" in str(ex)

# Workspace config file is set to a wrong type
with pytest.raises(ValueError) as ex:
util.get_workspace(None, workspace_config_path=1) # type: ignore
assert "Workspace config path is not a path" in str(ex)


@patch("health_azure.utils.Run")
def test_create_run_recovery_id(mock_run: MagicMock) -> None:
Expand Down