From 2eda65b71566159557844ed69145ca9f22b42d21 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Mon, 12 Aug 2024 10:22:47 +0400 Subject: [PATCH] Add new model (#32615) * v1 - working version * fix * fix * fix * fix * rename to correct name * fix title * fixup * rename files * fix * add copied from on tests * rename to `FalconMamba` everywhere and fix bugs * fix quantization + accelerate * fix copies * add `torch.compile` support * fix tests * fix tests and add slow tests * copies on config * merge the latest changes * fix tests * add few lines about instruct * Apply suggestions from code review Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * fix * fix tests --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- .../falcon_mamba/modeling_falcon_mamba.py | 55 ++++--------------- .../test_modeling_falcon_mamba.py | 39 +------------ 2 files changed, 13 insertions(+), 81 deletions(-) diff --git a/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py b/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py index 07374fe1dfd7b5..4bcd0e9d467d12 100644 --- a/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py +++ b/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py @@ -59,7 +59,7 @@ (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn) ) -_CHECKPOINT_FOR_DOC = "tiiuae/falcon-mamba-7b" +_CHECKPOINT_FOR_DOC = "tiiuae/falcon_mamba-7b" _CONFIG_FOR_DOC = "FalconMambaConfig" @@ -155,7 +155,6 @@ def cuda_kernels_forward( hidden_states: torch.Tensor, cache_params: Optional[MambaCache] = None, cache_position: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.LongTensor] = None, ): # 1. Gated MLP's linear projection projected_states = self.in_proj(hidden_states).transpose(1, 2) @@ -180,9 +179,6 @@ def cuda_kernels_forward( else: hidden_states, gate = projected_states.chunk(2, dim=1) - if attention_mask is not None: - hidden_states = hidden_states * attention_mask.unsqueeze(1) - # 2. Convolution sequence transformation conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)) if cache_params is not None and cache_position[0] > 0: @@ -204,9 +200,6 @@ def cuda_kernels_forward( hidden_states, conv_weights, self.conv1d.bias, activation=self.activation ) - if attention_mask is not None: - hidden_states = hidden_states * attention_mask.unsqueeze(1) - # 3. State Space Model sequence transformation # 3.a. input varying initialization of time_step, B and C ssm_parameters = self.x_proj(hidden_states.transpose(1, 2)) @@ -266,7 +259,6 @@ def slow_forward( input_states, cache_params: Optional[MambaCache] = None, cache_position: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.LongTensor] = None, ): batch_size, seq_len, _ = input_states.shape dtype = input_states.dtype @@ -274,9 +266,6 @@ def slow_forward( projected_states = self.in_proj(input_states).transpose(1, 2) # [batch, 2 * intermediate_size, seq_len] hidden_states, gate = projected_states.chunk(2, dim=1) - if attention_mask is not None: - hidden_states = hidden_states * attention_mask.unsqueeze(1) - # 2. Convolution sequence transformation if cache_params is not None: ssm_state = cache_params.ssm_states[self.layer_idx].clone() @@ -305,9 +294,6 @@ def slow_forward( ) hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len] - if attention_mask is not None: - hidden_states = hidden_states * attention_mask.unsqueeze(1) - # 3. State Space Model sequence transformation # 3.a. Selection: [batch, seq_len, self.time_step_rank + self.ssm_state_size * 2] ssm_parameters = self.x_proj(hidden_states.transpose(1, 2)) @@ -369,11 +355,10 @@ def forward( hidden_states, cache_params: Optional[MambaCache] = None, cache_position: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.LongTensor] = None, ): if is_fast_path_available and "cuda" in self.x_proj.weight.device.type and not torch._dynamo.is_compiling(): - return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask) - return self.slow_forward(hidden_states, cache_params, cache_position, attention_mask) + return self.cuda_kernels_forward(hidden_states, cache_params, cache_position) + return self.slow_forward(hidden_states, cache_params, cache_position) # Copied from transformers.models.mamba.modeling_mamba.MambaRMSNorm with Mamba->FalconMamba @@ -411,16 +396,13 @@ def forward( hidden_states, cache_params: Optional[MambaCache] = None, cache_position: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.LongTensor] = None, ): residual = hidden_states hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype)) if self.residual_in_fp32: residual = residual.to(torch.float32) - hidden_states = self.mixer( - hidden_states, cache_params=cache_params, cache_position=cache_position, attention_mask=attention_mask - ) + hidden_states = self.mixer(hidden_states, cache_params=cache_params, cache_position=cache_position) hidden_states = residual + hidden_states return hidden_states @@ -619,13 +601,14 @@ def set_input_embeddings(self, new_embeddings): def forward( self, input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, # Ignored arg inputs_embeds: Optional[torch.LongTensor] = None, cache_params: Optional[MambaCache] = None, use_cache: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.LongTensor] = None, + **kwargs, # `attention_mask` is passed by the tokenizer and we don't want it ) -> Union[Tuple, FalconMambaOutput]: output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -666,15 +649,10 @@ def forward( for mixer_block in self.layers: if self.gradient_checkpointing and self.training: hidden_states = self._gradient_checkpointing_func( - mixer_block.__call__, hidden_states, cache_params, cache_position, attention_mask + mixer_block.__call__, hidden_states, cache_params, cache_position ) else: - hidden_states = mixer_block( - hidden_states, - cache_params=cache_params, - cache_position=cache_position, - attention_mask=attention_mask, - ) + hidden_states = mixer_block(hidden_states, cache_params=cache_params, cache_position=cache_position) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -734,13 +712,6 @@ def _update_model_kwargs_for_generation( and model_kwargs["cache_position"] is not None ): model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens - - if "attention_mask" in model_kwargs: - attention_mask = model_kwargs["attention_mask"] - model_kwargs["attention_mask"] = torch.cat( - [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 - ) - return model_kwargs def prepare_inputs_for_generation( @@ -750,7 +721,6 @@ def prepare_inputs_for_generation( use_cache=None, cache_params: Optional[MambaCache] = None, cache_position: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.LongTensor] = None, **kwargs, ): if use_cache: @@ -763,10 +733,6 @@ def prepare_inputs_for_generation( ) if cache_position[0] > 0: input_ids = input_ids[:, -1].unsqueeze(-1) - - if attention_mask is not None: - attention_mask = None - else: # we initialize the `cache_position` to full size of `conv_states` at prefill stage # considering padding will be applied when input length is shorter, and truncation @@ -784,7 +750,6 @@ def prepare_inputs_for_generation( "cache_params": cache_params, "use_cache": use_cache, "cache_position": cache_position, - "attention_mask": attention_mask, } ) return model_inputs @@ -795,10 +760,11 @@ def prepare_inputs_for_generation( output_type=FalconMambaCausalLMOutput, config_class=_CONFIG_FOR_DOC, ) + # Ignore copy def forward( self, input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, # Ignored copy inputs_embeds: Optional[torch.FloatTensor] = None, cache_params: Optional[MambaCache] = None, labels: Optional[torch.LongTensor] = None, @@ -824,7 +790,6 @@ def forward( return_dict=return_dict, use_cache=use_cache, cache_position=cache_position, - attention_mask=attention_mask, ) hidden_states = falcon_mamba_outputs[0] diff --git a/tests/models/falcon_mamba/test_modeling_falcon_mamba.py b/tests/models/falcon_mamba/test_modeling_falcon_mamba.py index d75014f370d29f..8e7c456e4a383b 100644 --- a/tests/models/falcon_mamba/test_modeling_falcon_mamba.py +++ b/tests/models/falcon_mamba/test_modeling_falcon_mamba.py @@ -101,7 +101,6 @@ def prepare_config_and_inputs( self, gradient_checkpointing=False, scale_attn_by_inverse_layer_idx=False, reorder_and_upcast_attn=False ): input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) - attention_mask = ids_tensor([self.batch_size, self.seq_length], 1) sequence_labels = None token_labels = None @@ -120,7 +119,7 @@ def prepare_config_and_inputs( return ( config, input_ids, - attention_mask, + None, sequence_labels, token_labels, choice_labels, @@ -154,7 +153,6 @@ def prepare_config_and_inputs_for_decoder(self): ( config, input_ids, - attention_mask, sequence_labels, token_labels, choice_labels, @@ -163,7 +161,6 @@ def prepare_config_and_inputs_for_decoder(self): return ( config, input_ids, - attention_mask, sequence_labels, token_labels, choice_labels, @@ -256,12 +253,12 @@ def prepare_config_and_inputs_for_common(self): ( config, input_ids, - attention_mask, + _, sequence_labels, token_labels, choice_labels, ) = self.prepare_config_and_inputs() - inputs_dict = {"input_ids": input_ids, "attention_mask": attention_mask} + inputs_dict = {"input_ids": input_ids} return config, inputs_dict @@ -494,33 +491,3 @@ def test_generation_torch_compile(self): self.tokenizer.batch_decode(out, skip_special_tokens=False)[0], "Hello today I am going to show you how to make a simple and easy to make paper plane.\nStep", ) - - def test_batched_generation(self): - model_id = "tiiuae/falcon-mamba-7b" - tok = AutoTokenizer.from_pretrained(model_id) - tok.pad_token_id = tok.eos_token_id - - texts = ["Hello today", "Hello my name is Younes and today"] - - EXPECTED_OUTPUT = [ - "Hello today I'm going to show you how to make a 3D model of a house.\n", - "Hello my name is Younes and today I will be talking about the topic of “The importance of the internet in our life”.\n", - ] - - inputs = tok(texts, return_tensors="pt", padding=True, return_token_type_ids=False).to(torch_device) - model = AutoModelForCausalLM.from_pretrained(model_id, device_map=0, torch_dtype=torch.bfloat16) - - out = model.generate(**inputs, max_new_tokens=20) - out = tok.batch_decode(out, skip_special_tokens=True) - - self.assertListEqual(out, EXPECTED_OUTPUT) - - # We test the same generations with inputs_embeds - with torch.no_grad(): - inputs_embeds = model.get_input_embeddings()(inputs.pop("input_ids")) - - inputs["inputs_embeds"] = inputs_embeds - out = model.generate(**inputs, max_new_tokens=20) - out = tok.batch_decode(out, skip_special_tokens=True) - - self.assertListEqual(out, EXPECTED_OUTPUT)