Skip to content

Commit

Permalink
Require that all SpanGroup spans are from the current doc
Browse files Browse the repository at this point in the history
The restriction on only adding spans from the current doc were already
implemented for all operations except for `SpanGroup.__init__`.

Initialize copied spans for `SpanGroup.copy` with `Doc.char_span` in
order to validate the character offsets and to make it possible to copy
spans between documents with differing tokenization. Currently there is
no validation that the document texts are identical, but the span char
offsets must be valid spans in the target doc, which prevents you from
ending up with completely invalid spans.
  • Loading branch information
adrianeboyd committed Apr 24, 2023
1 parent 68da580 commit b9c3045
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 10 deletions.
3 changes: 3 additions & 0 deletions spacy/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -970,6 +970,9 @@ class Errors(metaclass=ErrorsWithCodes):
E1050 = ("Port {port} is already in use. Please specify an available port with `displacy.serve(doc, port=port)` "
"or use `auto_select_port=True` to pick an available port automatically.")
E1051 = ("'allow_overlap' can only be False when max_positive is 1, but found 'max_positive': {max_positive}.")
E1052 = ("Unable to copy spans: the character offsets for the span at "
"index {i} in the span group do not align with the tokenization "
"in the target doc.")


# Deprecated model shortcuts, only used in errors and warnings
Expand Down
24 changes: 24 additions & 0 deletions spacy/tests/doc/test_span_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,21 @@ def test_span_group_copy(doc):
assert span_group.attrs["key"] == "value"
assert list(span_group) != list(clone)

# can't copy if the character offsets don't align to tokens
doc2 = Doc(doc.vocab, words=[t.text + "x" for t in doc])
with pytest.raises(ValueError):
span_group.copy(doc=doc2)

# can copy with valid character offsets despite different tokenization
doc3 = doc.copy()
with doc3.retokenize() as retokenizer:
retokenizer.merge(doc3[0:2])
retokenizer.merge(doc3[3:6])
span_group = SpanGroup(doc, spans=[doc[0:6], doc[3:6]])
for span1, span2 in zip(span_group, span_group.copy(doc=doc3)):
assert span1.start_char == span2.start_char
assert span1.end_char == span2.end_char


def test_span_group_set_item(doc, other_doc):
span_group = doc.spans["SPANS"]
Expand Down Expand Up @@ -253,3 +268,12 @@ def test_span_group_typing(doc: Doc):
for i, span in enumerate(span_group):
assert span == span_group[i] == spans[i]
filter_spans(span_group)


def test_span_group_init_doc(en_tokenizer):
"""Test that all spans must come from the specified doc."""
doc1 = en_tokenizer("a b c")
doc2 = en_tokenizer("a b c")
span_group = SpanGroup(doc1, spans=[doc1[0:1], doc1[1:2]])
with pytest.raises(ValueError):
span_group = SpanGroup(doc1, spans=[doc1[0:1], doc2[1:2]])
12 changes: 6 additions & 6 deletions spacy/tests/parser/test_ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,7 +699,7 @@ def test_beam_overfitting_IO(neg_key):
# Try to unlearn the entity by using negative annotations
neg_doc = nlp.make_doc(test_text)
neg_ex = Example(neg_doc, neg_doc)
neg_ex.reference.spans[neg_key] = [Span(neg_doc, 2, 3, "LOC")]
neg_ex.reference.spans[neg_key] = [Span(neg_ex.reference, 2, 3, "LOC")]
neg_train_examples = [neg_ex]

for i in range(20):
Expand Down Expand Up @@ -728,9 +728,9 @@ def test_neg_annotation(neg_key):
ner.add_label("ORG")
example = Example.from_dict(neg_doc, {"entities": [(7, 17, "PERSON")]})
example.reference.spans[neg_key] = [
Span(neg_doc, 2, 4, "ORG"),
Span(neg_doc, 2, 3, "PERSON"),
Span(neg_doc, 1, 4, "PERSON"),
Span(example.reference, 2, 4, "ORG"),
Span(example.reference, 2, 3, "PERSON"),
Span(example.reference, 1, 4, "PERSON"),
]

