Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error in _prepare_generated_length #32911

Closed
3 of 4 tasks
jefferyZhan opened this issue Aug 21, 2024 · 8 comments · Fixed by #32932
Closed
3 of 4 tasks

Error in _prepare_generated_length #32911

jefferyZhan opened this issue Aug 21, 2024 · 8 comments · Fixed by #32932
Labels

Comments

@jefferyZhan
Copy link

System Info

  • transformers version: 4.45.0.dev0
  • Platform: Linux-3.10.0-1160.el7.x86_64-x86_64-with-glibc2.17
  • Python version: 3.10.13
  • Huggingface_hub version: 0.24.5
  • Safetensors version: 0.4.2
  • Accelerate version: 0.33.0
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.3.1 (False)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?: Yes

Who can help?

@ArthurZucker @zucchini-nlp
There might be a logic error in the generation utils. When passing the input_embeds instead of the input_ids, the shape[-1] is 0, and the max_length is set to the max_new_tokens instead of the generation_config.max_new_tokens + input_ids_length. It further cause the size mismatch error in the _prepare_4d_causal_attention_mask_with_cache_position due to passing the target_length (when the target_length is less than the mask_length). padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]

    input_ids_length = input_ids.shape[-1]
    has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
    has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None
    generation_config = self._prepare_generated_length(
        generation_config=generation_config,
        has_default_max_length=has_default_max_length,
        has_default_min_length=has_default_min_length,
        model_input_name=model_input_name,
        inputs_tensor=inputs_tensor,
        input_ids_length=input_ids_length,
    )

...

    if generation_config.max_new_tokens is not None:

        if not has_default_max_length and generation_config.max_length is not None:
            logger.warning(
                f"Both max_new_tokens (={generation_config.max_new_tokens}) and max_length(="
                f"{generation_config.max_length}) seem to have been set. max_new_tokens will take precedence. "
                "Please refer to the documentation for more information. "
                "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
            )
        generation_config.max_length = generation_config.max_new_tokens + input_ids_length
    # if both `inputs_embeds` and `input_ids` are passed, we do not correct the length
    # otherwise we need total length [inputs-embeds-len + new-tokens-len] to not go beyond indicated `max_length``
    elif (
        model_input_name == "inputs_embeds"
        and input_ids_length != inputs_tensor.shape[1]
        and not self.config.is_encoder_decoder
    ):
        generation_config.max_length -= inputs_tensor.shape[1]

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Simply replace the input_ids to the input_embeds in gemma2.

Expected behavior

Support the passing of input_embeds normally like input_ids

@ArthurZucker
Copy link
Collaborator

hey! Could you share a snippet with an error for reproducing? 🤗

@jefferyZhan
Copy link
Author

jefferyZhan commented Aug 21, 2024

Of course!
I called the generation of Gemma2 in the script. The max_new_tokens is 128, the input_embeds is [1, 778, *] in Gemma2 model:
Traceback (most recent call last): File "/public/*/miniconda3/envs/awesome_env2.0/lib/python3.10/runpy.py", line 196, in _run_module_as_main return _run_code(code, main_globals, None, File "/public/*/miniconda3/envs/awesome_env2.0/lib/python3.10/runpy.py", line 86, in _run_code exec(code, run_globals) File "/public/*/LLaVA/llava/eval/model_vqa_loader.py", line 142, in <module> eval_model(args) File "/public/*/LLaVA/llava/eval/model_vqa_loader.py", line 99, in eval_model output_ids = model.generate( File "/public/*/miniconda3/envs/awesome_env2.0/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context return func(*args, **kwargs) File "/public/*/LLaVA/llava/model/language_model/llava_gemma.py", line 175, in generate return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs) File "/public/*/miniconda3/envs/awesome_env2.0/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context return func(*args, **kwargs) File "/public/*/transformers/src/transformers/generation/utils.py", line 1996, in generate result = self._sample( File "/public/*/transformers/src/transformers/generation/utils.py", line 2916, in _sample model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) File "/public/*LLaVA/llava/model/language_model/llava_gemma.py", line 143, in prepare_inputs_for_generation _inputs = super().prepare_inputs_for_generation( File "/public/*/src/transformers/models/gemma2/modeling_gemma2.py", line 1112, in prepare_inputs_for_generation attention_mask = _prepare_4d_causal_attention_mask_with_cache_position( File "/public/*transformers/src/transformers/models/gemma2/modeling_gemma2.py", line 103, in _prepare_4d_causal_attention_mask_with_cache_position padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] RuntimeError: The size of tensor a (128) must match the size of tensor b (778) at non-singleton dimension 3

@zucchini-nlp
Copy link
Member

Indeed there's an issue when we generate from embeds with static cache, since we rely on already-adjusted max_length. I found that this wasn't caught by our CI because the tests were running with GemmaForCausalLM

I can open a PR to fix this @ArthurZucker

@ArthurZucker
Copy link
Collaborator

🫠 Let's first make sure this does not affect Gemma2 / Gemma on the v4.44-release, then let's fix for main

@ArthurZucker
Copy link
Collaborator

ArthurZucker commented Aug 21, 2024

I am not sure we test static cache + input embeds, let's add this

@jefferyZhan
Copy link
Author

jefferyZhan commented Aug 21, 2024

I have also tried the v4.44-release version. The issue also occurred. I switch to the 4.45.dev because of the PR #32493 to fix the getting shape error in 4.44-release:

if model_inputs["inputs_embeds"] is not None:
            batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
            device = model_inputs["inputs_embeds"].device
        else:
            batch_size, sequence_length = model_inputs["input_ids"].shape
            device = model_inputs["input_ids"].device

@zucchini-nlp
Copy link
Member

Yes, it is also in v4.44 as static cache never worked with input embeddings, I'll add a test oke

@jefferyZhan
Copy link
Author

Thank you for your updates. This PR fixed this problem with a small typo in Line 1824 generation.utils.py RuntimeError: Failed to import transformers.generation.utils because of the following error (look up to see its traceback): invalid syntax. Perhaps you forgot a comma? (utils.py, line 1824)

    model_kwargs[cache_name] = self._get_cache(
                cache_implementation=generation_config.cache_implementation,
                batch_size=max(generation_config.num_return_sequences, generation_config.num_beams) * batch_size
                max_cache_len=max_cache_length,
                device=device,
                model_kwargs=model_kwargs,
            )

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants