Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pydantic v2 migration #5167

Merged
merged 68 commits into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
68 commits
Select commit Hold shift + click to select a range
99323fa
initial changes to migrate to pydantic V2
mrwyattii Feb 21, 2024
3e0979c
update requirements
mrwyattii Feb 21, 2024
4571701
fix migration bug
mrwyattii Feb 21, 2024
643ae42
Merge branch 'master' into mrwyattii/pydantic-2-support
loadams Feb 21, 2024
96fee35
fix inference config type annotations
mrwyattii Feb 21, 2024
dfe47eb
update RTD reqs
mrwyattii Feb 21, 2024
a6f8651
fix error in offload config
mrwyattii Feb 21, 2024
e780745
final fixes and updates to remove deprecated warnings from pydantic
mrwyattii Feb 21, 2024
9037e3c
Merge branch 'master' into mrwyattii/pydantic-2-support
loadams Feb 22, 2024
c12cfe6
Merge branch 'master' into mrwyattii/pydantic-2-support
loadams Feb 26, 2024
e2d075a
Merge branch 'master' into mrwyattii/pydantic-2-support
loadams Feb 27, 2024
fea7c1d
Test with updating thinc version - fixes pydantic on a6000
loadams Feb 27, 2024
5266568
Remove thinc
loadams Feb 27, 2024
65be824
Confirm uninstall of thinc
loadams Feb 27, 2024
ed08718
Also uninstall spacy
loadams Feb 27, 2024
a97e569
Reverting testing commits
loadams Feb 27, 2024
b398ba6
Update packages to support latest pydantic
loadams Feb 27, 2024
43e6367
further changes to support MII
mrwyattii Feb 28, 2024
1e8ba21
Merge branch 'master' into mrwyattii/pydantic-2-support
mrwyattii Mar 2, 2024
b9781e1
Merge branch 'master' into mrwyattii/pydantic-2-support
loadams Mar 4, 2024
ae5fd5b
Merge branch 'master' into mrwyattii/pydantic-2-support
loadams Mar 8, 2024
4969551
Merge branch 'master' into mrwyattii/pydantic-2-support
loadams Apr 4, 2024
cf7bee9
Merge branch 'master' into mrwyattii/pydantic-2-support
loadams Apr 4, 2024
91789b5
Update file that was modified in #5234
loadams Apr 4, 2024
93d3d6a
Merge branch 'master' into mrwyattii/pydantic-2-support
loadams Apr 4, 2024
203f5b7
Update container to newer version rather than updating specific packages
loadams Apr 5, 2024
aea6795
Revert "Update container to newer version rather than updating specif…
loadams Apr 5, 2024
a8658ca
Add comment
loadams Apr 5, 2024
55193a5
Merge branch 'master' into mrwyattii/pydantic-2-support
loadams Apr 8, 2024
4161028
Merge branch 'master' into mrwyattii/pydantic-2-support
loadams Apr 24, 2024
2d5327d
Merge branch 'master' into mrwyattii/pydantic-2-support
tjruwase Apr 27, 2024
0978380
Merge branch 'master' into mrwyattii/pydantic-2-support
loadams May 7, 2024
a7ddc5e
Merge branch 'master' into mrwyattii/pydantic-2-support
loadams May 9, 2024
55d39c0
Merge branch 'master' into mrwyattii/pydantic-2-support
loadams May 16, 2024
fcee6a7
Merge branch 'master' into mrwyattii/pydantic-2-support
loadams May 20, 2024
d80508d
Merge branch 'master' into mrwyattii/pydantic-2-support
loadams May 22, 2024
8c0b98f
Merge branch 'master' into mrwyattii/pydantic-2-support
adk9 May 24, 2024
ace913b
Fix a couple of failing CI tests
adk9 May 28, 2024
aee5f9d
Merge branch 'master' into mrwyattii/pydantic-2-support
loadams May 28, 2024
62ca5f2
Merge branch 'master' into mrwyattii/pydantic-2-support
loadams May 28, 2024
4cb7ac3
Correct fix for dtype validation in DeepSpeedInferenceConfig
adk9 May 28, 2024
45a9c25
Rename model_config to model_conf
adk9 May 28, 2024
96edbbf
Revert "Rename model_config to model_conf"
adk9 May 28, 2024
8c982d2
Merge branch 'master' into mrwyattii/pydantic-2-support
lekurile May 30, 2024
a04de7f
Temporarily checkout PR branch in the nv-accelerate-v100 pipeline
adk9 May 30, 2024
08c16c1
Merge branch 'master' into mrwyattii/pydantic-2-support
adk9 Jun 3, 2024
75640e3
PR 2814 is now merged into accelerate/master
adk9 Jun 6, 2024
d72db03
Merge branch 'master' into mrwyattii/pydantic-2-support
adk9 Jun 6, 2024
b97d514
Merge branch 'master' into mrwyattii/pydantic-2-support
adk9 Jun 10, 2024
0ac9533
Merge branch 'master' into mrwyattii/pydantic-2-support
adk9 Jun 12, 2024
437ecee
Merge branch 'master' into mrwyattii/pydantic-2-support
adk9 Jun 14, 2024
ca9c8ef
Merge branch 'master' into mrwyattii/pydantic-2-support
adk9 Jun 18, 2024
670ac94
Merge branch 'master' into mrwyattii/pydantic-2-support
loadams Jun 26, 2024
f973393
Merge branch 'master' into mrwyattii/pydantic-2-support
loadams Jun 26, 2024
09fa6b5
Merge branch 'master' into mrwyattii/pydantic-2-support
adk9 Jun 27, 2024
5d8fb2d
Merge branch 'master' into mrwyattii/pydantic-2-support
loadams Jul 1, 2024
1cbf3e1
Merge branch 'master' into mrwyattii/pydantic-2-support
loadams Jul 1, 2024
b3804ad
Merge branch 'master' into mrwyattii/pydantic-2-support
loadams Jul 12, 2024
41fc635
Merge branch 'master' into mrwyattii/pydantic-2-support
adk9 Jul 16, 2024
9f65563
Merge branch 'master' into mrwyattii/pydantic-2-support
loadams Jul 23, 2024
79c0835
Merge branch 'master' into mrwyattii/pydantic-2-support
loadams Jul 23, 2024
295a806
Merge branch 'master' into mrwyattii/pydantic-2-support
loadams Jul 25, 2024
1e3925e
Merge branch 'master' into mrwyattii/pydantic-2-support
adk9 Aug 1, 2024
1eec90f
Merge branch 'master' into mrwyattii/pydantic-2-support
loadams Aug 14, 2024
c82a73c
Merge branch 'master' into mrwyattii/pydantic-2-support
loadams Aug 15, 2024
75a9288
Merge branch 'master' into mrwyattii/pydantic-2-support
loadams Aug 19, 2024
e557489
Merge branch 'master' into mrwyattii/pydantic-2-support
loadams Aug 20, 2024
628cf25
Merge branch 'master' into mrwyattii/pydantic-2-support
loadams Aug 21, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
unit-tests:
strategy:
matrix:
pyVersion: ["3.6", "3.7", "3.8", "3.9", "3.10"]
pyVersion: ["3.7", "3.8", "3.9", "3.10"]
loadams marked this conversation as resolved.
Show resolved Hide resolved
fail-fast: false

