Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change parametrization #5

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 57 additions & 20 deletions bertalign/aligner.py
Original file line number Diff line number Diff line change
@@ -1,52 +1,89 @@
import numpy as np


from bertalign import model
from bertalign.corelib import *
from bertalign.utils import *

class Bertalign:
def __init__(self,
src,
tgt,
src_raw,
tgt_raw,
max_align=5,
top_k=3,
win=5,
skip=-0.1,
margin=True,
len_penalty=True,
is_split=False,
input_type='raw',
src_lang=None,
tgt_lang=None,
):

self.max_align = max_align
self.top_k = top_k
self.win = win
self.skip = skip
self.margin = margin
self.len_penalty = len_penalty

src = clean_text(src)
tgt = clean_text(tgt)
src_lang = detect_lang(src)
tgt_lang = detect_lang(tgt)

if is_split:

input_types = ['raw', 'lines', 'tokenized']
if input_type not in input_types:
raise ValueError("Invalid input type '%s'. Expected one of: %s" % (input_type, input_types))

if input_type == 'lines':
# need to split
src = clean_text(src_raw)
tgt = clean_text(tgt_raw)
src_sents = src.splitlines()
tgt_sents = tgt.splitlines()
else:

if not src_lang:
src_lang = detect_lang(src)
if not tgt_lang:
tgt_lang = detect_lang(tgt)


elif input_type == 'raw':
src = clean_text(src_raw)
tgt = clean_text(tgt_raw)

if not src_lang:
src_lang = detect_lang(src)
if not tgt_lang:
tgt_lang = detect_lang(tgt)

src_sents = split_sents(src, src_lang)
tgt_sents = split_sents(tgt, tgt_lang)


elif input_type == 'tokenized':

if not src_lang:
src_lang = detect_lang(src)
if not tgt_lang:
tgt_lang = detect_lang(tgt)

src_sents = src_raw
tgt_sents = tgt_raw

if not src_lang:
src_lang = detect_lang(' '.join(src_sents))
if not tgt_lang:
tgt_lang = detect_lang(' '.join(tgt_sents))


src_num = len(src_sents)
tgt_num = len(tgt_sents)

src_lang = LANG.ISO[src_lang]
tgt_lang = LANG.ISO[tgt_lang]

print("Source language: {}, Number of sentences: {}".format(src_lang, src_num))
print("Target language: {}, Number of sentences: {}".format(tgt_lang, tgt_num))

print("Embedding source and target text using {} ...".format(model.model_name))
print("Embedding source text using {} ...".format(model.model_name))
src_vecs, src_lens = model.transform(src_sents, max_align - 1)
print("Embedding target text using {} ...".format(model.model_name))
tgt_vecs, tgt_lens = model.transform(tgt_sents, max_align - 1)

