Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added an option to ReadsSparkSink specifying whether to sort the reads on output. #4874

Merged
merged 10 commits into from
Feb 12, 2019
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ protected JavaRDD<GATKRead> getGatkReadJavaRDD(TraversalParameters traversalPara
* @param reads reads to write.
*/
public void writeReads(final JavaSparkContext ctx, final String outputFile, JavaRDD<GATKRead> reads) {
writeReads(ctx, outputFile, reads, readsHeader);
writeReads(ctx, outputFile, reads, readsHeader, true);
}

/**
Expand All @@ -356,12 +356,12 @@ public void writeReads(final JavaSparkContext ctx, final String outputFile, Java
* @param reads reads to write.
* @param header the header to write.
*/
public void writeReads(final JavaSparkContext ctx, final String outputFile, JavaRDD<GATKRead> reads, SAMFileHeader header) {
public void writeReads(final JavaSparkContext ctx, final String outputFile, JavaRDD<GATKRead> reads, SAMFileHeader header, final boolean sortReadsToHeader) {
try {
ReadsSparkSink.writeReads(ctx, outputFile,
hasReference() ? referenceArguments.getReferencePath().toAbsolutePath().toUri().toString() : null,
reads, header, shardedOutput ? ReadsWriteFormat.SHARDED : ReadsWriteFormat.SINGLE,
getRecommendedNumReducers(), shardedPartsDir, createOutputBamIndex, createOutputBamSplittingIndex);
getRecommendedNumReducers(), shardedPartsDir, createOutputBamIndex, createOutputBamSplittingIndex, sortReadsToHeader);
} catch (IOException e) {
throw new UserException.CouldNotCreateOutputFile(outputFile,"writing failed", e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ public final class ReadsSparkSink {
public static void writeReads(
final JavaSparkContext ctx, final String outputFile, final String referenceFile, final JavaRDD<GATKRead> reads,
final SAMFileHeader header, ReadsWriteFormat format) throws IOException {
writeReads(ctx, outputFile, referenceFile, reads, header, format, 0, null);
writeReads(ctx, outputFile, referenceFile, reads, header, format, 0, null, true);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you think this behavior is important to bwaspark, push it down there. This is going to bite us in the future if we leave it in.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, you did, you just forgot the comment.

}

/**
Expand All @@ -66,11 +66,12 @@ public static void writeReads(
* @param numReducers the number of reducers to use when writing a single file. A value of zero indicates that the default
* should be used.
* @param outputPartsDir directory for temporary files for SINGLE output format, should be null for default value of filename + .output
* @param sortReadsToHeader if true, the writer will perform a sort of reads according to the sort order of the header before writing
*/
public static void writeReads(
final JavaSparkContext ctx, final String outputFile, final String referenceFile, final JavaRDD<GATKRead> reads,
final SAMFileHeader header, ReadsWriteFormat format, final int numReducers, final String outputPartsDir) throws IOException {
writeReads(ctx, outputFile, referenceFile, reads, header, format, numReducers, outputPartsDir, true, true);
final SAMFileHeader header, ReadsWriteFormat format, final int numReducers, final String outputPartsDir, final boolean sortReadsToHeader) throws IOException {
writeReads(ctx, outputFile, referenceFile, reads, header, format, numReducers, outputPartsDir, true, true, sortReadsToHeader);
}

/**
Expand All @@ -86,11 +87,12 @@ public static void writeReads(
* @param outputPartsDir directory for temporary files for SINGLE output format, should be null for default value of filename + .output
* @param writeBai whether to write a BAI file (when writing BAM format)
* @param writeSbi whether to write an SBI file (when writing BAM format)
* @param sortReadsToHeader whether to sort the reads in the underlying RDD to match the header sort order option before writing
*/
public static void writeReads(
final JavaSparkContext ctx, final String outputFile, final String referenceFile, final JavaRDD<GATKRead> reads,
final SAMFileHeader header, ReadsWriteFormat format, final int numReducers, final String outputPartsDir,
final boolean writeBai, final boolean writeSbi) throws IOException {
final boolean writeBai, final boolean writeSbi, final boolean sortReadsToHeader) throws IOException {

String absoluteOutputFile = BucketUtils.makeFilePathAbsolute(outputFile);
String absoluteReferenceFile = referenceFile != null ?
Expand All @@ -103,6 +105,7 @@ public static void writeReads(
// SAMRecords, this will effectively be a no-op. The SAMRecords will be headerless
// for efficient serialization.
final JavaRDD<SAMRecord> samReads = reads.map(read -> read.convertToSAMRecord(null));
final JavaRDD<SAMRecord> readsToOutput = sortReadsToHeader ? sortSamRecordsToMatchHeader(samReads, header, numReducers) : samReads;

if (format == ReadsWriteFormat.SINGLE) {
FileCardinalityWriteOption fileCardinalityWriteOption = FileCardinalityWriteOption.SINGLE;
Expand All @@ -114,12 +117,12 @@ public static void writeReads(
absoluteOutputFile.endsWith(CramIO.CRAM_FILE_EXTENSION) ||
absoluteOutputFile.endsWith(IOUtil.SAM_FILE_EXTENSION)) {
// don't specify a write option for format since it is inferred from the extension in the path
writeReads(ctx, absoluteOutputFile, absoluteReferenceFile, samReads, header, numReducers,
writeReads(ctx, absoluteOutputFile, absoluteReferenceFile, readsToOutput, header, numReducers,
fileCardinalityWriteOption, tempPartsDirectoryWriteOption, baiWriteOption, sbiWriteOption);
} else {
// default to BAM
ReadsFormatWriteOption formatWriteOption = ReadsFormatWriteOption.BAM;
writeReads(ctx, absoluteOutputFile, absoluteReferenceFile, samReads, header, numReducers, formatWriteOption,
writeReads(ctx, absoluteOutputFile, absoluteReferenceFile, readsToOutput, header, numReducers, formatWriteOption,
fileCardinalityWriteOption, tempPartsDirectoryWriteOption, baiWriteOption, sbiWriteOption);
}
} else if (format == ReadsWriteFormat.SHARDED) {
Expand All @@ -128,22 +131,21 @@ public static void writeReads(
}
ReadsFormatWriteOption formatWriteOption = ReadsFormatWriteOption.BAM; // use BAM if output file is a directory
FileCardinalityWriteOption fileCardinalityWriteOption = FileCardinalityWriteOption.MULTIPLE;
writeReads(ctx, absoluteOutputFile, absoluteReferenceFile, samReads, header, numReducers, formatWriteOption, fileCardinalityWriteOption);
writeReads(ctx, absoluteOutputFile, absoluteReferenceFile, readsToOutput, header, numReducers, formatWriteOption, fileCardinalityWriteOption);
} else if (format == ReadsWriteFormat.ADAM) {
if (outputPartsDir!=null) {
throw new GATKException(String.format("You specified the bam output parts directory %s, but requested an ADAM output format which does not use this option",outputPartsDir));
}
writeReadsADAM(ctx, absoluteOutputFile, samReads, header);
writeReadsADAM(ctx, absoluteOutputFile, readsToOutput, header);
}
}

private static void writeReads(
final JavaSparkContext ctx, final String outputFile, final String referenceFile, final JavaRDD<SAMRecord> reads,
final SAMFileHeader header, final int numReducers, final WriteOption... writeOptions) throws IOException {

final JavaRDD<SAMRecord> sortedReads = sortSamRecordsToMatchHeader(reads, header, numReducers);
Broadcast<SAMFileHeader> headerBroadcast = ctx.broadcast(header);
final JavaRDD<SAMRecord> sortedReadsWithHeader = sortedReads.map(read -> {
final JavaRDD<SAMRecord> sortedReadsWithHeader = reads.map(read -> {
read.setHeaderStrict(headerBroadcast.getValue());
return read;
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ protected void runTool(JavaSparkContext ctx) {
final String key = rmap.getKey();
JavaRDD<GATKRead> filteredreads = rmap.getKey()==null? readsReverted :
readsReverted.filter(r -> r.getReadGroup().equals(key));
writeReads(ctx, rmap.getValue().toString(), filteredreads, headerMap.get(rmap.getKey())); //TODO proper header map
writeReads(ctx, rmap.getValue().toString(), filteredreads, headerMap.get(rmap.getKey()), false); //TODO proper header map
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,10 @@ protected void runTool(final JavaSparkContext ctx) {
reads = bwaEngine.alignPaired(getUnfilteredReads()).filter(filter::test);
}
try {
// NOTE, we must include 'shardedOutput' as the input to `sortReadsToMatchHeader` to preserve the old default behavior for writing BwaSpark output
// where we would not sort the bam if outputting to a sharded output.
ReadsSparkSink.writeReads(ctx, output, null, reads, bwaEngine.getHeader(),
shardedOutput ? ReadsWriteFormat.SHARDED : ReadsWriteFormat.SINGLE);
shardedOutput ? ReadsWriteFormat.SHARDED : ReadsWriteFormat.SINGLE, getRecommendedNumReducers(), shardedPartsDir, shardedOutput);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe comment here so people know that something weird is going on?

} catch (final IOException e) {
throw new GATKException("Unable to write aligned reads", e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ private void writeBam(final JavaRDD<GATKRead> reads, final String inputBamPath,
try {
ReadsSparkSink.writeReads(ctx, outputPath, bwaArgs.referencePath, reads, header,
shardedOutput ? ReadsWriteFormat.SHARDED : ReadsWriteFormat.SINGLE,
PSUtils.pathseqGetRecommendedNumReducers(inputBamPath, numReducers, getTargetPartitionSize()), shardedPartsDir);
PSUtils.pathseqGetRecommendedNumReducers(inputBamPath, numReducers, getTargetPartitionSize()), shardedPartsDir, true);
} catch (final IOException e) {
throw new UserException.CouldNotCreateOutputFile(outputPath, "Writing failed", e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,13 +167,13 @@ protected void runTool(final JavaSparkContext ctx) {

if (!pairedReads.isEmpty()) {
header.setSortOrder(SAMFileHeader.SortOrder.queryname);
writeReads(ctx, outputPaired, pairedReads, header);
writeReads(ctx, outputPaired, pairedReads, header, true);
} else {
logger.info("No paired reads to write - BAM will not be written.");
}
if (!unpairedReads.isEmpty()) {
header.setSortOrder(SAMFileHeader.SortOrder.unsorted);
writeReads(ctx, outputUnpaired, unpairedReads, header);
writeReads(ctx, outputUnpaired, unpairedReads, header, true);
} else {
logger.info("No unpaired reads to write - BAM will not be written.");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ protected void runTool(final JavaSparkContext ctx) {
final int numPartitions = Math.max(1, (int) (numTotalReads / readsPerPartitionOutput));
final JavaRDD<GATKRead> readsFinalRepartitioned = readsFinal.coalesce(numPartitions, false);
ReadsSparkSink.writeReads(ctx, outputPath, null, readsFinalRepartitioned, header,
shardedOutput ? ReadsWriteFormat.SHARDED : ReadsWriteFormat.SINGLE, numPartitions, shardedPartsDir);
shardedOutput ? ReadsWriteFormat.SHARDED : ReadsWriteFormat.SINGLE, numPartitions, shardedPartsDir, true);
} catch (final IOException e) {
throw new UserException.CouldNotCreateOutputFile(outputPath, "writing failed", e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ protected void runTool(final JavaSparkContext ctx) {
if (outputPath != null) {
try {
ReadsSparkSink.writeReads(ctx, outputPath, null, readsFinal, header,
shardedOutput ? ReadsWriteFormat.SHARDED : ReadsWriteFormat.SINGLE, recommendedNumReducers, shardedPartsDir);
shardedOutput ? ReadsWriteFormat.SHARDED : ReadsWriteFormat.SINGLE, recommendedNumReducers, shardedPartsDir, true);
} catch (final IOException e) {
throw new UserException.CouldNotCreateOutputFile(outputPath, "writing failed", e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ protected void runTool(final JavaSparkContext ctx) {
referenceArguments.getReferencePath().toAbsolutePath().toUri().toString(),
markedReads, bwaEngine.getHeader(),
shardedOutput ? ReadsWriteFormat.SHARDED : ReadsWriteFormat.SINGLE,
getRecommendedNumReducers(), shardedPartsDir);
getRecommendedNumReducers(), shardedPartsDir, true);
} catch (IOException e) {
throw new GATKException("unable to write bam: " + e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,6 @@ public final class PrintReadsSpark extends GATKSparkTool {
protected void runTool(final JavaSparkContext ctx) {

final JavaRDD<GATKRead> reads = getReads();
writeReads(ctx, output, reads);
writeReads(ctx, output, reads, getHeaderForReads(), false);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ protected void runTool(final JavaSparkContext ctx) {
final JavaRDD<GATKRead> finalReads = ApplyBQSRSparkFn.apply(sortedMarkedReads, reportBroadcast, getHeaderForReads(), applyBqsrArgs.toApplyBQSRArgumentCollection(bqsrArgs.PRESERVE_QSCORES_LESS_THAN));

if (outputBam != null) { // only write output of BQSR if output BAM is specified
writeReads(ctx, outputBam, finalReads, header);
writeReads(ctx, outputBam, finalReads, header, true);
}

// Run Haplotype Caller
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,6 @@ protected void runTool(final JavaSparkContext ctx) {
final SAMFileHeader header = getHeaderForReads();
header.setSortOrder(sortOrder.getSamOrder());

final JavaRDD<GATKRead> readsToWrite;
if (shardedOutput) {
readsToWrite = SparkUtils.sortReadsAccordingToHeader(reads, header, numReducers);
} else {
readsToWrite = reads;
}
writeReads(ctx, outputFile, readsToWrite, header);
writeReads(ctx, outputFile, reads, header, true);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ protected void runTool( final JavaSparkContext ctx ) {
final Function<GATKRead, Boolean> predicate = getGatkReadBooleanFunction(namesToLookForBroadcast, invertFilter);

final JavaRDD<GATKRead> reads = getUnfilteredReads().filter(predicate).cache();
writeReads(ctx, outputSAM, reads, getHeaderForReads());
writeReads(ctx, outputSAM, reads, getHeaderForReads(), false);

logger.info("Found " + reads.count() + " alignment records for " +
namesToLookForBroadcast.getValue().size() + " unique read names.");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ protected void runTool(final JavaSparkContext ctx) {
}

header.setSortOrder(SAMFileHeader.SortOrder.coordinate);
writeReads(ctx, output, readsForWriting, header);
writeReads(ctx, output, readsForWriting, header, true);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -216,8 +216,8 @@ protected void runTool(final JavaSparkContext ctx) {
SAMFileHeader headerForwrite = bHeader.getValue();
headerForwrite.setAttribute("in","original read file source");

writeReads(ctx, output, firstReads.filter(read -> nameSet.value().contains(read.getName())), headerForwrite);
writeReads(ctx, output2, secondReads.filter(read -> nameSet.value().contains(read.getName())), headerForwrite);
writeReads(ctx, output, firstReads.filter(read -> nameSet.value().contains(read.getName())), headerForwrite, true);
writeReads(ctx, output2, secondReads.filter(read -> nameSet.value().contains(read.getName())), headerForwrite, true);
}

// Produces an RDD of MatchTypes, e.g., EQUAL, DIFFERENT_REPRESENTATIVE_READ, etc. per MarkDuplicates key,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,7 @@ public void printReadSparkOnDataproc(final String input) throws IOException {
DataprocTestUtils.launchGatkTool(PrintReadsSpark.class.getSimpleName(), argBuilder.getArgsList(), clusterName);
final File expected = copyLocally(gcsInputPath, "expected");
final File actual = copyLocally(outputPath, "actual");
IntegrationTestSpec.assertMatchingFiles(Collections.singletonList(actual), Collections.singletonList(expected.toString()), true, ValidationStringency.LENIENT);
assertReadsAreInCoordinatishOrder(actual);
IntegrationTestSpec.assertMatchingFiles(Collections.singletonList(actual), Collections.singletonList(expected.toString()), false, ValidationStringency.LENIENT);
}

private static void assertReadsAreInCoordinatishOrder(final File bam) {
Expand Down Expand Up @@ -117,7 +116,6 @@ public void markDuplicatesSparkOnDataproc() throws IOException {
final File actual = copyLocally(bamOut, "actual");

//assert that the output has the right number of reads and they're ordered correctly
assertReadsAreInCoordinatishOrder(actual);
try( ReadsDataSource reader = new ReadsDataSource(actual.toPath())){
Assert.assertEquals(Iterators.size(reader.iterator()), 1838);
}
Expand Down
Loading