Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pretrainedModel add gconfig #6915

Merged
merged 5 commits into from
Sep 5, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 12 additions & 9 deletions llm/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
PretrainedModel,
PretrainedTokenizer,
)
from paddlenlp.transformers.generation_utils import GenerationConfig
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

尽量从paddlenlp.generation 路径import, 后续paddlenlp.transformers.generation_utils 逐渐废弃

from paddlenlp.utils.import_utils import import_module, is_paddlenlp_ops_available


Expand Down Expand Up @@ -199,15 +200,17 @@ def _infer(self, inputs: dict[str, paddle.Tensor]):
max_length = max(self.config.max_length - inputs["input_ids"].shape[-1], 1)
result = self.model.generate(
**inputs,
max_length=max_length,
bos_token_id=self.tokenizer.bos_token_id,
eos_token_id=self.tokenizer.eos_token_id,
pad_token_id=self.tokenizer.pad_token_id,
decode_strategy=self.config.decode_strategy,
temperature=self.config.temperature,
top_k=self.config.top_k,
top_p=self.config.top_p,
repetition_penalty=self.config.repetition_penalty,
generation_config=GenerationConfig(
max_new_token=max_length,
bos_token_id=self.tokenizer.bos_token_id,
eos_token_id=self.tokenizer.eos_token_id,
pad_token_id=self.tokenizer.pad_token_id,
decode_strategy=self.config.decode_strategy,
temperature=self.config.temperature,
top_k=self.config.top_k,
top_p=self.config.top_p,
repetition_penalty=self.config.repetition_penalty,
),
)
result = result[0]
return result
Expand Down
19 changes: 11 additions & 8 deletions llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from paddlenlp.datasets import InTokensIterableDataset
from paddlenlp.trainer import Trainer, TrainerCallback
from paddlenlp.trainer.trainer_utils import IterableDatasetShard, has_length
from paddlenlp.transformers.generation_utils import GenerationConfig
from paddlenlp.utils.log import logger


Expand Down Expand Up @@ -200,14 +201,16 @@ def prediction_step(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"] if "attention_mask" in inputs else None,
position_ids=inputs["position_ids"] if "position_ids" in inputs else None,
max_length=self.data_args.tgt_length,
decode_strategy="sampling",
top_k=self.gen_args.top_k,
top_p=self.gen_args.top_p,
bos_token_id=self.tokenizer.bos_token_id,
eos_token_id=self.tokenizer.eos_token_id,
pad_token_id=self.tokenizer.pad_token_id,
use_cache=True,
generation_config=GenerationConfig(
max_new_token=self.data_args.tgt_length,
decode_strategy="sampling",
top_k=self.gen_args.top_k,
top_p=self.gen_args.top_p,
bos_token_id=self.tokenizer.bos_token_id,
eos_token_id=self.tokenizer.eos_token_id,
pad_token_id=self.tokenizer.pad_token_id,
use_cache=True,
),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

保持原有api

)[0]
all_preds = []
for pred_tokens in generated_tokens:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,20 @@
from paddle.common_ops_import import convert_dtype

from paddlenlp import __version__
from paddlenlp.transformers.configuration_utils import PretrainedConfig
from paddlenlp.transformers.utils import resolve_cache_dir
from paddlenlp.utils.log import logger

