Skip to content

Commit

Permalink
Run HaplotypeCallerSpark on WGS in strict mode (#5721)
Browse files Browse the repository at this point in the history
* Perform downsampling in AssemblyRegionWalkerSpark's strict mode. Fixes #5476.

* Reduce ActivityProfileStateRange memory usage so it can run on genome-sized data.

* Repartition data so it is spread evenly across the cluster after finding assembly regions
  • Loading branch information
tomwhite authored Mar 8, 2019
1 parent bd61efb commit 02dca71
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -159,12 +159,20 @@ public static JavaRDD<AssemblyRegionWalkerContext> getAssemblyRegionsStrict(
// at which points the reads can be filled in. (See next step.)
JavaRDD<ReadlessAssemblyRegion> readlessAssemblyRegions = contigToGroupedStates
.flatMap(getReadlessAssemblyRegionsFunction(header, assemblyRegionArgs));
// repartition to distribute the data evenly across the cluster again
readlessAssemblyRegions = readlessAssemblyRegions.repartition(readlessAssemblyRegions.getNumPartitions());

// 4. Fill in the reads. Each shard is an assembly region, with its overlapping reads.
JavaRDD<Shard<GATKRead>> assemblyRegionShardedReads = SparkSharder.shard(ctx, reads, GATKRead.class, header.getSequenceDictionary(), readlessAssemblyRegions, shardingArgs.readShardSize);

// 5. Convert shards to assembly regions.
JavaRDD<AssemblyRegion> assemblyRegions = assemblyRegionShardedReads.map((Function<Shard<GATKRead>, AssemblyRegion>) shard -> toAssemblyRegion(shard, header));
// 5. Convert shards to assembly regions. Reads downsampling is done again here. Note it will only be
// consistent with the downsampling done in step 1 when https://github.com/broadinstitute/gatk/issues/5437 is in.
JavaRDD<AssemblyRegion> assemblyRegions = assemblyRegionShardedReads.mapPartitions((FlatMapFunction<Iterator<Shard<GATKRead>>, AssemblyRegion>) shardedReadIterator -> {
final ReadsDownsampler readsDownsampler = assemblyRegionArgs.maxReadsPerAlignmentStart > 0 ?
new PositionalDownsampler(assemblyRegionArgs.maxReadsPerAlignmentStart, header) : null;
return Utils.stream(shardedReadIterator)
.map(shardedRead -> toAssemblyRegion(shardedRead, header, readsDownsampler)).iterator();
});

// 6. Add reference and feature context.
return assemblyRegions.mapPartitions(getAssemblyRegionWalkerContextFunction(referenceFileName, bFeatureManager));
Expand Down Expand Up @@ -221,12 +229,16 @@ public ReadlessAssemblyRegion apply(@Nullable AssemblyRegion input) {
});
}

private static AssemblyRegion toAssemblyRegion(Shard<GATKRead> shard, SAMFileHeader header) {
private static AssemblyRegion toAssemblyRegion(Shard<GATKRead> shard, SAMFileHeader header, ReadsDownsampler readsDownsampler) {
Shard<GATKRead> downsampledShardedRead =
new DownsampleableSparkReadShard(
new ShardBoundary(shard.getInterval(), shard.getPaddedInterval()), shard, readsDownsampler);

// TODO: interfaces could be improved to avoid casting
ReadlessAssemblyRegion readlessAssemblyRegion = (ReadlessAssemblyRegion) ((ShardBoundaryShard<GATKRead>) shard).getShardBoundary();
int extension = Math.max(shard.getInterval().getStart() - shard.getPaddedInterval().getStart(), shard.getPaddedInterval().getEnd() - shard.getInterval().getEnd());
AssemblyRegion assemblyRegion = new AssemblyRegion(shard.getInterval(), Collections.emptyList(), readlessAssemblyRegion.isActive(), extension, header);
assemblyRegion.addAll(Lists.newArrayList(shard));
assemblyRegion.addAll(Lists.newArrayList(downsampledShardedRead));
return assemblyRegion;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@ public class ActivityProfileStateRange {
private final SimpleInterval interval;
private final double[] activeProb;
private final ActivityProfileState.Type[] resultState;
private final Number[] resultValue;
private final double[] resultValue; // don't store as a Number since it uses more memory

public ActivityProfileStateRange(MultiIntervalShard<?> shard, Iterator<ActivityProfileState> activityProfileStateIterator) {
List<SimpleInterval> intervals = shard.getIntervals();
this.interval = Iterables.getOnlyElement(intervals);
int size = interval.size();
this.activeProb = new double[size];
this.resultState = new ActivityProfileState.Type[size];
this.resultValue = new Number[size];
this.resultValue = new double[size];

int i = 0;
ActivityProfileState prev = null;
Expand All @@ -39,7 +39,8 @@ public ActivityProfileStateRange(MultiIntervalShard<?> shard, Iterator<ActivityP
}
activeProb[i] = next.isActiveProb();
resultState[i] = next.getResultState();
resultValue[i] = next.getResultValue();
// store null result value as a negative number, since negative numbers are illegal in ActivityProfileState
resultValue[i] = next.getResultValue() == null ? Double.NEGATIVE_INFINITY : next.getResultValue().doubleValue();
i++;
prev = next;
}
Expand All @@ -59,7 +60,8 @@ protected ActivityProfileState computeNext() {
return endOfData();
}
int pos = interval.getStart() + i;
ActivityProfileState state = new ActivityProfileState(new SimpleInterval(interval.getContig(), pos, pos), activeProb[i], resultState[i], resultValue[i]);
double v = resultValue[i];
ActivityProfileState state = new ActivityProfileState(new SimpleInterval(interval.getContig(), pos, pos), activeProb[i], resultState[i], v == Double.NEGATIVE_INFINITY ? null : v);
i++;
return state;
}
Expand Down

0 comments on commit 02dca71

Please sign in to comment.