Skip to content

Commit

Permalink
pretrainedModel add gconfig (#6915)
Browse files Browse the repository at this point in the history
* pretrainedModel add gconfig

* max/min_new_token, bug fix

* deprecate generation_utils

* code refinement

* code refinement
  • Loading branch information
wtmlon authored Sep 5, 2023
1 parent 5fe97bd commit e183825
Show file tree
Hide file tree
Showing 16 changed files with 115 additions and 76 deletions.
2 changes: 1 addition & 1 deletion paddlenlp/experimental/transformers/generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
except:
from paddlenlp_ops import top_p_sampling

from paddlenlp.transformers.generation_utils import GenerationMixin
from ...generation import GenerationMixin

__all__ = ["GenerationInferenceModel"]

Expand Down
33 changes: 33 additions & 0 deletions paddlenlp/generation/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .configuration_utils import GenerationConfig
from .logits_process import (
ForcedBOSTokenLogitsProcessor,
ForcedEOSTokenLogitsProcessor,
HammingDiversityLogitsProcessor,
LogitsProcessorList,
MinLengthLogitsProcessor,
RepetitionPenaltyLogitsProcessor,
TopKProcess,
TopPProcess,
)
from .stopping_criteria import (
MaxLengthCriteria,
MaxTimeCriteria,
StoppingCriteria,
StoppingCriteriaList,
validate_stopping_criteria,
)
from .streamers import BaseStreamer, TextIteratorStreamer, TextStreamer
from .utils import BeamSearchScorer, GenerationMixin, get_unfinished_flag
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,20 @@
from paddle.common_ops_import import convert_dtype

from paddlenlp import __version__
from paddlenlp.transformers.configuration_utils import PretrainedConfig
from paddlenlp.transformers.utils import resolve_cache_dir
from paddlenlp.utils.log import logger

from ...utils import GENERATION_CONFIG_NAME
from ...utils.downloader import (
from ..utils import GENERATION_CONFIG_NAME
from ..utils.downloader import (
COMMUNITY_MODEL_PREFIX,
get_path_from_url_with_filelock,
hf_file_exists,
is_url,
url_file_exists,
)
from ..configuration_utils import PretrainedConfig
from ..utils import resolve_cache_dir

DEFAULT_MAX_NEW_TOKEN = 20


def resolve_hf_generation_config_path(repo_id: str, cache_dir: str, subfolder=None) -> str:
Expand Down Expand Up @@ -143,7 +145,9 @@ def _get_generation_mode(self):

def __init__(self, **kwargs):
# Parameters that control the length of the output
self.max_length = kwargs.pop("max_length", 20)
self.max_new_token = kwargs.get("max_new_token", DEFAULT_MAX_NEW_TOKEN)
self.min_new_token = kwargs.pop("min_new_token", 0)
self.max_length = kwargs.pop("max_length", 0)
self.min_length = kwargs.pop("min_length", 0)
self.early_stopping = kwargs.pop("early_stopping", False)

Expand Down Expand Up @@ -176,11 +180,6 @@ def __init__(self, **kwargs):
self._from_model_config = kwargs.pop("_from_model_config", False)
self.paddlenlp_version = kwargs.pop("paddlenlp_version", __version__)

# Parameters that control the generation strategy used
self.decode_strategy = kwargs.pop("decode_strategy", None)
if self.decode_strategy is None:
self.decode_strategy = self._get_generation_mode()

# Additional attributes without default values
if not self._from_model_config:
# we don't want to copy values from the model config if we're initializing a `GenerationConfig` from a
Expand All @@ -192,6 +191,12 @@ def __init__(self, **kwargs):
logger.error(f"Can't set {key} with value {value} for {self}")
raise err

# Parameters that control the generation strategy used
if "decode_strategy" in kwargs:
self.decode_strategy = kwargs.pop("decode_strategy")
else:
self.decode_strategy = self._get_generation_mode()

# Validate the values of the attributes
self.validate(is_init=True)

Expand All @@ -202,7 +207,7 @@ def __eq__(self, other):
self_dict = self.__dict__.copy()
other_dict = other.__dict__.copy()
# ignore metadata
for metadata_field in "paddlenlp_version":
for metadata_field in ["_from_model_config", "paddlenlp_version"]:
self_dict.pop(metadata_field, None)
other_dict.pop(metadata_field, None)
return self_dict == other_dict
Expand Down Expand Up @@ -432,7 +437,7 @@ def from_pretrained(
community_url = "/".join([COMMUNITY_MODEL_PREFIX, pretrained_model_name_or_path, GENERATION_CONFIG_NAME])
if url_file_exists(community_url):
resolved_config_file = get_path_from_url_with_filelock(
pretrained_model_name_or_path, cache_dir, check_exist=not force_download
community_url, cache_dir, check_exist=not force_download
)
else:
raise FileNotFoundError(f"configuration file<{GENERATION_CONFIG_NAME}> not found")
Expand Down Expand Up @@ -483,7 +488,7 @@ def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "GenerationConfig":
config = cls(**{**config_dict, **kwargs})
unused_kwargs = config.update(**kwargs)

logger.info(f"Generate config {config}")
# logger.info(f"Generate config {config}")
if return_unused_kwargs:
return config, unused_kwargs
else:
Expand Down
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class TextStreamer(BaseStreamer):
```python
>>> from paddlenlp.transformers import AutoModelForCausalLM, AutoTokenizer
>>> from paddlenlp.transformers.generation_utils import TextStreamer
>>> from paddlenlp.generation import TextStreamer
>>> tok = AutoTokenizer.from_pretrained("gpt2")
>>> model = AutoModelForCausalLM.from_pretrained("gpt2")
Expand Down Expand Up @@ -167,7 +167,7 @@ class TextIteratorStreamer(TextStreamer):
```python
>>> from paddlenlp.transformers import AutoModelForCausalLM, AutoTokenizer
>>> from paddlenlp.transformers.generation_utils import TextIteratorStreamer
>>> from paddlenlp.generation import TextIteratorStreamer
>>> from threading import Thread
>>> tok = AutoTokenizer.from_pretrained("gpt2")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.
from __future__ import annotations

import copy
from typing import Union

import paddle
Expand All @@ -37,7 +38,7 @@
from paddlenlp.transformers.utils import get_scale_by_dtype
from paddlenlp.utils.log import logger

from .configuration_utils import GenerationConfig
from .configuration_utils import DEFAULT_MAX_NEW_TOKEN, GenerationConfig
from .logits_process import (
ForcedBOSTokenLogitsProcessor,
ForcedEOSTokenLogitsProcessor,
Expand Down Expand Up @@ -737,8 +738,17 @@ def generate(
# ['是的', '嗯嗯']
"""
if generation_config is None:
generation_config = GenerationConfig.from_model_config(self.config)
if self.generation_config._from_model_config:
new_generation_config = GenerationConfig.from_model_config(self.config)
if new_generation_config != self.generation_config:
logger.warning(
"model.generation_config is in conflict with model.config, " "model.config is used."
)
self.generation_config = new_generation_config
generation_config = self.generation_config

# without update model.generation_config
generation_config = copy.deepcopy(generation_config)
model_kwargs = generation_config.update(**kwargs)

assert generation_config.decode_strategy in [
Expand Down Expand Up @@ -892,8 +902,16 @@ def generate(
print("Setting `pad_token_id` to `eos_token_id`:{} for " "open-end generation.".format(eos_token_id))
pad_token_id = eos_token_id

max_length = generation_config.max_length
min_length = generation_config.min_length
if generation_config.max_length != 0 and generation_config.max_new_token == DEFAULT_MAX_NEW_TOKEN:
logger.warning("`max_length` will be deprecated in future releases, use `max_new_token` instead.")
generation_config.max_new_token = generation_config.max_length

if generation_config.min_length != 0 and generation_config.min_new_token == 0:
logger.warning("`min_length` will be deprecated in future releases, use `min_new_token` instead.")
generation_config.min_new_token = generation_config.min_length

max_length = generation_config.max_new_token
min_length = generation_config.min_new_token
if is_tracing and not paddle.is_tensor(max_length):
if hasattr(paddle.framework, "_no_check_dy2st_diff"):
# TODO(daisiming): _no_check_dy2st_diff is used to turn off the checking of behavior
Expand Down
2 changes: 1 addition & 1 deletion paddlenlp/transformers/dallebart/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@
import paddle.nn.functional as F
from paddle.common_ops_import import convert_dtype

from ...generation import BeamSearchScorer
from ...transformers import PretrainedModel, register_base_model
from ...utils.env import CONFIG_NAME
from ...utils.log import logger
from ..generation_utils import BeamSearchScorer
from .configuration import (
DALLEBART_PRETRAINED_INIT_CONFIGURATION,
DALLEBART_PRETRAINED_RESOURCE_FILES_MAP,
Expand Down
19 changes: 0 additions & 19 deletions paddlenlp/transformers/generation_utils.py

This file was deleted.

24 changes: 20 additions & 4 deletions paddlenlp/transformers/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,10 @@
)
from paddlenlp.utils.log import logger

from ..generation import GenerationConfig, GenerationMixin
from ..utils import device_guard
from .configuration_utils import PretrainedConfig
from .conversion_utils import ConversionMixin
from .generation_utils import GenerationMixin
from .utils import ( # convert_ndarray_dtype,
ContextManagers,
InitTrackerMeta,
Expand Down Expand Up @@ -863,6 +863,7 @@ def __init__(self, *args, **kwargs):
if config is not None:
self.config: PretrainedConfig = config
self.model_config_file = CONFIG_NAME
self.generation_config = GenerationConfig.from_model_config(self.config) if self.can_generate() else None
return

# extract config from kwargs
Expand All @@ -876,6 +877,7 @@ def __init__(self, *args, **kwargs):
raise TypeError("config parameter should be the instance of PretrainedConfig")

self.config: PretrainedConfig = kwargs["config"]
self.generation_config = GenerationConfig.from_model_config(self.config) if self.can_generate() else None
self.model_config_file = CONFIG_NAME
self.warnings_issued = {}

Expand Down Expand Up @@ -1993,6 +1995,22 @@ def from_pretrained(
keep_in_fp32_modules=keep_in_fp32_modules,
)

# load generation_config.json
if model.can_generate() and pretrained_model_name_or_path is not None:
try:
model.generation_config = GenerationConfig.from_pretrained(
pretrained_model_name_or_path,
cache_dir=cache_dir,
force_download=force_download,
subfolder=subfolder,
**kwargs,
)
except OSError:
logger.info(
"Generation config file not found, using a generation config created from the model config."
)
pass

if paddle.in_dynamic_mode():
return model

Expand Down Expand Up @@ -2086,9 +2104,7 @@ def save_pretrained(
if is_main_process:
config_to_save.save_pretrained(save_directory)
if self.can_generate():
# to do support generation_config
pass
# model_to_save.generation_config.save_pretrained(save_directory)
model_to_save.generation_config.save_pretrained(save_directory)

# Handle the case where some state_dict keys shouldn't be saved
if self._keys_to_ignore_on_save is not None:
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def ids_tensor(shape, vocab_size, rng=None, name=None):
return paddle.to_tensor(data=values).reshape(shape)


from paddlenlp.transformers.generation.logits_process import (
from paddlenlp.generation.logits_process import (
ForcedBOSTokenLogitsProcessor,
ForcedEOSTokenLogitsProcessor,
HammingDiversityLogitsProcessor,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import paddle

from paddlenlp.transformers.generation.stopping_criteria import (
from paddlenlp.generation.stopping_criteria import (
MaxLengthCriteria,
MaxTimeCriteria,
StoppingCriteriaList,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,11 @@

import paddle

from paddlenlp.generation import TextIteratorStreamer, TextStreamer
from paddlenlp.transformers import AutoModelForCausalLM, AutoTokenizer
from paddlenlp.transformers.generation_utils import TextIteratorStreamer, TextStreamer
from paddlenlp.transformers.utils import CaptureStd
from tests.testing_utils import slow

from ..test_modeling_common import ids_tensor
from tests.transformers.test_modeling_common import ids_tensor


class StreamerTester(unittest.TestCase):
Expand Down
13 changes: 0 additions & 13 deletions tests/transformers/generation/__init__.py

This file was deleted.

Loading

0 comments on commit e183825

Please sign in to comment.