Skip to content

Commit

Permalink
Merge pull request #3 from AlexTate/operator-expansion
Browse files Browse the repository at this point in the history
Add SIMD hamming distance calculation
  • Loading branch information
AlexTate authored Dec 6, 2023
2 parents 59dad61 + 2a0d267 commit af1a11d
Show file tree
Hide file tree
Showing 6 changed files with 96 additions and 26 deletions.
6 changes: 3 additions & 3 deletions ShortSeq/short_seq.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -48,19 +48,19 @@ cdef class ShortSeq:
cdef inline object _new(char* sequence, size_t length):
if length == 0:
return empty
elif length <= 32:
elif length <= MAX_64_NT:
out64 = ShortSeq64.__new__(ShortSeq64)
length = <uint8_t> length
(<ShortSeq64> out64)._packed = _marshall_bytes_64(<uint8_t *> sequence, length)
(<ShortSeq64> out64)._length = length
return out64
elif length <= 64:
elif length <= MAX_128_NT:
out128 = ShortSeq128.__new__(ShortSeq128)
length = <uint8_t> length
(<ShortSeq128> out128)._packed = _marshall_bytes_128(<uint8_t *> sequence, length)
(<ShortSeq128> out128)._length = length
return out128
elif length <= 1024:
elif length <= MAX_VAR_NT:
outvar = ShortSeqVar.__new__(ShortSeqVar)
(<ShortSeqVar> outvar)._packed = _marshall_bytes_var(<uint8_t *> sequence, length)
(<ShortSeqVar> outvar)._length = length
Expand Down
14 changes: 14 additions & 0 deletions ShortSeq/short_seq_128.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,20 @@ cdef class ShortSeq128:
else:
raise TypeError(f"Invalid index type: {type(item)}")

def __xor__(self, ShortSeq128 other):
if self._length != other._length:
raise Exception("Hamming distance requires sequences of equal length")

cdef uint128_t xor = self._packed ^ (<ShortSeq128> other)._packed
cdef uint64_t lo = <uint64_t> xor
cdef uint64_t hi = xor >> 64

# Some bases XOR to 0x3; collapse these results to 0x1 inplace
lo = ((lo >> 1) | lo) & 0x5555555555555555LL
hi = ((hi >> 1) | hi) & 0x5555555555555555LL

return _popcnt64(lo) + _popcnt64(hi)

def __str__(self):
return _unmarshall_bytes_128(self._packed, self._length)

Expand Down
8 changes: 8 additions & 0 deletions ShortSeq/short_seq_64.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,14 @@ cdef class ShortSeq64:
else:
raise TypeError(f"Invalid index type: {type(item)}")

def __xor__(self, ShortSeq64 other):
if self._length != other._length:
raise Exception("Hamming distance requires sequences of equal length")

cdef uint64_t comp = self._packed ^ (<ShortSeq64>other)._packed
comp = ((comp >> 1) | comp) & 0x5555555555555555LL # Some bases XOR to 0x3; collapse these inplace to 0x1
return _popcnt64(comp)

def __str__(self):
return _unmarshall_bytes_64(self._packed, self._length)

Expand Down
15 changes: 15 additions & 0 deletions ShortSeq/short_seq_var.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,21 @@ cdef class ShortSeqVar:
def __str__(self):
return _unmarshall_bytes_var(self._packed, self._length)

def __xor__(self, ShortSeqVar other):
cdef uint64_t block, block_other, block_comp
cdef size_t n_blocks = _length_to_block_num(self._length)
cdef size_t pop_cnt = 0
cdef size_t i

for i in range(n_blocks):
block = self._packed[i]
block_other = other._packed[i]
block_comp = block ^ block_other
block_comp = ((block_comp >> 1) | block_comp) & 0x5555555555555555LL # Some bases XOR to 0x3; collapse these inplace to 0x1
pop_cnt += _popcnt64(block_comp)

