diff --git a/build.gradle b/build.gradle index 4a5e333db15..f6b5084549b 100644 --- a/build.gradle +++ b/build.gradle @@ -63,7 +63,7 @@ final barclayVersion = System.getProperty('barclay.version','2.1.0') final sparkVersion = System.getProperty('spark.version', '2.2.0') final hadoopVersion = System.getProperty('hadoop.version', '2.8.2') final hadoopBamVersion = System.getProperty('hadoopBam.version','7.10.0') -final tensorflowVersion = System.getProperty('tensorflow.version','1.4.0') +final tensorflowVersion = System.getProperty('tensorflow.version','1.9.0') final genomicsdbVersion = System.getProperty('genomicsdb.version','0.10.0-proto-3.0.0-beta-1+bdce8be25b873') final testNGVersion = '6.11' // Using the shaded version to avoid conflicts between its protobuf dependency diff --git a/scripts/gatkcondaenv.yml.template b/scripts/gatkcondaenv.yml.template index 287756a0e9e..54454a722d3 100644 --- a/scripts/gatkcondaenv.yml.template +++ b/scripts/gatkcondaenv.yml.template @@ -26,7 +26,7 @@ dependencies: - h5py==2.7.1 - html5lib==0.9999999 - joblib==0.11 - - keras==2.1.4 + - keras==2.2.0 - markdown==2.6.9 - matplotlib==2.1.0 - numpy==1.13.3 @@ -44,8 +44,7 @@ dependencies: - scipy==1.0.0 - six==1.11.0 - $tensorFlowDependency - - tensorflow-tensorboard==0.4.0rc3 - theano==0.9.0 - tqdm==4.19.4 - werkzeug==0.12.2 - - gatkPythonPackageArchive.zip + - gatkPythonPackageArchive.zip \ No newline at end of file diff --git a/src/main/java/org/broadinstitute/hellbender/tools/walkers/validation/EvaluateInfoFieldConcordance.java b/src/main/java/org/broadinstitute/hellbender/tools/walkers/validation/EvaluateInfoFieldConcordance.java new file mode 100644 index 00000000000..b255f160877 --- /dev/null +++ b/src/main/java/org/broadinstitute/hellbender/tools/walkers/validation/EvaluateInfoFieldConcordance.java @@ -0,0 +1,191 @@ +package org.broadinstitute.hellbender.tools.walkers.validation; + +import java.nio.file.Paths; +import java.io.IOException; + +import org.apache.commons.collections4.Predicate; + +import htsjdk.variant.variantcontext.VariantContext; + +import org.broadinstitute.barclay.argparser.Advanced; +import org.broadinstitute.barclay.argparser.Argument; +import org.broadinstitute.barclay.argparser.BetaFeature; +import org.broadinstitute.barclay.help.DocumentedFeature; +import org.broadinstitute.barclay.argparser.CommandLineProgramProperties; + +import org.broadinstitute.hellbender.engine.ReadsContext; +import org.broadinstitute.hellbender.engine.ReferenceContext; +import org.broadinstitute.hellbender.exceptions.UserException; +import org.broadinstitute.hellbender.engine.AbstractConcordanceWalker; + +import picard.cmdline.programgroups.VariantEvaluationProgramGroup; + +/** + * Compare INFO field values between two VCFs or compare two different INFO fields from one VCF. + * We only evaluate sites that are in both VCFs. + * Although we use the arguments eval and truth, we only compare the scores, we do not determine the correct score. + * Either VCF can be used as eval or truth, or the same VCF can be used for both. + * Differences greater than the epsilon argument will trigger a warning. + * + *

Compare the CNN_2D info fields for the same sites from two different VCFs:

+ * + *
+ * gatk EvaluateInfoFieldConcordance \
+ *  -eval a.vcf \
+ *  -truth another.vcf \
+ *  -S summary.txt \
+ *  -eval-info-key CNN_2D \
+ *  -truth-info-key CNN_2D \
+ *  -epsilon 0.01
+ * 
+ * + *

Compare the CNN_2D info field with the CNN_1D field from the same sites in one VCF:

+ * + *
+ * gatk EvaluateInfoFieldConcordance \
+ *  -eval my.vcf \
+ *  -truth my.vcf \
+ *  -S summary.txt \
+ *  -eval-info-key CNN_2D \
+ *  -truth-info-key CNN_1D \
+ *  -epsilon 0.01
+ * 
+ */ +@CommandLineProgramProperties( + summary=EvaluateInfoFieldConcordance.USAGE_SUMMARY, + oneLineSummary=EvaluateInfoFieldConcordance.USAGE_ONE_LINE_SUMMARY, + programGroup=VariantEvaluationProgramGroup.class) +@DocumentedFeature +@BetaFeature +public class EvaluateInfoFieldConcordance extends AbstractConcordanceWalker { + static final String USAGE_ONE_LINE_SUMMARY = "Evaluate concordance of info fields in an input VCF against a validated truth VCF"; + static final String USAGE_SUMMARY = "This tool evaluates info fields from an input VCF against a VCF that has been validated and is considered to represent ground truth.\n"; + public static final String SUMMARY_LONG_NAME = "summary"; + public static final String SUMMARY_SHORT_NAME = "S"; + + @Argument(doc="A table of summary statistics (true positives, sensitivity, etc.)", fullName=SUMMARY_LONG_NAME, shortName=SUMMARY_SHORT_NAME) + protected String summary; + + @Argument(fullName="eval-info-key", shortName="eval-info-key", doc="Info key from eval vcf") + protected String evalInfoKey; + + @Argument(fullName="truth-info-key", shortName="truth-info-key", doc="Info key from truth vcf") + protected String truthInfoKey; + + @Advanced + @Argument(fullName="warn-big-differences", + shortName="warn-big-differences", + doc="If set differences in the info key values greater than epsilon will trigger warnings.", + optional=true) + protected boolean warnBigDifferences = false; + + @Advanced + @Argument(fullName="epsilon", shortName="epsilon", doc="Difference tolerance", optional=true) + protected double epsilon = 0.1; + + private int snpCount = 0; + private int indelCount = 0; + + private double snpSumDelta = 0.0; + private double snpSumDeltaSquared = 0.0; + private double indelSumDelta = 0.0; + private double indelSumDeltaSquared = 0.0; + + @Override + public void onTraversalStart() { + if(getEvalHeader().getInfoHeaderLine(evalInfoKey) == null){ + throw new UserException("Missing key:"+evalInfoKey+" in Eval VCF:"+evalVariantsFile); + } + + if(getTruthHeader().getInfoHeaderLine(truthInfoKey) == null){ + throw new UserException("Missing key:"+truthInfoKey+" in Truth VCF:"+truthVariantsFile); + } + } + + @Override + protected void apply(AbstractConcordanceWalker.TruthVersusEval truthVersusEval, ReadsContext readsContext, ReferenceContext refContext) { + ConcordanceState concordanceState = truthVersusEval.getConcordance(); + switch (concordanceState) { + case TRUE_POSITIVE: { + if(truthVersusEval.getEval().isSNP()){ + snpCount++; + } else if (truthVersusEval.getEval().isIndel()) { + indelCount++; + } + this.infoDifference(truthVersusEval.getEval(), truthVersusEval.getTruth()); + break; + } + case FALSE_POSITIVE: + case FALSE_NEGATIVE: + case FILTERED_TRUE_NEGATIVE: + case FILTERED_FALSE_NEGATIVE: { + break; + } + default: { + throw new IllegalStateException("Unexpected ConcordanceState: " + concordanceState.toString()); + } + } + } + + private void infoDifference(final VariantContext eval, final VariantContext truth) { + if(eval.hasAttribute(this.evalInfoKey) && truth.hasAttribute(truthInfoKey)) { + final double evalVal = Double.valueOf((String) eval.getAttribute(this.evalInfoKey)); + final double truthVal = Double.valueOf((String) truth.getAttribute(this.truthInfoKey)); + final double delta = evalVal - truthVal; + final double deltaSquared = delta * delta; + if (eval.isSNP()) { + this.snpSumDelta += Math.sqrt(deltaSquared); + this.snpSumDeltaSquared += deltaSquared; + } else if (eval.isIndel()) { + this.indelSumDelta += Math.sqrt(deltaSquared); + this.indelSumDeltaSquared += deltaSquared; + } + if (warnBigDifferences && Math.abs(delta) > this.epsilon) { + this.logger.warn(String.format("Difference (%f) greater than epsilon (%f) at %s:%d %s:", delta, this.epsilon, eval.getContig(), eval.getStart(), eval.getAlleles().toString())); + this.logger.warn(String.format("\t\tTruth info: " + truth.getAttributes().toString())); + this.logger.warn(String.format("\t\tEval info: " + eval.getAttributes().toString())); + } + } + } + + @Override + public Object onTraversalSuccess() { + final double snpMean = this.snpSumDelta / snpCount; + final double snpVariance = (this.snpSumDeltaSquared - this.snpSumDelta * this.snpSumDelta / snpCount) / snpCount; + final double snpStd = Math.sqrt(snpVariance); + final double indelMean = this.indelSumDelta / indelCount; + final double indelVariance = (this.indelSumDeltaSquared - this.indelSumDelta * this.indelSumDelta / indelCount) / indelCount; + final double indelStd = Math.sqrt(indelVariance); + + this.logger.info(String.format("SNP average delta %f and standard deviation: %f", snpMean, snpStd)); + this.logger.info(String.format("INDEL average delta %f and standard deviation: %f", indelMean, indelStd)); + + try (final InfoConcordanceRecord.InfoConcordanceWriter + concordanceWriter = InfoConcordanceRecord.getWriter(Paths.get(this.summary))){ + concordanceWriter.writeRecord(new InfoConcordanceRecord(VariantContext.Type.SNP, this.evalInfoKey, this.truthInfoKey, snpMean, snpStd)); + concordanceWriter.writeRecord(new InfoConcordanceRecord(VariantContext.Type.INDEL, this.evalInfoKey, this.truthInfoKey, indelMean, indelStd)); + } catch (IOException e) { + throw new UserException("Encountered an IO exception writing the concordance summary table", e); + } + + return "SUCCESS"; + } + + @Override + protected boolean areVariantsAtSameLocusConcordant(VariantContext truth, VariantContext eval) { + final boolean sameRefAllele = truth.getReference().equals(eval.getReference()); + final boolean containsAltAllele = eval.getAlternateAlleles().contains(truth.getAlternateAllele(0)); + return sameRefAllele && containsAltAllele; + } + + @Override + protected Predicate makeTruthVariantFilter() { + return vc -> !vc.isFiltered() && !vc.isSymbolicOrSV(); + } + + @Override + protected Predicate makeEvalVariantFilter() { + return vc -> !vc.isFiltered() && !vc.isSymbolicOrSV(); + } + +} diff --git a/src/main/java/org/broadinstitute/hellbender/tools/walkers/validation/InfoConcordanceRecord.java b/src/main/java/org/broadinstitute/hellbender/tools/walkers/validation/InfoConcordanceRecord.java new file mode 100644 index 00000000000..3b704cd4d06 --- /dev/null +++ b/src/main/java/org/broadinstitute/hellbender/tools/walkers/validation/InfoConcordanceRecord.java @@ -0,0 +1,140 @@ +package org.broadinstitute.hellbender.tools.walkers.validation; + +import htsjdk.variant.variantcontext.VariantContext; + +import java.io.IOException; +import java.nio.file.Path; + +import org.broadinstitute.hellbender.exceptions.UserException; + +import org.broadinstitute.hellbender.utils.tsv.DataLine; +import org.broadinstitute.hellbender.utils.tsv.TableColumnCollection; +import org.broadinstitute.hellbender.utils.tsv.TableWriter; +import org.broadinstitute.hellbender.utils.tsv.TableReader; + +/** + * Keeps track of concordance between two info fields. + */ +public class InfoConcordanceRecord { + private static final String VARIANT_TYPE_COLUMN_NAME = "type"; + private static final String EVAL_INFO_KEY = "eval_info_key"; + private static final String TRUE_INFO_KEY = "true_info_key"; + private static final String MEAN_DIFFERENCE = "mean_difference"; + private static final String STD_DIFFERENCE = "std_difference"; + private static final String[] INFO_CONCORDANCE_COLUMN_HEADER = + {VARIANT_TYPE_COLUMN_NAME, EVAL_INFO_KEY, TRUE_INFO_KEY, MEAN_DIFFERENCE, STD_DIFFERENCE}; + final VariantContext.Type type; + private final String evalKey; + private final String trueKey; + private final double mean; + private final double std; + + /** + * Record keeps track of concordance between values from INFO-field keys of a VCF. + * + * @param type SNP or INDEL + * @param evalKey The INFO field key from the eval VCF + * @param trueKey The INFO field key from the truth VCF + * @param mean The mean of the differences in values for these INFO fields. + * @param std The standard deviation of the differences in values for these INFO fields. + */ + public InfoConcordanceRecord(VariantContext.Type type, String evalKey, String trueKey, double mean, double std) { + this.type = type; + this.evalKey = evalKey; + this.trueKey = trueKey; + this.mean = mean; + this.std = std; + } + + /** + * + * @return Variant type (e.g. SNP or INDEL) + */ + public VariantContext.Type getVariantType() { + return this.type; + } + + /** + * + * @return The mean of the differences between two INFO fields + */ + public double getMean() { + return this.mean; + } + + /** + * + * @return The Standard Deviation of the differences between two INFO fields + */ + public double getStd() { + return this.std; + } + + /** + * + * @return The INFO field for the eval VCF + */ + public String getEvalKey() { + return this.evalKey; + } + + /** + * + * @return The INFO field for the truth VCF + */ + public String getTrueKey() { + return this.trueKey; + } + + /** + * Get a table writer + * @param outputTable A Path where the output table will be written + * @return A Table writer for INFO field concordances + */ + public static InfoConcordanceWriter getWriter(Path outputTable) { + try { + InfoConcordanceWriter writer = new InfoConcordanceWriter(outputTable); + return writer; + } + catch (IOException e) { + throw new UserException(String.format("Encountered an IO exception while writing from %s.", outputTable), e); + } + } + + /** + * Table writing class for InfoConcordanceRecords + */ + public static class InfoConcordanceWriter extends TableWriter { + private InfoConcordanceWriter(Path output) throws IOException { + super(output.toFile(), new TableColumnCollection(INFO_CONCORDANCE_COLUMN_HEADER)); + } + + @Override + protected void composeLine(InfoConcordanceRecord record, DataLine dataLine) { + dataLine.set(VARIANT_TYPE_COLUMN_NAME, record.getVariantType().toString()) + .set(EVAL_INFO_KEY, record.getEvalKey()) + .set(TRUE_INFO_KEY, record.getTrueKey()) + .set(MEAN_DIFFERENCE, record.getMean()) + .set(STD_DIFFERENCE, record.getStd()); + } + } + + /** + * Table reading class for InfoConcordanceRecords + */ + public static class InfoConcordanceReader extends TableReader { + public InfoConcordanceReader(Path summary) throws IOException { + super(summary.toFile()); + } + + @Override + protected InfoConcordanceRecord createRecord(DataLine dataLine) { + VariantContext.Type type = VariantContext.Type.valueOf(dataLine.get(VARIANT_TYPE_COLUMN_NAME)); + String evalKey = dataLine.get(EVAL_INFO_KEY); + String trueKey = dataLine.get(TRUE_INFO_KEY); + double mean = Double.parseDouble(dataLine.get(MEAN_DIFFERENCE)); + double std = Double.parseDouble(dataLine.get(STD_DIFFERENCE)); + return new InfoConcordanceRecord(type, evalKey, trueKey, mean, std); + } + } +} \ No newline at end of file diff --git a/src/main/java/org/broadinstitute/hellbender/tools/walkers/vqsr/CNNScoreVariants.java b/src/main/java/org/broadinstitute/hellbender/tools/walkers/vqsr/CNNScoreVariants.java index be4e0c34ebb..1a114aea907 100644 --- a/src/main/java/org/broadinstitute/hellbender/tools/walkers/vqsr/CNNScoreVariants.java +++ b/src/main/java/org/broadinstitute/hellbender/tools/walkers/vqsr/CNNScoreVariants.java @@ -24,7 +24,6 @@ import java.io.*; import java.util.*; -import java.util.stream.StreamSupport; /** @@ -92,15 +91,15 @@ * -weights path/to/my_weights.hd5 * */ -@DocumentedFeature @ExperimentalFeature +@DocumentedFeature @CommandLineProgramProperties( summary = CNNScoreVariants.USAGE_SUMMARY, oneLineSummary = CNNScoreVariants.USAGE_ONE_LINE_SUMMARY, programGroup = VariantFilteringProgramGroup.class ) -public class CNNScoreVariants extends VariantWalker { +public class CNNScoreVariants extends TwoPassVariantWalker { private final static String NL = String.format("%n"); static final String USAGE_ONE_LINE_SUMMARY = "Apply a Convolutional Neural Net to filter annotated variants"; static final String USAGE_SUMMARY = "Annotate a VCF with scores from a Convolutional Neural Network (CNN)." + @@ -116,7 +115,6 @@ public class CNNScoreVariants extends VariantWalker { private static final int ALT_INDEX = 3; private static final int KEY_INDEX = 4; private static final int FIFO_STRING_INITIAL_CAPACITY = 1024; - private static final int MAX_READ_BATCH = 4098; @Argument(fullName = StandardArgumentDefinitions.OUTPUT_LONG_NAME, shortName = StandardArgumentDefinitions.OUTPUT_SHORT_NAME, @@ -138,6 +136,10 @@ public class CNNScoreVariants extends VariantWalker { @Argument(fullName = "filter-symbolic-and-sv", shortName = "filter-symbolic-and-sv", doc = "If set will filter symbolic and and structural variants from the input VCF", optional = true) private boolean filterSymbolicAndSV = false; + @Advanced + @Argument(fullName="info-annotation-keys", shortName="info-annotation-keys", doc="The VCF info fields to send to python.", optional=true) + private List annotationKeys = new ArrayList<>(Arrays.asList("MQ", "DP", "SOR", "FS", "QD", "MQRankSum", "ReadPosRankSum")); + @Advanced @Argument(fullName = "inference-batch-size", shortName = "inference-batch-size", doc = "Size of batches for python to do inference on.", minValue = 1, maxValue = 4096, optional = true) private int inferenceBatchSize = 256; @@ -180,9 +182,11 @@ public class CNNScoreVariants extends VariantWalker { private int windowEnd = windowSize / 2; private int windowStart = windowSize / 2; private boolean waitforBatchCompletion = false; - private File scoreFile; + private File scoreFile; private String scoreKey; + private Scanner scoreScan; + private VariantContextWriter vcfWriter; private static String resourcePathReadTensor = Resource.LARGE_RUNTIME_RESOURCES_PATH + "/cnn_score_variants/small_2d.json"; private static String resourcePathReferenceTensor = Resource.LARGE_RUNTIME_RESOURCES_PATH + "/cnn_score_variants/1d_cnn_mix_train_full_bn.json"; @@ -228,15 +232,14 @@ public List getDefaultReadFilters() { @Override public void onTraversalStart() { - scoreKey = getScoreKeyAndCheckModelAndReadsHarmony(); - if (architecture == null && weights == null) { - setArchitectureAndWeightsFromResources(); + if (getHeaderForVariants().getGenotypeSamples().size() > 1) { + logger.warn("CNNScoreVariants is a single sample tool, but the input VCF has more than 1 sample."); } // Start the Python process and initialize a stream writer for streaming data to the Python code pythonExecutor.start(Collections.emptyList(), enableJournal, pythonProfileResults); - pythonExecutor.initStreamWriter(AsynchronousStreamWriter.stringSerializer); + batchList = new ArrayList<>(transferBatchSize); // Execute Python code to open our output file, where it will write the contents of everything it reads @@ -248,26 +251,11 @@ public void onTraversalStart() { } else { logger.info("Saving temp file from python:" + scoreFile.getAbsolutePath()); } - - pythonExecutor.sendSynchronousCommand("from keras import backend" + NL); - pythonExecutor.sendSynchronousCommand(String.format("backend.set_session(backend.tf.Session(config=backend.tf.ConfigProto(intra_op_parallelism_threads=%d, inter_op_parallelism_threads=%d)))" + NL, intraOpThreads, interOpThreads)); - pythonExecutor.sendSynchronousCommand(String.format("tempFile = open('%s', 'w+')" + NL, scoreFile.getAbsolutePath())); pythonExecutor.sendSynchronousCommand("import vqsr_cnn" + NL); - String getArgsAndModel; - if (weights != null && architecture != null) { - getArgsAndModel = String.format("args, model = vqsr_cnn.args_and_model_from_semantics('%s', weights_hd5='%s')", architecture, weights) + NL; - logger.info("Using key:" + scoreKey + " for CNN architecture:" + architecture + " and weights:" + weights); - } else if (architecture == null) { - getArgsAndModel = String.format("args, model = vqsr_cnn.args_and_model_from_semantics(None, weights_hd5='%s', tensor_type='%s')", weights, tensorType.name()) + NL; - logger.info("Using key:" + scoreKey + " for CNN weights:" + weights); - } else { - getArgsAndModel = String.format("args, model = vqsr_cnn.args_and_model_from_semantics('%s')", architecture) + NL; - logger.info("Using key:" + scoreKey + " for CNN architecture:" + architecture); - } - pythonExecutor.sendSynchronousCommand(getArgsAndModel); - + scoreKey = getScoreKeyAndCheckModelAndReadsHarmony(); + initializePythonArgsAndModel(); } catch (IOException e) { throw new GATKException("Error when creating temp file and initializing python executor.", e); } @@ -275,7 +263,7 @@ public void onTraversalStart() { } @Override - public void apply(final VariantContext variant, final ReadsContext readsContext, final ReferenceContext referenceContext, final FeatureContext featureContext) { + public void firstPassApply(final VariantContext variant, final ReadsContext readsContext, final ReferenceContext referenceContext, final FeatureContext featureContext) { referenceContext.setWindow(windowStart, windowEnd); if (tensorType.isReadsRequired()) { transferReadsToPythonViaFifo(variant, readsContext, referenceContext); @@ -285,6 +273,64 @@ public void apply(final VariantContext variant, final ReadsContext readsContext, sendBatchIfReady(); } + @Override + public void afterFirstPass() { + if (waitforBatchCompletion) { + pythonExecutor.waitForPreviousBatchCompletion(); + } + if (curBatchSize > 0) { + executePythonCommand(); + pythonExecutor.waitForPreviousBatchCompletion(); + } + + pythonExecutor.sendSynchronousCommand("tempFile.close()" + NL); + pythonExecutor.terminate(); + + try { + scoreScan = new Scanner(scoreFile); + vcfWriter = createVCFWriter(new File(outputFile)); + scoreScan.useDelimiter("\\n"); + writeVCFHeader(vcfWriter); + } catch (IOException e) { + throw new GATKException("Error when trying to temporary score file scanner.", e); + } + + } + + @Override + protected void secondPassApply(VariantContext variant, ReadsContext readsContext, ReferenceContext referenceContext, FeatureContext featureContext) { + String sv = scoreScan.nextLine(); + String[] scoredVariant = sv.split("\\t"); + + if (variant.getContig().equals(scoredVariant[CONTIG_INDEX]) + && Integer.toString(variant.getStart()).equals(scoredVariant[POS_INDEX]) + && variant.getReference().getBaseString().equals(scoredVariant[REF_INDEX]) + && variant.getAlternateAlleles().toString().equals(scoredVariant[ALT_INDEX])) { + + final VariantContextBuilder builder = new VariantContextBuilder(variant); + if (scoredVariant.length > KEY_INDEX) { + builder.attribute(scoreKey, scoredVariant[KEY_INDEX]); + } + vcfWriter.add(builder.make()); + + } else { + String errorMsg = "Score file out of sync with original VCF. Score file has:" + sv; + errorMsg += "\n But VCF has:" + variant.toStringWithoutGenotypes(); + throw new GATKException(errorMsg); + } + } + + @Override + public void closeTool() { + logger.info("Done scoring variants with CNN."); + if (vcfWriter != null) { + vcfWriter.close(); + } + if (scoreScan != null){ + scoreScan.close(); + } + } + private void transferToPythonViaFifo(final VariantContext variant, final ReferenceContext referenceContext) { try { final String outDat = String.format("%s\t%s\t%s\t%s\n", @@ -329,6 +375,7 @@ private void transferReadsToPythonViaFifo(final VariantContext variant, final Re if (!readIt.hasNext()) { logger.warn("No reads at contig:" + variant.getContig() + " site:" + String.valueOf(variant.getStart())); } + while (readIt.hasNext()) { sb.append(GATKReadToString(readIt.next())); } @@ -374,29 +421,15 @@ private String getVariantDataString(final VariantContext variant) { private String getVariantInfoString(final VariantContext variant) { // Create a string that will easily be parsed as a python dictionary - String varInfo = ""; - for (final String attributeKey : variant.getAttributes().keySet()) { - varInfo += attributeKey + "=" + variant.getAttribute(attributeKey).toString().replace(" ", "").replace("[", "").replace("]", "") + ";"; - } - return varInfo; - } - - @Override - public Object onTraversalSuccess() { - if (waitforBatchCompletion) { - pythonExecutor.waitForPreviousBatchCompletion(); - } - if (curBatchSize > 0) { - executePythonCommand(); - pythonExecutor.waitForPreviousBatchCompletion(); + StringBuilder sb = new StringBuilder(FIFO_STRING_INITIAL_CAPACITY); + for (final String attributeKey : annotationKeys) { + if (variant.hasAttribute(attributeKey)) { + sb.append(attributeKey); + sb.append("="); + sb.append(variant.getAttribute(attributeKey).toString().replace(" ", "").replace("[", "").replace("]", "") + ";"); + } } - - pythonExecutor.sendSynchronousCommand("tempFile.close()" + NL); - pythonExecutor.terminate(); - - writeOutputVCFWithScores(); - - return true; + return sb.toString(); } private void executePythonCommand() { @@ -408,42 +441,6 @@ private void executePythonCommand() { pythonExecutor.startBatchWrite(pythonCommand, batchList); } - - private void writeOutputVCFWithScores() { - try (final Scanner scoreScan = new Scanner(scoreFile); - final VariantContextWriter vcfWriter = createVCFWriter(new File(outputFile))) { - scoreScan.useDelimiter("\\n"); - writeVCFHeader(vcfWriter); - final VariantFilter variantfilter = makeVariantFilter(); - - // Annotate each variant in the input stream, as in variantWalkerBase.traverse() - StreamSupport.stream(getSpliteratorForDrivingVariants(), false) - .filter(variantfilter) - .forEach(variant -> { - String sv = scoreScan.nextLine(); - String[] scoredVariant = sv.split("\\t"); - if (variant.getContig().equals(scoredVariant[CONTIG_INDEX]) - && Integer.toString(variant.getStart()).equals(scoredVariant[POS_INDEX]) - && variant.getReference().getBaseString().equals(scoredVariant[REF_INDEX]) - && variant.getAlternateAlleles().toString().equals(scoredVariant[ALT_INDEX])) { - final VariantContextBuilder builder = new VariantContextBuilder(variant); - if (scoredVariant.length > KEY_INDEX) { - builder.attribute(scoreKey, scoredVariant[KEY_INDEX]); - } - vcfWriter.add(builder.make()); - } else { - String errorMsg = "Score file out of sync with original VCF. Score file has:" + sv; - errorMsg += "\n But VCF has:" + variant.toStringWithoutGenotypes(); - throw new GATKException(errorMsg); - } - }); - - } catch (IOException e) { - throw new GATKException("Error when trying to write annotated VCF.", e); - } - - } - private void writeVCFHeader(VariantContextWriter vcfWriter) { // setup the header fields final VCFHeader inputHeader = getHeaderForVariants(); @@ -471,20 +468,33 @@ private String getScoreKeyAndCheckModelAndReadsHarmony() { } } - private void setArchitectureAndWeightsFromResources() { - if (tensorType.equals(TensorType.read_tensor)) { - architecture = IOUtils.writeTempResourceFromPath(resourcePathReadTensor, null).getAbsolutePath(); - weights = IOUtils.writeTempResourceFromPath( - resourcePathReadTensor.replace(".json", ".hd5"), - null).getAbsolutePath(); - } else if (tensorType.equals(TensorType.reference)) { - architecture = IOUtils.writeTempResourceFromPath(resourcePathReferenceTensor, null).getAbsolutePath(); - weights = IOUtils.writeTempResourceFromPath( - resourcePathReferenceTensor.replace(".json", ".hd5"), null).getAbsolutePath(); + private void initializePythonArgsAndModel(){ + if (weights == null && architecture == null) { + if (tensorType.equals(TensorType.read_tensor)) { + architecture = IOUtils.writeTempResourceFromPath(resourcePathReadTensor, null).getAbsolutePath(); + weights = IOUtils.writeTempResourceFromPath( + resourcePathReadTensor.replace(".json", ".hd5"), + null).getAbsolutePath(); + } else if (tensorType.equals(TensorType.reference)) { + architecture = IOUtils.writeTempResourceFromPath(resourcePathReferenceTensor, null).getAbsolutePath(); + weights = IOUtils.writeTempResourceFromPath( + resourcePathReferenceTensor.replace(".json", ".hd5"), null).getAbsolutePath(); + } else { + throw new GATKException("No default architecture for tensor type:" + tensorType.name()); + } + } + + String getArgsAndModel; + if (weights != null && architecture != null) { + getArgsAndModel = String.format("args, model = vqsr_cnn.start_session_get_args_and_model(%d, %d, '%s', weights_hd5='%s')", intraOpThreads, interOpThreads, architecture, weights) + NL; + logger.info("Using key:" + scoreKey + " for CNN architecture:" + architecture + " and weights:" + weights); + } else if (architecture == null) { + getArgsAndModel = String.format("args, model = vqsr_cnn.start_session_get_args_and_model(%d, %d, None, weights_hd5='%s', tensor_type='%s')", intraOpThreads, interOpThreads, weights, tensorType.name()) + NL; + logger.info("Using key:" + scoreKey + " for CNN weights:" + weights); } else { - throw new GATKException("No default architecture for tensor type:" + tensorType.name()); + getArgsAndModel = String.format("args, model = vqsr_cnn.start_session_get_args_and_model(%d, %d, '%s')", intraOpThreads, interOpThreads, architecture) + NL; + logger.info("Using key:" + scoreKey + " for CNN architecture:" + architecture); } + pythonExecutor.sendSynchronousCommand(getArgsAndModel); } - } - diff --git a/src/main/java/org/broadinstitute/hellbender/tools/walkers/vqsr/CNNVariantTrain.java b/src/main/java/org/broadinstitute/hellbender/tools/walkers/vqsr/CNNVariantTrain.java index 57127447a0c..bde150764d0 100644 --- a/src/main/java/org/broadinstitute/hellbender/tools/walkers/vqsr/CNNVariantTrain.java +++ b/src/main/java/org/broadinstitute/hellbender/tools/walkers/vqsr/CNNVariantTrain.java @@ -13,6 +13,7 @@ import java.util.Arrays; import java.util.List; + /** * Train a Convolutional Neural Network (CNN) for filtering variants. * This tool expects requires training data generated by {@link CNNVariantWriteTensors}. @@ -37,17 +38,17 @@ *

Train a 1D CNN on Reference Tensors

*
  * gatk CNNVariantTrain \
- *   --tensor-type reference \
- *   --input-tensor-dir my_tensor_folder \
- *   --model-name my_1d_model
+ *   -tensor-type reference \
+ *   -input-tensor-dir my_tensor_folder \
+ *   -model-name my_1d_model
  * 
* *

Train a 2D CNN on Read Tensors

*
  * gatk CNNVariantTrain \
- *   --input-tensor-dir my_tensor_folder \
- *   --tensor-type read-tensor \
- *   --model-name my_2d_model
+ *   -input-tensor-dir my_tensor_folder \
+ *   -tensor-type read-tensor \
+ *   -model-name my_2d_model
  * 
* */ @@ -66,7 +67,7 @@ public class CNNVariantTrain extends CommandLineProgram { @Argument(fullName = "output-dir", shortName = "output-dir", doc = "Directory where models will be saved, defaults to current working directory.", optional = true) private String outputDir = "./"; - @Argument(fullName = "tensor-type", shortName = "tensor-type", doc = "Name of the tensors to generate, reference for 1D reference tensors and read_tensor for 2D tensors.", optional = true) + @Argument(fullName = "tensor-type", shortName = "tensor-type", doc = "Type of tensors to use as input reference for 1D reference tensors and read_tensor for 2D tensors.", optional = true) private TensorType tensorType = TensorType.reference; @Argument(fullName = "model-name", shortName = "model-name", doc = "Name of the model to be trained.", optional = true) @@ -84,6 +85,42 @@ public class CNNVariantTrain extends CommandLineProgram { @Argument(fullName = "image-dir", shortName = "image-dir", doc = "Path where plots and figures are saved.", optional = true) private String imageDir; + @Argument(fullName = "conv-width", shortName = "conv-width", doc = "Width of convolution kernels", optional = true) + private int convWidth = 5; + + @Argument(fullName = "conv-height", shortName = "conv-height", doc = "Height of convolution kernels", optional = true) + private int convHeight = 5; + + @Argument(fullName = "conv-dropout", shortName = "conv-dropout", doc = "Dropout rate in convolution layers", optional = true) + private float convDropout = 0.0f; + + @Argument(fullName = "conv-batch-normalize", shortName = "conv-batch-normalize", doc = "Batch normalize convolution layers", optional = true) + private boolean convBatchNormalize = false; + + @Argument(fullName = "conv-layers", shortName = "conv-layers", doc = "List of number of filters to use in each convolutional layer", optional = true) + private List convLayers = new ArrayList(); + + @Argument(fullName = "padding", shortName = "padding", doc = "Padding for convolution layers, valid or same", optional = true) + private String padding = "valid"; + + @Argument(fullName = "spatial-dropout", shortName = "spatial-dropout", doc = "Spatial dropout on convolution layers", optional = true) + private boolean spatialDropout = false; + + @Argument(fullName = "fc-layers", shortName = "fc-layers", doc = "List of number of filters to use in each fully-connected layer", optional = true) + private List fcLayers = new ArrayList(); + + @Argument(fullName = "fc-dropout", shortName = "fc-dropout", doc = "Dropout rate in fully-connected layers", optional = true) + private float fcDropout = 0.0f; + + @Argument(fullName = "fc-batch-normalize", shortName = "fc-batch-normalize", doc = "Batch normalize fully-connected layers", optional = true) + private boolean fcBatchNormalize = false; + + @Argument(fullName = "annotation-units", shortName = "annotation-units", doc = "Number of units connected to the annotation input layer", optional = true) + private int annotationUnits = 16; + + @Argument(fullName = "annotation-shortcut", shortName = "annotation-shortcut", doc = "Shortcut connections on the annotation layers.", optional = true) + private boolean annotationShortcut = false; + @Advanced @Argument(fullName = "channels-last", shortName = "channels-last", doc = "Store the channels in the last axis of tensors, tensorflow->true, theano->false", optional = true) private boolean channelsLast = true; @@ -109,11 +146,18 @@ protected Object doWork() { "--output_dir", outputDir, "--tensor_name", tensorType.name(), "--annotation_set", annotationSet, + "--conv_width", Integer.toString(convWidth), + "--conv_height", Integer.toString(convHeight), + "--conv_dropout", Float.toString(convDropout), + "--padding", padding, + "--fc_dropout", Float.toString(fcDropout), + "--annotation_units", Integer.toString(annotationUnits), "--epochs", Integer.toString(epochs), "--training_steps", Integer.toString(trainingSteps), "--validation_steps", Integer.toString(validationSteps), "--id", modelName)); + // Add boolean arguments if(channelsLast){ arguments.add("--channels_last"); } else { @@ -124,12 +168,45 @@ protected Object doWork() { arguments.addAll(Arrays.asList("--image_dir", imageDir)); } - if (tensorType == TensorType.reference) { - arguments.addAll(Arrays.asList("--mode", "train_on_reference_tensors_and_annotations")); - } else if (tensorType == TensorType.read_tensor) { - arguments.addAll(Arrays.asList("--mode", "train_small_model_on_read_tensors_and_annotations")); - } else { - throw new GATKException("Unknown tensor mapping mode:"+ tensorType.name()); + if (convLayers.size() == 0 && fcLayers.size() == 0){ + if (tensorType == TensorType.reference) { + arguments.addAll(Arrays.asList("--mode", "train_default_1d_model")); + } else if (tensorType == TensorType.read_tensor) { + arguments.addAll(Arrays.asList("--mode", "train_default_2d_model")); + } else { + throw new GATKException("Unknown tensor mapping mode:"+ tensorType.name()); + } + } else { // Command line specified custom architecture + if(convBatchNormalize){ + arguments.add("--conv_batch_normalize"); + } + if(fcBatchNormalize){ + arguments.add("--fc_batch_normalize"); + } + if(spatialDropout){ + arguments.add("--spatial_dropout"); + } + if(annotationShortcut){ + arguments.add("--annotation_shortcut"); + } + + // Add list arguments + arguments.add("--conv_layers"); + for(Integer cl : convLayers){ + arguments.add(Integer.toString(cl)); + } + arguments.add("--fc_layers"); + for(Integer fl : fcLayers){ + arguments.add(Integer.toString(fl)); + } + + if (tensorType == TensorType.reference) { + arguments.addAll(Arrays.asList("--mode", "train_args_model_on_reference_and_annotations")); + } else if (tensorType == TensorType.read_tensor) { + arguments.addAll(Arrays.asList("--mode", "train_args_model_on_read_tensors_and_annotations")); + } else { + throw new GATKException("Unknown tensor mapping mode:"+ tensorType.name()); + } } logger.info("Args are:"+ Arrays.toString(arguments.toArray())); diff --git a/src/main/java/org/broadinstitute/hellbender/tools/walkers/vqsr/CNNVariantWriteTensors.java b/src/main/java/org/broadinstitute/hellbender/tools/walkers/vqsr/CNNVariantWriteTensors.java index fd1cd1fc390..be7ba2b60be 100644 --- a/src/main/java/org/broadinstitute/hellbender/tools/walkers/vqsr/CNNVariantWriteTensors.java +++ b/src/main/java/org/broadinstitute/hellbender/tools/walkers/vqsr/CNNVariantWriteTensors.java @@ -47,10 +47,10 @@ * gatk CNNVariantWriteTensors \ * -R reference.fasta \ * -V input.vcf.gz \ - * --truth-vcf platinum-genomes.vcf \ - * --truth-bed platinum-confident-region.bed \ - * --tensor-type reference \ - * --output-tensor-dir my-tensor-folder + * -truth-vcf platinum-genomes.vcf \ + * -truth-bed platinum-confident-region.bed \ + * -tensor-type reference \ + * -output-tensor-dir my-tensor-folder * * *

Write Read Tensors

@@ -58,11 +58,11 @@ * gatk CNNVariantWriteTensors \ * -R reference.fasta \ * -V input.vcf.gz \ - * --truth-vcf platinum-genomes.vcf \ - * --truth-bed platinum-confident-region.bed \ - * --tensor-type read_tensor \ - * --bam-file input.bam \ - * --output-tensor-dir my-tensor-folder + * -truth-vcf platinum-genomes.vcf \ + * -truth-bed platinum-confident-region.bed \ + * -tensor-type read_tensor \ + * -bam-file input.bam \ + * -output-tensor-dir my-tensor-folder * * */ @@ -100,6 +100,12 @@ public class CNNVariantWriteTensors extends CommandLineProgram { @Argument(fullName = "tensor-type", shortName = "tensor-type", doc = "Name of the tensors to generate.") private TensorType tensorType = TensorType.reference; + @Argument(fullName = "downsample-snps", shortName = "downsample-snps", doc = "Fraction of SNPs to write tensors for.", optional = true) + private float downsampleSnps = 0.05f; + + @Argument(fullName = "downsample-indels", shortName = "downsample-indels", doc = "Fraction of INDELs to write tensors for.", optional = true) + private float downsampleIndels = 0.5f; + @Advanced @Argument(fullName = "channels-last", shortName = "channels-last", doc = "Store the channels in the last axis of tensors, tensorflow->true, theano->false", optional = true) private boolean channelsLast = true; @@ -131,6 +137,8 @@ protected Object doWork() { "--tensor_name", tensorType.name(), "--annotation_set", annotationSet, "--samples", Integer.toString(maxTensors), + "--downsample_snps", Float.toString(downsampleSnps), + "--downsample_indels", Float.toString(downsampleIndels), "--data_dir", outputTensorsDir)); if(channelsLast){ diff --git a/src/main/java/org/broadinstitute/hellbender/tools/walkers/vqsr/FilterVariantTranches.java b/src/main/java/org/broadinstitute/hellbender/tools/walkers/vqsr/FilterVariantTranches.java index 1788d7790cb..8f83b5b2482 100644 --- a/src/main/java/org/broadinstitute/hellbender/tools/walkers/vqsr/FilterVariantTranches.java +++ b/src/main/java/org/broadinstitute/hellbender/tools/walkers/vqsr/FilterVariantTranches.java @@ -176,9 +176,11 @@ public void afterFirstPass() { @Override protected void secondPassApply(VariantContext variant, ReadsContext readsContext, ReferenceContext referenceContext, FeatureContext featureContext) { final VariantContextBuilder builder = new VariantContextBuilder(variant); + if (removeOldFilters) { builder.unfiltered(); } + if (variant.hasAttribute(infoKey)) { final double score = Double.parseDouble((String) variant.getAttribute(infoKey)); if (variant.isSNP() && isTrancheFiltered(score, snpCutoffs)) { diff --git a/src/main/python/org/broadinstitute/hellbender/vqsr_cnn/__init__.py b/src/main/python/org/broadinstitute/hellbender/vqsr_cnn/__init__.py index 41ed768c90b..ece27f8dc1e 100644 --- a/src/main/python/org/broadinstitute/hellbender/vqsr_cnn/__init__.py +++ b/src/main/python/org/broadinstitute/hellbender/vqsr_cnn/__init__.py @@ -1,5 +1,6 @@ -from .vqsr_cnn.models import build_read_tensor_2d_and_annotations_model, build_tiny_2d_annotation_model, build_reference_annotation_model -from .vqsr_cnn.models import args_and_model_from_semantics, train_model_from_generators, build_small_2d_annotation_model +from .vqsr_cnn.models import build_2d_annotation_model_from_args, build_1d_annotation_model_from_args +from .vqsr_cnn.models import build_default_1d_annotation_model, build_default_2d_annotation_model +from .vqsr_cnn.models import start_session_get_args_and_model, train_model_from_generators from .vqsr_cnn.tensor_maps import get_tensor_channel_map_from_args, tensor_shape_from_args from .vqsr_cnn.arguments import parse_args, weight_path_from_args, annotations_from_args from .vqsr_cnn.inference import score_and_write_batch diff --git a/src/main/python/org/broadinstitute/hellbender/vqsr_cnn/vqsr_cnn/__init__.py b/src/main/python/org/broadinstitute/hellbender/vqsr_cnn/vqsr_cnn/__init__.py index 247587826d9..c89424a56ad 100644 --- a/src/main/python/org/broadinstitute/hellbender/vqsr_cnn/vqsr_cnn/__init__.py +++ b/src/main/python/org/broadinstitute/hellbender/vqsr_cnn/vqsr_cnn/__init__.py @@ -1,6 +1,6 @@ -from .models import build_read_tensor_2d_and_annotations_model, build_tiny_2d_annotation_model -from .models import args_and_model_from_semantics, build_small_2d_annotation_model -from .models import build_reference_annotation_model, train_model_from_generators +from .models import build_2d_annotation_model_from_args, build_1d_annotation_model_from_args +from .models import build_default_1d_annotation_model, build_default_2d_annotation_model +from .models import start_session_get_args_and_model, train_model_from_generators from .tensor_maps import get_tensor_channel_map_from_args, tensor_shape_from_args from .arguments import parse_args, weight_path_from_args, annotations_from_args from .inference import score_and_write_batch diff --git a/src/main/python/org/broadinstitute/hellbender/vqsr_cnn/vqsr_cnn/arguments.py b/src/main/python/org/broadinstitute/hellbender/vqsr_cnn/vqsr_cnn/arguments.py index 6ab1bf15f33..f985626269c 100644 --- a/src/main/python/org/broadinstitute/hellbender/vqsr_cnn/vqsr_cnn/arguments.py +++ b/src/main/python/org/broadinstitute/hellbender/vqsr_cnn/vqsr_cnn/arguments.py @@ -31,7 +31,6 @@ def parse_args(): help='Key which maps to an input symbol to index mapping.') parser.add_argument('--input_symbols', help='Dict mapping input symbols to their index within input tensors, ' + 'initialised via input_symbols_set argument') - parser.add_argument('--batch_size', default=32, type=int, help='Mini batch size for stochastic gradient descent algorithms.') parser.add_argument('--read_limit', default=128, type=int, @@ -73,6 +72,7 @@ def parse_args(): help='Whether to skip positive examples when writing tensors.') parser.add_argument('--chrom', help='Chromosome to load for parallel tensor writing.') + # I/O files and directories: vcfs, bams, beds, hd5, fasta parser.add_argument('--output_dir', default='./', help='Directory to write models or other data out.') parser.add_argument('--image_dir', default=None, help='Directory to write images and plots to.') @@ -111,6 +111,32 @@ def parse_args(): parser.add_argument('--tensor_board', default=False, action='store_true', help='Add the tensor board callback.') + # Architecture defining arguments + parser.add_argument('--conv_width', default=5, type=int, help='Width of convolutional kernels.') + parser.add_argument('--conv_height', default=5, type=int, help='Height of convolutional kernels.') + parser.add_argument('--conv_dropout', default=0.0, type=float, + help='Dropout rate in convolutional layers.') + parser.add_argument('--conv_batch_normalize', default=False, action='store_true', + help='Batch normalize convolutional layers.') + parser.add_argument('--conv_layers', nargs='+', default=[128, 96, 64, 48], type=int, + help='List of sizes for each convolutional filter layer') + parser.add_argument('--padding', default='valid', choices=['valid', 'same'], + help='Valid or same border padding for convolutional layers.') + parser.add_argument('--spatial_dropout', default=False, action='store_true', + help='Spatial dropout on the convolutional layers.') + parser.add_argument('--max_pools', nargs='+', default=[], type=int, + help='List of max-pooling layers.') + parser.add_argument('--fc_layers', nargs='+', default=[32], type=int, + help='List of sizes for each fully connected layer') + parser.add_argument('--fc_dropout', default=0.0, type=float, + help='Dropout rate in fully connected layers.') + parser.add_argument('--fc_batch_normalize', default=False, action='store_true', + help='Batch normalize fully connected layers.') + parser.add_argument('--annotation_units', default=16, type=int, + help='Number of units connected to the annotation input layer.') + parser.add_argument('--annotation_shortcut', default=False, action='store_true', + help='Shortcut connections on the annotations.') + # Evaluation related arguments parser.add_argument('--score_keys', nargs='+', default=['VQSLOD'], help='List of variant score keys for performance comparisons.') diff --git a/src/main/python/org/broadinstitute/hellbender/vqsr_cnn/vqsr_cnn/inference.py b/src/main/python/org/broadinstitute/hellbender/vqsr_cnn/vqsr_cnn/inference.py index 6c43b7c184f..a99baf1f384 100644 --- a/src/main/python/org/broadinstitute/hellbender/vqsr_cnn/vqsr_cnn/inference.py +++ b/src/main/python/org/broadinstitute/hellbender/vqsr_cnn/vqsr_cnn/inference.py @@ -1,5 +1,6 @@ # Imports import os +import math import h5py import numpy as np from collections import Counter, defaultdict, namedtuple @@ -52,7 +53,6 @@ def score_and_write_batch(args, model, file_out, batch_size, python_batch_size, variant_types = [] variant_data = [] read_batch = [] - for _ in range(batch_size): fifo_line = tool.readDataFIFO() fifo_data = fifo_line.split(defines.SEPARATOR_CHAR) @@ -117,7 +117,6 @@ def reference_string_to_tensor(reference): break else: raise ValueError('Error! Unknown code:', b) - return dna_data @@ -126,9 +125,8 @@ def annotation_string_to_tensor(args, annotation_string): name_val_arrays = [p.split('=') for p in name_val_pairs] annotation_map = {str(p[0]).strip() : p[1] for p in name_val_arrays if len(p) > 1} annotation_data = np.zeros(( len(defines.ANNOTATIONS[args.annotation_set]),)) - for i,a in enumerate(defines.ANNOTATIONS[args.annotation_set]): - if a in annotation_map: + if a in annotation_map and not math.isnan(float(annotation_map[a])): annotation_data[i] = annotation_map[a] return annotation_data @@ -434,3 +432,14 @@ def _write_tensor_to_hd5(args, tensor, annotations, contig, pos, variant_type): with h5py.File(tensor_path, 'w') as hf: hf.create_dataset(args.tensor_name, data=tensor, compression='gzip') hf.create_dataset(args.annotation_set, data=annotations, compression='gzip') + +def clear_session(): + try: + K.clear_session() + K.get_session().close() + cfg = K.tf.ConfigProto() + cfg.gpu_options.allow_growth = True + K.set_session(K.tf.Session(config=cfg)) + except AttributeError as e: + print('Could not clear session. Maybe you are using Theano backend?') + diff --git a/src/main/python/org/broadinstitute/hellbender/vqsr_cnn/vqsr_cnn/models.py b/src/main/python/org/broadinstitute/hellbender/vqsr_cnn/vqsr_cnn/models.py index 4b5e3900486..b42f8fd5cc0 100644 --- a/src/main/python/org/broadinstitute/hellbender/vqsr_cnn/vqsr_cnn/models.py +++ b/src/main/python/org/broadinstitute/hellbender/vqsr_cnn/vqsr_cnn/models.py @@ -9,7 +9,7 @@ from keras.models import Model, load_model from keras.layers.convolutional import Conv1D, Conv2D, MaxPooling1D, MaxPooling2D from keras.callbacks import ModelCheckpoint, EarlyStopping, TensorBoard, ReduceLROnPlateau -from keras.layers import Input, Dense, Dropout, BatchNormalization, SpatialDropout2D, Activation, Flatten +from keras.layers import Input, Dense, Dropout, BatchNormalization, SpatialDropout1D, SpatialDropout2D, Activation, Flatten, AlphaDropout from . import plots from . import defines @@ -17,10 +17,19 @@ from . import tensor_maps +def start_session_get_args_and_model(intra_ops, inter_ops, semantics_json, weights_hd5=None, tensor_type=None): + K.clear_session() + K.get_session().close() + cfg = K.tf.ConfigProto(intra_op_parallelism_threads=intra_ops, inter_op_parallelism_threads=inter_ops) + cfg.gpu_options.allow_growth = True + K.set_session(K.tf.Session(config=cfg)) + return args_and_model_from_semantics(semantics_json, weights_hd5, tensor_type) + + def args_and_model_from_semantics(semantics_json, weights_hd5=None, tensor_type=None): args = arguments.parse_args() - if semantics_json is not None: + if semantics_json is not None and os.path.exists(semantics_json): model = set_args_and_get_model_from_semantics(args, semantics_json, weights_hd5) else: model = load_model(weights_hd5, custom_objects=get_metric_dict(args.labels)) @@ -83,162 +92,70 @@ def set_args_and_get_model_from_semantics(args, semantics_json, weights_hd5=None # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~ Models ~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -def build_reference_annotation_model(args): - '''Build Reference 1d CNN model for classifying variants with skip connected annotations. - - Convolutions followed by dense connection, concatenated with annotations. - Dynamically sets input channels based on args via tensor_maps.get_tensor_channel_map_from_args(args) - Uses the functional API. - Prints out model summary. - - Arguments - args.tensor_name: The name of the tensor mapping which data goes to which channels - args.annotation_set: The variant annotation set, perhaps from a HaplotypeCaller VCF. - args.labels: The output labels (e.g. SNP, NOT_SNP, INDEL, NOT_INDEL) - - Returns - The keras model - ''' - if K.image_data_format() == 'channels_last': - channel_axis = -1 - else: - channel_axis = 1 - - channel_map = tensor_maps.get_tensor_channel_map_from_args(args) - reference = Input(shape=(args.window_size, len(channel_map)), name=args.tensor_name) - conv_width = 12 - conv_dropout = 0.1 - fc_dropout = 0.2 - x = Conv1D(filters=256, kernel_size=conv_width, activation="relu", kernel_initializer='he_normal')(reference) - x = Conv1D(filters=256, kernel_size=conv_width, activation="relu", kernel_initializer='he_normal')(x) - x = Dropout(conv_dropout)(x) - x = Conv1D(filters=128, kernel_size=conv_width, activation="relu", kernel_initializer='he_normal')(x) - x = Dropout(conv_dropout)(x) - x = Flatten()(x) - - annotations = Input(shape=(len(args.annotations),), name=args.annotation_set) - annos_normed = BatchNormalization(axis=channel_axis)(annotations) - annos_normed_x = Dense(units=40, kernel_initializer='normal', activation='relu')(annos_normed) - - x = layers.concatenate([x, annos_normed_x], axis=channel_axis) - x = Dense(units=40, kernel_initializer='normal', activation='relu')(x) - x = Dropout(fc_dropout)(x) - x = layers.concatenate([x, annos_normed], axis=channel_axis) - - prob_output = Dense(units=len(args.labels), kernel_initializer='glorot_normal', activation='softmax')(x) - - model = Model(inputs=[reference, annotations], outputs=[prob_output]) - - adamo = Adam(lr=0.0001, beta_1=0.9, beta_2=0.999, epsilon=1e-08, clipnorm=1.) - - model.compile(optimizer=adamo, loss='categorical_crossentropy', metrics=get_metrics(args.labels)) - model.summary() - - if os.path.exists(args.weights_hd5): - model.load_weights(args.weights_hd5, by_name=True) - print('Loaded model weights from:', args.weights_hd5) - - return model - - - -def build_read_tensor_2d_and_annotations_model(args): - '''Build Read Tensor 2d CNN model with variant annotations mixed in for classifying variants. - - 2d Convolutions followed by dense connection mixed with annotation values. - Dynamically sets input channels based on args via defines.total_input_channels_from_args(args) - Uses the functional API. Supports theano or tensorflow channel ordering via K.image_data_format(). - Prints out model summary. - - Arguments - args.window_size: Length in base-pairs of sequence centered at the variant to use as input. - args.labels: The output labels (e.g. SNP, NOT_SNP, INDEL, NOT_INDEL) - - Returns - The keras model - ''' - in_channels = tensor_maps.total_input_channels_from_args(args) - - if K.image_data_format() == 'channels_last': - in_shape = (args.read_limit, args.window_size, in_channels) - concat_axis = -1 - else: - in_shape = (in_channels, args.read_limit, args.window_size) - concat_axis = 1 - - read_tensor = Input(shape=in_shape, name=args.tensor_name) - - read_conv_width = 16 - conv_dropout = 0.2 - fc_dropout = 0.3 - x = Conv2D(216, (read_conv_width, 1), padding='valid', activation="relu")(read_tensor) - x = Conv2D(160, (1, read_conv_width), padding='valid', activation="relu")(x) - x = Conv2D(128, (read_conv_width, 1), padding='valid', activation="relu")(x) - x = MaxPooling2D((2,1))(x) - x = Conv2D(96, (1, read_conv_width), padding='valid', activation="relu")(x) - x = MaxPooling2D((2,1))(x) - x = Dropout(conv_dropout)(x) - x = Conv2D(64, (read_conv_width, 1), padding='valid', activation="relu")(x) - x = MaxPooling2D((2,1))(x) - x = Dropout(conv_dropout)(x) - - x = Flatten()(x) - - # Mix the variant annotations in - annotations = Input(shape=(len(args.annotations),), name=args.annotation_set) - annotations_bn = BatchNormalization(axis=1)(annotations) - alt_input_mlp = Dense(units=16, kernel_initializer='glorot_normal', activation='relu')(annotations_bn) - x = layers.concatenate([x, alt_input_mlp], axis=concat_axis) - - x = Dense(units=32, kernel_initializer='glorot_normal', activation='relu')(x) - x = layers.concatenate([x, annotations_bn], axis=concat_axis) - x = Dropout(fc_dropout)(x) - - prob_output = Dense(units=len(args.labels), kernel_initializer='glorot_normal', activation='softmax')(x) - - model = Model(inputs=[read_tensor, annotations], outputs=[prob_output]) - - adamo = Adam(lr=0.0001, beta_1=0.9, beta_2=0.999, epsilon=1e-08, clipnorm=1.) - model.compile(loss='categorical_crossentropy', optimizer=adamo, metrics=get_metrics(args.labels)) - - model.summary() - - if os.path.exists(args.weights_hd5): - model.load_weights(args.weights_hd5, by_name=True) - print('Loaded model weights from:', args.weights_hd5) - - return model - - -def build_tiny_2d_annotation_model(args): +def build_default_1d_annotation_model(args): + return build_reference_annotation_1d_model_from_args(args, + conv_width=7, + conv_layers=[256, 216, 128, 64, 32], + conv_dropout=0.1, + conv_batch_normalize=True, + spatial_dropout=True, + max_pools=[], + padding='same', + annotation_units=64, + annotation_shortcut=True, + fc_layers=[64, 64], + fc_dropout=0.2, + annotation_batch_normalize=True, + fc_batch_normalize=False) + + +def build_1d_annotation_model_from_args(args): + return build_reference_annotation_1d_model_from_args(args, + conv_width=args.conv_width, + conv_layers=args.conv_layers, + conv_dropout=args.conv_dropout, + conv_batch_normalize=args.conv_batch_normalize, + spatial_dropout=args.spatial_dropout, + max_pools=args.max_pools, + padding=args.padding, + annotation_units=args.annotation_units, + annotation_shortcut=args.annotation_shortcut, + fc_layers=args.fc_layers, + fc_dropout=args.fc_dropout, + fc_batch_normalize=args.fc_batch_normalize) + + +def build_2d_annotation_model_from_args(args): return read_tensor_2d_annotation_model_from_args(args, - conv_width = 11, - conv_height = 5, - conv_layers = [32, 32], - conv_dropout = 0.0, - spatial_dropout = False, - max_pools = [(2,1),(8,1)], - padding='valid', - annotation_units = 10, - annotation_shortcut = False, - fc_layers = [16], - fc_dropout = 0.0) - - -def build_small_2d_annotation_model(args): + conv_width = args.conv_width, + conv_height = args.conv_height, + conv_layers = args.conv_layers, + conv_dropout = args.conv_dropout, + conv_batch_normalize = args.conv_batch_normalize, + spatial_dropout = args.spatial_dropout, + max_pools = args.max_pools, + padding = args.padding, + annotation_units = args.annotation_units, + annotation_shortcut = args.annotation_shortcut, + fc_layers = args.fc_layers, + fc_dropout = args.fc_dropout, + fc_batch_normalize = args.fc_batch_normalize) + + +def build_default_2d_annotation_model(args): return read_tensor_2d_annotation_model_from_args(args, conv_width = 25, conv_height = 25, conv_layers = [64, 48, 32, 24], - conv_dropout = 0.0, + conv_dropout = 0.1, conv_batch_normalize = False, - spatial_dropout = False, + spatial_dropout = True, max_pools = [(3,1),(3,1)], padding='valid', annotation_units = 64, annotation_shortcut = False, fc_layers = [24], - fc_dropout = 0.0, + fc_dropout = 0.3, fc_batch_normalize = False) @@ -302,11 +219,11 @@ def read_tensor_2d_annotation_model_from_args(args, cur_kernel = (conv_width, conv_height) if conv_batch_normalize: - x = Conv2D(f, cur_kernel, activation='linear', padding=padding, kernel_initializer=kernel_initializer)(x) + x = Conv2D(int(f), cur_kernel, activation='linear', padding=padding, kernel_initializer=kernel_initializer)(x) x = BatchNormalization(axis=concat_axis)(x) x = Activation('relu')(x) else: - x = Conv2D(f, cur_kernel, activation='relu', padding=padding, kernel_initializer=kernel_initializer)(x) + x = Conv2D(int(f), cur_kernel, activation='relu', padding=padding, kernel_initializer=kernel_initializer)(x) if conv_dropout > 0 and spatial_dropout: x = SpatialDropout2D(conv_dropout)(x) @@ -386,6 +303,103 @@ def read_tensor_2d_annotation_model_from_args(args, return model + +def build_reference_annotation_1d_model_from_args(args, + conv_width = 6, + conv_layers = [128, 128, 128, 128], + conv_dropout = 0.0, + conv_batch_normalize = False, + spatial_dropout = True, + max_pools = [], + padding='valid', + activation = 'relu', + annotation_units = 16, + annotation_shortcut = False, + annotation_batch_normalize = True, + fc_layers = [64], + fc_dropout = 0.0, + fc_batch_normalize = False, + fc_initializer = 'glorot_normal', + kernel_initializer = 'glorot_normal', + alpha_dropout = False + ): + '''Build Reference 1d CNN model for classifying variants. + + Architecture specified by parameters. + Dynamically sets input channels based on args via defines.total_input_channels_from_args(args) + Uses the functional API. + Prints out model summary. + + Arguments + args.annotations: The variant annotations, perhaps from a HaplotypeCaller VCF. + args.labels: The output labels (e.g. SNP, NOT_SNP, INDEL, NOT_INDEL) + + Returns + The keras model + ''' + in_channels = tensor_maps.total_input_channels_from_args(args) + concat_axis = -1 + x = reference = Input(shape=(args.window_size, in_channels), name=args.tensor_name) + + max_pool_diff = len(conv_layers)-len(max_pools) + for i,c in enumerate(conv_layers): + + if conv_batch_normalize: + x = Conv1D(filters=c, kernel_size=conv_width, activation='linear', padding=padding, kernel_initializer=kernel_initializer)(x) + x = BatchNormalization(axis=concat_axis)(x) + x = Activation(activation)(x) + else: + x = Conv1D(filters=c, kernel_size=conv_width, activation=activation, padding=padding, kernel_initializer=kernel_initializer)(x) + + if conv_dropout > 0 and alpha_dropout: + x = AlphaDropout(conv_dropout)(x) + elif conv_dropout > 0 and spatial_dropout: + x = SpatialDropout1D(conv_dropout)(x) + elif conv_dropout > 0: + x = Dropout(conv_dropout)(x) + + if i >= max_pool_diff: + x = MaxPooling1D(max_pools[i-max_pool_diff])(x) + + f = Flatten()(x) + + annotations = annotations_in = Input(shape=(len(args.annotations),), name=args.annotation_set) + if annotation_batch_normalize: + annotations_in = BatchNormalization(axis=concat_axis)(annotations_in) + annotation_mlp = Dense(units=annotation_units, kernel_initializer=fc_initializer, activation=activation)(annotations_in) + + x = layers.concatenate([f, annotation_mlp], axis=1) + for fc in fc_layers: + if fc_batch_normalize: + x = Dense(units=fc, activation='linear', kernel_initializer=fc_initializer)(x) + x = BatchNormalization(axis=1)(x) + x = Activation(activation)(x) + else: + x = Dense(units=fc, activation=activation, kernel_initializer=fc_initializer)(x) + + if fc_dropout > 0 and alpha_dropout: + x = AlphaDropout(fc_dropout)(x) + elif fc_dropout > 0: + x = Dropout(fc_dropout)(x) + + if annotation_shortcut: + x = layers.concatenate([x, annotations_in], axis=1) + + prob_output = Dense(units=len(args.labels), activation='softmax', name='softmax_predictions')(x) + + model = Model(inputs=[reference, annotations], outputs=[prob_output]) + + adam = Adam(lr=0.0001, beta_1=0.9, beta_2=0.999, epsilon=1e-08, clipnorm=1.) + model.compile(optimizer=adam, loss='categorical_crossentropy', metrics=get_metrics(args.labels)) + model.summary() + + if os.path.exists(args.weights_hd5): + model.load_weights(args.weights_hd5, by_name=True) + print('Loaded model weights from:', args.weights_hd5) + + return model + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~ Optimizing ~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -411,16 +425,18 @@ def train_model_from_generators(args, model, generate_train, generate_valid, sav ''' if not os.path.exists(os.path.dirname(save_weight_hd5)): os.makedirs(os.path.dirname(save_weight_hd5)) + serialize_model_semantics(args, save_weight_hd5) history = model.fit_generator(generate_train, steps_per_epoch=args.training_steps, epochs=args.epochs, verbose=1, validation_steps=args.validation_steps, validation_data=generate_valid, callbacks=get_callbacks(args, save_weight_hd5)) + print('Training complete, model weights saved at: %s' % save_weight_hd5) if args.image_dir: plots.plot_metric_history(history, plots.weight_path_to_title(save_weight_hd5), prefix=args.image_dir) - serialize_model_semantics(args, save_weight_hd5) - print('Model weights saved at: %s' % save_weight_hd5) + + return model diff --git a/src/main/python/org/broadinstitute/hellbender/vqsr_cnn/vqsr_cnn/tensor_maps.py b/src/main/python/org/broadinstitute/hellbender/vqsr_cnn/vqsr_cnn/tensor_maps.py index 3ecef9ab2bc..ee9fc749d7a 100644 --- a/src/main/python/org/broadinstitute/hellbender/vqsr_cnn/vqsr_cnn/tensor_maps.py +++ b/src/main/python/org/broadinstitute/hellbender/vqsr_cnn/vqsr_cnn/tensor_maps.py @@ -6,7 +6,7 @@ def get_tensor_channel_map_from_args(args): '''Return tensor mapping dict given args.tensor_name''' - if not args.tensor_name: + if args.tensor_name is None: return None if 'read_tensor' == args.tensor_name: diff --git a/src/main/resources/large/cnn_score_variants/small_2d.hd5 b/src/main/resources/large/cnn_score_variants/small_2d.hd5 index 5c007f29b26..deb36d22e04 100644 --- a/src/main/resources/large/cnn_score_variants/small_2d.hd5 +++ b/src/main/resources/large/cnn_score_variants/small_2d.hd5 @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:acd9efef4700826a8244a550c40b42bf95dbeafeac5a05bab0ec8d0f353bb80c -size 6410504 +oid sha256:6f663a2fdbcde0addc5cb755f7af5d4c19bed92dccfd20e25b2acf2bc8c2ca7c +size 2163096 diff --git a/src/main/resources/large/cnn_score_variants/small_2d.json b/src/main/resources/large/cnn_score_variants/small_2d.json index 648565cb199..c35cfbdfcae 100644 --- a/src/main/resources/large/cnn_score_variants/small_2d.json +++ b/src/main/resources/large/cnn_score_variants/small_2d.json @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:226c0344e590051c67a61873733bfafef3db5a0b8e95015b0b688a542199142b -size 720 +oid sha256:e38e09cfe7b7ffbc80dce4972bc9c382148520147d46738a3f6f3235b2d876c6 +size 758 diff --git a/src/main/resources/org/broadinstitute/hellbender/tools/walkers/vqsr/training.py b/src/main/resources/org/broadinstitute/hellbender/tools/walkers/vqsr/training.py index 4fe90ac70fd..574a50ae97e 100644 --- a/src/main/resources/org/broadinstitute/hellbender/tools/walkers/vqsr/training.py +++ b/src/main/resources/org/broadinstitute/hellbender/tools/walkers/vqsr/training.py @@ -1,10 +1,8 @@ # Imports import os -import sys import vcf import math import h5py -import time import pysam import vqsr_cnn import numpy as np @@ -14,20 +12,21 @@ # Keras Imports import keras.backend as K + def run(): args = vqsr_cnn.parse_args() if 'write_reference_and_annotation_tensors' == args.mode: write_reference_and_annotation_tensors(args) elif 'write_read_and_annotation_tensors' == args.mode: write_read_and_annotation_tensors(args) - elif 'train_on_reference_tensors_and_annotations' == args.mode: - train_on_reference_tensors_and_annotations(args) - elif 'train_on_read_tensors_and_annotations' == args.mode: - train_on_read_tensors_and_annotations(args) - elif 'train_tiny_model_on_read_tensors_and_annotations' == args.mode: - train_tiny_model_on_read_tensors_and_annotations(args) - elif 'train_small_model_on_read_tensors_and_annotations' == args.mode: - train_small_model_on_read_tensors_and_annotations(args) + elif 'train_default_1d_model' == args.mode: + train_default_1d_model(args) + elif 'train_default_2d_model' == args.mode: + train_default_2d_model(args) + elif 'train_args_model_on_read_tensors_and_annotations' == args.mode: + train_args_model_on_read_tensors_and_annotations(args) + elif 'train_args_model_on_reference_and_annotations' == args.mode: + train_args_model_on_read_tensors_and_annotations(args) else: raise ValueError('Unknown training mode:', args.mode) @@ -37,21 +36,15 @@ def write_reference_and_annotation_tensors(args, include_dna=True, include_annot raise ValueError('Unknown tensor name:', args.tensor_name, '1d maps must be in:', str(vqsr_cnn.TENSOR_MAPS_1D)) record_dict = SeqIO.to_dict(SeqIO.parse(args.reference_fasta, "fasta")) - if os.path.splitext(args.input_vcf)[-1].lower() == '.gz': - vcf_reader = vcf.Reader(open(args.input_vcf, 'rb')) - else: - vcf_reader = vcf.Reader(open(args.input_vcf, 'r')) - if os.path.splitext(args.train_vcf)[-1].lower() == '.gz': - vcf_ram = vcf.Reader(open(args.train_vcf, 'rb')) - else: - vcf_ram = vcf.Reader(open(args.train_vcf, 'r')) + vcf_reader = get_vcf_reader(args.input_vcf) + vcf_ram = get_vcf_reader(args.train_vcf) bed_dict = bed_file_to_dict(args.bed_file) stats = Counter() if args.chrom: - variants = vcf_reader.fetch(args.chrom, args.start_pos, args.end_pos) + variants = vcf_reader.fetch(args.chrom, args.start_pos, args.end_pos) else: variants = vcf_reader @@ -106,7 +99,6 @@ def write_reference_and_annotation_tensors(args, include_dna=True, include_annot print(k, ' has:', stats[k]) - def write_read_and_annotation_tensors(args, include_annotations=True, pileup=False): '''Create tensors structured as tensor map of reads organized by labels in the data directory. @@ -134,8 +126,8 @@ def write_read_and_annotation_tensors(args, include_annotations=True, pileup=Fal samfile = pysam.AlignmentFile(args.bam_file, "rb") bed_dict = bed_file_to_dict(args.bed_file) record_dict = SeqIO.to_dict(SeqIO.parse(args.reference_fasta, "fasta")) - vcf_reader = vcf.Reader(open(args.input_vcf, 'r')) - vcf_ram = vcf.Reader(open(args.train_vcf, 'rb')) + vcf_reader = get_vcf_reader(args.input_vcf) + vcf_ram = get_vcf_reader(args.train_vcf) if args.chrom: variants = vcf_reader.fetch(args.chrom, args.start_pos, args.end_pos) @@ -205,7 +197,7 @@ def write_read_and_annotation_tensors(args, include_annotations=True, pileup=Fal print('Done generating tensors. Last variant:', str(variant), 'from vcf:', args.input_vcf) -def train_on_reference_tensors_and_annotations(args): +def train_default_1d_model(args): '''Train a 1D Convolution plus reference tracks and MLP Annotation architecture. Arguments: @@ -223,7 +215,7 @@ def train_on_reference_tensors_and_annotations(args): generate_valid = dna_annotation_generator(args, valid_paths) weight_path = vqsr_cnn.weight_path_from_args(args) - model = vqsr_cnn.build_reference_annotation_model(args) + model = vqsr_cnn.build_default_1d_annotation_model(args) model = vqsr_cnn.train_model_from_generators(args, model, generate_train, generate_valid, weight_path) test = load_dna_annotations_positions_from_class_dirs(args, test_paths, per_class_max=args.samples) @@ -231,8 +223,7 @@ def train_on_reference_tensors_and_annotations(args): vqsr_cnn.plot_roc_per_class(model, [test[0], test[1]], test[2], args.labels, args.id, prefix=args.image_dir) - -def train_on_read_tensors_and_annotations(args): +def train_default_2d_model(args): '''Trains a reference, read, and annotation CNN architecture on tensors at the supplied data directory. This architecture looks at reads, read flags, reference sequence, and variant annotations. @@ -251,7 +242,7 @@ def train_on_read_tensors_and_annotations(args): generate_valid = tensor_generator_from_label_dirs_and_args(args, valid_paths) weight_path = vqsr_cnn.weight_path_from_args(args) - model = vqsr_cnn.build_read_tensor_2d_and_annotations_model(args) + model = vqsr_cnn.build_default_2d_annotation_model(args) model = vqsr_cnn.train_model_from_generators(args, model, generate_train, generate_valid, weight_path) test = load_tensors_and_annotations_from_class_dirs(args, test_paths, per_class_max=args.samples) @@ -260,7 +251,7 @@ def train_on_read_tensors_and_annotations(args): prefix=args.image_dir, batch_size=args.batch_size) -def train_tiny_model_on_read_tensors_and_annotations(args): +def train_args_model_on_read_tensors_and_annotations(args): '''Trains a reference, read, and annotation CNN architecture on tensors at the supplied data directory. This architecture looks at reads, read flags, reference sequence, and variant annotations. @@ -279,7 +270,7 @@ def train_tiny_model_on_read_tensors_and_annotations(args): generate_valid = tensor_generator_from_label_dirs_and_args(args, valid_paths) weight_path = vqsr_cnn.weight_path_from_args(args) - model = vqsr_cnn.build_tiny_2d_annotation_model(args) + model = vqsr_cnn.build_2d_annotation_model_from_args(args) model = vqsr_cnn.train_model_from_generators(args, model, generate_train, generate_valid, weight_path) test = load_tensors_and_annotations_from_class_dirs(args, test_paths, per_class_max=args.samples) @@ -712,6 +703,7 @@ def get_true_label(allele, variant, bed_dict, truth_vcf, stats): NOT_INDEL if variant is indel and not in truth vcf ''' in_bed = in_bed_file(bed_dict, variant.CHROM, variant.POS) + if allele_in_vcf(allele, variant, truth_vcf) and in_bed: class_prefix = '' elif in_bed: @@ -815,9 +807,14 @@ def bed_file_to_dict(bed_file): def in_bed_file(bed_dict, contig, pos): - # Exclusive + + if not contig in bed_dict: + return False + lows = bed_dict[contig][0] ups = bed_dict[contig][1] + + # Half open interval [#,#) return np.any((lows <= pos) & (pos < ups)) @@ -832,7 +829,14 @@ def allele_in_vcf(allele, variant, vcf_ram): Returns variant if it is found otherwise None ''' - variants = vcf_ram.fetch(variant.CHROM, variant.POS-1, variant.POS) + if not variant.CHROM in vcf_ram.contigs: + return None + + try: + variants = vcf_ram.fetch(variant.CHROM, variant.POS-1, variant.POS) + except ValueError as e: + print('catching value error on fetch') + return None for v in variants: if v.CHROM == variant.CHROM and v.POS == variant.POS and allele in v.ALT: @@ -1145,6 +1149,12 @@ def plain_name(full_name): name = os.path.basename(full_name) return name.split('.')[0] +def get_vcf_reader(my_vcf): + if os.path.splitext(my_vcf)[-1].lower() == '.gz': + return vcf.Reader(open(my_vcf, 'rb')) + else: + return vcf.Reader(open(my_vcf, 'r')) + # Back to the top! if "__main__" == __name__: diff --git a/src/test/java/org/broadinstitute/hellbender/tools/walkers/validation/EvaluateInfoFieldConcordanceIntegrationTest.java b/src/test/java/org/broadinstitute/hellbender/tools/walkers/validation/EvaluateInfoFieldConcordanceIntegrationTest.java new file mode 100644 index 00000000000..75a247f14bf --- /dev/null +++ b/src/test/java/org/broadinstitute/hellbender/tools/walkers/validation/EvaluateInfoFieldConcordanceIntegrationTest.java @@ -0,0 +1,65 @@ +package org.broadinstitute.hellbender.tools.walkers.validation; + +import htsjdk.variant.variantcontext.VariantContext; +import org.broadinstitute.hellbender.CommandLineProgramTest; +import org.broadinstitute.hellbender.engine.AbstractConcordanceWalker; +import org.broadinstitute.hellbender.testutils.ArgumentsBuilder; +import org.broadinstitute.hellbender.utils.variant.GATKVCFConstants; +import org.testng.Assert; +import org.testng.annotations.DataProvider; +import org.testng.annotations.Test; + +import java.nio.file.Path; + +public class EvaluateInfoFieldConcordanceIntegrationTest extends CommandLineProgramTest { + final double epsilon = 1e-3; + + @Test(dataProvider= "infoConcordanceDataProvider") + public void testInfoConcordanceFromProvider(String inputVcf1, String inputVcf2, String evalKey, String truthKey, + double snpMean, double snpSTD, + double indelMean, double indelSTD) throws Exception { + final Path summary = createTempPath("summary", ".txt"); + final ArgumentsBuilder argsBuilder = new ArgumentsBuilder(); + argsBuilder.addArgument(AbstractConcordanceWalker.EVAL_VARIANTS_SHORT_NAME, inputVcf1) + .addArgument(AbstractConcordanceWalker.TRUTH_VARIANTS_LONG_NAME, inputVcf2) + .addArgument("eval-info-key", evalKey) + .addArgument("truth-info-key", truthKey) + .addArgument(EvaluateInfoFieldConcordance.SUMMARY_LONG_NAME, summary.toString()); + runCommandLine(argsBuilder); + + try(InfoConcordanceRecord.InfoConcordanceReader + reader = new InfoConcordanceRecord.InfoConcordanceReader(summary)) { + InfoConcordanceRecord snpRecord = reader.readRecord(); + InfoConcordanceRecord indelRecord = reader.readRecord(); + + Assert.assertEquals(snpRecord.getVariantType(), VariantContext.Type.SNP); + Assert.assertEquals(indelRecord.getVariantType(), VariantContext.Type.INDEL); + + Assert.assertEquals(snpRecord.getMean(), snpMean, epsilon); + Assert.assertEquals(snpRecord.getStd(), snpSTD, epsilon); + Assert.assertEquals(indelRecord.getMean(), indelMean, epsilon); + Assert.assertEquals(indelRecord.getStd(), indelSTD, epsilon); + } + } + + @DataProvider + public Object[][] infoConcordanceDataProvider() { + return new Object [][]{ + new Object[]{ + largeFileTestDir + "VQSR/expected/chr20_tiny_tf_python_gpu2.vcf", + largeFileTestDir + "VQSR/expected/chr20_tiny_tf_python_gpu2.vcf", + GATKVCFConstants.CNN_2D_KEY, + "NOVA_HISEQ_MIX_SMALL", + 0.108878, 0.229415, 0.067024, 0.142705 // numbers verified by manual inspection + + }, + new Object[]{ + largeFileTestDir + "VQSR/expected/chr20_tiny_tf_python_cpu.vcf", + largeFileTestDir + "VQSR/expected/chr20_tiny_th_python_gpu.vcf", + GATKVCFConstants.CNN_1D_KEY, + "NOVA_HISEQ_MIX_1D_RAB", + 0.000256, 0.000136, 0.000240, 0.000153 // numbers verified by manual inspection + } + }; + } +} diff --git a/src/test/resources/large/VQSR/expected/chr20_tiny_tf_python_cpu.vcf b/src/test/resources/large/VQSR/expected/chr20_tiny_tf_python_cpu.vcf new file mode 100644 index 00000000000..e4941b9bd9b --- /dev/null +++ b/src/test/resources/large/VQSR/expected/chr20_tiny_tf_python_cpu.vcf @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b4e21b9c31936d59a5b4a52756254d932b9b63db418a918567538e70215c52e9 +size 171577 diff --git a/src/test/resources/large/VQSR/expected/chr20_tiny_tf_python_gpu2.vcf b/src/test/resources/large/VQSR/expected/chr20_tiny_tf_python_gpu2.vcf new file mode 100644 index 00000000000..b14fb5b7bf7 --- /dev/null +++ b/src/test/resources/large/VQSR/expected/chr20_tiny_tf_python_gpu2.vcf @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6338d032989dc6b9168bc94ee1da85c666ee1de995c3859046a65b3dba610350 +size 177316 diff --git a/src/test/resources/large/VQSR/expected/chr20_tiny_th_python_cpu.vcf b/src/test/resources/large/VQSR/expected/chr20_tiny_th_python_cpu.vcf new file mode 100755 index 00000000000..f8c92ca64b7 --- /dev/null +++ b/src/test/resources/large/VQSR/expected/chr20_tiny_th_python_cpu.vcf @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3f20647108452f717b84c7403d510d919fbbd83c7ce2ddf5545534fe6e8ee08f +size 155325 diff --git a/src/test/resources/large/VQSR/expected/chr20_tiny_th_python_gpu.vcf b/src/test/resources/large/VQSR/expected/chr20_tiny_th_python_gpu.vcf new file mode 100644 index 00000000000..333e6644636 --- /dev/null +++ b/src/test/resources/large/VQSR/expected/chr20_tiny_th_python_gpu.vcf @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ebffb2b960005d08830e3b879350ec1a75f1121a3be1d0205a79cbbe24cd94e4 +size 169754 diff --git a/src/test/resources/large/VQSR/expected/cnn_2d_chr20_subset_expected.vcf b/src/test/resources/large/VQSR/expected/cnn_2d_chr20_subset_expected.vcf index 438dc0b4f67..8de9528ce1e 100644 --- a/src/test/resources/large/VQSR/expected/cnn_2d_chr20_subset_expected.vcf +++ b/src/test/resources/large/VQSR/expected/cnn_2d_chr20_subset_expected.vcf @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:8f4a4ccf2ebbaed80ee40313f0f50be4287b947fc1e331159de871f7f0856c7c -size 150076 +oid sha256:6fe2f366e19080280fc8666b9fdc9f98e1143aa71b3ee27bd8002ccd6d4055d3 +size 150048 diff --git a/src/test/resources/large/VQSR/expected/nn_outy2d.vcf b/src/test/resources/large/VQSR/expected/nn_outy2d.vcf new file mode 100644 index 00000000000..073956299af --- /dev/null +++ b/src/test/resources/large/VQSR/expected/nn_outy2d.vcf @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2b58efd72c6b0af998ff7c8450f091119a46a3c0f7d4760e082709770f776d46 +size 163170 diff --git a/src/test/resources/large/VQSR/small_2d.hd5 b/src/test/resources/large/VQSR/small_2d.hd5 index 5c007f29b26..deb36d22e04 100644 --- a/src/test/resources/large/VQSR/small_2d.hd5 +++ b/src/test/resources/large/VQSR/small_2d.hd5 @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:acd9efef4700826a8244a550c40b42bf95dbeafeac5a05bab0ec8d0f353bb80c -size 6410504 +oid sha256:6f663a2fdbcde0addc5cb755f7af5d4c19bed92dccfd20e25b2acf2bc8c2ca7c +size 2163096 diff --git a/src/test/resources/large/VQSR/small_2d.json b/src/test/resources/large/VQSR/small_2d.json index 648565cb199..c35cfbdfcae 100644 --- a/src/test/resources/large/VQSR/small_2d.json +++ b/src/test/resources/large/VQSR/small_2d.json @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:226c0344e590051c67a61873733bfafef3db5a0b8e95015b0b688a542199142b -size 720 +oid sha256:e38e09cfe7b7ffbc80dce4972bc9c382148520147d46738a3f6f3235b2d876c6 +size 758