Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature Request] LLaMA task implementation for TokenClassification #26521

Closed
coen22 opened this issue Oct 1, 2023 · 5 comments
Closed

[Feature Request] LLaMA task implementation for TokenClassification #26521

coen22 opened this issue Oct 1, 2023 · 5 comments

Comments

@coen22
Copy link

coen22 commented Oct 1, 2023

Feature request

Hi,

I was trying to compare LLaMA 2 to the Roberta based model I used in a (soon to be published) study.
For Roberta, I implemented a version of token classification that outputs the one hot encodings.
However, it doesn't work with LLaMA because of optimisations done elsewhere in the code.

This works

class RobertaForMultiLabelTokenClassification(RobertaPreTrainedModel):
    _keys_to_ignore_on_load_unexpected = [r"pooler"]
    _keys_to_ignore_on_load_missing = [r"position_ids"]

    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels

        self.roberta = RobertaModel(config, add_pooling_layer=False)
        classifier_dropout = (
            config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
        )
        self.dropout = nn.Dropout(classifier_dropout)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)

        # Initialize weights and apply final processing
        self.post_init()

    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        outputs = self.roberta(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        sequence_output = outputs[0]

        sequence_output = self.dropout(sequence_output)
        logits = self.classifier(sequence_output)

        loss = None
        if labels is not None:
            loss_fct = BCEWithLogitsLoss()
            target: torch.LongTensor = labels.view(logits.size())
            loss = loss_fct(logits, target.float())

        logits = torch.sigmoid(logits)

        if not return_dict:
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return TokenClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

But this doesn't

class LlamaForTokenClassification(LlamaPreTrainedModel):
    _keys_to_ignore_on_load_unexpected = [r"pooler"]
    _keys_to_ignore_on_load_missing = [r"position_ids"]

    classifier_dropout = 0.1

    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels
        self.llama = LlamaModel(config)

        self.dropout = nn.Dropout(self.classifier_dropout)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)

        self.post_init()

    def forward(
            self,
            input_ids: torch.LongTensor = None,
            attention_mask: Optional[torch.Tensor] = None,
            position_ids: Optional[torch.LongTensor] = None,
            past_key_values: Optional[List[torch.FloatTensor]] = None,
            inputs_embeds: Optional[torch.FloatTensor] = None,
            labels: Optional[torch.LongTensor] = None,
            use_cache: Optional[bool] = None,
            output_attentions: Optional[bool] = None,
            output_hidden_states: Optional[bool] = None,
            return_dict: Optional[bool] = None,
    ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        outputs = self.llama(
            input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        sequence_output = outputs[0]

        sequence_output = self.dropout(sequence_output)
        logits = self.classifier(sequence_output)

        loss = None
        if labels is not None:
            loss_fct = nn.BCEWithLogitsLoss()
            target: torch.LongTensor = labels.view(logits.size())
            loss = loss_fct(logits, target.float())

        logits = torch.sigmoid(logits)

        if not return_dict:
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return TokenClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

Motivation

It would be nice to have access to larger models for this task.

In the code I've put an example of what I would like to have.
LlamaForTokenClassification.zip

Your contribution

There's my attampt at doing it :)

When I run it using the default Trainer, I get an error about CUDA

../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [52,0,0], thread: [30,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
  File "/mnt/e/Comparison-QoC/code/LlamaForTokenClassification.py", line 152, in forward
    outputs = self.llama(
  File "/home/coen/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/coen/.local/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = old_forward(*args, **kwargs)
  File "/home/coen/.local/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 708, in forward
    layer_outputs = decoder_layer(
  File "/home/coen/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/coen/.local/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = old_forward(*args, **kwargs)
  File "/home/coen/.local/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 424, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
  File "/home/coen/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/coen/.local/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = old_forward(*args, **kwargs)
  File "/home/coen/.local/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 321, in forward
    query_states = self.q_proj(hidden_states)
  File "/home/coen/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/coen/.local/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = old_forward(*args, **kwargs)
  File "/home/coen/.local/lib/python3.10/site-packages/bitsandbytes/nn/modules.py", line 248, in forward
    out = bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state)
  File "/home/coen/.local/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py", line 579, in matmul_4bit
    return MatMul4Bit.apply(A, B, out, bias, quant_state)
  File "/home/coen/.local/lib/python3.10/site-packages/torch/autograd/function.py", line 506, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/home/coen/.local/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py", line 516, in forward
    output = torch.nn.functional.linear(A, F.dequantize_4bit(B, state).to(A.dtype).t(), bias)
RuntimeError: CUDA error: CUBLAS_STATUS_NOT_INITIALIZED when calling `cublasCreate(handle)`

The issue shows a mismatch in size, but I don't see where the issue occurs.

Using the SFTTrainer, I get a NotImplemented exception.

@coen22 coen22 changed the title LlaMA task implementation for TokenClassification LLaMA task implementation for TokenClassification Oct 1, 2023
@coen22 coen22 changed the title LLaMA task implementation for TokenClassification [Feature Request] LLaMA task implementation for TokenClassification Oct 3, 2023
Copy link

github-actions bot commented Nov 1, 2023

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@amyeroberts
Copy link
Collaborator

Hi @coen22, thanks for raising an issue!

This is a question best placed in our forums. We try to reserve the github issues for feature requests and bug reports.

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot closed this as completed Dec 4, 2023
@KoichiYasuoka
Copy link
Contributor

KoichiYasuoka commented Dec 23, 2023

I've just written tentative LlamaForTokenClassification with the idea of #22209:

from typing import List, Optional, Tuple, Union
import torch
from torch import nn
from transformers.modeling_outputs import TokenClassifierOutput
from transformers.file_utils import add_start_docstrings_to_model_forward
from transformers.models.llama.modeling_llama import LlamaModel, LlamaPreTrainedModel, LLAMA_INPUTS_DOCSTRING

class LlamaForTokenClassification(LlamaPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels
        self.model = LlamaModel(config)
        if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None:
            classifier_dropout = config.classifier_dropout
        elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None:
            classifier_dropout = config.hidden_dropout
        else:
            classifier_dropout = 0.1
        self.dropout = nn.Dropout(classifier_dropout)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)

        # Initialize weights and apply final processing
        self.post_init()

    def get_input_embeddings(self):
        return self.model.embed_tokens

    def set_input_embeddings(self, value):
        self.model.embed_tokens = value

    @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, TokenClassifierOutput]:

        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        transformer_outputs = self.model(
            input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        hidden_states = transformer_outputs[0]
        hidden_states = self.dropout(hidden_states)
        logits = self.classifier(hidden_states)

        loss = None
        if labels is not None:
            labels = labels.to(logits.device)
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

        if not return_dict:
            output = (logits,) + transformer_outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return TokenClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=transformer_outputs.hidden_states,
            attentions=transformer_outputs.attentions
        )

Does this work well, @coen22 and @lewtun ?

@KoichiYasuoka
Copy link
Contributor

Hi @coen22 and @lewtun this issue is now continued at #29940

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants