-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
pretrainedModel add gconfig #6915
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,6 +28,7 @@ | |
from paddlenlp.datasets import InTokensIterableDataset | ||
from paddlenlp.trainer import Trainer, TrainerCallback | ||
from paddlenlp.trainer.trainer_utils import IterableDatasetShard, has_length | ||
from paddlenlp.transformers.generation_utils import GenerationConfig | ||
from paddlenlp.utils.log import logger | ||
|
||
|
||
|
@@ -200,14 +201,16 @@ def prediction_step( | |
input_ids=inputs["input_ids"], | ||
attention_mask=inputs["attention_mask"] if "attention_mask" in inputs else None, | ||
position_ids=inputs["position_ids"] if "position_ids" in inputs else None, | ||
max_length=self.data_args.tgt_length, | ||
decode_strategy="sampling", | ||
top_k=self.gen_args.top_k, | ||
top_p=self.gen_args.top_p, | ||
bos_token_id=self.tokenizer.bos_token_id, | ||
eos_token_id=self.tokenizer.eos_token_id, | ||
pad_token_id=self.tokenizer.pad_token_id, | ||
use_cache=True, | ||
generation_config=GenerationConfig( | ||
max_new_token=self.data_args.tgt_length, | ||
decode_strategy="sampling", | ||
top_k=self.gen_args.top_k, | ||
top_p=self.gen_args.top_p, | ||
bos_token_id=self.tokenizer.bos_token_id, | ||
eos_token_id=self.tokenizer.eos_token_id, | ||
pad_token_id=self.tokenizer.pad_token_id, | ||
use_cache=True, | ||
), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 保持原有api |
||
)[0] | ||
all_preds = [] | ||
for pred_tokens in generated_tokens: | ||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -14,6 +14,7 @@ | |||||||||||||||||||||
# limitations under the License. | ||||||||||||||||||||||
from __future__ import annotations | ||||||||||||||||||||||
|
||||||||||||||||||||||
import copy | ||||||||||||||||||||||
from typing import Union | ||||||||||||||||||||||
|
||||||||||||||||||||||
import paddle | ||||||||||||||||||||||
|
@@ -37,7 +38,7 @@ | |||||||||||||||||||||
from paddlenlp.transformers.utils import get_scale_by_dtype | ||||||||||||||||||||||
from paddlenlp.utils.log import logger | ||||||||||||||||||||||
|
||||||||||||||||||||||
from .configuration_utils import GenerationConfig | ||||||||||||||||||||||
from .configuration_utils import DEFAULT_MAX_NEW_TOKEN, GenerationConfig | ||||||||||||||||||||||
from .logits_process import ( | ||||||||||||||||||||||
ForcedBOSTokenLogitsProcessor, | ||||||||||||||||||||||
ForcedEOSTokenLogitsProcessor, | ||||||||||||||||||||||
|
@@ -737,8 +738,17 @@ def generate( | |||||||||||||||||||||
# ['是的', '嗯嗯'] | ||||||||||||||||||||||
""" | ||||||||||||||||||||||
if generation_config is None: | ||||||||||||||||||||||
generation_config = GenerationConfig.from_model_config(self.config) | ||||||||||||||||||||||
if self.generation_config._from_model_config: | ||||||||||||||||||||||
new_generation_config = GenerationConfig.from_model_config(self.config) | ||||||||||||||||||||||
if new_generation_config != self.generation_config: | ||||||||||||||||||||||
logger.warning( | ||||||||||||||||||||||
"model.generation_config is in conflict with model.config, " "model.config is used." | ||||||||||||||||||||||
) | ||||||||||||||||||||||
self.generation_config = new_generation_config | ||||||||||||||||||||||
generation_config = self.generation_config | ||||||||||||||||||||||
|
||||||||||||||||||||||
# without update model.generation_config | ||||||||||||||||||||||
generation_config = copy.deepcopy(generation_config) | ||||||||||||||||||||||
model_kwargs = generation_config.update(**kwargs) | ||||||||||||||||||||||
|
||||||||||||||||||||||
assert generation_config.decode_strategy in [ | ||||||||||||||||||||||
|
@@ -892,8 +902,16 @@ def generate( | |||||||||||||||||||||
print("Setting `pad_token_id` to `eos_token_id`:{} for " "open-end generation.".format(eos_token_id)) | ||||||||||||||||||||||
pad_token_id = eos_token_id | ||||||||||||||||||||||
|
||||||||||||||||||||||
max_length = generation_config.max_length | ||||||||||||||||||||||
min_length = generation_config.min_length | ||||||||||||||||||||||
if generation_config.max_length != 0 and generation_config.max_new_token == DEFAULT_MAX_NEW_TOKEN: | ||||||||||||||||||||||
logger.warning("`max_length` will be deprecated in future, use" " `max_new_token` instead.") | ||||||||||||||||||||||
generation_config.max_new_token = generation_config.max_length | ||||||||||||||||||||||
|
||||||||||||||||||||||
if generation_config.min_length != 0 and generation_config.min_new_token == 0: | ||||||||||||||||||||||
logger.warning("`min_length` will be deprecated in future, use" " `min_new_token` instead.") | ||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||
generation_config.min_new_token = generation_config.min_length | ||||||||||||||||||||||
|
||||||||||||||||||||||
max_length = generation_config.max_new_token | ||||||||||||||||||||||
min_length = generation_config.min_new_token | ||||||||||||||||||||||
if is_tracing and not paddle.is_tensor(max_length): | ||||||||||||||||||||||
if hasattr(paddle.framework, "_no_check_dy2st_diff"): | ||||||||||||||||||||||
# TODO(daisiming): _no_check_dy2st_diff is used to turn off the checking of behavior | ||||||||||||||||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,8 +12,8 @@ | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from paddlenlp.transformers.generation.configuration_utils import * # noqa: F401, F403 | ||
from paddlenlp.transformers.generation.logits_process import * # noqa: F401, F403 | ||
from paddlenlp.transformers.generation.stopping_criteria import * # noqa: F401, F403 | ||
from paddlenlp.transformers.generation.streamers import * # noqa: F401, F403 | ||
from paddlenlp.transformers.generation.utils import * # noqa: F401, F403 | ||
from paddlenlp.generation.configuration_utils import * # noqa: F401, F403 | ||
from paddlenlp.generation.logits_process import * # noqa: F401, F403 | ||
from paddlenlp.generation.stopping_criteria import * # noqa: F401, F403 | ||
from paddlenlp.generation.streamers import * # noqa: F401, F403 | ||
from paddlenlp.generation.utils import * # noqa: F401, F403 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 可不可以直接废弃调transformers/generation_utils? 后续都从transformers/generation import? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
尽量从paddlenlp.generation 路径import, 后续paddlenlp.transformers.generation_utils 逐渐废弃