-
Notifications
You must be signed in to change notification settings - Fork 107
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Speed up combine_wave_lists using new merge_sorted function
- Loading branch information
Showing
8 changed files
with
362 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
# Copyright (c) 2012-2022 by the GalSim developers team on GitHub | ||
# https://github.com/GalSim-developers | ||
# | ||
# This file is part of GalSim: The modular galaxy image simulation toolkit. | ||
# https://github.com/GalSim-developers/GalSim | ||
# | ||
# GalSim is free software: redistribution and use in source and binary forms, | ||
# with or without modification, are permitted provided that the following | ||
# conditions are met: | ||
# | ||
# 1. Redistributions of source code must retain the above copyright notice, this | ||
# list of conditions, and the disclaimer given in the accompanying LICENSE | ||
# file. | ||
# 2. Redistributions in binary form must reproduce the above copyright notice, | ||
# this list of conditions, and the disclaimer given in the documentation | ||
# and/or other materials provided with the distribution. | ||
# | ||
|
||
import timeit | ||
import galsim | ||
import numpy as np | ||
|
||
from galsim.utilities import Profile | ||
|
||
def old_combine_wave_list(*args): | ||
if len(args) == 1: | ||
if isinstance(args[0], (list, tuple)): | ||
args = args[0] | ||
else: | ||
raise TypeError("Single input argument must be a list or tuple") | ||
|
||
if len(args) == 0: | ||
return np.array([], dtype=float), 0.0, np.inf | ||
|
||
if len(args) == 1: | ||
obj = args[0] | ||
return obj.wave_list, getattr(obj, 'blue_limit', 0.0), getattr(obj, 'red_limit', np.inf) | ||
|
||
blue_limit = np.max([getattr(obj, 'blue_limit', 0.0) for obj in args]) | ||
red_limit = np.min([getattr(obj, 'red_limit', np.inf) for obj in args]) | ||
if blue_limit > red_limit: | ||
raise GalSimError("Empty wave_list intersection.") | ||
|
||
waves = [np.asarray(obj.wave_list) for obj in args] | ||
waves = [w[(blue_limit <= w) & (w <= red_limit)] for w in waves] | ||
wave_list = np.union1d(waves[0], waves[1]) | ||
for w in waves[2:]: | ||
wave_list = np.union1d(wave_list, w) | ||
# Make sure both limits are included in final list | ||
if len(wave_list) > 0 and (wave_list[0] != blue_limit or wave_list[-1] != red_limit): | ||
wave_list = np.union1d([blue_limit, red_limit], wave_list) | ||
return wave_list, blue_limit, red_limit | ||
|
||
# This edit was suggested by Jim Chiang to not merge things if they are all equal. | ||
# (Slightly improved to use np.array_equal, rather than all(waves[0] == w).) | ||
# It helps a lot when the inputs are equal, but not quite as much as the new C++ code. | ||
def jims_combine_wave_list(*args): | ||
if len(args) == 1: | ||
if isinstance(args[0], (list, tuple)): | ||
args = args[0] | ||
else: | ||
raise TypeError("Single input argument must be a list or tuple") | ||
|
||
if len(args) == 0: | ||
return np.array([], dtype=float), 0.0, np.inf | ||
|
||
if len(args) == 1: | ||
obj = args[0] | ||
return obj.wave_list, getattr(obj, 'blue_limit', 0.0), getattr(obj, 'red_limit', np.inf) | ||
|
||
blue_limit = np.max([getattr(obj, 'blue_limit', 0.0) for obj in args]) | ||
red_limit = np.min([getattr(obj, 'red_limit', np.inf) for obj in args]) | ||
if blue_limit > red_limit: | ||
raise GalSimError("Empty wave_list intersection.") | ||
|
||
waves = [np.asarray(obj.wave_list) for obj in args] | ||
waves = [w[(blue_limit <= w) & (w <= red_limit)] for w in waves] | ||
if (len(waves[0]) == len(waves[1]) | ||
and all(np.array_equal(waves[0], w) for w in waves[1:])): | ||
wave_list = waves[0] | ||
else: | ||
wave_list = np.union1d(waves[0], waves[1]) | ||
for w in waves[2:]: | ||
wave_list = np.union1d(wave_list, w) | ||
# Make sure both limits are included in final list | ||
if len(wave_list) > 0 and (wave_list[0] != blue_limit or wave_list[-1] != red_limit): | ||
wave_list = np.union1d([blue_limit, red_limit], wave_list) | ||
return wave_list, blue_limit, red_limit | ||
|
||
|
||
sed_list = [ galsim.SED(name, wave_type='ang', flux_type='flambda') for name in | ||
['CWW_E_ext.sed', 'CWW_Im_ext.sed', 'CWW_Sbc_ext.sed', 'CWW_Scd_ext.sed'] ] | ||
|
||
ref_wave, ref_bl, ref_rl = old_combine_wave_list(sed_list) | ||
wave_list, blue_limit, red_limit = galsim.utilities.combine_wave_list(sed_list) | ||
np.testing.assert_array_equal(wave_list, ref_wave) | ||
assert blue_limit == ref_bl | ||
assert red_limit == ref_rl | ||
|
||
n = 10000 | ||
t1 = min(timeit.repeat(lambda: old_combine_wave_list(sed_list), number=n)) | ||
t2 = min(timeit.repeat(lambda: jims_combine_wave_list(sed_list), number=n)) | ||
t3 = min(timeit.repeat(lambda: galsim.utilities.combine_wave_list(sed_list), number=n)) | ||
|
||
print(f'Time for {n} iterations of combine_wave_list') | ||
print('old time = ',t1) | ||
print('jims time = ',t2) | ||
print('new time = ',t3) | ||
|
||
# Check when all wave_lists are equal. | ||
sed_list = [ galsim.SED(name, wave_type='ang', flux_type='flambda') for name in | ||
['CWW_E_ext.sed', 'CWW_E_ext.sed', 'CWW_E_ext.sed', 'CWW_E_ext.sed'] ] | ||
|
||
ref_wave, ref_bl, ref_rl = old_combine_wave_list(sed_list) | ||
jims_wave, jims_bl, jims_rl = jims_combine_wave_list(sed_list) | ||
wave_list, blue_limit, red_limit = galsim.utilities.combine_wave_list(sed_list) | ||
|
||
np.testing.assert_array_equal(wave_list, ref_wave) | ||
assert blue_limit == ref_bl | ||
assert red_limit == ref_rl | ||
|
||
t1 = min(timeit.repeat(lambda: old_combine_wave_list(sed_list), number=n)) | ||
t2 = min(timeit.repeat(lambda: jims_combine_wave_list(sed_list), number=n)) | ||
t3 = min(timeit.repeat(lambda: galsim.utilities.combine_wave_list(sed_list), number=n)) | ||
|
||
print(f'Time for {n} iterations of combine_wave_list with identical wave_lists') | ||
print('old time = ',t1) | ||
print('jims time = ',t2) | ||
print('new time = ',t3) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
/* -*- c++ -*- | ||
* Copyright (c) 2012-2022 by the GalSim developers team on GitHub | ||
* https://github.com/GalSim-developers | ||
* | ||
* This file is part of GalSim: The modular galaxy image simulation toolkit. | ||
* https://github.com/GalSim-developers/GalSim | ||
* | ||
* GalSim is free software: redistribution and use in source and binary forms, | ||
* with or without modification, are permitted provided that the following | ||
* conditions are met: | ||
* | ||
* 1. Redistributions of source code must retain the above copyright notice, this | ||
* list of conditions, and the disclaimer given in the accompanying LICENSE | ||
* file. | ||
* 2. Redistributions in binary form must reproduce the above copyright notice, | ||
* this list of conditions, and the disclaimer given in the documentation | ||
* and/or other materials provided with the distribution. | ||
*/ | ||
|
||
//#define DEBUGLOGGING | ||
|
||
#include <limits> | ||
#include "PyBind11Helper.h" | ||
#include "Std.h" | ||
|
||
namespace galsim { | ||
|
||
static py::array_t<double> MergeSorted(py::list& arrays) | ||
{ | ||
dbg<<"Start MergeSorted: "<<arrays.size()<<std::endl; | ||
const int n_arrays = arrays.size(); | ||
if (n_arrays == 0) | ||
throw std::runtime_error("No arrays provided to merge_sorted"); | ||
assert(n_arrays > 0); | ||
py::array_t<double> a0 = arrays[0].cast<py::array_t<double> >(); | ||
int n0 = a0.size(); | ||
dbg<<"size of array 0 = "<<n0<<std::endl; | ||
|
||
// First figure out the maximum possible size of the return array. | ||
int max_ret_size = n0; | ||
const double* a0_begin = static_cast<const double*>(a0.data()); | ||
for(int k=1; k<n_arrays; ++k) { | ||
py::array_t<double> ak = arrays[k].cast<py::array_t<double> >(); | ||
int nk = ak.size(); | ||
dbg<<"size of array "<<k<<" = "<<nk<<std::endl; | ||
// Check how far into the array, this one is itentical to a0. | ||
// Do this from both sizes. Not least because in GalSim, the typical | ||
// way with use this includes a 2-element array which is often just the | ||
// first and last values of other arrays. | ||
const double* a0_p1 = static_cast<const double*>(a0.data()); | ||
const double* ak_p1 = static_cast<const double*>(ak.data()); | ||
const double* a0_p2 = a0_p1 + n0; | ||
const double* ak_p2 = ak_p1 + nk; | ||
while (a0_p1 != a0_p2 && ak_p1 != ak_p2 && *a0_p1 == *ak_p1) { | ||
++a0_p1; ++ak_p1; | ||
} | ||
while (a0_p1 != a0_p2 && ak_p1 != ak_p2 && *(a0_p2-1) == *(ak_p2-1)) { | ||
--a0_p2; --ak_p2; | ||
} | ||
int n_left = ak_p2 - ak_p1; | ||
dbg<<"For array "<<k<<", "<<nk - n_left<<" elements are identical to a0\n"; | ||
max_ret_size += n_left; | ||
} | ||
dbg<<"max_ret_size = "<<max_ret_size<<std::endl; | ||
if (max_ret_size == n0) { | ||
// Then arrays[0] already has all the values. No need to merge. | ||
// (This is not terribly uncommon, and the early exit saves a lot of time!) | ||
return a0; | ||
} | ||
|
||
// We actually merge these 1 at a time, since that's much simpler (and maybe even faster?). | ||
// At each step, | ||
// p0 is a pointer into the first array being merged (possibly a previous merge result) | ||
// p1 is a pointer into the second array being merged | ||
// p2 is a pointer into the resulting merged array. | ||
// Note: If more than 2 input arrays to merge, we might need a second temporary vector. | ||
// This will be swapped with res as needed during the iteration. | ||
|
||
std::vector<double> res(max_ret_size); | ||
std::vector<double> res2(n_arrays == 2 ? 0 : max_ret_size); | ||
const double* p0 = static_cast<const double*>(a0.data()); | ||
const double* p0_end = p0 + n0; | ||
double* p2 = res.data(); | ||
int n_res = 0; | ||
assert(n_arrays > 1); // Can't get here if len(arrays) == 1 | ||
|
||
for(int k=1; k<n_arrays; ++k) { | ||
py::array_t<double> a1 = arrays[k].cast<py::array_t<double> >(); | ||
const double* p1 = static_cast<const double*>(a1.data()); | ||
const double* p1_end = p1 + a1.size(); | ||
|
||
// Keep track of the previous value to be placed in the result array, | ||
// so we can raise an exception if an input array is not sorted. | ||
double prev = -std::numeric_limits<double>::max(); | ||
|
||
while (p0 != p0_end && p1 != p1_end) { | ||
double x; | ||
// Select the smaller one. | ||
if (*p1 < *p0) { | ||
x = *p1++; | ||
} else { | ||
x = *p0++; | ||
if (*p1 == x) ++p1; | ||
} | ||
if (x == prev) continue; // skip duplicates | ||
// Make sure the inputs make sense. | ||
if (x < prev) { | ||
throw std::runtime_error("Arrays are not sorted"); | ||
} | ||
*p2++ = prev = x; | ||
} | ||
|
||
// Now at least one of the two arrays are exhausted. Fill the rest of res. | ||
while (p0 != p0_end) { | ||
double x = *p0++; | ||
if (x == prev) continue; | ||
if (x < prev) { | ||
throw std::runtime_error("Arrays are not sorted"); | ||
} | ||
*p2++ = prev = x; | ||
} | ||
while (p1 != p1_end) { | ||
double x = *p1++; | ||
if (x == prev) continue; | ||
if (x < prev) { | ||
throw std::runtime_error("Arrays are not sorted"); | ||
} | ||
*p2++ = prev = x; | ||
} | ||
assert(p2 <= res.data()+max_ret_size); | ||
// The final value of p2-res.data() is the relevant length of res. | ||
n_res = p2 - res.data(); | ||
|
||
if (k+1 < n_arrays) { | ||
// Set up for the next pass through the loop. | ||
res.swap(res2); | ||
// Now res2 has the result of this loop. Use that for a0 in next loop. | ||
p0 = res2.data(); | ||
p0_end = p0 + n_res; | ||
p2 = res.data(); | ||
} | ||
} | ||
dbg<<"Done. Final size = "<<n_res<<std::endl; | ||
// Finally, return res as a numpy array | ||
return py::array_t<double>(n_res, res.data()); | ||
} | ||
|
||
void pyExportUtilities(py::module& _galsim) | ||
{ | ||
_galsim.def("MergeSorted", &MergeSorted); | ||
} | ||
|
||
} // namespace galsim |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.