Skip to content

Commit

Permalink
[refactor] Cleanup python imports (#2226)
Browse files Browse the repository at this point in the history
  • Loading branch information
k-ye authored Mar 22, 2021
1 parent 904181b commit adfe729
Show file tree
Hide file tree
Showing 18 changed files with 243 additions and 263 deletions.
8 changes: 4 additions & 4 deletions python/taichi/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .util import *
from .settings import *
from .record import *
from .unit import unit
from taichi.core.util import *
from taichi.core.settings import *
from taichi.core.record import *
from taichi.core.logging import *

ti_core.build = build
ti_core.load_module = load_module
Expand Down
50 changes: 50 additions & 0 deletions python/taichi/core/logging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import inspect
import os

from taichi.core import util


def get_logging(name):
def logger(msg, *args, **kwargs):
# Python inspection takes time (~0.1ms) so avoid it as much as possible
if util.ti_core.logging_effective(name):
msg_formatted = msg.format(*args, **kwargs)
func = getattr(util.ti_core, name)
frame = inspect.currentframe().f_back
file_name, lineno, func_name, _, _ = inspect.getframeinfo(frame)
file_name = os.path.basename(file_name)
msg = f'[{file_name}:{func_name}@{lineno}] {msg_formatted}'
func(msg)

return logger


def set_logging_level(level):
util.ti_core.set_logging_level(level)


DEBUG = 'debug'
TRACE = 'trace'
INFO = 'info'
WARN = 'warn'
ERROR = 'error'
CRITICAL = 'critical'

debug = get_logging(DEBUG)
trace = get_logging(TRACE)
info = get_logging(INFO)
warn = get_logging(WARN)
error = get_logging(ERROR)
critical = get_logging(CRITICAL)


def _get_file_name(asc=0):
return inspect.stack()[1 + asc][1]


def _get_function_name(asc=0):
return inspect.stack()[1 + asc][3]


def _get_line_number(asc=0):
return inspect.stack()[1 + asc][2]
29 changes: 0 additions & 29 deletions python/taichi/core/unit.py

This file was deleted.

1 change: 0 additions & 1 deletion python/taichi/core/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,6 @@ def get_unique_task_id():
try:
import_ti_core(tmp_dir)
except Exception as e:
from colorama import Fore, Back, Style
print_red_bold("Taichi core import failed: ", end='')
print(e)
print(
Expand Down
10 changes: 2 additions & 8 deletions python/taichi/lang/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
import os
from copy import deepcopy as _deepcopy

from taichi.lang import type_factory as type_factory_mod
from taichi.lang.impl import *
from taichi.lang.matrix import Matrix, Vector
from taichi.lang.ndrange import GroupedNDRange, ndrange
from taichi.lang.quant_impl import quant
from taichi.lang.runtime_ops import async_flush, sync
from taichi.lang.transformer import TaichiSyntaxError
from taichi.lang.type_factory_impl import type_factory
from taichi.lang.util import deprecated

core = taichi_lang_core
Expand Down Expand Up @@ -54,10 +55,6 @@
# Legacy API
type_factory_ = core.get_type_factory_instance()

# Unstable API
quant = type_factory_mod.Quant
type_factory = type_factory_mod.TypeFactory()


def memory_profiler_print():
get_runtime().materialize()
Expand Down Expand Up @@ -330,9 +327,6 @@ def visit(node):
visit(ti.root)


lang_core = core


def benchmark(func, repeat=300, args=()):
import time

Expand Down
26 changes: 14 additions & 12 deletions python/taichi/lang/expr.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from taichi.lang import impl, runtime_ops
from taichi.lang import impl
from taichi.lang.common_ops import TaichiOperations
from taichi.lang.core import taichi_lang_core
from taichi.lang.util import (deprecated, is_taichi_class, python_scope,
to_numpy_type, to_pytorch_type)

import taichi as ti


# Scalar, basic data type
class Expr(TaichiOperations):
Expand Down Expand Up @@ -122,7 +124,7 @@ def clear(self, deactivate=False):
@python_scope
def fill(self, val):
# TODO: avoid too many template instantiations
from .meta import fill_tensor
from taichi.lang.meta import fill_tensor
fill_tensor(self, val)

def parent(self, n=1):
Expand All @@ -134,7 +136,7 @@ def is_global(self):

@property
def snode(self):
from .snode import SNode
from taichi.lang.snode import SNode
return SNode(self.ptr.snode())

def __hash__(self):
Expand Down Expand Up @@ -165,22 +167,24 @@ def data_type(self):

@python_scope
def to_numpy(self):
from .meta import tensor_to_ext_arr
import numpy as np

from taichi.lang.meta import tensor_to_ext_arr
arr = np.zeros(shape=self.shape, dtype=to_numpy_type(self.dtype))
tensor_to_ext_arr(self, arr)
runtime_ops.sync()
ti.sync()
return arr

@python_scope
def to_torch(self, device=None):
from .meta import tensor_to_ext_arr
import torch

from taichi.lang.meta import tensor_to_ext_arr
arr = torch.zeros(size=self.shape,
dtype=to_pytorch_type(self.dtype),
device=device)
tensor_to_ext_arr(self, arr)
runtime_ops.sync()
ti.sync()
return arr

@python_scope
Expand All @@ -193,7 +197,7 @@ def from_numpy(self, arr):
if hasattr(arr, 'contiguous'):
arr = arr.contiguous()
ext_arr_to_tensor(arr, self)
runtime_ops.sync()
ti.sync()

@python_scope
def from_torch(self, arr):
Expand Down Expand Up @@ -223,19 +227,17 @@ def __repr__(self):


def make_var_vector(size):
from taichi.lang.matrix import Vector
exprs = []
for _ in range(size):
exprs.append(taichi_lang_core.make_id_expr(''))
return Vector(exprs)
return ti.Vector(exprs)


def make_expr_group(*exprs):
if len(exprs) == 1:
from taichi.lang.matrix import Matrix
if isinstance(exprs[0], (list, tuple)):
exprs = exprs[0]
elif isinstance(exprs[0], Matrix):
elif isinstance(exprs[0], ti.Matrix):
mat = exprs[0]
assert mat.m == 1
exprs = mat.entries
Expand Down
Loading

0 comments on commit adfe729

Please sign in to comment.