Skip to content

Commit

Permalink
Added the ability for MarkDuplicatesSpark to accept multiple inputs (#…
Browse files Browse the repository at this point in the history
…5430)

* Added the ability for MarkDuplicatesSpark to take multiple input bams if they are all queryname sorted, and extended this support to opting-in gatkSparkTools.
  • Loading branch information
jamesemery authored Jan 3, 2019
1 parent 5d1d572 commit 1fb980c
Show file tree
Hide file tree
Showing 10 changed files with 148 additions and 43 deletions.
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
package org.broadinstitute.hellbender.engine.spark;

import htsjdk.samtools.SAMFileHeader;
import htsjdk.samtools.SAMSequenceDictionary;
import com.google.common.annotations.VisibleForTesting;
import htsjdk.samtools.*;
import htsjdk.samtools.reference.ReferenceSequenceFileFactory;
import htsjdk.samtools.util.CloseableIterator;
import htsjdk.samtools.util.GZIIndex;
import htsjdk.samtools.util.IOUtil;
import htsjdk.variant.vcf.VCFHeaderLine;
Expand All @@ -25,6 +26,7 @@
import org.broadinstitute.hellbender.engine.filters.WellformedReadFilter;
import org.broadinstitute.hellbender.engine.spark.datasources.ReadsSparkSink;
import org.broadinstitute.hellbender.engine.spark.datasources.ReadsSparkSource;
import org.broadinstitute.hellbender.exceptions.GATKException;
import org.broadinstitute.hellbender.exceptions.UserException;
import org.broadinstitute.hellbender.tools.walkers.annotator.Annotation;
import org.broadinstitute.hellbender.utils.SequenceDictionaryUtils;
Expand Down Expand Up @@ -121,7 +123,7 @@ public abstract class GATKSparkTool extends SparkCommandLineProgram {

private ReadsSparkSource readsSource;
private SAMFileHeader readsHeader;
private String readInput;
private LinkedHashMap<String, SAMFileHeader> readInputs;
private ReferenceMultiSparkSource referenceSource;
private SAMSequenceDictionary referenceDictionary;
private List<SimpleInterval> userIntervals;
Expand Down Expand Up @@ -158,6 +160,20 @@ public boolean requiresReads() {
return false;
}

/**
* Does this tool support multiple inputs? Tools that do should override this method with the desired {@link ReadInputMergingPolicy}.
*
* @return doNotMerge by default
*/
public ReadInputMergingPolicy getReadInputMergingPolicy() {
return ReadInputMergingPolicy.doNotMerge;
}

public static enum ReadInputMergingPolicy {
doNotMerge,
concatMerge
}

/**
* Does this tool require intervals? Tools that do should override to return true.
*
Expand Down Expand Up @@ -274,24 +290,39 @@ public JavaRDD<GATKRead> getUnfilteredReads() {
}
traversalParameters = new TraversalParameters(getIntervals(), traverseUnmapped);
} else {
traversalParameters = null; // no intervals were specified so return all reads (mapped and unmapped)
traversalParameters = null;
}

// TODO: This if statement is a temporary hack until #959 gets resolved.
if (readInput.endsWith(".adam")) {
JavaRDD<GATKRead> output = null;
ReadsSparkSource source = readsSource;
for (String input : readInputs.keySet()) {
if (output == null) {
output = getGatkReadJavaRDD(traversalParameters, source, input);
} else {
output = output.union(getGatkReadJavaRDD(traversalParameters, source, input));
}
}
return output;
}

protected JavaRDD<GATKRead> getGatkReadJavaRDD(TraversalParameters traversalParameters, ReadsSparkSource source, String input) {
JavaRDD<GATKRead> output;
// TODO: This if statement is a temporary hack until #959 gets resolve
if (input.endsWith(".adam")) {
try {
return readsSource.getADAMReads(readInput, traversalParameters, getHeaderForReads());
output = source.getADAMReads(input, traversalParameters, getHeaderForReads());
} catch (IOException e) {
throw new UserException("Failed to read ADAM file " + readInput, e);
throw new UserException("Failed to read ADAM file " + input, e);
}

} else {
if (hasCramInput() && !hasReference()){
throw new UserException.MissingReference("A reference file is required when using CRAM files.");
}
final String refPath = hasReference() ? referenceArguments.getReferenceFileName() : null;
return readsSource.getParallelReads(readInput, refPath, traversalParameters, bamPartitionSplitSize);
output = source.getParallelReads(input, refPath, traversalParameters, bamPartitionSplitSize);
}
return output;
}

/**
Expand Down Expand Up @@ -334,7 +365,8 @@ public int getRecommendedNumReducers() {
if (numReducers != 0) {
return numReducers;
}
return 1 + (int) (BucketUtils.dirSize(getReadSourceName()) / getTargetPartitionSize());
int size = readInputs.keySet().stream().mapToInt(k -> (int) BucketUtils.dirSize(k)).sum();
return 1 + (size / getTargetPartitionSize());
}

/**
Expand Down Expand Up @@ -445,8 +477,18 @@ public Collection<Annotation> makeVariantAnnotations() {
/**
* Returns the name of the source of reads data. It can be a file name or URL.
*/
protected String getReadSourceName(){
return readInput;
protected List<String> getReadSourceName(){
if (readInputs.size() > 1) {
throw new GATKException("Multiple ReadsDataSources specificed but a single source requested by the tool");
}
return new ArrayList<>(readInputs.keySet());
}

/**
* Returns a map of read input to header.
*/
protected LinkedHashMap<String, SAMFileHeader> getReadSouceHeaderMap(){
return readInputs;
}

/**
Expand Down Expand Up @@ -489,15 +531,37 @@ private void initializeReads(final JavaSparkContext sparkContext) {
return;
}

if ( readArguments.getReadFilesNames().size() != 1 ) {
throw new UserException("Sorry, we only support a single reads input for spark tools for now.");
if (getReadInputMergingPolicy() == ReadInputMergingPolicy.doNotMerge && readArguments.getReadFilesNames().size() != 1 ) {
throw new UserException("Sorry, we only support a single reads input for for this spark tool.");
}

readInput = readArguments.getReadFilesNames().get(0);
readInputs = new LinkedHashMap<>();
readsSource = new ReadsSparkSource(sparkContext, readArguments.getReadValidationStringency());
readsHeader = readsSource.getHeader(
readInput,
hasReference() ? referenceArguments.getReferenceFileName() : null);
for (String input : readArguments.getReadFilesNames()) {
readInputs.put(input, readsSource.getHeader(
input, hasReference() ? referenceArguments.getReferenceFileName() : null));
}
readsHeader = createHeaderMerger().getMergedHeader();
}

/**
* Create a header merger from the individual SAM/BAM headers in our readers
*
* @return a header merger containing all individual headers in this data source
*/
private SamFileHeaderMerger createHeaderMerger() {
return new SamFileHeaderMerger(identifySortOrder(readInputs.values()), readInputs.values(), true);
}
@VisibleForTesting
static SAMFileHeader.SortOrder identifySortOrder(final Collection<SAMFileHeader> headers){
final Set<SAMFileHeader.SortOrder> sortOrders = headers.stream().map(SAMFileHeader::getSortOrder).collect(Collectors.toSet());
final SAMFileHeader.SortOrder order;
if (sortOrders.size() == 1) {
order = sortOrders.iterator().next();
} else {
order = SAMFileHeader.SortOrder.unsorted;
}
return order;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ public void addToMetricsFile(final MetricsFile<BaseDistributionByCycleMetrics, ?
protected void runTool(final JavaSparkContext ctx) {
final JavaRDD<GATKRead> reads = getReads();
final MetricsFile<BaseDistributionByCycleMetrics, Integer> metricsFile = calculateBaseDistributionByCycle(reads);
saveResults(metricsFile, getHeaderForReads(), getReadSourceName());
saveResults(metricsFile, getHeaderForReads(), getReadSourceName().get(0));
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ protected void runTool( final JavaSparkContext ctx ) {
unFilteredReads.filter(r -> readFilter.test(r)),
getHeaderForReads()
);
metricsCollector.saveMetrics(getReadSourceName());
metricsCollector.saveMetrics(getReadSourceName().get(0));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ HistogramGeneratorPair merge(final HistogramGeneratorPair other){
protected void runTool(final JavaSparkContext ctx) {
final JavaRDD<GATKRead> reads = getReads();
final MetricsFile<?, Integer> metricsFile = calculateMeanQualityByCycle(reads);
saveResults(metricsFile, getHeaderForReads(), getReadSourceName());
saveResults(metricsFile, getHeaderForReads(), getReadSourceName().get(0));
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ protected void runTool( JavaSparkContext ctx ) {
getHeaderForReads().getSortOrder(),
getExpectedSortOrder(),
false,
getReadSourceName()
getReadSourceName().get(0)
);

// Execute the collector lifecycle
Expand All @@ -84,7 +84,7 @@ protected void runTool( JavaSparkContext ctx ) {
initialize(collectorArgs, getHeaderForReads(), getDefaultHeaders());
final JavaRDD<GATKRead> filteredReads = getReads();
collectMetrics(filteredReads, getHeaderForReads());
saveMetrics(getReadSourceName());
saveMetrics(getReadSourceName().get(0));
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ protected void runTool(final JavaSparkContext ctx) {
(counts1, counts2) -> counts1.merge(counts2));

final MetricsFile<?, Byte> metrics = makeMetrics(result);
saveResults(metrics, getHeaderForReads(), getReadSourceName());
saveResults(metrics, getHeaderForReads(), getReadSourceName().get(0));
}

//Convert the count object into a metrics object so save in a report
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
import org.broadinstitute.hellbender.engine.filters.ReadFilter;
import org.broadinstitute.hellbender.engine.filters.ReadFilterLibrary;
import org.broadinstitute.hellbender.engine.spark.GATKSparkTool;
import org.broadinstitute.hellbender.engine.spark.datasources.ReadsSparkSource;
import org.broadinstitute.hellbender.exceptions.GATKException;
import org.broadinstitute.hellbender.exceptions.UserException;
import org.broadinstitute.hellbender.utils.Utils;
import org.broadinstitute.hellbender.utils.read.GATKRead;
import org.broadinstitute.hellbender.utils.read.ReadUtils;
Expand Down Expand Up @@ -69,6 +71,11 @@ public List<ReadFilter> getDefaultReadFilters() {
// Reads with this marker will be treated and marked as optical duplicates
public static int OPTICAL_DUPLICATE_MARKER = -2;

@Override
public ReadInputMergingPolicy getReadInputMergingPolicy() {
return ReadInputMergingPolicy.concatMerge;
}

/**
* Main method for marking duplicates, takes an JavaRDD of GATKRead and an associated SAMFileHeader with corresponding
* sorting information and returns a new JavaRDD\<GATKRead\> in which all read templates have been marked as duplicates
Expand Down Expand Up @@ -215,6 +222,14 @@ public int getPartition(Object key) {

@Override
protected void runTool(final JavaSparkContext ctx) {
// Check if we are using multiple inputs that the headers are all in the correct querygrouped ordering
Map<String, SAMFileHeader> headerMap = getReadSouceHeaderMap();
if (headerMap.size() > 1) {
headerMap.entrySet().stream().forEach(h -> {if(!ReadUtils.isReadNameGroupedBam(h.getValue())) {
throw new UserException("Multiple inputs to MarkDuplicatesSpark detected but input "+h.getKey()+" was sorted in "+h.getValue().getSortOrder()+" order");
}});
}

JavaRDD<GATKRead> reads = getReads();
final OpticalDuplicateFinder finder = opticalDuplicatesArgumentCollection.READ_NAME_REGEX != null ?
new OpticalDuplicateFinder(opticalDuplicatesArgumentCollection.READ_NAME_REGEX, opticalDuplicatesArgumentCollection.OPTICAL_DUPLICATE_PIXEL_DISTANCE, null) : null;
Expand Down
Loading

0 comments on commit 1fb980c

Please sign in to comment.