Skip to content

Commit

Permalink
support for paddle2.4
Browse files Browse the repository at this point in the history
  • Loading branch information
Asthestarsfalll committed Dec 4, 2023
1 parent afab0fd commit 1660758
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,15 @@ def partial_concat(name: str, x, y, start_index=0, length=-1):
with paddle.static.program_guard(paddle.static.Program(), paddle.static.Program()):
x_data = paddle.static.data(name="x", shape=x.shape, dtype=x.dtype)
y_data = paddle.static.data(name="y", shape=x.shape, dtype=y.dtype)
out = paddle.incubate.layers.nn.partial_concat(

if paddle.__version__ >= '2.5.1':
out = paddle.incubate.layers.nn.partial_concat(
[x_data, y_data], start_index=start_index, length=length
)
else:
out = paddle.fluid.contrib.layers.partial_concat(
[x_data, y_data], start_index=start_index, length=length
)
)

cpu = paddle.static.cpu_places(1)
exe = paddle.static.Executor(cpu[0])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,23 @@
import paddle
import sys


def partial_sum(name: str, x, y, start_index=0, length=-1):
paddle.enable_static()

with paddle.static.program_guard(paddle.static.Program(), paddle.static.Program()):
x_data = paddle.static.data(name="x", shape=x.shape, dtype=x.dtype)
y_data = paddle.static.data(name="y", shape=x.shape, dtype=y.dtype)
out = paddle.incubate.layers.nn.partial_sum(

if paddle.__version__ >= '2.5.1':
out = paddle.incubate.layers.nn.partial_sum(
[x_data, y_data], start_index=start_index, length=length
)
else:
out = paddle.fluid.contrib.layers.partial_sum(
[x_data, y_data], start_index=start_index, length=length
)
)


cpu = paddle.static.cpu_places(1)
exe = paddle.static.Executor(cpu[0])
Expand Down

0 comments on commit 1660758

Please sign in to comment.