Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds auto object store creation to get_file #1750

Merged
merged 9 commits into from
Nov 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 9 additions & 55 deletions composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from copy import deepcopy
from pathlib import Path
from typing import Any, Callable, ContextManager, Dict, Iterable, List, Optional, Sequence, TextIO, Tuple, Union, cast
from urllib.parse import urlparse

import coolname
import torch
Expand All @@ -44,9 +43,11 @@
from composer.trainer._scale_schedule import scale_pytorch_scheduler
from composer.trainer._scaler import ClosureGradScaler
from composer.trainer.dist_strategy import DDPSyncStrategy, ddp_sync_context, prepare_ddp_module, prepare_fsdp_module
from composer.utils import (ExportFormat, MissingConditionalImportError, ObjectStore, S3ObjectStore, Transform,
checkpoint, dist, ensure_tuple, export_with_logger, format_name_with_dist, get_device,
get_file, is_tpu_installed, map_collection, model_eval_mode, reproducibility)
from composer.utils import (ExportFormat, MissingConditionalImportError, ObjectStore, Transform, checkpoint, dist,
ensure_tuple, export_with_logger, format_name_with_dist, get_device, get_file,
is_tpu_installed, map_collection, maybe_create_object_store_from_uri,
maybe_create_remote_uploader_downloader_from_uri, model_eval_mode, parse_uri,
reproducibility)

if is_tpu_installed():
import torch_xla.core.xla_model as xm
Expand Down Expand Up @@ -285,53 +286,6 @@ def _generate_run_name() -> str:
return generated_run_name


def _maybe_create_object_store_from_uri(uri: str) -> Optional[ObjectStore]:
backend, bucket_name, _ = _parse_uri(uri)
if backend == '':
return None
if backend == 's3':
return S3ObjectStore(bucket=bucket_name)
elif backend == 'wandb':
raise NotImplementedError(f'There is no implementation for WandB load_object_store via URI. Please use '
'WandBLogger')
else:
raise NotImplementedError(f'There is no implementation for the cloud backend {backend} via URI. Please use '
's3 or one of the supported object stores')


def _maybe_create_remote_uploader_downloader_from_uri(
uri: str, loggers: List[LoggerDestination]) -> Optional[RemoteUploaderDownloader]:
existing_remote_uds = [logger_dest for logger_dest in loggers if isinstance(logger_dest, RemoteUploaderDownloader)]
backend, bucket_name, _ = _parse_uri(uri)
if backend == '':
return None
for existing_remote_ud in existing_remote_uds:
if ((existing_remote_ud.remote_backend_name == backend) and
(existing_remote_ud.remote_bucket_name == bucket_name)):
warnings.warn(
f'There already exists a RemoteUploaderDownloader object to handle the uri: {uri} you specified')
return None
if backend == 's3':
return RemoteUploaderDownloader(bucket_uri=f'{backend}://{bucket_name}')

elif backend == 'wandb':
raise NotImplementedError(f'There is no implementation for WandB via URI. Please use '
'WandBLogger with log_artifacts set to True')

else:
raise NotImplementedError(f'There is no implementation for the cloud backend {backend} via URI. Please use '
's3 or one of the supported RemoteUploaderDownloader object stores')


def _parse_uri(uri: str) -> Tuple[str, str, str]:
parse_result = urlparse(uri)
backend, bucket_name, path = parse_result.scheme, parse_result.netloc, parse_result.path
if backend == '' and bucket_name == '':
return backend, bucket_name, path
else:
return backend, bucket_name, path.lstrip('/')