runs-on: ubuntu-20.04
Expand Down
5 changes: 3 additions & 2 deletions deepspeed/comm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@

# DeepSpeed Team

from pydantic import BaseModel

from .constants import *
from ..pydantic_v1 import BaseModel


class CommsConfig(BaseModel):

class Config:
validate_all = True
validate_default = True
validate_assignment = True
use_enum_values = True
extra = 'forbid'
Expand Down
16 changes: 8 additions & 8 deletions deepspeed/inference/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import torch
import deepspeed
from deepspeed.pydantic_v1 import Field, validator
from pydantic import Field, field_validator
from deepspeed.runtime.config_utils import DeepSpeedConfigModel
from deepspeed.runtime.zero.config import DeepSpeedZeroConfig
from typing import Dict, Union
Expand Down Expand Up @@ -91,24 +91,24 @@ class QuantTypeEnum(str, Enum):


class BaseQuantConfig(DeepSpeedConfigModel):
enabled = True
num_bits = 8
enabled: bool = True
num_bits: int = 8
q_type: QuantTypeEnum = QuantTypeEnum.sym
q_groups: int = 1


class WeightQuantConfig(BaseQuantConfig):
enabled = True
enabled: bool = True
quantized_initialization: Dict = {}
post_init_quant: Dict = {}