char_ratio = np.sum(src_lens[0,]) / np.sum(tgt_lens[0,])
Expand All @@ -62,7 +99,7 @@ def __init__(self,
self.char_ratio = char_ratio
self.src_vecs = src_vecs
self.tgt_vecs = tgt_vecs

def align_sents(self):

print("Performing first-step alignment ...")
Expand All @@ -71,18 +108,18 @@ def align_sents(self):
first_w, first_path = find_first_search_path(self.src_num, self.tgt_num)
first_pointers = first_pass_align(self.src_num, self.tgt_num, first_w, first_path, first_alignment_types, D, I)
first_alignment = first_back_track(self.src_num, self.tgt_num, first_pointers, first_path, first_alignment_types)

print("Performing second-step alignment ...")
second_alignment_types = get_alignment_types(self.max_align)
second_w, second_path = find_second_search_path(first_alignment, self.win, self.src_num, self.tgt_num)
second_pointers = second_pass_align(self.src_vecs, self.tgt_vecs, self.src_lens, self.tgt_lens,
second_w, second_path, second_alignment_types,
self.char_ratio, self.skip, margin=self.margin, len_penalty=self.len_penalty)
second_alignment = second_back_track(self.src_num, self.tgt_num, second_pointers, second_path, second_alignment_types)

print("Finished! Successfully aligning {} {} sentences to {} {} sentences\n".format(self.src_num, self.src_lang, self.tgt_num, self.tgt_lang))
self.result = second_alignment

def print_sents(self):
for bead in (self.result):
src_line = self._get_line(bead[0], self.src_sents)
Expand Down
37 changes: 37 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import pytest
import json
import os

def load_json(fpath):
with open(fpath) as json_file:
data = json.load(json_file)
return data


@pytest.fixture
def text_and_berg_expected_results():
"""Fixture for the Text und Berg expected result."""

cur_dir = os.path.dirname(os.path.realpath(__file__))
fname = 'gold_standard_text_und_berg.json'
fpath = os.path.join(cur_dir, fname)
data = load_json(fpath)
yield data



@pytest.fixture
def text_and_berg_inputs():
r"""Input data for Text and Berg."""

src_dir = 'text+berg/de'
tgt_dir = 'text+berg/fr'
gold_dir = 'text+berg/gold'

data = []
for file in os.listdir(src_dir):
src_file = os.path.join(src_dir, file).replace("\\","/")
tgt_file = os.path.join(tgt_dir, file).replace("\\","/")
data.append((file, src_file, tgt_file, gold_dir))

yield data
58 changes: 58 additions & 0 deletions tests/gold_standard_text_und_berg.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
{
"002": {
"recall_strict": 0.9588477366255144,
"recall_lax": 0.9917695473251029,
"precision_strict": 0.9505703422053232,
"precision_lax": 0.9847908745247148,
"f1_strict": 0.9546910980176844,
"f1_lax": 0.9882678910702977
},
"006": {
"recall_strict": 0.9694444444444444,
"recall_lax": 0.9944444444444445,
"precision_strict": 0.9607329842931938,
"precision_lax": 0.9869109947643979,
"f1_strict": 0.9650690556740179,
"f1_lax": 0.9906633978772443
},
"001": {
"recall_strict": 0.9553191489361702,
"recall_lax": 0.9957446808510638,
"precision_strict": 0.9496981891348089,
"precision_lax": 0.9879275653923542,
"f1_strict": 0.9525003764104154,
"f1_lax": 0.991820720553515
},
"005": {
"recall_strict": 0.9502982107355865,
"recall_lax": 0.9960238568588469,
"precision_strict": 0.9453860640301318,
"precision_lax": 0.9887005649717514,
"f1_strict": 0.9478357731413087,
"f1_lax": 0.9923487000713064
},
"007": {
"recall_strict": 0.937592867756315,
"recall_lax": 0.9910846953937593,
"precision_strict": 0.9265536723163842,
"precision_lax": 0.9830508474576272,
"f1_strict": 0.9320405838088075,
"f1_lax": 0.9870514243433223
},
"004": {
"recall_strict": 0.9404145077720207,
"recall_lax": 0.9896373056994818,
"precision_strict": 0.9320148331273177,
"precision_lax": 0.9851668726823238,
"f1_strict": 0.9361958300767388,
"f1_lax": 0.9873970292534215
},
"003": {
"recall_strict": 0.9405594405594405,
"recall_lax": 0.9906759906759907,
"precision_strict": 0.9319955406911928,
"precision_lax": 0.9866220735785953,
"f1_strict": 0.9362579076540054,
"f1_lax": 0.9886448763947483
}
}
1 change: 1 addition & 0 deletions tests/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pytest
78 changes: 78 additions & 0 deletions tests/test_results.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import os

import pytest

from bertalign import Bertalign
from bertalign.eval import read_alignments
from bertalign.eval import score_multiple
from bertalign.eval import log_final_scores


def align_text_and_berg(filespec, aligner_spec):
r"""Align Text and Berg using the original aligner."""

test_alignments = []
gold_alignments = []

results = {}

for test_data in filespec:

file, src_file, tgt_file, gold_dir = test_data
src = open(src_file, "rt", encoding="utf-8").read()
tgt = open(tgt_file, "rt", encoding="utf-8").read()

print("Start aligning {} to {}".format(src_file, tgt_file))
# aligner = Bertalign(src, tgt, is_split=True)
aligner = Bertalign(src, tgt, **aligner_spec)
aligner.align_sents()
test_alignments.append(aligner.result)

gold_file = os.path.join(gold_dir, file)
gold_alignments.append(read_alignments(gold_file))

scores = score_multiple(gold_list=gold_alignments, test_list=test_alignments)
log_final_scores(scores)
results[file] = scores
return results


@pytest.mark.skip(reason="is_split is removed at the moment.")
def test_aligner_original(text_and_berg_expected_results, text_and_berg_inputs):
r"""Test results for the original aligner using is_split."""

aligner_spec = {"is_split": True}
result = align_text_and_berg(text_and_berg_inputs, aligner_spec)

for file in result:
expected = text_and_berg_expected_results[file]
calculated = result[file]
for metric in expected:
assert expected[metric] == calculated[metric], "Result mismatch"


aligner_spec_explicit = {
"input_type": "lines",
"src_lang": "de",
"tgt_lang": "fr",
}


aligner_spec_detect = {
"input_type": "lines",
}

# @pytest.mark.parametrize("aligner_spec", [aligner_spec_detect])
@pytest.mark.parametrize("aligner_spec", [aligner_spec_explicit])
def test_aligner_altered_parametrization(
text_and_berg_expected_results, text_and_berg_inputs, aligner_spec
):
r"""Test results for the aligner using input_type and languages."""

result = align_text_and_berg(text_and_berg_inputs, aligner_spec)

for file in result:
expected = text_and_berg_expected_results[file]
calculated = result[file]
for metric in expected:
assert expected[metric] == calculated[metric], "Result mismatch"