Skip to content

Commit

Permalink
Add SpanModel with scikit-learn-style methods for easy usage (fit, pr…
Browse files Browse the repository at this point in the history
…edict, score)
  • Loading branch information
chiayewken committed Mar 31, 2022
1 parent 49bf770 commit 7cbf035
Show file tree
Hide file tree
Showing 5 changed files with 354 additions and 0 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ wordnet.zip
aste/data/
models/
model_outputs/
outputs/
4 changes: 4 additions & 0 deletions aste/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ class SentimentTriple(BaseModel):
t_end: int
label: LabelEnum

@classmethod
def make_dummy(cls):
return cls(o_start=0, o_end=0, t_start=0, t_end=0, label=LabelEnum.neutral)

@property
def opinion(self) -> Tuple[int, int]:
return self.o_start, self.o_end
Expand Down
29 changes: 29 additions & 0 deletions aste/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,5 +142,34 @@ def clean_up_many(pattern: str = "data/triplet_data/*/*.txt"):
clean_up_triplet_data(str(path))


def merge_data(
folders_in: List[str] = [
"aste/data/triplet_data/14res/",
"aste/data/triplet_data/15res/",
"aste/data/triplet_data/16res/",
],
folder_out: str = "aste/data/triplet_data/res_all/",
):
for name in ["train.txt", "dev.txt", "test.txt"]:
outputs = []
for folder in folders_in:
path = Path(folder) / name
with open(path) as f:
for line in f:
assert line.endswith("\n")
outputs.append(line)

path_out = Path(folder_out) / name
path_out.parent.mkdir(exist_ok=True, parents=True)
with open(path_out, "w") as f:
f.write("".join(outputs))


def safe_divide(a: float, b: float) -> float:
if a == 0 or b == 0:
return 0
return a / b


if __name__ == "__main__":
Fire()
168 changes: 168 additions & 0 deletions aste/wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
import json
import os
from pathlib import Path
from typing import List

import _jsonnet
from fire import Fire
from pydantic import BaseModel
from tqdm import tqdm

from data_utils import Data, SentimentTriple, SplitEnum
from main import SpanModelData, SpanModelPrediction
from utils import Shell, safe_divide


class SpanModel(BaseModel):
save_dir: str
random_seed: int
path_config_base: str = "training_config/config.jsonnet"

def save_temp_data(self, path_in: str, name: str, is_test: bool = False) -> Path:
path_temp = Path(self.save_dir) / "temp_data" / f"{name}.json"
path_temp = path_temp.resolve()
path_temp.parent.mkdir(exist_ok=True, parents=True)
data = Data.load_from_full_path(path_in)

if is_test:
# SpanModel error if s.triples is empty list
assert data.sentences is not None
for s in data.sentences:
s.triples = [SentimentTriple.make_dummy()]

span_data = SpanModelData.from_data(data)
span_data.dump(path_temp)
return path_temp

def fit(self, path_train: str, path_dev: str):
weights_dir = Path(self.save_dir) / "weights"
weights_dir.mkdir(exist_ok=True, parents=True)
print(dict(weights_dir=weights_dir))
path_config = Path(self.save_dir) / "config.jsonnet"
config = json.loads(_jsonnet.evaluate_file(self.path_config_base))

for key in ["random_seed", "pytorch_seed", "numpy_seed"]:
assert key in config.keys()
config[key] = self.random_seed
print({key: self.random_seed})

for name, path in dict(
train=path_train, validation=path_dev, test=path_dev
).items():
key = f"{name}_data_path"
assert key in config.keys()
path_temp = self.save_temp_data(path, name)
config[key] = str(path_temp)
print({key: path_temp})

with open(path_config, "w") as f:
f.write(json.dumps(config, indent=2))
print(dict(path_config=path_config))

shell = Shell()
work_dir = Path(".").resolve()
shell.run(
f"cd {work_dir} && allennlp train {path_config}",
serialization_dir=str(weights_dir),
include_package="span_model",
)

def predict(self, path_in: str, path_out: str):
work_dir = Path(".").resolve()
path_model = Path(self.save_dir) / "weights" / "model.tar.gz"
path_temp_in = self.save_temp_data(path_in, "pred_in", is_test=True)
path_temp_out = Path(self.save_dir) / "temp_data" / "pred_out.json"
if path_temp_out.exists():
os.remove(path_temp_out)

shell = Shell()
shell.run(
f"cd {work_dir} && allennlp predict {path_model}",
str(path_temp_in),
predictor="span_model",
include_package="span_model",
use_dataset_reader="",
output_file=str(path_temp_out),
cuda_device=0,
silent="",
)

with open(path_temp_out) as f:
preds = [SpanModelPrediction(**json.loads(line.strip())) for line in f]
data = Data(
root=Path(),
data_split=SplitEnum.test,
sentences=[p.to_sentence() for p in preds],
)
data.save_to_path(path_out)

def score(self, path_pred: str, path_gold: str) -> dict:
pred = Data.load_from_full_path(path_pred)
gold = Data.load_from_full_path(path_gold)
assert pred.sentences is not None
assert gold.sentences is not None
assert len(pred.sentences) == len(gold.sentences)
num_pred = 0
num_gold = 0
num_correct = 0

