-
Notifications
You must be signed in to change notification settings - Fork 0
/
preprocessing.py
63 lines (49 loc) · 1.9 KB
/
preprocessing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import pandas as pd
from sklearn.model_selection import train_test_split
import torch
def preprocess_agnews(
data_name: str,
data_type: str = "train",
use_agnews_title: bool = False,
train_size: float = 0.8,
random_state: int = 42,
):
# Read data
df = pd.read_csv(f"data/{data_name}/{data_type}.csv")
df["new_label"] = df["Class Index"] - 1
if data_type == "train":
# Split train data into train and validation
train_df, val_df = train_test_split(
df, train_size=train_size, random_state=random_state
)
train_label = train_df["new_label"].tolist()
val_label = val_df["new_label"].tolist()
if use_agnews_title:
train_text = train_df["Title"] + " " + train_df["Description"]
val_text = val_df["Title"] + " " + val_df["Description"]
train_text = train_text.tolist()
val_text = val_text.tolist()
else:
train_text = train_df["Description"].tolist()
val_text = val_df["Description"].tolist()
return train_text, train_label, val_text, val_label
else:
test_label = df["new_label"].tolist()
if use_agnews_title:
test_text = df["Title"] + " " + df["Description"]
test_text = test_text.tolist()
else:
test_text = df["Description"].tolist()
return test_text, test_label
class AGNewsDataset(torch.utils.data.Dataset):
def __init__(self, encodings, labels):
self.encodings = encodings
self.labels = labels
def __getitem__(self, idx):
item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
item['labels'] = torch.tensor(self.labels[idx])
return item
def __len__(self):
return len(self.labels)
if __name__ == "__main__":
test_text, test_label = preprocess_agnews("agnews", "test", use_agnews_title=True)