Skip to content

Commit

Permalink
fix Flatten api test=develop (#26346)
Browse files Browse the repository at this point in the history
  • Loading branch information
yaoxuefeng6 authored Aug 24, 2020
1 parent 0d71cff commit d6e888c
Showing 1 changed file with 3 additions and 15 deletions.
18 changes: 3 additions & 15 deletions python/paddle/fluid/dygraph/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from __future__ import print_function

import paddle
from six.moves import reduce
from .. import core
from ..layers import utils
Expand Down Expand Up @@ -3457,19 +3458,6 @@ def __init__(self, start_axis=1, stop_axis=-1):
self.stop_axis = stop_axis

def forward(self, input):
out = self._helper.create_variable_for_type_inference(input.dtype)
x_shape = self._helper.create_variable_for_type_inference(input.dtype)

if in_dygraph_mode():
dy_out, _ = core.ops.flatten_contiguous_range(
input, 'start_axis', self.start_axis, 'stop_axis',
self.stop_axis)
return dy_out
self._helper.append_op(
type="flatten_contiguous_range",
inputs={"X": input},
outputs={"Out": out,
"XShape": x_shape},
attrs={"start_axis": self.start_axis,
"stop_axis": self.stop_axis})
out = paddle.tensor.manipulation.flatten(
input, start_axis=self.start_axis, stop_axis=self.stop_axis)
return out

0 comments on commit d6e888c

Please sign in to comment.