return pop_cnt

def __repr__(self):
# Clips the sequence to MAX_REPR_LEN characters to avoid overwhelming the debugger
cdef unicode clipped_seq = _unmarshall_bytes_var(self._packed, MAX_REPR_LEN)
Expand Down
61 changes: 48 additions & 13 deletions ShortSeq/tests/unit_tests_main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import unittest

from random import randint

from ShortSeq import ShortSeq, ShortSeq64, ShortSeq128, ShortSeqVar
from ShortSeq import MIN_VAR_NT, MAX_VAR_NT, MIN_64_NT, MAX_64_NT, MIN_128_NT, MAX_128_NT
from ShortSeq.tests.util import rand_sequence, print_var_seq_pext_chunks
Expand Down Expand Up @@ -70,7 +72,7 @@ def test_length_range(self):
length = None
try:
for length in range(MIN_64_NT, MAX_64_NT):
sample = rand_sequence(length, no_range=True)
sample = rand_sequence(length)
sq = ShortSeq.from_str(sample)

self.assertIsInstance(sq, ShortSeq64)
Expand All @@ -84,7 +86,7 @@ def test_length_range(self):
length = None
try:
for length in range(MIN_128_NT, MAX_128_NT):
sample = rand_sequence(length, no_range=True)
sample = rand_sequence(length)
sq = ShortSeq.from_str(sample)

self.assertIsInstance(sq, ShortSeq128)
Expand All @@ -101,7 +103,7 @@ def test_subscript(self):
sample64, length, i, sq = None, None, None, None
try:
for length in range(MIN_64_NT, MAX_64_NT):
sample64 = rand_sequence(length, no_range=True)
sample64 = rand_sequence(length)
sq = ShortSeq.from_str(sample64)
for i in range(len(sample64)):
self.assertEqual(sq[i], sample64[i])
Expand All @@ -118,7 +120,7 @@ def test_subscript(self):
sample128, length, i, sq = None, None, None, None
try:
for length in range(MIN_128_NT, MAX_128_NT):
sample128 = rand_sequence(length, no_range=True)
sample128 = rand_sequence(length)
sq = ShortSeq.from_str(sample128)
for i in range(len(sample128)):
self.assertEqual(sq[i], sample128[i])
Expand All @@ -131,12 +133,22 @@ def test_subscript(self):
with self.assertRaises(IndexError):
_ = sq[oob]

"""Does the Hamming distance between two ShortSeqs work as expected?"""

def test_hamming_distance(self):
def str_ham(a, b): return sum(a_nt != b_nt for a_nt, b_nt in zip(a, b))

for length in range(0, MAX_128_NT):
a = rand_sequence(length)
b = rand_sequence(length)

self.assertEqual(ShortSeq.from_str(a) ^ ShortSeq.from_str(b), str_ham(a, b))

"""Can fixed width ShortSeqs be sliced like strings?"""

def test_slice(self):
#ShortSeq64
sample = rand_sequence(MAX_64_NT, no_range=True)
sample = rand_sequence(MAX_64_NT)
sq = ShortSeq.from_str(sample)
self.assertEqual(sq[:], sample)
for i in range(len(sample)):
Expand All @@ -146,7 +158,7 @@ def test_slice(self):
self.assertEqual(sq[-i:], sample[-i:])

# ShortSeq128
sample = rand_sequence(MAX_128_NT, no_range=True)
sample = rand_sequence(MAX_128_NT)
sq = ShortSeq.from_str(sample)
self.assertEqual(sq[:], sample)
for i in range(len(sample)):
Expand All @@ -167,7 +179,7 @@ def test_min_length(self):
n_samples = 3

for _ in range(n_samples):
sample = rand_sequence(sample_len, no_range=True)
sample = rand_sequence(sample_len)
sq = ShortSeq.from_str(sample)