from ...utils import GENERATION_CONFIG_NAME
from ...utils.downloader import (
from ..utils import GENERATION_CONFIG_NAME
from ..utils.downloader import (
COMMUNITY_MODEL_PREFIX,
get_path_from_url_with_filelock,
hf_file_exists,
is_url,
url_file_exists,
)
from ..configuration_utils import PretrainedConfig
from ..utils import resolve_cache_dir

DEFAULT_MAX_NEW_TOKEN = 20


def resolve_hf_generation_config_path(repo_id: str, cache_dir: str, subfolder=None) -> str:
Expand Down Expand Up @@ -143,7 +145,9 @@ def _get_generation_mode(self):

def __init__(self, **kwargs):
# Parameters that control the length of the output
self.max_length = kwargs.pop("max_length", 20)
self.max_new_token = kwargs.get("max_new_token", DEFAULT_MAX_NEW_TOKEN)
self.min_new_token = kwargs.pop("min_new_token", 0)
self.max_length = kwargs.pop("max_length", 0)
self.min_length = kwargs.pop("min_length", 0)
self.early_stopping = kwargs.pop("early_stopping", False)

Expand Down Expand Up @@ -176,11 +180,6 @@ def __init__(self, **kwargs):
self._from_model_config = kwargs.pop("_from_model_config", False)
self.paddlenlp_version = kwargs.pop("paddlenlp_version", __version__)

# Parameters that control the generation strategy used
self.decode_strategy = kwargs.pop("decode_strategy", None)
if self.decode_strategy is None:
self.decode_strategy = self._get_generation_mode()

# Additional attributes without default values
if not self._from_model_config:
# we don't want to copy values from the model config if we're initializing a `GenerationConfig` from a
Expand All @@ -192,6 +191,12 @@ def __init__(self, **kwargs):
logger.error(f"Can't set {key} with value {value} for {self}")
raise err

# Parameters that control the generation strategy used
if "decode_strategy" in kwargs:
self.decode_strategy = kwargs.pop("decode_strategy")
else:
self.decode_strategy = self._get_generation_mode()

# Validate the values of the attributes
self.validate(is_init=True)

Expand All @@ -202,7 +207,7 @@ def __eq__(self, other):
self_dict = self.__dict__.copy()
other_dict = other.__dict__.copy()
# ignore metadata
for metadata_field in "paddlenlp_version":
for metadata_field in ["_from_model_config", "paddlenlp_version"]:
self_dict.pop(metadata_field, None)
other_dict.pop(metadata_field, None)
return self_dict == other_dict
Expand Down Expand Up @@ -432,7 +437,7 @@ def from_pretrained(
community_url = "/".join([COMMUNITY_MODEL_PREFIX, pretrained_model_name_or_path, GENERATION_CONFIG_NAME])
if url_file_exists(community_url):
resolved_config_file = get_path_from_url_with_filelock(
pretrained_model_name_or_path, cache_dir, check_exist=not force_download
community_url, cache_dir, check_exist=not force_download
)
else:
raise FileNotFoundError(f"configuration file<{GENERATION_CONFIG_NAME}> not found")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.
from __future__ import annotations

import copy
from typing import Union

import paddle
Expand All @@ -37,7 +38,7 @@
from paddlenlp.transformers.utils import get_scale_by_dtype
from paddlenlp.utils.log import logger

from .configuration_utils import GenerationConfig
from .configuration_utils import DEFAULT_MAX_NEW_TOKEN, GenerationConfig
from .logits_process import (
ForcedBOSTokenLogitsProcessor,
ForcedEOSTokenLogitsProcessor,
Expand Down Expand Up @@ -737,8 +738,17 @@ def generate(
# ['是的', '嗯嗯']
"""
if generation_config is None:
generation_config = GenerationConfig.from_model_config(self.config)
if self.generation_config._from_model_config:
new_generation_config = GenerationConfig.from_model_config(self.config)
if new_generation_config != self.generation_config:
logger.warning(
"model.generation_config is in conflict with model.config, " "model.config is used."
)
self.generation_config = new_generation_config
generation_config = self.generation_config

# without update model.generation_config
generation_config = copy.deepcopy(generation_config)
model_kwargs = generation_config.update(**kwargs)

assert generation_config.decode_strategy in [
Expand Down Expand Up @@ -892,8 +902,16 @@ def generate(
print("Setting `pad_token_id` to `eos_token_id`:{} for " "open-end generation.".format(eos_token_id))
pad_token_id = eos_token_id

max_length = generation_config.max_length
min_length = generation_config.min_length
if generation_config.max_length != 0 and generation_config.max_new_token == DEFAULT_MAX_NEW_TOKEN:
logger.warning("`max_length` will be deprecated in future, use" " `max_new_token` instead.")
generation_config.max_new_token = generation_config.max_length

if generation_config.min_length != 0 and generation_config.min_new_token == 0:
logger.warning("`min_length` will be deprecated in future, use" " `min_new_token` instead.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
logger.warning("`max_length` will be deprecated in future, use" " `max_new_token` instead.")
generation_config.max_new_token = generation_config.max_length
if generation_config.min_length != 0 and generation_config.min_new_token == 0:
logger.warning("`min_length` will be deprecated in future, use" " `min_new_token` instead.")
logger.warning("`max_length` will be deprecated in future releases, use `max_new_token` instead.")
generation_config.max_new_token = generation_config.max_length
if generation_config.min_length != 0 and generation_config.min_new_token == 0:
logger.warning("`min_length` will be deprecated in future releases, use `min_new_token` instead.")

generation_config.min_new_token = generation_config.min_length

max_length = generation_config.max_new_token
min_length = generation_config.min_new_token
if is_tracing and not paddle.is_tensor(max_length):
if hasattr(paddle.framework, "_no_check_dy2st_diff"):
# TODO(daisiming): _no_check_dy2st_diff is used to turn off the checking of behavior
Expand Down
10 changes: 5 additions & 5 deletions paddlenlp/transformers/generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddlenlp.transformers.generation.configuration_utils import * # noqa: F401, F403
from paddlenlp.transformers.generation.logits_process import * # noqa: F401, F403
from paddlenlp.transformers.generation.stopping_criteria import * # noqa: F401, F403
from paddlenlp.transformers.generation.streamers import * # noqa: F401, F403
from paddlenlp.transformers.generation.utils import * # noqa: F401, F403
from paddlenlp.generation.configuration_utils import * # noqa: F401, F403
from paddlenlp.generation.logits_process import * # noqa: F401, F403
from paddlenlp.generation.stopping_criteria import * # noqa: F401, F403
from paddlenlp.generation.streamers import * # noqa: F401, F403
from paddlenlp.generation.utils import * # noqa: F401, F403
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可不可以直接废弃调transformers/generation_utils? 后续都从transformers/generation import?

24 changes: 20 additions & 4 deletions paddlenlp/transformers/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
from ..utils import device_guard
from .configuration_utils import PretrainedConfig
from .conversion_utils import ConversionMixin
from .generation_utils import GenerationMixin
from .generation_utils import GenerationConfig, GenerationMixin
from .utils import ( # convert_ndarray_dtype,
ContextManagers,
InitTrackerMeta,
Expand Down Expand Up @@ -863,6 +863,7 @@ def __init__(self, *args, **kwargs):
if config is not None:
self.config: PretrainedConfig = config
self.model_config_file = CONFIG_NAME
self.generation_config = GenerationConfig.from_model_config(self.config) if self.can_generate() else None
return

# extract config from kwargs
Expand All @@ -876,6 +877,7 @@ def __init__(self, *args, **kwargs):
raise TypeError("config parameter should be the instance of PretrainedConfig")

self.config: PretrainedConfig = kwargs["config"]
self.generation_config = GenerationConfig.from_model_config(self.config) if self.can_generate() else None
self.model_config_file = CONFIG_NAME
self.warnings_issued = {}

Expand Down Expand Up @@ -1993,6 +1995,22 @@ def from_pretrained(
keep_in_fp32_modules=keep_in_fp32_modules,
)

# load generation_config.json
if model.can_generate() and pretrained_model_name_or_path is not None:
try:
model.generation_config = GenerationConfig.from_pretrained(
pretrained_model_name_or_path,
cache_dir=cache_dir,
force_download=force_download,
subfolder=subfolder,
**kwargs,
)
except OSError:
logger.info(
"Generation config file not found, using a generation config created from the model config."
)
pass

if paddle.in_dynamic_mode():
return model

Expand Down Expand Up @@ -2086,9 +2104,7 @@ def save_pretrained(
if is_main_process:
config_to_save.save_pretrained(save_directory)
if self.can_generate():
# to do support generation_config
pass
# model_to_save.generation_config.save_pretrained(save_directory)
model_to_save.generation_config.save_pretrained(save_directory)

# Handle the case where some state_dict keys shouldn't be saved
if self._keys_to_ignore_on_save is not None:
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def ids_tensor(shape, vocab_size, rng=None, name=None):
return paddle.to_tensor(data=values).reshape(shape)


from paddlenlp.transformers.generation.logits_process import (
from paddlenlp.generation.logits_process import (
ForcedBOSTokenLogitsProcessor,
ForcedEOSTokenLogitsProcessor,
HammingDiversityLogitsProcessor,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import paddle

from paddlenlp.transformers.generation.stopping_criteria import (
from paddlenlp.generation.stopping_criteria import (
MaxLengthCriteria,
MaxTimeCriteria,
StoppingCriteriaList,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@
from paddlenlp.transformers.generation_utils import TextIteratorStreamer, TextStreamer
from paddlenlp.transformers.utils import CaptureStd
from tests.testing_utils import slow

from ..test_modeling_common import ids_tensor
from tests.transformers.test_modeling_common import ids_tensor


class StreamerTester(unittest.TestCase):
Expand Down
8 changes: 4 additions & 4 deletions tests/transformers/test_generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def _greedy_generate(
input_ids,
attention_mask=attention_mask,
generation_config=GenerationConfig(
max_length=max_length,
max_new_token=max_length,
decode_strategy="greedy_search",
**logits_process_kwargs,
),
Expand Down Expand Up @@ -277,7 +277,7 @@ def _sample_generate(
input_ids,
attention_mask=attention_mask,
generation_config=GenerationConfig(
max_length=max_length,
max_new_token=max_length,
decode_strategy="sampling",
num_return_sequences=num_return_sequences,
top_k=1,
Expand Down Expand Up @@ -332,7 +332,7 @@ def _beam_search_generate(
attention_mask=attention_mask,
generation_config=GenerationConfig(
decode_strategy="beam_search",
max_length=max_length,
max_new_token=max_length,
**beam_kwargs,
**logits_process_kwargs,
),
Expand Down Expand Up @@ -390,7 +390,7 @@ def _group_beam_search_generate(
attention_mask=attention_mask,
generation_config=GenerationConfig(
decode_strategy="beam_search",
max_length=max_length,
max_new_token=max_length,
**beam_kwargs,
**logits_process_kwargs,
),
Expand Down