Skip to content

Commit

Permalink
HF: switch conditional checks to self.backend from `AUTO_MODEL_CLAS…
Browse files Browse the repository at this point in the history
…S` (EleutherAI#2353)

* switch conditional checks to `self.backend`

* nit

* nit

* commit feedback

* fix test; update precommit hooks

* add escape hatch for custom self.AUTO_MODEL_CLASS

* add escape hatch for custom self.AUTO_MODEL_CLASS

* fix

* move assertion

* add logging messages

* update AUTO_MODEL_CLASS behavior in _get_backend

---------

Co-authored-by: haileyschoelkopf <hailey@eleuther.ai>
  • Loading branch information
2 people authored and mariagrandury committed Oct 9, 2024
1 parent 4501702 commit 4f35c89
Showing 1 changed file with 40 additions and 40 deletions.
80 changes: 40 additions & 40 deletions lm_eval/models/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class HFLM(TemplateLM):
def __init__(
self,
pretrained: Union[str, transformers.PreTrainedModel],
backend: Optional[Literal["default", "causal", "seq2seq"]] = "default",
backend: Literal["default", "causal", "seq2seq"] = "default",
# override whether the model should be treated as decoder-only (causal) or encoder-decoder (seq2seq)
revision: Optional[str] = "main",
subfolder: Optional[str] = None,
Expand Down Expand Up @@ -90,7 +90,6 @@ def __init__(
**kwargs,
) -> None:
super().__init__()

# optionally: take in an already-initialized transformers.PreTrainedModel
if not isinstance(pretrained, str):
eval_logger.warning(
Expand Down Expand Up @@ -164,7 +163,7 @@ def __init__(
trust_remote_code=trust_remote_code,
)

# determine which of 'causal' and 'seq2seq' backends to use
# determine which of 'causal' and 'seq2seq' backends to use for HF models
self._get_backend(
config=self.config, backend=backend, trust_remote_code=trust_remote_code
)
Expand Down Expand Up @@ -287,7 +286,7 @@ def __init__(

def _get_accelerate_args(
self,
parallelize: bool = None,
parallelize: Optional[bool] = None,
device_map: Optional[str] = "auto",
max_memory_per_gpu: Optional[Union[int, str]] = None,
max_cpu_memory: Optional[Union[int, str]] = None,
Expand Down Expand Up @@ -441,31 +440,26 @@ def tokenizer_name(self) -> str:
def _get_backend(
self,
config: Union[transformers.PretrainedConfig, transformers.AutoConfig],
backend: Optional[Literal["default", "causal", "seq2seq"]] = "default",
backend: Literal["default", "causal", "seq2seq"] = "default",
trust_remote_code: Optional[bool] = False,
) -> None:
"""
Helper method during initialization.
Determines the backend ("causal" (decoder-only) or "seq2seq" (encoder-decoder))
model type to be used.
Determines the backend ("causal" (decoder-only) or "seq2seq" (encoder-decoder)) model type to be used.
sets `self.AUTO_MODEL_CLASS` appropriately if not already set.
**If not calling HFLM.__init__() or HFLM._get_backend() within a subclass of HFLM,
user must set `self.backend` to be either "causal" or "seq2seq" manually!**
"""
# escape hatch: if we're using a subclass that shouldn't follow
# the default _get_backend logic,
# then skip over the method.
# TODO: this seems very much undesirable in some cases--our code in HFLM
# references AutoModelForCausalLM at times to check for equality
if self.AUTO_MODEL_CLASS is not None:
return

assert backend in ["default", "causal", "seq2seq"]

if backend != "default":
# if we've settled on non-default backend, use that manually
if backend == "causal":
self.AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM
self.backend = backend
elif backend == "seq2seq":
self.AUTO_MODEL_CLASS = transformers.AutoModelForSeq2SeqLM
self.backend = backend
eval_logger.info(
f"Overrode HF model backend type, and using type '{backend}'"
)
Expand All @@ -478,33 +472,40 @@ def _get_backend(
# first check if model type is listed under seq2seq models, since some
# models like MBart are listed in both seq2seq and causal mistakenly in HF transformers.
# these special cases should be treated as seq2seq models.
self.AUTO_MODEL_CLASS = transformers.AutoModelForSeq2SeqLM
self.backend = "seq2seq"
eval_logger.info(f"Using model type '{backend}'")
elif (
getattr(self.config, "model_type") in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
):
self.AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM
self.backend = "causal"
eval_logger.info(f"Using model type '{backend}'")
else:
if not trust_remote_code:
eval_logger.warning(
"HF model type is neither marked as CausalLM or Seq2SeqLM. \
This is expected if your model requires `trust_remote_code=True` but may be an error otherwise."
"Setting backend to causal"
)
# if model type is neither in HF transformers causal or seq2seq model registries
# then we default to AutoModelForCausalLM
self.AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM
# then we default to assuming AutoModelForCausalLM
self.backend = "causal"
eval_logger.info(
f"Model type cannot be determined. Using default model type '{backend}'"
)

assert self.AUTO_MODEL_CLASS in [
transformers.AutoModelForCausalLM,
transformers.AutoModelForSeq2SeqLM,
]
return None
if self.AUTO_MODEL_CLASS is None:
if self.backend == "causal":
self.AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM
elif self.backend == "seq2seq":
self.AUTO_MODEL_CLASS = transformers.AutoModelForSeq2SeqLM

def _get_config(
self,
pretrained: str,
revision: str = "main",
trust_remote_code: bool = False,
) -> None:
"""Return the model config for HuggingFace models"""
self._config = transformers.AutoConfig.from_pretrained(
pretrained,
revision=revision,
Expand Down Expand Up @@ -703,7 +704,7 @@ def _detect_batch_size(self, requests=None, pos: int = 0):
# if OOM, then halves batch_size and tries again
@find_executable_batch_size(starting_batch_size=self.max_batch_size)
def forward_batch(batch_size):
if self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
if self.backend == "seq2seq":
length = max(max_context_enc, max_cont_enc)
batched_conts = torch.ones(
(batch_size, length), device=self.device
Expand Down Expand Up @@ -754,7 +755,7 @@ def tok_encode(

# by default for CausalLM - false or self.add_bos_token is set
if add_special_tokens is None:
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
if self.backend == "causal":
special_tokens_kwargs = {
"add_special_tokens": False or self.add_bos_token
}
Expand Down Expand Up @@ -782,7 +783,7 @@ def tok_batch_encode(
self.tokenizer.padding_side = padding_side

add_special_tokens = {}
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
if self.backend == "causal":
add_special_tokens = {"add_special_tokens": False or self.add_bos_token}

encoding = self.tokenizer(
Expand Down Expand Up @@ -860,14 +861,14 @@ def _model_generate(self, context, max_length, stop, **generation_kwargs):
def _select_cont_toks(
self, logits: torch.Tensor, contlen: int = None, inplen: int = None
) -> torch.Tensor:
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
if self.backend == "causal":
assert (
contlen and inplen
), "Must pass input len and cont. len to select scored logits for causal LM"
# discard right-padding.
# also discard the input/context tokens. we'll only score continuations.
logits = logits[inplen - contlen : inplen]
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
elif self.backend == "seq2seq":
assert (
contlen and not inplen
), "Selecting scored logits for Seq2SeqLM requires only cont. len"
Expand Down Expand Up @@ -990,8 +991,7 @@ def _lookup_one_token_cont(req: Tuple[Tuple[str, str], List[int], List[int]]):
requests,
sort_fn=_collate,
group_by="contexts"
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM
and self.logits_cache
if self.backend == "causal" and self.logits_cache
else None,
group_fn=_lookup_one_token_cont,
)
Expand Down Expand Up @@ -1048,14 +1048,14 @@ def _lookup_one_token_cont(req: Tuple[Tuple[str, str], List[int], List[int]]):
# cont_toks 4 5 6 7 8 9 [:, -len(continuation_enc):, :self.vocab_size] slice

# when too long to fit in context, truncate from the left
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
if self.backend == "causal":
inp = torch.tensor(
(context_enc + continuation_enc)[-(self.max_length + 1) :][:-1],
dtype=torch.long,
device=self.device,
)
(inplen,) = inp.shape
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
elif self.backend == "seq2seq":
inp = torch.tensor(
(context_enc)[-self.max_length :],
dtype=torch.long,
Expand Down Expand Up @@ -1095,11 +1095,11 @@ def _lookup_one_token_cont(req: Tuple[Tuple[str, str], List[int], List[int]]):

# create encoder attn mask and batched conts, if seq2seq
call_kwargs = {}
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
if self.backend == "causal":
batched_inps = pad_and_concat(
padding_len_inp, inps, padding_side="right"
) # [batch, padding_len_inp]
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
elif self.backend == "seq2seq":
# TODO: left-pad encoder inps and mask?
batched_inps = pad_and_concat(
padding_len_inp, inps
Expand Down Expand Up @@ -1130,7 +1130,7 @@ def _lookup_one_token_cont(req: Tuple[Tuple[str, str], List[int], List[int]]):
# from prompt/prefix tuning tokens, if applicable
ctx_len = (
inplen + (logits.shape[0] - padding_len_inp)
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM
if self.backend == "causal"
else None
)
logits = self._select_cont_toks(logits, contlen=contlen, inplen=ctx_len)
Expand Down Expand Up @@ -1265,10 +1265,10 @@ def _collate(req: Tuple[str, dict]):
max_gen_toks = self.max_gen_toks

# set the max length in tokens of inputs ("context_enc")
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
if self.backend == "causal":
# max len for inputs = max length, minus room to generate the max new tokens
max_ctx_len = self.max_length - max_gen_toks
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
elif self.backend == "seq2seq":
# max len for inputs = encoder's whole max_length
max_ctx_len = self.max_length

Expand All @@ -1295,7 +1295,7 @@ def _collate(req: Tuple[str, dict]):
cont_toks_list = cont.tolist()
for cont_toks, context in zip(cont_toks_list, contexts):
# discard context + left-padding toks if using causal decoder-only LM
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
if self.backend == "causal":
cont_toks = cont_toks[context_enc.shape[1] :]

s = self.tok_decode(cont_toks)
Expand Down

0 comments on commit 4f35c89

Please sign in to comment.