Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[wasm] Fix cos and tan for large float numbers #7689

Merged
merged 12 commits into from
May 18, 2023
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tfjs-backend-wasm/src/cc/kernels/Cos.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ extern "C" {
EMSCRIPTEN_KEEPALIVE
#endif
void Cos(const int x_id, const DType dtype, const int out_id) {
unary_f32(x_id, out_id, tfjs::sin_cos_workaround::cos_fixed);
unary_f32(x_id, out_id, tfjs::sin_cos_workaround::CosFixed);
}

} // extern "C"
Expand Down
2 changes: 1 addition & 1 deletion tfjs-backend-wasm/src/cc/kernels/Sin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ extern "C" {
EMSCRIPTEN_KEEPALIVE
#endif
void Sin(const int x_id, const DType dtype, const int out_id) {
unary_f32(x_id, out_id, tfjs::sin_cos_workaround::sin_fixed);
unary_f32(x_id, out_id, tfjs::sin_cos_workaround::SinFixed);
}

} // extern "C"
Expand Down
2 changes: 1 addition & 1 deletion tfjs-backend-wasm/src/cc/kernels/Tan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ extern "C" {
EMSCRIPTEN_KEEPALIVE
#endif
void Tan(const int x_id, const DType dtype, const int out_id) {
unary_f32(x_id, out_id, tfjs::sin_cos_workaround::tan_fixed);
unary_f32(x_id, out_id, tfjs::sin_cos_workaround::TanFixed);
}

} // extern "C"
Expand Down
68 changes: 50 additions & 18 deletions tfjs-backend-wasm/src/cc/sin_cos_workaround.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,35 +11,67 @@
* See the License for the specific language governing permissions and
* limitations under the License.
* ===========================================================================*/
#include <math.h>

#include <cmath>

#include "tfjs-backend-wasm/src/cc/sin_cos_workaround.h"

namespace tfjs {
namespace sin_cos_workaround {

float sin_fixed(float x) {
if (isnan(x)) return nan("");
auto zero_to_2pi = fmod(fmod(x, 2 * M_PI) + 2 * M_PI, 2 * M_PI);

if (zero_to_2pi < M_PI_4) {
return sin(zero_to_2pi);
} else if (zero_to_2pi < M_PI_2) {
auto past_pi_4 = zero_to_2pi - M_PI_4;
return cos(M_PI_4 - past_pi_4);
} else if (zero_to_2pi < M_PI) {
auto past_pi_2 = zero_to_2pi - M_PI_2;
return sin_fixed(M_PI_2 - past_pi_2);
namespace {

template <typename T>
inline T ShiftRadianToZeroTo2PI(const T& x) {
if (std::isnan(x)) {
return x;
}
return std::fmod(std::fmod(x, 2 * M_PI) + 2 * M_PI, 2 * M_PI);
}

template <typename T>
inline T SinZeroTo2PI(T x) {
if (std::isnan(x)) {
return x;
}

if (x < M_PI_4) {
return std::sin(x);
} else if (x < M_PI_2) {
return std::cos(M_PI_2 - x);
} else if (x < M_PI) {
return SinZeroTo2PI<T, /*is_shifted=*/true>(M_PI - x);
} else {
return -sin_fixed(2 * M_PI - zero_to_2pi);
return -SinZeroTo2PI<T, /*is_shifted=*/true>(2 * M_PI - x);
mattsoulanille marked this conversation as resolved.
Show resolved Hide resolved
}
}

float cos_fixed(float x) { return sin_fixed(x + M_PI_2); }
template <typename T>
inline T CosZeroTo2PI(T x) {
if (std::isnan(x)) {
return x;
}

if (x < M_PI_4) {
return std::cos(x);
} else if (x < M_PI_2) {
return std::sin(M_PI_2 - x);
} else if (x < M_PI) {
return -CosZeroTo2PI<T, /*is_shifted=*/true>(M_PI - x);
} else {
return CosZeroTo2PI<T, /*is_shifted=*/true>(2 * M_PI - x);
}
}

} // namespace

float SinFixed(float x) { return SinZeroTo2PI(ShiftRadianToZeroTo2PI(x)); }

float CosFixed(float x) { return CosZeroTo2PI(ShiftRadianToZeroTo2PI(x)); }

float tan_fixed(float x) {
if (isnan(x)) return nan("");
return sin_fixed(x) / cos_fixed(x);
float TanFixed(float x) {
// TODO: Check if this work on iOS 11/12.
return std::tan(x);
}

} // namespace sin_cos_workaround
Expand Down
7 changes: 4 additions & 3 deletions tfjs-backend-wasm/src/cc/sin_cos_workaround.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@
namespace tfjs {
namespace sin_cos_workaround {

float sin_fixed(float x);
float SinFixed(float x);

float cos_fixed(float x);
float CosFixed(float x);

float tan_fixed(float x);
float TanFixed(float x);

} // namespace sin_cos_workaround
} // namespace tfjs

#endif // SIN_COS_WORKAROUND_H_
9 changes: 7 additions & 2 deletions tfjs-backend-webgl/src/setup_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,17 @@ import {registerTestEnvs} from './backend_webgl_test_registry';
registerTestEnvs();

const TEST_FILTERS: TestFilter[] = [];

const customInclude = (testName: string) => {
const toExclude = [
'isBrowser: false', 'dilation gradient',
'isBrowser: false',
'dilation gradient',
'throws when index is out of bound',
// otsu tests for threshold op is failing on windows
'method otsu', 'Draw on 2d context'
'method otsu',
'Draw on 2d context',
// https://github.com/tensorflow/tfjs/issues/7618
'numbers exceed float32 precision',
];
for (const subStr of toExclude) {
if (testName.includes(subStr)) {
Expand Down
2 changes: 2 additions & 0 deletions tfjs-backend-webgpu/src/setup_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ const TEST_FILTERS: TestFilter[] = [
excludes: [
'gradients', // Failing on MacOS
//'gradient with clones', // Failing on MacOS
// https://github.com/tensorflow/tfjs/issues/7618
'numbers exceed float32 precision',
],
},
{
Expand Down
17 changes: 17 additions & 0 deletions tfjs-core/src/ops/tan_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,23 @@ describeWithFlags('tan', ALL_ENVS, () => {
expectArraysClose(await result.data(), expected);
});

it('numbers exceed float32 precision', async () => {
const values = [
-608065414.8781943,
781902002.7943993,
-470910673.97399473,
1786759246.171617,
1873777868.5510726,
-1015107953.8969269,
830023227.6215034,
];
const a = tf.tensor1d(values, 'float32');
const result = tf.tan(a);

const expected = [...new Float32Array(values).map((v) => Math.tan(v))];
expectArraysClose(await result.data(), expected);
});

it('propagates NaNs', async () => {
const a = tf.tensor1d([4, NaN, 0]);
const res = tf.tan(a);
Expand Down
1 change: 1 addition & 0 deletions tfjs-node/src/run_tests.ts
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ const IGNORE_LIST: string[] = [
'upperBound',
'lowerBound',
'multinomial test-tensorflow {} creates the same data given the same seed',
'tan test-tensorflow {} numbers exceed float32 precision',
];

if (process.platform === 'win32') {
Expand Down