optimizer = nlp.initialize()
Expand All @@ -755,7 +755,7 @@ def test_neg_annotation_conflict(neg_key):
ner.add_label("PERSON")
ner.add_label("LOC")
example = Example.from_dict(neg_doc, {"entities": [(7, 17, "PERSON")]})
example.reference.spans[neg_key] = [Span(neg_doc, 2, 4, "PERSON")]
example.reference.spans[neg_key] = [Span(example.reference, 2, 4, "PERSON")]
assert len(example.reference.ents) == 1
assert example.reference.ents[0].text == "Shaka Khan"
assert example.reference.ents[0].label_ == "PERSON"
Expand Down Expand Up @@ -788,7 +788,7 @@ def test_beam_valid_parse(neg_key):

doc = Doc(nlp.vocab, words=tokens)
example = Example.from_dict(doc, {"ner": iob})
neg_span = Span(doc, 50, 53, "ORG")
neg_span = Span(example.reference, 50, 53, "ORG")
example.reference.spans[neg_key] = [neg_span]

optimizer = nlp.initialize()
Expand Down
4 changes: 2 additions & 2 deletions spacy/tests/test_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,14 +423,14 @@ def span_getter(doc, span_key):
return doc.spans[span_key]

# Predict exactly the same, but overlapping spans will be discarded
pred.spans[key] = spans
pred.spans[key] = gold.spans[key].copy(doc=pred)
eg = Example(pred, gold)
scores = Scorer.score_spans([eg], attr=key, getter=span_getter)
assert scores[f"{key}_p"] == 1.0
assert scores[f"{key}_r"] < 1.0

# Allow overlapping, now both precision and recall should be 100%
pred.spans[key] = spans
pred.spans[key] = gold.spans[key].copy(doc=pred)
eg = Example(pred, gold)
scores = Scorer.score_spans([eg], attr=key, getter=span_getter, allow_overlap=True)
assert scores[f"{key}_p"] == 1.0
Expand Down
4 changes: 3 additions & 1 deletion spacy/tokens/doc.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1264,12 +1264,14 @@ cdef class Doc:
other.user_span_hooks = dict(self.user_span_hooks)
other.length = self.length
other.max_length = self.max_length
other.spans = self.spans.copy(doc=other)
buff_size = other.max_length + (PADDING*2)
assert buff_size > 0
tokens = <TokenC*>other.mem.alloc(buff_size, sizeof(TokenC))
memcpy(tokens, self.c - PADDING, buff_size * sizeof(TokenC))
other.c = &tokens[PADDING]
# copy spans after setting tokens so that SpanGroup.copy can verify
# that the start/end offsets are valid
other.spans = self.spans.copy(doc=other)
return other

def to_disk(self, path, *, exclude=tuple()):
Expand Down
15 changes: 14 additions & 1 deletion spacy/tokens/span_group.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ cdef class SpanGroup:
if len(spans) :
self.c.reserve(len(spans))
for span in spans:
if doc is not span.doc:
raise ValueError(Errors.E855.format(obj="span"))
self.push_back(span.c)

def __repr__(self):
Expand Down Expand Up @@ -261,11 +263,22 @@ cdef class SpanGroup:
"""
if doc is None:
doc = self.doc
if doc is self.doc:
spans = list(self)
else:
spans = [doc.char_span(span.start_char, span.end_char, label=span.label_, kb_id=span.kb_id, span_id=span.id) for span in self]
for i, span in enumerate(spans):
if span is None:
raise ValueError(Errors.E1052.format(i=i))
if span.kb_id in self.doc.vocab.strings:
doc.vocab.strings.add(span.kb_id_)
if span.id in span.doc.vocab.strings:
doc.vocab.strings.add(span.id_)
return SpanGroup(
doc,
name=self.name,
attrs=deepcopy(self.attrs),
spans=list(self),
spans=spans,
)
def _concat(
Expand Down

0 comments on commit b9c3045

Please sign in to comment.