class ActivationQuantConfig(BaseQuantConfig):
enabled = True
enabled: bool = True


class QKVQuantConfig(DeepSpeedConfigModel):
enabled = True
enabled: bool = True


class QuantizationConfig(DeepSpeedConfigModel):
Expand Down Expand Up @@ -287,13 +287,13 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel):
moe_experts: list = Field([1], deprecated=True, new_param="moe.moe_experts")
moe_type: MoETypeEnum = Field(MoETypeEnum.standard, deprecated=True, new_param="moe.type")

@validator("moe")
@field_validator("moe")
def moe_backward_compat(cls, field_value, values):
if isinstance(field_value, bool):
return DeepSpeedMoEConfig(moe=field_value)
return field_value

@validator("use_triton")
@field_validator("use_triton")
def has_triton(cls, field_value, values):
if field_value and not deepspeed.HAS_TRITON:
raise ValueError('Triton needs to be installed to use deepspeed with triton kernels')
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/inference/v2/config_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

# DeepSpeed Team

from deepspeed.pydantic_v1 import Field
from pydantic import Field

from deepspeed.runtime.config_utils import DeepSpeedConfigModel
from .ragged import DSStateManagerConfig
Expand Down
5 changes: 3 additions & 2 deletions deepspeed/inference/v2/ragged/manager_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from enum import Enum
from typing import Tuple

from deepspeed.pydantic_v1 import PositiveInt, validator
from pydantic import PositiveInt, field_validator

from deepspeed.runtime.config_utils import DeepSpeedConfigModel
from ..inference_utils import DtypeEnum
Expand Down Expand Up @@ -173,7 +173,8 @@ class DSStateManagerConfig(DeepSpeedConfigModel):
Enable tracking for offloading KV-cache to host memory. Currently unsupported.
"""

@validator("max_ragged_sequence_count")
@field_validator("max_ragged_sequence_count")
@classmethod
def max_ragged_sequence_count_validator(cls, v: int, values: dict):
# If the attributes below failed their validation they won't appear in the values dict.
if "max_tracked_sequences" in values and v > values["max_tracked_sequences"]:
Expand Down
19 changes: 11 additions & 8 deletions deepspeed/monitor/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@

# DeepSpeed Team

from deepspeed.pydantic_v1 import root_validator
from typing import Optional

from pydantic import model_validator

from deepspeed.runtime.config_utils import DeepSpeedConfigModel


Expand Down Expand Up @@ -34,10 +37,10 @@ class WandbConfig(DeepSpeedConfigModel):
enabled: bool = False
""" Whether logging to WandB is enabled. Requires `wandb` package is installed. """

group: str = None
group: Optional[str] = None
""" Name for the WandB group. This can be used to group together runs. """

team: str = None
team: Optional[str] = None
""" Name for the WandB team. """

project: str = "deepspeed"
Expand Down Expand Up @@ -72,8 +75,8 @@ class DeepSpeedMonitorConfig(DeepSpeedConfigModel):
csv_monitor: CSVConfig = {}
""" Local CSV output of monitoring data. """

@root_validator
def check_enabled(cls, values):
values["enabled"] = values.get("tensorboard").enabled or values.get("wandb").enabled or values.get(
"csv_monitor").enabled
return values
@model_validator(mode="after")
def check_enabled(self):
enabled = self.tensorboard.enabled or self.wandb.enabled or self.csv_monitor.enabled
self.__dict__["enabled"] = enabled
return self
16 changes: 0 additions & 16 deletions deepspeed/pydantic_v1.py

This file was deleted.

7 changes: 4 additions & 3 deletions deepspeed/runtime/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Union, Callable, Dict, Any
import importlib
import torch
from ..pydantic_v1 import validator
from pydantic import field_validator
from .config_utils import DeepSpeedConfigModel

COMPILE_CONFIG = "compile"
Expand Down Expand Up @@ -76,8 +76,9 @@ class CompileConfig(DeepSpeedConfigModel):
Passed to `kwargs` argument of torch.compile.
"""

@validator("enabled")
def validate_enabled(cls, field_value, values):
@field_validator("enabled")
@classmethod
def validate_enabled(cls, field_value):
if field_value and not is_compile_supported():
raise ValueError("torch.compile is not supported on this version of PyTorch.")
return field_value
Expand Down
70 changes: 36 additions & 34 deletions deepspeed/runtime/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
"""
Collection of DeepSpeed configuration utilities
"""
import json
import collections
import collections.abc
import json
from functools import reduce
from deepspeed.pydantic_v1 import BaseModel
from pydantic import BaseModel, ConfigDict

from deepspeed.utils import logger


Expand Down Expand Up @@ -54,67 +54,69 @@ def __init__(self, strict=False, **data):
if (not strict): # This is temporary until we refactor all DS configs, allows HF to load models
data = {k: v for k, v in data.items() if (v != "auto" or k == "replace_method")}
super().__init__(**data)
self._deprecated_fields_check(self)
self._deprecated_fields_check()

def _process_deprecated_field(self, pydantic_config, field):
def _process_deprecated_field(self, dep_field):
# Get information about the deprecated field
fields_set = pydantic_config.__fields_set__
dep_param = field.name
kwargs = field.field_info.extra
fields_set = self.__fields_set__
kwargs = self.__fields__[dep_field].json_schema_extra
new_param_fn = kwargs.get("new_param_fn", lambda x: x)
param_value = new_param_fn(getattr(pydantic_config, dep_param))
new_param = kwargs.get("new_param", "")
loadams marked this conversation as resolved.
Show resolved Hide resolved
param_value = new_param_fn(getattr(self, dep_field))
new_field = kwargs.get("new_param", "")
dep_msg = kwargs.get("deprecated_msg", "")
if dep_param in fields_set:
logger.warning(f"Config parameter {dep_param} is deprecated" +
(f" use {new_param} instead" if new_param else "") + (f". {dep_msg}" if dep_msg else ""))
if dep_field in fields_set:
logger.warning(f"Config parameter {dep_field} is deprecated" +
(f" use {new_field} instead" if new_field else "") + (f". {dep_msg}" if dep_msg else ""))
# Check if there is a new param and if it should be set with a value
if new_param and kwargs.get("set_new_param", True):
if new_field and kwargs.get("set_new_param", True):
# Remove the deprecate field if there is a replacing field
try:
delattr(pydantic_config, dep_param)
delattr(self, dep_field)
except Exception as e:
logger.error(f"Tried removing deprecated '{dep_param}' from config")
logger.error(f"Tried removing deprecated '{dep_field}' from config")
raise e

# Set new param value
new_param_nested = new_param.split(".")
new_param_nested = new_field.split(".")
if len(new_param_nested) > 1:
# If the new param exists in a subconfig, we need to get
# the fields set for that subconfig
pydantic_config = self
pydantic_config = reduce(getattr, new_param_nested[:-1], pydantic_config)
fields_set = pydantic_config.__fields_set__
new_param_name = new_param_nested[-1]
assert (
new_param_name not in fields_set
), f"Cannot provide deprecated parameter '{dep_param}' and replacing parameter '{new_param}' together"
), f"Cannot provide deprecated parameter '{dep_field}' and replacing parameter '{new_field}' together"
# A custom function for converting the old param value to new param value can be provided
try:
setattr(pydantic_config, new_param_name, param_value)
setattr(self, new_param_name, param_value)
except Exception as e:
logger.error(f"Tried setting value for '{new_param}' with value from deprecated '{dep_param}'")
logger.error(f"Tried setting value for '{new_field}' with value from deprecated '{dep_field}'")
raise e

def _deprecated_fields_check(self, pydantic_config):
fields = pydantic_config.__fields__
for field in fields.values():
if field.field_info.extra.get("deprecated", False):
self._process_deprecated_field(pydantic_config, field)
def _deprecated_fields_check(self):
fields = self.__fields__
for field_name, field_info in fields.items():
if field_info.json_schema_extra and field_info.json_schema_extra.get("deprecated", False):
self._process_deprecated_field(field_name)

