-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【PaddlePaddle Hackathon 4】No.56 : add fp16 test and bf16 for bernoulli and trunc #51657
Closed
Closed
Changes from 5 commits
Commits
Show all changes
28 commits
Select commit
Hold shift + click to select a range
a1d0522
add fp16 and bf16 support for bernoulli
longranger2 f6455e7
add fp16 and bf16 support for trunc
longranger2 3279c68
Merge branch 'PaddlePaddle:develop' into fp16_56_2
longranger2 99f5854
fix bug
longranger2 9ee7d3a
Merge branch 'develop' into fp16_56_2
longranger2 dce1754
fix bug
longranger2 63c6f39
Merge branch 'PaddlePaddle:develop' into fp16_56_2
longranger2 b1771eb
fix bug
longranger2 528e5b8
fix PR-CI-Codestyle-Check
longranger2 2fc39e1
fix bug of trunc_kernel.cu
longranger2 8b8361d
fix bug of trunc_kernel.cu
longranger2 099d3bb
fix bug of trunc_kernel.cu
longranger2 22dbf8d
fix bug of trunc and bernoulli
longranger2 9db702f
fix bug
longranger2 38d7bc1
fix bug
longranger2 f4ce773
fix bug of MPType
longranger2 bd62029
fix check_variable_and_dtype
longranger2 3782bd1
fix bug of MPType
longranger2 b20ac1a
fix bug of undefined T
longranger2 7def562
fix bug
longranger2 3f44c3d
Merge branch 'PaddlePaddle:develop' into fp16_56_2
longranger2 3e9063a
Update test_bernoulli_op.py
longranger2 13a2c74
Update test_bernoulli_op.py
longranger2 3c4e333
Update test_bernoulli_op.py
longranger2 e7ad7f2
fix bug of import
longranger2 10336f8
Merge branch 'PaddlePaddle:develop' into fp16_56_2
longranger2 f922dd8
remove the trunc
longranger2 ea1d0ed
Merge branch 'PaddlePaddle:develop' into fp16_56_2
longranger2 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -31,10 +31,15 @@ def output_hist(out): | |
class TestBernoulliOp(OpTest): | ||
def setUp(self): | ||
self.op_type = "bernoulli" | ||
self.inputs = {"X": np.random.uniform(size=(1000, 784))} | ||
self.inputs = { | ||
"X": np.random.uniform(size=(1000, 784)).astype(self.dtype) | ||
} | ||
self.attrs = {} | ||
self.outputs = {"Out": np.zeros((1000, 784)).astype("float32")} | ||
|
||
def init_dtype(self): | ||
self.dtype = np.float32 | ||
|
||
def test_check_output(self): | ||
self.check_output_customized(self.verify_output) | ||
|
||
|
@@ -98,5 +103,10 @@ def test_fixed_random_number(self): | |
paddle.enable_static() | ||
|
||
|
||
class TestBernoulliFP16OP(TestBernoulliOp): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 是不是还要添加一下BF16的单测 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 对的,已经添加好了~ |
||
def init_dtype(self): | ||
self.dtype = np.float16 | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
float16的输出不应该是float32类型吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的👌