Skip to content

Commit

Permalink
[Bug] [lang] Allow numpy int as snode dimension
Browse files Browse the repository at this point in the history
  • Loading branch information
strongoier committed Sep 30, 2022
1 parent bf20101 commit 6239e9c
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 4 deletions.
7 changes: 7 additions & 0 deletions 6186.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import taichi as ti
import numpy as np

ti.init(arch=ti.gpu)

np_int = np.max([0, 1, 2])
block = ti.root.pointer(ti.i, np_int)
8 changes: 4 additions & 4 deletions python/taichi/lang/snode.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def dense(self, axes, dimensions):
Returns:
The added :class:`~taichi.lang.SNode` instance.
"""
if isinstance(dimensions, int):
if isinstance(dimensions, numbers.Number):
dimensions = [dimensions] * len(axes)
return SNode(
self.ptr.dense(axes, dimensions,
Expand All @@ -46,7 +46,7 @@ def pointer(self, axes, dimensions):
Returns:
The added :class:`~taichi.lang.SNode` instance.
"""
if isinstance(dimensions, int):
if isinstance(dimensions, numbers.Number):
dimensions = [dimensions] * len(axes)
return SNode(
self.ptr.pointer(axes, dimensions,
Expand Down Expand Up @@ -90,7 +90,7 @@ def bitmasked(self, axes, dimensions):
Returns:
The added :class:`~taichi.lang.SNode` instance.
"""
if isinstance(dimensions, int):
if isinstance(dimensions, numbers.Number):
dimensions = [dimensions] * len(axes)
return SNode(
self.ptr.bitmasked(axes, dimensions,
Expand All @@ -107,7 +107,7 @@ def quant_array(self, axes, dimensions, max_num_bits):
Returns:
The added :class:`~taichi.lang.SNode` instance.
"""
if isinstance(dimensions, int):
if isinstance(dimensions, numbers.Number):
dimensions = [dimensions] * len(axes)
return SNode(
self.ptr.quant_array(axes, dimensions, max_num_bits,
Expand Down
12 changes: 12 additions & 0 deletions tests/python/test_fields_builder.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest
from taichi.lang.exception import TaichiRuntimeError

import numpy as np
import taichi as ti
from tests import test_utils

Expand Down Expand Up @@ -227,3 +228,14 @@ def calc_loss(arr: ti.template(), loss: ti.template()):
mul.grad(arr, out)
for i in range(10):
assert arr.grad[i] == 2.0


@test_utils.test(arch=ti.cpu)
def test_fields_builder_numpy_dimension():
shape = np.int32(5)
fb = ti.FieldsBuilder()
x = ti.field(ti.f32)
y = ti.field(ti.i32)
fb.dense(ti.i, shape).place(x)
fb.pointer(ti.j, shape).place(y)
fb.finalize()

0 comments on commit 6239e9c

Please sign in to comment.