Skip to content

Commit

Permalink
Added an option to ReadsSparkSink specifying whether to sort the read…
Browse files Browse the repository at this point in the history
…s on output. (#4874)
  • Loading branch information
jamesemery authored Feb 12, 2019
1 parent 566e97c commit 9c22c34
Show file tree
Hide file tree
Showing 20 changed files with 505 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ protected JavaRDD<GATKRead> getGatkReadJavaRDD(TraversalParameters traversalPara
* @param reads reads to write.
*/
public void writeReads(final JavaSparkContext ctx, final String outputFile, JavaRDD<GATKRead> reads) {
writeReads(ctx, outputFile, reads, readsHeader);
writeReads(ctx, outputFile, reads, readsHeader, true);
}

/**
Expand All @@ -356,12 +356,12 @@ public void writeReads(final JavaSparkContext ctx, final String outputFile, Java
* @param reads reads to write.
* @param header the header to write.
*/
public void writeReads(final JavaSparkContext ctx, final String outputFile, JavaRDD<GATKRead> reads, SAMFileHeader header) {
public void writeReads(final JavaSparkContext ctx, final String outputFile, JavaRDD<GATKRead> reads, SAMFileHeader header, final boolean sortReadsToHeader) {
try {
ReadsSparkSink.writeReads(ctx, outputFile,
hasReference() ? referenceArguments.getReferencePath().toAbsolutePath().toUri().toString() : null,
reads, header, shardedOutput ? ReadsWriteFormat.SHARDED : ReadsWriteFormat.SINGLE,
getRecommendedNumReducers(), shardedPartsDir, createOutputBamIndex, createOutputBamSplittingIndex);
getRecommendedNumReducers(), shardedPartsDir, createOutputBamIndex, createOutputBamSplittingIndex, sortReadsToHeader);
} catch (IOException e) {
throw new UserException.CouldNotCreateOutputFile(outputFile,"writing failed", e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ public final class ReadsSparkSink {
public static void writeReads(
final JavaSparkContext ctx, final String outputFile, final String referenceFile, final JavaRDD<GATKRead> reads,
final SAMFileHeader header, ReadsWriteFormat format) throws IOException {
writeReads(ctx, outputFile, referenceFile, reads, header, format, 0, null);
writeReads(ctx, outputFile, referenceFile, reads, header, format, 0, null, true);
}

/**
Expand All @@ -66,11 +66,12 @@ public static void writeReads(
* @param numReducers the number of reducers to use when writing a single file. A value of zero indicates that the default
* should be used.
* @param outputPartsDir directory for temporary files for SINGLE output format, should be null for default value of filename + .output
* @param sortReadsToHeader if true, the writer will perform a sort of reads according to the sort order of the header before writing
*/
public static void writeReads(
final JavaSparkContext ctx, final String outputFile, final String referenceFile, final JavaRDD<GATKRead> reads,
final SAMFileHeader header, ReadsWriteFormat format, final int numReducers, final String outputPartsDir) throws IOException {
writeReads(ctx, outputFile, referenceFile, reads, header, format, numReducers, outputPartsDir, true, true);
final SAMFileHeader header, ReadsWriteFormat format, final int numReducers, final String outputPartsDir, final boolean sortReadsToHeader) throws IOException {
writeReads(ctx, outputFile, referenceFile, reads, header, format, numReducers, outputPartsDir, true, true, sortReadsToHeader);
}

/**
Expand All @@ -86,11 +87,12 @@ public static void writeReads(
* @param outputPartsDir directory for temporary files for SINGLE output format, should be null for default value of filename + .output
* @param writeBai whether to write a BAI file (when writing BAM format)
* @param writeSbi whether to write an SBI file (when writing BAM format)
* @param sortReadsToHeader whether to sort the reads in the underlying RDD to match the header sort order option before writing
*/
public static void writeReads(
final JavaSparkContext ctx, final String outputFile, final String referenceFile, final JavaRDD<GATKRead> reads,
final SAMFileHeader header, ReadsWriteFormat format, final int numReducers, final String outputPartsDir,
final boolean writeBai, final boolean writeSbi) throws IOException {
final boolean writeBai, final boolean writeSbi, final boolean sortReadsToHeader) throws IOException {

String absoluteOutputFile = BucketUtils.makeFilePathAbsolute(outputFile);
String absoluteReferenceFile = referenceFile != null ?
Expand All @@ -103,6 +105,7 @@ public static void writeReads(
// SAMRecords, this will effectively be a no-op. The SAMRecords will be headerless
// for efficient serialization.
final JavaRDD<SAMRecord> samReads = reads.map(read -> read.convertToSAMRecord(null));
final JavaRDD<SAMRecord> readsToOutput = sortReadsToHeader ? sortSamRecordsToMatchHeader(samReads, header, numReducers) : samReads;

if (format == ReadsWriteFormat.SINGLE) {
FileCardinalityWriteOption fileCardinalityWriteOption = FileCardinalityWriteOption.SINGLE;
Expand All @@ -114,12 +117,12 @@ public static void writeReads(
absoluteOutputFile.endsWith(CramIO.CRAM_FILE_EXTENSION) ||
absoluteOutputFile.endsWith(IOUtil.SAM_FILE_EXTENSION)) {
// don't specify a write option for format since it is inferred from the extension in the path
writeReads(ctx, absoluteOutputFile, absoluteReferenceFile, samReads, header, numReducers,
writeReads(ctx, absoluteOutputFile, absoluteReferenceFile, readsToOutput, header, numReducers,
fileCardinalityWriteOption, tempPartsDirectoryWriteOption, baiWriteOption, sbiWriteOption);
} else {
// default to BAM
ReadsFormatWriteOption formatWriteOption = ReadsFormatWriteOption.BAM;
writeReads(ctx, absoluteOutputFile, absoluteReferenceFile, samReads, header, numReducers, formatWriteOption,
writeReads(ctx, absoluteOutputFile, absoluteReferenceFile, readsToOutput, header, numReducers, formatWriteOption,
fileCardinalityWriteOption, tempPartsDirectoryWriteOption, baiWriteOption, sbiWriteOption);
}
} else if (format == ReadsWriteFormat.SHARDED) {
Expand All @@ -128,22 +131,21 @@ public static void writeReads(
}
ReadsFormatWriteOption formatWriteOption = ReadsFormatWriteOption.BAM; // use BAM if output file is a directory
FileCardinalityWriteOption fileCardinalityWriteOption = FileCardinalityWriteOption.MULTIPLE;
writeReads(ctx, absoluteOutputFile, absoluteReferenceFile, samReads, header, numReducers, formatWriteOption, fileCardinalityWriteOption);
writeReads(ctx, absoluteOutputFile, absoluteReferenceFile, readsToOutput, header, numReducers, formatWriteOption, fileCardinalityWriteOption);
} else if (format == ReadsWriteFormat.ADAM) {
if (outputPartsDir!=null) {
throw new GATKException(String.format("You specified the bam output parts directory %s, but requested an ADAM output format which does not use this option",outputPartsDir));
}
writeReadsADAM(ctx, absoluteOutputFile, samReads, header);
writeReadsADAM(ctx, absoluteOutputFile, readsToOutput, header);
}
}

private static void writeReads(
final JavaSparkContext ctx, final String outputFile, final String referenceFile, final JavaRDD<SAMRecord> reads,
final SAMFileHeader header, final int numReducers, final WriteOption... writeOptions) throws IOException {

final JavaRDD<SAMRecord> sortedReads = sortSamRecordsToMatchHeader(reads, header, numReducers);
Broadcast<SAMFileHeader> headerBroadcast = ctx.broadcast(header);
final JavaRDD<SAMRecord> sortedReadsWithHeader = sortedReads.map(read -> {
final JavaRDD<SAMRecord> sortedReadsWithHeader = reads.map(read -> {
read.setHeaderStrict(headerBroadcast.getValue());
return read;
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ protected void runTool(JavaSparkContext ctx) {
final String key = rmap.getKey();
JavaRDD<GATKRead> filteredreads = rmap.getKey()==null? readsReverted :
readsReverted.filter(r -> r.getReadGroup().equals(key));
writeReads(ctx, rmap.getValue().toString(), filteredreads, headerMap.get(rmap.getKey())); //TODO proper header map
writeReads(ctx, rmap.getValue().toString(), filteredreads, headerMap.get(rmap.getKey()), false); //TODO proper header map
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,10 @@ protected void runTool(final JavaSparkContext ctx) {
reads = bwaEngine.alignPaired(getUnfilteredReads()).filter(filter::test);
}
try {
// NOTE, we must include 'shardedOutput' as the input to `sortReadsToMatchHeader` to preserve the old default behavior for writing BwaSpark output
// where we would not sort the bam if outputting to a sharded output.
ReadsSparkSink.writeReads(ctx, output, null, reads, bwaEngine.getHeader(),
shardedOutput ? ReadsWriteFormat.SHARDED : ReadsWriteFormat.SINGLE);
shardedOutput ? ReadsWriteFormat.SHARDED : ReadsWriteFormat.SINGLE, getRecommendedNumReducers(), shardedPartsDir, shardedOutput);
} catch (final IOException e) {
throw new GATKException("Unable to write aligned reads", e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ private void writeBam(final JavaRDD<GATKRead> reads, final String inputBamPath,
try {
ReadsSparkSink.writeReads(ctx, outputPath, bwaArgs.referencePath, reads, header,
shardedOutput ? ReadsWriteFormat.SHARDED : ReadsWriteFormat.SINGLE,
PSUtils.pathseqGetRecommendedNumReducers(inputBamPath, numReducers, getTargetPartitionSize()), shardedPartsDir);
PSUtils.pathseqGetRecommendedNumReducers(inputBamPath, numReducers, getTargetPartitionSize()), shardedPartsDir, true);
} catch (final IOException e) {
throw new UserException.CouldNotCreateOutputFile(outputPath, "Writing failed", e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,13 +167,13 @@ protected void runTool(final JavaSparkContext ctx) {

if (!pairedReads.isEmpty()) {
header.setSortOrder(SAMFileHeader.SortOrder.queryname);
writeReads(ctx, outputPaired, pairedReads, header);
writeReads(ctx, outputPaired, pairedReads, header, true);
} else {
logger.info("No paired reads to write - BAM will not be written.");
}
if (!unpairedReads.isEmpty()) {
header.setSortOrder(SAMFileHeader.SortOrder.unsorted);
writeReads(ctx, outputUnpaired, unpairedReads, header);
writeReads(ctx, outputUnpaired, unpairedReads, header, true);
} else {
logger.info("No unpaired reads to write - BAM will not be written.");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ protected void runTool(final JavaSparkContext ctx) {
final int numPartitions = Math.max(1, (int) (numTotalReads / readsPerPartitionOutput));
final JavaRDD<GATKRead> readsFinalRepartitioned = readsFinal.coalesce(numPartitions, false);
ReadsSparkSink.writeReads(ctx, outputPath, null, readsFinalRepartitioned, header,
shardedOutput ? ReadsWriteFormat.SHARDED : ReadsWriteFormat.SINGLE, numPartitions, shardedPartsDir);
shardedOutput ? ReadsWriteFormat.SHARDED : ReadsWriteFormat.SINGLE, numPartitions, shardedPartsDir, true);
} catch (final IOException e) {
throw new UserException.CouldNotCreateOutputFile(outputPath, "writing failed", e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ protected void runTool(final JavaSparkContext ctx) {
if (outputPath != null) {
try {
ReadsSparkSink.writeReads(ctx, outputPath, null, readsFinal, header,
shardedOutput ? ReadsWriteFormat.SHARDED : ReadsWriteFormat.SINGLE, recommendedNumReducers, shardedPartsDir);
shardedOutput ? ReadsWriteFormat.SHARDED : ReadsWriteFormat.SINGLE, recommendedNumReducers, shardedPartsDir, true);
} catch (final IOException e) {
throw new UserException.CouldNotCreateOutputFile(outputPath, "writing failed", e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ protected void runTool(final JavaSparkContext ctx) {
referenceArguments.getReferencePath().toAbsolutePath().toUri().toString(),
markedReads, bwaEngine.getHeader(),
shardedOutput ? ReadsWriteFormat.SHARDED : ReadsWriteFormat.SINGLE,
getRecommendedNumReducers(), shardedPartsDir);
getRecommendedNumReducers(), shardedPartsDir, true);
} catch (IOException e) {
throw new GATKException("unable to write bam: " + e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,6 @@ public final class PrintReadsSpark extends GATKSparkTool {
protected void runTool(final JavaSparkContext ctx) {

final JavaRDD<GATKRead> reads = getReads();
writeReads(ctx, output, reads);
writeReads(ctx, output, reads, getHeaderForReads(), false);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ protected void runTool(final JavaSparkContext ctx) {
final JavaRDD<GATKRead> finalReads = ApplyBQSRSparkFn.apply(sortedMarkedReads, reportBroadcast, getHeaderForReads(), applyBqsrArgs.toApplyBQSRArgumentCollection(bqsrArgs.PRESERVE_QSCORES_LESS_THAN));

if (outputBam != null) { // only write output of BQSR if output BAM is specified
writeReads(ctx, outputBam, finalReads, header);
writeReads(ctx, outputBam, finalReads, header, true);
}

// Run Haplotype Caller
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,6 @@ protected void runTool(final JavaSparkContext ctx) {
final SAMFileHeader header = getHeaderForReads();
header.setSortOrder(sortOrder.getSamOrder());

final JavaRDD<GATKRead> readsToWrite;
if (shardedOutput) {
readsToWrite = SparkUtils.sortReadsAccordingToHeader(reads, header, numReducers);
} else {
readsToWrite = reads;
}
writeReads(ctx, outputFile, readsToWrite, header);
writeReads(ctx, outputFile, reads, header, true);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ protected void runTool( final JavaSparkContext ctx ) {
final Function<GATKRead, Boolean> predicate = getGatkReadBooleanFunction(namesToLookForBroadcast, invertFilter);

final JavaRDD<GATKRead> reads = getUnfilteredReads().filter(predicate).cache();
writeReads(ctx, outputSAM, reads, getHeaderForReads());
writeReads(ctx, outputSAM, reads, getHeaderForReads(), false);

logger.info("Found " + reads.count() + " alignment records for " +
namesToLookForBroadcast.getValue().size() + " unique read names.");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ protected void runTool(final JavaSparkContext ctx) {
}

header.setSortOrder(SAMFileHeader.SortOrder.coordinate);
writeReads(ctx, output, readsForWriting, header);
writeReads(ctx, output, readsForWriting, header, true);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -216,8 +216,8 @@ protected void runTool(final JavaSparkContext ctx) {
SAMFileHeader headerForwrite = bHeader.getValue();
headerForwrite.setAttribute("in","original read file source");

writeReads(ctx, output, firstReads.filter(read -> nameSet.value().contains(read.getName())), headerForwrite);
writeReads(ctx, output2, secondReads.filter(read -> nameSet.value().contains(read.getName())), headerForwrite);
writeReads(ctx, output, firstReads.filter(read -> nameSet.value().contains(read.getName())), headerForwrite, true);
writeReads(ctx, output2, secondReads.filter(read -> nameSet.value().contains(read.getName())), headerForwrite, true);
}

// Produces an RDD of MatchTypes, e.g., EQUAL, DIFFERENT_REPRESENTATIVE_READ, etc. per MarkDuplicates key,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,7 @@ public void printReadSparkOnDataproc(final String input) throws IOException {
DataprocTestUtils.launchGatkTool(PrintReadsSpark.class.getSimpleName(), argBuilder.getArgsList(), clusterName);
final File expected = copyLocally(gcsInputPath, "expected");
final File actual = copyLocally(outputPath, "actual");
IntegrationTestSpec.assertMatchingFiles(Collections.singletonList(actual), Collections.singletonList(expected.toString()), true, ValidationStringency.LENIENT);
assertReadsAreInCoordinatishOrder(actual);
IntegrationTestSpec.assertMatchingFiles(Collections.singletonList(actual), Collections.singletonList(expected.toString()), false, ValidationStringency.LENIENT);
}

private static void assertReadsAreInCoordinatishOrder(final File bam) {
Expand Down Expand Up @@ -117,7 +116,6 @@ public void markDuplicatesSparkOnDataproc() throws IOException {
final File actual = copyLocally(bamOut, "actual");

//assert that the output has the right number of reads and they're ordered correctly
assertReadsAreInCoordinatishOrder(actual);
try( ReadsDataSource reader = new ReadsDataSource(actual.toPath())){
Assert.assertEquals(Iterators.size(reader.iterator()), 1838);
}
Expand Down
Loading

0 comments on commit 9c22c34

Please sign in to comment.