Skip to content

Commit

Permalink
[Lang] [refactor] shows an error if not called from correct Taichi/Py…
Browse files Browse the repository at this point in the history
…thon-scope (#1121)

* [Lang] [refactor] shows an error if not called from correct Taichi/Python-scope

* [skip ci] fix deprecation warning (instead of raising error)

* also use deprecated deco for classkernel & classfunc

* [skip ci] format

* [skip ci] deprecation warning use stacklevel

* [skip ci] enforce code format

* [skip ci] more taichi_scope

Co-authored-by: Taichi Gardener <taichigardener@gmail.com>
  • Loading branch information
archibate and taichi-gardener authored Jun 4, 2020
1 parent 2ad7d22 commit 7536eba
Show file tree
Hide file tree
Showing 7 changed files with 96 additions and 16 deletions.
3 changes: 3 additions & 0 deletions python/taichi/core/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import shutil
import sys
import ctypes
import warnings
from pathlib import Path
from colorama import Fore, Back, Style
from taichi.misc.settings import get_output_directory, get_build_directory, get_bin_directory, get_repo_directory, get_runtime_directory
Expand All @@ -13,6 +14,8 @@
print("Current Python version:", sys.version_info)
exit(-1)

warnings.filterwarnings('always')

ti_core = None


Expand Down
13 changes: 11 additions & 2 deletions python/taichi/lang/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(self, *args, tb=None):
self.grad = None
self.val = self

@python_scope
def __setitem__(self, key, value):
if not Expr.layout_materialized:
self.materialize_layout_callback()
Expand All @@ -50,9 +51,8 @@ def __setitem__(self, key, value):
(taichi_lang_core.get_max_num_indices() - len(key)))
self.setter(value, *key)

@python_scope
def __getitem__(self, key):
import taichi as ti
assert not ti.get_runtime().inside_kernel
if not Expr.layout_materialized:
self.materialize_layout_callback()
self.initialize_accessor()
Expand All @@ -70,6 +70,7 @@ def loop_range(self):
def serialize(self):
return self.ptr.serialize()

@python_scope
def initialize_accessor(self):
if self.getter:
return
Expand Down Expand Up @@ -103,16 +104,19 @@ def setter(value, *key):
self.getter = getter
self.setter = setter

@python_scope
def set_grad(self, grad):
self.grad = grad
self.ptr.set_grad(grad.ptr)

@python_scope
def clear(self, deactivate=False):
assert not deactivate
node = self.ptr.snode().parent
assert node
node.clear_data()

@python_scope
def fill(self, val):
# TODO: avoid too many template instantiations
from .meta import fill_tensor
Expand Down Expand Up @@ -145,6 +149,7 @@ def shape(self):
def data_type(self):
return self.snode().data_type()

@python_scope
def to_numpy(self):
from .meta import tensor_to_ext_arr
import numpy as np
Expand All @@ -155,6 +160,7 @@ def to_numpy(self):
ti.sync()
return arr

@python_scope
def to_torch(self, device=None):
from .meta import tensor_to_ext_arr
import torch
Expand All @@ -166,6 +172,7 @@ def to_torch(self, device=None):
ti.sync()
return arr

@python_scope
def from_numpy(self, arr):
assert self.dim() == len(arr.shape)
s = self.shape()
Expand All @@ -178,9 +185,11 @@ def from_numpy(self, arr):
import taichi as ti
ti.sync()

@python_scope
def from_torch(self, arr):
self.from_numpy(arr.contiguous())

@python_scope
def copy_from(self, other):
assert isinstance(other, Expr)
from .meta import tensor_to_tensor
Expand Down
10 changes: 10 additions & 0 deletions python/taichi/lang/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .util import *


@taichi_scope
def expr_init(rhs):
import taichi as ti
if rhs is None:
Expand All @@ -29,6 +30,7 @@ def expr_init(rhs):
return Expr(taichi_lang_core.expr_var(Expr(rhs).ptr))


@taichi_scope
def expr_init_func(rhs): # temporary solution to allow passing in tensors as
import taichi as ti
if isinstance(rhs, Expr) and rhs.ptr.is_global_var():
Expand All @@ -45,6 +47,7 @@ def wrap_scalar(x):
return x


@taichi_scope
def subscript(value, *indices):
import numpy as np
if isinstance(value, np.ndarray):
Expand Down Expand Up @@ -76,6 +79,7 @@ def subscript(value, *indices):
return Expr(taichi_lang_core.subscript(value.ptr, indices_expr_group))


@taichi_scope
def chain_compare(comparators, ops):
assert len(comparators) == len(ops) + 1, \
f'Chain comparison invoked with {len(comparators)} comparators but {len(ops)} operators'
Expand Down Expand Up @@ -180,6 +184,7 @@ def get_runtime():
return pytaichi


@taichi_scope
def make_constant_expr(val):
if isinstance(val, int):
if pytaichi.default_ip == i32:
Expand Down Expand Up @@ -229,6 +234,7 @@ def __getattribute__(self, item):
root = Root()


@python_scope
def var(dt, shape=None, offset=None, needs_grad=False):
if isinstance(shape, numbers.Number):
shape = (shape, )
Expand Down Expand Up @@ -277,6 +283,7 @@ def __init__(self, soa=False):
AOS = Layout(soa=False)


@python_scope
def layout(func):
assert not pytaichi.materialized, "All layout must be specified before the first kernel launch / data access."
warnings.warn(
Expand All @@ -286,6 +293,7 @@ def layout(func):
pytaichi.layout_functions.append(func)


@taichi_scope
def ti_print(*vars):
def entry2content(var):
if isinstance(var, str):
Expand Down Expand Up @@ -327,13 +335,15 @@ def fused_string(entries):
taichi_lang_core.create_print(contentries)


@taichi_scope
def ti_int(var):
if hasattr(var, '__ti_int__'):
return var.__ti_int__()
else:
return int(var)


@taichi_scope
def ti_float(var):
if hasattr(var, '__ti_float__'):
return var.__ti_float__()
Expand Down
10 changes: 2 additions & 8 deletions python/taichi/lang/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,11 +121,8 @@ def extract_arguments(self):
self.argument_names.append(param.name)


@deprecated('@ti.classfunc', '@ti.func directly')
def classfunc(foo):
import warnings
warnings.warn('@ti.classfunc is deprecated. Please use @ti.func directly.',
DeprecationWarning)

func = Func(foo, classfunc=True)

@functools.wraps(foo)
Expand Down Expand Up @@ -545,11 +542,8 @@ def kernel(func):
return _kernel_impl(func, level_of_class_stackframe=3)


@deprecated('@ti.classkernel', '@ti.kernel directly')
def classkernel(func):
import warnings
warnings.warn(
'@ti.classkernel is deprecated. Please use @ti.kernel directly.',
DeprecationWarning)
return _kernel_impl(func, level_of_class_stackframe=3)


Expand Down
Loading

0 comments on commit 7536eba

Please sign in to comment.