class Trainer:
"""Train models with Composer algorithms.

Expand Down Expand Up @@ -1010,7 +964,7 @@ def __init__(
))

if save_folder is not None:
remote_ud = _maybe_create_remote_uploader_downloader_from_uri(save_folder, loggers)
remote_ud = maybe_create_remote_uploader_downloader_from_uri(save_folder, loggers)
if remote_ud is not None:
loggers.append(remote_ud)

Expand All @@ -1035,7 +989,7 @@ def __init__(
self._checkpoint_saver = None
latest_remote_file_name = None
if save_folder is not None:
_, _, parsed_save_folder = _parse_uri(save_folder)
_, _, parsed_save_folder = parse_uri(save_folder)

# If user passes a URI with s3:// and a bucket_name, but no other
# path then we assume they just want their checkpoints saved directly in their
Expand Down Expand Up @@ -1258,12 +1212,12 @@ def __init__(
# Actually load the checkpoint from potentially updated arguments
if load_path is not None:
if load_object_store is None:
load_object_store = _maybe_create_object_store_from_uri(load_path)
load_object_store = maybe_create_object_store_from_uri(load_path)
if isinstance(load_object_store, WandBLogger):
import wandb
if wandb.run is None:
load_object_store.init(self.state, self.logger)
_, _, parsed_load_path = _parse_uri(load_path)
_, _, parsed_load_path = parse_uri(load_path)
self._rng_state = checkpoint.load_checkpoint(
state=self.state,
logger=self.logger,
Expand Down
6 changes: 5 additions & 1 deletion composer/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from composer.utils.file_helpers import (FORMAT_NAME_WITH_DIST_AND_TIME_TABLE, FORMAT_NAME_WITH_DIST_TABLE,
create_symlink_file, ensure_folder_has_no_conflicting_files,
ensure_folder_is_empty, format_name_with_dist, format_name_with_dist_and_time,
get_file, is_tar)
get_file, is_tar, maybe_create_object_store_from_uri,
maybe_create_remote_uploader_downloader_from_uri, parse_uri)
from composer.utils.import_helpers import MissingConditionalImportError, import_object
from composer.utils.inference import ExportFormat, Transform, export_for_inference, export_with_logger, quantize_dynamic
from composer.utils.iter_helpers import IteratorFileStream, ensure_tuple, map_collection
Expand Down Expand Up @@ -71,6 +72,9 @@ def warn_streaming_dataset_deprecation(old_version: int, new_version: int) -> No
'format_name_with_dist',
'format_name_with_dist_and_time',
'is_tar',
'maybe_create_object_store_from_uri',
'maybe_create_remote_uploader_downloader_from_uri',
'parse_uri',
'batch_get',
'batch_set',
'configure_excepthook',
Expand Down
107 changes: 96 additions & 11 deletions composer/utils/file_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,29 +11,27 @@
import re
import tempfile
import uuid
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
import warnings
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from urllib.parse import urlparse

import requests
import tqdm

from composer.utils import dist
from composer.utils.iter_helpers import iterate_with_callback
from composer.utils.object_store import ObjectStore
from composer.utils.object_store import ObjectStore, S3ObjectStore

if TYPE_CHECKING:
from composer.core import Timestamp
from composer.loggers import LoggerDestination
from composer.loggers import LoggerDestination, RemoteUploaderDownloader

log = logging.getLogger(__name__)

__all__ = [
'get_file',
'ensure_folder_is_empty',
'ensure_folder_has_no_conflicting_files',
'format_name_with_dist',
'format_name_with_dist_and_time',
'is_tar',
'create_symlink_file',
'get_file', 'ensure_folder_is_empty', 'ensure_folder_has_no_conflicting_files', 'format_name_with_dist',
'format_name_with_dist_and_time', 'is_tar', 'create_symlink_file', 'maybe_create_object_store_from_uri',
'maybe_create_remote_uploader_downloader_from_uri', 'parse_uri'
]


Expand All @@ -42,7 +40,7 @@ def _get_dist_config(strict: bool = True) -> Dict[str, Any]:

If ``strict=True``, will error if a setting is not available (e.g. the
environment variable is not set). Otherwise, will only return settings
that are availalbe.
that are available.
"""
settings = {
'rank': dist.get_global_rank,
Expand Down Expand Up @@ -306,6 +304,86 @@ def format_name_with_dist_and_time(
"""


def parse_uri(uri: str) -> Tuple[str, str, str]:
"""Uses :py:func:`urllib.parse.urlparse` to parse the provided URI.

Args:
uri (str): The provided URI string

Returns:
Tuple[str, str, str]: A tuple containing the backend (e.g. s3), bucket name, and path.
Backend and bucket name will be empty string if the input is a local path
"""
parse_result = urlparse(uri)
backend, bucket_name, path = parse_result.scheme, parse_result.netloc, parse_result.path
if backend == '' and bucket_name == '':
return backend, bucket_name, path
else:
return backend, bucket_name, path.lstrip('/')


def maybe_create_object_store_from_uri(uri: str) -> Optional[ObjectStore]:
"""Automatically creates an :class:`composer.utils.ObjectStore` from supported URI formats.

Args:
uri (str): The path to (maybe) create an :class:`composer.utils.ObjectStore` from

Raises:
NotImplementedError: Raises when the URI format is not supported.

Returns:
Optional[ObjectStore]: Returns an :class:`composer.utils.ObjectStore` if the URI is of a supported format, otherwise None
"""
backend, bucket_name, _ = parse_uri(uri)
if backend == '':
return None
if backend == 's3':
return S3ObjectStore(bucket=bucket_name)
elif backend == 'wandb':
raise NotImplementedError(f'There is no implementation for WandB load_object_store via URI. Please use '
'WandBLogger')
else:
raise NotImplementedError(f'There is no implementation for the cloud backend {backend} via URI. Please use '
's3 or one of the supported object stores')


def maybe_create_remote_uploader_downloader_from_uri(
uri: str, loggers: List[LoggerDestination]) -> Optional['RemoteUploaderDownloader']:
"""Automatically creates a :class:`composer.loggers.RemoteUploaderDownloader` from supported URI formats.

Args:
uri (str):The path to (maybe) create a :class:`composer.loggers.RemoteUploaderDownloader` from
loggers (List[:class:`composer.loggers.LoggerDestination`]): List of the existing :class:`composer.loggers.LoggerDestination` s so as to not create a duplicate

Raises:
NotImplementedError: Raises when the URI format is not supported.

Returns:
Optional[RemoteUploaderDownloader]: Returns a :class:`composer.loggers.RemoteUploaderDownloader` if the URI is of a supported format, otherwise None
"""
from composer.loggers import RemoteUploaderDownloader
existing_remote_uds = [logger_dest for logger_dest in loggers if isinstance(logger_dest, RemoteUploaderDownloader)]
backend, bucket_name, _ = parse_uri(uri)
if backend == '':
return None
for existing_remote_ud in existing_remote_uds:
if ((existing_remote_ud.remote_backend_name == backend) and
(existing_remote_ud.remote_bucket_name == bucket_name)):
warnings.warn(
f'There already exists a RemoteUploaderDownloader object to handle the uri: {uri} you specified')
return None
if backend == 's3':
return RemoteUploaderDownloader(bucket_uri=f'{backend}://{bucket_name}')

elif backend == 'wandb':
raise NotImplementedError(f'There is no implementation for WandB via URI. Please use '
'WandBLogger with log_artifacts set to True')

else:
raise NotImplementedError(f'There is no implementation for the cloud backend {backend} via URI. Please use '
's3 or one of the supported RemoteUploaderDownloader object stores')


def get_file(
path: str,
destination: str,
Expand All @@ -324,6 +402,9 @@ def get_file(
* If ``object_store`` is not specified but the ``path`` begins with ``http://`` or ``https://``,
the object at this URL will be downloaded.

* If ``object_store`` is not specified, but the ``path`` begins with ``s3://``, an :class:`composer.utils.S3ObjectStore`
will be created and used.

* Otherwise, ``path`` is presumed to be a local filepath.

destination (str): The destination filepath.
Expand All @@ -347,6 +428,10 @@ def get_file(
Raises:
FileNotFoundError: If the ``path`` does not exist.
"""
if object_store is None and not (path.lower().startswith('http://') or path.lower().startswith('https://')):
object_store = maybe_create_object_store_from_uri(path)
_, _, path = parse_uri(path)

if path.endswith('.symlink'):
with tempfile.TemporaryDirectory() as tmpdir:
symlink_file_name = os.path.join(tmpdir, 'file.symlink')
Expand Down
8 changes: 5 additions & 3 deletions tests/loggers/test_remote_uploader_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,18 @@
class DummyObjectStore(ObjectStore):
"""Dummy ObjectStore implementation that is backed by a local directory."""

def __init__(self, dir: pathlib.Path, always_fail: bool = False, **kwargs: Dict[str, Any]) -> None:
self.dir = str(dir)
def __init__(self, dir: Optional[pathlib.Path] = None, always_fail: bool = False, **kwargs: Dict[str, Any]) -> None:
self.dir = str(dir) if dir is not None else kwargs['bucket']
self.always_fail = always_fail
assert isinstance(self.dir, str)
os.makedirs(self.dir, exist_ok=True)

def get_uri(self, object_name: str) -> str:
return 'local://' + object_name

def _get_abs_path(self, object_name: str):
return self.dir + '/' + object_name
assert isinstance(self.dir, str)
return os.path.abspath(self.dir + '/' + object_name)

def upload_object(
self,
Expand Down
20 changes: 18 additions & 2 deletions tests/utils/test_file_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import datetime
import os
import pathlib
from unittest.mock import patch

import pytest
import pytest_httpserver
Expand All @@ -13,9 +14,10 @@
format_name_with_dist, format_name_with_dist_and_time, get_file, is_tar)
from composer.utils.object_store.libcloud_object_store import LibcloudObjectStore
from tests.common.markers import world_size
from tests.loggers.test_remote_uploader_downloader import DummyObjectStore


@pytest.mark.xfail(reason='Occassionally hits the timeout. Should refactor to use a local webserver.')
@pytest.mark.xfail(reason='Occasionally hits the timeout. Should refactor to use a local webserver.')
def test_get_file_uri(tmp_path: pathlib.Path, httpserver: pytest_httpserver.HTTPServer):
httpserver.expect_request('/hi').respond_with_data('hi')
get_file(
Expand All @@ -27,7 +29,7 @@ def test_get_file_uri(tmp_path: pathlib.Path, httpserver: pytest_httpserver.HTTP
assert f.readline().startswith('<!')


@pytest.mark.xfail(reason='Occassionally hits the timeout. Should refactor to use a local webserver.')
@pytest.mark.xfail(reason='Occasionally hits the timeout. Should refactor to use a local webserver.')
def test_get_file_uri_not_found(tmp_path: pathlib.Path, httpserver: pytest_httpserver.HTTPServer):
with pytest.raises(FileNotFoundError):
get_file(
Expand Down Expand Up @@ -60,6 +62,20 @@ def test_get_file_object_store(tmp_path: pathlib.Path, monkeypatch: pytest.Monke
assert f.read() == b'checkpoint1'


def test_get_file_auto_object_store(tmp_path: pathlib.Path, monkeypatch: pytest.MonkeyPatch):
with patch('composer.utils.file_helpers.S3ObjectStore', DummyObjectStore):
object_store = DummyObjectStore(pathlib.Path('my-test-bucket'))
with open(str(tmp_path / 'test-file.txt'), 'w') as _txt_file:
_txt_file.write('testing')
object_store.upload_object('test-file.txt', str(tmp_path / 'test-file.txt'))
get_file(f's3://my-test-bucket/test-file.txt', str(tmp_path / 'loaded-test-file.txt'))

with open(str(tmp_path / 'loaded-test-file.txt')) as _txt_file:
loaded_content = _txt_file.read()

assert loaded_content.startswith('testing')


def test_get_file_object_store_with_symlink(tmp_path: pathlib.Path, monkeypatch: pytest.MonkeyPatch):
pytest.importorskip('libcloud')

Expand Down