diff --git a/bert_score/scorer.py b/bert_score/scorer.py index 48e7054..9c1b2f1 100644 --- a/bert_score/scorer.py +++ b/bert_score/scorer.py @@ -88,7 +88,11 @@ def __init__( self._model_type = model_type if num_layers is None: - self._num_layers = model2layers[self.model_type] + if '/' in model_type: + real_model_type = model_type.split('/')[-1] + self._num_layers = model2layers[real_model_type] + else: + self._num_layers = model2layers[self.model_type] else: self._num_layers = num_layers @@ -106,10 +110,16 @@ def __init__( self.baseline_path = baseline_path self.use_custom_baseline = self.baseline_path is not None if self.baseline_path is None: - self.baseline_path = os.path.join( - os.path.dirname(__file__), - f"rescale_baseline/{self.lang}/{self.model_type}.tsv", - ) + if '/' in model_type: + self.baseline_path = os.path.join( + os.path.dirname(__file__), + f"rescale_baseline/{self.lang}/{real_model_type}.tsv", + ) + else: + self.baseline_path = os.path.join( + os.path.dirname(__file__), + f"rescale_baseline/{self.lang}/{self.model_type}.tsv", + ) @property def lang(self):