Skip to content

Commit

Permalink
Merge branch 'master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
loadams authored Sep 6, 2023
2 parents d447a39 + 55d9964 commit 9f7cd71
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 21 deletions.
1 change: 0 additions & 1 deletion .github/workflows/nv-inference.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ jobs:
run: |
git clone https://github.com/huggingface/transformers
cd transformers
git checkout ccb92be
git rev-parse --short HEAD
pip install .
Expand Down
25 changes: 14 additions & 11 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2759,26 +2759,29 @@ def _load_checkpoint(self,

self.loaded_checkpoint_dp_world_size = checkpoint['dp_world_size']

optim_checkpoint = None
if load_module_only:
deepspeed_states = ['module']
if self.optimizer is not None and self.fp16_enabled():
self.optimizer.refresh_fp32_params()
else:
if self.has_moe_layers:
largest_group_name = groups._get_max_expert_size_name()
expp_rank = groups._get_expert_parallel_rank(largest_group_name)
optim_load_path = self._get_optimizer_ckpt_name(load_dir, tag, expp_rank)
optim_checkpoint = self.checkpoint_engine.load(optim_load_path, map_location=torch.device('cpu'))
else:
optim_checkpoint = checkpoint

has_zero_optimizer_state = self.zero_optimization() or self.bfloat16_enabled()
if load_optimizer_states and self.optimizer is not None and not has_zero_optimizer_state:
if self.fp16_enabled():
if self.has_moe_layers:
largest_group_name = groups._get_max_expert_size_name()
expp_rank = groups._get_expert_parallel_rank(largest_group_name)
optim_load_path = self._get_optimizer_ckpt_name(load_dir, tag, expp_rank)
optim_checkpoint = self.checkpoint_engine.load(optim_load_path, map_location=torch.device('cpu'))
else:
optim_checkpoint = checkpoint

if self.fp16_enabled() or self.bfloat16_enabled():
self.optimizer.load_state_dict(optim_checkpoint['optimizer'],
load_optimizer_states=load_optimizer_states)
else:
self.optimizer.load_state_dict(optim_checkpoint['optimizer'])
optim_checkpoint = checkpoint

self.optimizer.load_state_dict(optim_checkpoint['optimizer'])

if load_lr_scheduler_states and self.lr_scheduler is not None:
self.lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
Expand Down Expand Up @@ -2835,7 +2838,7 @@ def get_sparse_tensor_module_names(original_set, loaded_set, original_parameters

client_state = {key: value for key, value in checkpoint.items() if not key in deepspeed_states}

if not load_optimizer_states and not load_module_only:
if optim_checkpoint is not None:
client_state['optimizer'] = optim_checkpoint['optimizer']

return load_path, client_state
Expand Down
17 changes: 12 additions & 5 deletions tests/unit/checkpoint/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus

from unit.simple_model import *
from unittest.mock import MagicMock, patch


def compare_deepspeed_states(saved_model, loaded_model):
Expand Down Expand Up @@ -209,11 +210,17 @@ def checkpoint_correctness_verification(config_dict,
loaded_model = create_deepspeed_model(config_dict=config_dict, model=models[1], base_optimizer=base_optimizers[1])
assert list(trained_model.parameters())[0].dtype == list(loaded_model.parameters())[0].dtype

loaded_model.load_checkpoint(save_folder,
tag=save_tag,
load_optimizer_states=load_optimizer_states,
load_lr_scheduler_states=load_lr_scheduler_states,
load_module_only=load_module_only)
context = patch.object(loaded_model, "_get_optimizer_ckpt_name",
wraps=loaded_model._get_optimizer_ckpt_name) if not load_optimizer_states else MagicMock()
with context as optim_load_state_dict_mock:
loaded_model.load_checkpoint(save_folder,
tag=save_tag,
load_optimizer_states=load_optimizer_states,
load_lr_scheduler_states=load_lr_scheduler_states,
load_module_only=load_module_only)
if not load_optimizer_states:
# should not attempt to get the file name to load it
optim_load_state_dict_mock.assert_not_called()

compare_model_states(trained_model,
loaded_model,
Expand Down
6 changes: 2 additions & 4 deletions tests/unit/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,8 +299,7 @@ def __init__(self,
if isinstance(vocab_size_or_config_json_file, str):
with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
json_config = json.loads(reader.read())
for key, value in json_config.items():
self.__dict__[key] = value
self.__dict__.update(json_config)
elif isinstance(vocab_size_or_config_json_file, int):
self.vocab_size = vocab_size_or_config_json_file
self.hidden_size = hidden_size
Expand All @@ -323,8 +322,7 @@ def __init__(self,
def from_dict(cls, json_object):
"""Constructs a `BertConfig` from a Python dictionary of parameters."""
config = BertConfig(vocab_size_or_config_json_file=-1)
for key, value in json_object.items():
config.__dict__[key] = value
config.__dict__.update(json_object)
return config

@classmethod
Expand Down

0 comments on commit 9f7cd71

Please sign in to comment.