self.assertIsInstance(sq, ShortSeqVar)
Expand All @@ -191,8 +203,8 @@ def test_max_length(self):
def test_length_range(self):
length = None
try:
for length in range(MIN_VAR_NT, MAX_VAR_NT-1):
sample = rand_sequence(length, no_range=True)
for length in range(MIN_VAR_NT, MAX_VAR_NT):
sample = rand_sequence(length)
sq = ShortSeq.from_str(sample)

self.assertIsInstance(sq, ShortSeqVar)
Expand All @@ -207,8 +219,8 @@ def test_length_range(self):
def test_subscript(self):
length, i = None, None
try:
for length in range(MIN_VAR_NT, MAX_VAR_NT-1):
sample = rand_sequence(length, no_range=True)
for length in range(MIN_VAR_NT, MAX_VAR_NT):
sample = rand_sequence(length)
sq = ShortSeq.from_str(sample)
for i in range(len(sample)):
self.assertEqual(sq[i], sample[i])
Expand All @@ -225,7 +237,7 @@ def test_subscript(self):

def test_slice(self):
# Min length
sample = rand_sequence(MIN_VAR_NT, no_range=True)
sample = rand_sequence(MIN_VAR_NT)
sq = ShortSeq.from_str(sample)
self.assertEqual(sq[:], sample)
for i in range(len(sample)):
Expand All @@ -235,7 +247,7 @@ def test_slice(self):
self.assertEqual(sq[-i:], sample[-i:])

# Max length
sample = rand_sequence(MAX_VAR_NT, no_range=True)
sample = rand_sequence(MAX_VAR_NT)
sq = ShortSeq.from_str(sample)
self.assertEqual(sq[:], sample)
for i in range(len(sample)):
Expand All @@ -244,6 +256,29 @@ def test_slice(self):
self.assertEqual(sq[i:], sample[i:])
self.assertEqual(sq[-i:], sample[-i:])

"""Just slice the heck out of the darn thing"""

def test_stochastic_slice(self):
sample = rand_sequence(MAX_VAR_NT)
sq = ShortSeq.from_str(sample)

for _ in range(10000):
a = randint(0, MAX_VAR_NT // 2)
b = randint(a, a + randint(1, MAX_VAR_NT - a))
self.assertEqual(sq[a:b], sample[a:b])

"""Does the Hamming distance between two ShortSeqs work as expected?"""

def test_hamming_distance(self):
def str_ham(a, b): return sum(a_nt != b_nt for a_nt, b_nt in zip(a, b))

for length in range(MIN_VAR_NT, MAX_VAR_NT):
a = rand_sequence(length)
b = rand_sequence(length)

self.assertEqual(ShortSeq.from_str(a) ^ ShortSeq.from_str(b), str_ham(a, b))



if __name__ == '__main__':
unittest.main()
18 changes: 8 additions & 10 deletions ShortSeq/tests/util.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import numpy as np
import random
import math


Expand All @@ -24,18 +24,16 @@ def print_var_seq_pext_chunks(seq):
print(" -> ".join(out))


def rand_sequence(min_length=None, max_length=None, no_range=False, as_bytes=False):
def rand_sequence(min_length=None, max_length=None, as_bytes=False):
"""Returns a randomly generated sequence of the specified type, with a length in the specified range"""

assert (min_length, max_length) != (None, None)
bases = ("A", "C", "T", "G")

if no_range:
max_length = min_length
min_length = 0
if min_length is None:
min_length = max_length
if max_length is None:
max_length = min_length
if min_length and max_length:
assert min_length <= max_length
seq = ''.join(random.choice(bases) for _ in range(min_length, max_length))
else:
seq = ''.join(random.choice(bases) for _ in range(min_length))

seq = ''.join(np.random.choice(["A", "C", "T", "G"]) for _ in range(min_length, max_length))
return seq.encode() if as_bytes else seq

0 comments on commit af1a11d

Please sign in to comment.