diff --git a/firecrown/models/cluster_mass_true.py b/firecrown/models/cluster_mass_true.py new file mode 100644 index 00000000..9a7d5fd6 --- /dev/null +++ b/firecrown/models/cluster_mass_true.py @@ -0,0 +1,90 @@ +"""Cluster Mass True Module + +Class to compute cluster mass functions with no proxy, +i.e., assuming we have the true masses of the clusters. + +""" + +from typing import final, List, Tuple + +import numpy as np +from .. import sacc_support +from ..sacc_support import sacc + + +from ..parameters import ( + ParamsMap, + RequiredParameters, + DerivedParameterCollection, +) +from .cluster_mass import ClusterMass, ClusterMassArgument + + +class ClusterMassTrue(ClusterMass): + """Cluster Mass class.""" + + @final + def _update_cluster_mass(self, params: ParamsMap): + """Method to update the ClusterMassTrue from the given ParamsMap.""" + + @final + def _reset_cluster_mass(self): + """Method to reset the ClusterMassTrue.""" + + @final + def _required_parameters(self) -> RequiredParameters: + return RequiredParameters([]) + + @final + def _get_derived_parameters(self) -> DerivedParameterCollection: + return DerivedParameterCollection([]) + + def read(self, sacc_data: sacc.Sacc): + """Method to read the data for this source from the SACC file.""" + + def gen_bins_by_array(self, logM_bins: np.ndarray) -> List[ClusterMassArgument]: + """Generate the bins by an array of bin edges.""" + + if len(logM_bins) < 2: + raise ValueError("logM_bins must have at least two elements") + + # itertools.pairwise is only available in Python 3.10 + # using zip instead + return [ + ClusterMassTrueArgument(lower, upper) + for lower, upper in zip(logM_bins[:-1], logM_bins[1:]) + ] + + def point_arg(self, logM: float) -> ClusterMassArgument: + """Return the argument for the given mass.""" + + return ClusterMassTrueArgument(logM, logM) + + def gen_bin_from_tracer(self, tracer: sacc.BaseTracer) -> ClusterMassArgument: + """Return the argument for the given tracer.""" + + if not isinstance(tracer, sacc_support.BinLogMTracer): + raise ValueError("Tracer must be a BinLogMTracer") + + return ClusterMassTrueArgument(tracer.lower, tracer.upper) + + +class ClusterMassTrueArgument(ClusterMassArgument): + """Cluster mass true argument class.""" + + @property + def dim(self) -> int: + """Return the dimension of the argument.""" + return 0 + + def get_logM_bounds(self) -> Tuple[float, float]: + """Return the bounds of the cluster mass argument.""" + return (self.logMl, self.logMu) + + def get_proxy_bounds(self) -> List[Tuple[float, float]]: + """Return the bounds of the cluster mass proxy argument.""" + return [] + + def p(self, logM: float, z: float, *args) -> float: + """Return the probability of the argument.""" + return 1.0 diff --git a/firecrown/sacc_support.py b/firecrown/sacc_support.py index 25c8c4ca..cca2deab 100644 --- a/firecrown/sacc_support.py +++ b/firecrown/sacc_support.py @@ -91,6 +91,82 @@ def from_tables(cls, table_list): return tracers +class BinLogMTracer(BaseTracer, tracer_type="bin_logM"): # type: ignore + """A tracer for a single log-mass bin.""" + + def __init__(self, name: str, lower: float, upper: float, **kwargs): + """ + Create a tracer corresponding to a single log-mass bin. + + :param name: The name of the tracer + :param lower: The lower bound of the log-mass bin + :param upper: The upper bound of the log-mass bin + """ + super().__init__(name, **kwargs) + self.lower = lower + self.upper = upper + + def __eq__(self, other) -> bool: + """Test for equality. If :python:`other` is not a + :python:`BinLogMTracer`, then it is not equal to :python:`self`. + Otherwise, they are equal if names, and the z-range of the bins, + are equal.""" + if not isinstance(other, BinLogMTracer): + return False + return ( + self.name == other.name + and self.lower == other.lower + and self.upper == other.upper + ) + + @classmethod + def to_tables(cls, instance_list): + """Convert a list of BinLogMTracers to a single astropy table + + This is used when saving data to a file. + One table is generated with the information for all the tracers. + + :param instance_list: List of tracer instances + :return: List with a single astropy table + """ + + names = ["name", "quantity", "lower", "upper"] + + cols = [ + [obj.name for obj in instance_list], + [obj.quantity for obj in instance_list], + [obj.lower for obj in instance_list], + [obj.upper for obj in instance_list], + ] + + table = Table(data=cols, names=names) + table.meta["SACCTYPE"] = "tracer" + table.meta["SACCCLSS"] = cls.tracer_type + table.meta["EXTNAME"] = f"tracer:{cls.tracer_type}" + return [table] + + @classmethod + def from_tables(cls, table_list): + """Convert an astropy table into a dictionary of tracers + + This is used when loading data from a file. + One tracer object is created for each "row" in each table. + + :param table_list: List of astropy tables + :return: Dictionary of tracers + """ + tracers = {} + + for table in table_list: + for row in table: + name = row["name"] + quantity = row["quantity"] + lower = row["lower"] + upper = row["upper"] + tracers[name] = cls(name, quantity=quantity, lower=lower, upper=upper) + return tracers + + class BinRichnessTracer(BaseTracer, tracer_type="bin_richness"): # type: ignore """A tracer for a single richness bin.""" @@ -112,8 +188,8 @@ def __init__(self, name: str, lower: float, upper: float, **kwargs): Create a tracer corresponding to a single richness bin. :param name: The name of the tracer - :param lower: The lower bound of the redshift bin - :param upper: The upper bound of the redshift bin + :param lower: The lower bound of the richness bin + :param upper: The upper bound of the richness bin """ super().__init__(name, **kwargs) self.lower = lower diff --git a/pylintrc b/pylintrc index bb0e974c..22980150 100644 --- a/pylintrc +++ b/pylintrc @@ -42,6 +42,9 @@ const-naming-style=any # Naming stule matching correct method names. method-naming-style=any +# Naming style matching correct function names. +function-naming-style=any + [STRING] # This flag controls whether inconsistent-quotes generates a warning when the diff --git a/tests/pylintrc b/tests/pylintrc index 3f1edd72..fa87643d 100644 --- a/tests/pylintrc +++ b/tests/pylintrc @@ -39,6 +39,9 @@ variable-naming-style=any # Naming style matching correct method names. method-naming-style=any +# Naming style matching correct function names. +function-naming-style=any + # Do not require docstrings in test functions no-docstring-rgx=^(test|fixture) diff --git a/tests/test_cluster.py b/tests/test_cluster.py index 1b377111..253ca9ed 100644 --- a/tests/test_cluster.py +++ b/tests/test_cluster.py @@ -1,6 +1,6 @@ """Tests for the cluster module.""" -from typing import Any, Dict, List, Tuple +from typing import Any, Dict import itertools import math @@ -10,6 +10,7 @@ from firecrown.models.cluster_mass import ClusterMass, ClusterMassArgument from firecrown.models.cluster_redshift import ClusterRedshift, ClusterRedshiftArgument +from firecrown.models.cluster_mass_true import ClusterMassTrue from firecrown.models.cluster_abundance import ClusterAbundance from firecrown.models.cluster_mass_rich_proxy import ClusterMassRich from firecrown.models.cluster_redshift_spec import ClusterRedshiftSpec @@ -25,16 +26,9 @@ def fixture_ccl_cosmo(): ) -@pytest.fixture(name="cluster_objects") -def fixture_cluster_objects(): - """Fixture for cluster objects.""" - - hmd_200 = ccl.halos.MassDef200c() - hmf_args: Dict[str, Any] = {} - hmf_name = "Bocquet16" - pivot_mass = 14.0 - pivot_redshift = 0.6 - sky_area = 489 +@pytest.fixture(name="parameters") +def fixture_parameters(): + """Fixture for a parameter map.""" parameters = ParamsMap( { @@ -46,43 +40,80 @@ def fixture_cluster_objects(): "sigma_p2": 0.0, } ) + return parameters - z_bins = np.array([0.2000146, 0.31251036, 0.42500611, 0.53750187, 0.64999763]) - proxy_bins = np.array([0.45805137, 0.81610273, 1.1741541, 1.53220547, 1.89025684]) +@pytest.fixture(name="z_args") +def fixture_cluster_z_args(parameters): + """Fixture for cluster redshifts.""" + z_bins = np.array([0.2000146, 0.31251036, 0.42500611, 0.53750187, 0.64999763]) cluster_z = ClusterRedshiftSpec() assert isinstance(cluster_z, ClusterRedshift) z_args = cluster_z.gen_bins_by_array(z_bins) + cluster_z.update(parameters) + + return z_args + + +@pytest.fixture(name="logM_args") +def fixture_cluster_mass_logM_args(parameters): + """Fixture for cluster masses.""" + logM_bins = np.array([13.0, 13.5, 14.0, 14.5, 15.0]) + cluster_mass_t = ClusterMassTrue() + assert isinstance(cluster_mass_t, ClusterMass) + + logM_args = cluster_mass_t.gen_bins_by_array(logM_bins) + cluster_mass_t.update(parameters) + + return logM_args + + +@pytest.fixture(name="rich_args") +def fixture_cluster_mass_rich_args(parameters): + """Fixture for cluster masses.""" + pivot_mass = 14.0 + pivot_redshift = 0.6 + proxy_bins = np.array([0.45805137, 0.81610273, 1.1741541, 1.53220547, 1.89025684]) + cluster_mass_r = ClusterMassRich(pivot_mass, pivot_redshift) assert isinstance(cluster_mass_r, ClusterMass) rich_args = cluster_mass_r.gen_bins_by_array(proxy_bins) + cluster_mass_r.update(parameters) + + return rich_args + + +@pytest.fixture(name="cluster_abundance") +def fixture_cluster_abundance(parameters): + """Fixture for cluster objects.""" + + hmd_200 = ccl.halos.MassDef200c() + hmf_args: Dict[str, Any] = {} + hmf_name = "Bocquet16" + sky_area = 489 cluster_abundance = ClusterAbundance(hmd_200, hmf_name, hmf_args, sky_area) assert isinstance(cluster_abundance, ClusterAbundance) cluster_abundance.update(parameters) - cluster_z.update(parameters) - cluster_mass_r.update(parameters) - return cluster_abundance, z_args, rich_args + return cluster_abundance def test_initialize_objects( - ccl_cosmo: ccl.Cosmology, - cluster_objects: Tuple[ - ClusterAbundance, List[ClusterRedshiftArgument], List[ClusterMassArgument] - ], + ccl_cosmo: ccl.Cosmology, cluster_abundance, z_args, logM_args, rich_args ): """Test initialization of cluster objects.""" - cluster_abundance, z_args, rich_args = cluster_objects - for z_arg in z_args: assert isinstance(z_arg, ClusterRedshiftArgument) + for logM_arg in logM_args: + assert isinstance(logM_arg, ClusterMassArgument) + for rich_arg in rich_args: assert isinstance(rich_arg, ClusterMassArgument) @@ -91,16 +122,16 @@ def test_initialize_objects( def test_cluster_mass_function_compute( - ccl_cosmo: ccl.Cosmology, - cluster_objects: Tuple[ - ClusterAbundance, List[ClusterRedshiftArgument], List[ClusterMassArgument] - ], + ccl_cosmo: ccl.Cosmology, cluster_abundance, z_args, logM_args, rich_args ): """Test cluster mass function computations.""" - cluster_abundance, z_args, rich_args = cluster_objects - - for redshift_arg, mass_arg in itertools.product(z_args, rich_args): + for redshift_arg, logM_arg, rich_arg in itertools.product( + z_args, logM_args, rich_args + ): + assert math.isfinite( + cluster_abundance.compute(ccl_cosmo, rich_arg, redshift_arg) + ) assert math.isfinite( - cluster_abundance.compute(ccl_cosmo, mass_arg, redshift_arg) + cluster_abundance.compute(ccl_cosmo, logM_arg, redshift_arg) ) diff --git a/tests/test_sacc_support.py b/tests/test_sacc_support.py index 63d73fe5..8e7b334e 100644 --- a/tests/test_sacc_support.py +++ b/tests/test_sacc_support.py @@ -4,6 +4,7 @@ """ from firecrown.sacc_support import ( BinZTracer, + BinLogMTracer, BinRichnessTracer, BinRadiusTracer, ClusterSurveyTracer, @@ -44,6 +45,40 @@ def test_binztracer_tables(): assert d["wilma"] == b +def test_make_binlogmtracer(): + tracer = BinLogMTracer.make("bin_logm", name="fred", lower=13.0, upper=15.0) + assert isinstance(tracer, BinLogMTracer) + assert tracer.quantity == "generic" + assert tracer.name == "fred" + assert tracer.lower == 13.0 + assert tracer.upper == 15.0 + + +def test_binlogmtracer_equality(): + a = BinLogMTracer.make("bin_logm", name="fred", lower=13.0, upper=15.0) + b = BinLogMTracer.make("bin_logm", name="fred", lower=13.0, upper=15.0) + c = BinLogMTracer.make("bin_logm", name="wilma", lower=13.0, upper=15.0) + d = BinLogMTracer.make("bin_logm", name="fred", lower=14.0, upper=15.0) + e = BinLogMTracer.make("bin_logm", name="fred", lower=13.0, upper=15.1) + assert a == b + assert a != "fred" + assert a != c + assert a != d + assert a != e + + +def test_binlogmtracer_tables(): + a = BinLogMTracer.make("bin_logm", name="fred", lower=13.0, upper=15.0) + b = BinLogMTracer.make("bin_logm", name="wilma", lower=14.0, upper=15.5) + tables = BinLogMTracer.to_tables([a, b]) + assert len(tables) == 1 # all BinLogMTracers are written to a single table + + d = BinLogMTracer.from_tables(tables) + assert len(d) == 2 # this list of tables recovers both BinLogMTracers + assert d["fred"] == a + assert d["wilma"] == b + + def test_make_binrichness_tracer(): tracer = BinRichnessTracer.make( "bin_richness",