diff --git a/phylo/src/substitution_models/dna_models/gtr.rs b/phylo/src/substitution_models/dna_models/gtr.rs index 8296b3d..6410e71 100644 --- a/phylo/src/substitution_models/dna_models/gtr.rs +++ b/phylo/src/substitution_models/dna_models/gtr.rs @@ -6,17 +6,17 @@ use argmin::solver::brent::BrentOpt; use log::info; use crate::evolutionary_models::EvolutionaryModelInfo; +use crate::substitution_models::SubstParams; use crate::substitution_models::{ dna_models::{make_dna_model, make_pi, DNASubstModel, DNASubstParams}, SubstMatrix, SubstitutionLikelihoodCost, SubstitutionModelInfo, }; use crate::Result; -pub fn gtr(model_params: &[f64]) -> Result { - let gtr_params = parse_gtr_parameters(model_params)?; +pub fn gtr(gtr_params: DNASubstParams) -> DNASubstModel { info!("Setting up gtr with rates: {}", gtr_params.print_as_gtr()); let q = gtr_q(>r_params); - Ok(make_dna_model(gtr_params, q)) + make_dna_model(gtr_params, q) } pub fn parse_gtr_parameters(model_params: &[f64]) -> Result { @@ -86,7 +86,7 @@ fn gtr_q(gtr: &DNASubstParams) -> SubstMatrix { impl DNASubstModel { pub(crate) fn reset_gtr(&mut self, params: &DNASubstParams) { - self.params = ((*params).clone()).into(); + self.params = SubstParams::DNA((*params).clone()); self.q = gtr_q(params); } } @@ -116,7 +116,9 @@ impl CostFunction for GTRParamOptimiser<'_> { type Output = f64; fn cost(&self, param: &Self::Param) -> Result { - let mut params = parse_gtr_parameters(self.base_model.params.as_slice())?; + let SubstParams::DNA(mut params) = self.base_model.params.clone() else { + unreachable!() + }; match self.parameter { ParamEnum::Pit | ParamEnum::Pic | ParamEnum::Pia | ParamEnum::Pig => { bail!("Cannot optimise frequencies for now.") @@ -158,13 +160,7 @@ impl<'a> GTRModelOptimiser<'a> { } pub fn optimise_parameters(&self) -> Result<(u32, DNASubstParams, f64)> { - let epsilon = 1e-10; - let params = self.base_model.params.clone(); - let model = gtr(params.as_slice())?; - let mut logl = f64::NEG_INFINITY; - let mut new_logl = 0.0; - let mut gtr_params = parse_gtr_parameters(params.as_slice())?; - let mut iters = 0; + let epsilon = 1e-5; let params_to_optimise = [ ParamEnum::Rag, ParamEnum::Rca, @@ -173,6 +169,14 @@ impl<'a> GTRModelOptimiser<'a> { ParamEnum::Rtc, ParamEnum::Rtg, ]; + let SubstParams::DNA(mut gtr_params) = self.base_model.params.clone() else { + unreachable!() + }; + let model = gtr(gtr_params.clone()); + let mut logl = f64::NEG_INFINITY; + let mut new_logl = 0.0; + let mut iters = 0; + while (logl - new_logl).abs() > epsilon { println!("Iteration: {}", iters); logl = new_logl; @@ -226,7 +230,7 @@ mod gtr_optimisation_tests { fn check_parameter_optimisation_gtr() { // Original params from paml: 0.88892 0.03190 0.00001 0.07102 0.02418 let info = phyloinfo_from_files( - PathBuf::from("./data/sim/gtr/gtr.fasta"), + PathBuf::from("./data/sim/GTR/gtr.fasta"), PathBuf::from("./data/sim/tree.newick"), ) .unwrap(); @@ -242,7 +246,7 @@ mod gtr_optimisation_tests { let mut tmp_info = SubstitutionModelInfo::new(likelihood.info, &model).unwrap(); let unopt_logl = likelihood.compute_log_likelihood(&model, &mut tmp_info); assert_relative_eq!(unopt_logl, -3474.48083, epsilon = 1.0e-5); - let (iters, params, logl) = GTRModelOptimiser::new(&likelihood, &model) + let (_, _, logl) = GTRModelOptimiser::new(&likelihood, &model) .optimise_parameters() .unwrap(); assert!(logl > unopt_logl); diff --git a/phylo/src/substitution_models/dna_models/hky.rs b/phylo/src/substitution_models/dna_models/hky.rs index 8f4d34b..41faf43 100644 --- a/phylo/src/substitution_models/dna_models/hky.rs +++ b/phylo/src/substitution_models/dna_models/hky.rs @@ -7,14 +7,13 @@ use crate::substitution_models::dna_models::{ use crate::substitution_models::FreqVector; use crate::Result; -pub fn hky(model_params: &[f64]) -> Result { - let hky_params = parse_hky_parameters(model_params)?; +pub fn hky(hky_params: DNASubstParams) -> DNASubstModel { info!( "Setting up hky with parameters {}", hky_params.print_as_hky() ); let q = tn93_q(&hky_params); - Ok(make_dna_model(hky_params, q)) + make_dna_model(hky_params, q) } pub fn parse_hky_parameters(model_params: &[f64]) -> Result { diff --git a/phylo/src/substitution_models/dna_models/jc69.rs b/phylo/src/substitution_models/dna_models/jc69.rs index cf6c7b6..0ec96ce 100644 --- a/phylo/src/substitution_models/dna_models/jc69.rs +++ b/phylo/src/substitution_models/dna_models/jc69.rs @@ -6,13 +6,12 @@ use crate::substitution_models::{ }; use crate::Result; -pub fn jc69(model_params: &[f64]) -> Result { - let jc69_params = parse_jc69_parameters(model_params)?; +pub fn jc69(jc69_params: DNASubstParams) -> DNASubstModel { info!( "Setting up jc69 with parameters: {}", jc69_params.print_as_jc69() ); - Ok(make_dna_model(jc69_params, jc69_q())) + make_dna_model(jc69_params, jc69_q()) } pub fn parse_jc69_parameters(model_params: &[f64]) -> Result { diff --git a/phylo/src/substitution_models/dna_models/k80.rs b/phylo/src/substitution_models/dna_models/k80.rs index 1b58ea6..df34ad1 100644 --- a/phylo/src/substitution_models/dna_models/k80.rs +++ b/phylo/src/substitution_models/dna_models/k80.rs @@ -1,24 +1,25 @@ use std::ops::Div; +use anyhow::bail; use argmin::core::{CostFunction, Executor}; use argmin::solver::brent::BrentOpt; use log::{info, warn}; use crate::evolutionary_models::EvolutionaryModelInfo; +use crate::substitution_models::SubstParams; use crate::substitution_models::{ dna_models::{dna_substitution_parameters::DNASubstParams, make_dna_model, DNASubstModel}, FreqVector, SubstMatrix, SubstitutionLikelihoodCost, SubstitutionModelInfo, }; use crate::Result; -pub fn k80(model_params: &[f64]) -> Result { - let k80_params = parse_k80_parameters(model_params)?; +pub fn k80(k80_params: DNASubstParams) -> DNASubstModel { info!( "Setting up k80 with parameters: {}", k80_params.print_as_k80() ); let q = k80_q(&k80_params); - Ok(make_dna_model(k80_params, q)) + make_dna_model(k80_params, q) } pub fn parse_k80_parameters(model_params: &[f64]) -> Result { @@ -80,9 +81,9 @@ pub fn k80_q(p: &DNASubstParams) -> SubstMatrix { } impl DNASubstModel { - pub(crate) fn reset_k80_q(&mut self, params: &DNASubstParams) { - self.params = ((*params).clone()).into(); - self.q = k80_q(params); + pub(crate) fn reset_k80_q(&mut self, params: DNASubstParams) { + self.q = k80_q(¶ms); + self.params = SubstParams::DNA(params); } } @@ -114,12 +115,12 @@ impl<'a> K80ModelOptimiser<'a> { pub fn optimise_parameters(&self) -> Result<(u32, DNASubstParams, f64)> { let epsilon = 1e-10; - let alpha = self.base_model.params[0]; - let beta = self.base_model.params[1]; - let mut model = k80(&[alpha, beta])?; + let SubstParams::DNA(mut k80_params) = self.base_model.params.clone() else { + unreachable!(); + }; + let mut model = k80(k80_params.clone()); let mut logl = f64::NEG_INFINITY; let mut new_logl = 0.0; - let mut k80_params = k80_params(alpha, beta); let mut iters = 0; while (logl - new_logl).abs() > epsilon { logl = new_logl; @@ -135,7 +136,7 @@ impl<'a> K80ModelOptimiser<'a> { model: &mut DNASubstModel, k80_params: &mut DNASubstParams, ) -> Result { - model.reset_k80_q(k80_params); + model.reset_k80_q(k80_params.clone()); let alpha_optimiser = K80ModelAlphaOptimiser { likelihood_cost: self.likelihood_cost, base_model: model, @@ -152,7 +153,7 @@ impl<'a> K80ModelOptimiser<'a> { model: &mut DNASubstModel, k80_params: &mut DNASubstParams, ) -> Result { - model.reset_k80_q(k80_params); + model.reset_k80_q(k80_params.clone()); let beta_optimiser = K80ModelBetaOptimiser { likelihood_cost: self.likelihood_cost, base_model: model, @@ -177,7 +178,12 @@ impl CostFunction for K80ModelAlphaOptimiser<'_> { fn cost(&self, param: &Self::Param) -> Result { let mut model = self.base_model.clone(); - model.reset_k80_q(&k80_params(*param, self.base_model.params[1])); + let SubstParams::DNA(mut k80_params) = model.params.clone() else { + bail!("Incorrect substitution model parameter type.") + }; + k80_params.rtc = *param; + k80_params.rag = *param; + model.reset_k80_q(k80_params); let mut tmp_info = SubstitutionModelInfo::new(self.likelihood_cost.info, &model)?; Ok(-self .likelihood_cost @@ -195,7 +201,14 @@ impl CostFunction for K80ModelBetaOptimiser<'_> { fn cost(&self, param: &Self::Param) -> Result { let mut model = self.base_model.clone(); - model.reset_k80_q(&k80_params(self.base_model.params[0], *param)); + let SubstParams::DNA(mut k80_params) = model.params.clone() else { + bail!("Incorrect substitution model parameter type.") + }; + k80_params.rta = *param; + k80_params.rtg = *param; + k80_params.rca = *param; + k80_params.rcg = *param; + model.reset_k80_q(k80_params); let mut tmp_info = SubstitutionModelInfo::new(self.likelihood_cost.info, &model)?; Ok(-self .likelihood_cost diff --git a/phylo/src/substitution_models/dna_models/mod.rs b/phylo/src/substitution_models/dna_models/mod.rs index 98be91c..6d40636 100644 --- a/phylo/src/substitution_models/dna_models/mod.rs +++ b/phylo/src/substitution_models/dna_models/mod.rs @@ -9,8 +9,8 @@ use crate::evolutionary_models::{EvolutionaryModel, EvolutionaryModelInfo}; use crate::likelihood::LikelihoodCostFunction; use crate::sequences::{charify, dna_alphabet, AMBIG, NUCLEOTIDES_STR}; use crate::substitution_models::{ - FreqVector, ParsimonyModel, SubstMatrix, SubstitutionLikelihoodCost, SubstitutionModel, - SubstitutionModelInfo, + FreqVector, ParsimonyModel, SubstMatrix, SubstParams, SubstitutionLikelihoodCost, + SubstitutionModel, SubstitutionModelInfo, }; use crate::{Result, Rounding}; @@ -61,13 +61,10 @@ fn dna_ambiguous_chars() -> HashMap> { } } -fn make_dna_model( - params: dna_substitution_parameters::DNASubstParams, - q: SubstMatrix, -) -> DNASubstModel { +fn make_dna_model(params: DNASubstParams, q: SubstMatrix) -> DNASubstModel { let pi = params.pi.clone(); DNASubstModel { - params: params.into(), + params: SubstParams::DNA(params), index: nucleotide_index(), q, pi, @@ -80,11 +77,11 @@ impl EvolutionaryModel<4> for DNASubstModel { Self: std::marker::Sized, { match model_name.to_uppercase().as_str() { - "JC69" => jc69(model_params), - "K80" => k80(model_params), - "HKY" => hky(model_params), - "TN93" => tn93(model_params), - "GTR" => gtr(model_params), + "JC69" => Ok(jc69(parse_jc69_parameters(model_params)?)), + "K80" => Ok(k80(parse_k80_parameters(model_params)?)), + "HKY" => Ok(hky(parse_hky_parameters(model_params)?)), + "TN93" => Ok(tn93(parse_tn93_parameters(model_params)?)), + "GTR" => Ok(gtr(parse_gtr_parameters(model_params)?)), _ => bail!("Unknown DNA model requested."), } } diff --git a/phylo/src/substitution_models/dna_models/tn93.rs b/phylo/src/substitution_models/dna_models/tn93.rs index fa80e7e..e35541a 100644 --- a/phylo/src/substitution_models/dna_models/tn93.rs +++ b/phylo/src/substitution_models/dna_models/tn93.rs @@ -12,14 +12,13 @@ use crate::substitution_models::{ }; use crate::Result; -pub fn tn93(model_params: &[f64]) -> Result { - let tn93_params = parse_tn93_parameters(model_params)?; +pub fn tn93(tn93_params: DNASubstParams) -> DNASubstModel { info!( "Setting up tn93 with parameters {}", tn93_params.print_as_tn93() ); let q = tn93_q(&tn93_params); - Ok(make_dna_model(tn93_params, q)) + make_dna_model(tn93_params, q) } pub fn parse_tn93_parameters(model_params: &[f64]) -> Result { diff --git a/phylo/src/substitution_models/mod.rs b/phylo/src/substitution_models/mod.rs index 55b477d..1704c9a 100644 --- a/phylo/src/substitution_models/mod.rs +++ b/phylo/src/substitution_models/mod.rs @@ -16,10 +16,16 @@ pub mod protein_models; pub type SubstMatrix = DMatrix; pub type FreqVector = DVector; +#[derive(Clone, Debug, PartialEq)] +pub enum SubstParams { + DNA(dna_models::DNASubstParams), + Protein(protein_models::ProteinSubstParams), +} + #[derive(Clone, Debug, PartialEq)] pub struct SubstitutionModel { index: [i32; 255], - pub params: Vec, + pub params: SubstParams, pub(crate) q: SubstMatrix, pub(crate) pi: FreqVector, } diff --git a/phylo/src/substitution_models/protein_models.rs b/phylo/src/substitution_models/protein_models.rs index 5fc947c..19cb007 100644 --- a/phylo/src/substitution_models/protein_models.rs +++ b/phylo/src/substitution_models/protein_models.rs @@ -8,8 +8,8 @@ use crate::evolutionary_models::EvolutionaryModel; use crate::likelihood::LikelihoodCostFunction; use crate::sequences::{charify, AMINOACIDS_STR}; use crate::substitution_models::{ - FreqVector, ParsimonyModel, SubstMatrix, SubstitutionLikelihoodCost, SubstitutionModel, - SubstitutionModelInfo, + FreqVector, ParsimonyModel, SubstMatrix, SubstParams, SubstitutionLikelihoodCost, + SubstitutionModel, SubstitutionModelInfo, }; use crate::{Result, Rounding}; @@ -20,6 +20,11 @@ pub type ProteinSubstModel = SubstitutionModel<20>; pub type ProteinLikelihoodCost<'a> = SubstitutionLikelihoodCost<'a, 20>; pub type ProteinSubstModelInfo = SubstitutionModelInfo<20>; +#[derive(Clone, Debug, PartialEq)] +pub struct ProteinSubstParams { + pub(crate) pi: FreqVector, +} + impl ProteinSubstModel { fn normalise(&mut self) { let factor = -(self.pi.transpose() * self.q.diagonal())[(0, 0)]; @@ -39,7 +44,7 @@ impl EvolutionaryModel<20> for ProteinSubstModel { _ => bail!("Unknown protein model requested."), }; let mut model = ProteinSubstModel { - params: vec![], + params: SubstParams::Protein(ProteinSubstParams { pi: pi.clone() }), index: aminoacid_index(), q, pi,