diff --git a/lenstronomy/LightModel/Profiles/hernquist.py b/lenstronomy/LightModel/Profiles/hernquist.py index 7a33bb4a6..049f6a20e 100644 --- a/lenstronomy/LightModel/Profiles/hernquist.py +++ b/lenstronomy/LightModel/Profiles/hernquist.py @@ -1,4 +1,5 @@ import lenstronomy.Util.param_util as param_util +import numpy as np __all__ = ["Hernquist", "HernquistEllipse"] @@ -49,6 +50,18 @@ def light_3d(self, r, amp, Rs): rho0 = self.lens.sigma2rho(amp, Rs) return self.lens.density(r, rho0, Rs) + @staticmethod + def total_flux(amp, Rs): + """ + + :param amp: + :param Rs: + :return: + """ + rhos = amp / Rs + m_tot = 2 * np.pi * rhos * Rs ** 3 + return m_tot + class HernquistEllipse(object): """Class for elliptical pseudo Jaffe lens light (2d projected light/mass @@ -108,3 +121,14 @@ def light_3d(self, r, amp, Rs, e1=0, e2=0): """ rho0 = self.lens.sigma2rho(amp, Rs) return self.lens.density(r, rho0, Rs) + + def total_flux(self, amp, Rs, e1=0, e2=0): + """ + + :param amp: + :param Rs: + :param e1: + :param e2: + :return: + """ + return self.spherical.total_flux(amp=amp, Rs=Rs) diff --git a/test/test_LightModel/test_Profiles/test_hernquist.py b/test/test_LightModel/test_Profiles/test_hernquist.py new file mode 100644 index 000000000..b0cf597e8 --- /dev/null +++ b/test/test_LightModel/test_Profiles/test_hernquist.py @@ -0,0 +1,26 @@ +from lenstronomy.LightModel.Profiles.hernquist import Hernquist, HernquistEllipse +from lenstronomy.Util.util import make_grid +import numpy as np +import numpy.testing as npt + + +class TestHernquist(object): + + def setup_method(self): + self.hernquist = Hernquist() + self.hernquist_ellipse = HernquistEllipse() + + def test_total_flux(self): + delta_pix = 0.2 + x, y = make_grid(numPix=1000, deltapix=delta_pix) + + rs, amp = 1, 1 + total_flux = self.hernquist.total_flux(amp=amp, Rs=rs) + flux = self.hernquist.function(x, y, amp=amp, Rs=rs) + total_flux_numerics = np.sum(flux) * delta_pix**2 + npt.assert_almost_equal(total_flux_numerics/total_flux, 1, decimal=1) + + total_flux_ellipse = self.hernquist_ellipse.total_flux(amp=amp, Rs=rs) + npt.assert_almost_equal(total_flux_ellipse, total_flux) + +