class Config:
validate_all = True
validate_assignment = True
use_enum_values = True
allow_population_by_field_name = True
extra = "forbid"
arbitrary_types_allowed = True
model_config = ConfigDict(
validate_default=True,
validate_assignment=True,
use_enum_values=True,
populate_by_name=True,
extra="forbid",
arbitrary_types_allowed=True,
protected_namespaces=(),
)


def get_config_default(config, field_name):
assert field_name in config.__fields__, f"'{field_name}' is not a field in {config}"
assert not config.__fields__.get(
field_name).required, f"'{field_name}' is a required field and does not have a default value"
return config.__fields__.get(field_name).default
field_name).is_required(), f"'{field_name}' is a required field and does not have a default value"
return config.__fields__.get(field_name).get_default()


class pp_int(int):
Expand Down
35 changes: 17 additions & 18 deletions deepspeed/runtime/zero/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import sys
from typing import Optional
from enum import Enum
from deepspeed.pydantic_v1 import Field, validator, root_validator
from pydantic import Field, model_validator
from deepspeed.runtime.config_utils import get_scalar_param, pp_int, DeepSpeedConfigModel
from deepspeed.utils import logger
from .offload_config import DeepSpeedZeroOffloadParamConfig, DeepSpeedZeroOffloadOptimizerConfig, OffloadDeviceEnum
Expand All @@ -29,7 +29,7 @@
"reduce_bucket_size": 500000000,
"load_from_fp32_weights": [true|false],
"cpu_offload": [true|false] (deprecated),
"cpu_offload_params" : [true|false] (deprecated),
"cpu_offload_param" : [true|false] (deprecated),
"cpu_offload_use_pin_memory": [true|false] (deprecated),
"sub_group_size" : 1000000000000,
"offload_param": {...},
Expand Down Expand Up @@ -127,7 +127,7 @@ class DeepSpeedZeroConfig(DeepSpeedConfigModel):
the allgather for large model sizes
"""

overlap_comm: bool = None # None for dynamic default value (see validator `overlap_comm_valid` below)
overlap_comm: Optional[bool] = None # None for dynamic default value (see validator `overlap_comm_valid` below)
"""
Attempts to overlap the reduction of the gradients with backward computation
"""
Expand Down Expand Up @@ -167,23 +167,23 @@ class DeepSpeedZeroConfig(DeepSpeedConfigModel):
parameters). Used by ZeRO3-Offload and ZeRO-Infinity
"""

cpu_offload_param: bool = Field(
cpu_offload_param: Optional[bool] = Field(
None,
deprecated=True,
new_param="offload_param",
new_param_fn=(lambda val: DeepSpeedZeroOffloadParamConfig(device=OffloadDeviceEnum.cpu) if val else None),
)
""" Deprecated, please use ``offload_param`` """

cpu_offload_use_pin_memory: bool = Field(
cpu_offload_use_pin_memory: Optional[bool] = Field(
None,
deprecated=True,
new_param="offload_param or offload_optimizer",
set_new_param=False,
)
""" Deprecated, please use ``offload_param`` or ``offload_optimizer`` """

cpu_offload: bool = Field(
cpu_offload: Optional[bool] = Field(
None,
deprecated=True,
new_param="offload_optimizer",
Expand Down Expand Up @@ -302,16 +302,15 @@ class DeepSpeedZeroConfig(DeepSpeedConfigModel):
"""

# Validators
@validator("overlap_comm")
def overlap_comm_valid(cls, field_value, values):
if field_value is None:
assert ("stage" in values), "DeepSpeedZeroConfig: 'stage' must be defined before 'overlap_comm'"
field_value = values["stage"] == ZeroStageEnum.weights
return field_value

@root_validator
def offload_ratio_check(cls, values):
offload_config = getattr(values, "offload_optimizer", {})
@model_validator(mode="after")
def overlap_comm_valid(self):
if self.overlap_comm is None:
self.overlap_comm = self.stage == ZeroStageEnum.weights
return self

@model_validator(mode="after")
def offload_ratio_check(self):
offload_config = self.offload_optimizer
if offload_config and offload_config.ratio < 1.0:
assert values.get("stage") == ZeroStageEnum.weights, "Partial offloading only supported for ZeRO Stage 3."
return values
assert self.stage == ZeroStageEnum.weights, "Partial offloading only supported for ZeRO Stage 3."
return self
Loading
Loading