Skip to content

Commit

Permalink
use lists for dataset get item
Browse files Browse the repository at this point in the history
  • Loading branch information
SpirinEgor committed Sep 8, 2021
1 parent e678f8b commit 84a8d01
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 35 deletions.
40 changes: 27 additions & 13 deletions code2seq/data/path_context.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,47 @@
from dataclasses import dataclass
from typing import Iterable, Tuple, Optional, Sequence
from typing import Iterable, Tuple, Optional, Sequence, List, cast

import torch


@dataclass
class Path:
from_token: torch.Tensor # [max token parts]
path_node: torch.Tensor # [path length]
to_token: torch.Tensor # [max token parts]
from_token: List[int] # [max token parts]
path_node: List[int] # [path length]
to_token: List[int] # [max token parts]


@dataclass
class LabeledPathContext:
label: torch.Tensor # [max label parts]
label: List[int] # [max label parts]
path_contexts: Sequence[Path]


def transpose(list_of_lists: List[List[int]]) -> List[List[int]]:
return [cast(List[int], it) for it in zip(*list_of_lists)]


class BatchedLabeledPathContext:
def __init__(self, all_samples: Sequence[Optional[LabeledPathContext]]):
samples = [s for s in all_samples if s is not None]

# [max label parts; batch size]
self.labels = torch.cat([s.label.unsqueeze(1) for s in samples], dim=1)
self.labels = torch.tensor(transpose([s.label for s in samples]), dtype=torch.long)
# [batch size]
self.contexts_per_label = torch.tensor([len(s.path_contexts) for s in samples])

# [max token parts; n contexts]
self.from_token = torch.cat([path.from_token.unsqueeze(1) for s in samples for path in s.path_contexts], dim=1)
self.from_token = torch.tensor(
transpose([path.from_token for s in samples for path in s.path_contexts]), dtype=torch.long
)
# [path length; n contexts]
self.path_nodes = torch.cat([path.path_node.unsqueeze(1) for s in samples for path in s.path_contexts], dim=1)
self.path_nodes = torch.tensor(
transpose([path.path_node for s in samples for path in s.path_contexts]), dtype=torch.long
)
# [max token parts; n contexts]
self.to_token = torch.cat([path.to_token.unsqueeze(1) for s in samples for path in s.path_contexts], dim=1)
self.to_token = torch.tensor(
transpose([path.to_token for s in samples for path in s.path_contexts]), dtype=torch.long
)

def __len__(self) -> int:
return len(self.contexts_per_label)
Expand All @@ -53,8 +63,8 @@ def move_to_device(self, device: torch.device):

@dataclass
class TypedPath(Path):
from_type: torch.Tensor # [max type parts]
to_type: torch.Tensor # [max type parts]
from_type: List[int] # [max type parts]
to_type: List[int] # [max type parts]


@dataclass
Expand All @@ -67,6 +77,10 @@ def __init__(self, all_samples: Sequence[Optional[LabeledTypedPathContext]]):
super().__init__(all_samples)
samples = [s for s in all_samples if s is not None]
# [max type parts; n contexts]
self.from_type = torch.cat([path.from_type.unsqueeze(1) for s in samples for path in s.path_contexts], dim=1)
self.from_type = torch.tensor(
transpose([path.from_type for s in samples for path in s.path_contexts]), dtype=torch.long
)
# [max type parts; n contexts]
self.to_type = torch.cat([path.to_type.unsqueeze(1) for s in samples for path in s.path_contexts], dim=1)
self.to_type = torch.tensor(
transpose([path.to_type for s in samples for path in s.path_contexts]), dtype=torch.long
)
24 changes: 9 additions & 15 deletions code2seq/data/path_context_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from random import shuffle
from typing import Dict, List, Optional

import torch
from commode_utils.filesystem import get_lines_offsets, get_line_by_offset
from omegaconf import DictConfig
from torch.utils.data import Dataset
Expand Down Expand Up @@ -63,34 +62,29 @@ def __getitem__(self, index) -> Optional[LabeledPathContext]:
return LabeledPathContext(label, paths)

