Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Renamed create to register_new_updatable_parameter. #334

Merged
merged 1 commit into from
Oct 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions firecrown/likelihood/gauss_family/statistic/source/number_counts.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,9 @@ def __init__(self, sacc_tracer: str):
"""
super().__init__(parameter_prefix=sacc_tracer)

self.alphaz = parameters.create()
self.alphag = parameters.create()
self.z_piv = parameters.create()
self.alphaz = parameters.register_new_updatable_parameter()
self.alphag = parameters.register_new_updatable_parameter()
self.z_piv = parameters.register_new_updatable_parameter()

def apply(
self, tools: ModelingTools, tracer_arg: NumberCountsArgs
Expand Down Expand Up @@ -143,8 +143,8 @@ def __init__(self, sacc_tracer: str):
"""
super().__init__(parameter_prefix=sacc_tracer)

self.b_2 = parameters.create()
self.b_s = parameters.create()
self.b_2 = parameters.register_new_updatable_parameter()
self.b_s = parameters.register_new_updatable_parameter()

def apply(
self, tools: ModelingTools, tracer_arg: NumberCountsArgs
Expand Down Expand Up @@ -186,11 +186,11 @@ def __init__(self, sacc_tracer: str):
"""
super().__init__(parameter_prefix=sacc_tracer)

self.r_lim = parameters.create()
self.sig_c = parameters.create()
self.eta = parameters.create()
self.z_c = parameters.create()
self.z_m = parameters.create()
self.r_lim = parameters.register_new_updatable_parameter()
self.sig_c = parameters.register_new_updatable_parameter()
self.eta = parameters.register_new_updatable_parameter()
self.z_c = parameters.register_new_updatable_parameter()
self.z_m = parameters.register_new_updatable_parameter()

def apply(
self, tools: ModelingTools, tracer_arg: NumberCountsArgs
Expand Down Expand Up @@ -245,7 +245,7 @@ def __init__(self, sacc_tracer: str):
"""
super().__init__(parameter_prefix=sacc_tracer)

self.mag_bias = parameters.create()
self.mag_bias = parameters.register_new_updatable_parameter()

def apply(
self, tools: ModelingTools, tracer_arg: NumberCountsArgs
Expand Down Expand Up @@ -291,7 +291,7 @@ def __init__(
self.has_rsd = has_rsd
self.derived_scale = derived_scale

self.bias = parameters.create()
self.bias = parameters.register_new_updatable_parameter()
self.systematics = UpdatableCollection(systematics)
self.scale = scale
self.current_tracer_args: Optional[NumberCountsArgs] = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def __init__(self, sacc_tracer: str):
"""
super().__init__(parameter_prefix=sacc_tracer)

self.delta_z = parameters.create()
self.delta_z = parameters.register_new_updatable_parameter()

def apply(self, tools: ModelingTools, tracer_arg: _SourceGalaxyArgsT):
"""Apply a shift to the photo-z distribution of a source."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def __init__(self, sacc_tracer: str) -> None:
"""
super().__init__(parameter_prefix=sacc_tracer)

self.mult_bias = parameters.create()
self.mult_bias = parameters.register_new_updatable_parameter()

def apply(
self, tools: ModelingTools, tracer_arg: WeakLensingArgs
Expand Down Expand Up @@ -124,10 +124,10 @@ def __init__(self, sacc_tracer: Optional[str] = None, alphag=1.0):
"""
super().__init__(parameter_prefix=sacc_tracer)

self.ia_bias = parameters.create()
self.alphaz = parameters.create()
self.alphag = parameters.create(alphag)
self.z_piv = parameters.create()
self.ia_bias = parameters.register_new_updatable_parameter()
self.alphaz = parameters.register_new_updatable_parameter()
self.alphag = parameters.register_new_updatable_parameter(alphag)
self.z_piv = parameters.register_new_updatable_parameter()

def apply(
self, tools: ModelingTools, tracer_arg: WeakLensingArgs
Expand Down Expand Up @@ -172,9 +172,9 @@ def __init__(self, sacc_tracer: Optional[str] = None):
as a prefix for its parameters.
"""
super().__init__(parameter_prefix=sacc_tracer)
self.ia_a_1 = parameters.create()
self.ia_a_2 = parameters.create()
self.ia_a_d = parameters.create()
self.ia_a_1 = parameters.register_new_updatable_parameter()
self.ia_a_2 = parameters.register_new_updatable_parameter()
self.ia_a_d = parameters.register_new_updatable_parameter()

def apply(
self, tools: ModelingTools, tracer_arg: WeakLensingArgs
Expand Down
2 changes: 1 addition & 1 deletion firecrown/likelihood/gauss_family/statistic/statistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def __init__(self) -> None:
# Data and theory will both be of length self.count
self.count = 3
self.data_vector: Optional[DataVector] = None
self.mean = firecrown.parameters.create()
self.mean = firecrown.parameters.register_new_updatable_parameter()
self.computed_theory_vector = False

def read(self, sacc_data: sacc.Sacc):
Expand Down
2 changes: 1 addition & 1 deletion firecrown/likelihood/gauss_family/statistic/supernova.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(self, sacc_tracer) -> None:
self.sacc_tracer = sacc_tracer
self.data_vector: Optional[DataVector] = None
self.a: Optional[npt.NDArray[np.float64]] = None
self.M = parameters.create()
self.M = parameters.register_new_updatable_parameter()

def read(self, sacc_data: sacc.Sacc):
"""Read the data for this statistic from the SACC file."""
Expand Down
2 changes: 1 addition & 1 deletion firecrown/likelihood/gauss_family/student_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(
nu: Optional[float],
):
super().__init__(statistics)
self.nu = parameters.create(nu)
self.nu = parameters.register_new_updatable_parameter(nu)

def compute_loglike(self, tools: ModelingTools):
"""Compute the log-likelihood.
Expand Down
12 changes: 6 additions & 6 deletions firecrown/models/cluster_mass_rich_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,12 @@ def __init__(
self.logMu = logMu

# Updatable parameters
self.mu_p0 = parameters.create()
self.mu_p1 = parameters.create()
self.mu_p2 = parameters.create()
self.sigma_p0 = parameters.create()
self.sigma_p1 = parameters.create()
self.sigma_p2 = parameters.create()
self.mu_p0 = parameters.register_new_updatable_parameter()
self.mu_p1 = parameters.register_new_updatable_parameter()
self.mu_p2 = parameters.register_new_updatable_parameter()
self.sigma_p0 = parameters.register_new_updatable_parameter()
self.sigma_p1 = parameters.register_new_updatable_parameter()
self.sigma_p2 = parameters.register_new_updatable_parameter()

self.logM_obs_min = 0.0
self.logM_obs_max = np.inf
Expand Down
15 changes: 15 additions & 0 deletions firecrown/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from __future__ import annotations
from typing import Iterable, List, Dict, Set, Tuple, Optional, Iterator, Sequence
import warnings
from abc import ABC, abstractmethod


Expand Down Expand Up @@ -304,6 +305,20 @@ def get_value(self) -> float:
def create(value: Optional[float] = None):
"""Create a new parameter, either a SamplerParameter or an InternalParameter.

See register_new_updatable_parameter for details."""
warnings.simplefilter("always", DeprecationWarning)
warnings.warn(
"This function is named `create` and will be removed in a future version "
"due to its name being too generic."
"Use `register_new_updatable_parameter` instead.",
category=DeprecationWarning,
)
return register_new_updatable_parameter(value)


def register_new_updatable_parameter(value: Optional[float] = None):
"""Create a new parameter, either a SamplerParameter or an InternalParameter.

If `value` is `None`, the result will be a `SamplerParameter`; Firecrown
will expect this value to be supplied by the sampling framwork. If `value`
is a `float` quantity, then Firecrown will expect this parameter to *not*
Expand Down
2 changes: 1 addition & 1 deletion tests/likelihood/lkdir/lkmodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __init__(self, params: NamedParameters):
parameter_prefix value and creates a sampler parameter called "sampler_param0".
"""
super().__init__(parameter_prefix=params.get_string("parameter_prefix"))
self.sampler_param0 = parameters.create()
self.sampler_param0 = parameters.register_new_updatable_parameter()

def read(self, sacc_data: sacc.Sacc) -> None:
"""This class has nothing to read."""
Expand Down
28 changes: 23 additions & 5 deletions tests/test_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from firecrown.parameters import (
DerivedParameterScalar,
DerivedParameterCollection,
register_new_updatable_parameter,
create,
InternalParameter,
SamplerParameter,
Expand All @@ -16,23 +17,40 @@
def test_create_with_no_arg():
"""Calling parameters.create() with no argument should return an
SamplerParameter"""
a_parameter = create()
assert isinstance(a_parameter, SamplerParameter)
with pytest.deprecated_call():
a_parameter = create()
assert isinstance(a_parameter, SamplerParameter)


def test_create_with_float_arg():
"""Calling parameters.create() with a float argument should return a
InternalParameter ."""
a_parameter = create(1.5)
with pytest.deprecated_call():
a_parameter = create(1.5)
assert isinstance(a_parameter, InternalParameter)
assert a_parameter.value == 1.5


def test_register_new_updatable_parameter_with_no_arg():
"""Calling parameters.create() with no argument should return an
SamplerParameter"""
a_parameter = register_new_updatable_parameter()
assert isinstance(a_parameter, SamplerParameter)


def test_register_new_updatable_parameter_with_float_arg():
"""Calling parameters.create() with a float argument should return a
InternalParameter ."""
a_parameter = register_new_updatable_parameter(1.5)
assert isinstance(a_parameter, InternalParameter)
assert a_parameter.value == 1.5


def test_create_with_wrong_arg():
def test_register_new_updatable_parameter_with_wrong_arg():
"""Calling parameters.create() with an org that is neither float nor None should
raise a TypeError."""
with pytest.raises(TypeError):
_ = create("cow") # type: ignore
_ = register_new_updatable_parameter("cow") # type: ignore


def test_required_parameters_length():
Expand Down
44 changes: 29 additions & 15 deletions tests/test_updatable.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __init__(self):
"""Initialize object with defaulted value."""
super().__init__()

self.a = parameters.create()
self.a = parameters.register_new_updatable_parameter()


class SimpleUpdatable(Updatable):
Expand All @@ -32,8 +32,8 @@ def __init__(self):
"""Initialize object with defaulted values."""
super().__init__()

self.x = parameters.create()
self.y = parameters.create()
self.x = parameters.register_new_updatable_parameter()
self.y = parameters.register_new_updatable_parameter()


class UpdatableWithDerived(Updatable):
Expand All @@ -43,8 +43,8 @@ def __init__(self):
"""Initialize object with defaulted values."""
super().__init__()

self.A = parameters.create()
self.B = parameters.create()
self.A = parameters.register_new_updatable_parameter()
self.B = parameters.register_new_updatable_parameter()

def _get_derived_parameters(self) -> DerivedParameterCollection:
derived_scale = DerivedParameterScalar("Section", "Name", self.A + self.B)
Expand Down Expand Up @@ -132,7 +132,9 @@ def test_updatable_collection_insertion():

def test_set_sampler_parameter():
my_updatable = MinimalUpdatable()
my_updatable.set_sampler_parameter("the_meaning_of_life", parameters.create())
my_updatable.set_sampler_parameter(
"the_meaning_of_life", parameters.register_new_updatable_parameter()
)

assert hasattr(my_updatable, "the_meaning_of_life")
assert my_updatable.the_meaning_of_life is None
Expand All @@ -143,21 +145,27 @@ def test_set_sampler_parameter_rejects_internal_parameter():

with pytest.raises(TypeError):
my_updatable.set_sampler_parameter(
"the_meaning_of_life", parameters.create(42.0)
"the_meaning_of_life", parameters.register_new_updatable_parameter(42.0)
)


def test_set_sampler_parameter_rejects_duplicates():
my_updatable = MinimalUpdatable()
my_updatable.set_sampler_parameter("the_meaning_of_life", parameters.create())
my_updatable.set_sampler_parameter(
"the_meaning_of_life", parameters.register_new_updatable_parameter()
)

with pytest.raises(ValueError):
my_updatable.set_sampler_parameter("the_meaning_of_life", parameters.create())
my_updatable.set_sampler_parameter(
"the_meaning_of_life", parameters.register_new_updatable_parameter()
)


def test_set_internal_parameter():
my_updatable = MinimalUpdatable()
my_updatable.set_internal_parameter("the_meaning_of_life", parameters.create(42.0))
my_updatable.set_internal_parameter(
"the_meaning_of_life", parameters.register_new_updatable_parameter(42.0)
)

assert hasattr(my_updatable, "the_meaning_of_life")
assert my_updatable.the_meaning_of_life == 42.0
Expand All @@ -166,27 +174,33 @@ def test_set_internal_parameter():
def test_set_internal_parameter_rejects_sampler_parameter():
my_updatable = MinimalUpdatable()
with pytest.raises(TypeError):
my_updatable.set_internal_parameter("sampler_param", parameters.create())
my_updatable.set_internal_parameter(
"sampler_param", parameters.register_new_updatable_parameter()
)


def test_set_internal_parameter_rejects_duplicates():
my_updatable = MinimalUpdatable()
my_updatable.set_internal_parameter("the_meaning_of_life", parameters.create(42.0))
my_updatable.set_internal_parameter(
"the_meaning_of_life", parameters.register_new_updatable_parameter(42.0)
)

with pytest.raises(ValueError):
my_updatable.set_internal_parameter(
"the_meaning_of_life", parameters.create(42.0)
"the_meaning_of_life", parameters.register_new_updatable_parameter(42.0)
)

with pytest.raises(ValueError):
my_updatable.set_internal_parameter(
"the_meaning_of_life", parameters.create(41.0)
"the_meaning_of_life", parameters.register_new_updatable_parameter(41.0)
)


def test_update_rejects_internal_parameters():
my_updatable = MinimalUpdatable()
my_updatable.set_internal_parameter("the_meaning_of_life", parameters.create(42.0))
my_updatable.set_internal_parameter(
"the_meaning_of_life", parameters.register_new_updatable_parameter(42.0)
)
assert hasattr(my_updatable, "the_meaning_of_life")

params = ParamsMap({"a": 1.1, "the_meaning_of_life": 34.0})
Expand Down
Loading