Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[numpy] add trunc #2316

Merged
merged 9 commits into from
Sep 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions integration_tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,7 @@ RUN(NAME elemental_09 LABELS cpython llvm c NOFAST)
RUN(NAME elemental_10 LABELS cpython llvm c NOFAST)
RUN(NAME elemental_11 LABELS cpython llvm c NOFAST)
RUN(NAME elemental_12 LABELS cpython llvm c NOFAST)
RUN(NAME elemental_13 LABELS cpython llvm c NOFAST)
RUN(NAME test_random LABELS cpython llvm NOFAST)
RUN(NAME test_os LABELS cpython llvm c NOFAST)
RUN(NAME test_builtin LABELS cpython llvm c)
Expand Down
64 changes: 64 additions & 0 deletions integration_tests/elemental_13.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from lpython import f32, f64
from numpy import trunc, empty, sqrt, reshape, int32, float32, float64


def elemental_trunc64():
i: i32
j: i32
k: i32
l: i32
eps: f32
eps = f32(1e-6)

arraynd: f64[32, 16, 8, 4] = empty((32, 16, 8, 4), dtype=float64)

newshape: i32[1] = empty(1, dtype = int32)
newshape[0] = 16384

for i in range(32):
for j in range(16):
for k in range(8):
for l in range(4):
arraynd[i, j, k, l] = f64((-1)**l) * sqrt(float(i + j + j + l))

observed: f64[32, 16, 8, 4] = empty((32, 16, 8, 4), dtype=float64)
observed = trunc(arraynd)

observed1d: f64[16384] = empty(16384, dtype=float64)
observed1d = reshape(observed, newshape)

array: f64[16384] = empty(16384, dtype=float64)
array = reshape(arraynd, newshape)

for i in range(16384):
assert f32(abs(trunc(array[i]) - observed1d[i])) <= eps


def elemental_trunc32():
i: i32
j: i32
k: i32
l: i32
eps: f32
eps = f32(1e-6)

arraynd: f32[32, 16, 8, 4] = empty((32, 16, 8, 4), dtype=float32)

for i in range(32):
for j in range(16):
for k in range(8):
for l in range(4):
arraynd[i, j, k, l] = f32(f64((-1)**l) * sqrt(float(i + j + j + l)))

observed: f32[32, 16, 8, 4] = empty((32, 16, 8, 4), dtype=float32)
observed = trunc(arraynd)

for i in range(32):
for j in range(16):
for k in range(8):
for l in range(4):
assert abs(trunc(arraynd[i, j, k, l]) - observed[i, j, k, l]) <= eps


