Skip to content

Commit

Permalink
fixed concat tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jakpiase committed Oct 4, 2021
1 parent 9361040 commit 47a73fe
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,14 @@ def setUp(self):
'mkldnn_data_type': self.mkldnn_data_type
}

self.sections = [self.x0.shape[self.axis]] * 2
self.sections[1] += self.x1.shape[self.axis]

self.output = np.concatenate(
(self.x0, self.x1, self.x2), axis=self.axis).astype(np.uint16)
self.outputs = {'Out': self.output}

def calculate_grads(self):
def calculate_grads(self):
self.dout = self.outputs['Out']
self.dxs = np.split(self.dout, self.sections, self.axis)

Expand Down Expand Up @@ -73,9 +76,9 @@ def init_axis(self):
self.axis = 0

def init_shape(self):
self.x0_shape = [2, 2, 1, 2]
self.x1_shape = [1, 2, 1, 2]
self.x2_shape = [3, 2, 1, 2]
self.x0_shape = [6, 2, 4, 3]
self.x1_shape = [7, 2, 4, 3]
self.x2_shape = [8, 2, 4, 3]


# --------------------test concat bf16 in with axis 1--------------------
Expand All @@ -86,9 +89,9 @@ def init_axis(self):
self.axis = 1

def init_shape(self):
self.x0_shape = [1, 1, 5, 5]
self.x1_shape = [1, 2, 5, 5]
self.x2_shape = [1, 3, 5, 5]
self.x0_shape = [1, 4, 5, 5]
self.x1_shape = [1, 8, 5, 5]
self.x2_shape = [1, 6, 5, 5]


# --------------------test concat bf16 in with axis 2--------------------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,6 @@ def setUp(self):
self.output = np.concatenate(
(self.x0, self.x1, self.x2), axis=self.axis).astype(self.dtype)

self.sections = [self.x0.shape[self.axis]] * 2
self.sections[1] += self.x1.shape[self.axis]

self.outputs = {'Out': self.output}

def configure_datatype(self):
Expand Down

1 comment on commit 47a73fe

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.