Skip to content

Commit

Permalink
restore google cloud object store test (#3538)
Browse files Browse the repository at this point in the history
  • Loading branch information
bigning authored Aug 9, 2024
1 parent bad3f0c commit e9aee74
Showing 1 changed file with 5 additions and 7 deletions.
12 changes: 5 additions & 7 deletions tests/utils/object_store/test_gs_object_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,22 @@
from botocore.exceptions import ClientError
from torch.utils.data import DataLoader

from composer.loggers import RemoteUploaderDownloader
from composer.optim import DecoupledSGDW
from composer.trainer import Trainer
from composer.utils import GCSObjectStore
from tests.common import RandomClassificationDataset, SimpleModel


def get_gcs_os_from_trainer(trainer: Trainer) -> GCSObjectStore:
rud = [dest for dest in trainer.logger.destinations if isinstance(dest, RemoteUploaderDownloader)][0]
gcs_os = rud.remote_backend
assert trainer._checkpoint_saver is not None
assert trainer._checkpoint_saver.remote_uploader is not None
gcs_os = trainer._checkpoint_saver.remote_uploader.remote_backend
assert isinstance(gcs_os, GCSObjectStore)
return gcs_os


@pytest.mark.gpu # json auth is hard to set up on github actions / CPU tests
@pytest.mark.remote
@pytest.mark.skip(reason='Waiting for new GCP key to be approved')
def test_gs_object_store_integration_hmac_auth(expected_use_gcs_sdk_val=False, client_should_be_none=True):
model = SimpleModel()
train_dataset = RandomClassificationDataset()
Expand All @@ -35,7 +34,7 @@ def test_gs_object_store_integration_hmac_auth(expected_use_gcs_sdk_val=False, c
model=model,
optimizers=optimizer,
train_dataloader=train_dataloader,
save_folder='gs://mosaicml-internal-integration-testing/checkpoints/{run_name}',
save_folder='gs://mosaicml-runtime-internal-integration-testing/checkpoints/{run_name}',
save_filename='test-model.pt',
max_duration='1ba',
precision='amp_bf16',
Expand All @@ -54,7 +53,7 @@ def test_gs_object_store_integration_hmac_auth(expected_use_gcs_sdk_val=False, c
model=model,
optimizers=optimizer,
train_dataloader=train_dataloader,
load_path=f'gs://mosaicml-internal-integration-testing/checkpoints/{run_name}/test-model.pt',
load_path=f'gs://mosaicml-runtime-internal-integration-testing/checkpoints/{run_name}/test-model.pt',
max_duration='2ba',
precision='amp_bf16',
)
Expand All @@ -64,7 +63,6 @@ def test_gs_object_store_integration_hmac_auth(expected_use_gcs_sdk_val=False, c

@pytest.mark.gpu
@pytest.mark.remote
@pytest.mark.skip(reason='Waiting for new GCP key to be approved')
def test_gs_object_store_integration_json_auth():
with mock.patch.dict(os.environ):
if 'GCS_KEY' in os.environ:
Expand Down

0 comments on commit e9aee74

Please sign in to comment.