forked from ROCm/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_approx.cpp
96 lines (79 loc) · 2.94 KB
/
test_approx.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
#ifdef TORCH_ENABLE_LLVM
#include <gtest/gtest.h>
#include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
#include <torch/csrc/jit/tensorexpr/llvm_codegen.h>
#include <torch/csrc/jit/tensorexpr/loopnest.h>
#include <torch/csrc/jit/tensorexpr/tensor.h>
#include <torch/torch.h>
#include <cstring>
using namespace torch::indexing;
namespace te = torch::jit::tensorexpr;
static void vectorize(te::LoopNest* ln, te::Tensor target, int width) {
auto loops = ln->getLoopStmtsFor(target);
te::ForPtr inner, tail;
ln->splitWithTail(loops[0], width, &inner, &tail);
ASSERT_TRUE(te::LoopNest::vectorize(inner));
}
std::string diffs(const at::Tensor& a, const at::Tensor& b) {
auto diff = torch::abs(a.flatten() - b.flatten());
auto count_diffs = torch::sum(diff > 0.f);
auto greatest_diff_index = torch::argmax(diff);
std::stringstream ss;
ss << "Found " << count_diffs << " unequal element(s). "
<< "The greatest difference was " << diff.index({greatest_diff_index})
<< " at index " << greatest_diff_index;
return ss.str();
}
TEST(Approx, log_vml) {
te::VarHandle N("N", te::kInt);
te::BufHandle A("A", {N}, te::kFloat);
te::Tensor B = te::Compute(
"B", {N}, [&](const te::VarHandle& i) { return log_vml(A.load(i)); });
te::LoopNest ln({B});
ln.prepareForCodegen();
vectorize(&ln, B, 8);
te::StmtPtr s = ln.root_stmt();
s = te::IRSimplifier::simplify(s);
te::LLVMCodeGen cg(s, {A, B, N});
auto eps = std::numeric_limits<float>::epsilon();
auto test = [&](const at::Tensor& A_t) {
at::Tensor B_ref = at::log(A_t);
at::Tensor B_t = at::empty_like(A_t);
auto ap = A_t.data_ptr<float>();
auto bp = B_t.data_ptr<float>();
cg.call({ap, bp, A_t.numel()});
// Results should be bit-identical.
ASSERT_TRUE(torch::allclose(
B_t, B_ref, /*rtol=*/eps, /*atol=*/0.0f, /*equal_nan=*/true))
<< "Input[:8]\n"
<< A_t.index({Slice(0, 8)}) << "\n"
<< "Test[:8]\n"
<< B_t.index({Slice(0, 8)}) << "\n"
<< "Ref[:8]\n"
<< B_ref.index({Slice(0, 8)}) << diffs(B_t, B_ref);
};
// Generate every single-precision FP value in [1.0, 2.0).
at::Tensor A_t = torch::arange(1.0f, 2.0f, eps);
ASSERT_EQ(A_t.numel(), 1 << 23);
test(A_t);
test(A_t * 2.0f);
test(A_t * 0.5f);
test(A_t * 4.0f);
test(A_t * 0.25f);
test(A_t * powf(2.0f, 16));
test(A_t * powf(2.0f, -16));
test(A_t * powf(2.0f, 126));
test(A_t * powf(2.0f, -126));
test(torch::full({32}, INFINITY));
test(torch::full({32}, NAN));
auto min = std::numeric_limits<float>::min();
auto denorm_min = std::numeric_limits<float>::denorm_min();
// Denormals aren't bit precise, because sleef isn't bit-precise either.
A_t = torch::arange(0.0f, min, denorm_min);
ASSERT_EQ(A_t.numel(), 1 << 23);
auto B_ref = at::log(A_t);
auto B_t = at::empty_like(B_ref);
cg.call({A_t.data_ptr<float>(), B_t.data_ptr<float>(), A_t.numel()});
ASSERT_TRUE(torch::allclose(B_t, B_ref));
}
#endif // TORCH_ENABLE_LLVM