Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Lang] [refactor] shows an error if not called from correct Taichi/Python-scope #1121

Merged
merged 9 commits into from
Jun 4, 2020
3 changes: 3 additions & 0 deletions python/taichi/core/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import shutil
import sys
import ctypes
import warnings
from pathlib import Path
from colorama import Fore, Back, Style
from taichi.misc.settings import get_output_directory, get_build_directory, get_bin_directory, get_repo_directory, get_runtime_directory
Expand All @@ -13,6 +14,8 @@
print("Current Python version:", sys.version_info)
exit(-1)

warnings.filterwarnings('always')

ti_core = None


Expand Down
13 changes: 11 additions & 2 deletions python/taichi/lang/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(self, *args, tb=None):
self.grad = None
self.val = self

@python_scope
def __setitem__(self, key, value):
if not Expr.layout_materialized:
self.materialize_layout_callback()
Expand All @@ -50,9 +51,8 @@ def __setitem__(self, key, value):
(taichi_lang_core.get_max_num_indices() - len(key)))
self.setter(value, *key)

@python_scope
def __getitem__(self, key):
import taichi as ti
assert not ti.get_runtime().inside_kernel
if not Expr.layout_materialized:
self.materialize_layout_callback()
self.initialize_accessor()
Expand All @@ -70,6 +70,7 @@ def loop_range(self):
def serialize(self):
return self.ptr.serialize()

@python_scope
def initialize_accessor(self):
if self.getter:
return
Expand Down Expand Up @@ -103,16 +104,19 @@ def setter(value, *key):
self.getter = getter
self.setter = setter

@python_scope
def set_grad(self, grad):
self.grad = grad
self.ptr.set_grad(grad.ptr)

@python_scope
def clear(self, deactivate=False):
assert not deactivate
node = self.ptr.snode().parent
assert node
node.clear_data()

@python_scope
def fill(self, val):
# TODO: avoid too many template instantiations
from .meta import fill_tensor
Expand Down Expand Up @@ -145,6 +149,7 @@ def shape(self):
def data_type(self):
return self.snode().data_type()

@python_scope
def to_numpy(self):
from .meta import tensor_to_ext_arr
import numpy as np
Expand All @@ -155,6 +160,7 @@ def to_numpy(self):
ti.sync()
return arr

@python_scope
def to_torch(self, device=None):
from .meta import tensor_to_ext_arr
import torch
Expand All @@ -166,6 +172,7 @@ def to_torch(self, device=None):
ti.sync()
return arr

@python_scope
def from_numpy(self, arr):
assert self.dim() == len(arr.shape)
s = self.shape()
Expand All @@ -178,9 +185,11 @@ def from_numpy(self, arr):
import taichi as ti
ti.sync()

@python_scope
def from_torch(self, arr):
self.from_numpy(arr.contiguous())

@python_scope
def copy_from(self, other):
assert isinstance(other, Expr)
from .meta import tensor_to_tensor
Expand Down
10 changes: 10 additions & 0 deletions python/taichi/lang/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .util import *


@taichi_scope
def expr_init(rhs):
import taichi as ti
if rhs is None:
Expand All @@ -29,6 +30,7 @@ def expr_init(rhs):
return Expr(taichi_lang_core.expr_var(Expr(rhs).ptr))


@taichi_scope
def expr_init_func(rhs): # temporary solution to allow passing in tensors as
import taichi as ti
if isinstance(rhs, Expr) and rhs.ptr.is_global_var():
Expand All @@ -45,6 +47,7 @@ def wrap_scalar(x):
return x


@taichi_scope
def subscript(value, *indices):
import numpy as np
if isinstance(value, np.ndarray):
Expand Down Expand Up @@ -76,6 +79,7 @@ def subscript(value, *indices):
return Expr(taichi_lang_core.subscript(value.ptr, indices_expr_group))


@taichi_scope
def chain_compare(comparators, ops):
assert len(comparators) == len(ops) + 1, \
f'Chain comparison invoked with {len(comparators)} comparators but {len(ops)} operators'
Expand Down Expand Up @@ -180,6 +184,7 @@ def get_runtime():
return pytaichi


@taichi_scope
def make_constant_expr(val):
if isinstance(val, int):
if pytaichi.default_ip == i32:
Expand Down Expand Up @@ -229,6 +234,7 @@ def __getattribute__(self, item):
root = Root()


@python_scope
def var(dt, shape=None, offset=None, needs_grad=False):
if isinstance(shape, numbers.Number):
shape = (shape, )
Expand Down Expand Up @@ -277,6 +283,7 @@ def __init__(self, soa=False):
AOS = Layout(soa=False)


@python_scope
def layout(func):
assert not pytaichi.materialized, "All layout must be specified before the first kernel launch / data access."
warnings.warn(
Expand All @@ -286,6 +293,7 @@ def layout(func):
pytaichi.layout_functions.append(func)


@taichi_scope
def ti_print(*vars):
def entry2content(var):
if isinstance(var, str):
Expand Down Expand Up @@ -327,13 +335,15 @@ def fused_string(entries):
taichi_lang_core.create_print(contentries)


@taichi_scope
def ti_int(var):
if hasattr(var, '__ti_int__'):
return var.__ti_int__()
else:
return int(var)


@taichi_scope
def ti_float(var):
if hasattr(var, '__ti_float__'):
return var.__ti_float__()
Expand Down
10 changes: 2 additions & 8 deletions python/taichi/lang/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,11 +121,8 @@ def extract_arguments(self):
self.argument_names.append(param.name)


@deprecated('@ti.classfunc', '@ti.func directly')
def classfunc(foo):
import warnings
warnings.warn('@ti.classfunc is deprecated. Please use @ti.func directly.',
DeprecationWarning)

func = Func(foo, classfunc=True)

@functools.wraps(foo)
Expand Down Expand Up @@ -545,11 +542,8 @@ def kernel(func):
return _kernel_impl(func, level_of_class_stackframe=3)


@deprecated('@ti.classkernel', '@ti.kernel directly')
def classkernel(func):
import warnings
warnings.warn(
'@ti.classkernel is deprecated. Please use @ti.kernel directly.',
DeprecationWarning)
return _kernel_impl(func, level_of_class_stackframe=3)


Expand Down
Loading