Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More testing #310

Merged
merged 34 commits into from
Aug 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
0cfcad5
Add test for linear bias systemaic
marcpaterno May 30, 2023
06532d1
Initialize instance variable in __init__
marcpaterno Jul 17, 2023
1a314a4
Add is_updated to generic updatable interface
marcpaterno Jul 17, 2023
449a5ae
Fix line breaking
marcpaterno Jul 17, 2023
7808e79
Enhance test of number_counts.LinearBiasSystematic
marcpaterno Jul 17, 2023
4860faa
Fix the implementation of _reset
marcpaterno Jul 17, 2023
ab0ffce
Fix and improve comments
marcpaterno Aug 11, 2023
322cfff
Add more tests
marcpaterno Aug 20, 2023
6e3275f
Raise ValueError when survey tracer is missing
marcpaterno Aug 20, 2023
142567a
Apply black
marcpaterno Aug 20, 2023
373adcf
Fix mypy deprecation warning
marcpaterno Aug 21, 2023
3fada95
Remove unused imports
marcpaterno Aug 21, 2023
52ca0f3
Minimal tweaks for test to pass mypy check
marcpaterno Aug 21, 2023
c642786
Automated reset calls in all updatables.
vitenti Aug 25, 2023
989ac71
Automated required_parameters protocol.
vitenti Aug 25, 2023
74a1a5b
Automated derived_parameters protocol.
vitenti Aug 25, 2023
7cf59b8
Added tests for updatable nesting.
vitenti Aug 25, 2023
1ede0d4
Testing multiple get_derived.
vitenti Aug 25, 2023
8d6f69b
Add test, fix code to raise expected exception
marcpaterno Aug 11, 2023
64abcb1
Improve comment
marcpaterno Aug 11, 2023
e2355cf
Improve minimal_stat fixture
marcpaterno Aug 25, 2023
256d461
Remove duplicate test, in the wrong directory
marcpaterno Aug 25, 2023
73e0c71
Fix package name usage
marcpaterno Aug 25, 2023
e1419bd
Minimal tests for Cobaya
vitenti Aug 26, 2023
7da2248
More tests for Cobaya
vitenti Aug 27, 2023
0d00402
Tests for Cobaya derived parameters.
vitenti Aug 27, 2023
257a9d6
Tests for Cobaya derived parameters.
vitenti Aug 27, 2023
b2326bc
Fixed pylint issue on tests.
vitenti Aug 27, 2023
b654c8f
Removed redundant methods in cobaya connector.
vitenti Aug 27, 2023
ecd3d23
Added tests for DerivedParameter* eq.
vitenti Aug 27, 2023
f34d87e
Removed old interface which were replaced by Updatable.required_param…
vitenti Aug 27, 2023
c8c84ee
More testing for load_likelihood.
vitenti Aug 27, 2023
a75c7d0
Filtering warnings from external packages that we cannot handle.
vitenti Aug 27, 2023
5c57150
Added tests for NamedParameters and load_likelihood.
vitenti Aug 28, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 0 additions & 7 deletions firecrown/connector/cobaya/ccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,6 @@ def initialize(self):
self.z_Pk = np.arange(0.0, 6.0, 1)
self.Pk_kmax = 1.0

def get_param(self, p: str) -> None:
"""Return the current value of the parameter named 'p'.

This implementation always returns None.
"""
return None

def initialize_with_params(self):
"""Required by Cobaya.

Expand Down
13 changes: 8 additions & 5 deletions firecrown/connector/cobaya/likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,19 @@ def initialize(self):
if not hasattr(self, "build_parameters"):
build_parameters = NamedParameters()
else:
build_parameters = self.build_parameters
if isinstance(self.build_parameters, dict):
build_parameters = NamedParameters(self.build_parameters)
else:
if not isinstance(self.build_parameters, NamedParameters):
raise TypeError(
"build_parameters must be a NamedParameters or dict"
)
build_parameters = self.build_parameters

self.likelihood, self.tools = load_likelihood(
self.firecrownIni, build_parameters
)

def get_param(self, p: str):
"""Return the current value of the parameter named 'p'."""
return self._current_state["derived"][p]

def initialize_with_params(self) -> None:
"""Required by Cobaya.

