forked from infocusp/varformers
-
Notifications
You must be signed in to change notification settings - Fork 0
/
dataset.py
58 lines (44 loc) · 2.19 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from typing import Any
import torch
import torch.nn as nn
from torch.utils.data import Dataset
class SentimentDataset(Dataset):
def __init__(self, ds, tokenizer) -> None:
super().__init__()
self.ds = ds
self.tokenizer = tokenizer
self.cls_token = torch.tensor([tokenizer.token_to_id('[CLS]')], dtype=torch.int64)
self.pad_token = torch.tensor([tokenizer.token_to_id('[PAD]')], dtype=torch.int64)
def __len__(self):
return len(self.ds)
def __getitem__(self, index) -> Any:
text = self.ds.loc[index, "text"]
label = self.ds.loc[index, "sentiment"]
text_enc = self.tokenizer.encode(text).ids
encoder_input = torch.cat(
[
self.cls_token,
torch.tensor(text_enc, dtype=torch.int64),
]
)
label = torch.tensor(label, dtype=torch.int64)
return {
"encoder_input": encoder_input,
# "decoder_input": decoder_input,
# "encoder_mask": (encoder_input != self.pad_token).unsqueeze(0).unsqueeze(0).int(), # (1, 1, seq_len)
# "decoder_mask": (decoder_input != self.pad_token).unsqueeze(0).int() & causal_mask(decoder_input.size(0)), # (1, seq_len) & (1, seq_len, seq_len),
"label": label,
# "src_text": src_text,
# "tgt_text": tgt_text
}
def collate_fn(batch):
# restricting the mas seq_len in a batch to 500
max_len = min(max([batch[i]["encoder_input"].size(dim=0) for i in range(len(batch))]), 500)
for i in range(len(batch)):
batch[i]["encoder_input"] = batch[i]["encoder_input"][:max_len]
batch[i]["encoder_input"] = torch.cat([batch[i]["encoder_input"], torch.ones(max_len - batch[i]["encoder_input"].size(dim=0), dtype=torch.int64)])
encoder_input = torch.tensor([batch[i]["encoder_input"].tolist() for i in range(len(batch))])
label = torch.tensor([[batch[i]["label"].item()] for i in range(len(batch))])
mask = torch.tensor([[1 if ele == 1 else 0 for ele in li] for li in encoder_input])
mask = (mask > 0).unsqueeze(-1)
return encoder_input, label.float(), mask