diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index acfc5b2f00..41fc46b6b5 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,5 +1,9 @@ name: GalSim CI +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + on: push: branches: diff --git a/setup.py b/setup.py index 1bae8495c8..8ce1f0fcd8 100644 --- a/setup.py +++ b/setup.py @@ -244,7 +244,7 @@ def supports_gpu(compiler, cc_type): extra_cflags = copt[cc_type] extra_lflags = lopt[cc_type] return try_compile(cpp_code, compiler, extra_cflags, extra_lflags) - + # Check for the fftw3 library in some likely places def find_fftw_lib(output=False): import distutils.sysconfig @@ -434,29 +434,29 @@ def find_eigen_dir(output=False): if output: print("Downloaded %s. Unpacking tarball."%fname) with tarfile.open(fname) as tar: - + def is_within_directory(directory, target): - + abs_directory = os.path.abspath(directory) abs_target = os.path.abspath(target) - + prefix = os.path.commonprefix([abs_directory, abs_target]) - + return prefix == abs_directory - + def safe_extract(tar, path=".", members=None, *, numeric_owner=False): # Avoid security vulnerability in tar.extractall function. # This bit of code was added by the Advanced Research Center at Trellix in PR #1188. # For more information about the security vulnerability, see # https://github.com/advisories/GHSA-gw9q-c7gh-j9vm - + for member in tar.getmembers(): member_path = os.path.join(path, member.name) if not is_within_directory(path, member_path): raise Exception("Attempted Path Traversal in Tar File") - - tar.extractall(path, members, numeric_owner=numeric_owner) - + + tar.extractall(path, members, numeric_owner=numeric_owner) + safe_extract(tar, dir) os.remove(fname) # This actually extracts into a subdirectory with a name eigen-eigen-5a0156e40feb/ diff --git a/tests/galsim_test_helpers.py b/tests/galsim_test_helpers.py index 128075c61e..d6c4fa0cf4 100644 --- a/tests/galsim_test_helpers.py +++ b/tests/galsim_test_helpers.py @@ -25,6 +25,31 @@ from numpy.testing import assert_raises from numpy.testing import assert_warns +__all__ = [ + "default_params", + "gsobject_compare", + "printval", + "convertToShear", + "check_basic_x", + "check_basic_k", + "assert_floatlike", + "assert_intlike", + "check_basic", + "do_shoot", + "do_kvalue", + "radial_integrate", + "drawNoise", + "check_pickle", + "check_all_diff", + "timer", + "CaptureLog", + "assert_raises", + "assert_warns", + "Profile", + "galsim_backend", + "is_jax_galsim", +] + # This file has some helper functions that are used by tests from multiple files to help # avoid code duplication. @@ -43,6 +68,18 @@ integration_relerr = 1.e-6, integration_abserr = 1.e-8) + +def galsim_backend(): + if "jax_galsim/__init__.py" in galsim.__file__: + return "jax_galsim" + else: + return "galsim" + + +def is_jax_galsim(): + return galsim_backend() == "jax_galsim" + + def gsobject_compare(obj1, obj2, conv=None, decimal=10): """Helper function to check that two GSObjects are equivalent """ @@ -133,20 +170,32 @@ def check_basic_x(prof, name, approx_maxsb=False, scale=None): np.testing.assert_allclose( image(i,j), prof._xValue(galsim.PositionD(x,y)), rtol=1.e-5, err_msg="%s profile sb image does not match _xValue at %d,%d"%(name,i,j)) - assert prof.withFlux.__doc__ == galsim.GSObject.withFlux.__doc__ - assert prof.__class__.withFlux.__doc__ == galsim.GSObject.withFlux.__doc__ + if is_jax_galsim(): + for line in galsim.GSObject.withFlux.__doc__.splitlines(): + if line.strip() and "LAX" not in line: + assert line.strip() in prof.withFlux.__doc__, ( + prof.withFlux.__doc__, galsim.GSObject.withFlux.__doc__, + ) + for line in galsim.GSObject.withFlux.__doc__.splitlines(): + if line.strip() and "LAX" not in line: + assert line.strip() in prof.__class__.withFlux.__doc__, ( + prof.__class__.withFlux.__doc__, galsim.GSObject.withFlux.__doc__, + ) + else: + assert prof.withFlux.__doc__ == galsim.GSObject.withFlux.__doc__ + assert prof.__class__.withFlux.__doc__ == galsim.GSObject.withFlux.__doc__ # Check negative flux: neg_image = prof.withFlux(-prof.flux).drawImage(method='sb', scale=scale, use_true_center=False) - np.testing.assert_almost_equal(neg_image.array/prof.flux, -image.array/prof.flux, 7, - '%s negative flux drawReal is not negative of +flux image'%name) + np.testing.assert_array_almost_equal(neg_image.array/prof.flux, -image.array/prof.flux, 7, + '%s negative flux drawReal is not negative of +flux image'%name) # Direct call to drawReal should also work and be equivalent to the above with scale = 1. prof.drawImage(image, method='sb', scale=1., use_true_center=False) image2 = image.copy() prof.drawReal(image2) - np.testing.assert_equal(image2.array, image.array, - err_msg="%s drawReal not equivalent to drawImage"%name) + np.testing.assert_array_equal(image2.array, image.array, + err_msg="%s drawReal not equivalent to drawImage"%name) # If supposed to be axisymmetric, make sure it is. if prof.is_axisymmetric: @@ -194,7 +243,7 @@ def check_basic_k(prof, name): # Check negative flux: neg_image = prof.withFlux(-prof.flux).drawKImage(kimage.copy()) - np.testing.assert_almost_equal(neg_image.array/prof.flux, -kimage.array/prof.flux, 7, + np.testing.assert_array_almost_equal(neg_image.array/prof.flux, -kimage.array/prof.flux, 7, '%s negative flux drawK is not negative of +flux image'%name) # If supposed to be axisymmetric, make sure it is in the kValues. @@ -206,6 +255,30 @@ def check_basic_k(prof, name): np.testing.assert_allclose(test_values, ref_value, rtol=1.e-5, err_msg="%s profile not axisymmetric in kValues"%name) +def assert_floatlike(val): + assert ( + isinstance(val, float) + or ( + is_jax_galsim() + and hasattr(val, "shape") + and val.shape == () + and hasattr(val, "dtype") + and val.dtype.name in ["float", "float32", "float64"] + ) + ), "Value is not float-like: type(%r) = %r" % (val, type(val)) + +def assert_intlike(val): + assert ( + isinstance(val, int) + or ( + is_jax_galsim() + and hasattr(val, "shape") + and val.shape == () + and hasattr(val, "dtype") + and val.dtype.name in ["int", "int32", "int64"] + ) + ), "Value is not int-like: type(%r) = %r" % (val, type(val)) + def check_basic(prof, name, approx_maxsb=False, scale=None, do_x=True, do_k=True): """Do some basic sanity checks that should work for all profiles. """ @@ -220,12 +293,12 @@ def check_basic(prof, name, approx_maxsb=False, scale=None, do_x=True, do_k=True prof.positive_flux - prof.negative_flux, prof.flux, err_msg="%s profile flux not equal to posflux + negflux"%name) assert isinstance(prof.centroid, galsim.PositionD) - assert isinstance(prof.flux, float) - assert isinstance(prof.positive_flux, float) - assert isinstance(prof.negative_flux, float) - assert isinstance(prof.max_sb, float) - assert isinstance(prof.stepk, float) - assert isinstance(prof.maxk, float) + assert_floatlike(prof.flux) + assert_floatlike(prof.positive_flux) + assert_floatlike(prof.negative_flux) + assert_floatlike(prof.max_sb) + assert_floatlike(prof.stepk) + assert_floatlike(prof.maxk) assert isinstance(prof.has_hard_edges, bool) assert isinstance(prof.is_axisymmetric, bool) assert isinstance(prof.is_analytic_x, bool) @@ -298,6 +371,9 @@ def do_shoot(prof, img, name): print('nphot = ',nphot) img2 = img.copy() + if is_jax_galsim(): + rtol *= 3 + # Use a deterministic random number generator so we don't fail tests because of rare flukes # in the random numbers. rng = galsim.UniformDeviate(12345) diff --git a/tests/test_catalog.py b/tests/test_catalog.py index 11623ec60f..331b495f88 100644 --- a/tests/test_catalog.py +++ b/tests/test_catalog.py @@ -16,6 +16,7 @@ # and/or other materials provided with the distribution. # +import os import numpy as np import galsim @@ -198,7 +199,7 @@ def test_basic_dict(): def test_single_row(): """Test that we can read catalogs with just one row (#394) """ - filename = "output/test394.txt" + filename = os.path.join(os.path.dirname(__file__), "output/test394.txt") with open(filename, 'w') as f: f.write("3 4 5\n") cat = galsim.Catalog(filename, file_type='ascii') diff --git a/tests/test_celestial.py b/tests/test_celestial_galsim.py similarity index 96% rename from tests/test_celestial.py rename to tests/test_celestial_galsim.py index 1a64f37283..3e89bf4eb9 100644 --- a/tests/test_celestial.py +++ b/tests/test_celestial_galsim.py @@ -17,6 +17,7 @@ # import numpy +import numpy as np import os import math @@ -122,11 +123,11 @@ def test_angle(): # Check invalid constructors assert_raises(TypeError,galsim.AngleUnit, galsim.degrees) - assert_raises(ValueError,galsim.AngleUnit, 'spam') + assert_raises((ValueError, TypeError), galsim.AngleUnit, 'spam') assert_raises(TypeError,galsim.AngleUnit, 1, 3) assert_raises(TypeError,galsim.Angle, 3.4) assert_raises(TypeError,galsim.Angle, theta1, galsim.degrees) - assert_raises(ValueError,galsim.Angle, 'spam', galsim.degrees) + assert_raises((ValueError, TypeError), galsim.Angle, 'spam') assert_raises(TypeError,galsim.Angle, 1, 3) @@ -155,7 +156,7 @@ def test_celestialcoord_basic(): x, y, z = c1.get_xyz() print('c1 is at x,y,z = ',x,y,z) - np.testing.assert_equal((x,y,z), (1,0,0)) + np.testing.assert_array_equal((x,y,z), (1,0,0)) assert c1 == galsim.CelestialCoord.from_xyz(x,y,z) x, y, z = c2.get_xyz() @@ -343,9 +344,19 @@ def test_projection(): # First the trivial case p0 = center.project(center, projection='lambert') - assert p0 == (0.0 * galsim.arcsec, 0.0 * galsim.arcsec) + np.testing.assert_allclose( + (p0[0].rad, p0[1].rad), + (0.0, 0.0), + rtol=0, + atol=1e-16, + ) c0 = center.deproject(*p0, projection='lambert') - assert c0 == center + np.testing.assert_allclose( + c0.rad, + center.rad, + rtol=0, + atol=1e-16, + ) np.testing.assert_almost_equal(center.jac_deproject(*p0, projection='lambert').ravel(), (1,0,0,1)) @@ -398,9 +409,19 @@ def test_projection(): # First the trivial case p0 = center.project(center, projection='stereographic') - assert p0 == (0.0 * galsim.arcsec, 0.0 * galsim.arcsec) + np.testing.assert_allclose( + (p0[0].rad, p0[1].rad), + (0.0, 0.0), + rtol=0, + atol=1e-16, + ) c0 = center.deproject(*p0, projection='stereographic') - assert c0 == center + np.testing.assert_allclose( + c0.rad, + center.rad, + rtol=0, + atol=1e-16, + ) np.testing.assert_almost_equal(center.jac_deproject(*p0, projection='stereographic').ravel(), (1,0,0,1)) @@ -456,9 +477,19 @@ def test_projection(): # First the trivial case p0 = center.project(center, projection='gnomonic') - assert p0 == (0.0 * galsim.arcsec, 0.0 * galsim.arcsec) + np.testing.assert_allclose( + (p0[0].rad, p0[1].rad), + (0.0, 0.0), + rtol=0, + atol=1e-16, + ) c0 = center.deproject(*p0, projection='gnomonic') - assert c0 == center + np.testing.assert_allclose( + c0.rad, + center.rad, + rtol=0, + atol=1e-16, + ) np.testing.assert_almost_equal(center.jac_deproject(*p0, projection='gnomonic').ravel(), (1,0,0,1)) @@ -510,9 +541,19 @@ def test_projection(): # First the trivial case p0 = center.project(center, projection='postel') - assert p0 == (0.0 * galsim.arcsec, 0.0 * galsim.arcsec) + np.testing.assert_allclose( + (p0[0].rad, p0[1].rad), + (0.0, 0.0), + rtol=0, + atol=1e-16, + ) c0 = center.deproject(*p0, projection='postel') - assert c0 == center + np.testing.assert_allclose( + c0.rad, + center.rad, + rtol=0, + atol=1e-16, + ) np.testing.assert_almost_equal(center.jac_deproject(*p0, projection='postel').ravel(), (1,0,0,1)) diff --git a/tests/test_chromatic.py b/tests/test_chromatic.py index 23afd9de46..af87ae8476 100644 --- a/tests/test_chromatic.py +++ b/tests/test_chromatic.py @@ -16,6 +16,7 @@ # and/or other materials provided with the distribution. # +import copy import os import copy import numpy as np diff --git a/tests/test_config_output.py b/tests/test_config_output.py index c75ff49301..aed0c3bec5 100644 --- a/tests/test_config_output.py +++ b/tests/test_config_output.py @@ -1238,7 +1238,7 @@ def test_config(): } # Test yaml - yaml_file_name = "output/test_config.yaml" + yaml_file_name = os.path.join(os.path.dirname(__file__), "output/test_config.yaml") with open(yaml_file_name, 'w') as fout: yaml.dump(config, fout, default_flow_style=True) # String None will be coverted to a real None. Convert here in the comparison dict @@ -1252,7 +1252,7 @@ def test_config(): assert config == dict(config2) # Test json - json_file_name = "output/test_config.json" + json_file_name = os.path.join(os.path.dirname(__file__), "output/test_config.json") with open(json_file_name, 'w') as fout: json.dump(config, fout) diff --git a/tests/test_convolve.py b/tests/test_convolve.py index 53984c6223..3abe444ea0 100644 --- a/tests/test_convolve.py +++ b/tests/test_convolve.py @@ -22,8 +22,8 @@ import galsim from galsim_test_helpers import * -imgdir = os.path.join(".", "SBProfile_comparison_images") # Directory containing the reference - # images. +# Directory containing the reference images. +imgdir = os.path.join(os.path.dirname(__file__), "SBProfile_comparison_images") @timer def test_convolve(): @@ -496,9 +496,9 @@ def test_deconvolve(): cen = galsim.PositionD(0,0) np.testing.assert_equal(inv_psf.centroid, cen) - np.testing.assert_almost_equal(inv_psf.flux, 1./psf.flux) + np.testing.assert_array_almost_equal(inv_psf.flux, 1./psf.flux) # This doesn't really have any meaning, but this is what we've assigned to a deconvolve max_sb. - np.testing.assert_almost_equal(inv_psf.max_sb, -psf.max_sb / psf.flux**2) + np.testing.assert_array_almost_equal(inv_psf.max_sb, -psf.max_sb / psf.flux**2) check_basic(inv_psf, "Deconvolve(Moffat)", do_x=False) diff --git a/tests/test_correlatednoise.py b/tests/test_correlatednoise.py index 6848fc04c4..5e5c0a4bf0 100644 --- a/tests/test_correlatednoise.py +++ b/tests/test_correlatednoise.py @@ -16,6 +16,8 @@ # and/or other materials provided with the distribution. # +import os +import time import numpy as np import os diff --git a/tests/test_deprecated.py b/tests/test_deprecated.py index 0e562531ab..150654a77e 100644 --- a/tests/test_deprecated.py +++ b/tests/test_deprecated.py @@ -34,9 +34,12 @@ def check_dep(f, *args, **kwargs): @timer def test_gsparams(): - check_dep(galsim.GSParams, allowed_flux_variation=0.90) - check_dep(galsim.GSParams, range_division_for_extrema=50) - check_dep(galsim.GSParams, small_fraction_of_flux=1.e-6) + if is_jax_galsim(): + pass + else: + check_dep(galsim.GSParams, allowed_flux_variation=0.90) + check_dep(galsim.GSParams, range_division_for_extrema=50) + check_dep(galsim.GSParams, small_fraction_of_flux=1.e-6) @timer @@ -540,75 +543,182 @@ def test_photon_array_depr(): # Using the getter is allowed, but deprecated. photon_array = galsim.PhotonArray(nphotons) - dxdz = check_dep(getattr, photon_array, 'dxdz') - assert photon_array.hasAllocatedAngles() - assert photon_array.hasAllocatedAngles() - assert len(photon_array.dxdz) == nphotons - assert len(photon_array.dydz) == nphotons - dxdz[:] = 0.17 + if is_jax_galsim(): + # jax-galsim always sets these additional properties + # dxdz = check_dep(getattr, photon_array, 'dxdz') + # however jax-galsim sets them to NaN so they are not allocated + assert not photon_array.hasAllocatedAngles() + assert len(photon_array.dxdz) == nphotons + # JAX-Galsim does not allow by reference setting - changed this + # to make tests below run + photon_array.dxdz = 0.17 + # non-nan means allocated for jax-galsim + assert photon_array.hasAllocatedAngles() + else: + dxdz = check_dep(getattr, photon_array, 'dxdz') + assert photon_array.hasAllocatedAngles() + assert photon_array.hasAllocatedAngles() + assert len(photon_array.dxdz) == nphotons + assert len(photon_array.dydz) == nphotons + dxdz[:] = 0.17 np.testing.assert_array_equal(photon_array.dxdz, 0.17) np.testing.assert_array_equal(photon_array.dydz, 0.) - dydz = photon_array.dydz # Allowed now. - dydz[:] = 0.59 + if is_jax_galsim(): + assert hasattr(photon_array, "dydz") + # JAX-Galsim does not allow by reference setting - changed this + # to make tests below run + photon_array.dydz = 0.59 + else: + dydz = photon_array.dydz # Allowed now. + dydz[:] = 0.59 np.testing.assert_array_equal(photon_array.dydz, 0.59) - wave = check_dep(getattr, photon_array, 'wavelength') - assert photon_array.hasAllocatedWavelengths() + if is_jax_galsim(): + # jax-galsim always sets these additional properties + # wave = check_dep(getattr, photon_array, 'wavelength') + # however jax-galsim sets them to NaN so they are not allocated + assert not photon_array.hasAllocatedWavelengths() + else: + wave = check_dep(getattr, photon_array, 'wavelength') + assert photon_array.hasAllocatedWavelengths() assert len(photon_array.wavelength) == nphotons - wave[:] = 500. + if is_jax_galsim(): + # JAX-Galsim does not allow by reference setting - changed this + # to make tests below run + photon_array.wavelength = 500.0 + # jax-galsim is allocated now + assert photon_array.hasAllocatedWavelengths() + else: + wave[:] = 500. np.testing.assert_array_equal(photon_array.wavelength, 500) - u = check_dep(getattr, photon_array, 'pupil_u') - assert photon_array.hasAllocatedPupil() - assert len(photon_array.pupil_u) == nphotons + if is_jax_galsim(): + # jax-galsim always sets these additional properties + # u = check_dep(getattr, photon_array, "pupil_u") + # however jax-galsim sets them to NaN so they are not allocated + assert not photon_array.hasAllocatedPupil() + else: + u = check_dep(getattr, photon_array, 'pupil_u') + assert photon_array.hasAllocatedPupil() + assert len(photon_array.pupil_u) == nphotons assert len(photon_array.pupil_v) == nphotons - u[:] = 6.0 + if is_jax_galsim(): + # JAX-Galsim does not allow by reference setting - changed this + # to make tests below run + photon_array.pupil_u = 6.0 + # jax-galsim is allocated now + assert photon_array.hasAllocatedPupil() + else: + u[:] = 6.0 np.testing.assert_array_equal(photon_array.pupil_u, 6.0) np.testing.assert_array_equal(photon_array.pupil_v, 0.0) - v = photon_array.pupil_v - v[:] = 10.0 + if is_jax_galsim(): + assert hasattr(photon_array, "pupil_v") + # JAX-Galsim does not allow by reference setting - changed this + # to make tests below run + photon_array.pupil_v = 10.0 + else: + v = photon_array.pupil_v + v[:] = 10.0 np.testing.assert_array_equal(photon_array.pupil_v, 10.0) - t = check_dep(getattr, photon_array, 'time') + if is_jax_galsim(): + # jax-galsim always sets these additional properties + # t = check_dep(getattr, photon_array, "time") + # however jax-galsim sets them to NaN so they are not allocated + assert not photon_array.hasAllocatedTimes() + # jax-galsim needs to set 0 + photon_array.time = 0.0 + # jax-galsim is allocated now + assert photon_array.hasAllocatedTimes() + else: + t = check_dep(getattr, photon_array, 'time') assert photon_array.hasAllocatedTimes() assert len(photon_array.time) == nphotons np.testing.assert_array_equal(photon_array.time, 0.0) - t[:] = 10 + if is_jax_galsim(): + # JAX-Galsim does not allow by reference setting - changed this + # to make tests below run + photon_array.time = 10 + else: + t[:] = 10 np.testing.assert_array_equal(photon_array.time, 10.0) # For coverage, also need to test the two pair ones in other order. photon_array = galsim.PhotonArray(nphotons) - dydz = check_dep(getattr, photon_array, 'dydz') - assert photon_array.hasAllocatedAngles() - assert photon_array.hasAllocatedAngles() + if is_jax_galsim(): + # jax-galsim always sets these additional properties + # dydz = check_dep(getattr, photon_array, "dydz") + # however jax-galsim sets them to NaN so they are not allocated + assert not photon_array.hasAllocatedAngles() + else: + dydz = check_dep(getattr, photon_array, 'dydz') + assert photon_array.hasAllocatedAngles() + assert photon_array.hasAllocatedAngles() assert len(photon_array.dxdz) == nphotons assert len(photon_array.dydz) == nphotons - dydz[:] = 0.59 + if is_jax_galsim(): + # JAX-Galsim does not allow by reference setting - changed this + # to make tests below run + photon_array.dydz = 0.59 + # non-nan means allocated for jax-galsim + assert photon_array.hasAllocatedAngles() + else: + dydz[:] = 0.59 np.testing.assert_array_equal(photon_array.dxdz, 0.) np.testing.assert_array_equal(photon_array.dydz, 0.59) - dxdz = photon_array.dxdz # Allowed now. - dxdz[:] = 0.17 + if is_jax_galsim(): + assert hasattr(photon_array, "dxdz") + # JAX-Galsim does not allow by reference setting - changed this + # to make tests below run + photon_array.dxdz = 0.17 + else: + dxdz = photon_array.dxdz # Allowed now. + dxdz[:] = 0.17 np.testing.assert_array_equal(photon_array.dxdz, 0.17) - v = check_dep(getattr, photon_array, 'pupil_v') - assert photon_array.hasAllocatedPupil() - assert len(photon_array.pupil_u) == nphotons + if is_jax_galsim(): + # jax-galsim always sets these additional properties + # v = check_dep(getattr, photon_array, "pupil_v") + # however jax-galsim sets them to NaN so they are not allocated + assert not photon_array.hasAllocatedPupil() + else: + v = check_dep(getattr, photon_array, 'pupil_v') + assert photon_array.hasAllocatedPupil() + assert len(photon_array.pupil_u) == nphotons assert len(photon_array.pupil_v) == nphotons - v[:] = 10.0 + if is_jax_galsim(): + # JAX-Galsim does not allow by reference setting - changed this + # to make tests below run + photon_array.pupil_v = 10.0 + # jax-galsim is allocated now + assert photon_array.hasAllocatedPupil() + else: + v[:] = 10.0 np.testing.assert_array_equal(photon_array.pupil_u, 0.0) np.testing.assert_array_equal(photon_array.pupil_v, 10.0) - u = photon_array.pupil_u - u[:] = 6.0 + if is_jax_galsim(): + assert hasattr(photon_array, "pupil_u") + # JAX-Galsim does not allow by reference setting - changed this + # to make tests below run + photon_array.pupil_u = 6.0 + else: + u = photon_array.pupil_u + u[:] = 6.0 np.testing.assert_array_equal(photon_array.pupil_u, 6.0) # Check assignAt pa1 = galsim.PhotonArray(50) pa1.x = photon_array.x[:50] - for i in range(50): - pa1.y[i] = photon_array.y[i] - pa1.flux[0:50] = photon_array.flux[:50] + if is_jax_galsim(): + pa1.y = photon_array.y[:50] + pa1.flux.at[0:50].set(photon_array.flux[:50]) + else: + for i in range(50): + pa1.y[i] = photon_array.y[i] + pa1.flux[0:50] = photon_array.flux[:50] pa1.dxdz = photon_array.dxdz[:50] pa1.dydz = photon_array.dydz[:50] pa1.pupil_u = photon_array.pupil_u[:50] @@ -616,20 +726,20 @@ def test_photon_array_depr(): pa2 = galsim.PhotonArray(100) check_dep(pa2.assignAt, 0, pa1) check_dep(pa2.assignAt, 50, pa1) - np.testing.assert_almost_equal(pa2.x[:50], pa1.x) - np.testing.assert_almost_equal(pa2.y[:50], pa1.y) - np.testing.assert_almost_equal(pa2.flux[:50], pa1.flux) - np.testing.assert_almost_equal(pa2.dxdz[:50], pa1.dxdz) - np.testing.assert_almost_equal(pa2.dydz[:50], pa1.dydz) - np.testing.assert_almost_equal(pa2.pupil_u[:50], pa1.pupil_u) - np.testing.assert_almost_equal(pa2.pupil_v[:50], pa1.pupil_v) - np.testing.assert_almost_equal(pa2.x[50:], pa1.x) - np.testing.assert_almost_equal(pa2.y[50:], pa1.y) - np.testing.assert_almost_equal(pa2.flux[50:], pa1.flux) - np.testing.assert_almost_equal(pa2.dxdz[50:], pa1.dxdz) - np.testing.assert_almost_equal(pa2.dydz[50:], pa1.dydz) - np.testing.assert_almost_equal(pa2.pupil_u[50:], pa1.pupil_u) - np.testing.assert_almost_equal(pa2.pupil_v[50:], pa1.pupil_v) + np.testing.assert_array_almost_equal(pa2.x[:50], pa1.x) + np.testing.assert_array_almost_equal(pa2.y[:50], pa1.y) + np.testing.assert_array_almost_equal(pa2.flux[:50], pa1.flux) + np.testing.assert_array_almost_equal(pa2.dxdz[:50], pa1.dxdz) + np.testing.assert_array_almost_equal(pa2.dydz[:50], pa1.dydz) + np.testing.assert_array_almost_equal(pa2.pupil_u[:50], pa1.pupil_u) + np.testing.assert_array_almost_equal(pa2.pupil_v[:50], pa1.pupil_v) + np.testing.assert_array_almost_equal(pa2.x[50:], pa1.x) + np.testing.assert_array_almost_equal(pa2.y[50:], pa1.y) + np.testing.assert_array_almost_equal(pa2.flux[50:], pa1.flux) + np.testing.assert_array_almost_equal(pa2.dxdz[50:], pa1.dxdz) + np.testing.assert_array_almost_equal(pa2.dydz[50:], pa1.dydz) + np.testing.assert_array_almost_equal(pa2.pupil_u[50:], pa1.pupil_u) + np.testing.assert_array_almost_equal(pa2.pupil_v[50:], pa1.pupil_v) # Error if it doesn't fit. with assert_raises(ValueError): diff --git a/tests/test_des.py b/tests/test_des.py index c6f6e72404..52dcebf075 100644 --- a/tests/test_des.py +++ b/tests/test_des.py @@ -491,7 +491,7 @@ def test_nan_fits(): if not hasattr(pyfits, 'verify'): return # The problematic file: - file_name = "des_data/DECam_00158414_01.fits.fz" + file_name = os.path.join(os.path.dirname(__file__), "des_data/DECam_00158414_01.fits.fz") # These are the values we should be reading in: ref_bounds = galsim.BoundsI(xmin=1, xmax=2048, ymin=1, ymax=4096) @@ -549,7 +549,7 @@ def test_nan_fits(): def test_psf(): """Test the two kinds of PSF files we have in DES. """ - data_dir = 'des_data' + data_dir = os.path.join(os.path.dirname(__file__), 'des_data') psfex_file = "DECam_00154912_12_psfcat.psf" fitpsf_file = "DECam_00154912_12_fitpsf.fits" wcs_file = "DECam_00154912_12_header.fits" diff --git a/tests/test_draw.py b/tests/test_draw.py index 24e22bca78..e2077c5c73 100644 --- a/tests/test_draw.py +++ b/tests/test_draw.py @@ -16,6 +16,7 @@ # and/or other materials provided with the distribution. # +import os import numpy as np import galsim @@ -375,21 +376,28 @@ def test_drawImage(): assert_raises(TypeError, obj.drawImage, bounds=bounds, scale=scale, wcs=galsim.PixelScale(3)) assert_raises(TypeError, obj.drawImage, bounds=bounds, wcs=scale) assert_raises(TypeError, obj.drawImage, image=im10.array) - assert_raises(TypeError, obj.drawImage, wcs=galsim.FitsWCS('fits_files/tpv.fits')) + assert_raises(TypeError, obj.drawImage, wcs=galsim.FitsWCS( + os.path.join(os.path.dirname(__file__), 'fits_files/tpv.fits'))) assert_raises(ValueError, obj.drawImage, bounds=galsim.BoundsI()) - assert_raises(ValueError, obj.drawImage, image=im10, gain=0.) - assert_raises(ValueError, obj.drawImage, image=im10, gain=-1.) - assert_raises(ValueError, obj.drawImage, image=im10, area=0.) - assert_raises(ValueError, obj.drawImage, image=im10, area=-1.) - assert_raises(ValueError, obj.drawImage, image=im10, exptime=0.) - assert_raises(ValueError, obj.drawImage, image=im10, exptime=-1.) + if is_jax_galsim(): + pass + else: + assert_raises(ValueError, obj.drawImage, image=im10, gain=0.) + assert_raises(ValueError, obj.drawImage, image=im10, gain=-1.) + assert_raises(ValueError, obj.drawImage, image=im10, area=0.) + assert_raises(ValueError, obj.drawImage, image=im10, area=-1.) + assert_raises(ValueError, obj.drawImage, image=im10, exptime=0.) + assert_raises(ValueError, obj.drawImage, image=im10, exptime=-1.) assert_raises(ValueError, obj.drawImage, image=im10, method='invalid') # These options are invalid unless metho=phot assert_raises(TypeError, obj.drawImage, image=im10, n_photons=3) assert_raises(TypeError, obj.drawImage, rng=galsim.BaseDeviate(234)) - assert_raises(TypeError, obj.drawImage, max_extra_noise=23) + if is_jax_galsim(): + pass + else: + assert_raises(TypeError, obj.drawImage, max_extra_noise=23) assert_raises(TypeError, obj.drawImage, poisson_flux=True) assert_raises(TypeError, obj.drawImage, maxN=10000) assert_raises(TypeError, obj.drawImage, save_photons=True) @@ -519,6 +527,15 @@ def test_drawKImage(): """Test the various optional parameters to the drawKImage function. In particular test the parameters image, and scale in various combinations. """ + if is_jax_galsim(): + maxk_threshold = 1.e-3 + N = 880 + Ns = 28 + else: + maxk_threshold = 1.e-4 + N = 1174 + Ns = 37 + # We use a Moffat profile with beta = 1.5, since its real-space profile is # flux / (2 pi rD^2) * (1 + (r/rD)^2)^3/2 # and the 2-d Fourier transform of that is @@ -526,14 +543,13 @@ def test_drawKImage(): # So this should draw in Fourier space the same image as the Exponential drawn in # test_drawImage(). obj = galsim.Moffat(flux=test_flux, beta=1.5, scale_radius=0.5) - obj = obj.withGSParams(maxk_threshold=1.e-4) + obj = obj.withGSParams(maxk_threshold=maxk_threshold) # First test drawKImage() with no kwargs. It should: # - create new images # - return the new images # - set the scale to 2pi/(N*obj.nyquist_scale) im1 = obj.drawKImage() - N = 1174 np.testing.assert_equal(im1.bounds, galsim.BoundsI(-N/2,N/2,-N/2,N/2), "obj.drawKImage() produced image with wrong bounds") stepk = obj.stepk @@ -555,7 +571,7 @@ def test_drawKImage(): # - also return that image # - set the scale to obj.stepk # - zero out any existing data - im3 = galsim.ImageCD(1149,1149) + im3 = galsim.ImageCD(N-25,N-25) im4 = obj.drawKImage(im3) np.testing.assert_almost_equal(im3.scale, stepk, 9, "obj.drawKImage(im3) produced image with wrong scale") @@ -603,7 +619,7 @@ def test_drawKImage(): np.testing.assert_almost_equal(CalculateScale(im7), 2, 1, "Measured wrong scale after obj.drawKImage(dx)") # This image is smaller because not using nyquist scale for stepk - np.testing.assert_equal(im7.bounds, galsim.BoundsI(-37,37,-37,37), + np.testing.assert_equal(im7.bounds, galsim.BoundsI(-Ns,Ns,-Ns,Ns), "obj.drawKImage(dx) produced image with wrong bounds") # Test if we provide an image with a defined scale. It should: @@ -758,7 +774,7 @@ def test_drawKImage(): np.testing.assert_equal( im6.array.shape, (ny//4+1, nx//3+1), "obj.drawKImage(bounds,scale,recenter=False) produced image with wrong shape") - np.testing.assert_almost_equal( + np.testing.assert_array_almost_equal( im6.array, im4[bounds6].array, 9, "obj.drawKImage(recenter=False) produced different values than recenter=True") @@ -768,7 +784,7 @@ def test_drawKImage(): np.testing.assert_almost_equal( im6.scale, scale, 9, "obj.drawKImage(image,recenter=False) produced image with wrong scale") - np.testing.assert_almost_equal( + np.testing.assert_array_almost_equal( im6.array, im4[bounds6].array, 9, "obj.drawKImage(image,recenter=False) produced different values than recenter=True") @@ -778,7 +794,7 @@ def test_drawKImage(): np.testing.assert_almost_equal( im6.scale, scale, 9, "obj.drawKImage(image,add_to_image=True) produced image with wrong scale") - np.testing.assert_almost_equal( + np.testing.assert_array_almost_equal( im6.array, im4[bounds6].array, 9, "obj.drawKImage(image,add_to_image=True) produced different values than recenter=True") @@ -790,7 +806,7 @@ def test_drawKImage(): np.testing.assert_almost_equal( im7.scale, scale, 9, "obj.drawKImage(image,add_to_image=True) produced image with wrong scale") - np.testing.assert_almost_equal( + np.testing.assert_array_almost_equal( im7.array, im4.array, 9, "obj.drawKImage(image,add_to_image=True) produced different values than recenter=True") @@ -951,7 +967,7 @@ def test_offset(): # Can also use center to explicitly say we want to use the true_center. im3 = obj.drawImage(im.copy(), method='sb', center=im.true_center) - np.testing.assert_almost_equal(im3.array, im.array) + np.testing.assert_array_almost_equal(im3.array, im.array) # Test that a few pixel values match xValue. # Note: we don't expect the FFT drawn image to match the xValues precisely, since the @@ -1039,14 +1055,14 @@ def test_offset(): # Test that the center parameter can be used to do the same thing. center = galsim.PositionD(cenx + offx, ceny + offy) im3 = obj.drawImage(im.copy(), method='sb', center=center) - np.testing.assert_almost_equal(im3.array, im.array) + np.testing.assert_array_almost_equal(im3.array, im.array) assert im3.bounds == im.bounds assert im3.wcs == im.wcs # Can also use both offset and center im3 = obj.drawImage(im.copy(), method='sb', center=(cenx-1, ceny+1), offset=(offx+1, offy-1)) - np.testing.assert_almost_equal(im3.array, im.array) + np.testing.assert_array_almost_equal(im3.array, im.array) assert im3.bounds == im.bounds assert im3.wcs == im.wcs @@ -1097,19 +1113,31 @@ def test_shoot(): # in exact arithmetic. We had an assert there which blew up in a not very nice way. obj = galsim.Gaussian(sigma=0.2398318) + 0.1*galsim.Gaussian(sigma=0.47966352) obj = obj.withFlux(100001) - image1 = galsim.ImageF(32,32, init_value=100) + if is_jax_galsim(): + # jax galsim needs double images here + image1 = galsim.ImageD(32,32, init_value=100) + else: + image1 = galsim.ImageF(32,32, init_value=100) rng = galsim.BaseDeviate(1234) obj.drawImage(image1, method='phot', poisson_flux=False, add_to_image=True, rng=rng, maxN=100000) # The test here is really just that it doesn't crash. # But let's do something to check correctness. - image2 = galsim.ImageF(32,32) + if is_jax_galsim(): + # jax galsim needs double images here + image2 = galsim.ImageD(32,32) + else: + image2 = galsim.ImageF(32,32) rng = galsim.BaseDeviate(1234) obj.drawImage(image2, method='phot', poisson_flux=False, add_to_image=False, rng=rng, maxN=100000) image2 += 100 - np.testing.assert_almost_equal(image2.array, image1.array, decimal=12) + if is_jax_galsim(): + # jax galsim works not as well + np.testing.assert_array_almost_equal(image2.array, image1.array, decimal=10) + else: + np.testing.assert_array_almost_equal(image2.array, image1.array, decimal=12) # Also check that you get the same answer with a smaller maxN. image3 = galsim.ImageF(32,32, init_value=100) @@ -1120,16 +1148,19 @@ def test_shoot(): # Test that shooting with 0.0 flux makes a zero-photons image. image4 = (obj*0).drawImage(method='phot') - np.testing.assert_equal(image4.array, 0) + np.testing.assert_array_equal(image4.array, 0) # Warns if flux is 1 and n_photons not given. psf = galsim.Gaussian(sigma=3) - with assert_warns(galsim.GalSimWarning): - psf.drawImage(method='phot') - with assert_warns(galsim.GalSimWarning): - psf.drawPhot(image4) - with assert_warns(galsim.GalSimWarning): - psf.makePhot() + if is_jax_galsim(): + pass + else: + with assert_warns(galsim.GalSimWarning): + psf.drawImage(method='phot') + with assert_warns(galsim.GalSimWarning): + psf.drawPhot(image4) + with assert_warns(galsim.GalSimWarning): + psf.makePhot() # With n_photons=1, it's fine. psf.drawImage(method='phot', n_photons=1) psf.drawPhot(image4, n_photons=1) @@ -1188,19 +1219,28 @@ def test_drawImage_area_exptime(): # Shooting with flux=1 raises a warning. obj1 = obj.withFlux(1) - with assert_warns(galsim.GalSimWarning): - obj1.drawImage(method='phot') + if is_jax_galsim(): + pass + else: + with assert_warns(galsim.GalSimWarning): + obj1.drawImage(method='phot') # But not if we explicitly tell it to shoot 1 photon with assert_raises(AssertionError): assert_warns(galsim.GalSimWarning, obj1.drawImage, method='phot', n_photons=1) # Likewise for makePhot - with assert_warns(galsim.GalSimWarning): - obj1.makePhot() + if is_jax_galsim(): + pass + else: + with assert_warns(galsim.GalSimWarning): + obj1.makePhot() with assert_raises(AssertionError): assert_warns(galsim.GalSimWarning, obj1.makePhot, n_photons=1) # And drawPhot - with assert_warns(galsim.GalSimWarning): - obj1.drawPhot(im1) + if is_jax_galsim(): + pass + else: + with assert_warns(galsim.GalSimWarning): + obj1.drawPhot(im1) with assert_raises(AssertionError): assert_warns(galsim.GalSimWarning, obj1.drawPhot, im1, n_photons=1) @@ -1219,9 +1259,15 @@ def test_fft(): [4,6,8,4], [2,4,6,6] ], xmin=-2, ymin=-2, dtype=dt, scale=0.1) - kim = xim.calculate_fft() - xim2 = kim.calculate_inverse_fft() - np.testing.assert_almost_equal(xim.array, xim2.array) + if is_jax_galsim(): + if dt not in [np.complex128, complex]: + kim = xim.calculate_fft() + xim2 = kim.calculate_inverse_fft() + np.testing.assert_array_almost_equal(xim.array, xim2.array) + else: + kim = xim.calculate_fft() + xim2 = kim.calculate_inverse_fft() + np.testing.assert_array_almost_equal(xim.array, xim2.array) # Now the other way, starting with a (real) k-space image. kim = galsim.Image([ [4,2,0], @@ -1231,7 +1277,7 @@ def test_fft(): xmin=0, ymin=-2, dtype=dt, scale=0.1) xim = kim.calculate_inverse_fft() kim2 = xim.calculate_fft() - np.testing.assert_almost_equal(kim.array, kim2.array) + np.testing.assert_array_almost_equal(kim.array, kim2.array) # Test starting with a larger image that gets wrapped. kim3 = galsim.Image([ [0,1,2,1,0], @@ -1242,7 +1288,7 @@ def test_fft(): xmin=-2, ymin=-2, dtype=dt, scale=0.1) xim = kim3.calculate_inverse_fft() kim2 = xim.calculate_fft() - np.testing.assert_almost_equal(kim.array, kim2.array) + np.testing.assert_array_almost_equal(kim.array, kim2.array) # Test padding X Image with zeros xim = galsim.Image([ [0,0,0,0], @@ -1253,9 +1299,15 @@ def test_fft(): xim2 = galsim.Image([ [2,4,6], [4,6,8] ], xmin=-2, ymin=-1, dtype=dt, scale=0.1) - kim = xim.calculate_fft() - kim2 = xim2.calculate_fft() - np.testing.assert_almost_equal(kim.array, kim2.array) + if is_jax_galsim(): + if dt not in [np.complex128, complex]: + kim = xim.calculate_fft() + kim2 = xim2.calculate_fft() + np.testing.assert_array_almost_equal(kim.array, kim2.array) + else: + kim = xim.calculate_fft() + kim2 = xim2.calculate_fft() + np.testing.assert_array_almost_equal(kim.array, kim2.array) # Test padding K Image with zeros kim = galsim.Image([ [4,2,0], @@ -1270,14 +1322,22 @@ def test_fft(): xmin=0, ymin=-1, dtype=dt, scale=0.1) xim = kim.calculate_inverse_fft() xim2 = kim2.calculate_inverse_fft() - np.testing.assert_almost_equal(xim.array, xim2.array) + np.testing.assert_array_almost_equal(xim.array, xim2.array) # Now use drawKImage (as above in test_drawKImage) to get a more realistic k-space image + # NB. It is useful to have this come out not a multiple of 4, since some of the + # calculation needs to be different when N/2 is odd. + if is_jax_galsim(): + maxk_threshold = 0.78e-3 + N = 912 + Nfft = 1024 + else: + maxk_threshold = 1.e-4 + N = 1174 + Nfft = 1536 obj = galsim.Moffat(flux=test_flux, beta=1.5, scale_radius=0.5) - obj = obj.withGSParams(maxk_threshold=1.e-4) + obj = obj.withGSParams(maxk_threshold=maxk_threshold) im1 = obj.drawKImage() - N = 1174 # NB. It is useful to have this come out not a multiple of 4, since some of the - # calculation needs to be different when N/2 is odd. np.testing.assert_equal(im1.bounds, galsim.BoundsI(-N/2,N/2,-N/2,N/2), "obj.drawKImage() produced image with wrong bounds") nyq_scale = obj.nyquist_scale @@ -1297,14 +1357,14 @@ def test_fft(): np.testing.assert_almost_equal( im1_real.scale, im1_alt_real.scale, 3, "inverse_fft produce a different scale than obj2.drawImage(method='sb')") - np.testing.assert_almost_equal( + np.testing.assert_array_almost_equal( im1_real.array, im1_alt_real.array, 3, "inverse_fft produce a different array than obj2.drawImage(method='sb')") # If we give both a good size to use and match up the scales, then they should produce the # same thing. N = galsim.Image.good_fft_size(N) - assert N == 1536 == 3 * 2**9 + assert N == Nfft kscale = 2.*np.pi / (N * nyq_scale) im2 = obj.drawKImage(nx=N+1, ny=N+1, scale=kscale) im2_real = im2.calculate_inverse_fft() @@ -1316,7 +1376,7 @@ def test_fft(): np.testing.assert_almost_equal( im2_real.scale, im2_alt_real.scale, 9, "inverse_fft produce a different scale than obj2.drawImage(nx,ny,method='sb')") - np.testing.assert_almost_equal( + np.testing.assert_array_almost_equal( im2_real.array, im2_alt_real.array, 9, "inverse_fft produce a different array than obj2.drawImage(nx,ny,method='sb')") @@ -1524,8 +1584,12 @@ def test_np_fft(): def round_cast(array, dt): # array.astype(dt) doesn't round to the nearest for integer types. # This rounds first if dt is integer and then casts. - if dt(0.5) != 0.5: - array = np.around(array) + if is_jax_galsim(): + # NOTE JAX doesn't round to the nearest int when drawing + pass + else: + if dt(0.5) != 0.5: + array = np.around(array) return array.astype(dt) @timer @@ -1551,13 +1615,13 @@ def test_types(): "wrong scale when drawing onto dt=%s"%dt) np.testing.assert_equal(im.bounds, ref_im.bounds, "wrong bounds when drawing onto dt=%s"%dt) - np.testing.assert_almost_equal(im.array, round_cast(ref_im.array, dt), 6, + np.testing.assert_array_almost_equal(im.array, round_cast(ref_im.array, dt), 6, "wrong array when drawing onto dt=%s"%dt) if method == 'phot': rng.reset(1234) obj.drawImage(im, method=method, add_to_image=True, rng=rng) - np.testing.assert_almost_equal(im.array, round_cast(ref_im.array, dt) * 2, 6, + np.testing.assert_array_almost_equal(im.array, round_cast(ref_im.array, dt) * 2, 6, "wrong array when adding to image with dt=%s"%dt) @timer @@ -1597,13 +1661,13 @@ def test_direct_scale(): obj.dilate(1.0).drawReal(im4) obj.rotate(0.3*galsim.radians).drawReal(im5) print('no_pixel: max diff = ',np.max(np.abs(im1.array - im2.array))) - np.testing.assert_almost_equal(im1.array, im2.array, 15, + np.testing.assert_array_almost_equal(im1.array, im2.array, 15, "drawReal made different image than method='no_pixel'") - np.testing.assert_almost_equal(im3.array, im2[im3.bounds].array, 15, + np.testing.assert_array_almost_equal(im3.array, im2[im3.bounds].array, 15, "drawReal made different image when off-center") - np.testing.assert_almost_equal(im4.array, im2[im3.bounds].array, 15, + np.testing.assert_array_almost_equal(im4.array, im2[im3.bounds].array, 15, "drawReal made different image when jac is not None") - np.testing.assert_almost_equal(im5.array, im2[im3.bounds].array, 15, + np.testing.assert_array_almost_equal(im5.array, im2[im3.bounds].array, 15, "drawReal made different image when jac is not diagonal") obj.drawImage(im1, method='sb') @@ -1612,13 +1676,13 @@ def test_direct_scale(): obj_sb.dilate(1.0).drawReal(im4) obj_sb.rotate(0.3*galsim.radians).drawReal(im5) print('sb: max diff = ',np.max(np.abs(im1.array - im2.array))) - np.testing.assert_almost_equal(im1.array, im2.array, 15, + np.testing.assert_array_almost_equal(im1.array, im2.array, 14, "drawReal made different image than method='sb'") - np.testing.assert_almost_equal(im3.array, im2[im3.bounds].array, 15, + np.testing.assert_array_almost_equal(im3.array, im2[im3.bounds].array, 15, "drawReal made different image when off-center") - np.testing.assert_almost_equal(im4.array, im2[im3.bounds].array, 15, + np.testing.assert_array_almost_equal(im4.array, im2[im3.bounds].array, 15, "drawReal made different image when jac is not None") - np.testing.assert_almost_equal(im5.array, im2[im3.bounds].array, 14, + np.testing.assert_array_almost_equal(im5.array, im2[im3.bounds].array, 14, "drawReal made different image when jac is not diagonal") obj.drawImage(im1, method='fft') @@ -1627,13 +1691,13 @@ def test_direct_scale(): obj_with_pixel.dilate(1.0).drawFFT(im4) obj_with_pixel.rotate(90 * galsim.degrees).drawFFT(im5) print('fft: max diff = ',np.max(np.abs(im1.array - im2.array))) - np.testing.assert_almost_equal(im1.array, im2.array, 15, + np.testing.assert_array_almost_equal(im1.array, im2.array, 15, "drawFFT made different image than method='fft'") - np.testing.assert_almost_equal(im3.array, im2[im3.bounds].array, 15, + np.testing.assert_array_almost_equal(im3.array, im2[im3.bounds].array, 15, "drawFFT made different image when off-center") - np.testing.assert_almost_equal(im4.array, im2[im3.bounds].array, 15, + np.testing.assert_array_almost_equal(im4.array, im2[im3.bounds].array, 15, "drawFFT made different image when jac is not None") - np.testing.assert_almost_equal(im5.array, im2[im3.bounds].array, 14, + np.testing.assert_array_almost_equal(im5.array, im2[im3.bounds].array, 14, "drawFFT made different image when jac is not diagonal") obj.drawImage(im1, method='real_space') @@ -1644,13 +1708,13 @@ def test_direct_scale(): print('real_space: max diff = ',np.max(np.abs(im1.array - im2.array))) # I'm not sure why this one comes out a bit less precisely equal. But 12 digits is still # plenty accurate enough. - np.testing.assert_almost_equal(im1.array, im2.array, 12, + np.testing.assert_array_almost_equal(im1.array, im2.array, 12, "drawReal made different image than method='real_space'") - np.testing.assert_almost_equal(im3.array, im2[im3.bounds].array, 14, + np.testing.assert_array_almost_equal(im3.array, im2[im3.bounds].array, 14, "drawReal made different image when off-center") - np.testing.assert_almost_equal(im4.array, im2[im3.bounds].array, 14, + np.testing.assert_array_almost_equal(im4.array, im2[im3.bounds].array, 14, "drawReal made different image when jac is not None") - np.testing.assert_almost_equal(im5.array, im2[im3.bounds].array, 14, + np.testing.assert_array_almost_equal(im5.array, im2[im3.bounds].array, 14, "drawReal made different image when jac is not diagonal") obj.drawImage(im1, method='phot', rng=rng.duplicate()) @@ -1660,18 +1724,18 @@ def test_direct_scale(): phot3.scaleXY(1./scale) phot4 = im3.wcs.toImage(obj).makePhot(rng=rng.duplicate()) print('phot: max diff = ',np.max(np.abs(im1.array - im2.array))) - np.testing.assert_almost_equal(im1.array, im2.array, 15, + np.testing.assert_array_almost_equal(im1.array, im2.array, 15, "drawPhot made different image than method='phot'") - np.testing.assert_almost_equal(im3.array, im2[im3.bounds].array, 15, + np.testing.assert_array_almost_equal(im3.array, im2[im3.bounds].array, 15, "drawPhot made different image when off-center") assert phot2 == phot1, "drawPhot made different photons than method='phot'" assert phot3 == phot1, "makePhot made different photons than method='phot'" # phot4 has a different order of operations for the math, so it doesn't come out exact. - np.testing.assert_almost_equal(phot4.x, phot3.x, 15, + np.testing.assert_array_almost_equal(phot4.x, phot3.x, 15, "two ways to have makePhot apply scale have different x") - np.testing.assert_almost_equal(phot4.y, phot3.y, 15, + np.testing.assert_array_almost_equal(phot4.y, phot3.y, 15, "two ways to have makePhot apply scale have different y") - np.testing.assert_almost_equal(phot4.flux, phot3.flux, 15, + np.testing.assert_array_almost_equal(phot4.flux, phot3.flux, 15, "two ways to have makePhot apply scale have different flux") # Check images with invalid wcs raise ValueError diff --git a/tests/test_errors.py b/tests/test_errors.py index c655b52ea0..ade275c968 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -16,6 +16,7 @@ # and/or other materials provided with the distribution. # +import numpy as np import galsim from galsim_test_helpers import * diff --git a/tests/test_exponential.py b/tests/test_exponential.py index 8d3549d41e..f1ee579913 100644 --- a/tests/test_exponential.py +++ b/tests/test_exponential.py @@ -226,7 +226,12 @@ def test_exponential_shoot(): assert np.isclose(added_flux, obj.flux) assert np.isclose(im.array.sum(), obj.flux) photons2 = obj.makePhot(poisson_flux=False, rng=rng) - assert photons2 == photons, "Exponential makePhot not equivalent to drawPhot" + if is_jax_galsim(): + np.testing.assert_allclose(photons2.x, photons.x) + np.testing.assert_allclose(photons2.y, photons.y) + np.testing.assert_allclose(photons2.flux, photons.flux) + else: + assert photons2 == photons, "Exponential makePhot not equivalent to drawPhot" @timer diff --git a/tests/test_fitsheader.py b/tests/test_fitsheader.py index 13f9e08e9c..ea2f96af2f 100644 --- a/tests/test_fitsheader.py +++ b/tests/test_fitsheader.py @@ -49,7 +49,7 @@ def check_tpv(header): assert 54384.18627436 in header.itervalues() file_name = 'tpv.fits' - dir = 'fits_files' + dir = os.path.join(os.path.dirname(__file__), 'fits_files') # First option: give a file_name header = galsim.FitsHeader(file_name=os.path.join(dir,file_name)) check_tpv(header) @@ -189,7 +189,7 @@ def check_tpv(header): def test_scamp(): """Test that we can read in a SCamp .head file correctly """ - dir = 'fits_files' + dir = os.path.join(os.path.dirname(__file__), 'fits_files') file_name = 'scamp.head' header = galsim.FitsHeader(file_name=file_name, dir=dir, text_file=True) diff --git a/tests/test_hsm.py b/tests/test_hsm.py index 1d06284ff0..28e1c7121a 100644 --- a/tests/test_hsm.py +++ b/tests/test_hsm.py @@ -45,7 +45,7 @@ test_timing = False # define inputs and expected results for tests that use real SDSS galaxies -img_dir = os.path.join(".","HSM_precomputed") +img_dir = os.path.join(os.path.dirname(__file__), "HSM_precomputed") gal_file_prefix = "image." psf_file_prefix = "psf." img_suff = ".fits" diff --git a/tests/test_image.py b/tests/test_image.py index 2d6e1a7b33..4e6302508c 100644 --- a/tests/test_image.py +++ b/tests/test_image.py @@ -81,7 +81,7 @@ # it helps speed things up. nimages = 3 -datadir = os.path.join(".", "Image_comparison_images") +datadir = os.path.join(os.path.dirname(__file__), "Image_comparison_images") @timer @@ -103,8 +103,12 @@ def test_Image_basic(): np.testing.assert_array_equal(im1.array, 0.) assert im1.array.shape == (nrow,ncol) assert im1.array.dtype.type == np_array_type - assert im1.array.flags.writeable == True - assert im1.array.flags.c_contiguous == True + if is_jax_galsim(): + pass + else: + # jax arrays do not have flags + assert im1.array.flags.writeable == True + assert im1.array.flags.c_contiguous == True assert im1.dtype == np_array_type assert im1.ncol == ncol assert im1.nrow == nrow @@ -201,16 +205,32 @@ def test_Image_basic(): assert im1.view()(x,y) == value assert im1.view()(galsim.PositionI(x,y)) == value assert im1.view(make_const=True)(x,y) == value - assert im2(x,y) == value + if is_jax_galsim(): + # no real views in jax + assert im2(x,y) != value + else: + assert im2(x,y) == value assert im2_view(x,y) == value - assert im2_cview(x,y) == value + if is_jax_galsim(): + # no real views in jax + assert im2_cview(x,y) != value + else: + assert im2_cview(x,y) == value assert im1.conjugate(x,y) == value if tchar[i][0] == 'C': # complex conjugate is not a view into the original. assert im2_conj(x,y) == 23 - assert im2.conjugate(x,y) == value + if is_jax_galsim(): + # no real views in jax + assert im2.conjugate(x,y) != value + else: + assert im2.conjugate(x,y) == value else: - assert im2_conj(x,y) == value + if is_jax_galsim(): + # no real views in jax + assert im2_conj(x,y) != value + else: + assert im2_conj(x,y) == value value2 = 53 + 12*x - 19*y if tchar[i] in ['US', 'UI']: @@ -220,16 +240,32 @@ def test_Image_basic(): assert im1.getValue(x,y) == value2 assert im1.view().getValue(x=x, y=y) == value2 assert im1.view(make_const=True).getValue(x,y) == value2 - assert im2.getValue(x=x, y=y) == value2 + if is_jax_galsim(): + # no real views in jax + assert im2.getValue(x=x, y=y) != value2 + else: + assert im2.getValue(x=x, y=y) == value2 assert im2_view.getValue(x,y) == value2 - assert im2_cview._getValue(x,y) == value2 + if is_jax_galsim(): + # no real views in jax + assert im2_cview._getValue(x,y) != value2 + else: + assert im2_cview._getValue(x,y) == value2 assert im1.real(x,y) == value2 assert im1.view().real(x,y) == value2 assert im1.view(make_const=True).real(x,y) == value2.real - assert im2.real(x,y) == value2.real + if is_jax_galsim(): + # no real views in jax + assert im2.real(x,y) != value2.real + else: + assert im2.real(x,y) == value2.real assert im2_view.real(x,y) == value2.real - assert im2_cview.real(x,y) == value2.real + if is_jax_galsim(): + # no real views in jax + assert im2_cview.real(x,y) != value2.real + else: + assert im2_cview.real(x,y) == value2.real assert im1.imag(x,y) == 0 assert im1.view().imag(x,y) == 0 assert im1.view(make_const=True).imag(x,y) == 0 @@ -237,15 +273,26 @@ def test_Image_basic(): assert im2_view.imag(x,y) == 0 assert im2_cview.imag(x,y) == 0 - value3 = 10*x + y + if is_jax_galsim(): + value3 = 10*x + y + 111 + else: + value3 = 10*x + y im1.addValue(x,y, np.int64(value3-value2)) im2_view[x,y] += np.int64(value3-value2) assert im1[galsim.PositionI(x,y)] == value3 assert im1.view()[x,y] == value3 assert im1.view(make_const=True)[galsim.PositionI(x,y)] == value3 - assert im2[x,y] == value3 + if is_jax_galsim(): + # no real views in jax + assert im2[x,y] != value3 + else: + assert im2[x,y] == value3 assert im2_view[galsim.PositionI(x,y)] == value3 - assert im2_cview[x,y] == value3 + if is_jax_galsim(): + # no real views in jax + assert im2_cview[x,y] != value3 + else: + assert im2_cview[x,y] == value3 # Setting or getting the value outside the bounds should throw an exception. assert_raises(galsim.GalSimBoundsError,im1.setValue,0,0,1) @@ -356,11 +403,19 @@ def test_Image_basic(): assert im2.bounds == bounds for y in range(1,nrow+1): for x in range(1,ncol+1): - value3 = 10*x+y + if is_jax_galsim(): + value3 = 10*x+y + 111 + else: + value3 = 10*x+y assert im1(x+dx,y+dy) == value3 assert im1_view(x,y) == value3 - assert im2(x,y) == value3 + if is_jax_galsim(): + assert im2(x,y) != value3 + else: + assert im2(x,y) == value3 assert im2_view(x+dx,y+dy) == value3 + if is_jax_galsim(): + value3 = 10*x+y assert im3_view(x+dx,y+dy) == value3 assert_raises(TypeError, im1.shift, dx) @@ -690,7 +745,7 @@ def test_Image_FITS_IO(): assert_raises(OSError, galsim.fits.read, test_file, compression='none') # Check a file with no WCS information - nowcs_file = 'fits_files/blankimg.fits' + nowcs_file = os.path.join(os.path.dirname(__file__), 'fits_files/blankimg.fits') im = galsim.fits.read(nowcs_file) assert im.wcs == galsim.PixelScale(1.0) @@ -1016,7 +1071,7 @@ def test_Image_MultiFITS_IO(): assert_raises(OSError, galsim.fits.readMulti, test_multi_file, compression='none') # Check a file with no WCS information - nowcs_file = 'fits_files/blankimg.fits' + nowcs_file = os.path.join(os.path.dirname(__file__), 'fits_files/blankimg.fits') ims = galsim.fits.readMulti(nowcs_file) assert ims[0].wcs == galsim.PixelScale(1.0) @@ -1346,7 +1401,7 @@ def test_Image_CubeFITS_IO(): assert_raises(OSError, galsim.fits.readCube, test_cube_file, compression='none') # Check a file with no WCS information - nowcs_file = 'fits_files/blankimg.fits' + nowcs_file = os.path.join(os.path.dirname(__file__), 'fits_files/blankimg.fits') ims = galsim.fits.readCube(nowcs_file) assert ims[0].wcs == galsim.PixelScale(1.0) @@ -2133,7 +2188,7 @@ def test_subImage_persistence(): """Test that a subimage is properly accessible even if the original image has gone out of scope. """ - file_name = os.path.join('fits_files','tpv.fits') + file_name = os.path.join(os.path.dirname(__file__), os.path.join('fits_files','tpv.fits')) bounds = galsim.BoundsI(123, 133, 45, 55) # Something random # In this case, the original image has gone out of scope. At least on some systems, @@ -2415,7 +2470,11 @@ def test_Image_view(): assert imv.bounds == im.bounds imv.setValue(11,19, 20) assert imv(11,19) == 20 - assert im(11,19) == 20 + if is_jax_galsim(): + # jax-galsim does not support views + assert im(11,19) != 20 + else: + assert im(11,19) == 20 check_pickle(im) check_pickle(imv) @@ -2427,7 +2486,11 @@ def test_Image_view(): assert imv.bounds == galsim.BoundsI(0,24,0,24) imv.setValue(10,18, 30) assert imv(10,18) == 30 - assert im(11,19) == 30 + if is_jax_galsim(): + # jax-galsim does not support views + assert im(11,19) != 20 + else: + assert im(11,19) == 30 imv2 = im.view() imv2.setOrigin(0,0) assert imv.bounds == imv2.bounds @@ -2443,7 +2506,11 @@ def test_Image_view(): assert imv.bounds == galsim.BoundsI(-12,12,-12,12) imv.setValue(-2,6, 40) assert imv(-2,6) == 40 - assert im(11,19) == 40 + if is_jax_galsim(): + # jax-galsim does not support views + assert im(11,19) != 40 + else: + assert im(11,19) == 40 imv2 = im.view() imv2.setCenter(0,0) assert imv.bounds == imv2.bounds @@ -2460,7 +2527,11 @@ def test_Image_view(): assert imv.bounds == im.bounds imv.setValue(11,19, 50) assert imv(11,19) == 50 - assert im(11,19) == 50 + if is_jax_galsim(): + # jax-galsim does not support views + assert im(11,19) != 50 + else: + assert im(11,19) == 50 imv2 = im.view() with assert_raises(galsim.GalSimError): imv2.scale = 0.17 # Invalid if wcs is not PixelScale @@ -2478,7 +2549,11 @@ def test_Image_view(): assert imv.bounds == im.bounds imv.setValue(11,19, 60) assert imv(11,19) == 60 - assert im(11,19) == 60 + if is_jax_galsim(): + # jax-galsim does not support views + assert im(11,19) != 60 + else: + assert im(11,19) == 60 imv2 = im.view() imv2.wcs = galsim.JacobianWCS(0.,0.23,-0.23,0.) assert imv.bounds == imv2.bounds @@ -2569,13 +2644,17 @@ def test_copy(): assert im(3,8) != 11. # If copy=False is specified, then it shares the same array - im3b = galsim.Image(im, copy=False) - assert im3b.wcs == im.wcs - assert im3b.bounds == im.bounds - np.testing.assert_array_equal(im3b.array, im.array) - im3b.setValue(2,3,2.) - assert im3b(2,3) == 2. - assert im(2,3) == 2. + if is_jax_galsim(): + # jax-galsim does not support references + pass + else: + im3b = galsim.Image(im, copy=False) + assert im3b.wcs == im.wcs + assert im3b.bounds == im.bounds + np.testing.assert_array_equal(im3b.array, im.array) + im3b.setValue(2,3,2.) + assert im3b(2,3) == 2. + assert im(2,3) == 2. # Constructor can change the wcs im4 = galsim.Image(im, scale=0.6) @@ -2626,13 +2705,17 @@ def test_copy(): assert im_slice(2,3) != 11. # Can also copy by giving the array and specify copy=True - im10 = galsim.Image(im.array, bounds=im.bounds, wcs=im.wcs, copy=False) - assert im10.wcs == im.wcs - assert im10.bounds == im.bounds - np.testing.assert_array_equal(im10.array, im.array) - im10[2,3] = 17 - assert im10(2,3) == 17. - assert im(2,3) == 17. + if is_jax_galsim(): + # jax-galsim does not support references + pass + else: + im10 = galsim.Image(im.array, bounds=im.bounds, wcs=im.wcs, copy=False) + assert im10.wcs == im.wcs + assert im10.bounds == im.bounds + np.testing.assert_array_equal(im10.array, im.array) + im10[2,3] = 17 + assert im10(2,3) == 17. + assert im(2,3) == 17. im10b = galsim.Image(im.array, bounds=im.bounds, wcs=im.wcs, copy=True) assert im10b.wcs == im.wcs @@ -2683,39 +2766,82 @@ def test_complex_image(): assert im1(x,y) == value assert im1.view()(x,y) == value assert im1.view(make_const=True)(x,y) == value - assert im2(x,y) == value + if is_jax_galsim(): + # jax galsim does not support views + assert im2(x,y) != value + else: + assert im2(x,y) == value assert im2_view(x,y) == value - assert im2_cview(x,y) == value + if is_jax_galsim(): + # jax galsim does not support views + assert im2_cview(x,y) != value + else: + assert im2_cview(x,y) == value assert im1.conjugate(x,y) == np.conjugate(value) # complex conjugate is not a view into the original. assert im2_conj(x,y) == 23 - assert im2.conjugate(x,y) == np.conjugate(value) + if is_jax_galsim(): + # jax galsim does not support views + assert im2.conjugate(x,y) != np.conjugate(value) + else: + assert im2.conjugate(x,y) == np.conjugate(value) - value2 = 10*x + y + 20j*x + 2j*y + if is_jax_galsim(): + value2 = 400000 + 10*x + y + 20j*x + 2j*y + else: + value2 = 10*x + y + 20j*x + 2j*y im1.setValue(x,y, value2) im2_view.setValue(x=x, y=y, value=value2) assert im1(x,y) == value2 assert im1.view()(x,y) == value2 assert im1.view(make_const=True)(x,y) == value2 - assert im2(x,y) == value2 + if is_jax_galsim(): + # jax galsim does not support views + assert im2(x,y) != value2 + else: + assert im2(x,y) == value2 assert im2_view(x,y) == value2 - assert im2_cview(x,y) == value2 + if is_jax_galsim(): + # jax galsim does not support views + assert im2_cview(x,y) != value2 + else: + assert im2_cview(x,y) == value2 assert im1.real(x,y) == value2.real assert im1.view().real(x,y) == value2.real assert im1.view(make_const=True).real(x,y) == value2.real - assert im2.real(x,y) == value2.real + if is_jax_galsim(): + # jax galsim does not support views + assert im2.real(x,y) != value2.real + else: + assert im2.real(x,y) == value2.real assert im2_view.real(x,y) == value2.real - assert im2_cview.real(x,y) == value2.real + if is_jax_galsim(): + # jax galsim does not support views + assert im2_cview.real(x,y) != value2.real + else: + assert im2_cview.real(x,y) == value2.real assert im1.imag(x,y) == value2.imag assert im1.view().imag(x,y) == value2.imag assert im1.view(make_const=True).imag(x,y) == value2.imag - assert im2.imag(x,y) == value2.imag + if is_jax_galsim(): + # jax galsim does not support views + assert im2.imag(x,y) != value2.imag + else: + assert im2.imag(x,y) == value2.imag assert im2_view.imag(x,y) == value2.imag - assert im2_cview.imag(x,y) == value2.imag + if is_jax_galsim(): + # jax galsim does not support views + assert im2_cview.imag(x,y) != value2.imag + else: + assert im2_cview.imag(x,y) == value2.imag assert im1.conjugate(x,y) == np.conjugate(value2) - assert im2.conjugate(x,y) == np.conjugate(value2) + if is_jax_galsim(): + # jax galsim does not support views + assert im2.conjugate(x,y) != np.conjugate(value2) + else: + assert im2.conjugate(x,y) == np.conjugate(value2) rvalue3 = 12*x + y ivalue3 = x + 21*y @@ -2724,14 +2850,25 @@ def test_complex_image(): im1.imag.setValue(x,y, ivalue3) im2_view.real.setValue(x,y, rvalue3) im2_view.imag.setValue(x,y, ivalue3) - assert im1(x,y) == value3 - assert im1.view()(x,y) == value3 - assert im1.view(make_const=True)(x,y) == value3 - assert im2(x,y) == value3 - assert im2_view(x,y) == value3 - assert im2_cview(x,y) == value3 - assert im1.conjugate(x,y) == np.conjugate(value3) - assert im2.conjugate(x,y) == np.conjugate(value3) + # jax galsim does not support views + if is_jax_galsim(): + assert im1(x,y) != value3 + assert im1.view()(x,y) != value3 + assert im1.view(make_const=True)(x,y) != value3 + assert im2(x,y) != value3 + assert im2_view(x,y) != value3 + assert im2_cview(x,y) != value3 + assert im1.conjugate(x,y) != np.conjugate(value3) + assert im2.conjugate(x,y) != np.conjugate(value3) + else: + assert im1(x,y) == value3 + assert im1.view()(x,y) == value3 + assert im1.view(make_const=True)(x,y) == value3 + assert im2(x,y) == value3 + assert im2_view(x,y) == value3 + assert im2_cview(x,y) == value3 + assert im1.conjugate(x,y) == np.conjugate(value3) + assert im2.conjugate(x,y) == np.conjugate(value3) # Check view of given data im3_view = galsim.Image((1+2j)*ref_array.astype(complex)) @@ -2768,7 +2905,7 @@ def test_complex_image_arith(): np.testing.assert_array_equal(image2.array, ref_array * (2+5j), err_msg="ImageD * complex is not correct") image2 = image1 / (2+5j) - np.testing.assert_array_equal(image2.array, ref_array / (2+5j), + np.testing.assert_allclose(image2.array, ref_array / (2+5j), err_msg="ImageD / complex is not correct") # Binary complex scalar op ImageD @@ -2782,7 +2919,7 @@ def test_complex_image_arith(): np.testing.assert_array_equal(image2.array, ref_array * (2+5j), err_msg="complex * ImageD is not correct") image2 = (2+5j) / image1 - np.testing.assert_array_equal(image2.array, (2+5j) / ref_array.astype(float), + np.testing.assert_allclose(image2.array, (2+5j) / ref_array.astype(float), err_msg="complex / ImageD is not correct") image2 = image1 * (3+1j) @@ -2798,7 +2935,7 @@ def test_complex_image_arith(): np.testing.assert_array_equal(image3.array, (3+1j)*ref_array * (2+5j), err_msg="ImageCD * complex is not correct") image3 = image2 / (2+5j) - np.testing.assert_array_equal(image3.array, (3+1j)*ref_array / (2+5j), + np.testing.assert_allclose(image3.array, (3+1j)*ref_array / (2+5j), err_msg="ImageCD / complex is not correct") # Binary complex scalar op ImageCD @@ -2812,7 +2949,7 @@ def test_complex_image_arith(): np.testing.assert_array_equal(image3.array, (3+1j)*ref_array * (2+5j), err_msg="complex * ImageCD is not correct") image3 = (2+5j) / image2 - np.testing.assert_array_equal(image3.array, (2+5j) / ((3+1j)*ref_array), + np.testing.assert_allclose(image3.array, (2+5j) / ((3+1j)*ref_array), err_msg="complex / ImageCD is not correct") # Binary ImageD op ImageCD @@ -2826,7 +2963,7 @@ def test_complex_image_arith(): np.testing.assert_array_equal(image3.array, (3+1j)*ref_array**2, err_msg="ImageD * ImageCD is not correct") image3 = image1 / image2 - np.testing.assert_almost_equal(image3.array, 1./(3+1j), decimal=12, + np.testing.assert_array_almost_equal(image3.array, 1./(3+1j), decimal=12, err_msg="ImageD / ImageCD is not correct") # Binary ImageCD op ImageD @@ -2840,7 +2977,7 @@ def test_complex_image_arith(): np.testing.assert_array_equal(image3.array, (3+1j)*ref_array**2, err_msg="ImageD * ImageCD is not correct") image3 = image2 / image1 - np.testing.assert_almost_equal(image3.array, (3+1j), decimal=12, + np.testing.assert_array_almost_equal(image3.array, (3+1j), decimal=12, err_msg="ImageD / ImageCD is not correct") # Binary ImageCD op ImageCD @@ -2855,7 +2992,7 @@ def test_complex_image_arith(): np.testing.assert_array_equal(image4.array, (15-5j)*ref_array**2, err_msg="ImageCD * ImageCD is not correct") image4 = image2 / image3 - np.testing.assert_almost_equal(image4.array, (9+13j)/25., decimal=12, + np.testing.assert_array_almost_equal(image4.array, (9+13j)/25., decimal=12, err_msg="ImageCD / ImageCD is not correct") # In place ImageCD op complex scalar @@ -2873,7 +3010,7 @@ def test_complex_image_arith(): err_msg="ImageCD * complex is not correct") image4 = image2.copy() image4 /= (2+5j) - np.testing.assert_array_equal(image4.array, (3+1j)*ref_array / (2+5j), + np.testing.assert_allclose(image4.array, (3+1j)*ref_array / (2+5j), err_msg="ImageCD / complex is not correct") # In place ImageCD op ImageD @@ -2891,7 +3028,7 @@ def test_complex_image_arith(): err_msg="ImageD * ImageCD is not correct") image4 = image2.copy() image4 /= image1 - np.testing.assert_almost_equal(image4.array, (3+1j), decimal=12, + np.testing.assert_array_almost_equal(image4.array, (3+1j), decimal=12, err_msg="ImageD / ImageCD is not correct") # In place ImageCD op ImageCD @@ -2909,7 +3046,7 @@ def test_complex_image_arith(): err_msg="ImageCD * ImageCD is not correct") image4 = image2.copy() image4 /= image3 - np.testing.assert_almost_equal(image4.array, (9+13j)/25., decimal=12, + np.testing.assert_array_almost_equal(image4.array, (9+13j)/25., decimal=12, err_msg="ImageCD / ImageCD is not correct") @timer @@ -3258,7 +3395,7 @@ def test_wrap(): b = galsim.BoundsI(1,4,1,4) im_quad = im_orig[b] im_wrap = im.wrap(b) - np.testing.assert_almost_equal(im_wrap.array, 4.*im_quad.array, 12, + np.testing.assert_array_almost_equal(im_wrap.array, 4.*im_quad.array, 12, "image.wrap() into first quadrant did not match expectation") # The same thing should work no matter where the lower left corner is: @@ -3267,7 +3404,7 @@ def test_wrap(): im_quad = im_orig[b] im = im_orig.copy() im_wrap = im.wrap(b) - np.testing.assert_almost_equal(im_wrap.array, 4.*im_quad.array, 12, + np.testing.assert_array_almost_equal(im_wrap.array, 4.*im_quad.array, 12, "image.wrap(%s) did not match expectation"%b) np.testing.assert_array_equal(im_wrap.array, im[b].array, "image.wrap(%s) did not return the right subimage") @@ -3289,7 +3426,7 @@ def test_wrap(): jj = (j-b.ymin) % (b.ymax-b.ymin+1) + b.ymin im_test.addValue(ii,jj,val) im_wrap = im.wrap(b) - np.testing.assert_almost_equal(im_wrap.array, im_test.array, 12, + np.testing.assert_array_almost_equal(im_wrap.array, im_test.array, 12, "image.wrap(%s) did not match expectation"%b) np.testing.assert_array_equal(im_wrap.array, im[b].array, "image.wrap(%s) did not return the right subimage") @@ -3336,7 +3473,7 @@ def test_wrap(): im_wrap = im.wrap(b) #print("im_wrap = ",im_wrap.array) - np.testing.assert_almost_equal(im_wrap.array, im_test.array, 12, + np.testing.assert_array_almost_equal(im_wrap.array, im_test.array, 12, "image.wrap(%s) did not match expectation"%b) np.testing.assert_array_equal(im_wrap.array, im[b].array, "image.wrap(%s) did not return the right subimage") @@ -3347,7 +3484,7 @@ def test_wrap(): #print('im_test = ',im_test[b2].array) #print('im2_wrap = ',im2_wrap.array) #print('diff = ',im2_wrap.array-im_test[b2].array) - np.testing.assert_almost_equal(im2_wrap.array, im_test[b2].array, 12, + np.testing.assert_array_almost_equal(im2_wrap.array, im_test[b2].array, 12, "image.wrap(%s) did not match expectation"%b) np.testing.assert_array_equal(im2_wrap.array, im2[b2].array, "image.wrap(%s) did not return the right subimage") @@ -3358,7 +3495,7 @@ def test_wrap(): #print('im_test = ',im_test[b3].array) #print('im3_wrap = ',im3_wrap.array) #print('diff = ',im3_wrap.array-im_test[b3].array) - np.testing.assert_almost_equal(im3_wrap.array, im_test[b3].array, 12, + np.testing.assert_array_almost_equal(im3_wrap.array, im_test[b3].array, 12, "image.wrap(%s) did not match expectation"%b) np.testing.assert_array_equal(im3_wrap.array, im3[b3].array, "image.wrap(%s) did not return the right subimage") @@ -3536,7 +3673,7 @@ def test_fpack(): """Test the functionality that we advertise as being equivalent to fpack/funpack """ from astropy.io import fits - file_name0 = os.path.join('des_data','DECam_00158414_01.fits.fz') + file_name0 = os.path.join(os.path.dirname(__file__), 'des_data','DECam_00158414_01.fits.fz') hdulist = fits.open(file_name0) # Remove a few invalid header keys in the DECam fits file diff --git a/tests/test_inclined.py b/tests/test_inclined.py index 7fcd83e325..a7a04b70d3 100644 --- a/tests/test_inclined.py +++ b/tests/test_inclined.py @@ -16,6 +16,7 @@ # and/or other materials provided with the distribution. # +import os from copy import deepcopy import numpy as np @@ -28,7 +29,7 @@ # set up any necessary info for tests # Note that changes here should match changes to test image files -image_dir = './inclined_exponential_images' +image_dir = os.path.join(os.path.dirname(__file__), './inclined_exponential_images') # Values here are strings, so the filenames will be sure to work (without truncating zeros) diff --git a/tests/test_integ.py b/tests/test_integ.py index 537b5c72b1..c89785c20f 100644 --- a/tests/test_integ.py +++ b/tests/test_integ.py @@ -183,8 +183,11 @@ def test_func(x): return x**-2 test_integral, true_result, decimal=test_decimal, verbose=True, err_msg="x^(-2) integral failed across interval [1, inf].") - with assert_raises(galsim.GalSimError): - galsim.integ.int1d(test_func, 0., 1., test_rel_err, test_abs_err) + if is_jax_galsim(): + assert np.isnan(galsim.integ.int1d(test_func, 0., 1., test_rel_err, test_abs_err)) + else: + with assert_raises(galsim.GalSimError): + galsim.integ.int1d(test_func, 0., 1., test_rel_err, test_abs_err) @timer diff --git a/tests/test_interpolatedimage.py b/tests/test_interpolatedimage.py index 281b3de568..f3e97891e6 100644 --- a/tests/test_interpolatedimage.py +++ b/tests/test_interpolatedimage.py @@ -369,8 +369,12 @@ def test_interpolant(): -(vm+1) * sici(np.pi*(vm+1))[0] -(vp-1) * sici(np.pi*(vp-1))[0] +(vp+1) * sici(np.pi*(vp+1))[0] ) / (2*np.pi) - np.testing.assert_allclose(ln.kval(x), true_kval, rtol=1.e-4, atol=1.e-8) - assert np.isclose(ln.kval(x[12]), true_kval[12]) + if is_jax_galsim(): + np.testing.assert_allclose(ln.kval(x), true_kval, rtol=3.0e-4, atol=3.0e-6) + np.testing.assert_allclose(ln.kval(x[12]), true_kval[12], rtol=3.0e-4, atol=3.0e-6) + else: + np.testing.assert_allclose(ln.kval(x), true_kval, rtol=1.e-4, atol=1.e-8) + assert np.isclose(ln.kval(x[12]), true_kval[12]) # Base class is invalid. assert_raises(NotImplementedError, galsim.Interpolant) @@ -400,12 +404,18 @@ def test_unit_integrals(): print(str(interp)) # Compute directly with int1d n = interp.ixrange//2 + 1 + if is_jax_galsim(): + # jax galsim is slow when doing direct integration + _n_do = min(n, 100) + else: + _n_do = n + direct_integrals = np.zeros(n) if isinstance(interp, galsim.Delta): # int1d doesn't handle this well. direct_integrals[0] = 1 else: - for k in range(n): + for k in range(_n_do): direct_integrals[k] = galsim.integ.int1d(interp.xval, k-0.5, k+0.5) print('direct: ',direct_integrals) @@ -414,7 +424,7 @@ def test_unit_integrals(): print('integrals: ',len(integrals),integrals) assert len(integrals) == n - np.testing.assert_allclose(integrals, direct_integrals, atol=1.e-12) + np.testing.assert_allclose(integrals[:_n_do], direct_integrals[:_n_do], atol=1.e-12) if n > 10: print('n>10 for ',repr(interp)) @@ -451,8 +461,8 @@ def test_fluxnorm(): # First, make some Image with some total flux value (sum of pixel values) and scale im = galsim.ImageF(im_lin_scale, im_lin_scale, scale=im_scale, init_value=im_fill_value) total_flux = im_fill_value*(im_lin_scale**2) - np.testing.assert_equal(total_flux, im.array.sum(), - err_msg='Created array with wrong total flux') + np.testing.assert_array_equal(total_flux, im.array.sum(), + err_msg='Created array with wrong total flux') # Check that if we make an InterpolatedImage with flux normalization, it keeps that flux interp = galsim.InterpolatedImage(im) # note, flux normalization is the default @@ -486,8 +496,8 @@ def test_fluxnorm(): # Finally make an InterpolatedImage but give it some other flux value interp_flux = galsim.InterpolatedImage(im, flux=test_flux) # Check that it has that flux - np.testing.assert_equal(test_flux, interp_flux.flux, - err_msg = 'InterpolatedImage did not use flux keyword') + np.testing.assert_array_equal(test_flux, interp_flux.flux, + err_msg = 'InterpolatedImage did not use flux keyword') # Check that this is preserved when drawing im5 = interp_flux.drawImage(scale = im_scale, method='no_pixel') np.testing.assert_almost_equal(test_flux/im5.array.sum(), 1.0, decimal=6, @@ -509,12 +519,18 @@ def test_exceptions(): galsim.InterpolatedImage(image=galsim.ImageF(5, 5)) # Image must be real type (F or D) - with assert_raises(galsim.GalSimValueError): - galsim.InterpolatedImage(image=galsim.ImageI(5, 5, scale=1)) + if is_jax_galsim(): + pass + else: + with assert_raises(galsim.GalSimValueError): + galsim.InterpolatedImage(image=galsim.ImageI(5, 5, scale=1)) - # Image must have non-zero flux - with assert_raises(galsim.GalSimValueError): - galsim.InterpolatedImage(image=galsim.ImageF(5, 5, scale=1, init_value=0.)) + if is_jax_galsim(): + pass + else: + # Image must have non-zero flux + with assert_raises(galsim.GalSimValueError): + galsim.InterpolatedImage(image=galsim.ImageF(5, 5, scale=1, init_value=0.)) # Can't shoot II with SincInterpolant ii = galsim.InterpolatedImage(image=galsim.ImageF(5, 5, scale=1, init_value=1.), @@ -742,8 +758,11 @@ def test_operations(): test_decimal = 3 # Make some nontrivial image - im = galsim.fits.read('./real_comparison_images/test_images.fits') # read in first real galaxy - # in test catalog + im_path = os.path.join( + os.path.dirname(__file__), "real_comparison_images/test_images.fits" + ) + im = galsim.fits.read(im_path) # read in first real galaxy + # in test catalog int_im = galsim.InterpolatedImage(im) orig_mom = im.FindAdaptiveMom() @@ -952,7 +971,7 @@ def test_corr_padding(): # Set up some defaults for tests. decimal_precise=4 decimal_coarse=2 - imgfile = 'fits_files/blankimg.fits' + imgfile = os.path.join(os.path.dirname(__file__), 'fits_files/blankimg.fits') orig_nx = 187 orig_ny = 164 big_nx = 319 @@ -1588,9 +1607,9 @@ def test_ii_shoot(): else: flux = 1.e4 for interp in interp_list: + print('interp = ',interp) obj = galsim.InterpolatedImage(image_in, x_interpolant=interp, scale=3.3, flux=flux) added_flux, photons = obj.drawPhot(im, poisson_flux=False, rng=rng.duplicate()) - print('interp = ',interp) print('obj.flux = ',obj.flux) print('added_flux = ',added_flux) print('photon fluxes = ',photons.flux.min(),'..',photons.flux.max()) @@ -1608,7 +1627,12 @@ def test_ii_shoot(): assert np.isclose(added_flux, obj.flux, rtol=rtol) assert np.isclose(im.array.sum(), obj.flux, rtol=rtol) photons2 = obj.makePhot(poisson_flux=False, rng=rng.duplicate()) - assert photons2 == photons, "InterpolatedImage makePhot not equivalent to drawPhot" + if is_jax_galsim(): + np.testing.assert_allclose(photons2.x, photons.x) + np.testing.assert_allclose(photons2.y, photons.y) + np.testing.assert_allclose(photons2.flux, photons.flux) + else: + assert photons2 == photons, "InterpolatedImage makePhot not equivalent to drawPhot" # Can treat as a convolution of a delta function and put it in a photon_ops list. delta = galsim.DeltaFunction(flux=flux) @@ -1630,7 +1654,10 @@ def test_ne(): # Copy ref_image and perturb it slightly in the middle, away from where the InterpolatedImage # repr string will report. perturb_image = ref_image.copy() - perturb_image.array[64, 64] *= 1000 + if is_jax_galsim(): + perturb_image._array = perturb_image._array.at[64, 64].set(perturb_image._array[64, 64] * 100) + else: + perturb_image.array[64, 64] *= 100 obj2 = galsim.InterpolatedImage(perturb_image, flux=20, calculate_maxk=False, calculate_stepk=False) with galsim.utilities.printoptions(threshold=128*128): @@ -1696,7 +1723,7 @@ def test_quintic_glagn(): """This is code that was giving a seg fault. cf. Issue 1079. """ - fname = os.path.join('fits_files','GLAGN_host_427_0_disk.fits') + fname = os.path.join(os.path.dirname(__file__), 'fits_files','GLAGN_host_427_0_disk.fits') for interpolant in 'linear cubic quintic'.split(): print(interpolant) fits_image = galsim.InterpolatedImage(fname, scale=0.04, x_interpolant=interpolant) @@ -1845,33 +1872,39 @@ def test_depixelize(): def test_drawreal_seg_fault(): """Test to reproduce bug report in Issue #1164 that was causing seg faults """ - - import pickle - - prof_file = 'input/test_interpolatedimage_seg_fault_prof.pkl' - with open(prof_file, 'rb') as f: - prof = pickle.load(f) - print(repr(prof)) - - image = galsim.Image( - galsim.BoundsI( - xmin=-12, - xmax=12, - ymin=-12, - ymax=12 - ), - dtype=float, - scale=1 - ) - - image.fill(3) - prof.drawReal(image) - - # The problem was that the object is shifted fully off the target image and that was leading - # to an attempt to create a stack of length -1, which caused the seg fault. - # So mostly this test just confirms that this runs without seg faulting. - # But we can check that the image is now correctly all zeros. - np.testing.assert_array_equal(image.array, 0) + # this test only runs with real galsim + if is_jax_galsim(): + pass + else: + import pickle + + prof_file = os.path.join( + os.path.dirname(__file__), + 'input/test_interpolatedimage_seg_fault_prof.pkl' + ) + with open(prof_file, 'rb') as f: + prof = pickle.load(f) + print(repr(prof)) + + image = galsim.Image( + galsim.BoundsI( + xmin=-12, + xmax=12, + ymin=-12, + ymax=12 + ), + dtype=float, + scale=1 + ) + + image.fill(3) + prof.drawReal(image) + + # The problem was that the object is shifted fully off the target image and that was leading + # to an attempt to create a stack of length -1, which caused the seg fault. + # So mostly this test just confirms that this runs without seg faulting. + # But we can check that the image is now correctly all zeros. + np.testing.assert_array_equal(image.array, 0) diff --git a/tests/test_lensing.py b/tests/test_lensing.py index 04c9ea0025..c2acc41395 100644 --- a/tests/test_lensing.py +++ b/tests/test_lensing.py @@ -24,7 +24,7 @@ from galsim_test_helpers import * -refdir = os.path.join(".", "lensing_reference_data") # Directory containing the reference +refdir = os.path.join(os.path.dirname(__file__), ".", "lensing_reference_data") # Directory containing the reference klim_test = 0.00175 # Value of klim for flat (up to klim, then zero beyond) power spectrum test tolerance_var = 0.03 # fractional error allowed in the variance of shear - calculation is not exact diff --git a/tests/test_moffat.py b/tests/test_moffat.py index e98bf1bcbf..f19b5c352f 100644 --- a/tests/test_moffat.py +++ b/tests/test_moffat.py @@ -23,8 +23,9 @@ from galsim_test_helpers import * path, filename = os.path.split(__file__) -imgdir = os.path.join(path, "SBProfile_comparison_images") # Directory containing the reference - # images. +# Directory containing the reference images. +imgdir = os.path.join(path, "SBProfile_comparison_images") + @timer def test_moffat(): @@ -135,36 +136,42 @@ def test_moffat_properties(): cen = galsim.PositionD(0, 0) np.testing.assert_equal(psf.centroid, cen) # Check Fourier properties - np.testing.assert_almost_equal(psf.maxk, 11.634597424960159) - np.testing.assert_almost_equal(psf.stepk, 0.62831853071795873) - np.testing.assert_almost_equal(psf.kValue(cen), test_flux+0j) - np.testing.assert_almost_equal(psf.half_light_radius, 1.0) - np.testing.assert_almost_equal(psf.fwhm, fwhm_backwards_compatible) - np.testing.assert_almost_equal(psf.xValue(cen), 0.50654651638242509) - np.testing.assert_almost_equal(psf.kValue(cen), (1+0j) * test_flux) - np.testing.assert_almost_equal(psf.flux, test_flux) - np.testing.assert_almost_equal(psf.xValue(cen), psf.max_sb) + if is_jax_galsim(): + np.testing.assert_allclose(psf.maxk, 11.634597424960159, atol=0, rtol=0.2) + else: + np.testing.assert_array_almost_equal(psf.maxk, 11.634597424960159) + np.testing.assert_array_almost_equal(psf.stepk, 0.62831853071795873) + np.testing.assert_array_almost_equal(psf.kValue(cen), test_flux+0j) + np.testing.assert_array_almost_equal(psf.half_light_radius, 1.0) + np.testing.assert_array_almost_equal(psf.fwhm, fwhm_backwards_compatible) + np.testing.assert_array_almost_equal(psf.xValue(cen), 0.50654651638242509) + np.testing.assert_array_almost_equal(psf.kValue(cen), (1+0j) * test_flux) + np.testing.assert_array_almost_equal(psf.flux, test_flux) + np.testing.assert_array_almost_equal(psf.xValue(cen), psf.max_sb) # Now create the same profile using the half_light_radius: psf = galsim.Moffat(beta=2.0, half_light_radius=1., trunc=2*fwhm_backwards_compatible, flux=test_flux) np.testing.assert_equal(psf.centroid, cen) - np.testing.assert_almost_equal(psf.maxk, 11.634597426100862) - np.testing.assert_almost_equal(psf.stepk, 0.62831853071795862) - np.testing.assert_almost_equal(psf.kValue(cen), test_flux+0j) - np.testing.assert_almost_equal(psf.half_light_radius, 1.0) - np.testing.assert_almost_equal(psf.fwhm, fwhm_backwards_compatible) - np.testing.assert_almost_equal(psf.xValue(cen), 0.50654651638242509) - np.testing.assert_almost_equal(psf.kValue(cen), (1+0j) * test_flux) - np.testing.assert_almost_equal(psf.flux, test_flux) - np.testing.assert_almost_equal(psf.xValue(cen), psf.max_sb) + if is_jax_galsim(): + np.testing.assert_allclose(psf.maxk, 11.634597424960159, atol=0, rtol=0.2) + else: + np.testing.assert_array_almost_equal(psf.maxk, 11.634597424960159) + np.testing.assert_array_almost_equal(psf.stepk, 0.62831853071795862) + np.testing.assert_array_almost_equal(psf.kValue(cen), test_flux+0j) + np.testing.assert_array_almost_equal(psf.half_light_radius, 1.0) + np.testing.assert_array_almost_equal(psf.fwhm, fwhm_backwards_compatible) + np.testing.assert_array_almost_equal(psf.xValue(cen), 0.50654651638242509) + np.testing.assert_array_almost_equal(psf.kValue(cen), (1+0j) * test_flux) + np.testing.assert_array_almost_equal(psf.flux, test_flux) + np.testing.assert_array_almost_equal(psf.xValue(cen), psf.max_sb) # Check input flux vs output flux for inFlux in np.logspace(-2, 2, 10): psfFlux = galsim.Moffat(2.0, fwhm=fwhm_backwards_compatible, trunc=2*fwhm_backwards_compatible, flux=inFlux) outFlux = psfFlux.flux - np.testing.assert_almost_equal(outFlux, inFlux) + np.testing.assert_array_almost_equal(outFlux, inFlux) @timer def test_moffat_maxk(): @@ -190,13 +197,16 @@ def test_moffat_maxk(): galsim.Moffat(beta=12.9, scale_radius=11, flux=23, trunc=1000), ] threshs = [1.e-3, 1.e-4, 0.03] - print('beta \t trunc \t thresh \t kValue(maxk)') + print('beta \t trunc \t thresh \t kValue(maxk) \t maxk') for psf in psfs: for thresh in threshs: psf = psf.withGSParams(maxk_threshold=thresh) - rtol = 1.e-7 if psf.trunc == 0 else 3.e-3 + if is_jax_galsim(): + rtol = 5e-3 + else: + rtol = 1.e-7 if psf.trunc == 0 else 3.e-3 fk = psf.kValue(psf.maxk,0).real/psf.flux - print(f'{psf.beta} \t {int(psf.trunc)} \t {thresh:.1e} \t {fk:.3e}') + print(f'{psf.beta} \t {int(psf.trunc)} \t {thresh:.1e} \t {fk:.3e} \t {psf.maxk:.3e}') np.testing.assert_allclose(abs(psf.kValue(psf.maxk,0).real)/psf.flux, thresh, rtol=rtol) @@ -217,7 +227,7 @@ def test_moffat_radii(): np.testing.assert_almost_equal( hlr_sum, 0.5, decimal=4, err_msg="Error in Moffat constructor with half-light radius") - np.testing.assert_equal( + np.testing.assert_array_equal( test_gal.half_light_radius, test_hlr, err_msg="Moffat half_light_radius returned wrong value") @@ -282,7 +292,7 @@ def test_moffat_radii(): np.testing.assert_almost_equal( ratio, 0.5, decimal=4, err_msg="Error in Moffat constructor with fwhm") - np.testing.assert_equal( + np.testing.assert_array_equal( test_gal.fwhm, test_fwhm, err_msg="Moffat fwhm returned wrong value") @@ -312,7 +322,7 @@ def test_moffat_radii(): np.testing.assert_almost_equal( hlr_sum, 0.5, decimal=4, err_msg="Error in Moffat constructor with half-light radius") - np.testing.assert_equal( + np.testing.assert_allclose( test_gal.half_light_radius, test_hlr, err_msg="Moffat hlr incorrect") diff --git a/tests/test_noise.py b/tests/test_noise.py index df788ca004..faaaefad0c 100644 --- a/tests/test_noise.py +++ b/tests/test_noise.py @@ -37,7 +37,11 @@ def test_deviate_noise(): """ u = galsim.UniformDeviate(testseed) uResult = np.empty((10,10)) - u.generate(uResult) + # jax-galsim cannot fill arrays so it returns + if is_jax_galsim(): + uResult = u.generate(uResult) + else: + u.generate(uResult) noise = galsim.DeviateNoise(galsim.UniformDeviate(testseed)) @@ -100,7 +104,11 @@ def test_gaussian_noise(): gSigma = 17.23 g = galsim.GaussianDeviate(testseed, sigma=gSigma) gResult = np.empty((10,10)) - g.generate(gResult) + # jax-galsim cannot fill arrays so it returns + if is_jax_galsim(): + gResult = g.generate(gResult) + else: + g.generate(gResult) noise = galsim.DeviateNoise(g) # Test filling an image @@ -276,13 +284,22 @@ def test_variable_gaussian_noise(): gSigma2 = 28.55 var_image = galsim.ImageD(galsim.BoundsI(0,9,0,9)) coords = np.ogrid[0:10, 0:10] - var_image.array[ (coords[0] + coords[1]) % 2 == 1 ] = gSigma1**2 - var_image.array[ (coords[0] + coords[1]) % 2 == 0 ] = gSigma2**2 + # jax does not support item assignment + if is_jax_galsim(): + var_image._array = var_image.array.at[(coords[0] + coords[1]) % 2 == 1].set(gSigma1**2) + var_image._array = var_image.array.at[(coords[0] + coords[1]) % 2 == 0].set(gSigma2**2) + else: + var_image.array[ (coords[0] + coords[1]) % 2 == 1 ] = gSigma1**2 + var_image.array[ (coords[0] + coords[1]) % 2 == 0 ] = gSigma2**2 print('var_image.array = ',var_image.array) g = galsim.GaussianDeviate(testseed, sigma=1.) vgResult = np.empty((10,10)) - g.generate(vgResult) + # jax-galsim cannot fill arrays so it returns + if is_jax_galsim(): + vgResult = g.generate(vgResult) + else: + g.generate(vgResult) vgResult *= np.sqrt(var_image.array) # Test filling an image @@ -302,7 +319,7 @@ def test_variable_gaussian_noise(): err_msg="Wrong VariableGaussian noise generated for Fortran-ordered Image") # Check var_image property - np.testing.assert_almost_equal( + np.testing.assert_array_almost_equal( vgn.var_image.array, var_image.array, precision, err_msg="VariableGaussianNoise var_image returns wrong var_image") @@ -311,8 +328,13 @@ def test_variable_gaussian_noise(): big_coords = np.ogrid[0:2048, 0:2048] mask1 = (big_coords[0] + big_coords[1]) % 2 == 0 mask2 = (big_coords[0] + big_coords[1]) % 2 == 1 - big_var_image.array[mask1] = gSigma1**2 - big_var_image.array[mask2] = gSigma2**2 + # jax does not support item assignment + if is_jax_galsim(): + big_var_image._array = big_var_image.array.at[mask1].set(gSigma1**2) + big_var_image._array = big_var_image.array.at[mask2].set(gSigma2**2) + else: + big_var_image.array[mask1] = gSigma1**2 + big_var_image.array[mask2] = gSigma2**2 big_vgn = galsim.VariableGaussianNoise(galsim.BaseDeviate(testseed), big_var_image) big_im = galsim.Image(2048,2048,dtype=float) @@ -320,8 +342,13 @@ def test_variable_gaussian_noise(): var = np.var(big_im.array) print('variance = ',var) print('getVar = ',big_vgn.var_image.array.mean()) + if is_jax_galsim(): + # jax galsim has a different RNG + decimal = 0 + else: + decimal = 1 np.testing.assert_almost_equal( - var, big_vgn.var_image.array.mean(), 1, + var, big_vgn.var_image.array.mean(), decimal, err_msg='Realized variance for VariableGaussianNoise did not match var_image') # Check realized variance in each mask @@ -337,8 +364,13 @@ def test_variable_gaussian_noise(): big_im.addNoise(big_vgn) gal.withFlux(-1.e4).drawImage(image=big_im, add_to_image=True) var = np.var(big_im.array) + if is_jax_galsim(): + # jax galsim has a different RNG + decimal = 0 + else: + decimal = 1 np.testing.assert_almost_equal( - var, big_vgn.var_image.array.mean(), 1, + var, big_vgn.var_image.array.mean(), decimal, err_msg='VariableGaussianNoise wrong when already an object drawn on the image') # Check picklability @@ -376,7 +408,11 @@ def test_poisson_noise(): pMean = 17 p = galsim.PoissonDeviate(testseed, mean=pMean) pResult = np.empty((10,10)) - p.generate(pResult) + # jax does not support item assignment + if is_jax_galsim(): + pResult = p.generate(pResult) + else: + p.generate(pResult) noise = galsim.DeviateNoise(p) # Test filling an image @@ -545,11 +581,24 @@ def test_ccdnoise(): sky = 50 # Tabulated results for the above settings and testseed value. - cResultS = np.array([[44, 47], [50, 49]], dtype=np.int16) - cResultI = np.array([[44, 47], [50, 49]], dtype=np.int32) - cResultF = np.array([[44.45332718, 47.79725266], [50.67744064, 49.58272934]], dtype=np.float32) - cResultD = np.array([[44.453328440057618, 47.797254142519577], - [50.677442088335162, 49.582730949808081]],dtype=np.float64) + if is_jax_galsim(): + # jax-galsim has a different RNG + cResultS = np.array([[42, 52], [49, 45]], dtype=np.int16) # noqa: F841 + cResultI = np.array([[42, 52], [49, 45]], dtype=np.int32) # noqa: F841 + cResultF = np.array([ # noqa: F841 + [42.4286994934082, 52.42875671386719], + [49.016048431396484, 45.61003875732422] + ], dtype=np.float32) + cResultD = np.array([ # noqa: F841 + [42.42870031326479, 52.42875718917211], + [49.016050296441094, 45.61003745208172] + ], dtype=np.float64) + else: + cResultS = np.array([[44, 47], [50, 49]], dtype=np.int16) + cResultI = np.array([[44, 47], [50, 49]], dtype=np.int32) + cResultF = np.array([[44.45332718, 47.79725266], [50.67744064, 49.58272934]], dtype=np.float32) + cResultD = np.array([[44.453328440057618, 47.797254142519577], + [50.677442088335162, 49.582730949808081]],dtype=np.float64) for i in range(4): prec = eval("precision"+typestrings[i]) diff --git a/tests/test_optics.py b/tests/test_optics.py index 6c2398c757..c73bc9d55f 100644 --- a/tests/test_optics.py +++ b/tests/test_optics.py @@ -22,7 +22,7 @@ import galsim from galsim_test_helpers import * -imgdir = os.path.join(".", "Optics_comparison_images") # Directory containing the reference images. +imgdir = os.path.join(os.path.dirname(__file__), "Optics_comparison_images") # Directory containing the reference images. testshape = (512, 512) # shape of image arrays for all tests @@ -745,7 +745,11 @@ def test_OpticalPSF_pupil_plane_size(): im = galsim.Image(512, 512) x = y = np.arange(512) - 256 y, x = np.meshgrid(y, x) - im.array[x**2+y**2 < 230**2] = 1.0 + if is_jax_galsim(): + # no refs in jax-galsim + im._array = im.array.at[x**2+y**2 < 230**2].set(1.0) + else: + im.array[x**2+y**2 < 230**2] = 1.0 # The following still fails (uses deprecated optics framework): # galsim.optics.OpticalPSF(aberrations=[0,0,0,0,0.5], diam=4.0, lam=700.0, pupil_plane_im=im) # But using the new framework, should work. diff --git a/tests/test_phase_psf.py b/tests/test_phase_psf.py index e2fdf247b1..c4dfb923e6 100644 --- a/tests/test_phase_psf.py +++ b/tests/test_phase_psf.py @@ -24,7 +24,7 @@ from galsim_test_helpers import * -imgdir = os.path.join(".", "Optics_comparison_images") # Directory containing the reference images. +imgdir = os.path.join(os.path.dirname(__file__), "Optics_comparison_images") # Directory containing the reference images. pp_file = 'sample_pupil_rolled.fits' theta0 = (0*galsim.arcmin, 0*galsim.arcmin) diff --git a/tests/test_photon_array.py b/tests/test_photon_array.py index f17cf409ea..acb3aada8b 100644 --- a/tests/test_photon_array.py +++ b/tests/test_photon_array.py @@ -60,12 +60,16 @@ def test_photon_array(): check_pickle(photon_array) # Check assignment via numpy [:] - photon_array.x[:] = 5 - photon_array.y[:] = 17 - photon_array.flux[:] = 23 - np.testing.assert_array_equal(photon_array.x, 5.) - np.testing.assert_array_equal(photon_array.y, 17.) - np.testing.assert_array_equal(photon_array.flux, 23.) + # jax does not support direct assignment + if is_jax_galsim(): + pass + else: + photon_array.x[:] = 5 + photon_array.y[:] = 17 + photon_array.flux[:] = 23 + np.testing.assert_array_equal(photon_array.x, 5.) + np.testing.assert_array_equal(photon_array.y, 17.) + np.testing.assert_array_equal(photon_array.flux, 23.) # Check assignment directly to the attributes photon_array.x = 25 @@ -94,9 +98,9 @@ def test_photon_array(): photon_array.x *= 5 photon_array.y += 17 photon_array.flux /= 23 - np.testing.assert_almost_equal(photon_array.x, orig_x * 5.) - np.testing.assert_almost_equal(photon_array.y, orig_y + 17.) - np.testing.assert_almost_equal(photon_array.flux, orig_flux / 23.) + np.testing.assert_array_almost_equal(photon_array.x, orig_x * 5.) + np.testing.assert_array_almost_equal(photon_array.y, orig_y + 17.) + np.testing.assert_array_almost_equal(photon_array.flux, orig_flux / 23.) # Check picklability again with non-zero values check_pickle(photon_array) @@ -181,30 +185,36 @@ def test_photon_array(): x = photon_array.x.copy() y = photon_array.y.copy() photon_array.scaleXY(1.9) - np.testing.assert_almost_equal(photon_array.x, 1.9*x) - np.testing.assert_almost_equal(photon_array.y, 1.9*y) + np.testing.assert_array_almost_equal(photon_array.x, 1.9*x) + np.testing.assert_array_almost_equal(photon_array.y, 1.9*y) # Check ways to assign to photons pa1 = galsim.PhotonArray(50) pa1.x = photon_array.x[:50] - for i in range(50): - pa1.y[i] = photon_array.y[i] - pa1.flux[0:50] = photon_array.flux[:50] + if is_jax_galsim(): + pa1.y = photon_array.y[:50] + else: + for i in range(50): + pa1.y[i] = photon_array.y[i] + if is_jax_galsim(): + pa1.flux = photon_array.flux[:50] + else: + pa1.flux[0:50] = photon_array.flux[:50] pa1.dxdz = photon_array.dxdz[:50] pa1.dydz = photon_array.dydz[:50] pa1.wavelength = photon_array.wavelength[:50] pa1.pupil_u = photon_array.pupil_u[:50] pa1.pupil_v = photon_array.pupil_v[:50] pa1.time = photon_array.time[:50] - np.testing.assert_almost_equal(pa1.x, photon_array.x[:50]) - np.testing.assert_almost_equal(pa1.y, photon_array.y[:50]) - np.testing.assert_almost_equal(pa1.flux, photon_array.flux[:50]) - np.testing.assert_almost_equal(pa1.dxdz, photon_array.dxdz[:50]) - np.testing.assert_almost_equal(pa1.dydz, photon_array.dydz[:50]) - np.testing.assert_almost_equal(pa1.wavelength, photon_array.wavelength[:50]) - np.testing.assert_almost_equal(pa1.pupil_u, photon_array.pupil_u[:50]) - np.testing.assert_almost_equal(pa1.pupil_v, photon_array.pupil_v[:50]) - np.testing.assert_almost_equal(pa1.time, photon_array.time[:50]) + np.testing.assert_array_almost_equal(pa1.x, photon_array.x[:50]) + np.testing.assert_array_almost_equal(pa1.y, photon_array.y[:50]) + np.testing.assert_array_almost_equal(pa1.flux, photon_array.flux[:50]) + np.testing.assert_array_almost_equal(pa1.dxdz, photon_array.dxdz[:50]) + np.testing.assert_array_almost_equal(pa1.dydz, photon_array.dydz[:50]) + np.testing.assert_array_almost_equal(pa1.wavelength, photon_array.wavelength[:50]) + np.testing.assert_array_almost_equal(pa1.pupil_u, photon_array.pupil_u[:50]) + np.testing.assert_array_almost_equal(pa1.pupil_v, photon_array.pupil_v[:50]) + np.testing.assert_array_almost_equal(pa1.time, photon_array.time[:50]) # Check copyFrom pa2 = galsim.PhotonArray(100) @@ -236,7 +246,10 @@ def test_photon_array(): assert pa2.time[17] == pa1.time[20] # Can choose not to copy flux - pa2.flux[27] = -1 + if is_jax_galsim(): + pa2._flux = pa2._flux.at[27].set(-1) + else: + pa2.flux[27] = -1 pa2.copyFrom(pa1, 27, 10, do_flux=False) assert pa2.flux[27] != pa1.flux[10] assert pa2.x[27] == pa1.x[10] @@ -250,8 +263,16 @@ def test_photon_array(): assert pa2.time[37] == pa1.time[8] # ... or the other arrays - pa2.dxdz[47] = pa2.dydz[47] = pa2.wavelength[47] = -1 - pa2.pupil_u[47] = pa2.pupil_v[47] = pa2.time[47] = -1 + if is_jax_galsim(): + pa2._dxdz = pa2._dxdz.at[47].set(-1) + pa2._dydz = pa2._dydz.at[47].set(-1) + pa2._wave = pa2._wave.at[47].set(-1) + pa2._pupil_u = pa2._pupil_u.at[47].set(-1) + pa2._pupil_v = pa2._pupil_v.at[47].set(-1) + pa2._time = pa2._time.at[47].set(-1) + else: + pa2.dxdz[47] = pa2.dydz[47] = pa2.wavelength[47] = -1 + pa2.pupil_u[47] = pa2.pupil_v[47] = pa2.time[47] = -1 pa2.copyFrom(pa1, 47, 18, do_other=False) assert pa2.flux[47] == pa1.flux[18] assert pa2.x[47] == pa1.x[18] @@ -276,7 +297,10 @@ def test_photon_array(): # Error if indices are invalid assert_raises(ValueError, pa2.copyFrom, pa1, slice(50,None), slice(50,None)) - assert_raises(ValueError, pa2.copyFrom, pa1, 100, 0) + if is_jax_galsim(): + pass + else: + assert_raises(ValueError, pa2.copyFrom, pa1, 100, 0) assert_raises(ValueError, pa2.copyFrom, pa1, 0, slice(None)) assert_raises(ValueError, pa2.copyFrom, pa1) assert_raises(ValueError, pa2.copyFrom, pa1, slice(None), pa1.x<0) @@ -293,13 +317,13 @@ def test_photon_array(): photons = galsim.PhotonArray.makeFromImage(ones) print('photons = ',photons) assert len(photons) == 16 - np.testing.assert_almost_equal(photons.flux, 1.) + np.testing.assert_array_almost_equal(photons.flux, 1.) tens = galsim.Image(4,4,init_value=8) photons = galsim.PhotonArray.makeFromImage(tens, max_flux=5.) print('photons = ',photons) assert len(photons) == 32 - np.testing.assert_almost_equal(photons.flux, 4.) + np.testing.assert_array_almost_equal(photons.flux, 4.) assert_raises(ValueError, galsim.PhotonArray.makeFromImage, zero, max_flux=0.) assert_raises(ValueError, galsim.PhotonArray.makeFromImage, zero, max_flux=-2) @@ -1435,9 +1459,15 @@ def test_fromArrays(): flux[Nsplit:] ) - assert pa_batch.x is x - assert pa_batch.y is y - assert pa_batch.flux is flux + if is_jax_galsim(): + # jax-galsim never copies + assert pa_batch.x is not x + assert pa_batch.y is not y + assert pa_batch.flux is not flux + else: + assert pa_batch.x is x + assert pa_batch.y is y + assert pa_batch.flux is flux np.testing.assert_array_equal(pa_batch.x, x) np.testing.assert_array_equal(pa_batch.y, y) np.testing.assert_array_equal(pa_batch.flux, flux) diff --git a/tests/test_random.py b/tests/test_random.py index 429137600b..5fbd90b43f 100644 --- a/tests/test_random.py +++ b/tests/test_random.py @@ -47,42 +47,80 @@ testseed = 1000 # seed used for UniformDeviate for all tests # Warning! If you change testseed, then all of the *Result variables below must change as well. -# the right answer for the first three uniform deviates produced from testseed -uResult = (0.11860922840423882, 0.21456799632869661, 0.43088198406621814) - -# mean, sigma to use for Gaussian tests -gMean = 4.7 -gSigma = 3.2 -# the right answer for the first three Gaussian deviates produced from testseed -gResult = (6.3344979808161215, 6.2082355273987861, -0.069894693358302007) - -# N, p to use for binomial tests -bN = 10 -bp = 0.7 -# the right answer for the first three binomial deviates produced from testseed -bResult = (9, 8, 7) - -# mean to use for Poisson tests -pMean = 7 -# the right answer for the first three Poisson deviates produced from testseed -pResult = (4, 5, 6) - -# a & b to use for Weibull tests -wA = 4. -wB = 9. -# Tabulated results for Weibull -wResult = (5.3648053017485591, 6.3093033550873878, 7.7982696798921074) - -# k & theta to use for Gamma tests -gammaK = 1.5 -gammaTheta = 4.5 -# Tabulated results for Gamma -gammaResult = (4.7375613139927157, 15.272973580418618, 21.485016362839747) - -# n to use for Chi2 tests -chi2N = 30 -# Tabulated results for Chi2 -chi2Result = (32.209933900954049, 50.040002656028513, 24.301442486313896) +if is_jax_galsim(): + # the right answer for the first three uniform deviates produced from testseed + uResult = (0.0160653916, 0.228817832, 0.1609966951) + + # mean, sigma to use for Gaussian tests + gMean = 4.7 + gSigma = 3.2 + # the right answer for the first three Gaussian deviates produced from testseed + gResult = (-2.1568953985, 2.3232138032, 1.5308165692) + + # N, p to use for binomial tests + bN = 10 + bp = 0.7 + # the right answer for the first three binomial deviates produced from testseed + bResult = (5, 8, 7) + + # mean to use for Poisson tests + pMean = 7 + # the right answer for the first three Poisson deviates produced from testseed + pResult = (6, 11, 4) + + # a & b to use for Weibull tests + wA = 4.0 + wB = 9.0 + # Tabulated results for Weibull + wResult = (3.2106530102, 6.4256210259, 5.8255498741) + + # k & theta to use for Gamma tests + gammaK = 1.5 + gammaTheta = 4.5 + # Tabulated results for Gamma + gammaResult = (10.9318881415, 7.6074550007, 2.0526795529) + + # n to use for Chi2 tests + chi2N = 30 + # Tabulated results for Chi2 + chi2Result = (36.7583415337, 32.7223187231, 23.1555198334) +else: + # the right answer for the first three uniform deviates produced from testseed + uResult = (0.11860922840423882, 0.21456799632869661, 0.43088198406621814) + + # mean, sigma to use for Gaussian tests + gMean = 4.7 + gSigma = 3.2 + # the right answer for the first three Gaussian deviates produced from testseed + gResult = (6.3344979808161215, 6.2082355273987861, -0.069894693358302007) + + # N, p to use for binomial tests + bN = 10 + bp = 0.7 + # the right answer for the first three binomial deviates produced from testseed + bResult = (9, 8, 7) + + # mean to use for Poisson tests + pMean = 7 + # the right answer for the first three Poisson deviates produced from testseed + pResult = (4, 5, 6) + + # a & b to use for Weibull tests + wA = 4. + wB = 9. + # Tabulated results for Weibull + wResult = (5.3648053017485591, 6.3093033550873878, 7.7982696798921074) + + # k & theta to use for Gamma tests + gammaK = 1.5 + gammaTheta = 4.5 + # Tabulated results for Gamma + gammaResult = (4.7375613139927157, 15.272973580418618, 21.485016362839747) + + # n to use for Chi2 tests + chi2N = 30 + # Tabulated results for Chi2 + chi2Result = (32.209933900954049, 50.040002656028513, 24.301442486313896) #function and min&max to use for DistDeviate function call tests dmin=0.0 @@ -214,14 +252,20 @@ def test_uniform(): # Test generate u.seed(testseed) test_array = np.empty(3) - u.generate(test_array) + if is_jax_galsim(): + test_array = u.generate(test_array) + else: + u.generate(test_array) np.testing.assert_array_almost_equal( test_array, np.array(uResult), precision, err_msg='Wrong uniform random number sequence from generate.') # Test add_generate u.seed(testseed) - u.add_generate(test_array) + if is_jax_galsim(): + test_array = u.add_generate(test_array) + else: + u.add_generate(test_array) np.testing.assert_array_almost_equal( test_array, 2.*np.array(uResult), precision, err_msg='Wrong uniform random number sequence from generate.') @@ -229,14 +273,20 @@ def test_uniform(): # Test generate with a float32 array u.seed(testseed) test_array = np.empty(3, dtype=np.float32) - u.generate(test_array) + if is_jax_galsim(): + test_array = u.generate(test_array) + else: + u.generate(test_array) np.testing.assert_array_almost_equal( test_array, np.array(uResult), precisionF, err_msg='Wrong uniform random number sequence from generate.') # Test add_generate u.seed(testseed) - u.add_generate(test_array) + if is_jax_galsim(): + test_array = u.add_generate(test_array) + else: + u.add_generate(test_array) np.testing.assert_array_almost_equal( test_array, 2.*np.array(uResult), precisionF, err_msg='Wrong uniform random number sequence from generate.') @@ -247,14 +297,26 @@ def test_uniform(): v1 = np.empty(555) v2 = np.empty(555) with single_threaded(): - u1.generate(v1) + if is_jax_galsim(): + v1 = u1.generate(v1) + else: + u1.generate(v1) with single_threaded(num_threads=10): - u2.generate(v2) + if is_jax_galsim(): + v2 = u2.generate(v2) + else: + u2.generate(v2) np.testing.assert_array_equal(v1, v2) with single_threaded(): - u1.add_generate(v1) + if is_jax_galsim(): + v1 = u1.add_generate(v1) + else: + u1.add_generate(v1) with single_threaded(num_threads=10): - u2.add_generate(v2) + if is_jax_galsim(): + v2 = u2.add_generate(v2) + else: + u2.add_generate(v2) np.testing.assert_array_equal(v1, v2) # Check picklability @@ -275,12 +337,16 @@ def test_uniform(): assert u1 != u2, "Consecutive UniformDeviate(None) compared equal!" # We shouldn't be able to construct a UniformDeviate from anything but a BaseDeviate, int, str, # or None. - assert_raises(TypeError, galsim.UniformDeviate, dict()) - assert_raises(TypeError, galsim.UniformDeviate, list()) - assert_raises(TypeError, galsim.UniformDeviate, set()) + if is_jax_galsim(): + # jax galsim doesn't test this + pass + else: + assert_raises(TypeError, galsim.UniformDeviate, dict()) + assert_raises(TypeError, galsim.UniformDeviate, list()) + assert_raises(TypeError, galsim.UniformDeviate, set()) - assert_raises(TypeError, u.seed, '123') - assert_raises(TypeError, u.seed, 12.3) + assert_raises(TypeError, u.seed, '123') + assert_raises(TypeError, u.seed, 12.3) @timer @@ -323,19 +389,28 @@ def test_gaussian(): v1,v2 = g(),g2() print('after %d vals, next one is %s, %s'%(nvals,v1,v2)) assert v1 == v2 - # Note: For Gaussian, this only works if nvals is even. - g2 = galsim.GaussianDeviate(testseed, mean=gMean, sigma=gSigma) - g2.discard(nvals+1, suppress_warnings=True) - v1,v2 = g(),g2() - print('after %d vals, next one is %s, %s'%(nvals+1,v1,v2)) - assert v1 != v2 - assert g.has_reliable_discard - assert g.generates_in_pairs + if is_jax_galsim(): + # jax doesn't have this issue + assert g.has_reliable_discard + assert not g.generates_in_pairs + else: + # Note: For Gaussian, this only works if nvals is even. + g2 = galsim.GaussianDeviate(testseed, mean=gMean, sigma=gSigma) + g2.discard(nvals+1, suppress_warnings=True) + v1,v2 = g(),g2() + print('after %d vals, next one is %s, %s'%(nvals+1,v1,v2)) + assert v1 != v2 + assert g.has_reliable_discard + assert g.generates_in_pairs # If don't explicitly suppress the warning, then a warning is emitted when n is odd. g2 = galsim.GaussianDeviate(testseed, mean=gMean, sigma=gSigma) - with assert_warns(galsim.GalSimWarning): - g2.discard(nvals+1) + if is_jax_galsim(): + pass + else: + # jax doesn't do this + with assert_warns(galsim.GalSimWarning): + g2.discard(nvals+1) # Check seed, reset g.seed(testseed) @@ -405,7 +480,10 @@ def test_gaussian(): # Test generate g.seed(testseed) test_array = np.empty(3) - g.generate(test_array) + if is_jax_galsim(): + test_array = g.generate(test_array) + else: + g.generate(test_array) np.testing.assert_array_almost_equal( test_array, np.array(gResult), precision, err_msg='Wrong Gaussian random number sequence from generate.') @@ -413,29 +491,43 @@ def test_gaussian(): # Test generate_from_variance. g2 = galsim.GaussianDeviate(testseed, mean=5, sigma=0.3) g3 = galsim.GaussianDeviate(testseed, mean=5, sigma=0.3) + test_array = np.empty(3) test_array.fill(gSigma**2) - g2.generate_from_variance(test_array) + if is_jax_galsim(): + test_array = g2.generate_from_variance(test_array) + else: + g2.generate_from_variance(test_array) np.testing.assert_array_almost_equal( test_array, np.array(gResult)-gMean, precision, err_msg='Wrong Gaussian random number sequence from generate_from_variance.') # After running generate_from_variance, it should be back to using the specified mean, sigma. # Note: need to round up to even number for discard, since gd generates 2 at a time. - g3.discard((len(test_array)+1)//2 * 2) + if is_jax_galsim(): + g3.discard(len(test_array)) + else: + g3.discard((len(test_array)+1)//2 * 2) print('g2,g3 = ',g2(),g3()) assert g2() == g3() # Test generate with a float32 array. g.seed(testseed) test_array = np.empty(3, dtype=np.float32) - g.generate(test_array) + if is_jax_galsim(): + test_array = g.generate(test_array) + else: + g.generate(test_array) np.testing.assert_array_almost_equal( test_array, np.array(gResult), precisionF, err_msg='Wrong Gaussian random number sequence from generate.') # Test generate_from_variance. g2.seed(testseed) + test_array = np.empty(3, dtype=np.float32) test_array.fill(gSigma**2) - g2.generate_from_variance(test_array) + if is_jax_galsim(): + test_array = g2.generate_from_variance(test_array) + else: + g2.generate_from_variance(test_array) np.testing.assert_array_almost_equal( test_array, np.array(gResult)-gMean, precisionF, err_msg='Wrong Gaussian random number sequence from generate_from_variance.') @@ -446,23 +538,45 @@ def test_gaussian(): v1 = np.empty(555) v2 = np.empty(555) with single_threaded(): - g1.generate(v1) + if is_jax_galsim(): + v1 = g1.generate(v1) + else: + g1.generate(v1) with single_threaded(num_threads=10): - g2.generate(v2) + if is_jax_galsim(): + v2 = g2.generate(v2) + else: + g2.generate(v2) np.testing.assert_array_equal(v1, v2) with single_threaded(): - g1.add_generate(v1) + if is_jax_galsim(): + v1 = g1.add_generate(v1) + else: + g1.add_generate(v1) with single_threaded(num_threads=10): - g2.add_generate(v2) + if is_jax_galsim(): + v2 = g2.add_generate(v2) + else: + g2.add_generate(v2) np.testing.assert_array_equal(v1, v2) ud = galsim.UniformDeviate(testseed + 3) ud.generate(v1) v1 += 6.7 - v2[:] = v1 + if is_jax_galsim(): + # jax galsim makes a copy + v2 = v1.copy() + else: + v2[:] = v1 with single_threaded(): - g1.generate_from_variance(v1) + if is_jax_galsim(): + v1 = g1.generate_from_variance(v1) + else: + g1.generate_from_variance(v1) with single_threaded(num_threads=10): - g2.generate_from_variance(v2) + if is_jax_galsim(): + v2 = g2.generate_from_variance(v2) + else: + g2.generate_from_variance(v2) np.testing.assert_array_equal(v1, v2) # Check picklability @@ -480,11 +594,15 @@ def test_gaussian(): assert g1 != g2, "Consecutive GaussianDeviate(None) compared equal!" # We shouldn't be able to construct a GaussianDeviate from anything but a BaseDeviate, int, str, # or None. - assert_raises(TypeError, galsim.GaussianDeviate, dict()) - assert_raises(TypeError, galsim.GaussianDeviate, list()) - assert_raises(TypeError, galsim.GaussianDeviate, set()) + if is_jax_galsim(): + pass + else: + # jax-galsim doesn't test for these things + assert_raises(TypeError, galsim.GaussianDeviate, dict()) + assert_raises(TypeError, galsim.GaussianDeviate, list()) + assert_raises(TypeError, galsim.GaussianDeviate, set()) - assert_raises(ValueError, galsim.GaussianDeviate, testseed, mean=1, sigma=-1) + assert_raises(ValueError, galsim.GaussianDeviate, testseed, mean=1, sigma=-1) @timer @@ -597,7 +715,10 @@ def test_binomial(): # Test generate b.seed(testseed) test_array = np.empty(3) - b.generate(test_array) + if is_jax_galsim(): + test_array = b.generate(test_array) + else: + b.generate(test_array) np.testing.assert_array_almost_equal( test_array, np.array(bResult), precision, err_msg='Wrong binomial random number sequence from generate.') @@ -605,7 +726,10 @@ def test_binomial(): # Test generate with an int array b.seed(testseed) test_array = np.empty(3, dtype=int) - b.generate(test_array) + if is_jax_galsim(): + test_array = b.generate(test_array) + else: + b.generate(test_array) np.testing.assert_array_almost_equal( test_array, np.array(bResult), precisionI, err_msg='Wrong binomial random number sequence from generate.') @@ -616,14 +740,26 @@ def test_binomial(): v1 = np.empty(555) v2 = np.empty(555) with single_threaded(): - b1.generate(v1) + if is_jax_galsim(): + v1 = b1.generate(v1) + else: + b1.generate(v1) with single_threaded(num_threads=10): - b2.generate(v2) + if is_jax_galsim(): + v2 = b2.generate(v2) + else: + b2.generate(v2) np.testing.assert_array_equal(v1, v2) with single_threaded(): - b1.add_generate(v1) + if is_jax_galsim(): + v1 = b1.add_generate(v1) + else: + b1.add_generate(v1) with single_threaded(num_threads=10): - b2.add_generate(v2) + if is_jax_galsim(): + v2 = b2.add_generate(v2) + else: + b2.add_generate(v2) np.testing.assert_array_equal(v1, v2) # Check picklability @@ -641,9 +777,13 @@ def test_binomial(): assert b1 != b2, "Consecutive BinomialDeviate(None) compared equal!" # We shouldn't be able to construct a BinomialDeviate from anything but a BaseDeviate, int, str, # or None. - assert_raises(TypeError, galsim.BinomialDeviate, dict()) - assert_raises(TypeError, galsim.BinomialDeviate, list()) - assert_raises(TypeError, galsim.BinomialDeviate, set()) + if is_jax_galsim(): + pass + else: + # jax does not raise for this + assert_raises(TypeError, galsim.BinomialDeviate, dict()) + assert_raises(TypeError, galsim.BinomialDeviate, list()) + assert_raises(TypeError, galsim.BinomialDeviate, set()) @timer @@ -697,14 +837,23 @@ def test_poisson(): p2.discard(nvals, suppress_warnings=True) v1,v2 = p(),p2() print('With mean = %d, after %d vals, next one is %s, %s'%(high_mean,nvals,v1,v2)) - assert v1 != v2 - assert not p.has_reliable_discard + if is_jax_galsim(): + # jax always discards reliably + assert v1 == v2 + assert p.has_reliable_discard + else: + assert v1 != v2 + assert not p.has_reliable_discard assert not p.generates_in_pairs # Discard normally emits a warning for Poisson p2 = galsim.PoissonDeviate(testseed, mean=pMean) - with assert_warns(galsim.GalSimWarning): + if is_jax_galsim(): + # jax always discards reliably p2.discard(nvals) + else: + with assert_warns(galsim.GalSimWarning): + p2.discard(nvals) # Check seed, reset p = galsim.PoissonDeviate(testseed, mean=pMean) @@ -774,7 +923,10 @@ def test_poisson(): # Test generate p.seed(testseed) test_array = np.empty(3) - p.generate(test_array) + if is_jax_galsim(): + test_array = p.generate(test_array) + else: + p.generate(test_array) np.testing.assert_array_almost_equal( test_array, np.array(pResult), precision, err_msg='Wrong poisson random number sequence from generate.') @@ -782,7 +934,10 @@ def test_poisson(): # Test generate with an int array p.seed(testseed) test_array = np.empty(3, dtype=int) - p.generate(test_array) + if is_jax_galsim(): + test_array = p.generate(test_array) + else: + p.generate(test_array) np.testing.assert_array_almost_equal( test_array, np.array(pResult), precisionI, err_msg='Wrong poisson random number sequence from generate.') @@ -790,7 +945,10 @@ def test_poisson(): # Test generate_from_expectation p2 = galsim.PoissonDeviate(testseed, mean=77) test_array = np.array([pMean]*3, dtype=int) - p2.generate_from_expectation(test_array) + if is_jax_galsim(): + test_array = p2.generate_from_expectation(test_array) + else: + p2.generate_from_expectation(test_array) np.testing.assert_array_almost_equal( test_array, np.array(pResult), precisionI, err_msg='Wrong poisson random number sequence from generate_from_expectation.') @@ -807,14 +965,26 @@ def test_poisson(): v1 = np.empty(555) v2 = np.empty(555) with single_threaded(): - p1.generate(v1) + if is_jax_galsim(): + v1 = p1.generate(v1) + else: + p1.generate(v1) with single_threaded(num_threads=10): - p2.generate(v2) + if is_jax_galsim(): + v2 = p2.generate(v2) + else: + p2.generate(v2) np.testing.assert_array_equal(v1, v2) with single_threaded(): - p1.add_generate(v1) + if is_jax_galsim(): + v1 = p1.add_generate(v1) + else: + p1.add_generate(v1) with single_threaded(num_threads=10): - p2.add_generate(v2) + if is_jax_galsim(): + v2 = p2.add_generate(v2) + else: + p2.add_generate(v2) np.testing.assert_array_equal(v1, v2) # Check picklability @@ -832,9 +1002,12 @@ def test_poisson(): assert p1 != p2, "Consecutive PoissonDeviate(None) compared equal!" # We shouldn't be able to construct a PoissonDeviate from anything but a BaseDeviate, int, str, # or None. - assert_raises(TypeError, galsim.PoissonDeviate, dict()) - assert_raises(TypeError, galsim.PoissonDeviate, list()) - assert_raises(TypeError, galsim.PoissonDeviate, set()) + if is_jax_galsim(): + pass + else: + assert_raises(TypeError, galsim.PoissonDeviate, dict()) + assert_raises(TypeError, galsim.PoissonDeviate, list()) + assert_raises(TypeError, galsim.PoissonDeviate, set()) @timer @@ -966,11 +1139,20 @@ def test_poisson_zeromean(): # Test generate test_array = np.empty(3, dtype=int) - p.generate(test_array) + if is_jax_galsim(): + test_array = p.generate(test_array) + else: + p.generate(test_array) np.testing.assert_array_equal(test_array, 0) - p2.generate(test_array) + if is_jax_galsim(): + test_array = p2.generate(test_array) + else: + p2.generate(test_array) np.testing.assert_array_equal(test_array, 0) - p3.generate(test_array) + if is_jax_galsim(): + test_array = p3.generate(test_array) + else: + p3.generate(test_array) np.testing.assert_array_equal(test_array, 0) # Test generate_from_expectation @@ -982,16 +1164,20 @@ def test_poisson_zeromean(): assert test_array[2] != 0 # Error raised if mean<0 - with assert_raises(ValueError): - p = galsim.PoissonDeviate(testseed, mean=-0.1) - with assert_raises(ValueError): - p = galsim.PoissonDeviate(testseed, mean=-10) - test_array = np.array([-1,1,4]) - with assert_raises(ValueError): - p.generate_from_expectation(test_array) - test_array = np.array([1,-1,-4]) - with assert_raises(ValueError): - p.generate_from_expectation(test_array) + # jax doesn't raise here + if is_jax_galsim(): + pass + else: + with assert_raises(ValueError): + p = galsim.PoissonDeviate(testseed, mean=-0.1) + with assert_raises(ValueError): + p = galsim.PoissonDeviate(testseed, mean=-10) + test_array = np.array([-1,1,4]) + with assert_raises(ValueError): + p.generate_from_expectation(test_array) + test_array = np.array([1,-1,-4]) + with assert_raises(ValueError): + p.generate_from_expectation(test_array) @timer def test_weibull(): @@ -1103,7 +1289,10 @@ def test_weibull(): # Test generate w.seed(testseed) test_array = np.empty(3) - w.generate(test_array) + if is_jax_galsim(): + test_array = w.generate(test_array) + else: + w.generate(test_array) np.testing.assert_array_almost_equal( test_array, np.array(wResult), precision, err_msg='Wrong weibull random number sequence from generate.') @@ -1111,7 +1300,10 @@ def test_weibull(): # Test generate with a float32 array w.seed(testseed) test_array = np.empty(3, dtype=np.float32) - w.generate(test_array) + if is_jax_galsim(): + test_array = w.generate(test_array) + else: + w.generate(test_array) np.testing.assert_array_almost_equal( test_array, np.array(wResult), precisionF, err_msg='Wrong weibull random number sequence from generate.') @@ -1122,14 +1314,26 @@ def test_weibull(): v1 = np.empty(555) v2 = np.empty(555) with single_threaded(): - w1.generate(v1) + if is_jax_galsim(): + v1 = w1.generate(v1) + else: + w1.generate(v1) with single_threaded(num_threads=10): - w2.generate(v2) + if is_jax_galsim(): + v2 = w2.generate(v2) + else: + w2.generate(v2) np.testing.assert_array_equal(v1, v2) with single_threaded(): - w1.add_generate(v1) + if is_jax_galsim(): + v1 = w1.add_generate(v1) + else: + w1.add_generate(v1) with single_threaded(num_threads=10): - w2.add_generate(v2) + if is_jax_galsim(): + v2 = w2.add_generate(v2) + else: + w2.add_generate(v2) np.testing.assert_array_equal(v1, v2) # Check picklability @@ -1147,9 +1351,12 @@ def test_weibull(): assert w1 != w2, "Consecutive WeibullDeviate(None) compared equal!" # We shouldn't be able to construct a WeibullDeviate from anything but a BaseDeviate, int, str, # or None. - assert_raises(TypeError, galsim.WeibullDeviate, dict()) - assert_raises(TypeError, galsim.WeibullDeviate, list()) - assert_raises(TypeError, galsim.WeibullDeviate, set()) + if is_jax_galsim(): + pass + else: + assert_raises(TypeError, galsim.WeibullDeviate, dict()) + assert_raises(TypeError, galsim.WeibullDeviate, list()) + assert_raises(TypeError, galsim.WeibullDeviate, set()) @timer @@ -1192,14 +1399,22 @@ def test_gamma(): v1,v2 = g(),g2() print('after %d vals, next one is %s, %s'%(nvals,v1,v2)) # Gamma uses at least 2 rngs per value, but can use arbitrarily more than this. - assert v1 != v2 - assert not g.has_reliable_discard + if is_jax_galsim(): + assert v1 == v2 + assert g.has_reliable_discard + else: + assert v1 != v2 + assert not g.has_reliable_discard assert not g.generates_in_pairs # Discard normally emits a warning for Gamma g2 = galsim.GammaDeviate(testseed, k=gammaK, theta=gammaTheta) - with assert_warns(galsim.GalSimWarning): + if is_jax_galsim(): + # jax always discards reliably g2.discard(nvals) + else: + with assert_warns(galsim.GalSimWarning): + g2.discard(nvals) # Check seed, reset g.seed(testseed) @@ -1266,7 +1481,10 @@ def test_gamma(): # Test generate g.seed(testseed) test_array = np.empty(3) - g.generate(test_array) + if is_jax_galsim(): + test_array = g.generate(test_array) + else: + g.generate(test_array) np.testing.assert_array_almost_equal( test_array, np.array(gammaResult), precision, err_msg='Wrong gamma random number sequence from generate.') @@ -1274,7 +1492,10 @@ def test_gamma(): # Test generate with a float32 array g.seed(testseed) test_array = np.empty(3, dtype=np.float32) - g.generate(test_array) + if is_jax_galsim(): + test_array = g.generate(test_array) + else: + g.generate(test_array) np.testing.assert_array_almost_equal( test_array, np.array(gammaResult), precisionF, err_msg='Wrong gamma random number sequence from generate.') @@ -1294,9 +1515,12 @@ def test_gamma(): assert g1 != g2, "Consecutive GammaDeviate(None) compared equal!" # We shouldn't be able to construct a GammaDeviate from anything but a BaseDeviate, int, str, # or None. - assert_raises(TypeError, galsim.GammaDeviate, dict()) - assert_raises(TypeError, galsim.GammaDeviate, list()) - assert_raises(TypeError, galsim.GammaDeviate, set()) + if is_jax_galsim(): + pass + else: + assert_raises(TypeError, galsim.GammaDeviate, dict()) + assert_raises(TypeError, galsim.GammaDeviate, list()) + assert_raises(TypeError, galsim.GammaDeviate, set()) @timer @@ -1339,14 +1563,22 @@ def test_chi2(): v1,v2 = c(),c2() print('after %d vals, next one is %s, %s'%(nvals,v1,v2)) # Chi2 uses at least 2 rngs per value, but can use arbitrarily more than this. - assert v1 != v2 - assert not c.has_reliable_discard + if is_jax_galsim(): + assert v1 == v2 + assert c.has_reliable_discard + else: + assert v1 != v2 + assert not c.has_reliable_discard assert not c.generates_in_pairs # Discard normally emits a warning for Chi2 c2 = galsim.Chi2Deviate(testseed, n=chi2N) - with assert_warns(galsim.GalSimWarning): + if is_jax_galsim(): + # jax always discards reliably c2.discard(nvals) + else: + with assert_warns(galsim.GalSimWarning): + c2.discard(nvals) # Check seed, reset c.seed(testseed) @@ -1413,7 +1645,10 @@ def test_chi2(): # Test generate c.seed(testseed) test_array = np.empty(3) - c.generate(test_array) + if is_jax_galsim(): + test_array = c.generate(test_array) + else: + c.generate(test_array) np.testing.assert_array_almost_equal( test_array, np.array(chi2Result), precision, err_msg='Wrong Chi^2 random number sequence from generate.') @@ -1421,7 +1656,10 @@ def test_chi2(): # Test generate with a float32 array c.seed(testseed) test_array = np.empty(3, dtype=np.float32) - c.generate(test_array) + if is_jax_galsim(): + test_array = c.generate(test_array) + else: + c.generate(test_array) np.testing.assert_array_almost_equal( test_array, np.array(chi2Result), precisionF, err_msg='Wrong Chi^2 random number sequence from generate.') @@ -1441,9 +1679,12 @@ def test_chi2(): assert c1 != c2, "Consecutive Chi2Deviate(None) compared equal!" # We shouldn't be able to construct a Chi2Deviate from anything but a BaseDeviate, int, str, # or None. - assert_raises(TypeError, galsim.Chi2Deviate, dict()) - assert_raises(TypeError, galsim.Chi2Deviate, list()) - assert_raises(TypeError, galsim.Chi2Deviate, set()) + if is_jax_galsim(): + pass + else: + assert_raises(TypeError, galsim.Chi2Deviate, dict()) + assert_raises(TypeError, galsim.Chi2Deviate, list()) + assert_raises(TypeError, galsim.Chi2Deviate, set()) @timer @@ -1927,7 +2168,11 @@ def test_permute(): ind_list = list(range(n_list)) # Permute both at the same time. - galsim.random.permute(312, my_list, ind_list) + if is_jax_galsim(): + # jax requires arrays + galsim.random.permute(312, np.array(my_list), np.array(ind_list)) + else: + galsim.random.permute(312, my_list, ind_list) # Make sure that everything is sensible for ind in range(n_list): @@ -1935,13 +2180,20 @@ def test_permute(): # Repeat with same seed, should do same permutation. my_list = copy.deepcopy(my_list_copy) - galsim.random.permute(312, my_list) + if is_jax_galsim(): + galsim.random.permute(312, np.array(my_list)) + else: + galsim.random.permute(312, my_list) for ind in range(n_list): assert my_list_copy[ind_list[ind]] == my_list[ind] # permute with no lists should raise TypeError - with assert_raises(TypeError): - galsim.random.permute(312) + # jax galsim does not raise + if is_jax_galsim(): + pass + else: + with assert_raises(TypeError): + galsim.random.permute(312) @timer @@ -1949,10 +2201,16 @@ def test_ne(): """ Check that inequality works as expected for corner cases where the reprs of two unequal BaseDeviates may be the same due to truncation. """ - a = galsim.BaseDeviate(seed='1 2 3 4 5 6 7 8 9 10') - b = galsim.BaseDeviate(seed='1 2 3 7 6 5 4 8 9 10') - assert repr(a) == repr(b) - assert a != b + if is_jax_galsim(): + a = galsim.BaseDeviate(seed="(0, 10)") + b = galsim.BaseDeviate(seed="(0, 11)") + assert repr(a) != repr(b) + assert a != b + else: + a = galsim.BaseDeviate(seed='1 2 3 4 5 6 7 8 9 10') + b = galsim.BaseDeviate(seed='1 2 3 7 6 5 4 8 9 10') + assert repr(a) == repr(b) + assert a != b # Check DistDeviate separately, since it overrides __repr__ and __eq__ d1 = galsim.DistDeviate(seed=a, function=galsim.LookupTable([1, 2, 3], [4, 5, 6])) diff --git a/tests/test_real.py b/tests/test_real.py index eef6312b92..c95a2d73a3 100644 --- a/tests/test_real.py +++ b/tests/test_real.py @@ -29,7 +29,7 @@ # set up any necessary info for tests ### Note: changes to either of the tests below might require regeneration of the catalog and image ### files that are saved here. Modify with care!!! -image_dir = './real_comparison_images' +image_dir = os.path.join(os.path.dirname(__file__), './real_comparison_images') catalog_file = 'test_catalog.fits' # some helper functions @@ -77,8 +77,8 @@ def test_real_galaxy_catalog(): # Test some values that are lazy evaluated: assert rgc.ident[0] == '100533' - assert rgc.gal_file_name[0] == './real_comparison_images/test_images.fits' - assert rgc.psf_file_name[0] == './real_comparison_images/test_images.fits' + assert rgc.gal_file_name[0].split("/")[-2:] == ['real_comparison_images', 'test_images.fits'] + assert rgc.psf_file_name[0].split("/")[-2:] == ['real_comparison_images', 'test_images.fits'] assert rgc.noise_file_name is None np.testing.assert_array_equal(rgc.gal_hdu, [0,1]) np.testing.assert_array_equal(rgc.psf_hdu, [2,3]) @@ -962,7 +962,7 @@ def test_sys_share_dir(): import galsim print(galsim.meta_data.share_dir) """) - script_file = os.path.join('scratch_space', 'sys_share_dir.py') + script_file = os.path.join(os.path.dirname(__file__), os.path.join('scratch_space', 'sys_share_dir.py')) with open(script_file, 'w') as f: f.write(script) env = os.environ.copy() diff --git a/tests/test_roman.py b/tests/test_roman.py index fa3e7db8e6..6371819227 100644 --- a/tests/test_roman.py +++ b/tests/test_roman.py @@ -16,9 +16,9 @@ # and/or other materials provided with the distribution. # +import sys import logging import os -import sys import numpy as np import datetime from unittest import mock @@ -148,7 +148,7 @@ def test_roman_wcs(): # we compare that with the GalSim routines for finding SCAs. import datetime date = datetime.datetime(2025, 1, 12) - test_data_file = os.path.join('roman_files','chris_comparison.txt') + test_data_file = os.path.join(os.path.dirname(__file__), os.path.join('roman_files','chris_comparison.txt')) test_data = np.loadtxt(test_data_file).transpose() ra_cen = test_data[0,:] diff --git a/tests/test_shear.py b/tests/test_shear.py index 6efd9f1423..26546cb6a6 100644 --- a/tests/test_shear.py +++ b/tests/test_shear.py @@ -116,7 +116,7 @@ def test_shear_initialization(): vec_ideal = np.zeros(len(vec)) np.testing.assert_array_almost_equal(vec, vec_ideal, decimal = decimal, err_msg = "Incorrectly initialized empty shear") - np.testing.assert_equal(s.q, 1.) + np.testing.assert_array_equal(s.q, 1.) # now loop over shear values and ways of initializing for ind in range(n_shear): # initialize with reduced shear components @@ -176,17 +176,23 @@ def test_shear_initialization(): assert_raises(TypeError,galsim.Shear,g1=0.3,e2=0.2) assert_raises(TypeError,galsim.Shear,eta1=0.3,beta=0.*galsim.degrees) assert_raises(TypeError,galsim.Shear,q=0.3) - assert_raises(galsim.GalSimRangeError,galsim.Shear,q=1.3,beta=0.*galsim.degrees) - assert_raises(galsim.GalSimRangeError,galsim.Shear,g1=0.9,g2=0.6) - assert_raises(galsim.GalSimRangeError,galsim.Shear,e=-1.3,beta=0.*galsim.radians) - assert_raises(galsim.GalSimRangeError,galsim.Shear,e=1.3,beta=0.*galsim.radians) - assert_raises(galsim.GalSimRangeError,galsim.Shear,e1=0.7,e2=0.9) + if is_jax_galsim(): + pass + else: + assert_raises(galsim.GalSimRangeError,galsim.Shear,q=1.3,beta=0.*galsim.degrees) + assert_raises(galsim.GalSimRangeError,galsim.Shear,g1=0.9,g2=0.6) + assert_raises(galsim.GalSimRangeError,galsim.Shear,e=-1.3,beta=0.*galsim.radians) + assert_raises(galsim.GalSimRangeError,galsim.Shear,e=1.3,beta=0.*galsim.radians) + assert_raises(galsim.GalSimRangeError,galsim.Shear,e1=0.7,e2=0.9) assert_raises(TypeError,galsim.Shear,g=0.5) assert_raises(TypeError,galsim.Shear,e=0.5) assert_raises(TypeError,galsim.Shear,eta=0.5) - assert_raises(galsim.GalSimRangeError,galsim.Shear,eta=-0.5,beta=0.*galsim.radians) - assert_raises(galsim.GalSimRangeError,galsim.Shear,g=1.3,beta=0.*galsim.radians) - assert_raises(galsim.GalSimRangeError,galsim.Shear,g=-0.3,beta=0.*galsim.radians) + if is_jax_galsim(): + pass + else: + assert_raises(galsim.GalSimRangeError,galsim.Shear,eta=-0.5,beta=0.*galsim.radians) + assert_raises(galsim.GalSimRangeError,galsim.Shear,g=1.3,beta=0.*galsim.radians) + assert_raises(galsim.GalSimRangeError,galsim.Shear,g=-0.3,beta=0.*galsim.radians) assert_raises(TypeError,galsim.Shear,e=0.3,beta=0.) assert_raises(TypeError,galsim.Shear,eta=0.3,beta=0.) assert_raises(TypeError,galsim.Shear,randomkwarg=0.1) diff --git a/tests/test_shear_position.py b/tests/test_shear_position.py index 8b2689a98e..bb18d4bd99 100644 --- a/tests/test_shear_position.py +++ b/tests/test_shear_position.py @@ -88,7 +88,7 @@ def test_shear_position_image_integration_pixelwcs(): ) print("err:", np.max(np.abs(im1.array - im2.array))) - assert np.allclose(im1.array, im2.array, rtol=0, atol=5e-8) + np.testing.assert_allclose(im1.array, im2.array, rtol=0, atol=5e-8) @timer @@ -119,9 +119,7 @@ def test_shear_position_image_integration_offsetwcs(): ) print("err:", np.max(np.abs(im1.array - im2.array))) - assert np.allclose(im1.array, im2.array, rtol=0, atol=2e-7), ( - np.max(np.abs(im1.array - im2.array)) - ) + np.testing.assert_allclose(im1.array, im2.array, rtol=0, atol=2e-7) if __name__ == "__main__": diff --git a/tests/test_sum.py b/tests/test_sum.py index bd9b12ea55..1447d348d1 100644 --- a/tests/test_sum.py +++ b/tests/test_sum.py @@ -22,8 +22,8 @@ import galsim from galsim_test_helpers import * -imgdir = os.path.join(".", "SBProfile_comparison_images") # Directory containing the reference - # images. +# Directory containing the reference images. +imgdir = os.path.join(os.path.dirname(__file__), "SBProfile_comparison_images") @timer def test_add(): @@ -259,12 +259,12 @@ def test_sum_transform(): rgal2_im = rgal2.drawImage(nx=64, ny=64, scale=0.2) # Check that the objects are equivalent, even if they may be written differently. - np.testing.assert_almost_equal(gal1_im.array, sgal1_im.array, decimal=8) - np.testing.assert_almost_equal(gal1_im.array, rgal1_im.array, decimal=8) + np.testing.assert_array_almost_equal(gal1_im.array, sgal1_im.array, decimal=8) + np.testing.assert_array_almost_equal(gal1_im.array, rgal1_im.array, decimal=8) # These two used to fail. - np.testing.assert_almost_equal(gal2_im.array, sgal2_im.array, decimal=8) - np.testing.assert_almost_equal(gal2_im.array, rgal2_im.array, decimal=8) + np.testing.assert_array_almost_equal(gal2_im.array, sgal2_im.array, decimal=8) + np.testing.assert_array_almost_equal(gal2_im.array, rgal2_im.array, decimal=8) check_pickle(gal0) check_pickle(gal1) diff --git a/tests/test_table.py b/tests/test_table.py index cd1e419f08..d6b6bc09bb 100644 --- a/tests/test_table.py +++ b/tests/test_table.py @@ -16,7 +16,7 @@ # and/or other materials provided with the distribution. # - +import sys import os import sys import numpy as np diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 0d4a2c8276..63376c9513 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -22,8 +22,8 @@ import galsim from galsim_test_helpers import * -imgdir = os.path.join(".", "SBProfile_comparison_images") # Directory containing the reference - # images. +# Directory containing the reference images. +imgdir = os.path.join(os.path.dirname(__file__), "SBProfile_comparison_images") # for flux normalization tests test_flux = 1.8 @@ -947,9 +947,9 @@ def test_compound(): gal5.drawImage(image=im5_f, method='sb', scale=0.2) np.testing.assert_almost_equal(im3_f[1,1], gal3.xValue(-0.7,-0.7), decimal=4) np.testing.assert_almost_equal(im5_f[1,1], gal3.xValue(-0.7,-0.7), decimal=4) - np.testing.assert_almost_equal(im3_f.array, im5_f.array, decimal=4) - np.testing.assert_almost_equal(im3_f.array, im3_d.array, decimal=4) - np.testing.assert_almost_equal(im5_f.array, im5_d.array, decimal=4) + np.testing.assert_array_almost_equal(im3_f.array, im5_f.array, decimal=4) + np.testing.assert_array_almost_equal(im3_f.array, im3_d.array, decimal=4) + np.testing.assert_array_almost_equal(im5_f.array, im5_d.array, decimal=4) gal3.drawKImage(image=im3_cd, scale=0.5) gal5.drawKImage(image=im5_cd, scale=0.5) diff --git a/tests/test_utilities.py b/tests/test_utilities.py index bfbd039f60..41bec8252c 100644 --- a/tests/test_utilities.py +++ b/tests/test_utilities.py @@ -35,7 +35,10 @@ def test_pos(): assert pi1.y == 23 assert isinstance(pi1.x, int) assert isinstance(pi1.y, int) - assert isinstance(pi1._p, galsim._galsim.PositionI) + if is_jax_galsim(): + pass + else: + assert isinstance(pi1._p, galsim._galsim.PositionI) pi2 = galsim.PositionI((11,23)) pi3 = galsim.PositionI(x=11.0, y=23.0) @@ -43,7 +46,10 @@ def test_pos(): pi5 = galsim.PositionI(galsim.PositionD(11.0,23.0)) pi6 = galsim.PositionD(11.3,23.4).round() pi7 = pi2.round() - pi8 = galsim._PositionI(11,23) + if is_jax_galsim(): + pi8 = galsim.PositionI(11,23) + else: + pi8 = galsim._PositionI(11,23) assert pi2 == pi1 assert pi3 == pi1 assert pi4 == pi1 @@ -61,14 +67,20 @@ def test_pos(): assert pd1.y == 23. assert isinstance(pd1.x, float) assert isinstance(pd1.y, float) - assert isinstance(pd1._p, galsim._galsim.PositionD) + if is_jax_galsim(): + pass + else: + assert isinstance(pd1._p, galsim._galsim.PositionD) pd2 = galsim.PositionD((11,23)) pd3 = galsim.PositionD(x=11.0, y=23.0) pd4 = galsim.PositionD(pd1) pd5 = galsim.PositionD(pi1) pd6 = galsim.PositionD(galsim.PositionD(11.3,23.4).round()) - pd7 = galsim._PositionD(11.0,23.0) + if is_jax_galsim(): + pd7 = galsim.PositionD(11.0,23.0) + else: + pd7 = galsim._PositionD(11.0,23.0) assert pd2 == pd1 assert pd3 == pd1 assert pd4 == pd1 @@ -86,7 +98,10 @@ def test_pos(): assert_raises(TypeError, galsim.PositionI, x=11) assert_raises(TypeError, galsim.PositionD, x=11, y=23, z=17) assert_raises(TypeError, galsim.PositionI, 11, 23, x=13, z=21) - assert_raises(TypeError, galsim.PositionI, 11, 23.5) + if is_jax_galsim(): + pass + else: + assert_raises(TypeError, galsim.PositionI, 11, 23.5) assert_raises(TypeError, galsim.PositionD, 11) assert_raises(TypeError, galsim.PositionD, 11, 23, 9) @@ -94,7 +109,7 @@ def test_pos(): assert_raises(TypeError, galsim.PositionD, x=11) assert_raises(TypeError, galsim.PositionD, x=11, y=23, z=17) assert_raises(TypeError, galsim.PositionD, 11, 23, x=13, z=21) - assert_raises(ValueError, galsim.PositionD, 11, "blue") + assert_raises((ValueError, TypeError), galsim.PositionD, 11, "blue") # Can't use base class directly. assert_raises(TypeError, galsim.Position, 11, 23) @@ -170,7 +185,10 @@ def test_bounds(): assert isinstance(bi1.xmax, int) assert isinstance(bi1.ymin, int) assert isinstance(bi1.ymax, int) - assert isinstance(bi1._b, galsim._galsim.BoundsI) + if is_jax_galsim(): + pass + else: + assert isinstance(bi1._b, galsim._galsim.BoundsI) bi2 = galsim.BoundsI(galsim.PositionI(11,17), galsim.PositionI(23,50)) bi3 = galsim.BoundsI(galsim.PositionD(11.,50.), galsim.PositionD(23.,17.)) @@ -183,17 +201,20 @@ def test_bounds(): bi10 = galsim.BoundsI() + galsim.PositionI(11,17) + galsim.PositionI(23,50) bi11 = galsim.BoundsI(galsim.BoundsD(11.,23.,17.,50.)) bi12 = galsim.BoundsI(xmin=11,ymin=17,xmax=23,ymax=50) - bi13 = galsim._BoundsI(11,23,17,50) + if is_jax_galsim(): + bi13 = galsim.BoundsI(11,23,17,50) + else: + bi13 = galsim._BoundsI(11,23,17,50) bi14 = galsim.BoundsI() bi14 += galsim.PositionI(11,17) bi14 += galsim.PositionI(23,50) for b in [bi1, bi2, bi3, bi4, bi5, bi6, bi7, bi8, bi9, bi10, bi11, bi12, bi13, bi14]: assert b.isDefined() assert b == bi1 - assert isinstance(b.xmin, int) - assert isinstance(b.xmax, int) - assert isinstance(b.ymin, int) - assert isinstance(b.ymax, int) + assert_intlike(b.xmin) + assert_intlike(b.xmax) + assert_intlike(b.ymin) + assert_intlike(b.ymax) assert b.origin == galsim.PositionI(11, 17) assert b.center == galsim.PositionI(17, 34) assert b.true_center == galsim.PositionD(17, 33.5) @@ -203,11 +224,14 @@ def test_bounds(): assert bd1.xmax == bd1.getXMax() == 23. assert bd1.ymin == bd1.getYMin() == 17. assert bd1.ymax == bd1.getYMax() == 50. - assert isinstance(bd1.xmin, float) - assert isinstance(bd1.xmax, float) - assert isinstance(bd1.ymin, float) - assert isinstance(bd1.ymax, float) - assert isinstance(bd1._b, galsim._galsim.BoundsD) + assert_floatlike(bd1.xmin) + assert_floatlike(bd1.xmax) + assert_floatlike(bd1.ymin) + assert_floatlike(bd1.ymax) + if is_jax_galsim(): + pass + else: + assert isinstance(bd1._b, galsim._galsim.BoundsD) bd2 = galsim.BoundsD(galsim.PositionI(11,17), galsim.PositionI(23,50)) bd3 = galsim.BoundsD(galsim.PositionD(11.,50.), galsim.PositionD(23.,17.)) @@ -220,17 +244,20 @@ def test_bounds(): bd10 = galsim.BoundsD() + galsim.PositionD(11,17) + galsim.PositionD(23,50) bd11 = galsim.BoundsD(galsim.BoundsI(11,23,17,50)) bd12 = galsim.BoundsD(xmin=11.0,ymin=17.0,xmax=23.0,ymax=50.0) - bd13 = galsim._BoundsD(11,23,17,50) + if is_jax_galsim(): + bd13 = galsim.BoundsD(11,23,17,50) + else: + bd13 = galsim._BoundsD(11,23,17,50) bd14 = galsim.BoundsD() bd14 += galsim.PositionD(11.,17.) bd14 += galsim.PositionD(23,50) for b in [bd1, bd2, bd3, bd4, bd5, bd6, bd7, bd8, bd9, bd10, bd11, bd12, bd13, bd14]: assert b.isDefined() assert b == bd1 - assert isinstance(b.xmin, float) - assert isinstance(b.xmax, float) - assert isinstance(b.ymin, float) - assert isinstance(b.ymax, float) + assert_floatlike(b.xmin) + assert_floatlike(b.xmax) + assert_floatlike(b.ymin) + assert_floatlike(b.ymax) assert b.origin == galsim.PositionD(11, 17) assert b.center == galsim.PositionD(17, 33.5) assert b.true_center == galsim.PositionD(17, 33.5) @@ -241,7 +268,10 @@ def test_bounds(): assert_raises(TypeError, galsim.BoundsI, 11, 23, 9, 12, 59) assert_raises(TypeError, galsim.BoundsI, xmin=11, xmax=23, ymin=17, ymax=50, z=23) assert_raises(TypeError, galsim.BoundsI, xmin=11, xmax=50) - assert_raises(TypeError, galsim.BoundsI, 11, 23.5, 17, 50.9) + if is_jax_galsim(): + pass + else: + assert_raises(TypeError, galsim.BoundsI, 11, 23.5, 17, 50.9) assert_raises(TypeError, galsim.BoundsI, 11, 23, 9, 12, xmin=19, xmax=2) with assert_raises(TypeError): bi1 += (11,23) @@ -252,7 +282,11 @@ def test_bounds(): assert_raises(TypeError, galsim.BoundsD, 11, 23, 9, 12, 59) assert_raises(TypeError, galsim.BoundsD, xmin=11, xmax=23, ymin=17, ymax=50, z=23) assert_raises(TypeError, galsim.BoundsD, xmin=11, xmax=50) - assert_raises(ValueError, galsim.BoundsD, 11, 23, 17, "blue") + if is_jax_galsim(): + # jax doesn't raise for this + pass + else: + assert_raises(ValueError, galsim.BoundsD, 11, 23, 17, "blue") assert_raises(TypeError, galsim.BoundsD, 11, 23, 9, 12, xmin=19, xmax=2) with assert_raises(TypeError): bd1 += (11,23) @@ -373,15 +407,22 @@ def test_bounds(): assert galsim.BoundsD() == galsim.BoundsD() + galsim.BoundsD() assert galsim.BoundsD().area() == 0 - assert galsim.BoundsI(23, 11, 17, 50) == galsim.BoundsI() - assert galsim.BoundsI(11, 23, 50, 17) == galsim.BoundsI() - assert galsim.BoundsD(23, 11, 17, 50) == galsim.BoundsD() - assert galsim.BoundsD(11, 23, 50, 17) == galsim.BoundsD() - - assert_raises(galsim.GalSimUndefinedBoundsError, getattr, galsim.BoundsI(), 'center') - assert_raises(galsim.GalSimUndefinedBoundsError, getattr, galsim.BoundsD(), 'center') - assert_raises(galsim.GalSimUndefinedBoundsError, getattr, galsim.BoundsI(), 'true_center') - assert_raises(galsim.GalSimUndefinedBoundsError, getattr, galsim.BoundsD(), 'true_center') + if is_jax_galsim(): + pass + else: + assert galsim.BoundsI(23, 11, 17, 50) == galsim.BoundsI() + assert galsim.BoundsI(11, 23, 50, 17) == galsim.BoundsI() + assert galsim.BoundsD(23, 11, 17, 50) == galsim.BoundsD() + assert galsim.BoundsD(11, 23, 50, 17) == galsim.BoundsD() + + if is_jax_galsim(): + # jax doesn't raise for these things + pass + else: + assert_raises(galsim.GalSimUndefinedBoundsError, getattr, galsim.BoundsI(), 'center') + assert_raises(galsim.GalSimUndefinedBoundsError, getattr, galsim.BoundsD(), 'center') + assert_raises(galsim.GalSimUndefinedBoundsError, getattr, galsim.BoundsI(), 'true_center') + assert_raises(galsim.GalSimUndefinedBoundsError, getattr, galsim.BoundsD(), 'true_center') check_pickle(bi1) check_pickle(bd1) @@ -501,15 +542,14 @@ def test_check_all_contiguous(): @timer def test_deInterleaveImage(): - from galsim.utilities import deInterleaveImage, interleaveImages np.random.seed(84) # for generating the same random instances # 1) Check compatability with interleaveImages img = galsim.Image(np.random.randn(64,64),scale=0.25) img.setOrigin(galsim.PositionI(5,7)) ## for non-trivial bounds - im_list, offsets = deInterleaveImage(img,8) - img1 = interleaveImages(im_list,8,offsets) + im_list, offsets = galsim.utilities.deInterleaveImage(img,8) + img1 = galsim.utilities.interleaveImages(im_list,8,offsets) np.testing.assert_array_equal(img1.array,img.array, err_msg = "interleaveImages cannot reproduce the input to deInterleaveImage for square " "images") @@ -519,8 +559,8 @@ def test_deInterleaveImage(): img = galsim.Image(abs(np.random.randn(16*5,16*2)),scale=0.5) img.setCenter(0,0) ## for non-trivial bounds - im_list, offsets = deInterleaveImage(img,(2,5)) - img1 = interleaveImages(im_list,(2,5),offsets) + im_list, offsets = galsim.utilities.deInterleaveImage(img,(2,5)) + img1 = galsim.utilities.interleaveImages(im_list,(2,5),offsets) np.testing.assert_array_equal(img1.array,img.array, err_msg = "interleaveImages cannot reproduce the input to deInterleaveImage for " "rectangular images") @@ -530,7 +570,7 @@ def test_deInterleaveImage(): # 2) Checking for offsets img = galsim.Image(np.random.randn(32,32),scale=2.0) - im_list, offsets = deInterleaveImage(img,(4,2)) + im_list, offsets = galsim.utilities.deInterleaveImage(img,(4,2)) ## Checking if offsets are centered around zero assert np.sum([offset.x for offset in offsets]) == 0. @@ -547,7 +587,7 @@ def test_deInterleaveImage(): img0 = galsim.Image(32,32) g0.drawImage(image=img0,method='no_pixel',scale=0.25) - im_list0, offsets0 = deInterleaveImage(img0,2,conserve_flux=True) + im_list0, offsets0 = galsim.utilities.deInterleaveImage(img0,2,conserve_flux=True) for n in range(len(im_list0)): im = galsim.Image(16,16) @@ -570,8 +610,8 @@ def test_deInterleaveImage(): g1.drawImage(image=img1,scale=0.5/n1,method='no_pixel') g2.drawImage(image=img2,scale=0.5/n2,method='no_pixel') - im_list1, offsets1 = deInterleaveImage(img1,(n1**2,1),conserve_flux=True) - im_list2, offsets2 = deInterleaveImage(img2,[1,n2**2],conserve_flux=False) + im_list1, offsets1 = galsim.utilities.deInterleaveImage(img1,(n1**2,1),conserve_flux=True) + im_list2, offsets2 = galsim.utilities.deInterleaveImage(img2,[1,n2**2],conserve_flux=False) for n in range(n1**2): im, offset = im_list1[n], offsets1[n] @@ -588,26 +628,24 @@ def test_deInterleaveImage(): "horizontal direction") # im is scaled to account for flux not being conserved - assert_raises(TypeError, deInterleaveImage, image=img0.array, N=2) - assert_raises(TypeError, deInterleaveImage, image=img0, N=2.0) - assert_raises(TypeError, deInterleaveImage, image=img0, N=(2.0, 2.0)) - assert_raises(TypeError, deInterleaveImage, image=img0, N=(2,2,3)) - assert_raises(ValueError, deInterleaveImage, image=img0, N=7) - assert_raises(ValueError, deInterleaveImage, image=img0, N=(2,7)) - assert_raises(ValueError, deInterleaveImage, image=img0, N=(7,2)) + assert_raises(TypeError, galsim.utilities.deInterleaveImage, image=img0.array, N=2) + assert_raises(TypeError, galsim.utilities.deInterleaveImage, image=img0, N=2.0) + assert_raises(TypeError, galsim.utilities.deInterleaveImage, image=img0, N=(2.0, 2.0)) + assert_raises(TypeError, galsim.utilities.deInterleaveImage, image=img0, N=(2,2,3)) + assert_raises(ValueError, galsim.utilities.deInterleaveImage, image=img0, N=7) + assert_raises(ValueError, galsim.utilities.deInterleaveImage, image=img0, N=(2,7)) + assert_raises(ValueError, galsim.utilities.deInterleaveImage, image=img0, N=(7,2)) # It is legal to have the input image with wcs=None, but it emits a warning img0.wcs = None with assert_warns(galsim.GalSimWarning): - deInterleaveImage(img0, N=2) + galsim.utilities.deInterleaveImage(img0, N=2) # Unless suppress_warnings is True - deInterleaveImage(img0, N=2, suppress_warnings=True) + galsim.utilities.deInterleaveImage(img0, N=2, suppress_warnings=True) @timer def test_interleaveImages(): - from galsim.utilities import interleaveImages, deInterleaveImage - # 1a) With galsim Gaussian g = galsim.Gaussian(sigma=3.7,flux=1000.) gal = galsim.Convolve([g,galsim.Pixel(1.0)]) @@ -625,7 +663,7 @@ def test_interleaveImages(): scale = im.scale # Input to N as an int - img = interleaveImages(im_list,n,offsets=offset_list) + img = galsim.utilities.interleaveImages(im_list,n,offsets=offset_list) im = galsim.Image(16*n*n,16*n*n) g = galsim.Gaussian(sigma=3.7,flux=1000.*n*n) gal = galsim.Convolve([g,galsim.Pixel(1.0)]) @@ -651,7 +689,7 @@ def test_interleaveImages(): im_list_randperm = [im_list[idx] for idx in rand_idx] offset_list_randperm = [offset_list[idx] for idx in rand_idx] # Input to N as a tuple - img_randperm = interleaveImages(im_list_randperm,(n,n),offsets=offset_list_randperm) + img_randperm = galsim.utilities.interleaveImages(im_list_randperm,(n,n),offsets=offset_list_randperm) np.testing.assert_array_equal(img_randperm.array,img.array, err_msg="Interleaved images do not match when 'offsets' is supplied") @@ -674,9 +712,9 @@ def test_interleaveImages(): N = (n,n) with assert_raises(ValueError): - interleaveImages(im_list,N,offset_list) + galsim.utilities.interleaveImages(im_list,N,offset_list) # Can turn off the checks and just use these as they are with catch_offset_errors=False - interleaveImages(im_list,N,offset_list, catch_offset_errors=False) + galsim.utilities.interleaveImages(im_list,N,offset_list, catch_offset_errors=False) offset_list = [] im_list = [] @@ -693,8 +731,8 @@ def test_interleaveImages(): N = (n,n) with assert_raises(ValueError): - interleaveImages(im_list, N, offset_list) - interleaveImages(im_list, N, offset_list, catch_offset_errors=False) + galsim.utilities.interleaveImages(im_list, N, offset_list) + galsim.utilities.interleaveImages(im_list, N, offset_list, catch_offset_errors=False) # 2a) Increase resolution along one direction - square to rectangular images n = 2 @@ -713,8 +751,8 @@ def test_interleaveImages(): gal1.drawImage(im,offset=offset,method='no_pixel',scale=2.0) im_list.append(im) - img = interleaveImages(im_list, N=[1,n**2], offsets=offset_list, - add_flux=False, suppress_warnings=True) + img = galsim.utilities.interleaveImages(im_list, N=[1,n**2], offsets=offset_list, + add_flux=False, suppress_warnings=True) im = galsim.Image(16,16*n*n) # The interleaved image has the total flux averaged out since `add_flux = False' gal = galsim.Gaussian(sigma=3.7*n,flux=100.) @@ -741,7 +779,7 @@ def test_interleaveImages(): gal2.drawImage(im,offset=offset,method='no_pixel',scale=3.0) im_list.append(im) - img = interleaveImages(im_list, N=np.array([n**2,1]), offsets=offset_list, + img = galsim.utilities.interleaveImages(im_list, N=np.array([n**2,1]), offsets=offset_list, suppress_warnings=True) im = galsim.Image(16*n*n,16*n*n) gal = galsim.Gaussian(sigma=3.7,flux=100.*n*n) @@ -770,8 +808,8 @@ def test_interleaveImages(): im.setOrigin(3,3) # for non-trivial bounds im_list.append(im) - img = interleaveImages(im_list,N=n,offsets=offset_list) - im_list_1, offset_list_1 = deInterleaveImage(img, N=n) + img = galsim.utilities.interleaveImages(im_list,N=n,offsets=offset_list) + im_list_1, offset_list_1 = galsim.utilities.deInterleaveImage(img, N=n) for k in range(n**2): assert offset_list_1[k] == offset_list[k] @@ -782,50 +820,50 @@ def test_interleaveImages(): assert im_list[k].bounds == im_list_1[k].bounds # Checking for non-default flux option - img = interleaveImages(im_list,N=n,offsets=offset_list,add_flux=False) - im_list_2, offset_list_2 = deInterleaveImage(img,N=n,conserve_flux=True) + img = galsim.utilities.interleaveImages(im_list,N=n,offsets=offset_list,add_flux=False) + im_list_2, offset_list_2 = galsim.utilities.deInterleaveImage(img,N=n,conserve_flux=True) for k in range(n**2): assert offset_list_2[k] == offset_list[k] np.testing.assert_array_equal(im_list_2[k].array, im_list[k].array) assert im_list_2[k].wcs == im_list[k].wcs - assert_raises(TypeError, interleaveImages, im_list=img, N=n, offsets=offset_list) - assert_raises(ValueError, interleaveImages, [img], N=1, offsets=offset_list) - assert_raises(ValueError, interleaveImages, im_list, n, offset_list[:-1]) - assert_raises(TypeError, interleaveImages, [im.array for im in im_list], n, offset_list) - assert_raises(TypeError, interleaveImages, + assert_raises(TypeError, galsim.utilities.interleaveImages, im_list=img, N=n, offsets=offset_list) + assert_raises(ValueError, galsim.utilities.interleaveImages, [img], N=1, offsets=offset_list) + assert_raises(ValueError, galsim.utilities.interleaveImages, im_list, n, offset_list[:-1]) + assert_raises(TypeError, galsim.utilities.interleaveImages, [im.array for im in im_list], n, offset_list) + assert_raises(TypeError, galsim.utilities.interleaveImages, [im_list[0]] + [im.array for im in im_list[1:]], n, offset_list) - assert_raises(TypeError, interleaveImages, + assert_raises(TypeError, galsim.utilities.interleaveImages, [galsim.Image(16+i,16+j,scale=1) for i in range(n) for j in range(n)], n, offset_list) - assert_raises(TypeError, interleaveImages, + assert_raises(TypeError, galsim.utilities.interleaveImages, [galsim.Image(16,16,scale=i) for i in range(n) for j in range(n)], n, offset_list) - assert_raises(TypeError, interleaveImages, im_list, N=3.0, offsets=offset_list) - assert_raises(TypeError, interleaveImages, im_list, N=(3.0, 3.0), offsets=offset_list) - assert_raises(TypeError, interleaveImages, im_list, N=(3,3,3), offsets=offset_list) - assert_raises(ValueError, interleaveImages, im_list, N=7, offsets=offset_list) - assert_raises(ValueError, interleaveImages, im_list, N=(2,7), offsets=offset_list) - assert_raises(ValueError, interleaveImages, im_list, N=(7,2), offsets=offset_list) - assert_raises(TypeError, interleaveImages, im_list, N=n) - assert_raises(TypeError, interleaveImages, im_list, N=n, offsets=offset_list[0]) - assert_raises(TypeError, interleaveImages, im_list, N=n, offsets=range(n*n)) + assert_raises(TypeError, galsim.utilities.interleaveImages, im_list, N=3.0, offsets=offset_list) + assert_raises(TypeError, galsim.utilities.interleaveImages, im_list, N=(3.0, 3.0), offsets=offset_list) + assert_raises(TypeError, galsim.utilities.interleaveImages, im_list, N=(3,3,3), offsets=offset_list) + assert_raises(ValueError, galsim.utilities.interleaveImages, im_list, N=7, offsets=offset_list) + assert_raises(ValueError, galsim.utilities.interleaveImages, im_list, N=(2,7), offsets=offset_list) + assert_raises(ValueError, galsim.utilities.interleaveImages, im_list, N=(7,2), offsets=offset_list) + assert_raises(TypeError, galsim.utilities.interleaveImages, im_list, N=n) + assert_raises(TypeError, galsim.utilities.interleaveImages, im_list, N=n, offsets=offset_list[0]) + assert_raises(TypeError, galsim.utilities.interleaveImages, im_list, N=n, offsets=range(n*n)) # It is legal to have the input images with wcs=None, but it emits a warning for im in im_list: im.wcs = None with assert_warns(galsim.GalSimWarning): - interleaveImages(im_list, N=n, offsets=offset_list) + galsim.utilities.interleaveImages(im_list, N=n, offsets=offset_list) # Unless suppress_warnings is True - interleaveImages(im_list, N=n, offsets=offset_list, suppress_warnings=True) + galsim.utilities.interleaveImages(im_list, N=n, offsets=offset_list, suppress_warnings=True) # Also legal to have different origins im_list[0].setCenter(0,0) with assert_warns(galsim.GalSimWarning): - interleaveImages(im_list, N=n, offsets=offset_list) - interleaveImages(im_list, N=n, offsets=offset_list, suppress_warnings=True) + galsim.utilities.interleaveImages(im_list, N=n, offsets=offset_list) + galsim.utilities.interleaveImages(im_list, N=n, offsets=offset_list, suppress_warnings=True) @timer @@ -1112,67 +1150,73 @@ def test_horner(): # Make a random list of values to test x = np.empty(20) rng = galsim.UniformDeviate(1234) - rng.generate(x) + if is_jax_galsim(): + x = rng.generate(x) + else: + rng.generate(x) # Check against the direct calculation truth = coef[0] + coef[1]*x + coef[2]*x**2 + coef[3]*x**3 + coef[4]*x**4 result = galsim.utilities.horner(x, coef) - np.testing.assert_almost_equal(result, truth) + np.testing.assert_array_almost_equal(result, truth) # Also check against the (slower) numpy code - np.testing.assert_almost_equal(result, np.polynomial.polynomial.polyval(x,coef)) + np.testing.assert_array_almost_equal(result, np.polynomial.polynomial.polyval(x,coef)) # Check that trailing zeros give the same answer result = galsim.utilities.horner(x, coef + [0]*3) - np.testing.assert_almost_equal(result, truth) + np.testing.assert_array_almost_equal(result, truth) # Check that leading zeros give the right answer result = galsim.utilities.horner(x, [0]*3 + coef) - np.testing.assert_almost_equal(result, truth*x**3) + np.testing.assert_array_almost_equal(result, truth*x**3) # Check using a different dtype result = galsim.utilities.horner(x, coef, dtype=complex) - np.testing.assert_almost_equal(result, truth) + np.testing.assert_array_almost_equal(result, truth) # Check that a single element coef gives the right answer result = galsim.utilities.horner([1,2,3], [17]) - np.testing.assert_almost_equal(result, 17) + np.testing.assert_array_almost_equal(result, 17) result = galsim.utilities.horner(x, [17]) - np.testing.assert_almost_equal(result, 17) + np.testing.assert_array_almost_equal(result, 17) result = galsim.utilities.horner([1,2,3], [17,0,0,0]) - np.testing.assert_almost_equal(result, 17) + np.testing.assert_array_almost_equal(result, 17) result = galsim.utilities.horner(x, [17,0,0,0]) - np.testing.assert_almost_equal(result, 17) + np.testing.assert_array_almost_equal(result, 17) result = galsim.utilities.horner([1,2,3], [0,0,0,0]) - np.testing.assert_almost_equal(result, 0) + np.testing.assert_array_almost_equal(result, 0) result = galsim.utilities.horner(x, [0,0,0,0]) - np.testing.assert_almost_equal(result, 0) + np.testing.assert_array_almost_equal(result, 0) # Check that x may be non-contiguous result = galsim.utilities.horner(x[::3], coef) - np.testing.assert_almost_equal(result, np.polynomial.polynomial.polyval(x[::3],coef)) + np.testing.assert_array_almost_equal(result, np.polynomial.polynomial.polyval(x[::3],coef)) # Check that coef may be non-contiguous result = galsim.utilities.horner(x, coef[::-1]) - np.testing.assert_almost_equal(result, np.polynomial.polynomial.polyval(x,coef[::-1])) + np.testing.assert_array_almost_equal(result, np.polynomial.polynomial.polyval(x,coef[::-1])) # Check odd length result = galsim.utilities.horner(x[:15], coef) - np.testing.assert_almost_equal(result, np.polynomial.polynomial.polyval(x[:15],coef)) + np.testing.assert_array_almost_equal(result, np.polynomial.polynomial.polyval(x[:15],coef)) # Check unaligned array result = galsim.utilities.horner(x[1:], coef) - np.testing.assert_almost_equal(result, np.polynomial.polynomial.polyval(x[1:],coef)) + np.testing.assert_array_almost_equal(result, np.polynomial.polynomial.polyval(x[1:],coef)) # Check length > 64 xx = np.empty(2000) - rng.generate(xx) + if is_jax_galsim(): + xx = rng.generate(xx) + else: + rng.generate(xx) result = galsim.utilities.horner(xx, coef) - np.testing.assert_almost_equal(result, np.polynomial.polynomial.polyval(xx,coef)) + np.testing.assert_array_almost_equal(result, np.polynomial.polynomial.polyval(xx,coef)) # Check scalar x result = galsim.utilities.horner(3.9, coef) - np.testing.assert_almost_equal(result, np.polynomial.polynomial.polyval([3.9],coef)) + np.testing.assert_array_almost_equal(result, np.polynomial.polynomial.polyval([3.9],coef)) # Check invalid arguments with assert_raises(galsim.GalSimValueError): @@ -1190,78 +1234,86 @@ def test_horner2d(): x = np.empty(20) y = np.empty(20) rng = galsim.UniformDeviate(1234) - rng.generate(x) - rng.generate(y) + if is_jax_galsim(): + x = rng.generate(x) + y = rng.generate(y) + else: + rng.generate(x) + rng.generate(y) # Check against the direct calculation truth = coef[0,0] + coef[0,1]*y + coef[0,2]*y**2 + coef[0,3]*y**3 + coef[0,4]*y**4 truth += (coef[1,0] + coef[1,1]*y + coef[1,2]*y**2 + coef[1,3]*y**3 + coef[1,4]*y**4)*x truth += (coef[2,0] + coef[2,1]*y + coef[2,2]*y**2 + coef[2,3]*y**3 + coef[2,4]*y**4)*x**2 result = galsim.utilities.horner2d(x, y, coef) - np.testing.assert_almost_equal(result, truth) + np.testing.assert_array_almost_equal(result, truth) # Also check against the (slower) numpy code - np.testing.assert_almost_equal(result, np.polynomial.polynomial.polyval2d(x,y,coef)) + np.testing.assert_array_almost_equal(result, np.polynomial.polynomial.polyval2d(x,y,coef)) # Check that trailing zeros give the same answer result = galsim.utilities.horner2d(x, y, np.hstack([coef, np.zeros((3,1))])) - np.testing.assert_almost_equal(result, truth) + np.testing.assert_array_almost_equal(result, truth) result = galsim.utilities.horner2d(x, y, np.hstack([coef, np.zeros((3,6))])) - np.testing.assert_almost_equal(result, truth) + np.testing.assert_array_almost_equal(result, truth) result = galsim.utilities.horner2d(x, y, np.vstack([coef, np.zeros((1,5))])) - np.testing.assert_almost_equal(result, truth) + np.testing.assert_array_almost_equal(result, truth) result = galsim.utilities.horner2d(x, y, np.vstack([coef, np.zeros((6,5))])) - np.testing.assert_almost_equal(result, truth) + np.testing.assert_array_almost_equal(result, truth) # Check that leading zeros give the right answer result = galsim.utilities.horner2d(x, y, np.hstack([np.zeros((3,1)), coef])) - np.testing.assert_almost_equal(result, truth*y) + np.testing.assert_array_almost_equal(result, truth*y) result = galsim.utilities.horner2d(x, y, np.hstack([np.zeros((3,6)), coef])) - np.testing.assert_almost_equal(result, truth*y**6) + np.testing.assert_array_almost_equal(result, truth*y**6) result = galsim.utilities.horner2d(x, y, np.vstack([np.zeros((1,5)), coef])) - np.testing.assert_almost_equal(result, truth*x) + np.testing.assert_array_almost_equal(result, truth*x) result = galsim.utilities.horner2d(x, y, np.vstack([np.zeros((6,5)), coef])) - np.testing.assert_almost_equal(result, truth*x**6) + np.testing.assert_array_almost_equal(result, truth*x**6) # Check using a different dtype result = galsim.utilities.horner2d(x, y, coef, dtype=complex) - np.testing.assert_almost_equal(result, truth) + np.testing.assert_array_almost_equal(result, truth) # Check that x,y may be non-contiguous result = galsim.utilities.horner2d(x[::3], y[:7], coef) - np.testing.assert_almost_equal(result, np.polynomial.polynomial.polyval2d(x[::3],y[:7],coef)) + np.testing.assert_array_almost_equal(result, np.polynomial.polynomial.polyval2d(x[::3],y[:7],coef)) result = galsim.utilities.horner2d(x[:7], y[::-3], coef) - np.testing.assert_almost_equal(result, np.polynomial.polynomial.polyval2d(x[:7],y[::-3],coef)) + np.testing.assert_array_almost_equal(result, np.polynomial.polynomial.polyval2d(x[:7],y[::-3],coef)) # Check that coef may be non-contiguous result = galsim.utilities.horner2d(x, y, coef[:,::-1]) - np.testing.assert_almost_equal(result, np.polynomial.polynomial.polyval2d(x,y,coef[:,::-1])) + np.testing.assert_array_almost_equal(result, np.polynomial.polynomial.polyval2d(x,y,coef[:,::-1])) result = galsim.utilities.horner2d(x, y, coef[::-1,:]) - np.testing.assert_almost_equal(result, np.polynomial.polynomial.polyval2d(x,y,coef[::-1,:])) + np.testing.assert_array_almost_equal(result, np.polynomial.polynomial.polyval2d(x,y,coef[::-1,:])) # Check odd length result = galsim.utilities.horner2d(x[:15], y[:15], coef) - np.testing.assert_almost_equal(result, np.polynomial.polynomial.polyval2d(x[:15],y[:15],coef)) + np.testing.assert_array_almost_equal(result, np.polynomial.polynomial.polyval2d(x[:15],y[:15],coef)) # Check unaligned array result = galsim.utilities.horner2d(x[1:], y[1:], coef) - np.testing.assert_almost_equal(result, np.polynomial.polynomial.polyval2d(x[1:],y[1:],coef)) + np.testing.assert_array_almost_equal(result, np.polynomial.polynomial.polyval2d(x[1:],y[1:],coef)) result = galsim.utilities.horner2d(x[1:], y[:-1], coef) - np.testing.assert_almost_equal(result, np.polynomial.polynomial.polyval2d(x[1:],y[:-1],coef)) + np.testing.assert_array_almost_equal(result, np.polynomial.polynomial.polyval2d(x[1:],y[:-1],coef)) result = galsim.utilities.horner2d(x[:-1], y[1:], coef) - np.testing.assert_almost_equal(result, np.polynomial.polynomial.polyval2d(x[:-1],y[1:],coef)) + np.testing.assert_array_almost_equal(result, np.polynomial.polynomial.polyval2d(x[:-1],y[1:],coef)) # Check length > 64 xx = np.empty(2000) yy = np.empty(2000) - rng.generate(xx) - rng.generate(yy) + if is_jax_galsim(): + xx = rng.generate(xx) + yy = rng.generate(yy) + else: + rng.generate(xx) + rng.generate(yy) result = galsim.utilities.horner2d(xx, yy, coef) - np.testing.assert_almost_equal(result, np.polynomial.polynomial.polyval2d(xx,yy,coef)) + np.testing.assert_array_almost_equal(result, np.polynomial.polynomial.polyval2d(xx,yy,coef)) # Check scalar x, y result = galsim.utilities.horner2d(3.9, 1.7, coef) - np.testing.assert_almost_equal(result, np.polynomial.polynomial.polyval2d([3.9],[1.7],coef)) + np.testing.assert_array_almost_equal(result, np.polynomial.polynomial.polyval2d([3.9],[1.7],coef)) # Check the triangle = True option @@ -1275,13 +1327,13 @@ def test_horner2d(): truth += coef[1,0]*x + coef[1,1]*x*y truth += coef[2,0]*x**2 result = galsim.utilities.horner2d(x, y, coef) - np.testing.assert_almost_equal(result, truth) + np.testing.assert_array_almost_equal(result, truth) result = galsim.utilities.horner2d(x, y, coef, triangle=True) - np.testing.assert_almost_equal(result, truth) + np.testing.assert_array_almost_equal(result, truth) # Check using a different dtype result = galsim.utilities.horner2d(x, y, coef, dtype=complex, triangle=True) - np.testing.assert_almost_equal(result, truth) + np.testing.assert_array_almost_equal(result, truth) # Check invalid arguments with assert_raises(galsim.GalSimValueError): @@ -1310,60 +1362,67 @@ def test_horner_complex(): rx = np.empty(20) ry = np.empty(20) rng = galsim.UniformDeviate(1234) - rng.generate(rx) - rng.generate(ry) + if is_jax_galsim(): + rx = rng.generate(rx) + ry = rng.generate(ry) + else: + rng.generate(rx) + rng.generate(ry) ix = np.empty(20) iy = np.empty(20) rng = galsim.UniformDeviate(1234) - rng.generate(ix) - rng.generate(iy) - + if is_jax_galsim(): + ix = rng.generate(ix) + iy = rng.generate(iy) + else: + rng.generate(ix) + rng.generate(iy) x = rx + 1j*ix y = ry + 1j*iy # Check all combinations of which things are complex and which are real. # First, just 1 of the three complex: result = galsim.utilities.horner2d(rx, ry, coef, dtype=complex) - np.testing.assert_almost_equal(result, np.polynomial.polynomial.polyval2d(rx, ry, coef)) + np.testing.assert_array_almost_equal(result, np.polynomial.polynomial.polyval2d(rx, ry, coef)) result = galsim.utilities.horner2d(rx, y, rcoef, dtype=complex) - np.testing.assert_almost_equal(result, np.polynomial.polynomial.polyval2d(rx, y, rcoef)) + np.testing.assert_array_almost_equal(result, np.polynomial.polynomial.polyval2d(rx, y, rcoef)) result = galsim.utilities.horner2d(x, ry, rcoef, dtype=complex) - np.testing.assert_almost_equal(result, np.polynomial.polynomial.polyval2d(x, ry, rcoef)) + np.testing.assert_array_almost_equal(result, np.polynomial.polynomial.polyval2d(x, ry, rcoef)) # Now two complex: result = galsim.utilities.horner2d(rx, y, coef, dtype=complex) - np.testing.assert_almost_equal(result, np.polynomial.polynomial.polyval2d(rx, y, coef)) + np.testing.assert_array_almost_equal(result, np.polynomial.polynomial.polyval2d(rx, y, coef)) result = galsim.utilities.horner2d(x, ry, coef, dtype=complex) - np.testing.assert_almost_equal(result, np.polynomial.polynomial.polyval2d(x, ry, coef)) + np.testing.assert_array_almost_equal(result, np.polynomial.polynomial.polyval2d(x, ry, coef)) result = galsim.utilities.horner2d(x, y, rcoef, dtype=complex) - np.testing.assert_almost_equal(result, np.polynomial.polynomial.polyval2d(x, y, rcoef)) + np.testing.assert_array_almost_equal(result, np.polynomial.polynomial.polyval2d(x, y, rcoef)) # All three complex result = galsim.utilities.horner2d(x, y, coef, dtype=complex) - np.testing.assert_almost_equal(result, np.polynomial.polynomial.polyval2d(x, y, coef)) + np.testing.assert_array_almost_equal(result, np.polynomial.polynomial.polyval2d(x, y, coef)) # Check scalar complex x, y result = galsim.utilities.horner2d(3.9+2.1j, 1.7-0.9j, coef, dtype=complex) - np.testing.assert_almost_equal(result, np.polynomial.polynomial.polyval2d( + np.testing.assert_array_almost_equal(result, np.polynomial.polynomial.polyval2d( [3.9+2.1j],[1.7-0.9j],coef)) # Repeast for 1d result = galsim.utilities.horner(rx, coef[0], dtype=complex) - np.testing.assert_almost_equal(result, np.polynomial.polynomial.polyval(rx, coef[0])) + np.testing.assert_array_almost_equal(result, np.polynomial.polynomial.polyval(rx, coef[0])) result = galsim.utilities.horner(x, rcoef[0], dtype=complex) - np.testing.assert_almost_equal(result, np.polynomial.polynomial.polyval(x, rcoef[0])) + np.testing.assert_array_almost_equal(result, np.polynomial.polynomial.polyval(x, rcoef[0])) result = galsim.utilities.horner(x, coef[0], dtype=complex) - np.testing.assert_almost_equal(result, np.polynomial.polynomial.polyval(x, coef[0])) + np.testing.assert_array_almost_equal(result, np.polynomial.polynomial.polyval(x, coef[0])) result = galsim.utilities.horner(3.9+2.1j, coef[0], dtype=complex) - np.testing.assert_almost_equal(result, np.polynomial.polynomial.polyval([3.9+2.1j],coef[0])) + np.testing.assert_array_almost_equal(result, np.polynomial.polynomial.polyval([3.9+2.1j],coef[0])) @timer def test_merge_sorted(): diff --git a/tests/test_wcs.py b/tests/test_wcs.py index ff5699060e..f945e35598 100644 --- a/tests/test_wcs.py +++ b/tests/test_wcs.py @@ -24,6 +24,7 @@ import coord from unittest import mock +import coord import galsim from galsim_test_helpers import * @@ -499,7 +500,7 @@ def do_local_wcs(wcs, ufunc, vfunc, name): wcs2 = wcs.local() assert wcs == wcs2, name+' local() is not == the original' new_origin = galsim.PositionI(123,321) - wcs3 = wcs.withOrigin(new_origin) + wcs3 = wcs.shiftOrigin(new_origin) assert wcs != wcs3, name+' is not != wcs.withOrigin(pos)' assert wcs3 != wcs, name+' is not != wcs.withOrigin(pos) (reverse)' wcs2 = wcs3.local() @@ -513,7 +514,7 @@ def do_local_wcs(wcs, ufunc, vfunc, name): world_pos2.y, world_pos1.y, digits, 'withOrigin(new_origin) returned wrong world position') new_world_origin = galsim.PositionD(5352.7, 9234.3) - wcs4 = wcs.withOrigin(new_origin, new_world_origin) + wcs4 = wcs.shiftOrigin(new_origin, new_world_origin) world_pos3 = wcs4.toWorld(new_origin) np.testing.assert_almost_equal( world_pos3.x, new_world_origin.x, digits, @@ -1181,7 +1182,10 @@ def test_pixelscale(): assert wcs.world_origin == galsim.PositionD(0,0) assert_raises(TypeError, galsim.PixelScale) - assert_raises(TypeError, galsim.PixelScale, scale=galsim.PixelScale(scale)) + if is_jax_galsim(): + pass + else: + assert_raises(TypeError, galsim.PixelScale, scale=galsim.PixelScale(scale)) assert_raises(TypeError, galsim.PixelScale, scale=scale, origin=galsim.PositionD(0,0)) assert_raises(TypeError, galsim.PixelScale, scale=scale, world_origin=galsim.PositionD(0,0)) @@ -2370,7 +2374,7 @@ def test_gsfitswcs(): # And it's required to get (relatively) complete test coverage. test_tags = [ 'TAN', 'STG', 'ZEA', 'ARC', 'TPV', 'TAN-PV', 'TAN-FLIP', 'TNX', 'SIP', 'ZTF' ] - dir = 'fits_files' + dir = os.path.join(os.path.dirname(__file__), 'fits_files') for tag in test_tags: file_name, ref_list = references[tag] @@ -2463,32 +2467,44 @@ def test_inverseab_convergence(): # Now one that should fail, since it's well outside the applicable area for the SIP polynomials. ra = 2.1 dec = -0.45 - with assert_raises(galsim.GalSimError): + if is_jax_galsim(): x, y = wcs.radecToxy(ra, dec, units="radians") - try: - x, y = wcs.radecToxy(ra, dec, units="radians") - except galsim.GalSimError as e: - print('Error message is\n',e) - assert "[0,]" in str(e) + assert np.all(np.isnan(x)) + assert np.all(np.isnan(y)) + else: + with assert_raises(galsim.GalSimError): + x, y = wcs.radecToxy(ra, dec, units="radians") + try: + x, y = wcs.radecToxy(ra, dec, units="radians") + except galsim.GalSimError as e: + print('Error message is\n',e) + assert "[0,]" in str(e) or "[0]" in str(e) # Check as part of a longer list (longer than 256 is important) - ra = np.random.uniform(2.185, 2.186, 1000) - dec = np.random.uniform(-0.501, -0.499, 1000) + rng = np.random.RandomState(1234) + ra = rng.uniform(2.185, 2.186, 1000) + dec = rng.uniform(-0.501, -0.499, 1000) ra = np.append(ra, [2.1, 2.9]) dec = np.append(dec, [-0.45, 0.2]) print('ra = ',ra) print('dec = ',dec) - with assert_raises(galsim.GalSimError): + if is_jax_galsim(): x, y = wcs.radecToxy(ra, dec, units="radians") - try: - x, y = wcs.radecToxy(ra, dec, units="radians") - except galsim.GalSimError as e: - print('Error message is\n',e) - assert "[1000,1001,]" in str(e) - # We don't currently do this for the user, but it's not too hard to get a python list - # of the bad indices. Included here as an example for users who may need this. - bad = eval(str(e)[str(e).rfind('['):]) - print('as a python list: ',bad) + assert np.sum(np.isnan(x)) >= 2 + assert np.sum(np.isnan(y)) >= 2 + else: + with assert_raises(galsim.GalSimError): + x, y = wcs.radecToxy(ra, dec, units="radians") + try: + x, y = wcs.radecToxy(ra, dec, units="radians") + except galsim.GalSimError as e: + print('Error message is\n',e) + assert "[1000,1001,]" in str(e) or "[1000, 1001]" in str(e) + # We don't currently do this for the user, but it's not too hard to get a python list + # of the bad indices. Included here as an example for users who may need this. + bad = eval(str(e)[str(e).rfind('['):]) + print('as a python list: ',bad) + @timer def test_tanwcs(): @@ -2580,7 +2596,7 @@ def test_fitswcs(): except: pass - dir = 'fits_files' + dir = os.path.join(os.path.dirname(__file__), 'fits_files') for tag in test_tags: file_name, ref_list = references[tag] @@ -2688,7 +2704,7 @@ def test_fittedsipwcs(): 'ZTF': (0.1, 0.1), } - dir = 'fits_files' + dir = os.path.join(os.path.dirname(__file__), 'fits_files') if __name__ == "__main__": test_tags = all_tags @@ -2917,7 +2933,7 @@ def test_fittedsipwcs(): def test_scamp(): """Test that we can read in a SCamp .head file correctly """ - dir = 'fits_files' + dir = os.path.join(os.path.dirname(__file__), 'fits_files') file_name = 'scamp.head' wcs = galsim.FitsWCS(file_name, dir=dir, text_file=True) @@ -3138,7 +3154,7 @@ def test_int_args(): test_tags = all_tags - dir = 'fits_files' + dir = os.path.join(os.path.dirname(__file__), 'fits_files') for tag in test_tags: file_name, ref_list = references[tag] @@ -3161,7 +3177,7 @@ def test_int_args(): # Along the way, check issue #1024 where Erin noticed that reading the WCS from the # header of a compressed file was spending lots of time decompressing the data, which # is unnecessary. - dir = 'des_data' + dir = os.path.join(os.path.dirname(__file__), 'des_data') file_name = 'DECam_00158414_01.fits.fz' with Profile(): t0 = time.time() @@ -3211,7 +3227,7 @@ def test_razero(): import astropy.wcs import scipy # AstropyWCS constructor will do this, so check now. - dir = 'fits_files' + dir = os.path.join(os.path.dirname(__file__), 'fits_files') # This file is based in sipsample.fits, but with the CRVAL1 changed to 0.002322805429 file_name = 'razero.fits' wcs = galsim.AstropyWCS(file_name, dir=dir) diff --git a/tests/test_zernike.py b/tests/test_zernike.py index 254aeaad0d..d53a4c3087 100644 --- a/tests/test_zernike.py +++ b/tests/test_zernike.py @@ -19,8 +19,7 @@ import numpy as np import galsim -from galsim.zernike import Zernike, DoubleZernike -from galsim_test_helpers import * +from galsim_test_helpers import timer, check_pickle, assert_raises, check_all_diff, is_jax_galsim @timer @@ -41,10 +40,10 @@ def test_Zernike_orthonormality(): y = y[w].ravel() area = np.pi*R_outer**2 for j1 in range(1, jmax+1): - Z1 = Zernike([0]*(j1+1)+[1], R_outer=R_outer) + Z1 = galsim.zernike.Zernike([0]*(j1+1)+[1], R_outer=R_outer) val1 = Z1.evalCartesian(x, y) for j2 in range(j1, jmax+1): - Z2 = Zernike([0]*(j2+1)+[1], R_outer=R_outer) + Z2 = galsim.zernike.Zernike([0]*(j2+1)+[1], R_outer=R_outer) val2 = Z2.evalCartesian(x, y) integral = np.dot(val1, val2) * dx**2 if j1 == j2: @@ -73,10 +72,10 @@ def test_Zernike_orthonormality(): y = y[w].ravel() area = np.pi*(R_outer**2 - R_inner**2) for j1 in range(1, jmax+1): - Z1 = Zernike([0]*(j1+1)+[1], R_outer=R_outer, R_inner=R_inner) + Z1 = galsim.zernike.Zernike([0]*(j1+1)+[1], R_outer=R_outer, R_inner=R_inner) val1 = Z1.evalCartesian(x, y) for j2 in range(j1, jmax+1): - Z2 = Zernike([0]*(j2+1)+[1], R_outer=R_outer, R_inner=R_inner) + Z2 = galsim.zernike.Zernike([0]*(j2+1)+[1], R_outer=R_outer, R_inner=R_inner) val2 = Z2.evalCartesian(x, y) integral = np.dot(val1, val2) * dx**2 if j1 == j2: @@ -93,7 +92,7 @@ def test_Zernike_orthonormality(): check_pickle(Z1, lambda z: tuple(z.evalCartesian(x, y))) with assert_raises(ValueError): - Z1 = Zernike([0]*4 + [0.1]*7, R_outer=R_inner, R_inner=R_outer) + Z1 = galsim.zernike.Zernike([0]*4 + [0.1]*7, R_outer=R_inner, R_inner=R_outer) val1 = Z1.evalCartesian(x, y) @@ -202,13 +201,13 @@ def test_Zernike_rotate(): R_inner = R_outer*eps coefs = [u() for _ in range(jmax+1)] - Z = Zernike(coefs, R_outer=R_outer, R_inner=R_inner) + Z = galsim.zernike.Zernike(coefs, R_outer=R_outer, R_inner=R_inner) check_pickle(Z) for theta in [0.0, 0.1, 1.0, np.pi, 4.0]: R = galsim.zernike.zernikeRotMatrix(jmax, theta) rotCoefs = np.dot(R, coefs) - Zrot = Zernike(rotCoefs, R_outer=R_outer, R_inner=R_inner) + Zrot = galsim.zernike.Zernike(rotCoefs, R_outer=R_outer, R_inner=R_inner) print('j,theta: ',jmax,theta) print('Z: ',Z.evalPolar(rhos, thetas)) print('Zrot: ',Zrot.evalPolar(rhos, thetas+theta)) @@ -236,7 +235,7 @@ def test_zernike_eval(): np.ones(4, dtype=float), np.ones(4, dtype=np.float32) ]: - Z = Zernike(coef) + Z = galsim.zernike.Zernike(coef) assert Z.coef.dtype == np.float64 assert Z(0.0, 0.0) == 1.0 assert Z(0, 0) == 1.0 @@ -246,29 +245,29 @@ def test_zernike_eval(): np.ones((4, 4), dtype=float), np.ones((4, 4), dtype=np.float32) ]: - dz = DoubleZernike(coefs) + dz = galsim.zernike.DoubleZernike(coefs) assert dz.coef.dtype == np.float64 assert dz(0.0, 0.0) == dz(0, 0) # Make sure we cast to float in _from_uvxy uvxy = dz._coef_array_uvxy - dz2 = DoubleZernike._from_uvxy(uvxy.astype(int)) + dz2 = galsim.zernike.DoubleZernike._from_uvxy(uvxy.astype(int)) np.testing.assert_array_equal(dz2._coef_array_uvxy, dz._coef_array_uvxy) @timer def test_ne(): objs = [ - Zernike([0, 1, 2]), - Zernike([0, 1, 2, 3]), - Zernike([0, 1, 2, 3], R_outer=0.2), - Zernike([0, 1, 2, 3], R_outer=0.2, R_inner=0.1), - DoubleZernike(np.eye(3)), - DoubleZernike(np.ones((4, 4))), - DoubleZernike(np.ones((4, 4)), xy_outer=1.1), - DoubleZernike(np.ones((4, 4)), xy_outer=1.1, xy_inner=0.9), - DoubleZernike(np.ones((4, 4)), xy_outer=1.1, xy_inner=0.9, uv_outer=1.1), - DoubleZernike(np.ones((4, 4)), xy_outer=1.1, xy_inner=0.9, uv_outer=1.1, uv_inner=0.9) + galsim.zernike.Zernike([0, 1, 2]), + galsim.zernike.Zernike([0, 1, 2, 3]), + galsim.zernike.Zernike([0, 1, 2, 3], R_outer=0.2), + galsim.zernike.Zernike([0, 1, 2, 3], R_outer=0.2, R_inner=0.1), + galsim.zernike.DoubleZernike(np.eye(3)), + galsim.zernike.DoubleZernike(np.ones((4, 4))), + galsim.zernike.DoubleZernike(np.ones((4, 4)), xy_outer=1.1), + galsim.zernike.DoubleZernike(np.ones((4, 4)), xy_outer=1.1, xy_inner=0.9), + galsim.zernike.DoubleZernike(np.ones((4, 4)), xy_outer=1.1, xy_inner=0.9, uv_outer=1.1), + galsim.zernike.DoubleZernike(np.ones((4, 4)), xy_outer=1.1, xy_inner=0.9, uv_outer=1.1, uv_inner=0.9) ] check_all_diff(objs) @@ -294,7 +293,7 @@ def test_Zernike_basis(): # Compare to basis vectors generated one at a time for j in range(1, jmax): - Z = Zernike([0]*j+[1], R_outer=R_outer, R_inner=R_inner) + Z = galsim.zernike.Zernike([0]*j+[1], R_outer=R_outer, R_inner=R_inner) zBasis = Z.evalCartesian(x, y) np.testing.assert_allclose( zBases[j], @@ -326,11 +325,11 @@ def test_fit(): [u()-0.5, 0, 0, 0, 0]] z = galsim.utilities.horner2d(x, y, cartesian_coefs) z2 = galsim.utilities.horner2d(x, y, cartesian_coefs, triangle=True) - np.testing.assert_equal(z,z2) + np.testing.assert_array_equal(z,z2) basis = galsim.zernike.zernikeBasis(21, x, y, R_outer=R_outer, R_inner=R_inner) coefs, _, _, _ = np.linalg.lstsq(basis.T, z, rcond=-1.) - resids = (Zernike(coefs, R_outer=R_outer, R_inner=R_inner) + resids = (galsim.zernike.Zernike(coefs, R_outer=R_outer, R_inner=R_inner) .evalCartesian(x, y) - z) resids2 = np.dot(basis.T, coefs).T - z @@ -379,7 +378,7 @@ def test_fit(): assert basis.shape == (22, 25, 40) # lstsq doesn't handle the extra dimension though... coefs, _, _, _ = np.linalg.lstsq(basis.reshape(21+1, 1000).T, z.ravel(), rcond=-1.) - resids = (Zernike(coefs, R_outer=R_outer, R_inner=R_inner) + resids = (galsim.zernike.Zernike(coefs, R_outer=R_outer, R_inner=R_inner) .evalCartesian(x, y) - z) resids2 = np.dot(basis.T, coefs).T - z @@ -395,7 +394,7 @@ def test_gradient(): """ # Start with a few that just quote the literature, e.g., Stephenson (2014). - Z11 = Zernike([0]*11+[1]) + Z11 = galsim.zernike.Zernike([0]*11+[1]) x = np.linspace(-1, 1, 100) x, y = np.meshgrid(x, x) @@ -419,7 +418,7 @@ def Z11_grad(x, y): np.testing.assert_allclose(Z11.evalCartesianGrad(x, y), Z11_grad(x, y), rtol=1.e-12, atol=1e-12) - Z28 = Zernike([0]*28+[1]) + Z28 = galsim.zernike.Zernike([0]*28+[1]) def Z28_grad(x, y): # Z28 = sqrt(14) (x^6 - 15 x^4 y^2 + 15 x^2 y^4 - y^6) @@ -444,7 +443,7 @@ def finite_difference_gradient(Z, x, y): nj = 1+int(u()*55) R_inner = 0.2+0.6*u() R_outer = R_inner + 0.2+0.6*u() - Z = Zernike([0]+[u() for _ in range(nj)], R_inner=R_inner, R_outer=R_outer) + Z = galsim.zernike.Zernike([0]+[u() for _ in range(nj)], R_inner=R_inner, R_outer=R_outer) np.testing.assert_allclose( finite_difference_gradient(Z, x, y), @@ -452,7 +451,7 @@ def finite_difference_gradient(Z, x, y): rtol=1e-5, atol=1e-5) # Make sure the gradient of the zero-Zernike works - Z = Zernike([0]) + Z = galsim.zernike.Zernike([0]) assert Z == Z.gradX == Z.gradX.gradX == Z.gradY == Z.gradY.gradY @@ -478,7 +477,7 @@ def test_gradient_bases(): # Compare to basis vectors generated one at a time for j in range(1, jmax+1): - Z = Zernike([0]*j+[1], R_outer=R_outer, R_inner=R_inner) + Z = galsim.zernike.Zernike([0]*j+[1], R_outer=R_outer, R_inner=R_inner) ZX = Z.gradX ZY = Z.gradY @@ -517,16 +516,22 @@ def test_sum(): a2 = np.empty(n2, dtype=float) u.generate(a1) u.generate(a2) - z1 = Zernike(a1, R_outer=R_outer, R_inner=R_inner) - z2 = Zernike(a2, R_outer=R_outer, R_inner=R_inner) + z1 = galsim.zernike.Zernike(a1, R_outer=R_outer, R_inner=R_inner) + z2 = galsim.zernike.Zernike(a2, R_outer=R_outer, R_inner=R_inner) c1 = u() c2 = u() coefSum = c2*np.array(z2.coef) - coefSum[:len(z1.coef)] += c1*z1.coef + if is_jax_galsim(): + coefSum = coefSum.at[:len(z1.coef)].add(c1*z1.coef) + else: + coefSum[:len(z1.coef)] += c1*z1.coef coefDiff = c2*np.array(z2.coef) - coefDiff[:len(z1.coef)] -= c1*z1.coef + if is_jax_galsim(): + coefDiff = coefDiff.at[:len(z1.coef)].add(-c1*z1.coef) + else: + coefDiff[:len(z1.coef)] -= c1*z1.coef np.testing.assert_allclose(coefSum, (c1*z1 + c2*z2).coef) np.testing.assert_allclose(coefDiff, -(c1*z1 - c2*z2).coef) @@ -553,13 +558,13 @@ def test_sum(): with np.testing.assert_raises(TypeError): z1 - 3 with np.testing.assert_raises(ValueError): - z1 + Zernike([0,1], R_outer=z1.R_outer*2) + z1 + galsim.zernike.Zernike([0,1], R_outer=z1.R_outer*2) with np.testing.assert_raises(ValueError): - z1 + Zernike([0,1], R_outer=z1.R_outer, R_inner=z1.R_inner*2) + z1 + galsim.zernike.Zernike([0,1], R_outer=z1.R_outer, R_inner=z1.R_inner*2) # Commutative with integer coefficients - z1 = Zernike([0,1,2,3,4]) - z2 = Zernike([1,2,3,4,5,6]) + z1 = galsim.zernike.Zernike([0,1,2,3,4]) + z2 = galsim.zernike.Zernike([1,2,3,4,5,6]) assert z1+z2 == z2+z1 assert (z2-z1) == z2 + -z1 == -(z1-z2) @@ -583,8 +588,8 @@ def test_product(): a2 = np.empty(n2, dtype=float) u.generate(a1) u.generate(a2) - z1 = Zernike(a1, R_outer=R_outer, R_inner=R_inner) - z2 = Zernike(a2, R_outer=R_outer, R_inner=R_inner) + z1 = galsim.zernike.Zernike(a1, R_outer=R_outer, R_inner=R_inner) + z2 = galsim.zernike.Zernike(a2, R_outer=R_outer, R_inner=R_inner) np.testing.assert_allclose( z1(x, y) * z2(x, y), @@ -627,15 +632,15 @@ def test_product(): with np.testing.assert_raises(TypeError): z1 * galsim.Gaussian(fwhm=1) with np.testing.assert_raises(ValueError): - z1 * Zernike([0,1], R_outer=z1.R_outer*2) + z1 * galsim.zernike.Zernike([0,1], R_outer=z1.R_outer*2) with np.testing.assert_raises(ValueError): - z1 * Zernike([0,1], R_outer=z1.R_outer, R_inner=z1.R_inner*2) + z1 * galsim.zernike.Zernike([0,1], R_outer=z1.R_outer, R_inner=z1.R_inner*2) with np.testing.assert_raises(TypeError): z1 / z2 # Commutative with integer coefficients - z1 = Zernike([0,1,2,3,4,5]) - z2 = Zernike([1,2,3,4,5,6]) + z1 = galsim.zernike.Zernike([0,1,2,3,4,5]) + z2 = galsim.zernike.Zernike([1,2,3,4,5,6]) assert z1*z2 == z2*z1 @@ -655,7 +660,7 @@ def test_laplacian(): u.generate(a) R_outer = 1+0.1*u() R_inner = 0.1*u() - z = Zernike(a, R_outer=R_outer, R_inner=R_inner) + z = galsim.zernike.Zernike(a, R_outer=R_outer, R_inner=R_inner) np.testing.assert_allclose( z.laplacian(x, y), @@ -676,7 +681,7 @@ def test_laplacian(): # implies laplacian = 4 sqrt(3) + 4 sqrt(3) = 8 sqrt(3) # which is 8 sqrt(3) Z1 np.testing.assert_allclose( - Zernike([0,0,0,0,1]).laplacian.coef, + galsim.zernike.Zernike([0,0,0,0,1]).laplacian.coef, np.array([0,8*np.sqrt(3)]) ) @@ -686,7 +691,7 @@ def test_laplacian(): # implies laplacian = 24 sqrt(8) y # which is 12*sqrt(8) * Z3 since Z3 = 2 y np.testing.assert_allclose( - Zernike([0,0,0,0,0,0,0,1]).laplacian.coef, + galsim.zernike.Zernike([0,0,0,0,0,0,0,1]).laplacian.coef, np.array([0,0,0,12*np.sqrt(8)]) ) @@ -707,7 +712,7 @@ def test_hessian(): u.generate(a) R_outer = 1+0.1*u() R_inner = 0.1*u() - z = Zernike(a, R_outer=R_outer, R_inner=R_inner) + z = galsim.zernike.Zernike(a, R_outer=R_outer, R_inner=R_inner) np.testing.assert_allclose( z.hessian(x, y), @@ -728,7 +733,7 @@ def test_hessian(): # implies hessian = 4 sqrt(3) * 4 sqrt(3) - 0 * 0 = 16*3 = 48 # which is 48 Z1 np.testing.assert_allclose( - Zernike([0,0,0,0,1]).hessian.coef, + galsim.zernike.Zernike([0,0,0,0,1]).hessian.coef, np.array([0,48]) ) @@ -741,7 +746,7 @@ def test_hessian(): # That's a little inconvenient to decompose into Zernikes by hand, but we can test against # an array of (x,y) values. np.testing.assert_allclose( - Zernike([0,0,0,0,0,0,0,1]).hessian(x, y), + galsim.zernike.Zernike([0,0,0,0,0,0,0,1]).hessian(x, y), 864*y*y - 288*x*x ) @@ -783,7 +788,7 @@ def test_lazy_coef(): zarr = [0]+[u() for i in range(jmax)] R_inner = u()*0.5+0.2 R_outer = u()*2.0+2.0 - Z = Zernike(zarr, R_outer=R_outer, R_inner=R_inner) + Z = galsim.zernike.Zernike(zarr, R_outer=R_outer, R_inner=R_inner) Z._coef_array_xy del Z.coef np.testing.assert_allclose(zarr, Z.coef, rtol=0, atol=1e-12) @@ -793,7 +798,7 @@ def test_lazy_coef(): zarr = [0]+[u() for i in range(jmax)] R_inner = u()*0.5+0.2 R_outer = u()*2.0+2.0 - Z = Zernike(zarr, R_outer=R_outer, R_inner=R_inner) + Z = galsim.zernike.Zernike(zarr, R_outer=R_outer, R_inner=R_inner) Z._coef_array_xy del Z.coef np.testing.assert_allclose(zarr, Z.coef[:len(zarr)], rtol=0, atol=1e-12) @@ -812,7 +817,7 @@ def test_dz_val(): uv_outer = rng.uniform(1.3, 1.7) xy_inner = rng.uniform(0.4, 0.7) xy_outer = rng.uniform(1.3, 1.7) - dz = DoubleZernike( + dz = galsim.zernike.DoubleZernike( coef, uv_inner=uv_inner, uv_outer=uv_outer, @@ -836,9 +841,9 @@ def test_dz_val(): check_pickle(dz, lambda dz_: tuple(dz_(*uv_vector, *xy_vector))) # If you don't specify xy, then get (list of) Zernike out. - assert isinstance(dz(*uv_scalar), Zernike) + assert isinstance(dz(*uv_scalar), galsim.zernike.Zernike) assert isinstance(dz(*uv_vector), list) - assert all(isinstance(z, Zernike) for z in dz(*uv_vector)) + assert all(isinstance(z, galsim.zernike.Zernike) for z in dz(*uv_vector)) # If uv scalar and xy scalar, then get scalar out. assert np.ndim(dz(*uv_scalar, *xy_scalar)) == 0 @@ -891,7 +896,7 @@ def test_dz_val(): dz([0.0, 1.0], [0.0, 1.0], x=[1.0], y=[1.0]) # Try pickle/repr with default domain - dz = DoubleZernike(coef) + dz = galsim.zernike.DoubleZernike(coef) check_pickle(dz) @@ -908,7 +913,7 @@ def test_dz_coef_uvxy(): uv_outer = rng.uniform(1.3, 1.7) xy_inner = rng.uniform(0.4, 0.7) xy_outer = rng.uniform(1.3, 1.7) - dz = DoubleZernike( + dz = galsim.zernike.DoubleZernike( coef, uv_inner=uv_inner, uv_outer=uv_outer, @@ -988,12 +993,12 @@ def test_dz_sum(): coef2[0] = 0.0 coef2[:, 0] = 0.0 - dz1 = DoubleZernike( + dz1 = galsim.zernike.DoubleZernike( coef1, uv_inner=uv_inner, uv_outer=uv_outer, xy_inner=xy_inner, xy_outer=xy_outer ) - dz2 = DoubleZernike( + dz2 = galsim.zernike.DoubleZernike( coef2, uv_inner=uv_inner, uv_outer=uv_outer, xy_inner=xy_inner, xy_outer=xy_outer @@ -1048,25 +1053,25 @@ def test_dz_sum(): with np.testing.assert_raises(TypeError): dz1 - 3 with np.testing.assert_raises(ValueError): - dz1 + DoubleZernike( + dz1 + galsim.zernike.DoubleZernike( coef1, uv_outer=2*uv_outer, uv_inner=uv_inner, xy_outer=xy_outer, xy_inner=xy_inner ) with np.testing.assert_raises(ValueError): - dz1 + DoubleZernike( + dz1 + galsim.zernike.DoubleZernike( coef1, uv_outer=uv_outer, uv_inner=2*uv_inner, xy_outer=xy_outer, xy_inner=xy_inner ) with np.testing.assert_raises(ValueError): - dz1 + DoubleZernike( + dz1 + galsim.zernike.DoubleZernike( coef1, uv_outer=uv_outer, uv_inner=uv_inner, xy_outer=2*xy_outer, xy_inner=xy_inner ) with np.testing.assert_raises(ValueError): - dz1 + DoubleZernike( + dz1 + galsim.zernike.DoubleZernike( coef1, uv_outer=uv_outer, uv_inner=uv_inner, xy_outer=xy_outer, xy_inner=2*xy_inner ) # Commutative with integer coefficients - dz1 = DoubleZernike(np.eye(3, dtype=int)) - dz2 = DoubleZernike(np.ones((4, 4), dtype=int)) + dz1 = galsim.zernike.DoubleZernike(np.eye(3, dtype=int)) + dz2 = galsim.zernike.DoubleZernike(np.ones((4, 4), dtype=int)) assert dz1 + dz2 == dz2 + dz1 assert (dz2 - dz1) == dz2 + (-dz1) == -(dz1 - dz2) @@ -1105,12 +1110,12 @@ def test_dz_product(): coef2[0] = 0.0 coef2[:, 0] = 0.0 - dz1 = DoubleZernike( + dz1 = galsim.zernike.DoubleZernike( coef1, uv_inner=uv_inner, uv_outer=uv_outer, xy_inner=xy_inner, xy_outer=xy_outer ) - dz2 = DoubleZernike( + dz2 = galsim.zernike.DoubleZernike( coef2, uv_inner=uv_inner, uv_outer=uv_outer, xy_inner=xy_inner, xy_outer=xy_outer @@ -1155,27 +1160,27 @@ def test_dz_product(): with np.testing.assert_raises(TypeError): dz1 * galsim.Gaussian(sigma=1.0) with np.testing.assert_raises(ValueError): - dz1 * DoubleZernike( + dz1 * galsim.zernike.DoubleZernike( coef1, uv_outer=2*uv_outer, uv_inner=uv_inner, xy_outer=xy_outer, xy_inner=xy_inner ) with np.testing.assert_raises(ValueError): - dz1 * DoubleZernike( + dz1 * galsim.zernike.DoubleZernike( coef1, uv_outer=uv_outer, uv_inner=2*uv_inner, xy_outer=xy_outer, xy_inner=xy_inner ) with np.testing.assert_raises(ValueError): - dz1 * DoubleZernike( + dz1 * galsim.zernike.DoubleZernike( coef1, uv_outer=uv_outer, uv_inner=uv_inner, xy_outer=2*xy_outer, xy_inner=xy_inner ) with np.testing.assert_raises(ValueError): - dz1 * DoubleZernike( + dz1 * galsim.zernike.DoubleZernike( coef1, uv_outer=uv_outer, uv_inner=uv_inner, xy_outer=xy_outer, xy_inner=2*xy_inner ) with np.testing.assert_raises(TypeError): dz1 / dz2 # Commutative with integer coefficients - dz1 = DoubleZernike(np.eye(3, dtype=int)) - dz2 = DoubleZernike(np.ones((4, 4), dtype=int)) + dz1 = galsim.zernike.DoubleZernike(np.eye(3, dtype=int)) + dz2 = galsim.zernike.DoubleZernike(np.ones((4, 4), dtype=int)) assert dz1 * dz2 == dz2 * dz1 assert (dz2 * 3) == (3 * dz2) @@ -1203,7 +1208,7 @@ def test_dz_grad(): coef[0] = 0.0 coef[:, 0] = 0.0 - dz = DoubleZernike( + dz = galsim.zernike.DoubleZernike( coef, uv_inner=uv_inner, uv_outer=uv_outer, xy_inner=xy_inner, xy_outer=xy_outer @@ -1221,7 +1226,7 @@ def test_dz_grad(): # U and V are trickier, since we aren't including a way to turn a DZ evaluated # at (x, y) into a single Zernike of (u, v). We can mock that though up by # transposing the DZ coefficients and swapping the domain parameters. - dz_xyuv = DoubleZernike( + dz_xyuv = galsim.zernike.DoubleZernike( np.transpose(coef, axes=(1, 0)), uv_inner=xy_inner, uv_outer=xy_outer, xy_inner=uv_inner, xy_outer=uv_outer @@ -1284,7 +1289,7 @@ def test_dz_to_T(): coef[0] = 0.0 coef[:, 0] = 0.0 - W = DoubleZernike( + W = galsim.zernike.DoubleZernike( coef, uv_inner=uv_inner, uv_outer=uv_outer, # field xy_inner=xy_inner, xy_outer=xy_outer # pupil @@ -1391,7 +1396,7 @@ def test_dz_rotate(): coef[0] = 0.0 coef[:, 0] = 0.0 - dz = DoubleZernike( + dz = galsim.zernike.DoubleZernike( coef, uv_inner=uv_inner, uv_outer=uv_outer, # field xy_inner=xy_inner, xy_outer=xy_outer # pupil @@ -1449,7 +1454,7 @@ def test_dz_basis(): for k in range(1, k1): coef = np.zeros((k1, j1)) coef[k, j] = 1.0 - DZ = DoubleZernike( + DZ = galsim.zernike.DoubleZernike( coef, uv_inner=uv_inner, uv_outer=uv_outer, xy_inner=xy_inner, xy_outer=xy_outer @@ -1479,7 +1484,7 @@ def test_dz_mean(): coef[0] = 0.0 coef[:, 0] = 0.0 - dz = DoubleZernike( + dz = galsim.zernike.DoubleZernike( coef, uv_inner=uv_inner, uv_outer=uv_outer, xy_inner=xy_inner, xy_outer=xy_outer @@ -1487,7 +1492,7 @@ def test_dz_mean(): # We don't have a function that returns a Zernike over uv at a given xy # point, but we can mimic that by transposing xy an uv in a new # DoubleZernike object. - dzT = DoubleZernike( + dzT = galsim.zernike.DoubleZernike( coef.T, uv_inner=xy_inner, uv_outer=xy_outer, xy_inner=uv_inner, xy_outer=uv_outer