for i in range(len(gold.sentences)):
num_pred += len(pred.sentences[i].triples)
num_gold += len(gold.sentences[i].triples)
for p in pred.sentences[i].triples:
for g in gold.sentences[i].triples:
if p.dict() == g.dict():
num_correct += 1

precision = safe_divide(num_correct, num_pred)
recall = safe_divide(num_correct, num_gold)

info = dict(
path_pred=path_pred,
path_gold=path_gold,
precision=precision,
recall=recall,
score=safe_divide(2 * precision * recall, precision + recall),
)
return info


def run_train(path_train: str, path_dev: str, save_dir: str, random_seed: int):
print(dict(run_train=locals()))
if Path(save_dir).exists():
return

model = SpanModel(save_dir=save_dir, random_seed=random_seed)
model.fit(path_train, path_dev)


def run_train_many(save_dir_template: str, random_seeds: List[int], **kwargs):
for seed in tqdm(random_seeds):
save_dir = save_dir_template.format(seed)
run_train(save_dir=save_dir, random_seed=seed, **kwargs)


def run_eval(path_test: str, save_dir: str):
print(dict(run_eval=locals()))
model = SpanModel(save_dir=save_dir, random_seed=0)
path_pred = str(Path(save_dir) / "pred.txt")
model.predict(path_test, path_pred)
results = model.score(path_pred, path_test)
print(results)
return results


def run_eval_many(save_dir_template: str, random_seeds: List[int], **kwargs):
results = []
for seed in tqdm(random_seeds):
save_dir = save_dir_template.format(seed)
results.append(run_eval(save_dir=save_dir, **kwargs))

precision = sum(r["precision"] for r in results) / len(random_seeds)
recall = sum(r["recall"] for r in results) / len(random_seeds)
score = safe_divide(2 * precision * recall, precision + recall)
print(dict(precision=precision, recall=recall, score=score))


if __name__ == "__main__":
Fire()
152 changes: 152 additions & 0 deletions training_config/config.jsonnet
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
{
"data_loader": {
"sampler": {
"type": "random"
}
},
"dataset_reader": {
"max_span_width": 8,
"token_indexers": {
"bert": {
"max_length": 512,
"model_name": "bert-base-uncased",
"type": "pretrained_transformer_mismatched"
}
},
"type": "span_model"
},
"model": {
"embedder": {
"token_embedders": {
"bert": {
"max_length": 512,
"model_name": "bert-base-uncased",
"type": "pretrained_transformer_mismatched"
}
}
},
"feature_size": 20,
"feedforward_params": {
"dropout": 0.4,
"hidden_dims": 150,
"num_layers": 2
},
"initializer": {
"regexes": [
[
"_span_width_embedding.weight",
{
"type": "xavier_normal"
}
]
]
},
"loss_weights": {
"ner": 1.0,
"relation": 1
},
"max_span_width": 8,
"module_initializer": {
"regexes": [
[
".*weight",
{
"type": "xavier_normal"
}
],
[
".*weight_matrix",
{
"type": "xavier_normal"
}
]
]
},
"modules": {
"ner": {
"focal_loss_gamma": 2,
"neg_class_weight": -1,
"use_bi_affine": false,
"use_double_scorer": false,
"use_focal_loss": false,
"use_gold_for_train_prune_scores": false,
"use_single_pool": false
},
"relation": {
"focal_loss_gamma": 2,
"neg_class_weight": -1,
"span_length_loss_weight_gamma": 0,
"spans_per_word": 0.5,
"use_bag_pair_scorer": false,
"use_bi_affine_classifier": false,
"use_bi_affine_pruner": false,
"use_bi_affine_v2": false,
"use_classify_mask_pruner": false,
"use_distance_embeds": true,
"use_focal_loss": false,
"use_ner_scores_for_prune": false,
"use_ope_down_project": false,
"use_pair_feature_cls": false,
"use_pair_feature_maxpool": false,
"use_pair_feature_multiply": false,
"use_pairwise_down_project": false,
"use_pruning": true,
"use_single_pool": false,
"use_span_loss_for_pruners": false,
"use_span_pair_aux_task": false,
"use_span_pair_aux_task_after_prune": false
}
},
"relation_head_type": "proper",
"span_extractor_type": "endpoint",
"target_task": "relation",
"type": "span_model",
"use_bilstm_after_embedder": false,
"use_double_mix_embedder": false,
"use_ner_embeds": false,
"use_span_width_embeds": true
},
"trainer": {
"checkpointer": {
"num_serialized_models_to_keep": 1
},
"cuda_device": 0,
"grad_norm": 5,
"learning_rate_scheduler": {
"type": "slanted_triangular"
},
"num_epochs": 10,
"optimizer": {
"lr": 0.001,
"parameter_groups": [
[
[
"_matched_embedder"
],
{
"finetune": true,
"lr": 5e-05,
"weight_decay": 0.01
}
],
[
[
"scalar_parameters"
],
{
"lr": 0.01
}
]
],
"type": "adamw",
"weight_decay": 0
},
"validation_metric": "+MEAN__relation_f1"
},
"numpy_seed": 0,
"pytorch_seed": 0,
"random_seed": 0,
"test_data_path": "",
"train_data_path": "",
"validation_data_path": ""
}

0 comments on commit 7cbf035

Please sign in to comment.