Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixed #5 : fixed merging incorrectly at overflowings field. #8

Merged
merged 1 commit into from
Oct 16, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 65 additions & 57 deletions encoding.go
Original file line number Diff line number Diff line change
Expand Up @@ -324,29 +324,52 @@ func (e *Encoding) Merge(encodings []Encoding, growingOffsets bool) (retVal *Enc

// MergeWith merges the current encoding with other (pair) encoding
func (e *Encoding) MergeWith(pair *Encoding, growingOffsets bool) (retVal *Encoding) {
// Merge overflowing
overflowings := make([]Encoding, 0)
var (
en Encoding = *e
pen Encoding = *pair
enOverflowings []Encoding = e.Overflowing
penOverflowings []Encoding = pair.Overflowing
)
en.Overflowing = []Encoding{}
pen.Overflowing = []Encoding{}

// Keep a copy before merging overflowings
ids := e.Ids
tokens := e.Tokens
wordIds := e.Words
offsets := e.Offsets
typeIds := e.TypeIds
specialTokenMask := e.SpecialTokenMask
attentionMask := e.AttentionMask
// 1. All our overflowings with all other overflowings
for _, o := range enOverflowings {
nEncoding := o
// 1.1. The pair itself
merge := mergeEncoding(nEncoding, pen, growingOffsets)
overflowings = append(overflowings, merge)

// 1.2. Its overflowings
for _, otherO := range penOverflowings {
oEncoding := otherO
merge := mergeEncoding(nEncoding, oEncoding, growingOffsets)
overflowings = append(overflowings, merge)
}
}

// Overflow
e.mergeOverflow(pair)
// 2. Ourself with all the other overflowings
for _, otherO := range penOverflowings {
oEncoding := otherO
merge := mergeEncoding(en, oEncoding, growingOffsets)
overflowings = append(overflowings, merge)
}

e.Overflowing = overflowings

// Others
e.Ids = append(ids, pair.Ids...)
e.Tokens = append(tokens, pair.Tokens...)
e.Words = append(wordIds, pair.Words...)
e.TypeIds = append(typeIds, pair.TypeIds...)
e.SpecialTokenMask = append(specialTokenMask, pair.SpecialTokenMask...)
e.AttentionMask = append(attentionMask, pair.AttentionMask...)
// Merging others
e.Ids = append(e.Ids, pair.Ids...)
e.Tokens = append(e.Tokens, pair.Tokens...)
e.Words = append(e.Words, pair.Words...)
e.TypeIds = append(e.TypeIds, pair.TypeIds...)
e.SpecialTokenMask = append(e.SpecialTokenMask, pair.SpecialTokenMask...)
e.AttentionMask = append(e.AttentionMask, pair.AttentionMask...)

// Offsets
var startingOffset int = 0
offsets := e.Offsets
if growingOffsets {
if len(offsets) > 0 {
last := offsets[len(offsets)-1]
Expand All @@ -366,57 +389,42 @@ func (e *Encoding) MergeWith(pair *Encoding, growingOffsets bool) (retVal *Encod
return e
}

// mergeOverflow merges overflowings of curent encoding with the pair.
//
// NOTE: this is a hacking solution created specifically for
// public method `MergeWith`. The merging have side-effects
// on other fields of Encoding.
func (e *Encoding) mergeOverflow(pair *Encoding) *Encoding {
// Merge overflowing
overflowings := make([]Encoding, 0)
// 1. All current overflowing with all other overflowing
for _, o := range e.Overflowing {
currO := o
// 1.1. The pair itself
currO.mergeOverflow(pair) // recursively call
overflowings = append(overflowings, currO)
currO = o // reset

// 1.2. The pair's overflowing
for _, otherO := range pair.Overflowing {
currO.mergeOverflow(&otherO)
overflowings = append(overflowings, currO)
currO = o // reset
}
}

// 2. Current encoding with all other overflowing
for _, otherO := range pair.Overflowing {
newE := e
newE.mergeOverflow(&otherO)
overflowings = append(overflowings, *newE)
// mergeEncoding merges 2 encodings those have `Overflowing` field empty.
// Otherwise, it will be panic.
func mergeEncoding(en1, en2 Encoding, growingOffsets bool) Encoding {
if len(en1.Overflowing) > 0 || len(en2.Overflowing) > 0 {
log.Fatalf("Invalid input encodings. Input encodings must have 'Overflowing' field empty.\n")
}

// 3. Current encoding and other encoding
e.Ids = append(e.Ids, pair.Ids...)
e.TypeIds = append(e.TypeIds, pair.TypeIds...)
e.Tokens = append(e.Tokens, pair.Tokens...)
e.SpecialTokenMask = append(e.SpecialTokenMask, pair.SpecialTokenMask...)
e.AttentionMask = append(e.AttentionMask, pair.AttentionMask...)
e.Overflowing = overflowings
e.Words = append(e.Words, pair.Words...)
var merge Encoding
merge.Overflowing = make([]Encoding, 0)
merge.Ids = append(en1.Ids, en2.Ids...)
merge.TypeIds = append(en1.TypeIds, en2.TypeIds...)
merge.Words = append(en1.Words, en2.Words...)
merge.Tokens = append(en1.Tokens, en2.Tokens...)
merge.SpecialTokenMask = append(en1.SpecialTokenMask, en2.SpecialTokenMask...)
merge.AttentionMask = append(en1.AttentionMask, en2.AttentionMask...)

// Offsets
offsets := en1.Offsets
var startingOffset int = 0
for _, o := range pair.Offsets {
if growingOffsets {
if len(offsets) > 0 {
last := offsets[len(offsets)-1]
startingOffset = last[1]
}
}

for _, o := range en2.Offsets {
adjustedO := []int{
o[0] + startingOffset,
o[1] + startingOffset,
}
e.Offsets = append(e.Offsets, adjustedO)
offsets = append(offsets, adjustedO)
}
merge.Offsets = offsets

return e
return merge
}

// Pad pads current encoding with given length, values to either Left or Right direction
Expand Down
2 changes: 1 addition & 1 deletion example/truncation/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func getBert() (retVal *tokenizer.Tokenizer) {
tk.WithPreTokenizer(bertPreTokenizer)

truncParams := tokenizer.TruncationParams{
MaxLength: 34,
MaxLength: 25,
Strategy: tokenizer.OnlySecond,
Stride: 0,
}
Expand Down