Expand Down
47 changes: 0 additions & 47 deletions firecrown/likelihood/gauss_family/gauss_family.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from __future__ import annotations
from typing import List, Optional, Tuple, Sequence
from typing import final
from abc import abstractmethod
import warnings

import numpy as np
Expand All @@ -23,7 +22,6 @@
from ...modeling_tools import ModelingTools
from ...updatable import UpdatableCollection
from .statistic.statistic import Statistic
from ...parameters import RequiredParameters, DerivedParameterCollection


class GaussFamily(Likelihood):
Expand Down Expand Up @@ -128,48 +126,3 @@ def compute_chisq(self, tools: ModelingTools) -> float:
chisq = np.dot(x, x)

return chisq

@final
def _reset(self) -> None:
"""Implementation of Likelihood interface method _reset.

This resets all statistics and calls the abstract method
_reset_gaussian_family."""
self._reset_gaussian_family()
self.statistics.reset()

@final
def _get_derived_parameters(self) -> DerivedParameterCollection:
derived_parameters = (
self._get_derived_parameters_gaussian_family()
+ self.statistics.get_derived_parameters()
)

return derived_parameters

@abstractmethod
def _reset_gaussian_family(self) -> None:
"""Abstract method to reset GaussianFamily state. Must be implemented by all
subclasses."""

@final
def _required_parameters(self) -> RequiredParameters:
"""Return a RequiredParameters object containing the information for
this Updatable.

This includes the required parameters for all statistics, as well as those
for the derived class.

Derived classes must implement required_parameters_gaussian_family."""
stats_rp = self.statistics.required_parameters()
stats_rp = self._required_parameters_gaussian_family() + stats_rp

return stats_rp

@abstractmethod
def _required_parameters_gaussian_family(self):
"""Required parameters for GaussFamily subclasses."""

@abstractmethod
def _get_derived_parameters_gaussian_family(self) -> DerivedParameterCollection:
"""Get derived parameters for GaussFamily subclasses."""
14 changes: 0 additions & 14 deletions firecrown/likelihood/gauss_family/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,8 @@
"""

from __future__ import annotations
from typing import final

from .gauss_family import GaussFamily
from ...parameters import RequiredParameters, DerivedParameterCollection
from ...modeling_tools import ModelingTools


Expand All @@ -17,15 +15,3 @@ def compute_loglike(self, tools: ModelingTools):
"""Compute the log-likelihood."""

return -0.5 * self.compute_chisq(tools)

@final
def _reset_gaussian_family(self):
pass

@final
def _required_parameters_gaussian_family(self):
return RequiredParameters([])

@final
def _get_derived_parameters_gaussian_family(self) -> DerivedParameterCollection:
return DerivedParameterCollection([])
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
"""Cluster Number Count statistic support.
This module reads the necessary data from a SACC file to compute the
theoretical prediction of cluster number counts inside bins of redshift
and a mass proxy. For further information, check README.md.
and a mass proxy.
"""

from __future__ import annotations
from typing import List, Dict, Tuple, Optional, final
from typing import List, Dict, Tuple, Optional

import numpy as np

Expand All @@ -14,10 +14,6 @@

from .statistic import Statistic, DataVector, TheoryVector
from .source.source import SourceSystematic
from ....parameters import (
RequiredParameters,
DerivedParameterCollection,
)
from ....models.cluster_abundance import ClusterAbundance
from ....models.cluster_mass import ClusterMass, ClusterMassArgument
from ....models.cluster_redshift import ClusterRedshift, ClusterRedshiftArgument
Expand All @@ -29,14 +25,12 @@ class ClusterNumberCounts(Statistic):
multiplicity functions, volume element, etc.).
This subclass implements the read and computes method for
the Statistic class. It is used to compute the theoretical prediction of
cluster number counts given a SACC file and a cosmology. For
further information on how the SACC file shall be created,
check README.md.
cluster number counts.
"""

