+ * This is an {@link AutoCloseable} mean to be use in try-with-resource constructs.
+ *
+ *
+ * If you don't do so, please remember to call {@link #close} at the end to free
+ * resources.
+ *
+ */
+public class SingleSequenceReferenceAligner implements AutoCloseable {
+
+ private final File image;
+ private final BwaMemAligner aligner;
+ private final BwaMemIndex index;
+ private final List refNames;
+ private boolean closed = false;
+
+ @FunctionalInterface
+ public interface TriFunction{
+ W apply(T t, U u, V v);
+ }
+
+ public static final Predicate super BwaMemAlignment> NO_FILTER = x -> true;
+
+ private final Function super T, List> basesOf;
+ private final TriFunction super T, List>, List, ? extends U> alignmentOf;
+ private final Predicate super BwaMemAlignment> alignmentFilter;
+
+ public SingleSequenceReferenceAligner(final String name, final byte[] bases,
+ final Function super T, List> basesOf,
+ final TriFunction super T, List>, List, ? extends U> alignmentOf) {
+ this(name, bases, basesOf, alignmentOf, NO_FILTER);
+ }
+
+ public SingleSequenceReferenceAligner(final String name, final byte[] bases,
+ final Function super T, List> basesOf,
+ final TriFunction super T, List>, List, ? extends U> alignmentOf,
+ final Predicate super BwaMemAlignment> alignmentFilter) {
+ Utils.nonNull(name, "the input reference name cannot be null");
+ Utils.nonNull(bases, "the input bases cannot be null");
+ this.basesOf = Utils.nonNull(basesOf);
+ this.alignmentOf = Utils.nonNull(alignmentOf);
+ this.alignmentFilter = Utils.nonNull(alignmentFilter);
+ Utils.validate(bases.length > 0, "the reference contig bases sequence must have at least one base");
+ try {
+ final File fasta = File.createTempFile("ssvh-temp", ".fasta");
+ fasta.deleteOnExit();
+ image = new File(fasta.getParentFile(), fasta.getName().replace(".fasta", ".img"));
+ image.deleteOnExit();
+ FastaReferenceWriter.writeSingleSequenceReference(fasta.toPath(), false, false, name, null, bases);
+ BwaMemIndex.createIndexImageFromFastaFile(fasta.toString(), image.toString());
+ fasta.delete(); // we don't need the fasta around.
+ index = new BwaMemIndex(image.toString());
+ aligner = new BwaMemAligner(index);
+ } catch (final IOException ex) {
+ throw new GATKException("could not create index files", ex);
+ }
+ refNames = Collections.singletonList(name);
+ }
+
+ /**
+ * Gives access to the underlying aligner so that you can modify its options.
+ *
+ * You could align sequences directly thru the return object but in that case you will lose the
+ * {@link BwaMemAlignment} to {@link AlignmentInterval} translation.
+ *
+ *
+ * @return never {@code null}.
+ */
+ public BwaMemAligner getAligner() {
+ return aligner;
+ }
+
+ public List align(final Iterable extends T> inputs) {
+ Utils.nonNull(inputs);
+ return align(Utils.stream(inputs).collect(Collectors.toList()));
+ }
+
+ /**
+ * Aligns the input object returning a list of the outputs in the corresponding order.
+ * @param inputs
+ * @return
+ */
+ public List align(final List extends T> inputs) {
+ checkNotClosed();
+ Utils.nonNull(inputs, "the input sequence array cannot be null");
+ final List> seqs = inputs.stream().map(basesOf).collect(Collectors.toList());
+ final List flattenSeqs = seqs.stream().flatMap(Collection::stream).collect(Collectors.toList());
+ final List> alignments = aligner.alignSeqs(flattenSeqs);
+ if (alignments.size() != flattenSeqs.size()) { // paranoiah??
+ throw new IllegalStateException("something went terribly wrong and the number of returned alignment list does " +
+ "not correspond to the number of input sequences: " + alignments.size() + " != " + flattenSeqs.size());
+ }
+ final List result = new ArrayList<>(inputs.size());
+ int nextAlignmentIndex = 0;
+ for (int i = 0; i < inputs.size(); i++) {
+ final T inputObject = inputs.get(i);
+ final List sequences = seqs.get(i);
+ final List> relevantAlignments =
+ alignments.subList(nextAlignmentIndex, nextAlignmentIndex += sequences.size());
+ final List> filteredAlignments;
+ if (alignmentFilter == NO_FILTER) {
+ filteredAlignments = relevantAlignments;
+ } else {
+ filteredAlignments = new ArrayList<>(relevantAlignments.size());
+ for (int j = 0; j < relevantAlignments.size(); j++) {
+ filteredAlignments.add(relevantAlignments.get(j).stream().filter(alignmentFilter).collect(Collectors.toList()));
+ }
+ }
+ final U outputObject = alignmentOf.apply(inputObject, filteredAlignments, refNames);
+ result.add(outputObject);
+ }
+ return result;
+ }
+
+ /**
+ * Composes a map of the aligned sequences.
+ *
+ * The key of such a map would be determined by the input object and the output alignment.
+ *
+ *
+ * Iterations over the entries, keys and values of the resulting tree will follow the order
+ * of the input objects. In case of key collisions (more than one input object, alignment result in the same key)
+ * then we only keep the first occurrence of such key.
+ *
+ * @param inputs the input objects to align.
+ * @param keyOf function that composed the key given the input and output objects.
+ * @param
+ * @return never {@code null}.
+ */
+ public Map align(final List extends T> inputs, final BiFunction super T, ? super U, ? extends V> keyOf) {
+ final List outs = align(inputs);
+ final LinkedHashMap result = new LinkedHashMap<>(outs.size());
+ for (int i = 0; i < outs.size(); i++) {
+ result.putIfAbsent(keyOf.apply(inputs.get(i), outs.get(i)), outs.get(i));
+ }
+ return result;
+ }
+
+ /**
+ * Composes a contig aligner from an arbitrary input type given contig name and base sequence generation functions.
+ *
+ * Supplementary alignments will be ignored.
+ *
+ *
+ * @param refName the name of the only reference sequence.
+ * @param refBases the bases for the only reference sequence.
+ * @param nameOf function that produces the aligned contig name based on the input object.
+ * @param basesOf function that produces the contig bases sequence based on the input object.
+ * @param the type-parameter of the input object.
+ * @return never {@code null}.
+ */
+ public static SingleSequenceReferenceAligner contigsAligner(final String refName, final byte[] refBases,
+ final Function super T, String> nameOf,
+ final Function super T, byte[]> basesOf) {
+ return new SingleSequenceReferenceAligner<>(refName, refBases,
+ t -> Collections.singletonList(basesOf.apply(t)),
+ (t, bma, refNames) -> {
+ final String name = nameOf.apply(t);
+ final byte[] bases = basesOf.apply(t);
+ final List intervals = bma.get(0).stream().map(b -> new AlignmentInterval(b, refNames, bases.length)).collect(Collectors.toList());
+ return new AlignedContig(name, bases, intervals);
+ },
+ bma -> bma.getRefId() >= 0 && SAMFlag.SECONDARY_ALIGNMENT.isUnset(bma.getSamFlag()));
+ }
+
+ private void checkNotClosed() {
+ if (closed) {
+ throw new IllegalStateException("operation not allowed once the aligner is closed");
+ }
+ }
+
+ @Override
+ public void close() throws IOException {
+ if (!closed) {
+ aligner.close();
+ closed = true;
+ try {
+ index.close();
+ } finally {
+ image.delete();
+ }
+ }
+ }
+}
diff --git a/src/test/java/org/broadinstitute/hellbender/tools/spark/sv/utils/SingleContigReferenceAlignerUnitTest.java b/src/test/java/org/broadinstitute/hellbender/tools/spark/sv/utils/SingleContigReferenceAlignerUnitTest.java
new file mode 100644
index 00000000000..fbd20d012d0
--- /dev/null
+++ b/src/test/java/org/broadinstitute/hellbender/tools/spark/sv/utils/SingleContigReferenceAlignerUnitTest.java
@@ -0,0 +1,145 @@
+package org.broadinstitute.hellbender.tools.spark.sv.utils;
+
+import htsjdk.samtools.Cigar;
+import htsjdk.samtools.CigarElement;
+import htsjdk.samtools.CigarOperator;
+import htsjdk.samtools.TextCigarCodec;
+import htsjdk.samtools.util.SequenceUtil;
+import org.apache.commons.math3.util.Pair;
+import org.broadinstitute.hellbender.tools.spark.sv.discovery.alignment.AlignedContig;
+import org.broadinstitute.hellbender.tools.spark.sv.discovery.alignment.AlignmentInterval;
+import org.broadinstitute.hellbender.utils.RandomDNA;
+import org.broadinstitute.hellbender.utils.SimpleInterval;
+import org.broadinstitute.hellbender.utils.read.CigarUtils;
+import org.broadinstitute.hellbender.utils.test.BaseTest;
+import org.testng.Assert;
+import org.testng.annotations.DataProvider;
+import org.testng.annotations.Test;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.stream.Collectors;
+
+/**
+ * Unit tests for {@link SingleSequenceReferenceAligner}
+ */
+public class SingleContigReferenceAlignerUnitTest extends BaseTest {
+
+ private static final int READ_LENGTH = 250;
+ private static final int NUM_ALIGNS = 1000;
+ private static final String REF_NAME = "ref00";
+
+ @Test(expectedExceptions = IllegalStateException.class, expectedExceptionsMessageRegExp = "^.*closed.*$")
+ public void testClosedException() {
+ final RandomDNA rdn = new RandomDNA(131);
+ final byte[] refBases = rdn.nextBases(1000);
+ final SingleSequenceReferenceAligner aligner = SingleSequenceReferenceAligner.contigsAligner("test", refBases,
+ a -> "contig", a -> a);
+ try {
+ aligner.close();
+ } catch (final IOException ex) {
+ Assert.fail("unexpected exception when closing");
+ }
+ aligner.align(Collections.singletonList(Arrays.copyOfRange(refBases, 100, 200)));
+ }
+
+
+ @Test(dataProvider = "testAlignmentData")
+ public void testAlignment(final boolean pairAlignment, final byte[] reference, final String referenceName)
+ throws IOException
+ {
+ try (final SingleSequenceReferenceAligner aligner = SingleSequenceReferenceAligner.contigsAligner(referenceName, reference,
+ a -> "ctg", a -> a);) {
+ Assert.assertNotNull(aligner.getAligner());
+ if (pairAlignment) {
+ aligner.getAligner().alignPairs();
+ }
+ final Random rdn = new Random(13111);
+ final RandomDNA rdnDNA = new RandomDNA(rdn);
+ final List> expected = new ArrayList<>(NUM_ALIGNS << 1);
+ final List seqs = new ArrayList<>(NUM_ALIGNS << 1);
+ for (int i = 0; i < NUM_ALIGNS; i++) {
+ final int start = rdn.nextInt(reference.length - READ_LENGTH) + 1;
+ final boolean forward = rdn.nextBoolean();
+ final boolean insert = rdn.nextDouble() < 0.1;
+ final boolean deletion = !insert && rdn.nextDouble() < 0.1;
+ final int indelLength = insert || deletion ? rdn.nextInt(10) + 10 : 0;
+ final int indelStart = insert || deletion ? rdn.nextInt((int) (READ_LENGTH * .50)) + 25 : -1;
+ final int end = start + READ_LENGTH - 1;
+ final byte[] templateSeq = Arrays.copyOfRange(reference, start - 1, end);
+ final byte[] actualSeq;
+ if (insert) {
+ actualSeq = Arrays.copyOf(templateSeq, templateSeq.length + indelLength);
+ System.arraycopy(actualSeq, indelStart - 1, actualSeq, indelStart - 1 + indelLength, templateSeq.length - indelStart + 1);
+ rdnDNA.nextBases(actualSeq, indelStart - 1, indelLength);
+ } else if (deletion) {
+ actualSeq = Arrays.copyOf(templateSeq, templateSeq.length - indelLength);
+ System.arraycopy(templateSeq, indelStart - 1 + indelLength, actualSeq, indelStart - 1, templateSeq.length - indelStart + 1 - indelLength);
+ } else {
+ actualSeq = templateSeq.clone();
+ }
+
+ while (insert && actualSeq[indelStart - 1] == actualSeq[indelStart + indelLength - 1]) {
+ actualSeq[indelStart + indelLength - 1] = rdnDNA.nextBase();
+ }
+ if (!forward) {
+ SequenceUtil.reverseComplement(actualSeq);
+ }
+ seqs.add(actualSeq);
+ final Cigar cigar = (!insert && !deletion) ? TextCigarCodec.decode(READ_LENGTH + "M"):
+ (insert ? TextCigarCodec.decode( "" + (indelStart - 1) + "M" + indelLength + "I" + (READ_LENGTH - indelStart + 1) + "M")
+ : TextCigarCodec.decode( "" + (indelStart - 1) + "M" + indelLength + "D" + (READ_LENGTH - indelStart + 1 - indelLength) + "M"));
+
+ expected.add(Collections.singletonList(new AlignmentInterval(new SimpleInterval(referenceName, start, end), 1, actualSeq.length, !forward ? CigarUtils.invertCigar(cigar) : cigar
+ , forward, 0, 0, 0, null)));
+ }
+ final List results = aligner.align(seqs);
+ final Map mapResult = aligner.align(seqs, (b, a) -> b);
+ Assert.assertEquals(results, new ArrayList<>(mapResult.values()));
+ Assert.assertEquals(new ArrayList<>(mapResult.keySet()), mapResult.values().stream().map(AlignedContig::getContigSequence).collect(Collectors.toList()));
+ for (int i = 0; i < NUM_ALIGNS; i++) {
+ final List actualValue = results.get(i).getAlignments();
+ final List expectedValue = expected.get(i);
+ Assert.assertEquals(actualValue.size(), 1);
+ Assert.assertEquals(actualValue.get(0).forwardStrand, expectedValue.get(0).forwardStrand);
+ Assert.assertEquals(actualValue.get(0).referenceSpan, expectedValue.get(0).referenceSpan, expectedValue.get(0).cigarAlong5to3DirectionOfContig.toString());
+ Assert.assertEquals(actualValue.get(0).startInAssembledContig, expectedValue.get(0).startInAssembledContig);
+ Assert.assertEquals(actualValue.get(0).endInAssembledContig, expectedValue.get(0).endInAssembledContig);
+ final Cigar expectedCigar = expectedValue.get(0).cigarAlong5to3DirectionOfContig;
+ final Cigar actualCigar = actualValue.get(0).cigarAlong5to3DirectionOfContig;
+ if (!expectedCigar.equals(actualCigar)) { // small differences may occur due to ambiguous indel location. So we check that they are small differences indeed:
+ Assert.assertEquals(expectedCigar.numCigarElements(), actualCigar.numCigarElements()); // same number of elements
+ Assert.assertEquals(expectedCigar.getCigarElements().stream().map(CigarElement::getOperator).collect(Collectors.toList()),
+ actualCigar.getCigarElements().stream().map(CigarElement::getOperator).collect(Collectors.toList())); // same operators sequence.
+ // then we check the total lengths per operator (must be the same):
+ final Map expectedLengthByOperator = expectedCigar.getCigarElements().stream()
+ .collect(Collectors.groupingBy(CigarElement::getOperator,
+ Collectors.reducing(0, CigarElement::getLength, (a, b) -> a + b)));
+ final Map actualLengthByOperator = actualCigar.getCigarElements().stream()
+ .collect(Collectors.groupingBy(CigarElement::getOperator,
+ Collectors.reducing(0, CigarElement::getLength, (a, b) -> a + b)));
+ Assert.assertEquals(actualLengthByOperator, expectedLengthByOperator);
+ // finally we don't allow more than 5 bases length difference for any given element.
+ for (int j = 0; j < expectedCigar.numCigarElements(); j++) {
+ Assert.assertTrue(Math.abs(expectedCigar.getCigarElement(j).getLength() - actualCigar.getCigarElement(j).getLength()) < 10, "actual: " + actualCigar + " != expected: " + expectedCigar);
+ }
+ }
+ }
+ }
+ }
+
+ @DataProvider(name="testAlignmentData")
+ public Object[][] testAlignmentData() {
+ final List