Skip to content

Commit

Permalink
isort/black
Browse files Browse the repository at this point in the history
  • Loading branch information
nicola-decao committed Jun 7, 2022
1 parent 9a863bf commit c04fbeb
Show file tree
Hide file tree
Showing 7 changed files with 19 additions and 9 deletions.
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,9 @@ After importing and loading the model and a prefix tree (trie), you would genera

```python
import pickle
from genre.trie import Trie

from genre.fairseq_model import GENRE
from genre.trie import Trie

# load the prefix tree (trie)
with open("../data/kilt_titles_trie_dict.pkl", "rb") as f:
Expand Down Expand Up @@ -97,8 +98,9 @@ Making predictions with mGENRE is very similar, but we additionally need to map

```python
import pickle
from genre.trie import Trie, MarisaTrie

from genre.fairseq_model import mGENRE
from genre.trie import MarisaTrie, Trie

with open("../data/lang_title2wikidataID-normalized_with_redirect.pkl", "rb") as f:
lang_title2wikidataID = pickle.load(f)
Expand Down
2 changes: 1 addition & 1 deletion genre/fairseq_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def encode(self, sentence) -> torch.LongTensor:
else:
return tokens


class GENRE(BARTModel):
@classmethod
def from_pretrained(
Expand Down
3 changes: 2 additions & 1 deletion genre/hf_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
from typing import Dict, List

import torch
from genre.utils import chunk_it
from transformers import BartForConditionalGeneration, BartTokenizer

from genre.utils import chunk_it

logger = logging.getLogger(__name__)


Expand Down
2 changes: 1 addition & 1 deletion requirements-test.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
torch
requests
pytest
pytest
9 changes: 8 additions & 1 deletion scripts_genre/convert_kilt_to_fairseq.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,14 @@ def convert_kilt_to_fairseq(dataset):
for prov in out["provenance"]
if prov.get("bleu_score", 1) > 0.5
):
source.append(create_input(doc, max_length=384, start_delimiter="[START_ENT]", end_delimiter="[END_ENT]"))
source.append(
create_input(
doc,
max_length=384,
start_delimiter="[START_ENT]",
end_delimiter="[END_ENT]",
)
)
target.append(title)
if "meta" in doc and "template_questions" in doc["meta"]:
for template_question in doc["meta"]["template_questions"]:
Expand Down
4 changes: 2 additions & 2 deletions scripts_mgenre/preprocess_sentencepiece.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@ def initializer(self):
sp = spm.SentencePieceProcessor(model_file=self.args.model)
old2new = None
if self.args.product_vocab_size is not None:
assert sp.vocab_size() <= self.args.product_vocab_size ** 2
assert sp.vocab_size() <= self.args.product_vocab_size**2
rand = random.Random(self.args.seed)
old2new = [
(x // self.args.product_vocab_size, x % self.args.product_vocab_size)
for x in rand.sample(
range(self.args.product_vocab_size ** 2), sp.vocab_size()
range(self.args.product_vocab_size**2), sp.vocab_size()
)
]

Expand Down
2 changes: 1 addition & 1 deletion tests/test_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

import pytest

from genre.trie import Trie
from genre.fairseq_model import GENRE, GENREHubInterface
from genre.trie import Trie


@pytest.fixture(scope="session")
Expand Down

0 comments on commit c04fbeb

Please sign in to comment.