@staticmethod
def tokenize_class(raw_class: str, vocab: Dict[str, int]) -> torch.Tensor:
return torch.tensor([vocab[raw_class]], dtype=torch.long)
def tokenize_class(raw_class: str, vocab: Dict[str, int]) -> List[int]:
return [vocab[raw_class]]

@staticmethod
def tokenize_label(raw_label: str, vocab: Dict[str, int], max_parts: Optional[int]) -> torch.Tensor:
def tokenize_label(raw_label: str, vocab: Dict[str, int], max_parts: Optional[int]) -> List[int]:
sublabels = raw_label.split(PathContextDataset._separator)
max_parts = max_parts or len(sublabels)
label_unk = vocab[Vocabulary.UNK]

label = torch.full((max_parts + 1,), vocab[Vocabulary.PAD], dtype=torch.long)
label[0] = vocab[Vocabulary.SOS]
sub_tokens_ids = [vocab.get(st, label_unk) for st in sublabels[:max_parts]]
label[1 : len(sub_tokens_ids) + 1] = torch.tensor(sub_tokens_ids)

label = [vocab[Vocabulary.SOS]] + [vocab.get(st, label_unk) for st in sublabels[:max_parts]]
if len(sublabels) < max_parts:
label[len(sublabels) + 1] = vocab[Vocabulary.EOS]

label.append(vocab[Vocabulary.EOS])
label += [vocab[Vocabulary.PAD]] * (max_parts + 1 - len(label))
return label

@staticmethod
def tokenize_token(token: str, vocab: Dict[str, int], max_parts: Optional[int]) -> torch.Tensor:
def tokenize_token(token: str, vocab: Dict[str, int], max_parts: Optional[int]) -> List[int]:
sub_tokens = token.split(PathContextDataset._separator)
max_parts = max_parts or len(sub_tokens)
token_unk = vocab[Vocabulary.UNK]

result = torch.full((max_parts,), vocab[Vocabulary.PAD], dtype=torch.long)
sub_tokens_ids = [vocab.get(st, token_unk) for st in sub_tokens[:max_parts]]
result[: len(sub_tokens_ids)] = torch.tensor(sub_tokens_ids)
result = [vocab.get(st, token_unk) for st in sub_tokens[:max_parts]]
result += [vocab[Vocabulary.PAD]] * (max_parts - len(result))
return result

def _get_path(self, raw_path: List[str]) -> Path:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from setuptools import setup, find_packages

VERSION = "1.0.0"
VERSION = "1.0.1"

with open("README.md") as readme_file:
readme = readme_file.read()
Expand Down
12 changes: 6 additions & 6 deletions tests/test_tokenization.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,23 @@ def test_tokenize_label(self):
raw_label = "my|super|label"
tokenized = PathContextDataset.tokenize_label(raw_label, self.vocab, 5)
# <SOS> my super <UNK> <EOS> <PAD>
correct = torch.tensor([2, 4, 5, 1, 3, 0], dtype=torch.long)
correct = [2, 4, 5, 1, 3, 0]

torch.testing.assert_equal(tokenized, correct)
self.assertListEqual(tokenized, correct)

def test_tokenize_class(self):
raw_class = "super"
tokenized = PathContextDataset.tokenize_class(raw_class, self.vocab)
correct = torch.tensor([5], dtype=torch.long)
correct = [5]

torch.testing.assert_equal(tokenized, correct)
self.assertListEqual(tokenized, correct)

def test_tokenize_token(self):
raw_token = "my|super|token"
tokenized = PathContextDataset.tokenize_token(raw_token, self.vocab, 5)
correct = torch.tensor([4, 5, 1, 0, 0], dtype=torch.long)
correct = [4, 5, 1, 0, 0]

torch.testing.assert_equal(tokenized, correct)
self.assertListEqual(tokenized, correct)


if __name__ == "__main__":
Expand Down

0 comments on commit 84a8d01

Please sign in to comment.