Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added the ability for MarkDuplicatesSpark to accept multiple inputs #5430

Merged
merged 9 commits into from
Jan 3, 2019
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
package org.broadinstitute.hellbender.engine.spark;

import htsjdk.samtools.SAMFileHeader;
import htsjdk.samtools.SAMSequenceDictionary;
import com.google.common.annotations.VisibleForTesting;
import htsjdk.samtools.*;
import htsjdk.samtools.reference.ReferenceSequenceFileFactory;
import htsjdk.samtools.util.CloseableIterator;
import htsjdk.samtools.util.GZIIndex;
import htsjdk.samtools.util.IOUtil;
import htsjdk.variant.vcf.VCFHeaderLine;
Expand All @@ -25,6 +26,7 @@
import org.broadinstitute.hellbender.engine.filters.WellformedReadFilter;
import org.broadinstitute.hellbender.engine.spark.datasources.ReadsSparkSink;
import org.broadinstitute.hellbender.engine.spark.datasources.ReadsSparkSource;
import org.broadinstitute.hellbender.exceptions.GATKException;
import org.broadinstitute.hellbender.exceptions.UserException;
import org.broadinstitute.hellbender.tools.walkers.annotator.Annotation;
import org.broadinstitute.hellbender.utils.SequenceDictionaryUtils;
Expand Down Expand Up @@ -121,7 +123,7 @@ public abstract class GATKSparkTool extends SparkCommandLineProgram {

private ReadsSparkSource readsSource;
private SAMFileHeader readsHeader;
private String readInput;
private LinkedHashMap<String, SAMFileHeader> readInputs;
private ReferenceMultiSparkSource referenceSource;
private SAMSequenceDictionary referenceDictionary;
private List<SimpleInterval> userIntervals;
Expand Down Expand Up @@ -158,6 +160,20 @@ public boolean requiresReads() {
return false;
}

/**
* Does this tool support multiple inputs? Tools that do should override this method with the desired {@link ReadInputMergingPolicy}.
*
* @return doNotMerge by default
*/
public ReadInputMergingPolicy getReadInputMergingPolicy() {
return ReadInputMergingPolicy.doNotMerge;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not default to concatMerge and support multiple inputs for all Spark tools?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I debated this with @lbergelson. It seems like an undesirable behavior for tools that tailor their behavior to the header sort order to union the RDDs of multiple inputs potentially invalidating any assumptions about input ordering.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm in favor of tools opting in.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

}

public static enum ReadInputMergingPolicy {
doNotMerge,
concatMerge
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The mergeAndSort policy is not used or implemented, so shouldn't be added yet. The walker version doesn't have such a policy as far as I can see, it effectively implements concatMerge, so just implementing the equivalent would be sufficient, no?


/**
* Does this tool require intervals? Tools that do should override to return true.
*
Expand Down Expand Up @@ -274,24 +290,39 @@ public JavaRDD<GATKRead> getUnfilteredReads() {
}
traversalParameters = new TraversalParameters(getIntervals(), traverseUnmapped);
} else {
traversalParameters = null; // no intervals were specified so return all reads (mapped and unmapped)
traversalParameters = null;
}

// TODO: This if statement is a temporary hack until #959 gets resolved.
if (readInput.endsWith(".adam")) {
JavaRDD<GATKRead> output = null;
ReadsSparkSource source = readsSource;
for (String input : readInputs.keySet()) {
if (output == null) {
output = getGatkReadJavaRDD(traversalParameters, source, input);
} else {
output = output.union(getGatkReadJavaRDD(traversalParameters, source, input));
}
}
return output;
}

protected JavaRDD<GATKRead> getGatkReadJavaRDD(TraversalParameters traversalParameters, ReadsSparkSource source, String input) {
JavaRDD<GATKRead> output;
// TODO: This if statement is a temporary hack until #959 gets resolve
if (input.endsWith(".adam")) {
try {
return readsSource.getADAMReads(readInput, traversalParameters, getHeaderForReads());
output = source.getADAMReads(input, traversalParameters, getHeaderForReads());
} catch (IOException e) {
throw new UserException("Failed to read ADAM file " + readInput, e);
throw new UserException("Failed to read ADAM file " + input, e);
}

} else {
if (hasCramInput() && !hasReference()){
throw new UserException.MissingReference("A reference file is required when using CRAM files.");
}
final String refPath = hasReference() ? referenceArguments.getReferenceFileName() : null;
return readsSource.getParallelReads(readInput, refPath, traversalParameters, bamPartitionSplitSize);
output = source.getParallelReads(input, refPath, traversalParameters, bamPartitionSplitSize);
}
return output;
}

/**
Expand Down Expand Up @@ -334,7 +365,8 @@ public int getRecommendedNumReducers() {
if (numReducers != 0) {
return numReducers;
}
return 1 + (int) (BucketUtils.dirSize(getReadSourceName()) / getTargetPartitionSize());
int size = readInputs.keySet().stream().mapToInt(k -> (int) BucketUtils.dirSize(k)).sum();
return 1 + (size / getTargetPartitionSize());
}

/**
Expand Down Expand Up @@ -445,8 +477,18 @@ public Collection<Annotation> makeVariantAnnotations() {
/**
* Returns the name of the source of reads data. It can be a file name or URL.
*/
protected String getReadSourceName(){
return readInput;
protected List<String> getReadSourceName(){
if (readInputs.size() > 1) {
throw new GATKException("Multiple ReadsDataSources specificed but a single source requested by the tool");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not ideal maybe but I don't know what else to do about this method... maybe change it to getReadSourceNames and return a list?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, I wanted to avoid returning a list because there are a bunch of places in the code where tools expect there to be only one read source kicking around and I didn't want to risk breaking something or having to uproot everything... I agree its pretty gross... Theoretically it shouldn't be a problem for most tools which don't accept multiple inputs anyway

Copy link
Collaborator

@cmnbroad cmnbroad Jan 2, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should change it to a list, or even get rid of it. All of the existing consumers except one use it to generate an output name for saving metrics (relatively easy to change); the other one (getRecommendedNumReducers) uses it to determine the number of reducers, which won't work correctly on multiple inputs with this implementation. I think getRecommendedNumReducers should be updated either way.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cmnbroad Alright, I'll return a list. getRecommendedNumReducers has already been updated in this branch to sum over the all the read input files.

}
return new ArrayList<>(readInputs.keySet());
}

/**
* Returns a map of read input to header.
*/
protected LinkedHashMap<String, SAMFileHeader> getReadSouceHeaderMap(){
return readInputs;
}

/**
Expand Down Expand Up @@ -489,15 +531,37 @@ private void initializeReads(final JavaSparkContext sparkContext) {
return;
}

if ( readArguments.getReadFilesNames().size() != 1 ) {
throw new UserException("Sorry, we only support a single reads input for spark tools for now.");
if (getReadInputMergingPolicy() == ReadInputMergingPolicy.doNotMerge && readArguments.getReadFilesNames().size() != 1 ) {
throw new UserException("Sorry, we only support a single reads input for for this spark tool.");
}

readInput = readArguments.getReadFilesNames().get(0);
readInputs = new LinkedHashMap<>();
readsSource = new ReadsSparkSource(sparkContext, readArguments.getReadValidationStringency());
readsHeader = readsSource.getHeader(
readInput,
hasReference() ? referenceArguments.getReferenceFileName() : null);
for (String input : readArguments.getReadFilesNames()) {
readInputs.put(input, readsSource.getHeader(
input, hasReference() ? referenceArguments.getReferenceFileName() : null));
}
readsHeader = createHeaderMerger().getMergedHeader();
}

/**
* Create a header merger from the individual SAM/BAM headers in our readers
*
* @return a header merger containing all individual headers in this data source
*/
private SamFileHeaderMerger createHeaderMerger() {
return new SamFileHeaderMerger(identifySortOrder(readInputs.values()), readInputs.values(), true);
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There should be a newline here:

Suggested change
}
}

@VisibleForTesting
static SAMFileHeader.SortOrder identifySortOrder(final Collection<SAMFileHeader> headers){
final Set<SAMFileHeader.SortOrder> sortOrders = headers.stream().map(SAMFileHeader::getSortOrder).collect(Collectors.toSet());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clever way to check this.

final SAMFileHeader.SortOrder order;
if (sortOrders.size() == 1) {
order = sortOrders.iterator().next();
} else {
order = SAMFileHeader.SortOrder.unsorted;
}
return order;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ public void addToMetricsFile(final MetricsFile<BaseDistributionByCycleMetrics, ?
protected void runTool(final JavaSparkContext ctx) {
final JavaRDD<GATKRead> reads = getReads();
final MetricsFile<BaseDistributionByCycleMetrics, Integer> metricsFile = calculateBaseDistributionByCycle(reads);
saveResults(metricsFile, getHeaderForReads(), getReadSourceName());
saveResults(metricsFile, getHeaderForReads(), getReadSourceName().get(0));
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ protected void runTool( final JavaSparkContext ctx ) {
unFilteredReads.filter(r -> readFilter.test(r)),
getHeaderForReads()
);
metricsCollector.saveMetrics(getReadSourceName());
metricsCollector.saveMetrics(getReadSourceName().get(0));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ HistogramGeneratorPair merge(final HistogramGeneratorPair other){
protected void runTool(final JavaSparkContext ctx) {
final JavaRDD<GATKRead> reads = getReads();
final MetricsFile<?, Integer> metricsFile = calculateMeanQualityByCycle(reads);
saveResults(metricsFile, getHeaderForReads(), getReadSourceName());
saveResults(metricsFile, getHeaderForReads(), getReadSourceName().get(0));
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ protected void runTool( JavaSparkContext ctx ) {
getHeaderForReads().getSortOrder(),
getExpectedSortOrder(),
false,
getReadSourceName()
getReadSourceName().get(0)
);

// Execute the collector lifecycle
Expand All @@ -84,7 +84,7 @@ protected void runTool( JavaSparkContext ctx ) {
initialize(collectorArgs, getHeaderForReads(), getDefaultHeaders());
final JavaRDD<GATKRead> filteredReads = getReads();
collectMetrics(filteredReads, getHeaderForReads());
saveMetrics(getReadSourceName());
saveMetrics(getReadSourceName().get(0));
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ protected void runTool(final JavaSparkContext ctx) {
(counts1, counts2) -> counts1.merge(counts2));

final MetricsFile<?, Byte> metrics = makeMetrics(result);
saveResults(metrics, getHeaderForReads(), getReadSourceName());
saveResults(metrics, getHeaderForReads(), getReadSourceName().get(0));
}

//Convert the count object into a metrics object so save in a report
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
import org.broadinstitute.hellbender.engine.filters.ReadFilter;
import org.broadinstitute.hellbender.engine.filters.ReadFilterLibrary;
import org.broadinstitute.hellbender.engine.spark.GATKSparkTool;
import org.broadinstitute.hellbender.engine.spark.datasources.ReadsSparkSource;
import org.broadinstitute.hellbender.exceptions.GATKException;
import org.broadinstitute.hellbender.exceptions.UserException;
import org.broadinstitute.hellbender.utils.Utils;
import org.broadinstitute.hellbender.utils.read.GATKRead;
import org.broadinstitute.hellbender.utils.read.ReadUtils;
Expand Down Expand Up @@ -69,6 +71,11 @@ public List<ReadFilter> getDefaultReadFilters() {
// Reads with this marker will be treated and marked as optical duplicates
public static int OPTICAL_DUPLICATE_MARKER = -2;

@Override
public ReadInputMergingPolicy getReadInputMergingPolicy() {
return ReadInputMergingPolicy.concatMerge;
}

/**
* Main method for marking duplicates, takes an JavaRDD of GATKRead and an associated SAMFileHeader with corresponding
* sorting information and returns a new JavaRDD\<GATKRead\> in which all read templates have been marked as duplicates
Expand Down Expand Up @@ -215,6 +222,14 @@ public int getPartition(Object key) {

@Override
protected void runTool(final JavaSparkContext ctx) {
// Check if we are using multiple inputs that the headers are all in the correct querygrouped ordering
Map<String, SAMFileHeader> headerMap = getReadSouceHeaderMap();
if (headerMap.size() > 1) {
headerMap.entrySet().stream().forEach(h -> {if(!ReadUtils.isReadNameGroupedBam(h.getValue())) {
throw new UserException("Multiple inputs to MarkDuplicatesSpark detected but input "+h.getKey()+" was sorted in "+h.getValue().getSortOrder()+" order");
}});
}

JavaRDD<GATKRead> reads = getReads();
final OpticalDuplicateFinder finder = opticalDuplicatesArgumentCollection.READ_NAME_REGEX != null ?
new OpticalDuplicateFinder(opticalDuplicatesArgumentCollection.READ_NAME_REGEX, opticalDuplicatesArgumentCollection.OPTICAL_DUPLICATE_PIXEL_DISTANCE, null) : null;
Expand Down
Loading