Skip to content

Commit

Permalink
Merge pull request #1141 from helmholtz-analytics/bug/789-pow-perform…
Browse files Browse the repository at this point in the history
…ance

`heat.pow()` speed-up when exponent is `int`
  • Loading branch information
mtar authored Apr 24, 2023
2 parents eaa90d6 + ed242bc commit 9f6b2eb
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 7 deletions.
21 changes: 15 additions & 6 deletions .gitlab-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,28 +5,37 @@ test:
- x86_64
script:
- apt update
- apt -y install build-essential python3-pip curl git
- "curl -X POST -H \"Accept: application/vnd.github+json\" -H \"Authorization: Bearer $GITHUB_TOKEN\" -H \"X-GitHub-Api-Version: 2022-11-28\" https://api.github.com/repos/helmholtz-analytics/heat/statuses/$CI_COMMIT_SHA -d '{\"state\":\"pending\",\"target_url\":\"https://codebase.helmholtz.cloud/haf/heat/-/jobs/$CI_JOB_ID\",\"description\":\"The build runs!\",\"context\":\"continuous-integration/codebase\"}'\n"
- apt -y install build-essential python3-pip curl git jo
- json=$(jo state="pending" target_url="https://codebase.helmholtz.cloud/haf/heat/-/jobs/$CI_JOB_ID" description="The build runs!" context="continuous-integration/codebase")
- "curl -o /dev/null -L -X POST -H \"Accept: application/vnd.github+json\" -H \"Authorization: Bearer $GITHUB_TOKEN\" -H \"X-GitHub-Api-Version: 2022-11-28\" https://api.github.com/repos/helmholtz-analytics/heat/statuses/$CI_COMMIT_SHA -d \"$json\" \n"
- pip install pytest coverage
- pip install .[hdf5,netcdf]
- COVERAGE_FILE=report/cov/coverage1 HEAT_TEST_USE_DEVICE=cpu mpirun --allow-run-as-root -n 1 coverage run --source=heat --parallel-mode -m pytest --junitxml=report/test/report1.xml heat/
- COVERAGE_FILE=report/cov/coverage2 HEAT_TEST_USE_DEVICE=cpu mpirun --allow-run-as-root -n 2 coverage run --source=heat --parallel-mode -m pytest --junitxml=report/test/report2.xml heat/
- COVERAGE_FILE=report/cov/coverage3 HEAT_TEST_USE_DEVICE=gpu mpirun --allow-run-as-root -n 3 coverage run --source=heat --parallel-mode -m pytest --junitxml=report/test/report3.xml heat/
- COVERAGE_FILE=report/cov/coverage5 HEAT_TEST_USE_DEVICE=cpu mpirun --allow-run-as-root -n 5 coverage run --source=heat --parallel-mode -m pytest --junitxml=report/test/report5.xml heat/
- COVERAGE_FILE=report/cov/coverage8 HEAT_TEST_USE_DEVICE=gpu mpirun --allow-run-as-root -n 6 coverage run --source=heat --parallel-mode -m pytest --junitxml=report/test/report6.xml heat/
- COVERAGE_FILE=report/cov/coverage6 HEAT_TEST_USE_DEVICE=gpu mpirun --allow-run-as-root -n 6 coverage run --source=heat --parallel-mode -m pytest --junitxml=report/test/report6.xml heat/
- coverage combine report/cov/*
- coverage report
- coverage xml
- curl -Os https://uploader.codecov.io/latest/linux/codecov
- chmod +x codecov
- ./codecov -F unit -f ./coverage.xml -t $CODECOV_TOKEN -Z
after_script:
- apt -y install jo
- |
case $CI_JOB_STATUS in
'success') curl -X POST -H "Accept: application/vnd.github+json" -H "Authorization: Bearer $GITHUB_TOKEN" -H "X-GitHub-Api-Version: 2022-11-28" https://api.github.com/repos/helmholtz-analytics/heat/statuses/$CI_COMMIT_SHA -d '{"state":"success","target_url":"https://codebase.helmholtz.cloud/haf/heat/-/jobs/$CI_JOB_ID","description":"The build succeeded!","context":"continuous-integration/codebase"}';;
'failed') curl -X POST -H "Accept: application/vnd.github+json" -H "Authorization: Bearer $GITHUB_TOKEN" -H "X-GitHub-Api-Version: 2022-11-28" https://api.github.com/repos/helmholtz-analytics/heat/statuses/$CI_COMMIT_SHA -d '{"state":"failure","target_url":"https://codebase.helmholtz.cloud/haf/heat/-/jobs/$CI_JOB_ID","description":"The build failed!","context":"continuous-integration/codebase"}';;
*) curl -X POST -H "Accept: application/vnd.github+json" -H "Authorization: Bearer $GITHUB_TOKEN" -H "X-GitHub-Api-Version: 2022-11-28" https://api.github.com/repos/helmholtz-analytics/heat/statuses/$CI_COMMIT_SHA -d '{"state":"error","target_url":"https://codebase.helmholtz.cloud/haf/heat/-/jobs/$CI_JOB_ID","description":"The build errored!","context":"continuous-integration/codebase"}';;
'success')
json=$(jo state="success" target_url="https://codebase.helmholtz.cloud/haf/heat/-/jobs/$CI_JOB_ID" description="The build succeeded!" context="continuous-integration/codebase")
;;
'failed')
json=$(jo state="failure" target_url="https://codebase.helmholtz.cloud/haf/heat/-/jobs/$CI_JOB_ID" description="The build failed!" context="continuous-integration/codebase")
;;
*)
json=$(jo state="error" target_url="https://codebase.helmholtz.cloud/haf/heat/-/jobs/$CI_JOB_ID" description="The build errored!" context="continuous-integration/codebase")
;;
esac
curl -o /dev/null -L -X POST -H "Accept: application/vnd.github+json" -H "Authorization: Bearer $GITHUB_TOKEN" -H "X-GitHub-Api-Version: 2022-11-28" https://api.github.com/repos/helmholtz-analytics/heat/statuses/$CI_COMMIT_SHA -d "$json"
artifacts:
when: always
paths:
Expand Down
31 changes: 31 additions & 0 deletions heat/core/arithmetics.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,6 +793,37 @@ def pow(t1: Union[DNDarray, float], t2: Union[DNDarray, float]) -> DNDarray:
DNDarray([[ 1., 8.],
[27., 64.]], dtype=ht.float32, device=cpu:0, split=None)
"""
# early exit for integer scalars
if isinstance(t2, int):
try:
result = torch.pow(t1.larray, t2)
return DNDarray(
result,
gshape=t1.gshape,
dtype=t1.dtype,
device=t1.device,
split=t1.split,
comm=t1.comm,
balanced=t1.balanced,
)
except AttributeError:
# t1 is no DNDarray
pass
elif isinstance(t1, int):
try:
result = torch.pow(t1, t2.larray)
return DNDarray(
result,
gshape=t2.gshape,
dtype=t2.dtype,
device=t2.device,
split=t2.split,
comm=t2.comm,
balanced=t2.balanced,
)
except AttributeError:
# t2 is no DNDarray
pass
return _operations.__binary_op(torch.pow, t1, t2)


Expand Down
10 changes: 9 additions & 1 deletion heat/core/tests/test_arithmetics.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,10 @@ def test_diff(self):
self.assertEqual(ht_diff.split, 1)
self.assertEqual(ht_diff.dtype, ht_array.dtype)

# test n=0
ht_diff = ht.diff(ht_array, n=0)
self.assertTrue(ht.equal(ht_diff, ht_array))

# raises
with self.assertRaises(ValueError):
ht.diff(ht_array, n=-2)
Expand Down Expand Up @@ -535,7 +539,6 @@ def test_pos(self):
def test_pow(self):
result = ht.array([[1.0, 4.0], [9.0, 16.0]])
commutated_result = ht.array([[2.0, 4.0], [8.0, 16.0]])

self.assertTrue(ht.equal(ht.pow(self.a_scalar, self.a_scalar), ht.array(4.0)))
self.assertTrue(ht.equal(ht.pow(self.a_tensor, self.a_scalar), result))
self.assertTrue(ht.equal(ht.pow(self.a_scalar, self.a_tensor), commutated_result))
Expand All @@ -544,6 +547,11 @@ def test_pow(self):
self.assertTrue(ht.equal(ht.pow(self.a_tensor, self.an_int_scalar), result))
self.assertTrue(ht.equal(ht.pow(self.a_split_tensor, self.a_tensor), commutated_result))

# test scalar base and exponent
self.assertTrue(ht.equal(ht.pow(2, 3), ht.array(8)))
self.assertTrue(ht.equal(ht.pow(2, 3.5), ht.array(11.313708498984761)))

# test exceptions
with self.assertRaises(ValueError):
ht.pow(self.a_tensor, self.another_vector)
with self.assertRaises(TypeError):
Expand Down

0 comments on commit 9f6b2eb

Please sign in to comment.