def __init__(
self,
sacc_tracer: str,
survey_tracer: str,
cluster_abundance: ClusterAbundance,
cluster_mass: ClusterMass,
cluster_redshift: ClusterRedshift,
Expand All @@ -47,18 +41,13 @@ def __init__(
"""Initialize the ClusterNumberCounts object.
Parameters

:param sacc_tracer: The SACC tracer. There must be only one tracer for all
the number Counts data points. Following the SACC file
documentation README.md, this string should be
'cluster_counts_true_mass'.
:param survey_tracer: name of the survey tracer in the SACC data.
:param cluster_abundance: The cluster abundance model to use.
:param systematics: A list of the statistics-level systematics to apply to
the statistic. The default of `None` implies no systematics.

"""
super().__init__()

self.sacc_tracer = sacc_tracer
self.sacc_tracer = survey_tracer
self.systematics = systematics or []
self.data_vector: Optional[DataVector] = None
self.theory_vector: Optional[TheoryVector] = None
Expand All @@ -74,32 +63,6 @@ def __init__(
"At least one of use_cluster_counts and use_mean_log_mass must be True."
)

@final
def _reset(self) -> None:
"""Reset all contained Updatable objects."""
self.cluster_abundance.reset()
self.cluster_mass.reset()
self.cluster_redshift.reset()

@final
def _required_parameters(self) -> RequiredParameters:
"""Return an empty RequiredParameters."""
return (
self.cluster_abundance.required_parameters()
+ self.cluster_mass.required_parameters()
+ self.cluster_redshift.required_parameters()
)

@final
def _get_derived_parameters(self) -> DerivedParameterCollection:
"""Return an empty DerivedParameterCollection."""
derived_parameters = DerivedParameterCollection([])
derived_parameters += self.cluster_abundance.get_derived_parameters()
derived_parameters += self.cluster_mass.get_derived_parameters()
derived_parameters += self.cluster_redshift.get_derived_parameters()

return derived_parameters

def _read_data_type(self, sacc_data, data_type):
"""Internal function to read the data from the SACC file."""
tracers_combinations = np.array(
Expand Down Expand Up @@ -163,21 +126,19 @@ def _read_data_type(self, sacc_data, data_type):

return data_vector_list, sacc_indices_list

def read(self, sacc_data):
def read(self, sacc_data: sacc.Sacc):
"""Read the data for this statistic from the SACC file.
This function takes the SACC file and extract the necessary
parameters needed to compute the number counts likelihood.
Check README.MD for a complete description of the method.

:param sacc_data: The data in the sacc format.
:param sacc_data: The data in the SACC format.
"""

survey_tracer: SurveyTracer = sacc_data.get_tracer(self.sacc_tracer)
if survey_tracer is None:
try:
survey_tracer: SurveyTracer = sacc_data.get_tracer(self.sacc_tracer)
except KeyError as exc:
raise ValueError(
f"The SACC file does not contain the SurveyTracer "
f"{self.sacc_tracer}."
)
) from exc
if not isinstance(survey_tracer, SurveyTracer):
raise ValueError(
f"The SACC tracer {self.sacc_tracer} is not a SurveyTracer."
Expand Down Expand Up @@ -220,10 +181,9 @@ def get_data_vector(self) -> DataVector:
def compute_theory_vector(self, tools: ModelingTools) -> TheoryVector:
"""Compute a Number Count statistic using the data from the
Read method, the cosmology object, and the Bocquet16 halo mass function.
Check README.MD for a complete description of the method.

:param tools: ModelingTools firecrown object
Firecrown object used to load the required cosmology.
used to load the required cosmology.

:return: Numpy Array of floats
An array with the theoretical prediction of the number of clusters
Expand Down
Loading
Loading