Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[refactor] Cleanup python imports #2226

Merged
merged 1 commit into from
Mar 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions python/taichi/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .util import *
from .settings import *
from .record import *
from .unit import unit
from taichi.core.util import *
from taichi.core.settings import *
from taichi.core.record import *
from taichi.core.logging import *

ti_core.build = build
ti_core.load_module = load_module
Expand Down
50 changes: 50 additions & 0 deletions python/taichi/core/logging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import inspect
import os

from taichi.core import util


def get_logging(name):
def logger(msg, *args, **kwargs):
# Python inspection takes time (~0.1ms) so avoid it as much as possible
if util.ti_core.logging_effective(name):
msg_formatted = msg.format(*args, **kwargs)
func = getattr(util.ti_core, name)
frame = inspect.currentframe().f_back
file_name, lineno, func_name, _, _ = inspect.getframeinfo(frame)
file_name = os.path.basename(file_name)
msg = f'[{file_name}:{func_name}@{lineno}] {msg_formatted}'
func(msg)

return logger


def set_logging_level(level):
util.ti_core.set_logging_level(level)


DEBUG = 'debug'
TRACE = 'trace'
INFO = 'info'
WARN = 'warn'
ERROR = 'error'
CRITICAL = 'critical'

debug = get_logging(DEBUG)
trace = get_logging(TRACE)
info = get_logging(INFO)
warn = get_logging(WARN)
error = get_logging(ERROR)
critical = get_logging(CRITICAL)


def _get_file_name(asc=0):
return inspect.stack()[1 + asc][1]


def _get_function_name(asc=0):
return inspect.stack()[1 + asc][3]


def _get_line_number(asc=0):
return inspect.stack()[1 + asc][2]
29 changes: 0 additions & 29 deletions python/taichi/core/unit.py

This file was deleted.

1 change: 0 additions & 1 deletion python/taichi/core/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,6 @@ def get_unique_task_id():
try:
import_ti_core(tmp_dir)
except Exception as e:
from colorama import Fore, Back, Style
print_red_bold("Taichi core import failed: ", end='')
print(e)
print(
Expand Down
10 changes: 2 additions & 8 deletions python/taichi/lang/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
import os
from copy import deepcopy as _deepcopy

from taichi.lang import type_factory as type_factory_mod
from taichi.lang.impl import *
from taichi.lang.matrix import Matrix, Vector
from taichi.lang.ndrange import GroupedNDRange, ndrange
from taichi.lang.quant_impl import quant
from taichi.lang.runtime_ops import async_flush, sync
from taichi.lang.transformer import TaichiSyntaxError
from taichi.lang.type_factory_impl import type_factory
from taichi.lang.util import deprecated

core = taichi_lang_core
Expand Down Expand Up @@ -54,10 +55,6 @@
# Legacy API
type_factory_ = core.get_type_factory_instance()

# Unstable API
quant = type_factory_mod.Quant
type_factory = type_factory_mod.TypeFactory()


def memory_profiler_print():
get_runtime().materialize()
Expand Down Expand Up @@ -330,9 +327,6 @@ def visit(node):
visit(ti.root)


lang_core = core


def benchmark(func, repeat=300, args=()):
import time

Expand Down
26 changes: 14 additions & 12 deletions python/taichi/lang/expr.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from taichi.lang import impl, runtime_ops
from taichi.lang import impl
from taichi.lang.common_ops import TaichiOperations
from taichi.lang.core import taichi_lang_core
from taichi.lang.util import (deprecated, is_taichi_class, python_scope,
to_numpy_type, to_pytorch_type)

import taichi as ti


# Scalar, basic data type
class Expr(TaichiOperations):
Expand Down Expand Up @@ -122,7 +124,7 @@ def clear(self, deactivate=False):
@python_scope
def fill(self, val):
# TODO: avoid too many template instantiations
from .meta import fill_tensor
from taichi.lang.meta import fill_tensor
fill_tensor(self, val)

def parent(self, n=1):
Expand All @@ -134,7 +136,7 @@ def is_global(self):

@property
def snode(self):
from .snode import SNode
from taichi.lang.snode import SNode
return SNode(self.ptr.snode())

def __hash__(self):
Expand Down Expand Up @@ -165,22 +167,24 @@ def data_type(self):

@python_scope
def to_numpy(self):
from .meta import tensor_to_ext_arr
import numpy as np

from taichi.lang.meta import tensor_to_ext_arr
arr = np.zeros(shape=self.shape, dtype=to_numpy_type(self.dtype))
tensor_to_ext_arr(self, arr)
runtime_ops.sync()
ti.sync()
return arr

@python_scope
def to_torch(self, device=None):
from .meta import tensor_to_ext_arr
import torch

from taichi.lang.meta import tensor_to_ext_arr
arr = torch.zeros(size=self.shape,
dtype=to_pytorch_type(self.dtype),
device=device)
tensor_to_ext_arr(self, arr)
runtime_ops.sync()
ti.sync()
return arr

@python_scope
Expand All @@ -193,7 +197,7 @@ def from_numpy(self, arr):
if hasattr(arr, 'contiguous'):
arr = arr.contiguous()
ext_arr_to_tensor(arr, self)
runtime_ops.sync()
ti.sync()

@python_scope
def from_torch(self, arr):
Expand Down Expand Up @@ -223,19 +227,17 @@ def __repr__(self):


def make_var_vector(size):
from taichi.lang.matrix import Vector
exprs = []
for _ in range(size):
exprs.append(taichi_lang_core.make_id_expr(''))
return Vector(exprs)
return ti.Vector(exprs)


def make_expr_group(*exprs):
if len(exprs) == 1:
from taichi.lang.matrix import Matrix
if isinstance(exprs[0], (list, tuple)):
exprs = exprs[0]
elif isinstance(exprs[0], Matrix):
elif isinstance(exprs[0], ti.Matrix):
mat = exprs[0]
assert mat.m == 1
exprs = mat.entries
Expand Down
Loading