From 3ce718deb67a8aa069e878920d7e6a26b621bc1e Mon Sep 17 00:00:00 2001 From: Dror Kessler Date: Tue, 14 Jun 2022 15:19:12 +0100 Subject: [PATCH] Lookahead Covariates moved to UG_feature_branch --- .../tools/walkers/bqsr/BaseRecalibrator.java | 7 + .../RecalibrationArgumentCollection.java | 21 +++ .../recalibration/RecalibrationReport.java | 31 +++- .../covariates/AltCovariate.java | 53 ++++++ .../covariates/ContextCovariate.java | 162 +++++++++++++++++- .../covariates/StandardCovariateList.java | 11 +- 6 files changed, 268 insertions(+), 17 deletions(-) create mode 100644 src/main/java/org/broadinstitute/hellbender/utils/recalibration/covariates/AltCovariate.java diff --git a/src/main/java/org/broadinstitute/hellbender/tools/walkers/bqsr/BaseRecalibrator.java b/src/main/java/org/broadinstitute/hellbender/tools/walkers/bqsr/BaseRecalibrator.java index e4bdc890c14..787d2202b48 100644 --- a/src/main/java/org/broadinstitute/hellbender/tools/walkers/bqsr/BaseRecalibrator.java +++ b/src/main/java/org/broadinstitute/hellbender/tools/walkers/bqsr/BaseRecalibrator.java @@ -145,6 +145,13 @@ public void onTraversalStart() { Utils.warnOnNonIlluminaReadGroups(getHeaderForReads(), logger); + // grab reference if not specified in case of using an extended context covariate + if ( recalArgs.EXTENDED_CONTEXT_LOOKAHEAD != 0 ) { + if ( recalArgs.EXTENDED_CONTEXT_REFERENCE == null ) { + recalArgs.EXTENDED_CONTEXT_REFERENCE = new GATKPath(referenceArguments.getReferencePath().toAbsolutePath().toString()); + } + } + recalibrationEngine = new BaseRecalibrationEngine(recalArgs, getHeaderForReads()); recalibrationEngine.logCovariatesUsed(); referenceDataSource = ReferenceDataSource.of(referenceArguments.getReferencePath()); diff --git a/src/main/java/org/broadinstitute/hellbender/utils/recalibration/RecalibrationArgumentCollection.java b/src/main/java/org/broadinstitute/hellbender/utils/recalibration/RecalibrationArgumentCollection.java index 01f1d853432..a770834ba81 100644 --- a/src/main/java/org/broadinstitute/hellbender/utils/recalibration/RecalibrationArgumentCollection.java +++ b/src/main/java/org/broadinstitute/hellbender/utils/recalibration/RecalibrationArgumentCollection.java @@ -3,6 +3,7 @@ import org.broadinstitute.barclay.argparser.Argument; import org.broadinstitute.barclay.argparser.Hidden; import org.broadinstitute.hellbender.cmdline.StandardArgumentDefinitions; +import org.broadinstitute.hellbender.engine.GATKPath; import org.broadinstitute.hellbender.utils.QualityUtils; import org.broadinstitute.hellbender.utils.baq.BAQ; import org.broadinstitute.hellbender.utils.report.GATKReportTable; @@ -42,6 +43,18 @@ public final class RecalibrationArgumentCollection implements Serializable { @Argument(fullName = "indels-context-size", shortName = "ics", doc = "Size of the k-mer context to be used for base insertions and deletions", optional = true) public int INDELS_CONTEXT_SIZE = 3; + /** + * experimental parameter indicating that we are going the extended version of the context covariate + */ + @Argument(fullName = "extended-context-lookahead", shortName = "ecl", doc = "amount of lookahead (out of the context) when using the extended context covariate", optional = true) + public int EXTENDED_CONTEXT_LOOKAHEAD; + + @Argument(fullName = "extended-context-reference", shortName = "ecr", doc = "reference for assembling the extended context", optional = true) + public GATKPath EXTENDED_CONTEXT_REFERENCE; + + @Argument(fullName = "extended-context-alt-embedded", shortName = "ecae", doc = "alt base, when extending context, embedded in context", optional = true) + public boolean EXTENDED_CONTEXT_ALT_EMBEDDED; + /** * The cycle covariate will generate an error if it encounters a cycle greater than this value. * This argument is ignored if the Cycle covariate is not used. @@ -164,6 +177,12 @@ public GATKReportTable generateReportTable(final String covariateNames) { argumentsTable.set("mismatches_context_size", RecalUtils.ARGUMENT_VALUE_COLUMN_NAME, MISMATCHES_CONTEXT_SIZE); argumentsTable.addRowID("indels_context_size", true); argumentsTable.set("indels_context_size", RecalUtils.ARGUMENT_VALUE_COLUMN_NAME, INDELS_CONTEXT_SIZE); + argumentsTable.addRowID("extended_context_lookahead", true); + argumentsTable.set("extended_context_lookahead", RecalUtils.ARGUMENT_VALUE_COLUMN_NAME, EXTENDED_CONTEXT_LOOKAHEAD); + argumentsTable.addRowID("extended_context_reference", true); + argumentsTable.set("extended_context_reference", RecalUtils.ARGUMENT_VALUE_COLUMN_NAME, EXTENDED_CONTEXT_REFERENCE); + argumentsTable.addRowID("extended_context_alt_embedded", true); + argumentsTable.set("extended_context_alt_embedded", RecalUtils.ARGUMENT_VALUE_COLUMN_NAME, EXTENDED_CONTEXT_ALT_EMBEDDED); argumentsTable.addRowID("mismatches_default_quality", true); argumentsTable.set("mismatches_default_quality", RecalUtils.ARGUMENT_VALUE_COLUMN_NAME, MISMATCHES_DEFAULT_QUALITY); argumentsTable.addRowID("deletions_default_quality", true); @@ -212,6 +231,8 @@ public Map compareReportArguments(final Recalibra compareSimpleReportArgument(result,"solid_recal_mode", SOLID_RECAL_MODE, SOLID_RECAL_MODE,thisRole,otherRole); compareSimpleReportArgument(result,"solid_nocall_strategy", SOLID_NOCALL_STRATEGY, SOLID_NOCALL_STRATEGY,thisRole,otherRole); compareSimpleReportArgument(result,"mismatches_context_size", MISMATCHES_CONTEXT_SIZE,other.MISMATCHES_CONTEXT_SIZE,thisRole,otherRole); + compareSimpleReportArgument(result,"extended_context_lookahead", EXTENDED_CONTEXT_LOOKAHEAD,other.EXTENDED_CONTEXT_LOOKAHEAD,thisRole,otherRole); + compareSimpleReportArgument(result,"extended_context_reference", EXTENDED_CONTEXT_REFERENCE,other.EXTENDED_CONTEXT_REFERENCE,thisRole,otherRole); compareSimpleReportArgument(result,"mismatches_default_quality", MISMATCHES_DEFAULT_QUALITY, other.MISMATCHES_DEFAULT_QUALITY,thisRole,otherRole); compareSimpleReportArgument(result,"deletions_default_quality", DELETIONS_DEFAULT_QUALITY, other.DELETIONS_DEFAULT_QUALITY,thisRole,otherRole); compareSimpleReportArgument(result,"insertions_default_quality", INSERTIONS_DEFAULT_QUALITY, other.INSERTIONS_DEFAULT_QUALITY,thisRole,otherRole); diff --git a/src/main/java/org/broadinstitute/hellbender/utils/recalibration/RecalibrationReport.java b/src/main/java/org/broadinstitute/hellbender/utils/recalibration/RecalibrationReport.java index 8e5ba8a7d2a..9e7e45aebda 100644 --- a/src/main/java/org/broadinstitute/hellbender/utils/recalibration/RecalibrationReport.java +++ b/src/main/java/org/broadinstitute/hellbender/utils/recalibration/RecalibrationReport.java @@ -4,6 +4,7 @@ import org.apache.commons.collections.CollectionUtils; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.broadinstitute.hellbender.engine.GATKPath; import org.broadinstitute.hellbender.exceptions.GATKException; import org.broadinstitute.hellbender.exceptions.UserException; import org.broadinstitute.hellbender.utils.QualityUtils; @@ -325,6 +326,30 @@ private static QuantizationInfo initializeQuantizationTable(GATKReportTable tabl private static RecalibrationArgumentCollection initializeArgumentCollectionTable(GATKReportTable table) { final RecalibrationArgumentCollection RAC = new RecalibrationArgumentCollection(); + // peek for lookahead argument - since it may change the list of covariates + for ( int i = 0; i < table.getNumRows(); i++ ) { + final String argument = table.get(i, "Argument").toString(); + Object value = table.get(i, RecalUtils.ARGUMENT_VALUE_COLUMN_NAME); + if (value.equals("null")) { + value = null; // generic translation of null values that were printed out as strings | todo -- add this capability to the GATKReport + } + if (argument.equals("extended_context_lookahead")) + RAC.EXTENDED_CONTEXT_LOOKAHEAD = decodeInteger(value); + + else if (argument.equals("extended_context_reference")) + RAC.EXTENDED_CONTEXT_REFERENCE = (value != null) ? new GATKPath((String)value) : null; + + else if (argument.equals("extended_context_alt_embedded")) + RAC.EXTENDED_CONTEXT_ALT_EMBEDDED = decodeBoolean(value); + + else if (argument.equals("mismatches_context_size")) + RAC.MISMATCHES_CONTEXT_SIZE = decodeInteger(value); + + else if (argument.equals("indels_context_size")) + RAC.INDELS_CONTEXT_SIZE = decodeInteger(value); + } + + final List standardCovariateClassNames = new StandardCovariateList(RAC, Collections.emptyList()).getStandardCovariateClassNames(); for ( int i = 0; i < table.getNumRows(); i++ ) { @@ -356,12 +381,6 @@ else if (argument.equals("solid_nocall_strategy")) { throw new UserException("Solid is not supported. Only " + RecalibrationArgumentCollection.SOLID_NOCALL_STRATEGY + " is allowed as value for solid_nocall_strategy"); } } - else if (argument.equals("mismatches_context_size")) - RAC.MISMATCHES_CONTEXT_SIZE = decodeInteger(value); - - else if (argument.equals("indels_context_size")) - RAC.INDELS_CONTEXT_SIZE = decodeInteger(value); - else if (argument.equals("mismatches_default_quality")) RAC.MISMATCHES_DEFAULT_QUALITY = decodeByte(value); diff --git a/src/main/java/org/broadinstitute/hellbender/utils/recalibration/covariates/AltCovariate.java b/src/main/java/org/broadinstitute/hellbender/utils/recalibration/covariates/AltCovariate.java new file mode 100644 index 00000000000..c6cfa1c9640 --- /dev/null +++ b/src/main/java/org/broadinstitute/hellbender/utils/recalibration/covariates/AltCovariate.java @@ -0,0 +1,53 @@ +package org.broadinstitute.hellbender.utils.recalibration.covariates; + +import htsjdk.samtools.SAMFileHeader; +import org.broadinstitute.hellbender.utils.BaseUtils; +import org.broadinstitute.hellbender.utils.read.GATKRead; +import org.broadinstitute.hellbender.utils.recalibration.RecalibrationArgumentCollection; + +public final class AltCovariate implements Covariate { + private static final long serialVersionUID = 1L; + + public AltCovariate(final RecalibrationArgumentCollection RAC){ + } + + // Used to pick out the covariate's value from attributes of the read + @Override + public void recordValues(final GATKRead read, final SAMFileHeader header, final ReadCovariates values, final boolean recordIndelValues) { + final int readLength = read.getLength(); + final byte[] readBases = read.getBasesNoCopy(); + for (int i = 0; i < readLength; i++) { + final int baseIndex = BaseUtils.simpleBaseToBaseIndex(readBases[i]); + values.addCovariate(baseIndex, 0, 0, i); + } + } + + @Override + public String formatKey(final int key){ + if ( key < 0 ) { + return "N"; + } else { + return String.format("%c", (char)BaseUtils.baseIndexToSimpleBase(key)); + } + } + + @Override + public int keyFromValue(final Object value) { + if ( (value instanceof String)) { + final byte base = ((String)value).getBytes()[0]; + if ( base == 'N' ) { + return -1; + } else { + return BaseUtils.simpleBaseToBaseIndex(base); + } + + } else { + return (Integer)value; + } + } + + @Override + public int maximumKeyValue() { + return BaseUtils.Base.values().length; + } +} \ No newline at end of file diff --git a/src/main/java/org/broadinstitute/hellbender/utils/recalibration/covariates/ContextCovariate.java b/src/main/java/org/broadinstitute/hellbender/utils/recalibration/covariates/ContextCovariate.java index a2f815d7857..e58ba11d231 100644 --- a/src/main/java/org/broadinstitute/hellbender/utils/recalibration/covariates/ContextCovariate.java +++ b/src/main/java/org/broadinstitute/hellbender/utils/recalibration/covariates/ContextCovariate.java @@ -1,20 +1,26 @@ package org.broadinstitute.hellbender.utils.recalibration.covariates; import com.google.common.annotations.VisibleForTesting; +import htsjdk.samtools.CigarElement; +import htsjdk.samtools.CigarOperator; import htsjdk.samtools.SAMFileHeader; import it.unimi.dsi.fastutil.ints.IntArrayList; import it.unimi.dsi.fastutil.ints.IntList; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.broadinstitute.barclay.argparser.CommandLineException; +import org.broadinstitute.hellbender.engine.ReferenceDataSource; import org.broadinstitute.hellbender.exceptions.GATKException; import org.broadinstitute.hellbender.utils.BaseUtils; +import org.broadinstitute.hellbender.utils.Utils; import org.broadinstitute.hellbender.utils.clipping.ClippingRepresentation; import org.broadinstitute.hellbender.utils.clipping.ReadClipper; import org.broadinstitute.hellbender.utils.read.GATKRead; import org.broadinstitute.hellbender.utils.recalibration.RecalibrationArgumentCollection; -public final class ContextCovariate implements Covariate { +import java.util.Arrays; + +public class ContextCovariate implements Covariate { private static final long serialVersionUID = 1L; private static final Logger logger = LogManager.getLogger(ContextCovariate.class); @@ -24,7 +30,7 @@ public final class ContextCovariate implements Covariate { private final int mismatchesKeyMask; private final int indelsKeyMask; - private static final int LENGTH_BITS = 4; + protected static final int LENGTH_BITS = 4; private static final int LENGTH_MASK = 15; @@ -33,6 +39,16 @@ public final class ContextCovariate implements Covariate { private static final int MAX_DNA_CONTEXT = 13; private final byte lowQualTail; + /** + * amount of lookhead allocated from the context size. The first by of the lookahead is always taken from the reference + * + * For example, if the context size is 3 and the amount of lookahead allocated is 2 then + * the context will be composed of 1 byte before the location, 1 byte from the reference and 1 byte after the location + */ + private int lookaheadSize; + private ReferenceDataSource referenceDataSource; + private boolean altEmbedded; + public ContextCovariate(final RecalibrationArgumentCollection RAC){ mismatchesContextSize = RAC.MISMATCHES_CONTEXT_SIZE; indelsContextSize = RAC.INDELS_CONTEXT_SIZE; @@ -53,6 +69,16 @@ public ContextCovariate(final RecalibrationArgumentCollection RAC){ mismatchesKeyMask = createMask(mismatchesContextSize); indelsKeyMask = createMask(indelsContextSize); + + if ( RAC.EXTENDED_CONTEXT_LOOKAHEAD != 0 ) { + Utils.nonNull(RAC.EXTENDED_CONTEXT_REFERENCE, "extended context reference can not be null"); + Utils.validate(RAC.EXTENDED_CONTEXT_LOOKAHEAD > 0, "lookahead must be positive"); + Utils.validate(RAC.EXTENDED_CONTEXT_LOOKAHEAD <= Math.min(RAC.MISMATCHES_CONTEXT_SIZE, RAC.INDELS_CONTEXT_SIZE), "lookahead can not be larger than the context"); + + this.lookaheadSize = RAC.EXTENDED_CONTEXT_LOOKAHEAD; + this.referenceDataSource = ReferenceDataSource.of(RAC.EXTENDED_CONTEXT_REFERENCE.toPath()); + this.altEmbedded = RAC.EXTENDED_CONTEXT_ALT_EMBEDDED; + } } @Override @@ -62,10 +88,13 @@ public void recordValues(final GATKRead read, final SAMFileHeader header, final // store the original bases and then write Ns over low quality ones final byte[] strandedClippedBases = getStrandedClippedBytes(read, lowQualTail); //Note: this makes a copy of the read + final byte[] refBases = isExtended() ? getReadReferenceBases(read) : null; //Note: we're using a non-standard library here because boxing came up on profiling as taking 20% of time in applyBQSR. //IntList avoids boxing - final IntList mismatchKeys = contextWith(strandedClippedBases, mismatchesContextSize, mismatchesKeyMask); + final IntList mismatchKeys = isExtended() + ? contextWith(strandedClippedBases, mismatchesContextSize, mismatchesKeyMask, refBases) + : contextWith(strandedClippedBases, mismatchesContextSize, mismatchesKeyMask); final int readLengthAfterClipping = strandedClippedBases.length; @@ -84,7 +113,9 @@ public void recordValues(final GATKRead read, final SAMFileHeader header, final //Note: duplicated the loop to avoid checking recordIndelValues on each iteration if (recordIndelValues) { - final IntList indelKeys = contextWith(strandedClippedBases, indelsContextSize, indelsKeyMask); + final IntList indelKeys = isExtended() + ? contextWith(strandedClippedBases, indelsContextSize, indelsKeyMask, refBases) + : contextWith(strandedClippedBases, indelsContextSize, indelsKeyMask); for (int i = 0; i < readLengthAfterClipping; i++) { final int readOffset = getStrandedOffset(negativeStrand, i, readLengthAfterClipping); final int indelKey = indelKeys.getInt(i); @@ -155,12 +186,11 @@ private static int createMask(final int contextSize) { /** * calculates the context of a base independent of the covariate mode (mismatch, insertion or deletion) - * - * @param bases the bases in the read to build the context from + * @param bases the bases in the read to build the context from * @param contextSize context size to use building the context * @param mask mask for pulling out just the context bits */ - private static IntList contextWith(final byte[] bases, final int contextSize, final int mask) { + protected IntList contextWith(final byte[] bases, final int contextSize, final int mask) { final int readLength = bases.length; @@ -219,7 +249,7 @@ private static IntList contextWith(final byte[] bases, final int contextSize, fi return keys; } - public static int keyFromContext(final String dna) { + public int keyFromContext(final String dna) { return keyFromContext(dna.getBytes(), 0, dna.length()); } @@ -275,7 +305,8 @@ public static String contextFromKey(final int key) { @Override public int maximumKeyValue() { // the maximum value is T (11 in binary) for each base in the context - final int length = Math.max(mismatchesContextSize, indelsContextSize); // the length of the context + final int length = Math.max(mismatchesContextSize, indelsContextSize) + + (altEmbedded ? 1 : 0); // the length of the context int key = length; int bitOffset = LENGTH_BITS; for (int i = 0; i INS. supplement reference with N + Arrays.fill(readRefBases, readOfs, readOfs + length, (byte)'N'); + readOfs += length; + } else if ( op.consumesReferenceBases() ) { + // has reference but not read bases -> DEL. skip on reference + refOfs += length; + } + } + + Utils.validate(readOfs == read.getLength(), "did not read end of read"); + Utils.validate(refOfs == refBases.length, "did not reach end of reference"); + + // reverse complement? + if ( read.isReverseStrand() ) { + return BaseUtils.simpleReverseComplement(readRefBases); + } else { + return readRefBases; + } + } + + private int keyFromContext(final byte[] dna, final int start, final int end, final int refIndex, final byte refBase) { + + int key = end - start + (altEmbedded ? 1 : 0); + int bitOffset = LENGTH_BITS; + for (int i = start; i < end; i++) { + final int baseIndex = BaseUtils.simpleBaseToBaseIndex((i != refIndex) ? dna[i] : refBase); + if (baseIndex == -1) { // ignore non-ACGT bases + return -1; + } + key |= (baseIndex << bitOffset); + bitOffset += 2; + + if ( altEmbedded && i == refIndex ) { + final int altBaseIndex = BaseUtils.simpleBaseToBaseIndex(dna[i]); + if (altBaseIndex == -1) { // ignore non-ACGT bases + return -1; + } + key |= (altBaseIndex << bitOffset); + bitOffset += 2; + } + } + return key; + } } diff --git a/src/main/java/org/broadinstitute/hellbender/utils/recalibration/covariates/StandardCovariateList.java b/src/main/java/org/broadinstitute/hellbender/utils/recalibration/covariates/StandardCovariateList.java index 229badf9af3..3a329fe4f2f 100644 --- a/src/main/java/org/broadinstitute/hellbender/utils/recalibration/covariates/StandardCovariateList.java +++ b/src/main/java/org/broadinstitute/hellbender/utils/recalibration/covariates/StandardCovariateList.java @@ -34,8 +34,15 @@ public StandardCovariateList(final RecalibrationArgumentCollection rac, final Li final ContextCovariate contextCovariate = new ContextCovariate(rac); final CycleCovariate cycleCovariate = new CycleCovariate(rac); - additionalCovariates = Collections.unmodifiableList(Arrays.asList(contextCovariate, cycleCovariate)); - allCovariates = Collections.unmodifiableList(Arrays.asList(readGroupCovariate, qualityScoreCovariate, contextCovariate, cycleCovariate)); + if ( rac.EXTENDED_CONTEXT_LOOKAHEAD == 0 || rac.EXTENDED_CONTEXT_ALT_EMBEDDED ) { + additionalCovariates = Collections.unmodifiableList(Arrays.asList(contextCovariate, cycleCovariate)); + allCovariates = Collections.unmodifiableList(Arrays.asList(readGroupCovariate, qualityScoreCovariate, contextCovariate, cycleCovariate)); + } else { + final AltCovariate altCovariate = new AltCovariate(rac); + + additionalCovariates = Collections.unmodifiableList(Arrays.asList(contextCovariate, cycleCovariate, altCovariate)); + allCovariates = Collections.unmodifiableList(Arrays.asList(readGroupCovariate, qualityScoreCovariate, contextCovariate, cycleCovariate, altCovariate)); + } //precompute for faster lookup (shows up on profile) indexByClass = new LinkedHashMap<>();