From 0cfcad50694e14df1465ab9b08aaf424dca8e3b1 Mon Sep 17 00:00:00 2001 From: Marc Paterno Date: Tue, 30 May 2023 09:52:02 -0500 Subject: [PATCH 01/34] Add test for linear bias systemaic --- .../statistic/source/number_counts.py | 4 ++- .../statistic/source/test_source.py | 31 +++++++++++++++++++ 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/firecrown/likelihood/gauss_family/statistic/source/number_counts.py b/firecrown/likelihood/gauss_family/statistic/source/number_counts.py index 593dd2cf..dd7ba1c2 100644 --- a/firecrown/likelihood/gauss_family/statistic/source/number_counts.py +++ b/firecrown/likelihood/gauss_family/statistic/source/number_counts.py @@ -43,7 +43,9 @@ class NumberCountsArgs: class NumberCountsSystematic(SourceSystematic): - """Class implementing systematics for Number Counts sources.""" + """Abstract base class for systematics for Number Counts sources. + + Derived classes must implement :python`apply` with the correct signature.""" @abstractmethod def apply( diff --git a/tests/likelihood/gauss_family/statistic/source/test_source.py b/tests/likelihood/gauss_family/statistic/source/test_source.py index 0d371286..bc87ff7c 100644 --- a/tests/likelihood/gauss_family/statistic/source/test_source.py +++ b/tests/likelihood/gauss_family/statistic/source/test_source.py @@ -4,6 +4,8 @@ import pytest import pyccl from firecrown.likelihood.gauss_family.statistic.source.source import Tracer +import firecrown.likelihood.gauss_family.statistic.source.number_counts as nc +from firecrown.parameters import ParamsMap class TrivialTracer(Tracer): @@ -42,3 +44,32 @@ def test_tracer_construction_with_name(empty_pyccl_tracer): assert named.halo_2pt is None assert not named.has_pt assert not named.has_hm + + +def test_linear_bias_systematic(): + a = nc.LinearBiasSystematic("xxx") + assert isinstance(a, nc.LinearBiasSystematic) + assert a.alphag is None + assert a.alphaz is None + assert a.z_piv is None + assert not a.is_updated() + + a.update(ParamsMap({"xxx_alphag": 1.0, "xxx_alphaz": 2.0, "xxx_z_piv": 1.5})) + assert a.is_updated() + assert a.alphag == 1.0 + assert a.alphaz == 2.0 + assert a.z_piv == 1.5 + + a.reset() + assert not a.is_updated() + assert a.alphag is None + assert a.alphaz is None + assert a.z_piv is None + + +def test_weak_lensing_source(): + pass + + +def test_number_counts_source(): + pass From 06532d11601af34b8336818a97b09627f9751f2b Mon Sep 17 00:00:00 2001 From: Marc Paterno Date: Mon, 17 Jul 2023 15:51:03 -0500 Subject: [PATCH 02/34] Initialize instance variable in __init__ --- firecrown/updatable.py | 1 + 1 file changed, 1 insertion(+) diff --git a/firecrown/updatable.py b/firecrown/updatable.py index bd28d710..2847ab5e 100644 --- a/firecrown/updatable.py +++ b/firecrown/updatable.py @@ -253,6 +253,7 @@ def __init__(self, iterable=None): :param iterable: An iterable that yields Updatable objects """ super().__init__(iterable) + self._updated: bool = False for item in self: if not isinstance(item, Updatable): raise TypeError( From 1a314a45097a56ce7627b82412e28054abb0f398 Mon Sep 17 00:00:00 2001 From: Marc Paterno Date: Mon, 17 Jul 2023 15:54:19 -0500 Subject: [PATCH 03/34] Add is_updated to generic updatable interface --- firecrown/updatable.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/firecrown/updatable.py b/firecrown/updatable.py index 2847ab5e..7c909a79 100644 --- a/firecrown/updatable.py +++ b/firecrown/updatable.py @@ -146,6 +146,13 @@ def update(self, params: ParamsMap) -> None: # worked. self._updated = True + def is_updated(self) -> bool: + """Return True if the object is currently updated, and False if not. + A default-constructed Updatable has not been updated. After `update`, + but before `reset`, has been called the object is updated. After + `reset` has been called, the object is not currently updated.""" + return self._updated + @final def reset(self) -> None: """Clean up self by clearing the _updated status and reseting all @@ -266,12 +273,26 @@ def update(self, params: ParamsMap): :param params: new parameter values """ + if self._updated: + return + for updatable in self: updatable.update(params) + self._updated = True + + + def is_updated(self) -> bool: + """Return True if the object is currently updated, and False if not. + A default-constructed Updatable has not been updated. After `update`, + but before `reset`, has been called the object is updated. After + `reset` has been called, the object is not currently updated.""" + return self._updated + @final def reset(self): """Resets self by calling reset() on each contained item.""" + self._updated = False for updatable in self: updatable.reset() From 449a5ae7750de5b59ef56485562bf72becf0717a Mon Sep 17 00:00:00 2001 From: Marc Paterno Date: Mon, 17 Jul 2023 15:54:41 -0500 Subject: [PATCH 04/34] Fix line breaking --- firecrown/likelihood/gauss_family/statistic/source/source.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/firecrown/likelihood/gauss_family/statistic/source/source.py b/firecrown/likelihood/gauss_family/statistic/source/source.py index 59688377..e3f60f25 100644 --- a/firecrown/likelihood/gauss_family/statistic/source/source.py +++ b/firecrown/likelihood/gauss_family/statistic/source/source.py @@ -75,8 +75,8 @@ def _update(self, params: ParamsMap): def _reset(self) -> None: """Implementation of the Updatable interface method `_reset`. - This calls the abstract method `_reset_source`, which must be implemented by - all subclasses.""" + This calls the abstract method `_reset_source`, which must be implemented + by all subclasses.""" self._reset_source() @abstractmethod From 7808e79d6f3795cc49ce8de03f1ba6e095af4611 Mon Sep 17 00:00:00 2001 From: Marc Paterno Date: Mon, 17 Jul 2023 15:55:27 -0500 Subject: [PATCH 05/34] Enhance test of number_counts.LinearBiasSystematic --- tests/likelihood/gauss_family/statistic/source/test_source.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/likelihood/gauss_family/statistic/source/test_source.py b/tests/likelihood/gauss_family/statistic/source/test_source.py index bc87ff7c..1d156fcd 100644 --- a/tests/likelihood/gauss_family/statistic/source/test_source.py +++ b/tests/likelihood/gauss_family/statistic/source/test_source.py @@ -49,6 +49,7 @@ def test_tracer_construction_with_name(empty_pyccl_tracer): def test_linear_bias_systematic(): a = nc.LinearBiasSystematic("xxx") assert isinstance(a, nc.LinearBiasSystematic) + assert a.sacc_tracer == "xxx" assert a.alphag is None assert a.alphaz is None assert a.z_piv is None @@ -62,6 +63,7 @@ def test_linear_bias_systematic(): a.reset() assert not a.is_updated() + assert a.sacc_tracer == "xxx" assert a.alphag is None assert a.alphaz is None assert a.z_piv is None From 4860faa0aae51e4c802cf31c3a520bf6894711f1 Mon Sep 17 00:00:00 2001 From: Marc Paterno Date: Mon, 17 Jul 2023 15:55:51 -0500 Subject: [PATCH 06/34] Fix the implementation of _reset --- .../gauss_family/statistic/source/number_counts.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/firecrown/likelihood/gauss_family/statistic/source/number_counts.py b/firecrown/likelihood/gauss_family/statistic/source/number_counts.py index dd7ba1c2..cc33d451 100644 --- a/firecrown/likelihood/gauss_family/statistic/source/number_counts.py +++ b/firecrown/likelihood/gauss_family/statistic/source/number_counts.py @@ -81,9 +81,10 @@ def __init__(self, sacc_tracer: str): @final def _reset(self) -> None: - """Reset this systematic. - - This implementation has nothing to do.""" + """Reset this systematic.""" + self.alphaz = None + self.alphag = None + self.z_piv = None @final def _required_parameters(self) -> RequiredParameters: From ab0ffceb05fd0164c45bc439ecfaca00e084c94b Mon Sep 17 00:00:00 2001 From: Marc Paterno Date: Thu, 10 Aug 2023 20:21:35 -0500 Subject: [PATCH 07/34] Fix and improve comments --- .../statistic/cluster_number_counts.py | 19 ++++++------------- 1 file changed, 6 insertions(+), 13 deletions(-) diff --git a/firecrown/likelihood/gauss_family/statistic/cluster_number_counts.py b/firecrown/likelihood/gauss_family/statistic/cluster_number_counts.py index 53e5498a..92c6b9c8 100644 --- a/firecrown/likelihood/gauss_family/statistic/cluster_number_counts.py +++ b/firecrown/likelihood/gauss_family/statistic/cluster_number_counts.py @@ -29,14 +29,12 @@ class ClusterNumberCounts(Statistic): multiplicity functions, volume element, etc.). This subclass implements the read and computes method for the Statistic class. It is used to compute the theoretical prediction of - cluster number counts given a SACC file and a cosmology. For - further information on how the SACC file shall be created, - check README.md. + cluster number counts given a SACC file and a cosmology. """ def __init__( self, - sacc_tracer: str, + survey_tracer: str, cluster_abundance: ClusterAbundance, cluster_mass: ClusterMass, cluster_redshift: ClusterRedshift, @@ -47,18 +45,13 @@ def __init__( """Initialize the ClusterNumberCounts object. Parameters - :param sacc_tracer: The SACC tracer. There must be only one tracer for all - the number Counts data points. Following the SACC file - documentation README.md, this string should be - 'cluster_counts_true_mass'. + :param survey_tracer: name of the survey tracer in the SACC data. :param cluster_abundance: The cluster abundance model to use. :param systematics: A list of the statistics-level systematics to apply to the statistic. The default of `None` implies no systematics. - """ super().__init__() - - self.sacc_tracer = sacc_tracer + self.sacc_tracer = survey_tracer self.systematics = systematics or [] self.data_vector: Optional[DataVector] = None self.theory_vector: Optional[TheoryVector] = None @@ -163,7 +156,7 @@ def _read_data_type(self, sacc_data, data_type): return data_vector_list, sacc_indices_list - def read(self, sacc_data): + def read(self, sacc_data: sacc.Sacc): """Read the data for this statistic from the SACC file. This function takes the SACC file and extract the necessary parameters needed to compute the number counts likelihood. @@ -223,7 +216,7 @@ def compute_theory_vector(self, tools: ModelingTools) -> TheoryVector: Check README.MD for a complete description of the method. :param tools: ModelingTools firecrown object - Firecrown object used to load the required cosmology. + used to load the required cosmology. :return: Numpy Array of floats An array with the theoretical prediction of the number of clusters From 322cfffc610f99fb0d4fda793d73dc05144e989f Mon Sep 17 00:00:00 2001 From: Marc Paterno Date: Sun, 20 Aug 2023 18:31:40 -0500 Subject: [PATCH 08/34] Add more tests --- .../statistic/cluster_number_counts.py | 10 +--- .../statistic/test_cluster_number_counts.py | 55 +++++++++++++++++++ 2 files changed, 58 insertions(+), 7 deletions(-) create mode 100644 tests/likelihood/gauss_family/statistic/test_cluster_number_counts.py diff --git a/firecrown/likelihood/gauss_family/statistic/cluster_number_counts.py b/firecrown/likelihood/gauss_family/statistic/cluster_number_counts.py index 92c6b9c8..2e458825 100644 --- a/firecrown/likelihood/gauss_family/statistic/cluster_number_counts.py +++ b/firecrown/likelihood/gauss_family/statistic/cluster_number_counts.py @@ -1,7 +1,7 @@ """Cluster Number Count statistic support. This module reads the necessary data from a SACC file to compute the theoretical prediction of cluster number counts inside bins of redshift -and a mass proxy. For further information, check README.md. +and a mass proxy. """ from __future__ import annotations @@ -29,7 +29,7 @@ class ClusterNumberCounts(Statistic): multiplicity functions, volume element, etc.). This subclass implements the read and computes method for the Statistic class. It is used to compute the theoretical prediction of - cluster number counts given a SACC file and a cosmology. + cluster number counts. """ def __init__( @@ -158,11 +158,8 @@ def _read_data_type(self, sacc_data, data_type): def read(self, sacc_data: sacc.Sacc): """Read the data for this statistic from the SACC file. - This function takes the SACC file and extract the necessary - parameters needed to compute the number counts likelihood. - Check README.MD for a complete description of the method. - :param sacc_data: The data in the sacc format. + :param sacc_data: The data in the SACC format. """ survey_tracer: SurveyTracer = sacc_data.get_tracer(self.sacc_tracer) @@ -213,7 +210,6 @@ def get_data_vector(self) -> DataVector: def compute_theory_vector(self, tools: ModelingTools) -> TheoryVector: """Compute a Number Count statistic using the data from the Read method, the cosmology object, and the Bocquet16 halo mass function. - Check README.MD for a complete description of the method. :param tools: ModelingTools firecrown object used to load the required cosmology. diff --git a/tests/likelihood/gauss_family/statistic/test_cluster_number_counts.py b/tests/likelihood/gauss_family/statistic/test_cluster_number_counts.py new file mode 100644 index 00000000..f3e08111 --- /dev/null +++ b/tests/likelihood/gauss_family/statistic/test_cluster_number_counts.py @@ -0,0 +1,55 @@ +"""Tests for ClusterNumberCounts. +""" +import numpy as np +import pytest + +import sacc + +from firecrown.modeling_tools import ModelingTools +from firecrown.likelihood.gauss_family.statistic.cluster_number_counts import ( + ClusterNumberCounts, +) + + +@pytest.fixture(name="minimal_stat") +def fixture_minimal_stat() -> ClusterNumberCounts: + """Return a correctly initialized :python:`ClusterNumberCounts` object.""" + stat = ClusterNumberCounts( + survey_tracer="SDSS", + cluster_abundance=None, + cluster_mass=None, + cluster_redshift=None, + ) + return stat + + +@pytest.fixture(name="missing_survey_tracer") +def fixture_missing_survey_tracer() -> sacc.Sacc: + """Return a sacc.Sacc object that lacks a survey_tracer.""" + return sacc.Sacc() + + +@pytest.fixture(name="good_sacc_data") +def fixture_sacc_data(): + """Return a sacc.Sacc object sufficient to correctly set a + :python:`ClusterNumberCounts` object. + """ + data = sacc.Sacc() + return data + + +def test_missing_survey_tracer( + minimal_stat: ClusterNumberCounts, missing_survey_tracer: sacc.Sacc +): + with pytest.raises( + ValueError, match="The SACC file does not contain the SurveyTracer SDSS." + ): + minimal_stat.read(missing_survey_tracer) + + +def test_read_works(minimal_stat: ClusterNumberCounts, good_sacc_data: sacc.Sacc): + """After read() is called, we should be able to get the statistic's + + :python:`DataVector` and also should be able to call + :python:`compute_theory_vector`. + """ From 6e3275f1ae946b8a566ebef6fa64401cd8fc896d Mon Sep 17 00:00:00 2001 From: Marc Paterno Date: Sun, 20 Aug 2023 18:56:13 -0500 Subject: [PATCH 09/34] Raise ValueError when survey tracer is missing --- .../gauss_family/statistic/cluster_number_counts.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/firecrown/likelihood/gauss_family/statistic/cluster_number_counts.py b/firecrown/likelihood/gauss_family/statistic/cluster_number_counts.py index 2e458825..0a7c3049 100644 --- a/firecrown/likelihood/gauss_family/statistic/cluster_number_counts.py +++ b/firecrown/likelihood/gauss_family/statistic/cluster_number_counts.py @@ -162,8 +162,9 @@ def read(self, sacc_data: sacc.Sacc): :param sacc_data: The data in the SACC format. """ - survey_tracer: SurveyTracer = sacc_data.get_tracer(self.sacc_tracer) - if survey_tracer is None: + try: + survey_tracer: SurveyTracer = sacc_data.get_tracer(self.sacc_tracer) + except KeyError: raise ValueError( f"The SACC file does not contain the SurveyTracer " f"{self.sacc_tracer}." From 142567a27a47a2b350ee5a5c0f2a8acff3a87445 Mon Sep 17 00:00:00 2001 From: Marc Paterno Date: Sun, 20 Aug 2023 18:57:27 -0500 Subject: [PATCH 10/34] Apply black --- firecrown/updatable.py | 1 - 1 file changed, 1 deletion(-) diff --git a/firecrown/updatable.py b/firecrown/updatable.py index 7c909a79..58237f6b 100644 --- a/firecrown/updatable.py +++ b/firecrown/updatable.py @@ -281,7 +281,6 @@ def update(self, params: ParamsMap): self._updated = True - def is_updated(self) -> bool: """Return True if the object is currently updated, and False if not. A default-constructed Updatable has not been updated. After `update`, From 373adcf51a96467cfbfcf6248a5de094f5a93b32 Mon Sep 17 00:00:00 2001 From: Marc Paterno Date: Sun, 20 Aug 2023 19:02:52 -0500 Subject: [PATCH 11/34] Fix mypy deprecation warning --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index f3257905..559b57bb 100644 --- a/setup.cfg +++ b/setup.cfg @@ -12,7 +12,7 @@ warn_redundant_casts = True warn_unused_ignores = True no_implicit_optional = True strict_equality = True -strict_concatenate = True +extra_checks = True disallow_subclassing_any = True disallow_untyped_decorators = True explicit_package_bases = True From 3fada95a38e7843cbec72a97320768607118e789 Mon Sep 17 00:00:00 2001 From: Marc Paterno Date: Sun, 20 Aug 2023 19:03:08 -0500 Subject: [PATCH 12/34] Remove unused imports --- .../gauss_family/statistic/test_cluster_number_counts.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/likelihood/gauss_family/statistic/test_cluster_number_counts.py b/tests/likelihood/gauss_family/statistic/test_cluster_number_counts.py index f3e08111..758a5311 100644 --- a/tests/likelihood/gauss_family/statistic/test_cluster_number_counts.py +++ b/tests/likelihood/gauss_family/statistic/test_cluster_number_counts.py @@ -1,11 +1,9 @@ """Tests for ClusterNumberCounts. """ -import numpy as np import pytest import sacc -from firecrown.modeling_tools import ModelingTools from firecrown.likelihood.gauss_family.statistic.cluster_number_counts import ( ClusterNumberCounts, ) From 52ca0f3060b3e089c6b2cda93793c1f27e7aca12 Mon Sep 17 00:00:00 2001 From: Marc Paterno Date: Mon, 21 Aug 2023 15:42:21 -0500 Subject: [PATCH 13/34] Minimal tweaks for test to pass mypy check --- .../statistic/test_cluster_number_counts.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/tests/likelihood/gauss_family/statistic/test_cluster_number_counts.py b/tests/likelihood/gauss_family/statistic/test_cluster_number_counts.py index 758a5311..acf6aca6 100644 --- a/tests/likelihood/gauss_family/statistic/test_cluster_number_counts.py +++ b/tests/likelihood/gauss_family/statistic/test_cluster_number_counts.py @@ -6,7 +6,10 @@ from firecrown.likelihood.gauss_family.statistic.cluster_number_counts import ( ClusterNumberCounts, + ClusterAbundance, ) +from firecrown.models.cluster_mass_rich_proxy import ClusterMassRich +from firecrown.models.cluster_redshift_spec import ClusterRedshiftSpec @pytest.fixture(name="minimal_stat") @@ -14,9 +17,13 @@ def fixture_minimal_stat() -> ClusterNumberCounts: """Return a correctly initialized :python:`ClusterNumberCounts` object.""" stat = ClusterNumberCounts( survey_tracer="SDSS", - cluster_abundance=None, - cluster_mass=None, - cluster_redshift=None, + cluster_abundance=ClusterAbundance( + halo_mass_definition=None, + halo_mass_function_name="hmf_func", + halo_mass_function_args={}, + ), + cluster_mass=ClusterMassRich(pivot_mass=10.0, pivot_redshift=1.25), + cluster_redshift=ClusterRedshiftSpec(), ) return stat @@ -45,7 +52,7 @@ def test_missing_survey_tracer( minimal_stat.read(missing_survey_tracer) -def test_read_works(minimal_stat: ClusterNumberCounts, good_sacc_data: sacc.Sacc): +def test_read_works(): """After read() is called, we should be able to get the statistic's :python:`DataVector` and also should be able to call From c642786c98cf18b78931bc3858d7349eb9caaa46 Mon Sep 17 00:00:00 2001 From: Sandro Dias Pinto Vitenti Date: Fri, 25 Aug 2023 12:04:49 -0300 Subject: [PATCH 14/34] Automated reset calls in all updatables. --- .../likelihood/gauss_family/gauss_family.py | 14 -------- firecrown/likelihood/gauss_family/gaussian.py | 4 --- .../statistic/cluster_number_counts.py | 7 ---- .../statistic/source/number_counts.py | 35 ------------------- .../gauss_family/statistic/source/source.py | 12 ------- .../statistic/source/weak_lensing.py | 26 -------------- .../gauss_family/statistic/supernova.py | 4 --- .../gauss_family/statistic/two_point.py | 2 -- .../likelihood/gauss_family/student_t.py | 4 --- firecrown/models/cluster_mass.py | 12 ------- firecrown/models/cluster_mass_rich_proxy.py | 6 ---- firecrown/models/cluster_mass_true.py | 4 --- firecrown/models/cluster_redshift.py | 12 ------- firecrown/models/cluster_redshift_spec.py | 4 --- firecrown/updatable.py | 15 +++++++- tests/test_updatable.py | 16 --------- 16 files changed, 14 insertions(+), 163 deletions(-) diff --git a/firecrown/likelihood/gauss_family/gauss_family.py b/firecrown/likelihood/gauss_family/gauss_family.py index fc3eefb8..00e58805 100644 --- a/firecrown/likelihood/gauss_family/gauss_family.py +++ b/firecrown/likelihood/gauss_family/gauss_family.py @@ -129,15 +129,6 @@ def compute_chisq(self, tools: ModelingTools) -> float: return chisq - @final - def _reset(self) -> None: - """Implementation of Likelihood interface method _reset. - - This resets all statistics and calls the abstract method - _reset_gaussian_family.""" - self._reset_gaussian_family() - self.statistics.reset() - @final def _get_derived_parameters(self) -> DerivedParameterCollection: derived_parameters = ( @@ -147,11 +138,6 @@ def _get_derived_parameters(self) -> DerivedParameterCollection: return derived_parameters - @abstractmethod - def _reset_gaussian_family(self) -> None: - """Abstract method to reset GaussianFamily state. Must be implemented by all - subclasses.""" - @final def _required_parameters(self) -> RequiredParameters: """Return a RequiredParameters object containing the information for diff --git a/firecrown/likelihood/gauss_family/gaussian.py b/firecrown/likelihood/gauss_family/gaussian.py index d2027fdb..b6fcf790 100644 --- a/firecrown/likelihood/gauss_family/gaussian.py +++ b/firecrown/likelihood/gauss_family/gaussian.py @@ -18,10 +18,6 @@ def compute_loglike(self, tools: ModelingTools): return -0.5 * self.compute_chisq(tools) - @final - def _reset_gaussian_family(self): - pass - @final def _required_parameters_gaussian_family(self): return RequiredParameters([]) diff --git a/firecrown/likelihood/gauss_family/statistic/cluster_number_counts.py b/firecrown/likelihood/gauss_family/statistic/cluster_number_counts.py index 0a7c3049..aa51fde0 100644 --- a/firecrown/likelihood/gauss_family/statistic/cluster_number_counts.py +++ b/firecrown/likelihood/gauss_family/statistic/cluster_number_counts.py @@ -67,13 +67,6 @@ def __init__( "At least one of use_cluster_counts and use_mean_log_mass must be True." ) - @final - def _reset(self) -> None: - """Reset all contained Updatable objects.""" - self.cluster_abundance.reset() - self.cluster_mass.reset() - self.cluster_redshift.reset() - @final def _required_parameters(self) -> RequiredParameters: """Return an empty RequiredParameters.""" diff --git a/firecrown/likelihood/gauss_family/statistic/source/number_counts.py b/firecrown/likelihood/gauss_family/statistic/source/number_counts.py index cc33d451..9aa9d96e 100644 --- a/firecrown/likelihood/gauss_family/statistic/source/number_counts.py +++ b/firecrown/likelihood/gauss_family/statistic/source/number_counts.py @@ -79,13 +79,6 @@ def __init__(self, sacc_tracer: str): self.z_piv = parameters.create() self.sacc_tracer = sacc_tracer - @final - def _reset(self) -> None: - """Reset this systematic.""" - self.alphaz = None - self.alphag = None - self.z_piv = None - @final def _required_parameters(self) -> RequiredParameters: return RequiredParameters([]) @@ -143,12 +136,6 @@ def __init__(self, sacc_tracer: str): self.b_s = parameters.create() self.sacc_tracer = sacc_tracer - @final - def _reset(self) -> None: - """Reset this systematic. - - This implementation has nothing to do.""" - @final def _required_parameters(self) -> RequiredParameters: return RequiredParameters([]) @@ -202,12 +189,6 @@ def __init__(self, sacc_tracer: str): self.sacc_tracer = sacc_tracer - @final - def _reset(self) -> None: - """Reset this systematic. - - This implementation has nothing to do.""" - @final def _required_parameters(self) -> RequiredParameters: return RequiredParameters([]) @@ -262,12 +243,6 @@ def __init__(self, sacc_tracer: str): self.mag_bias = parameters.create() self.sacc_tracer = sacc_tracer - @final - def _reset(self) -> None: - """Reset this systematic. - - This implementation has nothing to do.""" - @final def _required_parameters(self) -> RequiredParameters: return RequiredParameters([]) @@ -297,12 +272,6 @@ def __init__(self, sacc_tracer: str): self.delta_z = parameters.create() self.sacc_tracer = sacc_tracer - @final - def _reset(self) -> None: - """Reset this systematic. - - This implementation has nothing to do.""" - @final def _required_parameters(self) -> RequiredParameters: return RequiredParameters([]) @@ -358,10 +327,6 @@ def _update_source(self, params: ParamsMap): This implementation must update all contained Updatable instances.""" self.systematics.update(params) - @final - def _reset_source(self) -> None: - self.systematics.reset() - @final def _required_parameters(self) -> RequiredParameters: return self.systematics.required_parameters() diff --git a/firecrown/likelihood/gauss_family/statistic/source/source.py b/firecrown/likelihood/gauss_family/statistic/source/source.py index e3f60f25..1225a0cb 100644 --- a/firecrown/likelihood/gauss_family/statistic/source/source.py +++ b/firecrown/likelihood/gauss_family/statistic/source/source.py @@ -57,10 +57,6 @@ def _update_source(self, params: ParamsMap): that needs to do more than update its contained :python:`Updatable` objects should implement this method.""" - @abstractmethod - def _reset_source(self): - """Abstract method to reset the source.""" - @final def _update(self, params: ParamsMap): """Implementation of Updatable interface method `_update`. @@ -71,14 +67,6 @@ def _update(self, params: ParamsMap): self.tracers = [] self._update_source(params) - @final - def _reset(self) -> None: - """Implementation of the Updatable interface method `_reset`. - - This calls the abstract method `_reset_source`, which must be implemented - by all subclasses.""" - self._reset_source() - @abstractmethod def get_scale(self) -> float: """Abstract method to return the scales for this `Source`.""" diff --git a/firecrown/likelihood/gauss_family/statistic/source/weak_lensing.py b/firecrown/likelihood/gauss_family/statistic/source/weak_lensing.py index 67e07845..e4b25157 100644 --- a/firecrown/likelihood/gauss_family/statistic/source/weak_lensing.py +++ b/firecrown/likelihood/gauss_family/statistic/source/weak_lensing.py @@ -72,12 +72,6 @@ def __init__(self, sacc_tracer: str) -> None: self.mult_bias = parameters.create() self.sacc_tracer = sacc_tracer - @final - def _reset(self) -> None: - """Reset this systematic. - - This implementation has nothing to do.""" - def _required_parameters(self) -> RequiredParameters: return RequiredParameters([]) @@ -138,12 +132,6 @@ def __init__(self, sacc_tracer: Optional[str] = None, alphag=1.0): self.sacc_tracer = sacc_tracer - @final - def _reset(self) -> None: - """Reset this systematic. - - This implementation has nothing to do.""" - @final def _required_parameters(self) -> RequiredParameters: return RequiredParameters([]) @@ -197,12 +185,6 @@ def __init__(self, sacc_tracer: Optional[str] = None): self.sacc_tracer = sacc_tracer - @final - def _reset(self) -> None: - """Reset this systematic. - - This implementation has nothing to do.""" - @final def _required_parameters(self) -> RequiredParameters: return RequiredParameters([]) @@ -249,10 +231,6 @@ def __init__(self, sacc_tracer: str): self.delta_z = parameters.create() self.sacc_tracer = sacc_tracer - @final - def _reset(self) -> None: - pass - @final def _required_parameters(self) -> RequiredParameters: return RequiredParameters([]) @@ -305,10 +283,6 @@ def _update_source(self, params: ParamsMap): This updates all the contained systematics.""" self.systematics.update(params) - @final - def _reset_source(self) -> None: - self.systematics.reset() - @final def _required_parameters(self) -> RequiredParameters: return self.systematics.required_parameters() diff --git a/firecrown/likelihood/gauss_family/statistic/supernova.py b/firecrown/likelihood/gauss_family/statistic/supernova.py index 9198529d..2ddaa911 100644 --- a/firecrown/likelihood/gauss_family/statistic/supernova.py +++ b/firecrown/likelihood/gauss_family/statistic/supernova.py @@ -40,10 +40,6 @@ def read(self, sacc_data: sacc.Sacc): self.data_vector = DataVector.from_list([dp.value for dp in data_points]) self.sacc_indices = np.arange(len(self.data_vector)) - @final - def _reset(self): - """Reset this statistic. This implementation has nothing to do.""" - @final def _required_parameters(self) -> RequiredParameters: """Return an empty RequiredParameters.""" diff --git a/firecrown/likelihood/gauss_family/statistic/two_point.py b/firecrown/likelihood/gauss_family/statistic/two_point.py index 87d819b2..cf99aadd 100644 --- a/firecrown/likelihood/gauss_family/statistic/two_point.py +++ b/firecrown/likelihood/gauss_family/statistic/two_point.py @@ -226,8 +226,6 @@ def __init__( @final def _reset(self) -> None: """Prepared to be called again for a new cosmology.""" - self.source0.reset() - self.source1.reset() # TODO: Why is self.predicted_statistic_ not re-set to None here? # If we do that, then the CosmoSIS module fails -- because this data # is accessed from that code. diff --git a/firecrown/likelihood/gauss_family/student_t.py b/firecrown/likelihood/gauss_family/student_t.py index 5e113506..0c0ed9b6 100644 --- a/firecrown/likelihood/gauss_family/student_t.py +++ b/firecrown/likelihood/gauss_family/student_t.py @@ -44,10 +44,6 @@ def compute_loglike(self, tools: ModelingTools): chi2 = self.compute_chisq(ccl_cosmo) return -0.5 * self.nu * np.log(1.0 + chi2 / (self.nu - 1.0)) - @final - def _reset_gaussian_family(self): - pass - @final def _required_parameters_gaussian_family(self): return RequiredParameters([]) diff --git a/firecrown/models/cluster_mass.py b/firecrown/models/cluster_mass.py index 129c9918..a7efe8dc 100644 --- a/firecrown/models/cluster_mass.py +++ b/firecrown/models/cluster_mass.py @@ -71,24 +71,12 @@ def _update_cluster_mass(self, params: ParamsMap): Subclasses that need to do more than update their contained :python:`Updatable` instance variables should implement this method.""" - @abstractmethod - def _reset_cluster_mass(self): - """Abstract method to reset the ClusterMass.""" - @final def _update(self, params: ParamsMap): """Implementation of Updatable interface method `_update`.""" self._update_cluster_mass(params) - @final - def _reset(self) -> None: - """Implementation of the Updatable interface method `_reset`. - - This calls the abstract method `_reset_cluster_mass`, which must be implemented - by all subclasses.""" - self._reset_cluster_mass() - @abstractmethod def gen_bins_by_array(self, logM_obs_bins: np.ndarray) -> List[ClusterMassArgument]: """Generate bins by an array of bin edges.""" diff --git a/firecrown/models/cluster_mass_rich_proxy.py b/firecrown/models/cluster_mass_rich_proxy.py index 362f0488..268f84dd 100644 --- a/firecrown/models/cluster_mass_rich_proxy.py +++ b/firecrown/models/cluster_mass_rich_proxy.py @@ -50,12 +50,6 @@ def _update_cluster_mass(self, params: ParamsMap): This implementation has nothing to do.""" - @final - def _reset_cluster_mass(self) -> None: - """Reset the ClusterMass object. - - This implementation has nothing to do.""" - @final def _required_parameters(self) -> RequiredParameters: return RequiredParameters([]) diff --git a/firecrown/models/cluster_mass_true.py b/firecrown/models/cluster_mass_true.py index 90fbd8cb..df1750b3 100644 --- a/firecrown/models/cluster_mass_true.py +++ b/firecrown/models/cluster_mass_true.py @@ -26,10 +26,6 @@ class ClusterMassTrue(ClusterMass): def _update_cluster_mass(self, params: ParamsMap): """Method to update the ClusterMassTrue from the given ParamsMap.""" - @final - def _reset_cluster_mass(self): - """Method to reset the ClusterMassTrue.""" - @final def _required_parameters(self) -> RequiredParameters: return RequiredParameters([]) diff --git a/firecrown/models/cluster_redshift.py b/firecrown/models/cluster_redshift.py index c5a262c3..ea9952fb 100644 --- a/firecrown/models/cluster_redshift.py +++ b/firecrown/models/cluster_redshift.py @@ -72,24 +72,12 @@ def _update_cluster_redshift(self, params: ParamsMap): Subclasses that need to do more than update their contained :python:`Updatable` instance variables should implement this method.""" - @abstractmethod - def _reset_cluster_redshift(self): - """Abstract method to reset the ClusterRedshift.""" - @final def _update(self, params: ParamsMap): """Implementation of Updatable interface method `_update`.""" self._update_cluster_redshift(params) - @final - def _reset(self) -> None: - """Implementation of the Updatable interface method `_reset`. - - This calls the abstract method `_reset_cluster_redshift`, which must be - implemented by all subclasses.""" - self._reset_cluster_redshift() - @abstractmethod def gen_bins_by_array(self, z_bins: np.ndarray) -> List[ClusterRedshiftArgument]: """Generate the bins by an array of bin edges.""" diff --git a/firecrown/models/cluster_redshift_spec.py b/firecrown/models/cluster_redshift_spec.py index 8888fd78..3b7d7d42 100644 --- a/firecrown/models/cluster_redshift_spec.py +++ b/firecrown/models/cluster_redshift_spec.py @@ -24,10 +24,6 @@ class ClusterRedshiftSpec(ClusterRedshift): def _update_cluster_redshift(self, params: ParamsMap): """Method to update the ClusterRedshiftSpec from the given ParamsMap.""" - @final - def _reset_cluster_redshift(self): - """Method to reset the ClusterRedshiftSpec.""" - @final def _required_parameters(self) -> RequiredParameters: return RequiredParameters([]) diff --git a/firecrown/updatable.py b/firecrown/updatable.py index 58237f6b..ec8974f6 100644 --- a/firecrown/updatable.py +++ b/firecrown/updatable.py @@ -161,6 +161,20 @@ def reset(self) -> None: Each MCMC framework connector should call this after handling an MCMC sample.""" + + # If we have not been updated, there is nothing to do. + if not self._updated: + return + + # We reset in the inverse order, first the contained updatables, then + # the current object. + for item in self._updatables: + item.reset() + + # Reset the sampler parameters to None. + for parameter in self._sampler_parameters: + setattr(self, parameter, None) + self._updated = False self._returned_derived = False self._reset() @@ -178,7 +192,6 @@ def _update(self, params: ParamsMap) -> None: :param params: a new set of parameter values """ - @abstractmethod def _reset(self) -> None: # pragma: no cover """Abstract method to be implemented by all concrete classes to update self. diff --git a/tests/test_updatable.py b/tests/test_updatable.py index 67d55d01..08d40803 100644 --- a/tests/test_updatable.py +++ b/tests/test_updatable.py @@ -11,19 +11,6 @@ ) -class MissingReset(Updatable): - """A type that is abstract because it does not implement required_parameters.""" - - def _update(self, params): # pragma: no cover - pass - - def _required_parameters(self): # pragma: no cover - pass - - def _get_derived_parameters(self) -> DerivedParameterCollection: - return DerivedParameterCollection([]) - - class MissingRequiredParameters(Updatable): """A type that is abstract because it does not implement required_parameters.""" @@ -83,9 +70,6 @@ def _get_derived_parameters(self) -> DerivedParameterCollection: def test_verify_abstract_interface(): - with pytest.raises(TypeError): - # pylint: disable-next=E0110,W0612 - _ = MissingReset() # type: ignore with pytest.raises(TypeError): # pylint: disable-next=E0110,W0612 _ = MissingRequiredParameters() # type: ignore From 989ac7161a7e20a9ad7e930df7624fb88cc54ba5 Mon Sep 17 00:00:00 2001 From: Sandro Dias Pinto Vitenti Date: Fri, 25 Aug 2023 13:29:01 -0300 Subject: [PATCH 15/34] Automated required_parameters protocol. --- .../likelihood/gauss_family/gauss_family.py | 20 +-------------- firecrown/likelihood/gauss_family/gaussian.py | 6 +---- .../statistic/cluster_number_counts.py | 10 -------- .../statistic/source/number_counts.py | 25 ------------------- .../statistic/source/weak_lensing.py | 20 --------------- .../gauss_family/statistic/supernova.py | 7 +----- .../gauss_family/statistic/two_point.py | 6 +---- .../likelihood/gauss_family/student_t.py | 6 +---- firecrown/models/cluster_abundance.py | 6 +---- firecrown/models/cluster_mass_rich_proxy.py | 5 ---- firecrown/models/cluster_mass_true.py | 5 ---- firecrown/models/cluster_redshift_spec.py | 5 ---- firecrown/updatable.py | 12 ++++++--- tests/test_updatable.py | 19 -------------- 14 files changed, 15 insertions(+), 137 deletions(-) diff --git a/firecrown/likelihood/gauss_family/gauss_family.py b/firecrown/likelihood/gauss_family/gauss_family.py index 00e58805..52223540 100644 --- a/firecrown/likelihood/gauss_family/gauss_family.py +++ b/firecrown/likelihood/gauss_family/gauss_family.py @@ -23,7 +23,7 @@ from ...modeling_tools import ModelingTools from ...updatable import UpdatableCollection from .statistic.statistic import Statistic -from ...parameters import RequiredParameters, DerivedParameterCollection +from ...parameters import DerivedParameterCollection class GaussFamily(Likelihood): @@ -138,24 +138,6 @@ def _get_derived_parameters(self) -> DerivedParameterCollection: return derived_parameters - @final - def _required_parameters(self) -> RequiredParameters: - """Return a RequiredParameters object containing the information for - this Updatable. - - This includes the required parameters for all statistics, as well as those - for the derived class. - - Derived classes must implement required_parameters_gaussian_family.""" - stats_rp = self.statistics.required_parameters() - stats_rp = self._required_parameters_gaussian_family() + stats_rp - - return stats_rp - - @abstractmethod - def _required_parameters_gaussian_family(self): - """Required parameters for GaussFamily subclasses.""" - @abstractmethod def _get_derived_parameters_gaussian_family(self) -> DerivedParameterCollection: """Get derived parameters for GaussFamily subclasses.""" diff --git a/firecrown/likelihood/gauss_family/gaussian.py b/firecrown/likelihood/gauss_family/gaussian.py index b6fcf790..8ea6ad79 100644 --- a/firecrown/likelihood/gauss_family/gaussian.py +++ b/firecrown/likelihood/gauss_family/gaussian.py @@ -6,7 +6,7 @@ from typing import final from .gauss_family import GaussFamily -from ...parameters import RequiredParameters, DerivedParameterCollection +from ...parameters import DerivedParameterCollection from ...modeling_tools import ModelingTools @@ -18,10 +18,6 @@ def compute_loglike(self, tools: ModelingTools): return -0.5 * self.compute_chisq(tools) - @final - def _required_parameters_gaussian_family(self): - return RequiredParameters([]) - @final def _get_derived_parameters_gaussian_family(self) -> DerivedParameterCollection: return DerivedParameterCollection([]) diff --git a/firecrown/likelihood/gauss_family/statistic/cluster_number_counts.py b/firecrown/likelihood/gauss_family/statistic/cluster_number_counts.py index aa51fde0..bcd266ae 100644 --- a/firecrown/likelihood/gauss_family/statistic/cluster_number_counts.py +++ b/firecrown/likelihood/gauss_family/statistic/cluster_number_counts.py @@ -15,7 +15,6 @@ from .statistic import Statistic, DataVector, TheoryVector from .source.source import SourceSystematic from ....parameters import ( - RequiredParameters, DerivedParameterCollection, ) from ....models.cluster_abundance import ClusterAbundance @@ -67,15 +66,6 @@ def __init__( "At least one of use_cluster_counts and use_mean_log_mass must be True." ) - @final - def _required_parameters(self) -> RequiredParameters: - """Return an empty RequiredParameters.""" - return ( - self.cluster_abundance.required_parameters() - + self.cluster_mass.required_parameters() - + self.cluster_redshift.required_parameters() - ) - @final def _get_derived_parameters(self) -> DerivedParameterCollection: """Return an empty DerivedParameterCollection.""" diff --git a/firecrown/likelihood/gauss_family/statistic/source/number_counts.py b/firecrown/likelihood/gauss_family/statistic/source/number_counts.py index 9aa9d96e..c50751ac 100644 --- a/firecrown/likelihood/gauss_family/statistic/source/number_counts.py +++ b/firecrown/likelihood/gauss_family/statistic/source/number_counts.py @@ -18,7 +18,6 @@ from .....modeling_tools import ModelingTools from .....parameters import ( ParamsMap, - RequiredParameters, DerivedParameterScalar, DerivedParameterCollection, ) @@ -79,10 +78,6 @@ def __init__(self, sacc_tracer: str): self.z_piv = parameters.create() self.sacc_tracer = sacc_tracer - @final - def _required_parameters(self) -> RequiredParameters: - return RequiredParameters([]) - @final def _get_derived_parameters(self) -> DerivedParameterCollection: return DerivedParameterCollection([]) @@ -136,10 +131,6 @@ def __init__(self, sacc_tracer: str): self.b_s = parameters.create() self.sacc_tracer = sacc_tracer - @final - def _required_parameters(self) -> RequiredParameters: - return RequiredParameters([]) - @final def _get_derived_parameters(self) -> DerivedParameterCollection: return DerivedParameterCollection([]) @@ -189,10 +180,6 @@ def __init__(self, sacc_tracer: str): self.sacc_tracer = sacc_tracer - @final - def _required_parameters(self) -> RequiredParameters: - return RequiredParameters([]) - @final def _get_derived_parameters(self) -> DerivedParameterCollection: return DerivedParameterCollection([]) @@ -243,10 +230,6 @@ def __init__(self, sacc_tracer: str): self.mag_bias = parameters.create() self.sacc_tracer = sacc_tracer - @final - def _required_parameters(self) -> RequiredParameters: - return RequiredParameters([]) - @final def _get_derived_parameters(self) -> DerivedParameterCollection: return DerivedParameterCollection([]) @@ -272,10 +255,6 @@ def __init__(self, sacc_tracer: str): self.delta_z = parameters.create() self.sacc_tracer = sacc_tracer - @final - def _required_parameters(self) -> RequiredParameters: - return RequiredParameters([]) - @final def _get_derived_parameters(self) -> DerivedParameterCollection: return DerivedParameterCollection([]) @@ -327,10 +306,6 @@ def _update_source(self, params: ParamsMap): This implementation must update all contained Updatable instances.""" self.systematics.update(params) - @final - def _required_parameters(self) -> RequiredParameters: - return self.systematics.required_parameters() - @final def _get_derived_parameters(self) -> DerivedParameterCollection: if self.derived_scale: diff --git a/firecrown/likelihood/gauss_family/statistic/source/weak_lensing.py b/firecrown/likelihood/gauss_family/statistic/source/weak_lensing.py index e4b25157..102b78ed 100644 --- a/firecrown/likelihood/gauss_family/statistic/source/weak_lensing.py +++ b/firecrown/likelihood/gauss_family/statistic/source/weak_lensing.py @@ -18,7 +18,6 @@ from ..... import parameters from .....parameters import ( ParamsMap, - RequiredParameters, DerivedParameterCollection, ) from .....modeling_tools import ModelingTools @@ -72,9 +71,6 @@ def __init__(self, sacc_tracer: str) -> None: self.mult_bias = parameters.create() self.sacc_tracer = sacc_tracer - def _required_parameters(self) -> RequiredParameters: - return RequiredParameters([]) - @final def _get_derived_parameters(self) -> DerivedParameterCollection: return DerivedParameterCollection([]) @@ -132,10 +128,6 @@ def __init__(self, sacc_tracer: Optional[str] = None, alphag=1.0): self.sacc_tracer = sacc_tracer - @final - def _required_parameters(self) -> RequiredParameters: - return RequiredParameters([]) - @final def _get_derived_parameters(self) -> DerivedParameterCollection: return DerivedParameterCollection([]) @@ -185,10 +177,6 @@ def __init__(self, sacc_tracer: Optional[str] = None): self.sacc_tracer = sacc_tracer - @final - def _required_parameters(self) -> RequiredParameters: - return RequiredParameters([]) - @final def _get_derived_parameters(self) -> DerivedParameterCollection: return DerivedParameterCollection([]) @@ -231,10 +219,6 @@ def __init__(self, sacc_tracer: str): self.delta_z = parameters.create() self.sacc_tracer = sacc_tracer - @final - def _required_parameters(self) -> RequiredParameters: - return RequiredParameters([]) - @final def _get_derived_parameters(self) -> DerivedParameterCollection: derived_parameters = DerivedParameterCollection([]) @@ -283,10 +267,6 @@ def _update_source(self, params: ParamsMap): This updates all the contained systematics.""" self.systematics.update(params) - @final - def _required_parameters(self) -> RequiredParameters: - return self.systematics.required_parameters() - @final def _get_derived_parameters(self) -> DerivedParameterCollection: derived_parameters = DerivedParameterCollection([]) diff --git a/firecrown/likelihood/gauss_family/statistic/supernova.py b/firecrown/likelihood/gauss_family/statistic/supernova.py index 2ddaa911..e76dc0b0 100644 --- a/firecrown/likelihood/gauss_family/statistic/supernova.py +++ b/firecrown/likelihood/gauss_family/statistic/supernova.py @@ -13,7 +13,7 @@ from ....modeling_tools import ModelingTools from .statistic import Statistic, DataVector, TheoryVector from .... import parameters -from ....parameters import RequiredParameters, DerivedParameterCollection +from ....parameters import DerivedParameterCollection class Supernova(Statistic): @@ -40,11 +40,6 @@ def read(self, sacc_data: sacc.Sacc): self.data_vector = DataVector.from_list([dp.value for dp in data_points]) self.sacc_indices = np.arange(len(self.data_vector)) - @final - def _required_parameters(self) -> RequiredParameters: - """Return an empty RequiredParameters.""" - return RequiredParameters([]) - @final def _get_derived_parameters(self) -> DerivedParameterCollection: """Return an empty DerivedParameterCollection.""" diff --git a/firecrown/likelihood/gauss_family/statistic/two_point.py b/firecrown/likelihood/gauss_family/statistic/two_point.py index cf99aadd..a55b3a30 100644 --- a/firecrown/likelihood/gauss_family/statistic/two_point.py +++ b/firecrown/likelihood/gauss_family/statistic/two_point.py @@ -19,7 +19,7 @@ from .statistic import Statistic, DataVector, TheoryVector from .source.source import Source, Tracer -from ....parameters import RequiredParameters, DerivedParameterCollection +from ....parameters import DerivedParameterCollection # only supported types are here, anything else will throw # a value error @@ -230,10 +230,6 @@ def _reset(self) -> None: # If we do that, then the CosmoSIS module fails -- because this data # is accessed from that code. - @final - def _required_parameters(self) -> RequiredParameters: - return self.source0.required_parameters() + self.source1.required_parameters() - @final def _get_derived_parameters(self) -> DerivedParameterCollection: derived_parameters = DerivedParameterCollection([]) diff --git a/firecrown/likelihood/gauss_family/student_t.py b/firecrown/likelihood/gauss_family/student_t.py index 0c0ed9b6..7ba01e58 100644 --- a/firecrown/likelihood/gauss_family/student_t.py +++ b/firecrown/likelihood/gauss_family/student_t.py @@ -11,7 +11,7 @@ from ...modeling_tools import ModelingTools from .statistic.statistic import Statistic from ... import parameters -from ...parameters import RequiredParameters, DerivedParameterCollection +from ...parameters import DerivedParameterCollection class StudentT(GaussFamily): @@ -44,10 +44,6 @@ def compute_loglike(self, tools: ModelingTools): chi2 = self.compute_chisq(ccl_cosmo) return -0.5 * self.nu * np.log(1.0 + chi2 / (self.nu - 1.0)) - @final - def _required_parameters_gaussian_family(self): - return RequiredParameters([]) - @final def _get_derived_parameters_gaussian_family(self) -> DerivedParameterCollection: return DerivedParameterCollection([]) diff --git a/firecrown/models/cluster_abundance.py b/firecrown/models/cluster_abundance.py index 1cba111f..0ecfbed2 100644 --- a/firecrown/models/cluster_abundance.py +++ b/firecrown/models/cluster_abundance.py @@ -12,7 +12,7 @@ import scipy.integrate from ..updatable import Updatable -from ..parameters import RequiredParameters, DerivedParameterCollection +from ..parameters import DerivedParameterCollection from .cluster_mass import ClusterMassArgument from .cluster_redshift import ClusterRedshiftArgument @@ -104,10 +104,6 @@ def _reset(self) -> None: """Implementation of the Updatable interface method `_reset`.""" self.halo_mass_function = None - @final - def _required_parameters(self) -> RequiredParameters: - return RequiredParameters([]) - @final def _get_derived_parameters(self) -> DerivedParameterCollection: return DerivedParameterCollection([]) diff --git a/firecrown/models/cluster_mass_rich_proxy.py b/firecrown/models/cluster_mass_rich_proxy.py index 268f84dd..ed25eb67 100644 --- a/firecrown/models/cluster_mass_rich_proxy.py +++ b/firecrown/models/cluster_mass_rich_proxy.py @@ -10,7 +10,6 @@ from ..parameters import ( ParamsMap, - RequiredParameters, DerivedParameterCollection, ) from .cluster_mass import ClusterMass, ClusterMassArgument @@ -50,10 +49,6 @@ def _update_cluster_mass(self, params: ParamsMap): This implementation has nothing to do.""" - @final - def _required_parameters(self) -> RequiredParameters: - return RequiredParameters([]) - @final def _get_derived_parameters(self) -> DerivedParameterCollection: return DerivedParameterCollection([]) diff --git a/firecrown/models/cluster_mass_true.py b/firecrown/models/cluster_mass_true.py index df1750b3..0699974f 100644 --- a/firecrown/models/cluster_mass_true.py +++ b/firecrown/models/cluster_mass_true.py @@ -13,7 +13,6 @@ from ..parameters import ( ParamsMap, - RequiredParameters, DerivedParameterCollection, ) from .cluster_mass import ClusterMass, ClusterMassArgument @@ -26,10 +25,6 @@ class ClusterMassTrue(ClusterMass): def _update_cluster_mass(self, params: ParamsMap): """Method to update the ClusterMassTrue from the given ParamsMap.""" - @final - def _required_parameters(self) -> RequiredParameters: - return RequiredParameters([]) - @final def _get_derived_parameters(self) -> DerivedParameterCollection: return DerivedParameterCollection([]) diff --git a/firecrown/models/cluster_redshift_spec.py b/firecrown/models/cluster_redshift_spec.py index 3b7d7d42..746d0bba 100644 --- a/firecrown/models/cluster_redshift_spec.py +++ b/firecrown/models/cluster_redshift_spec.py @@ -11,7 +11,6 @@ from ..parameters import ( ParamsMap, - RequiredParameters, DerivedParameterCollection, ) from .cluster_redshift import ClusterRedshift, ClusterRedshiftArgument @@ -24,10 +23,6 @@ class ClusterRedshiftSpec(ClusterRedshift): def _update_cluster_redshift(self, params: ParamsMap): """Method to update the ClusterRedshiftSpec from the given ParamsMap.""" - @final - def _required_parameters(self) -> RequiredParameters: - return RequiredParameters([]) - @final def _get_derived_parameters(self) -> DerivedParameterCollection: return DerivedParameterCollection([]) diff --git a/firecrown/updatable.py b/firecrown/updatable.py index ec8974f6..28e1d19d 100644 --- a/firecrown/updatable.py +++ b/firecrown/updatable.py @@ -205,7 +205,7 @@ def _reset(self) -> None: # pragma: no cover def required_parameters(self) -> RequiredParameters: # pragma: no cover """Return a RequiredParameters object containing the information for all parameters defined in the implementing class, any additional - parameter + parameter. """ sampler_parameters = RequiredParameters( @@ -216,17 +216,23 @@ def required_parameters(self) -> RequiredParameters: # pragma: no cover ) additional_parameters = self._required_parameters() + for item in self._updatables: + additional_parameters = additional_parameters + item.required_parameters() + return sampler_parameters + additional_parameters - @abstractmethod def _required_parameters(self) -> RequiredParameters: # pragma: no cover """Return a RequiredParameters object containing the information for - this Updatable. This method must be overridden by concrete classes. + this Updatable. This method can be overridden by subclasses to add + additional parameters. The default implementation returns an empty + RequiredParameters object. This is only implemented to allow The base class implementation returns a list with all SamplerParameter objects properties. """ + return RequiredParameters([]) + @final def get_derived_parameters( self, diff --git a/tests/test_updatable.py b/tests/test_updatable.py index 08d40803..ea324406 100644 --- a/tests/test_updatable.py +++ b/tests/test_updatable.py @@ -11,19 +11,6 @@ ) -class MissingRequiredParameters(Updatable): - """A type that is abstract because it does not implement required_parameters.""" - - def _update(self, params): # pragma: no cover - pass - - def _reset(self) -> None: - pass - - def _get_derived_parameters(self) -> DerivedParameterCollection: - return DerivedParameterCollection([]) - - class MinimalUpdatable(Updatable): """A concrete time that implements Updatable.""" @@ -69,12 +56,6 @@ def _get_derived_parameters(self) -> DerivedParameterCollection: return DerivedParameterCollection([]) -def test_verify_abstract_interface(): - with pytest.raises(TypeError): - # pylint: disable-next=E0110,W0612 - _ = MissingRequiredParameters() # type: ignore - - def test_simple_updatable(): obj = SimpleUpdatable() expected_params = RequiredParameters(["x", "y"]) From 74a1a5bf5c0e9bdfcea292312a8a0b86f3e65bac Mon Sep 17 00:00:00 2001 From: Sandro Dias Pinto Vitenti Date: Fri, 25 Aug 2023 14:03:53 -0300 Subject: [PATCH 16/34] Automated derived_parameters protocol. --- .../likelihood/gauss_family/gauss_family.py | 15 ----------- firecrown/likelihood/gauss_family/gaussian.py | 6 ----- .../statistic/cluster_number_counts.py | 15 +---------- .../statistic/source/number_counts.py | 23 ---------------- .../statistic/source/weak_lensing.py | 27 ------------------- .../gauss_family/statistic/supernova.py | 8 +----- .../gauss_family/statistic/two_point.py | 8 ------ .../likelihood/gauss_family/student_t.py | 7 +---- firecrown/models/cluster_abundance.py | 5 ---- firecrown/models/cluster_mass_rich_proxy.py | 5 ---- firecrown/models/cluster_mass_true.py | 5 ---- firecrown/models/cluster_redshift_spec.py | 5 ---- firecrown/updatable.py | 21 +++++++++++---- tests/test_updatable.py | 9 ------- 14 files changed, 19 insertions(+), 140 deletions(-) diff --git a/firecrown/likelihood/gauss_family/gauss_family.py b/firecrown/likelihood/gauss_family/gauss_family.py index 52223540..884739b7 100644 --- a/firecrown/likelihood/gauss_family/gauss_family.py +++ b/firecrown/likelihood/gauss_family/gauss_family.py @@ -10,7 +10,6 @@ from __future__ import annotations from typing import List, Optional, Tuple, Sequence from typing import final -from abc import abstractmethod import warnings import numpy as np @@ -23,7 +22,6 @@ from ...modeling_tools import ModelingTools from ...updatable import UpdatableCollection from .statistic.statistic import Statistic -from ...parameters import DerivedParameterCollection class GaussFamily(Likelihood): @@ -128,16 +126,3 @@ def compute_chisq(self, tools: ModelingTools) -> float: chisq = np.dot(x, x) return chisq - - @final - def _get_derived_parameters(self) -> DerivedParameterCollection: - derived_parameters = ( - self._get_derived_parameters_gaussian_family() - + self.statistics.get_derived_parameters() - ) - - return derived_parameters - - @abstractmethod - def _get_derived_parameters_gaussian_family(self) -> DerivedParameterCollection: - """Get derived parameters for GaussFamily subclasses.""" diff --git a/firecrown/likelihood/gauss_family/gaussian.py b/firecrown/likelihood/gauss_family/gaussian.py index 8ea6ad79..af25bab4 100644 --- a/firecrown/likelihood/gauss_family/gaussian.py +++ b/firecrown/likelihood/gauss_family/gaussian.py @@ -3,10 +3,8 @@ """ from __future__ import annotations -from typing import final from .gauss_family import GaussFamily -from ...parameters import DerivedParameterCollection from ...modeling_tools import ModelingTools @@ -17,7 +15,3 @@ def compute_loglike(self, tools: ModelingTools): """Compute the log-likelihood.""" return -0.5 * self.compute_chisq(tools) - - @final - def _get_derived_parameters_gaussian_family(self) -> DerivedParameterCollection: - return DerivedParameterCollection([]) diff --git a/firecrown/likelihood/gauss_family/statistic/cluster_number_counts.py b/firecrown/likelihood/gauss_family/statistic/cluster_number_counts.py index bcd266ae..eced347a 100644 --- a/firecrown/likelihood/gauss_family/statistic/cluster_number_counts.py +++ b/firecrown/likelihood/gauss_family/statistic/cluster_number_counts.py @@ -5,7 +5,7 @@ """ from __future__ import annotations -from typing import List, Dict, Tuple, Optional, final +from typing import List, Dict, Tuple, Optional import numpy as np @@ -14,9 +14,6 @@ from .statistic import Statistic, DataVector, TheoryVector from .source.source import SourceSystematic -from ....parameters import ( - DerivedParameterCollection, -) from ....models.cluster_abundance import ClusterAbundance from ....models.cluster_mass import ClusterMass, ClusterMassArgument from ....models.cluster_redshift import ClusterRedshift, ClusterRedshiftArgument @@ -66,16 +63,6 @@ def __init__( "At least one of use_cluster_counts and use_mean_log_mass must be True." ) - @final - def _get_derived_parameters(self) -> DerivedParameterCollection: - """Return an empty DerivedParameterCollection.""" - derived_parameters = DerivedParameterCollection([]) - derived_parameters += self.cluster_abundance.get_derived_parameters() - derived_parameters += self.cluster_mass.get_derived_parameters() - derived_parameters += self.cluster_redshift.get_derived_parameters() - - return derived_parameters - def _read_data_type(self, sacc_data, data_type): """Internal function to read the data from the SACC file.""" tracers_combinations = np.array( diff --git a/firecrown/likelihood/gauss_family/statistic/source/number_counts.py b/firecrown/likelihood/gauss_family/statistic/source/number_counts.py index c50751ac..4c1fa69c 100644 --- a/firecrown/likelihood/gauss_family/statistic/source/number_counts.py +++ b/firecrown/likelihood/gauss_family/statistic/source/number_counts.py @@ -78,10 +78,6 @@ def __init__(self, sacc_tracer: str): self.z_piv = parameters.create() self.sacc_tracer = sacc_tracer - @final - def _get_derived_parameters(self) -> DerivedParameterCollection: - return DerivedParameterCollection([]) - def apply( self, tools: ModelingTools, tracer_arg: NumberCountsArgs ) -> NumberCountsArgs: @@ -131,10 +127,6 @@ def __init__(self, sacc_tracer: str): self.b_s = parameters.create() self.sacc_tracer = sacc_tracer - @final - def _get_derived_parameters(self) -> DerivedParameterCollection: - return DerivedParameterCollection([]) - def apply( self, tools: ModelingTools, tracer_arg: NumberCountsArgs ) -> NumberCountsArgs: @@ -180,10 +172,6 @@ def __init__(self, sacc_tracer: str): self.sacc_tracer = sacc_tracer - @final - def _get_derived_parameters(self) -> DerivedParameterCollection: - return DerivedParameterCollection([]) - def apply( self, tools: ModelingTools, tracer_arg: NumberCountsArgs ) -> NumberCountsArgs: @@ -230,10 +218,6 @@ def __init__(self, sacc_tracer: str): self.mag_bias = parameters.create() self.sacc_tracer = sacc_tracer - @final - def _get_derived_parameters(self) -> DerivedParameterCollection: - return DerivedParameterCollection([]) - def apply( self, tools: ModelingTools, tracer_arg: NumberCountsArgs ) -> NumberCountsArgs: @@ -255,10 +239,6 @@ def __init__(self, sacc_tracer: str): self.delta_z = parameters.create() self.sacc_tracer = sacc_tracer - @final - def _get_derived_parameters(self) -> DerivedParameterCollection: - return DerivedParameterCollection([]) - def apply(self, tools: ModelingTools, tracer_arg: NumberCountsArgs): """Apply a shift to the photo-z distribution of a source.""" @@ -318,9 +298,6 @@ def _get_derived_parameters(self) -> DerivedParameterCollection: derived_parameters = DerivedParameterCollection([derived_scale]) else: derived_parameters = DerivedParameterCollection([]) - derived_parameters = ( - derived_parameters + self.systematics.get_derived_parameters() - ) return derived_parameters diff --git a/firecrown/likelihood/gauss_family/statistic/source/weak_lensing.py b/firecrown/likelihood/gauss_family/statistic/source/weak_lensing.py index 102b78ed..bbe0f09c 100644 --- a/firecrown/likelihood/gauss_family/statistic/source/weak_lensing.py +++ b/firecrown/likelihood/gauss_family/statistic/source/weak_lensing.py @@ -18,7 +18,6 @@ from ..... import parameters from .....parameters import ( ParamsMap, - DerivedParameterCollection, ) from .....modeling_tools import ModelingTools from .....updatable import UpdatableCollection @@ -71,10 +70,6 @@ def __init__(self, sacc_tracer: str) -> None: self.mult_bias = parameters.create() self.sacc_tracer = sacc_tracer - @final - def _get_derived_parameters(self) -> DerivedParameterCollection: - return DerivedParameterCollection([]) - def apply( self, tools: ModelingTools, tracer_arg: WeakLensingArgs ) -> WeakLensingArgs: @@ -128,10 +123,6 @@ def __init__(self, sacc_tracer: Optional[str] = None, alphag=1.0): self.sacc_tracer = sacc_tracer - @final - def _get_derived_parameters(self) -> DerivedParameterCollection: - return DerivedParameterCollection([]) - def apply( self, tools: ModelingTools, tracer_arg: WeakLensingArgs ) -> WeakLensingArgs: @@ -177,10 +168,6 @@ def __init__(self, sacc_tracer: Optional[str] = None): self.sacc_tracer = sacc_tracer - @final - def _get_derived_parameters(self) -> DerivedParameterCollection: - return DerivedParameterCollection([]) - def apply( self, tools: ModelingTools, tracer_arg: WeakLensingArgs ) -> WeakLensingArgs: @@ -219,12 +206,6 @@ def __init__(self, sacc_tracer: str): self.delta_z = parameters.create() self.sacc_tracer = sacc_tracer - @final - def _get_derived_parameters(self) -> DerivedParameterCollection: - derived_parameters = DerivedParameterCollection([]) - - return derived_parameters - def apply(self, tools: ModelingTools, tracer_arg: WeakLensingArgs): """Apply a shift to the photo-z distribution of a source.""" @@ -267,14 +248,6 @@ def _update_source(self, params: ParamsMap): This updates all the contained systematics.""" self.systematics.update(params) - @final - def _get_derived_parameters(self) -> DerivedParameterCollection: - derived_parameters = DerivedParameterCollection([]) - derived_parameters = ( - derived_parameters + self.systematics.get_derived_parameters() - ) - return derived_parameters - def _read(self, sacc_data: sacc.Sacc) -> None: """Read the data for this source from the SACC file. diff --git a/firecrown/likelihood/gauss_family/statistic/supernova.py b/firecrown/likelihood/gauss_family/statistic/supernova.py index e76dc0b0..ebd03e6d 100644 --- a/firecrown/likelihood/gauss_family/statistic/supernova.py +++ b/firecrown/likelihood/gauss_family/statistic/supernova.py @@ -2,7 +2,7 @@ """ from __future__ import annotations -from typing import Optional, final +from typing import Optional import numpy as np import numpy.typing as npt @@ -13,7 +13,6 @@ from ....modeling_tools import ModelingTools from .statistic import Statistic, DataVector, TheoryVector from .... import parameters -from ....parameters import DerivedParameterCollection class Supernova(Statistic): @@ -40,11 +39,6 @@ def read(self, sacc_data: sacc.Sacc): self.data_vector = DataVector.from_list([dp.value for dp in data_points]) self.sacc_indices = np.arange(len(self.data_vector)) - @final - def _get_derived_parameters(self) -> DerivedParameterCollection: - """Return an empty DerivedParameterCollection.""" - return DerivedParameterCollection([]) - def get_data_vector(self) -> DataVector: """Return the data vector; raise exception if there is none.""" assert self.data_vector is not None diff --git a/firecrown/likelihood/gauss_family/statistic/two_point.py b/firecrown/likelihood/gauss_family/statistic/two_point.py index a55b3a30..1aee616b 100644 --- a/firecrown/likelihood/gauss_family/statistic/two_point.py +++ b/firecrown/likelihood/gauss_family/statistic/two_point.py @@ -19,7 +19,6 @@ from .statistic import Statistic, DataVector, TheoryVector from .source.source import Source, Tracer -from ....parameters import DerivedParameterCollection # only supported types are here, anything else will throw # a value error @@ -230,13 +229,6 @@ def _reset(self) -> None: # If we do that, then the CosmoSIS module fails -- because this data # is accessed from that code. - @final - def _get_derived_parameters(self) -> DerivedParameterCollection: - derived_parameters = DerivedParameterCollection([]) - derived_parameters = derived_parameters + self.source0.get_derived_parameters() - derived_parameters = derived_parameters + self.source1.get_derived_parameters() - return derived_parameters - def read(self, sacc_data: sacc.Sacc) -> None: """Read the data for this statistic from the SACC file. diff --git a/firecrown/likelihood/gauss_family/student_t.py b/firecrown/likelihood/gauss_family/student_t.py index 7ba01e58..eefb3bfc 100644 --- a/firecrown/likelihood/gauss_family/student_t.py +++ b/firecrown/likelihood/gauss_family/student_t.py @@ -3,7 +3,7 @@ """ from __future__ import annotations -from typing import List, Optional, final +from typing import List, Optional import numpy as np @@ -11,7 +11,6 @@ from ...modeling_tools import ModelingTools from .statistic.statistic import Statistic from ... import parameters -from ...parameters import DerivedParameterCollection class StudentT(GaussFamily): @@ -43,7 +42,3 @@ def compute_loglike(self, tools: ModelingTools): ccl_cosmo = tools.get_ccl_cosmology() chi2 = self.compute_chisq(ccl_cosmo) return -0.5 * self.nu * np.log(1.0 + chi2 / (self.nu - 1.0)) - - @final - def _get_derived_parameters_gaussian_family(self) -> DerivedParameterCollection: - return DerivedParameterCollection([]) diff --git a/firecrown/models/cluster_abundance.py b/firecrown/models/cluster_abundance.py index 0ecfbed2..38727308 100644 --- a/firecrown/models/cluster_abundance.py +++ b/firecrown/models/cluster_abundance.py @@ -12,7 +12,6 @@ import scipy.integrate from ..updatable import Updatable -from ..parameters import DerivedParameterCollection from .cluster_mass import ClusterMassArgument from .cluster_redshift import ClusterRedshiftArgument @@ -104,10 +103,6 @@ def _reset(self) -> None: """Implementation of the Updatable interface method `_reset`.""" self.halo_mass_function = None - @final - def _get_derived_parameters(self) -> DerivedParameterCollection: - return DerivedParameterCollection([]) - def read(self, sacc_data): """Read the data for this statistic from the SACC file. diff --git a/firecrown/models/cluster_mass_rich_proxy.py b/firecrown/models/cluster_mass_rich_proxy.py index ed25eb67..fee42b42 100644 --- a/firecrown/models/cluster_mass_rich_proxy.py +++ b/firecrown/models/cluster_mass_rich_proxy.py @@ -10,7 +10,6 @@ from ..parameters import ( ParamsMap, - DerivedParameterCollection, ) from .cluster_mass import ClusterMass, ClusterMassArgument from .. import parameters @@ -49,10 +48,6 @@ def _update_cluster_mass(self, params: ParamsMap): This implementation has nothing to do.""" - @final - def _get_derived_parameters(self) -> DerivedParameterCollection: - return DerivedParameterCollection([]) - def read(self, _: sacc.Sacc): """Method to read the data for this source from the SACC file.""" diff --git a/firecrown/models/cluster_mass_true.py b/firecrown/models/cluster_mass_true.py index 0699974f..33d3adc3 100644 --- a/firecrown/models/cluster_mass_true.py +++ b/firecrown/models/cluster_mass_true.py @@ -13,7 +13,6 @@ from ..parameters import ( ParamsMap, - DerivedParameterCollection, ) from .cluster_mass import ClusterMass, ClusterMassArgument @@ -25,10 +24,6 @@ class ClusterMassTrue(ClusterMass): def _update_cluster_mass(self, params: ParamsMap): """Method to update the ClusterMassTrue from the given ParamsMap.""" - @final - def _get_derived_parameters(self) -> DerivedParameterCollection: - return DerivedParameterCollection([]) - def read(self, sacc_data: sacc.Sacc): """Method to read the data for this source from the SACC file.""" diff --git a/firecrown/models/cluster_redshift_spec.py b/firecrown/models/cluster_redshift_spec.py index 746d0bba..32de4294 100644 --- a/firecrown/models/cluster_redshift_spec.py +++ b/firecrown/models/cluster_redshift_spec.py @@ -11,7 +11,6 @@ from ..parameters import ( ParamsMap, - DerivedParameterCollection, ) from .cluster_redshift import ClusterRedshift, ClusterRedshiftArgument @@ -23,10 +22,6 @@ class ClusterRedshiftSpec(ClusterRedshift): def _update_cluster_redshift(self, params: ParamsMap): """Method to update the ClusterRedshiftSpec from the given ParamsMap.""" - @final - def _get_derived_parameters(self) -> DerivedParameterCollection: - return DerivedParameterCollection([]) - def read(self, sacc_data: sacc.Sacc): """Method to read the data for this source from the SACC file.""" diff --git a/firecrown/updatable.py b/firecrown/updatable.py index 28e1d19d..0b50af02 100644 --- a/firecrown/updatable.py +++ b/firecrown/updatable.py @@ -15,7 +15,7 @@ from __future__ import annotations from typing import final, Dict, Optional, Any, List, Union -from abc import ABC, abstractmethod +from abc import ABC from collections import UserList from .parameters import ( ParamsMap, @@ -241,19 +241,30 @@ def get_derived_parameters( statistical analysis. First call returns the DerivedParameterCollection, further calls return None. """ + + if not self._updated: + raise RuntimeError( + "Derived parameters can only be obtained after update has been called." + ) + if self._returned_derived: return None self._returned_derived = True - return self._get_derived_parameters() + derived_parameters = self._get_derived_parameters() + + for item in self._updatables: + derived_parameters = derived_parameters + item.get_derived_parameters() + + return derived_parameters - @abstractmethod def _get_derived_parameters(self) -> DerivedParameterCollection: """Abstract method to be implemented by all concrete classes to return their derived parameters. - Concrete classes must override this. If no derived parameters are required - derived classes must simply return super()._get_derived_parameters(). + Derived classes can override this, returning a DerivedParameterCollection + containing the derived parameters for the class. The default implementation + returns an empty DerivedParameterCollection. """ return DerivedParameterCollection([]) diff --git a/tests/test_updatable.py b/tests/test_updatable.py index ea324406..8557524c 100644 --- a/tests/test_updatable.py +++ b/tests/test_updatable.py @@ -20,15 +20,6 @@ def __init__(self): self.a = parameters.create() - def _update(self, params): - pass - - def _reset(self) -> None: - pass - - def _required_parameters(self): - return RequiredParameters([]) - def _get_derived_parameters(self) -> DerivedParameterCollection: return DerivedParameterCollection([]) From 7cf59b89fcda71cf00825aa51ae0cd95aa08fc49 Mon Sep 17 00:00:00 2001 From: Sandro Dias Pinto Vitenti Date: Fri, 25 Aug 2023 15:19:39 -0300 Subject: [PATCH 17/34] Added tests for updatable nesting. --- .../statistic/cluster_number_counts.py | 4 +- firecrown/parameters.py | 26 ++++- tests/test_updatable.py | 100 ++++++++++++++++-- 3 files changed, 117 insertions(+), 13 deletions(-) diff --git a/firecrown/likelihood/gauss_family/statistic/cluster_number_counts.py b/firecrown/likelihood/gauss_family/statistic/cluster_number_counts.py index eced347a..36f01fc4 100644 --- a/firecrown/likelihood/gauss_family/statistic/cluster_number_counts.py +++ b/firecrown/likelihood/gauss_family/statistic/cluster_number_counts.py @@ -134,11 +134,11 @@ def read(self, sacc_data: sacc.Sacc): try: survey_tracer: SurveyTracer = sacc_data.get_tracer(self.sacc_tracer) - except KeyError: + except KeyError as exc: raise ValueError( f"The SACC file does not contain the SurveyTracer " f"{self.sacc_tracer}." - ) + ) from exc if not isinstance(survey_tracer, SurveyTracer): raise ValueError( f"The SACC tracer {self.sacc_tracer} is not a SurveyTracer." diff --git a/firecrown/parameters.py b/firecrown/parameters.py index 58b55682..dbc25201 100644 --- a/firecrown/parameters.py +++ b/firecrown/parameters.py @@ -155,6 +155,26 @@ def __init__(self, section: str, name: str, val: float): def get_val(self) -> float: return self.val + def __eq__(self, other: object) -> bool: + """Compare two DerivedParameterScalar objects for equality. + + This implementation raises a NotImplemented exception unless both + objects are DerivedParameterScalar objects. + + Two DerivedParameterScalar objects are equal if they have the same + section, name and value. + """ + if not isinstance(other, DerivedParameterScalar): + raise NotImplementedError( + "DerivedParameterScalar comparison is only implemented for " + "DerivedParameterScalar objects" + ) + return ( + self.section == other.section + and self.name == other.name + and self.val == other.val + ) + class DerivedParameterCollection: """Represents a list of DerivedParameter objects.""" @@ -199,8 +219,12 @@ def __eq__(self, other: object): Two DerivedParameterCollection objects are equal if they contain the same DerivedParameter objects. """ + if not isinstance(other, DerivedParameterCollection): - return NotImplemented + raise NotImplementedError( + "DerivedParameterCollection comparison is only implemented for " + "DerivedParameterCollection objects" + ) return self.derived_parameters == other.derived_parameters def __iter__(self) -> Iterator[Tuple[str, str, float]]: diff --git a/tests/test_updatable.py b/tests/test_updatable.py index 8557524c..1c349b48 100644 --- a/tests/test_updatable.py +++ b/tests/test_updatable.py @@ -1,12 +1,16 @@ """ Tests for the Updatable class. """ +from itertools import permutations import pytest +import numpy as np + from firecrown.updatable import Updatable, UpdatableCollection from firecrown import parameters from firecrown.parameters import ( RequiredParameters, ParamsMap, + DerivedParameterScalar, DerivedParameterCollection, ) @@ -20,9 +24,6 @@ def __init__(self): self.a = parameters.create() - def _get_derived_parameters(self) -> DerivedParameterCollection: - return DerivedParameterCollection([]) - class SimpleUpdatable(Updatable): """A concrete type that implements Updatable.""" @@ -34,17 +35,22 @@ def __init__(self): self.x = parameters.create() self.y = parameters.create() - def _update(self, params): - pass - def _reset(self) -> None: - pass +class UpdatableWithDerived(Updatable): + """A concrete type that implements Updatable that implements derived parameters.""" - def _required_parameters(self): - return RequiredParameters([]) + def __init__(self): + """Initialize object with defaulted values.""" + super().__init__() + + self.A = parameters.create() + self.B = parameters.create() def _get_derived_parameters(self) -> DerivedParameterCollection: - return DerivedParameterCollection([]) + derived_scale = DerivedParameterScalar("Section", "Name", self.A + self.B) + derived_parameters = DerivedParameterCollection([derived_scale]) + + return derived_parameters def test_simple_updatable(): @@ -192,3 +198,77 @@ def test_update_rejects_internal_parameters(): assert my_updatable.a is None assert my_updatable.the_meaning_of_life == 42.0 + + +@pytest.fixture(name="nested_updatables", params=permutations(range(3))) +def fixture_nested_updatables(request): + updatables = np.array( + [MinimalUpdatable(), SimpleUpdatable(), UpdatableWithDerived()] + ) + + # Reorder the updatables and set up the nesting + updatables = updatables[list(request.param)] + updatables[0].sub_updatable = updatables[1] + updatables[1].sub_updatable = updatables[2] + + return updatables + + +def test_nesting_updatables_missing_parameters(nested_updatables): + base = nested_updatables[0] + assert isinstance(base, Updatable) + + params = ParamsMap({}) + + with pytest.raises( + RuntimeError, + ): + base.update(params) + + params = ParamsMap({"a": 1.1}) + + with pytest.raises( + RuntimeError, + ): + base.update(params) + + params = ParamsMap({"a": 1.1, "x": 2.0, "y": 3.0}) + + with pytest.raises( + RuntimeError, + ): + base.update(params) + + params = ParamsMap({"a": 1.1, "x": 2.0, "y": 3.0, "A": 4.0, "B": 5.0}) + + base.update(params) + + for updatable in nested_updatables: + assert updatable.is_updated() + + +def test_nesting_updatables_required_parameters(nested_updatables): + base = nested_updatables[0] + assert isinstance(base, Updatable) + + assert base.required_parameters() == RequiredParameters(["a", "x", "y", "A", "B"]) + + +def test_nesting_updatables_derived_parameters(nested_updatables): + base = nested_updatables[0] + assert isinstance(base, Updatable) + + with pytest.raises( + RuntimeError, + match="Derived parameters can only be obtained after update has been called.", + ): + base.get_derived_parameters() + + params = ParamsMap({"a": 1.1, "x": 2.0, "y": 3.0, "A": 4.0, "B": 5.0}) + + base.update(params) + + derived_scale = DerivedParameterScalar("Section", "Name", 9.0) + derived_parameters = DerivedParameterCollection([derived_scale]) + + assert base.get_derived_parameters() == derived_parameters From 1ede0d4f8d8cdad51afe4ac3d8aa3a5b84156686 Mon Sep 17 00:00:00 2001 From: Sandro Dias Pinto Vitenti Date: Fri, 25 Aug 2023 15:37:08 -0300 Subject: [PATCH 18/34] Testing multiple get_derived. --- tests/test_updatable.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_updatable.py b/tests/test_updatable.py index 1c349b48..3e93e2dc 100644 --- a/tests/test_updatable.py +++ b/tests/test_updatable.py @@ -272,3 +272,4 @@ def test_nesting_updatables_derived_parameters(nested_updatables): derived_parameters = DerivedParameterCollection([derived_scale]) assert base.get_derived_parameters() == derived_parameters + assert base.get_derived_parameters() is None From 8d6f69b9b4b2bfdc2cf5dd8f30ea3a89ce80ae3a Mon Sep 17 00:00:00 2001 From: Marc Paterno Date: Thu, 10 Aug 2023 20:22:49 -0500 Subject: [PATCH 19/34] Add test, fix code to raise expected exception --- .../source/test_cluster_number_counts.py | 56 +++++++++++++++++++ 1 file changed, 56 insertions(+) create mode 100644 tests/likelihood/gauss_family/statistic/source/test_cluster_number_counts.py diff --git a/tests/likelihood/gauss_family/statistic/source/test_cluster_number_counts.py b/tests/likelihood/gauss_family/statistic/source/test_cluster_number_counts.py new file mode 100644 index 00000000..eb3a93ae --- /dev/null +++ b/tests/likelihood/gauss_family/statistic/source/test_cluster_number_counts.py @@ -0,0 +1,56 @@ +"""Tests for ClusterNumberCounts. +""" +import numpy as np +import pytest + +import sacc + +from firecrown.modeling_tools import ModelingTools +from firecrown.likelihood.gauss_family.statistic.cluster_number_counts import ( + ClusterNumberCounts, +) + + +@pytest.fixture(name="minimal_stat") +def fixture_minimal_stat() -> ClusterNumberCounts: + """Return a correctly initialized :python:`ClusterNumberCounts` object.""" + stat = ClusterNumberCounts( + survey_tracer="SDSS", + cluster_abundance=None, + cluster_mass=None, + cluster_redshift=None, + ) + return stat + + +@pytest.fixture(name="missing_survey_tracer") +def fixture_missing_survey_tracer() -> sacc.Sacc: + """Return a sacc.Sacc object that lacks a survey_tracer.""" + return sacc.Sacc() + + +@pytest.fixture(name="good_sacc_data") +def fixture_sacc_data(): + """Return a sacc.Sacc object sufficient to correctly set a + :python:`ClusterNumberCounts` object. + """ + data = sacc.Sacc() + return data + + +def test_missing_survey_tracer( + minimal_stat: ClusterNumberCounts, missing_survey_tracer: sacc.Sacc +): + with pytest.raises(ValueError) as exc_info: + minimal_stat.read(missing_survey_tracer) + assert exc_info.value.args[0] == ( + "The SACC file does not contain the " "SurveyTracer SDSS." + ) + + +def test_read_works(minimal_stat: ClusterNumberCounts, good_sacc_data: sacc.Sacc): + """After read() is called, we should be able to get the statistic's + + :python:`DataVector` and also should be able to call + :python:`compute_theory_vector`. + """ From 64abcb172bd03b56e465d7a02e6ad08a81ebe935 Mon Sep 17 00:00:00 2001 From: Marc Paterno Date: Thu, 10 Aug 2023 20:23:18 -0500 Subject: [PATCH 20/34] Improve comment --- firecrown/likelihood/gauss_family/statistic/statistic.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/firecrown/likelihood/gauss_family/statistic/statistic.py b/firecrown/likelihood/gauss_family/statistic/statistic.py index 32e2dfb0..aa842c90 100644 --- a/firecrown/likelihood/gauss_family/statistic/statistic.py +++ b/firecrown/likelihood/gauss_family/statistic/statistic.py @@ -91,7 +91,14 @@ def __iter__(self): class Statistic(Updatable): - """An abstract statistic class (e.g., two-point function, mass function, etc.).""" + """An abstract statistic class. + + Statistics read data from a SACC object as part of a multi-phase + initialization. The manage a :python:`DataVector` and, given a + :python:`ModelingTools` object, can compute a :python:`TheoryVector`. + + Statistics represent things like two-point functions and mass functinos. + """ systematics: List[SourceSystematic] sacc_indices: npt.NDArray[np.int64] From e2355cff9251223a8bf41de94a554e9de7c19e0d Mon Sep 17 00:00:00 2001 From: Marc Paterno Date: Fri, 25 Aug 2023 13:49:27 -0500 Subject: [PATCH 21/34] Improve minimal_stat fixture --- .../gauss_family/statistic/test_cluster_number_counts.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/likelihood/gauss_family/statistic/test_cluster_number_counts.py b/tests/likelihood/gauss_family/statistic/test_cluster_number_counts.py index acf6aca6..99bc2702 100644 --- a/tests/likelihood/gauss_family/statistic/test_cluster_number_counts.py +++ b/tests/likelihood/gauss_family/statistic/test_cluster_number_counts.py @@ -2,6 +2,7 @@ """ import pytest +import ccl.halos import sacc from firecrown.likelihood.gauss_family.statistic.cluster_number_counts import ( @@ -18,8 +19,8 @@ def fixture_minimal_stat() -> ClusterNumberCounts: stat = ClusterNumberCounts( survey_tracer="SDSS", cluster_abundance=ClusterAbundance( - halo_mass_definition=None, - halo_mass_function_name="hmf_func", + halo_mass_definition=ccl.halos.MassDef(0.5, "matter"), + halo_mass_function_name="200m", halo_mass_function_args={}, ), cluster_mass=ClusterMassRich(pivot_mass=10.0, pivot_redshift=1.25), From 256d4611134149c894282b5312581d7260cd3d19 Mon Sep 17 00:00:00 2001 From: Marc Paterno Date: Fri, 25 Aug 2023 14:36:05 -0500 Subject: [PATCH 22/34] Remove duplicate test, in the wrong directory --- .../source/test_cluster_number_counts.py | 56 ------------------- 1 file changed, 56 deletions(-) delete mode 100644 tests/likelihood/gauss_family/statistic/source/test_cluster_number_counts.py diff --git a/tests/likelihood/gauss_family/statistic/source/test_cluster_number_counts.py b/tests/likelihood/gauss_family/statistic/source/test_cluster_number_counts.py deleted file mode 100644 index eb3a93ae..00000000 --- a/tests/likelihood/gauss_family/statistic/source/test_cluster_number_counts.py +++ /dev/null @@ -1,56 +0,0 @@ -"""Tests for ClusterNumberCounts. -""" -import numpy as np -import pytest - -import sacc - -from firecrown.modeling_tools import ModelingTools -from firecrown.likelihood.gauss_family.statistic.cluster_number_counts import ( - ClusterNumberCounts, -) - - -@pytest.fixture(name="minimal_stat") -def fixture_minimal_stat() -> ClusterNumberCounts: - """Return a correctly initialized :python:`ClusterNumberCounts` object.""" - stat = ClusterNumberCounts( - survey_tracer="SDSS", - cluster_abundance=None, - cluster_mass=None, - cluster_redshift=None, - ) - return stat - - -@pytest.fixture(name="missing_survey_tracer") -def fixture_missing_survey_tracer() -> sacc.Sacc: - """Return a sacc.Sacc object that lacks a survey_tracer.""" - return sacc.Sacc() - - -@pytest.fixture(name="good_sacc_data") -def fixture_sacc_data(): - """Return a sacc.Sacc object sufficient to correctly set a - :python:`ClusterNumberCounts` object. - """ - data = sacc.Sacc() - return data - - -def test_missing_survey_tracer( - minimal_stat: ClusterNumberCounts, missing_survey_tracer: sacc.Sacc -): - with pytest.raises(ValueError) as exc_info: - minimal_stat.read(missing_survey_tracer) - assert exc_info.value.args[0] == ( - "The SACC file does not contain the " "SurveyTracer SDSS." - ) - - -def test_read_works(minimal_stat: ClusterNumberCounts, good_sacc_data: sacc.Sacc): - """After read() is called, we should be able to get the statistic's - - :python:`DataVector` and also should be able to call - :python:`compute_theory_vector`. - """ From 73e0c718c4a244bdd001187ee27952549ea18935 Mon Sep 17 00:00:00 2001 From: Marc Paterno Date: Fri, 25 Aug 2023 14:37:02 -0500 Subject: [PATCH 23/34] Fix package name usage --- .../gauss_family/statistic/test_cluster_number_counts.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/likelihood/gauss_family/statistic/test_cluster_number_counts.py b/tests/likelihood/gauss_family/statistic/test_cluster_number_counts.py index 99bc2702..ecea740a 100644 --- a/tests/likelihood/gauss_family/statistic/test_cluster_number_counts.py +++ b/tests/likelihood/gauss_family/statistic/test_cluster_number_counts.py @@ -2,7 +2,7 @@ """ import pytest -import ccl.halos +import pyccl.halos import sacc from firecrown.likelihood.gauss_family.statistic.cluster_number_counts import ( @@ -19,7 +19,7 @@ def fixture_minimal_stat() -> ClusterNumberCounts: stat = ClusterNumberCounts( survey_tracer="SDSS", cluster_abundance=ClusterAbundance( - halo_mass_definition=ccl.halos.MassDef(0.5, "matter"), + halo_mass_definition=pyccl.halos.MassDef(0.5, "matter"), halo_mass_function_name="200m", halo_mass_function_args={}, ), From e1419bdc613ce25d1795da3f0f2774d0e2d9cf12 Mon Sep 17 00:00:00 2001 From: Sandro Dias Pinto Vitenti Date: Sat, 26 Aug 2023 20:59:18 -0300 Subject: [PATCH 24/34] Minimal tests for Cobaya --- .../connector/cobaya/test_model_likelihood.py | 120 ++++++++++++++++++ 1 file changed, 120 insertions(+) create mode 100644 tests/connector/cobaya/test_model_likelihood.py diff --git a/tests/connector/cobaya/test_model_likelihood.py b/tests/connector/cobaya/test_model_likelihood.py new file mode 100644 index 00000000..9f5b7114 --- /dev/null +++ b/tests/connector/cobaya/test_model_likelihood.py @@ -0,0 +1,120 @@ +"""Unit tests for the cobaya Mapping connector.""" + +import pytest +import pyccl as ccl +from cobaya.model import get_model, Model +from firecrown.connector.cobaya.ccl import CCLConnector +from firecrown.connector.cobaya.likelihood import LikelihoodConnector + + +def test_cobaya_ccl_initialize(): + ccl_connector = CCLConnector(info={"input_style": "CAMB"}) + + assert isinstance(ccl_connector, CCLConnector) + assert ccl_connector.input_style == "CAMB" + + +def test_cobaya_ccl_initialize_with_params(): + ccl_connector = CCLConnector(info={"input_style": "CAMB"}) + + ccl_connector.initialize_with_params() + + assert isinstance(ccl_connector, CCLConnector) + assert ccl_connector.input_style == "CAMB" + + +def test_cobaya_likelihood_initialize(): + lk_connector = LikelihoodConnector( + info={"firecrownIni": "tests/likelihood/lkdir/lkscript.py"} + ) + + assert isinstance(lk_connector, LikelihoodConnector) + assert lk_connector.firecrownIni == "tests/likelihood/lkdir/lkscript.py" + + +def test_cobaya_likelihood_initialize_with_params(): + lk_connector = LikelihoodConnector( + info={"firecrownIni": "tests/likelihood/lkdir/lkscript.py"} + ) + + lk_connector.initialize_with_params() + + assert isinstance(lk_connector, LikelihoodConnector) + assert lk_connector.firecrownIni == "tests/likelihood/lkdir/lkscript.py" + + +@pytest.fixture(name="fiducial_params") +def fixture_fiducial_params(): + fiducial_params = { + "ombh2": 0.022, + "omch2": 0.12, + "H0": 68, + "tau": 0.07, + "As": 2.2e-9, + "ns": 0.96, + "mnu": 0.06, + "nnu": 3.046, + } + + return fiducial_params + + +def test_cobaya_ccl_with_model(fiducial_params): + # Fiducial parameters for CAMB + info_fiducial = { + "params": fiducial_params, + "likelihood": { + "test_lk": { + "external": lambda _self=None: 0.0, + "requires": {"pyccl": None}, + } + }, + "theory": { + "camb": {"extra_args": {"num_massive_neutrinos": 1}}, + "fcc_ccl": {"external": CCLConnector, "input_style": "CAMB"}, + }, + } + + model_fiducial = get_model(info_fiducial) + assert isinstance(model_fiducial, Model) + model_fiducial.logposterior({}) + + cosmo = model_fiducial.provider.get_pyccl() + assert isinstance(cosmo, ccl.Cosmology) + + h = fiducial_params["H0"] / 100.0 + assert cosmo["H0"] == pytest.approx(fiducial_params["H0"], rel=1.0e-5) + assert cosmo["Omega_c"] == pytest.approx( + fiducial_params["omch2"] / h**2, rel=1.0e-5 + ) + assert cosmo["Omega_b"] == pytest.approx( + fiducial_params["ombh2"] / h**2, rel=1.0e-5 + ) + assert cosmo["Omega_k"] == pytest.approx(0.0, rel=1.0e-5) + assert cosmo["A_s"] == pytest.approx(fiducial_params["As"], rel=1.0e-5) + assert cosmo["n_s"] == pytest.approx(fiducial_params["ns"], rel=1.0e-5) + # The following test fails because of we are using the default + # neutrino hierarchy, which is normal, while CAMB depends on the + # parameter which we do not have access to. + # assert cosmo["m_nu"] == pytest.approx(0.06, rel=1.0e-5) + + +def test_cobaya_ccl_with_likelihood(fiducial_params): + # Fiducial parameters for CAMB + info_fiducial = { + "params": fiducial_params, + "likelihood": { + "lk_connector": { + "external": LikelihoodConnector, + "firecrownIni": "tests/likelihood/lkdir/lkscript.py", + } + }, + "theory": { + "camb": {"extra_args": {"num_massive_neutrinos": 1}}, + "fcc_ccl": {"external": CCLConnector, "input_style": "CAMB"}, + }, + } + + model_fiducial = get_model(info_fiducial) + assert isinstance(model_fiducial, Model) + assert model_fiducial.logposterior({}).logpost == -3.0 From 7da2248f50ace517a51e01d2c98868c648f565d9 Mon Sep 17 00:00:00 2001 From: Sandro Dias Pinto Vitenti Date: Sat, 26 Aug 2023 21:57:06 -0300 Subject: [PATCH 25/34] More tests for Cobaya --- firecrown/connector/cobaya/likelihood.py | 9 +- .../connector/cobaya/test_model_likelihood.py | 135 +++++++++++++++++- .../likelihood/lkdir/lk_sampler_parameter.py | 12 ++ tests/likelihood/lkdir/lkmodule.py | 47 +++--- 4 files changed, 175 insertions(+), 28 deletions(-) create mode 100644 tests/likelihood/lkdir/lk_sampler_parameter.py diff --git a/firecrown/connector/cobaya/likelihood.py b/firecrown/connector/cobaya/likelihood.py index df272314..7262c4a9 100644 --- a/firecrown/connector/cobaya/likelihood.py +++ b/firecrown/connector/cobaya/likelihood.py @@ -29,7 +29,14 @@ def initialize(self): if not hasattr(self, "build_parameters"): build_parameters = NamedParameters() else: - build_parameters = self.build_parameters + if isinstance(self.build_parameters, dict): + build_parameters = NamedParameters(self.build_parameters) + else: + if not isinstance(self.build_parameters, NamedParameters): + raise TypeError( + "build_parameters must be a NamedParameters or dict" + ) + build_parameters = self.build_parameters self.likelihood, self.tools = load_likelihood( self.firecrownIni, build_parameters diff --git a/tests/connector/cobaya/test_model_likelihood.py b/tests/connector/cobaya/test_model_likelihood.py index 9f5b7114..cfd6fdc9 100644 --- a/tests/connector/cobaya/test_model_likelihood.py +++ b/tests/connector/cobaya/test_model_likelihood.py @@ -3,8 +3,10 @@ import pytest import pyccl as ccl from cobaya.model import get_model, Model +from cobaya.log import LoggedError from firecrown.connector.cobaya.ccl import CCLConnector from firecrown.connector.cobaya.likelihood import LikelihoodConnector +from firecrown.likelihood.likelihood import NamedParameters def test_cobaya_ccl_initialize(): @@ -45,6 +47,7 @@ def test_cobaya_likelihood_initialize_with_params(): @pytest.fixture(name="fiducial_params") def fixture_fiducial_params(): + # Fiducial parameters for CAMB fiducial_params = { "ombh2": 0.022, "omch2": 0.12, @@ -59,7 +62,7 @@ def fixture_fiducial_params(): return fiducial_params -def test_cobaya_ccl_with_model(fiducial_params): +def test_cobaya_ccl_model(fiducial_params): # Fiducial parameters for CAMB info_fiducial = { "params": fiducial_params, @@ -99,8 +102,7 @@ def test_cobaya_ccl_with_model(fiducial_params): # assert cosmo["m_nu"] == pytest.approx(0.06, rel=1.0e-5) -def test_cobaya_ccl_with_likelihood(fiducial_params): - # Fiducial parameters for CAMB +def test_cobaya_ccl_likelihood(fiducial_params): info_fiducial = { "params": fiducial_params, "likelihood": { @@ -118,3 +120,130 @@ def test_cobaya_ccl_with_likelihood(fiducial_params): model_fiducial = get_model(info_fiducial) assert isinstance(model_fiducial, Model) assert model_fiducial.logposterior({}).logpost == -3.0 + + +def test_parameterized_likelihood_missing(fiducial_params): + info_fiducial = { + "params": fiducial_params, + "likelihood": { + "lk_connector": { + "external": LikelihoodConnector, + "firecrownIni": "tests/likelihood/lkdir/lk_needing_param.py", + } + }, + "theory": { + "camb": {"extra_args": {"num_massive_neutrinos": 1}}, + "fcc_ccl": {"external": CCLConnector, "input_style": "CAMB"}, + }, + } + + with pytest.raises(KeyError): + _ = get_model(info_fiducial) + + +def test_parameterized_likelihood_wrong_type(fiducial_params): + info_fiducial = { + "params": fiducial_params, + "likelihood": { + "lk_connector": { + "external": LikelihoodConnector, + "firecrownIni": "tests/likelihood/lkdir/lk_needing_param.py", + "build_parameters": 1.0, + } + }, + "theory": { + "camb": {"extra_args": {"num_massive_neutrinos": 1}}, + "fcc_ccl": {"external": CCLConnector, "input_style": "CAMB"}, + }, + } + + with pytest.raises( + TypeError, match="build_parameters must be a NamedParameters or dict" + ): + _ = get_model(info_fiducial) + + +def test_parameterized_likelihood_dict(fiducial_params): + info_fiducial = { + "params": fiducial_params, + "likelihood": { + "lk_connector": { + "external": LikelihoodConnector, + "firecrownIni": "tests/likelihood/lkdir/lk_needing_param.py", + "build_parameters": {"sacc_filename": "this_sacc_does_not_exist.fits"}, + } + }, + "theory": { + "camb": {"extra_args": {"num_massive_neutrinos": 1}}, + "fcc_ccl": {"external": CCLConnector, "input_style": "CAMB"}, + }, + } + + model_fiducial = get_model(info_fiducial) + assert isinstance(model_fiducial, Model) + assert model_fiducial.logposterior({}).logpost == -1.5 + + +def test_parameterized_likelihood_namedparameters(fiducial_params): + info_fiducial = { + "params": fiducial_params, + "likelihood": { + "lk_connector": { + "external": LikelihoodConnector, + "firecrownIni": "tests/likelihood/lkdir/lk_needing_param.py", + "build_parameters": NamedParameters( + {"sacc_filename": "this_sacc_does_not_exist.fits"} + ), + } + }, + "theory": { + "camb": {"extra_args": {"num_massive_neutrinos": 1}}, + "fcc_ccl": {"external": CCLConnector, "input_style": "CAMB"}, + }, + } + + model_fiducial = get_model(info_fiducial) + assert isinstance(model_fiducial, Model) + assert model_fiducial.logposterior({}).logpost == -1.5 + + +def test_sampler_parameter_likelihood_missing(fiducial_params): + info_fiducial = { + "params": fiducial_params, + "likelihood": { + "lk_connector": { + "external": LikelihoodConnector, + "firecrownIni": "tests/likelihood/lkdir/lk_sampler_parameter.py", + "build_parameters": NamedParameters({"sacc_tracer": "my_prefix"}), + } + }, + "theory": { + "camb": {"extra_args": {"num_massive_neutrinos": 1}}, + "fcc_ccl": {"external": CCLConnector, "input_style": "CAMB"}, + }, + } + + with pytest.raises(LoggedError, match="my_prefix_sampler_param0"): + _ = get_model(info_fiducial) + + +def test_sampler_parameter_likelihood(fiducial_params): + fiducial_params.update({"my_prefix_sampler_param0": 1.0}) + info_fiducial = { + "params": fiducial_params, + "likelihood": { + "lk_connector": { + "external": LikelihoodConnector, + "firecrownIni": "tests/likelihood/lkdir/lk_sampler_parameter.py", + "build_parameters": NamedParameters({"sacc_tracer": "my_prefix"}), + } + }, + "theory": { + "camb": {"extra_args": {"num_massive_neutrinos": 1}}, + "fcc_ccl": {"external": CCLConnector, "input_style": "CAMB"}, + }, + } + + model_fiducial = get_model(info_fiducial) + assert isinstance(model_fiducial, Model) + assert model_fiducial.logposterior({}).logpost == -2.1 diff --git a/tests/likelihood/lkdir/lk_sampler_parameter.py b/tests/likelihood/lkdir/lk_sampler_parameter.py new file mode 100644 index 00000000..5bc943c9 --- /dev/null +++ b/tests/likelihood/lkdir/lk_sampler_parameter.py @@ -0,0 +1,12 @@ +""" +Provides a trivial likelihood factory function for testing purposes. +The likelihood created requires a string parameter named "sacc_tracer" +and has a sampler parameter named "sampler_param0". +""" +from firecrown.likelihood.likelihood import NamedParameters +from . import lkmodule + + +def build_likelihood(params: NamedParameters): + """Return a SamplerParameterLikelihood object.""" + return lkmodule.sampler_parameter_likelihood(params) diff --git a/tests/likelihood/lkdir/lkmodule.py b/tests/likelihood/lkdir/lkmodule.py index 5e7d518f..99488eca 100644 --- a/tests/likelihood/lkdir/lkmodule.py +++ b/tests/likelihood/lkdir/lkmodule.py @@ -2,12 +2,9 @@ Provides a trivial likelihood class and factory function for testing purposes. """ import sacc -from firecrown.parameters import ( - RequiredParameters, - DerivedParameterCollection, -) from firecrown.likelihood.likelihood import Likelihood, NamedParameters from firecrown.modeling_tools import ModelingTools +from firecrown import parameters class EmptyLikelihood(Likelihood): @@ -21,17 +18,6 @@ def __init__(self) -> None: def read(self, sacc_data: sacc.Sacc) -> None: """This class has nothing to read.""" - def _reset(self) -> None: - """This class has no state to reset.""" - - def _required_parameters(self) -> RequiredParameters: - """Return an empty RequiredParameters object.""" - return RequiredParameters([]) - - def _get_derived_parameters(self) -> DerivedParameterCollection: - """Return an empty DerivedParameterCollection.""" - return DerivedParameterCollection([]) - def compute_loglike(self, tools: ModelingTools) -> float: """Return a constant value of the likelihood, determined by the value of self.placeholder.""" @@ -56,22 +42,35 @@ def __init__(self, params: NamedParameters): def read(self, sacc_data: sacc.Sacc) -> None: """This class has nothing to read.""" - def _reset(self) -> None: - """This class has no state to reset""" + def compute_loglike(self, tools: ModelingTools) -> float: + """Return a constant value of the likelihood.""" + return -1.5 - def _required_parameters(self) -> RequiredParameters: - """Return an empty RequiredParameters object.""" - return RequiredParameters([]) - def _get_derived_parameters(self) -> DerivedParameterCollection: - """Return an empty DerivedParameterCollection.""" - return DerivedParameterCollection([]) +class SamplerParameterLikelihood(Likelihood): + """A minimal likelihood for testing. This likelihood requires a parameter + named 'sacc_filename'.""" + + def __init__(self, params: NamedParameters): + """Initialize the ParameterizedLikelihood by reading the specificed + sacc_filename value.""" + super().__init__() + self.sacc_tracer = params.get_string("sacc_tracer") + self.sampler_param0 = parameters.create() + + def read(self, sacc_data: sacc.Sacc) -> None: + """This class has nothing to read.""" def compute_loglike(self, tools: ModelingTools) -> float: """Return a constant value of the likelihood.""" - return -1.5 + return -2.1 def parameterized_likelihood(params: NamedParameters): """Return a ParameterizedLikelihood object.""" return ParamaterizedLikelihood(params) + + +def sampler_parameter_likelihood(params: NamedParameters): + """Return a SamplerParameterLikelihood object.""" + return SamplerParameterLikelihood(params) From 0d004029077b91fb9e9901fba68be13cf0e519c6 Mon Sep 17 00:00:00 2001 From: Sandro Dias Pinto Vitenti Date: Sat, 26 Aug 2023 23:06:07 -0300 Subject: [PATCH 26/34] Tests for Cobaya derived parameters. --- .../connector/cobaya/test_model_likelihood.py | 24 +++++++++++++ tests/likelihood/lkdir/lkmodule.py | 35 +++++++++++++++++-- 2 files changed, 57 insertions(+), 2 deletions(-) diff --git a/tests/connector/cobaya/test_model_likelihood.py b/tests/connector/cobaya/test_model_likelihood.py index cfd6fdc9..7c50a928 100644 --- a/tests/connector/cobaya/test_model_likelihood.py +++ b/tests/connector/cobaya/test_model_likelihood.py @@ -247,3 +247,27 @@ def test_sampler_parameter_likelihood(fiducial_params): model_fiducial = get_model(info_fiducial) assert isinstance(model_fiducial, Model) assert model_fiducial.logposterior({}).logpost == -2.1 + + +def test_derived_parameter_likelihood(fiducial_params): + fiducial_params.update({"derived_section__derived_param0": {"derived": True}}) + info_fiducial = { + "params": fiducial_params, + "likelihood": { + "lk_connector": { + "external": LikelihoodConnector, + "firecrownIni": "tests/likelihood/lkdir/lk_derived_parameter.py", + "derived_parameters": ["derived_section__derived_param0"], + } + }, + "theory": { + "camb": {"extra_args": {"num_massive_neutrinos": 1}}, + "fcc_ccl": {"external": CCLConnector, "input_style": "CAMB"}, + }, + } + + model_fiducial = get_model(info_fiducial) + assert isinstance(model_fiducial, Model) + logpost = model_fiducial.logposterior({}) + assert logpost.logpost == -3.14 + assert logpost.derived[0] == 1.0 diff --git a/tests/likelihood/lkdir/lkmodule.py b/tests/likelihood/lkdir/lkmodule.py index 99488eca..4debd6bb 100644 --- a/tests/likelihood/lkdir/lkmodule.py +++ b/tests/likelihood/lkdir/lkmodule.py @@ -1,6 +1,7 @@ """ Provides a trivial likelihood class and factory function for testing purposes. """ +from firecrown.parameters import DerivedParameterCollection, DerivedParameterScalar import sacc from firecrown.likelihood.likelihood import Likelihood, NamedParameters from firecrown.modeling_tools import ModelingTools @@ -52,8 +53,9 @@ class SamplerParameterLikelihood(Likelihood): named 'sacc_filename'.""" def __init__(self, params: NamedParameters): - """Initialize the ParameterizedLikelihood by reading the specificed - sacc_filename value.""" + """Initialize the SamplerParameterLikelihood by reading the specificed + sacc_tracer value and creates a sampler parameter called "sampler_param0". + """ super().__init__() self.sacc_tracer = params.get_string("sacc_tracer") self.sampler_param0 = parameters.create() @@ -66,6 +68,30 @@ def compute_loglike(self, tools: ModelingTools) -> float: return -2.1 +class DerivedParameterLikelihood(Likelihood): + """A minimal likelihood for testing. This likelihood requires a parameter + named 'sacc_filename'.""" + + def __init__(self): + """Initialize the DerivedParameterLikelihood where _get_derived_parameters + creates a derived parameter called "derived_param0". + """ + super().__init__() + self.placeholder = 1.0 + + def _get_derived_parameters(self) -> DerivedParameterCollection: + return DerivedParameterCollection( + [DerivedParameterScalar("derived_section", "derived_param0", 1.0)] + ) + + def read(self, sacc_data: sacc.Sacc) -> None: + """This class has nothing to read.""" + + def compute_loglike(self, tools: ModelingTools) -> float: + """Return a constant value of the likelihood.""" + return -3.14 + + def parameterized_likelihood(params: NamedParameters): """Return a ParameterizedLikelihood object.""" return ParamaterizedLikelihood(params) @@ -74,3 +100,8 @@ def parameterized_likelihood(params: NamedParameters): def sampler_parameter_likelihood(params: NamedParameters): """Return a SamplerParameterLikelihood object.""" return SamplerParameterLikelihood(params) + + +def derived_parameter_likelihood(): + """Return a DerivedParameterLikelihood object.""" + return DerivedParameterLikelihood() From 257a9d6da5e5c280ce2bfee589ca8b86d73f8247 Mon Sep 17 00:00:00 2001 From: Sandro Dias Pinto Vitenti Date: Sat, 26 Aug 2023 23:06:16 -0300 Subject: [PATCH 27/34] Tests for Cobaya derived parameters. --- tests/likelihood/lkdir/lk_derived_parameter.py | 11 +++++++++++ 1 file changed, 11 insertions(+) create mode 100644 tests/likelihood/lkdir/lk_derived_parameter.py diff --git a/tests/likelihood/lkdir/lk_derived_parameter.py b/tests/likelihood/lkdir/lk_derived_parameter.py new file mode 100644 index 00000000..baff91fa --- /dev/null +++ b/tests/likelihood/lkdir/lk_derived_parameter.py @@ -0,0 +1,11 @@ +""" +Provides a trivial likelihood factory function for testing purposes. +The likelihood created provides one derived parameter named "derived_param0". +""" +from firecrown.likelihood.likelihood import NamedParameters +from . import lkmodule + + +def build_likelihood(_: NamedParameters): + """Return a DerivedParameterLikelihood object.""" + return lkmodule.derived_parameter_likelihood() From b2326bc7f7a7939f7a5b0ac37be17ef1ef3871bc Mon Sep 17 00:00:00 2001 From: Sandro Dias Pinto Vitenti Date: Sat, 26 Aug 2023 23:17:08 -0300 Subject: [PATCH 28/34] Fixed pylint issue on tests. --- tests/likelihood/lkdir/lkmodule.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/likelihood/lkdir/lkmodule.py b/tests/likelihood/lkdir/lkmodule.py index 4debd6bb..d18a46f3 100644 --- a/tests/likelihood/lkdir/lkmodule.py +++ b/tests/likelihood/lkdir/lkmodule.py @@ -1,8 +1,8 @@ """ Provides a trivial likelihood class and factory function for testing purposes. """ -from firecrown.parameters import DerivedParameterCollection, DerivedParameterScalar import sacc +from firecrown.parameters import DerivedParameterCollection, DerivedParameterScalar from firecrown.likelihood.likelihood import Likelihood, NamedParameters from firecrown.modeling_tools import ModelingTools from firecrown import parameters From b654c8fbb3456b5922a7ad9b7e4aa2fb045be100 Mon Sep 17 00:00:00 2001 From: Sandro Dias Pinto Vitenti Date: Sat, 26 Aug 2023 23:29:06 -0300 Subject: [PATCH 29/34] Removed redundant methods in cobaya connector. --- firecrown/connector/cobaya/ccl.py | 7 ------- firecrown/connector/cobaya/likelihood.py | 4 ---- 2 files changed, 11 deletions(-) diff --git a/firecrown/connector/cobaya/ccl.py b/firecrown/connector/cobaya/ccl.py index 6c8fa400..31f86a14 100644 --- a/firecrown/connector/cobaya/ccl.py +++ b/firecrown/connector/cobaya/ccl.py @@ -38,13 +38,6 @@ def initialize(self): self.z_Pk = np.arange(0.0, 6.0, 1) self.Pk_kmax = 1.0 - def get_param(self, p: str) -> None: - """Return the current value of the parameter named 'p'. - - This implementation always returns None. - """ - return None - def initialize_with_params(self): """Required by Cobaya. diff --git a/firecrown/connector/cobaya/likelihood.py b/firecrown/connector/cobaya/likelihood.py index 7262c4a9..3f58d201 100644 --- a/firecrown/connector/cobaya/likelihood.py +++ b/firecrown/connector/cobaya/likelihood.py @@ -42,10 +42,6 @@ def initialize(self): self.firecrownIni, build_parameters ) - def get_param(self, p: str): - """Return the current value of the parameter named 'p'.""" - return self._current_state["derived"][p] - def initialize_with_params(self) -> None: """Required by Cobaya. From ecd3d230b05fe885ff9a2566e84a1d9528fcb854 Mon Sep 17 00:00:00 2001 From: Sandro Dias Pinto Vitenti Date: Sun, 27 Aug 2023 01:03:48 -0300 Subject: [PATCH 30/34] Added tests for DerivedParameter* eq. --- tests/test_parameters.py | 52 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/tests/test_parameters.py b/tests/test_parameters.py index d13bfdfe..394eaf51 100644 --- a/tests/test_parameters.py +++ b/tests/test_parameters.py @@ -150,3 +150,55 @@ def test_derived_parameters_collection_add_iter(): assert section == derived_parameter.section assert name == derived_parameter.name assert val == derived_parameter.get_val() + + +def test_derived_parameter_eq(): + dv1 = DerivedParameterScalar("sec1", "name1", 3.14) + dv2 = DerivedParameterScalar("sec1", "name1", 3.14) + + assert dv1 == dv2 + + +def test_derived_parameter_eq_invalid(): + dv1 = DerivedParameterScalar("sec1", "name1", 3.14) + + with pytest.raises( + NotImplementedError, + match="DerivedParameterScalar comparison is only " + "implemented for DerivedParameterScalar objects", + ): + _ = dv1 == 1.0 + + +def test_derived_parameters_collection_eq(): + olist1 = [ + DerivedParameterScalar("sec1", "name1", 3.14), + DerivedParameterScalar("sec2", "name2", 2.72), + DerivedParameterScalar("sec2", "name3", 0.58), + ] + dpc1 = DerivedParameterCollection(olist1) + + olist2 = [ + DerivedParameterScalar("sec1", "name1", 3.14), + DerivedParameterScalar("sec2", "name2", 2.72), + DerivedParameterScalar("sec2", "name3", 0.58), + ] + dpc2 = DerivedParameterCollection(olist2) + + assert dpc1 == dpc2 + + +def test_derived_parameters_collection_eq_invalid(): + olist1 = [ + DerivedParameterScalar("sec1", "name1", 3.14), + DerivedParameterScalar("sec2", "name2", 2.72), + DerivedParameterScalar("sec2", "name3", 0.58), + ] + dpc1 = DerivedParameterCollection(olist1) + + with pytest.raises( + NotImplementedError, + match="DerivedParameterCollection comparison is only " + "implemented for DerivedParameterCollection objects", + ): + _ = dpc1 == 1.0 From f34d87e7a700547d3eeff65291acedc33166f441 Mon Sep 17 00:00:00 2001 From: Sandro Dias Pinto Vitenti Date: Sun, 27 Aug 2023 10:32:08 -0300 Subject: [PATCH 31/34] Removed old interface which were replaced by Updatable.required_parameters. --- firecrown/likelihood/likelihood.py | 19 +------------------ 1 file changed, 1 insertion(+), 18 deletions(-) diff --git a/firecrown/likelihood/likelihood.py b/firecrown/likelihood/likelihood.py index 43b8869d..ebcbe3f8 100644 --- a/firecrown/likelihood/likelihood.py +++ b/firecrown/likelihood/likelihood.py @@ -10,7 +10,7 @@ """ from __future__ import annotations -from typing import List, Mapping, Tuple, Union, Optional +from typing import Mapping, Tuple, Union, Optional from abc import abstractmethod import warnings import importlib @@ -40,28 +40,11 @@ def __init__(self) -> None: """Default initialization for a base Likelihood object.""" super().__init__() - self.params_names: Optional[List[str]] = None self.predicted_data_vector: Optional[npt.NDArray[np.double]] = None self.measured_data_vector: Optional[npt.NDArray[np.double]] = None self.inv_cov: Optional[npt.NDArray[np.double]] = None self.statistics: UpdatableCollection = UpdatableCollection() - def set_params_names(self, params_names: List[str]) -> None: - """Set the parameter names for this Likelihood.""" - self.params_names = params_names - - def get_params_names(self) -> Optional[List[str]]: - """Return the parameter names of this Likelihood.""" - - # TODO: This test for the presence of the instance variable - # params_names seems unnecessary; we set the instance variable - # to None in the initializer. Since we would return an empty list - # if we did *not* have an instance variable, should we just make - # the default value used in the initializer an empty list? - if hasattr(self, "params_names"): - return self.params_names - return [] - @abstractmethod def read(self, sacc_data: sacc.Sacc): """Read the covariance matrix for this likelihood from the SACC file.""" From c8c84ee6aa527c8d08ae7d2a3f25c0d4df1dabf9 Mon Sep 17 00:00:00 2001 From: Sandro Dias Pinto Vitenti Date: Sun, 27 Aug 2023 11:24:54 -0300 Subject: [PATCH 32/34] More testing for load_likelihood. --- firecrown/likelihood/likelihood.py | 18 ++++-- tests/likelihood/lkdir/lkscript_invalid.py | 12 ++++ .../lkdir/lkscript_not_a_function.py | 9 +++ .../lkdir/lkscript_returns_wrong_type.py | 11 ++++ tests/likelihood/test_likelihood.py | 59 ++++++++++++++++++- 5 files changed, 103 insertions(+), 6 deletions(-) create mode 100644 tests/likelihood/lkdir/lkscript_invalid.py create mode 100644 tests/likelihood/lkdir/lkscript_not_a_function.py create mode 100644 tests/likelihood/lkdir/lkscript_returns_wrong_type.py diff --git a/firecrown/likelihood/likelihood.py b/firecrown/likelihood/likelihood.py index ebcbe3f8..95954684 100644 --- a/firecrown/likelihood/likelihood.py +++ b/firecrown/likelihood/likelihood.py @@ -171,14 +171,22 @@ def load_likelihood( modname, filename, submodule_search_locations=[script_path] ) - if spec is None: - raise ImportError(f"Could not load spec for module '{modname}' at: {filename}") + # Apparently, the spec can be None if the file extension is not .py + # However, we already checked for that, so this should never happen. + # if spec is None: + # raise ImportError(f"Could not load spec for module '{modname}' at: {filename}") + # Instead, we just assert that it is not None. + assert spec is not None mod = importlib.util.module_from_spec(spec) sys.modules[modname] = mod - if spec.loader is None: - raise ImportError(f"Spec for module '{modname}' has no loader.") - + # Apparently, the spec.loader can be None if the file extension is not + # recognized. However, we already checked for that, so this should never + # happen. + # if spec.loader is None: + # raise ImportError(f"Spec for module '{modname}' has no loader.") + # Instead, we just assert that it is not None. + assert spec.loader is not None spec.loader.exec_module(mod) if not hasattr(mod, "build_likelihood"): diff --git a/tests/likelihood/lkdir/lkscript_invalid.py b/tests/likelihood/lkdir/lkscript_invalid.py new file mode 100644 index 00000000..02aece5b --- /dev/null +++ b/tests/likelihood/lkdir/lkscript_invalid.py @@ -0,0 +1,12 @@ +""" +Provides a trivial likelihood factory function for testing purposes. +This module should be loaded by the test_load_likelihood_submodule test. +It should raise an exception because the factory function does not define +a build_likelihood Callable. +""" +from . import lkmodule + + +def not_a_build_likelihood(_): + """Return an EmptyLikelihood object.""" + return lkmodule.empty_likelihood() diff --git a/tests/likelihood/lkdir/lkscript_not_a_function.py b/tests/likelihood/lkdir/lkscript_not_a_function.py new file mode 100644 index 00000000..95bd5c30 --- /dev/null +++ b/tests/likelihood/lkdir/lkscript_not_a_function.py @@ -0,0 +1,9 @@ +""" +Provides a trivial likelihood factory function for testing purposes. +This module should be loaded by the test_load_likelihood_submodule test. +It should raise an exception because the factory function does not define +a build_likelihood as a Callable. +""" + + +build_likelihood = "I am not a function" diff --git a/tests/likelihood/lkdir/lkscript_returns_wrong_type.py b/tests/likelihood/lkdir/lkscript_returns_wrong_type.py new file mode 100644 index 00000000..6e13b0f6 --- /dev/null +++ b/tests/likelihood/lkdir/lkscript_returns_wrong_type.py @@ -0,0 +1,11 @@ +""" +Provides a trivial likelihood factory function for testing purposes. +This module should be loaded by the test_load_likelihood_submodule test. +It should raise an exception because the factory function does not return +a Likelihood object. +""" + + +def build_likelihood(_): + """Return an EmptyLikelihood object.""" + return "Not a Likelihood" diff --git a/tests/likelihood/test_likelihood.py b/tests/likelihood/test_likelihood.py index b3438047..eb59c8a3 100644 --- a/tests/likelihood/test_likelihood.py +++ b/tests/likelihood/test_likelihood.py @@ -2,7 +2,7 @@ Tests for the module firecrown.likelihood.likelihood. """ import os - +import pytest from firecrown.likelihood.likelihood import load_likelihood, NamedParameters @@ -14,3 +14,60 @@ def test_load_likelihood_submodule(): dir_path = os.path.dirname(os.path.realpath(__file__)) load_likelihood(os.path.join(dir_path, "lkdir/lkscript.py"), NamedParameters()) + + +def test_load_likelihood_submodule_invalid(): + """The likelihood script should be able to load other modules from its + directory using relative import.""" + + dir_path = os.path.dirname(os.path.realpath(__file__)) + + with pytest.raises(ValueError, match="Unrecognized Firecrown initialization file"): + load_likelihood( + os.path.join(dir_path, "lkdir/lkscript_invalid.ext"), NamedParameters() + ) + + +def test_load_likelihood_submodule_no_build_likelihood(): + """The likelihood script should be able to load other modules from its + directory using relative import.""" + + dir_path = os.path.dirname(os.path.realpath(__file__)) + + with pytest.raises( + AttributeError, match="does not define a `build_likelihood` factory function." + ): + load_likelihood( + os.path.join(dir_path, "lkdir/lkscript_invalid.py"), NamedParameters() + ) + + +def test_load_likelihood_submodule_not_a_function(): + """The likelihood script should be able to load other modules from its + directory using relative import.""" + + dir_path = os.path.dirname(os.path.realpath(__file__)) + + with pytest.raises( + TypeError, match="The factory function `build_likelihood` must be a callable." + ): + load_likelihood( + os.path.join(dir_path, "lkdir/lkscript_not_a_function.py"), + NamedParameters(), + ) + + +def test_load_likelihood_submodule_returns_wrong_type(): + """The likelihood script should be able to load other modules from its + directory using relative import.""" + + dir_path = os.path.dirname(os.path.realpath(__file__)) + + with pytest.raises( + TypeError, + match="The returned likelihood must be a Firecrown's `Likelihood` type,", + ): + load_likelihood( + os.path.join(dir_path, "lkdir/lkscript_returns_wrong_type.py"), + NamedParameters(), + ) From a75c7d02028c732159773098d9134574afc7a4de Mon Sep 17 00:00:00 2001 From: Sandro Dias Pinto Vitenti Date: Sun, 27 Aug 2023 11:59:07 -0300 Subject: [PATCH 33/34] Filtering warnings from external packages that we cannot handle. --- pytest.ini | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pytest.ini b/pytest.ini index b722a492..c130f5b8 100644 --- a/pytest.ini +++ b/pytest.ini @@ -6,3 +6,7 @@ addopts = testpaths = tests markers = slow: Mark slow tests to ignore them unless they are requested + +filterwarnings = + ignore::DeprecationWarning:pkg_resources.*: + ignore::DeprecationWarning:cobaya.*: From 5c5715088a8b87019be53d5e9d321d08fa1fb78b Mon Sep 17 00:00:00 2001 From: Sandro Dias Pinto Vitenti Date: Sun, 27 Aug 2023 22:14:43 -0300 Subject: [PATCH 34/34] Added tests for NamedParameters and load_likelihood. --- firecrown/likelihood/likelihood.py | 7 + tests/likelihood/lkdir/lkscript.py | 5 +- tests/likelihood/lkdir/lkscript_old.py | 8 + .../lkscript_returns_wrong_type_tools.py | 12 ++ tests/likelihood/test_likelihood.py | 47 ++++-- tests/likelihood/test_named_parameters.py | 157 ++++++++++++++++++ 6 files changed, 220 insertions(+), 16 deletions(-) create mode 100644 tests/likelihood/lkdir/lkscript_old.py create mode 100644 tests/likelihood/lkdir/lkscript_returns_wrong_type_tools.py create mode 100644 tests/likelihood/test_named_parameters.py diff --git a/firecrown/likelihood/likelihood.py b/firecrown/likelihood/likelihood.py index 95954684..0945174f 100644 --- a/firecrown/likelihood/likelihood.py +++ b/firecrown/likelihood/likelihood.py @@ -204,6 +204,7 @@ def load_likelihood( category=DeprecationWarning, ) likelihood = mod.likelihood + tools = ModelingTools() else: if not callable(mod.build_likelihood): raise TypeError( @@ -222,4 +223,10 @@ def load_likelihood( f"received {type(likelihood)} instead." ) + if not isinstance(tools, ModelingTools): + raise TypeError( + f"The returned tools must be a Firecrown's `ModelingTools` type, " + f"received {type(tools)} instead." + ) + return likelihood, tools diff --git a/tests/likelihood/lkdir/lkscript.py b/tests/likelihood/lkdir/lkscript.py index 44f04349..e3424d32 100644 --- a/tests/likelihood/lkdir/lkscript.py +++ b/tests/likelihood/lkdir/lkscript.py @@ -1,9 +1,12 @@ """ Provides a trivial likelihood factory function for testing purposes. """ +from firecrown.modeling_tools import ModelingTools from . import lkmodule def build_likelihood(_): """Return an EmptyLikelihood object.""" - return lkmodule.empty_likelihood() + tools = ModelingTools() + tools.test_attribute = "test" # type: ignore + return (lkmodule.empty_likelihood(), tools) diff --git a/tests/likelihood/lkdir/lkscript_old.py b/tests/likelihood/lkdir/lkscript_old.py new file mode 100644 index 00000000..bf298360 --- /dev/null +++ b/tests/likelihood/lkdir/lkscript_old.py @@ -0,0 +1,8 @@ +""" +Provides a trivial likelihood factory function for testing purposes. +""" +from . import lkmodule + + +# Defines a module variable "likelihood" +likelihood = lkmodule.empty_likelihood() diff --git a/tests/likelihood/lkdir/lkscript_returns_wrong_type_tools.py b/tests/likelihood/lkdir/lkscript_returns_wrong_type_tools.py new file mode 100644 index 00000000..1ce1abe4 --- /dev/null +++ b/tests/likelihood/lkdir/lkscript_returns_wrong_type_tools.py @@ -0,0 +1,12 @@ +""" +Provides a trivial likelihood factory function for testing purposes. +This module should be loaded by the test_load_likelihood_submodule test. +It should raise an exception because the factory function does not return +a Likelihood object. +""" +from . import lkmodule + + +def build_likelihood(_): + """Return an EmptyLikelihood object.""" + return lkmodule.empty_likelihood(), "Not a ModelingTools" diff --git a/tests/likelihood/test_likelihood.py b/tests/likelihood/test_likelihood.py index eb59c8a3..6ea10821 100644 --- a/tests/likelihood/test_likelihood.py +++ b/tests/likelihood/test_likelihood.py @@ -8,18 +8,12 @@ def test_load_likelihood_submodule(): - """The likelihood script should be able to load other modules from its - directory using relative import.""" - dir_path = os.path.dirname(os.path.realpath(__file__)) load_likelihood(os.path.join(dir_path, "lkdir/lkscript.py"), NamedParameters()) def test_load_likelihood_submodule_invalid(): - """The likelihood script should be able to load other modules from its - directory using relative import.""" - dir_path = os.path.dirname(os.path.realpath(__file__)) with pytest.raises(ValueError, match="Unrecognized Firecrown initialization file"): @@ -29,9 +23,6 @@ def test_load_likelihood_submodule_invalid(): def test_load_likelihood_submodule_no_build_likelihood(): - """The likelihood script should be able to load other modules from its - directory using relative import.""" - dir_path = os.path.dirname(os.path.realpath(__file__)) with pytest.raises( @@ -43,9 +34,6 @@ def test_load_likelihood_submodule_no_build_likelihood(): def test_load_likelihood_submodule_not_a_function(): - """The likelihood script should be able to load other modules from its - directory using relative import.""" - dir_path = os.path.dirname(os.path.realpath(__file__)) with pytest.raises( @@ -58,9 +46,6 @@ def test_load_likelihood_submodule_not_a_function(): def test_load_likelihood_submodule_returns_wrong_type(): - """The likelihood script should be able to load other modules from its - directory using relative import.""" - dir_path = os.path.dirname(os.path.realpath(__file__)) with pytest.raises( @@ -71,3 +56,35 @@ def test_load_likelihood_submodule_returns_wrong_type(): os.path.join(dir_path, "lkdir/lkscript_returns_wrong_type.py"), NamedParameters(), ) + + +def test_load_likelihood_submodule_returns_wrong_type_tools(): + dir_path = os.path.dirname(os.path.realpath(__file__)) + + with pytest.raises( + TypeError, + match="The returned tools must be a Firecrown's `ModelingTools` type", + ): + load_likelihood( + os.path.join(dir_path, "lkdir/lkscript_returns_wrong_type_tools.py"), + NamedParameters(), + ) + + +def test_load_likelihood_submodule_old(): + dir_path = os.path.dirname(os.path.realpath(__file__)) + + load_likelihood( + os.path.join(dir_path, "lkdir/lkscript_old.py"), + NamedParameters(), + ) + + +def test_load_likelihood_correct_tools(): + dir_path = os.path.dirname(os.path.realpath(__file__)) + + _, tools = load_likelihood( + os.path.join(dir_path, "lkdir/lkscript.py"), NamedParameters() + ) + + assert tools.test_attribute == "test" # type: ignore diff --git a/tests/likelihood/test_named_parameters.py b/tests/likelihood/test_named_parameters.py new file mode 100644 index 00000000..bdd3f31c --- /dev/null +++ b/tests/likelihood/test_named_parameters.py @@ -0,0 +1,157 @@ +"""Tests for the class firecrown.likelihood.NamedParameters.""" + + +import pytest +import numpy as np + +from firecrown.likelihood.likelihood import NamedParameters + + +def test_named_parameters_sanity(): + params = NamedParameters({"a": True}) + assert params.get_bool("a") is True + + params = NamedParameters({"a": "Im a string"}) + assert params.get_string("a") == "Im a string" + + params = NamedParameters({"a": 1}) + assert params.get_int("a") == 1 + + params = NamedParameters({"a": 1.0}) + assert params.get_float("a") == 1.0 + + params = NamedParameters({"a": np.array([1, 2, 3])}) + assert params.get_int_array("a").tolist() == [1, 2, 3] + + params = NamedParameters({"a": np.array([1.0, 2.0, 3.0])}) + assert params.get_float_array("a").tolist() == [1.0, 2.0, 3.0] + + params = NamedParameters({"a": False}) + assert params.to_set() == {"a"} + + +def test_named_parameters_default(): + params = NamedParameters({}) + + assert params.get_bool("a", True) is True + assert params.get_string("b", "Im a string") == "Im a string" + assert params.get_int("c", 1) == 1 + assert params.get_float("d", 1.0) == 1.0 + + params = NamedParameters({"a": False, "b": "Im another string", "c": 2, "d": 2.2}) + + assert params.get_bool("a", True) is False + assert params.get_string("b", "Im a string") == "Im another string" + assert params.get_int("c", 1) == 2 + assert params.get_float("d", 1.0) == 2.2 + + +def test_named_parameters_wrong_type_bool(): + params = NamedParameters({"a": True}) + with pytest.raises(AssertionError): + params.get_string("a") + + # Bools are ints in python + # with pytest.raises(AssertionError): + # params.get_int("a") + + with pytest.raises(AssertionError): + params.get_float("a") + + with pytest.raises(AssertionError): + params.get_int_array("a") + + with pytest.raises(AssertionError): + params.get_float_array("a") + + +def test_named_parameters_wrong_type_string(): + params = NamedParameters({"a": "Im a string"}) + with pytest.raises(AssertionError): + params.get_bool("a") + + with pytest.raises(AssertionError): + params.get_int("a") + + with pytest.raises(AssertionError): + params.get_float("a") + + with pytest.raises(AssertionError): + params.get_int_array("a") + + with pytest.raises(AssertionError): + params.get_float_array("a") + + +def test_named_parameters_wrong_type_int(): + params = NamedParameters({"a": 1}) + with pytest.raises(AssertionError): + params.get_bool("a") + + with pytest.raises(AssertionError): + params.get_string("a") + + with pytest.raises(AssertionError): + params.get_float("a") + + with pytest.raises(AssertionError): + params.get_int_array("a") + + with pytest.raises(AssertionError): + params.get_float_array("a") + + +def test_named_parameters_wrong_type_float(): + params = NamedParameters({"a": 1.0}) + with pytest.raises(AssertionError): + params.get_bool("a") + + with pytest.raises(AssertionError): + params.get_string("a") + + with pytest.raises(AssertionError): + params.get_int("a") + + with pytest.raises(AssertionError): + params.get_int_array("a") + + with pytest.raises(AssertionError): + params.get_float_array("a") + + +def test_named_parameters_wrong_type_int_array(): + params = NamedParameters({"a": np.array([1, 2, 3])}) + with pytest.raises(AssertionError): + params.get_bool("a") + + with pytest.raises(AssertionError): + params.get_string("a") + + with pytest.raises(AssertionError): + params.get_int("a") + + with pytest.raises(AssertionError): + params.get_float("a") + + # Int arrays are float arrays in python + # with pytest.raises(AssertionError): + # params.get_float_array("a") + + +def test_named_parameters_wrong_type_float_array(): + params = NamedParameters({"a": np.array([1.0, 2.0, 3.0])}) + with pytest.raises(AssertionError): + params.get_bool("a") + + with pytest.raises(AssertionError): + params.get_string("a") + + with pytest.raises(AssertionError): + params.get_int("a") + + with pytest.raises(AssertionError): + params.get_float("a") + + # Float arrays are int arrays in python + # with pytest.raises(AssertionError): + # params.get_int_array("a")