elemental_trunc64()
elemental_trunc32()
1 change: 1 addition & 0 deletions src/libasr/codegen/asr_to_c_cpp.h
Original file line number Diff line number Diff line change
Expand Up @@ -2790,6 +2790,7 @@ PyMODINIT_FUNC PyInit_lpython_module_)" + fn_name + R"((void) {
SET_INTRINSIC_NAME(Exp, "exp");
SET_INTRINSIC_NAME(Exp2, "exp2");
SET_INTRINSIC_NAME(Expm1, "expm1");
SET_INTRINSIC_NAME(Trunc, "trunc");
default : {
throw LCompilersException("IntrinsicScalarFunction: `"
+ ASRUtils::get_intrinsic_name(x.m_intrinsic_id)
Expand Down
1 change: 1 addition & 0 deletions src/libasr/codegen/asr_to_julia.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1899,6 +1899,7 @@ class ASRToJuliaVisitor : public ASR::BaseVisitor<ASRToJuliaVisitor>
SET_INTRINSIC_NAME(Exp, "exp");
SET_INTRINSIC_NAME(Exp2, "exp2");
SET_INTRINSIC_NAME(Expm1, "expm1");
SET_INTRINSIC_NAME(Trunc, "trunc");
default : {
throw LCompilersException("IntrinsicFunction: `"
+ ASRUtils::get_intrinsic_name(x.m_intrinsic_id)
Expand Down
46 changes: 46 additions & 0 deletions src/libasr/pass/intrinsic_function_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ enum class IntrinsicScalarFunctions : int64_t {
Atan2,
Gamma,
LogGamma,
Trunc,
Abs,
Exp,
Exp2,
Expand Down Expand Up @@ -96,6 +97,7 @@ inline std::string get_intrinsic_name(int x) {
INTRINSIC_NAME_CASE(Atan2)
INTRINSIC_NAME_CASE(Gamma)
INTRINSIC_NAME_CASE(LogGamma)
INTRINSIC_NAME_CASE(Trunc)
INTRINSIC_NAME_CASE(Abs)
INTRINSIC_NAME_CASE(Exp)
INTRINSIC_NAME_CASE(Exp2)
Expand Down Expand Up @@ -1142,6 +1144,44 @@ static inline ASR::expr_t* instantiate_LogGamma (Allocator &al,

} // namespace LogGamma

#define create_trunc_macro(X, stdeval) \
namespace X { \
static inline ASR::expr_t *eval_##X(Allocator &al, const Location &loc, \
ASR::ttype_t *t, Vec<ASR::expr_t*>& args) { \
LCOMPILERS_ASSERT(args.size() == 1); \
double rv = ASR::down_cast<ASR::RealConstant_t>(args[0])->m_r; \
if (ASRUtils::extract_value(args[0], rv)) { \
double val = std::stdeval(rv); \
return make_ConstantWithType(make_RealConstant_t, val, t, loc); \
} \
return nullptr; \
} \
static inline ASR::asr_t* create_##X(Allocator& al, const Location& loc, \
Vec<ASR::expr_t*>& args, \
const std::function<void (const std::string &, const Location &)> err) { \
ASR::ttype_t *type = ASRUtils::expr_type(args[0]); \
if (args.n != 1) { \
err("Intrinsic `#X` accepts exactly one argument", loc); \
} else if (!ASRUtils::is_real(*type)) { \
err("`x` argument of `#X` must be real", \
args[0]->base.loc); \
} \
return UnaryIntrinsicFunction::create_UnaryFunction(al, loc, args, \
eval_##X, static_cast<int64_t>(IntrinsicScalarFunctions::Trunc), \
0, type); \
} \
static inline ASR::expr_t* instantiate_##X (Allocator &al, \
const Location &loc, SymbolTable *scope, Vec<ASR::ttype_t*>& arg_types, \
ASR::ttype_t *return_type, Vec<ASR::call_arg_t>& new_args, \
int64_t overload_id) { \
ASR::ttype_t* arg_type = arg_types[0]; \
return UnaryIntrinsicFunction::instantiate_functions(al, loc, scope, \
"#X", arg_type, return_type, new_args, overload_id); \
} \
} // namespace X

create_trunc_macro(Trunc, trunc)

// `X` is the name of the function in the IntrinsicScalarFunctions enum and
// we use the same name for `create_X` and other places
// `stdeval` is the name of the function in the `std` namespace for compile
Expand Down Expand Up @@ -2879,6 +2919,8 @@ namespace IntrinsicScalarFunctionRegistry {
verify_function>>& intrinsic_function_by_id_db = {
{static_cast<int64_t>(IntrinsicScalarFunctions::LogGamma),
{&LogGamma::instantiate_LogGamma, &UnaryIntrinsicFunction::verify_args}},
{static_cast<int64_t>(IntrinsicScalarFunctions::Trunc),
{&Trunc::instantiate_Trunc, &UnaryIntrinsicFunction::verify_args}},
{static_cast<int64_t>(IntrinsicScalarFunctions::Sin),
{&Sin::instantiate_Sin, &UnaryIntrinsicFunction::verify_args}},
{static_cast<int64_t>(IntrinsicScalarFunctions::Cos),
Expand Down Expand Up @@ -2977,6 +3019,8 @@ namespace IntrinsicScalarFunctionRegistry {
{static_cast<int64_t>(IntrinsicScalarFunctions::LogGamma),
"log_gamma"},

{static_cast<int64_t>(IntrinsicScalarFunctions::Trunc),
"trunc"},
{static_cast<int64_t>(IntrinsicScalarFunctions::Sin),
"sin"},
{static_cast<int64_t>(IntrinsicScalarFunctions::Cos),
Expand Down Expand Up @@ -3074,6 +3118,7 @@ namespace IntrinsicScalarFunctionRegistry {
std::tuple<create_intrinsic_function,
eval_intrinsic_function>>& intrinsic_function_by_name_db = {
{"log_gamma", {&LogGamma::create_LogGamma, &LogGamma::eval_log_gamma}},
{"trunc", {&Trunc::create_Trunc, &Trunc::eval_Trunc}},
{"sin", {&Sin::create_Sin, &Sin::eval_Sin}},
{"cos", {&Cos::create_Cos, &Cos::eval_Cos}},
{"tan", {&Tan::create_Tan, &Tan::eval_Tan}},
Expand Down Expand Up @@ -3134,6 +3179,7 @@ namespace IntrinsicScalarFunctionRegistry {
id_ == IntrinsicScalarFunctions::Cos ||
id_ == IntrinsicScalarFunctions::Gamma ||
id_ == IntrinsicScalarFunctions::LogGamma ||
id_ == IntrinsicScalarFunctions::Trunc ||
id_ == IntrinsicScalarFunctions::Sin ||
id_ == IntrinsicScalarFunctions::Exp ||
id_ == IntrinsicScalarFunctions::Exp2 ||
Expand Down
12 changes: 12 additions & 0 deletions src/libasr/runtime/lfortran_intrinsics.c
Original file line number Diff line number Diff line change
Expand Up @@ -1134,6 +1134,18 @@ LFORTRAN_API double_complex_t _lfortran_zatanh(double_complex_t x)
return catanh(x);
}

// trunc -----------------------------------------------------------------------

LFORTRAN_API float _lfortran_strunc(float x)
{
return truncf(x);
}

LFORTRAN_API double _lfortran_dtrunc(double x)
{
return trunc(x);
}

// phase --------------------------------------------------------------------

LFORTRAN_API float _lfortran_cphase(float_complex_t x)
Expand Down
2 changes: 2 additions & 0 deletions src/libasr/runtime/lfortran_intrinsics.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,8 @@ LFORTRAN_API float _lfortran_satanh(float x);
LFORTRAN_API double _lfortran_datanh(double x);
LFORTRAN_API float_complex_t _lfortran_catanh(float_complex_t x);
LFORTRAN_API double_complex_t _lfortran_zatanh(double_complex_t x);
LFORTRAN_API float _lfortran_strunc(float x);
LFORTRAN_API double _lfortran_dtrunc(double x);
LFORTRAN_API float _lfortran_cphase(float_complex_t x);
LFORTRAN_API double _lfortran_zphase(double_complex_t x);
LFORTRAN_API bool _lpython_str_compare_eq(char** s1, char** s2);
Expand Down
2 changes: 1 addition & 1 deletion src/lpython/semantics/python_ast_to_asr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7311,7 +7311,7 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
if (!s) {
std::string intrinsic_name = call_name;
std::set<std::string> not_cpython_builtin = {
"sin", "cos", "gamma", "tan", "asin", "acos", "atan", "sinh", "cosh", "tanh", "exp", "exp2", "expm1", "Symbol", "diff", "expand",
"sin", "cos", "gamma", "tan", "asin", "acos", "atan", "sinh", "cosh", "tanh", "exp", "exp2", "expm1", "Symbol", "diff", "expand", "trunc",
"sum" // For sum called over lists
};
std::set<std::string> symbolic_functions = {
Expand Down
20 changes: 20 additions & 0 deletions src/runtime/lpython_intrinsic_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,3 +410,23 @@ def ceil(x: f32) -> f32:
if x <= f32(0) or x == resultf:
return resultf
return resultf + f32(1)

########## trunc ##########

@ccall
def _lfortran_dtrunc(x: f64) -> f64:
pass

@overload
@vectorize
def trunc(x: f64) -> f64:
return _lfortran_dtrunc(x)

@ccall
def _lfortran_strunc(x: f32) -> f32:
pass

@overload
@vectorize
def trunc(x: f32) -> f32:
return _lfortran_strunc(x)
certik marked this conversation as resolved.
Show resolved Hide resolved
2 changes: 1 addition & 1 deletion tests/reference/asr-array_01_decl-39cf894.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"outfile": null,
"outfile_hash": null,
"stdout": "asr-array_01_decl-39cf894.stdout",
"stdout_hash": "137a0c427925ba7da2e7151f2cf52bfa9a64fede11fe8d2653f20b64",
"stdout_hash": "2aa47467473392c970bb1ddde961e3007d4c157bb0ea507b5e0db4a4",
"stderr": null,
khushi-411 marked this conversation as resolved.
Show resolved Hide resolved
"stderr_hash": null,
"returncode": 0
Expand Down
Loading