Skip to content

Commit

Permalink
add pir
Browse files Browse the repository at this point in the history
  • Loading branch information
zrr1999 committed Oct 23, 2023
1 parent f899bfc commit f88b58d
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 6 deletions.
3 changes: 1 addition & 2 deletions python/paddle/nn/initializer/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,9 @@ def forward(self, var, block=None):
if self._force_cpu:
place = core.CPUPlace()
if in_dygraph_mode():
_C_ops.full_(
return _C_ops.full_(
var, var.shape, float(self._value), var.dtype, place
)
return None
else:
return _C_ops.full(
var.shape, float(self._value), var.dtype, place
Expand Down
4 changes: 2 additions & 2 deletions python/paddle/tensor/logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def logical_and(x, y, out=None, name=None):
[True , False, True , False])
"""
if in_dynamic_mode():
if in_dynamic_or_pir_mode():
return _C_ops.logical_and(x, y)

return _logical_op(
Expand Down Expand Up @@ -413,7 +413,7 @@ def equal_all(x, y, name=None):
Tensor(shape=[], dtype=bool, place=Place(cpu), stop_gradient=True,
False)
"""
if in_dynamic_mode():
if in_dynamic_or_pir_mode():
return _C_ops.equal_all(x, y)
else:
helper = LayerHelper("equal_all", **locals())
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -4437,7 +4437,7 @@ def isnan(x, name=None):
Tensor(shape=[7], dtype=bool, place=Place(cpu), stop_gradient=True,
[False, False, False, False, False, True , True ])
"""
if in_dynamic_mode():
if in_dynamic_or_pir_mode():
return _C_ops.isnan(x)
else:
helper = LayerHelper("isnan_v2", **locals())
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/tensor/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ def nonzero(x, as_tuple=False):
shape = x.shape
rank = len(shape)

if in_dynamic_mode():
if in_dynamic_or_pir_mode():
outs = _C_ops.nonzero(x)
else:
check_variable_and_dtype(
Expand Down

0 comments on commit f88b58d

Please sign in to comment.