Skip to content

Commit

Permalink
Add cleaner way to load Data.from_line_format without TagReader, allo…
Browse files Browse the repository at this point in the history
…wing to remove unused data processing code in the future
  • Loading branch information
chiayewken committed Mar 28, 2022
1 parent e008103 commit 49bf770
Showing 1 changed file with 56 additions and 10 deletions.
66 changes: 56 additions & 10 deletions aste/data_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import ast
import copy
import json
import os
Expand Down Expand Up @@ -197,7 +198,27 @@ def as_text(self) -> str:
tokens[t.t_end] = tokens[t.t_end] + "]"
return " ".join(tokens)

def as_line_format(self) -> str:
@classmethod
def from_line_format(cls, text: str):
front, back = text.split("#### #### ####")
tokens = front.split(" ")
triples = []

for a, b, label in ast.literal_eval(back):
t = SentimentTriple(
t_start=a[0],
t_end=a[0] if len(a) == 1 else a[-1],
o_start=b[0],
o_end=b[0] if len(b) == 1 else b[-1],
label=label,
)
triples.append(t)

return cls(
tokens=tokens, triples=triples, id=0, pos=[], weight=1, is_labeled=True
)

def to_line_format(self) -> str:
# ([1], [4], 'POS')
# ([1,2], [4], 'POS')
triplets = []
Expand All @@ -210,7 +231,11 @@ def as_line_format(self) -> str:
parts.append([start, end])
parts.append(f"{t.label}")
triplets.append(tuple(parts))
return " ".join(self.tokens) + "#### #### ####" + str(triplets) + "\n"

line = " ".join(self.tokens) + "#### #### ####" + str(triplets) + "\n"
assert self.from_line_format(line).tokens == self.tokens
assert self.from_line_format(line).triples == self.triples
return line


class Data(BaseModel):
Expand All @@ -227,13 +252,9 @@ def load(self):
path = self.root / f"{self.data_split}.txt"
if self.full_path:
path = self.full_path
instances = TagReader.read_inst(
file=path,
is_labeled=self.is_labeled,
number=self.num_instances,
opinion_offset=self.opinion_offset,
)
self.sentences = [Sentence.from_instance(x) for x in instances]

with open(path) as f:
self.sentences = [Sentence.from_line_format(line) for line in f]

@classmethod
def load_from_full_path(cls, path: str):
Expand All @@ -246,7 +267,7 @@ def save_to_path(self, path: str):
Path(path).parent.mkdir(exist_ok=True, parents=True)
with open(path, "w") as f:
for s in self.sentences:
f.write(s.as_line_format())
f.write(s.to_line_format())

data = Data.load_from_full_path(path)
assert data.sentences is not None
Expand Down Expand Up @@ -432,7 +453,32 @@ def analyze(self):
print("#" * 80)


def test_from_line_format(path: str = "aste/data/triplet_data/14lap/train.txt"):
print("\nCompare old TagReader with new Sentence.from_line_format")
instances = TagReader.read_inst(
file=path,
is_labeled=False,
number=-1,
opinion_offset=3,
)
a = Data(
root=Path(),
data_split=SplitEnum.test,
sentences=[Sentence.from_instance(x) for x in instances],
)

assert a.sentences is not None
with open(path) as f:
for i, line in enumerate(f):
s = Sentence.from_line_format(line)
assert s.tokens == a.sentences[i].tokens
set_a = set(t.json() for t in a.sentences[i].triples)
set_b = set(t.json() for t in s.triples)
assert set_a == set_b


def test_save_to_path(path: str = "aste/data/triplet_data/14lap/train.txt"):
print("\nEnsure that Data.save_to_path works properly")
path_temp = "temp.txt"
data = Data.load_from_full_path(path)
data.save_to_path(path_temp)
Expand Down

0 comments on commit 49bf770

Please sign in to comment.