-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add gradient test framework #3226
add gradient test framework #3226
Conversation
@@ -23,6 +23,32 @@ def setUp(self): | |||
self.Y = np.apply_along_axis(stable_softmax, 1, self.X) | |||
|
|||
|
|||
class TestSoftmaxGradOp1(unittest.TestCase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does the class name have suffix 1
? Should we give it a more specific name?
… GradientChecker
… GradientChecker
no_grad_set=set(), | ||
only_cpu=False, | ||
rtol=0.005, | ||
atol=0.05): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
only one augment rtol
or atol
is enough. Maybe rtol
or delta/epsilon
is a better name to understand.
Because Gradient Check is a numeric way to get the gradient, it is normal to have some diff.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how about max_relative_error
?
:param only_cpu: only compute and check gradient on cpu kernel. | ||
:return: | ||
""" | ||
out_names = filter(lambda name: name != "@TEMP@", forward_op.outputs()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not only @TEMP@
in forward_op is temporary. use forward_op.attr("temporary_index") to get no temporary output.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
raise ValueError("no_grad should be in in_names") | ||
|
||
for check_name in inputs_to_check: | ||
if check_name not in in_names: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This check is not necessary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed
# get and store numeric_grad | ||
if not numeric_grad.has_key(check_name): | ||
numeric_grad[check_name] = \ | ||
get_numeric_gradient(forward_op, numeric_input, output_name, check_name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The numeric gradient could be calculated once. Not in for loop.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, it will only calculated once, optimize the code for better reading
… GradientChecker
…into GradientChecker
… GradientChecker
Any one following up this PR? Also we need a design doc, a description or an issue describing this PR. |
@wangkuiyi I am now working on this and will add a design doc. |
… GradientChecker
|
||
|
||
def grad_var_name(var_name): | ||
return var_name + "@GRAD" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we should expose this method from C++
input_vars, | ||
inputs_to_check, | ||
output_name, | ||
no_grad_set=set(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do not use complex value as default argument.
Python's default argument is very very strange. Python default argument value is mutable.
See here
|
||
places = [core.CPUPlace()] | ||
if not only_cpu and core.is_compile_gpu() \ | ||
and core.Operator.support_gpu(backward_op.type()): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
backward operator is a network operator. Support gpu or not is not decided by type but op itself.
|
||
# create input var and set value | ||
for name, value in input_vars.iteritems(): | ||
assert name in in_names |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
raise ValueError instead assert
backward_op.infer_shape(scope) | ||
backward_op.run(scope, ctx) | ||
|
||
for check_name in inputs_to_check: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could check backward gradient and numeric gradient here?
我们可以立即检查反向梯度和数值梯度是否一致。没必要将他们存到两个dict里再比较。
进而,我们可以放一个函数
bool is_close(numeric_dict, scope)
来检查scope中的梯度是否正确。也能简化这个函数。
msg = "CPU kernel gradient is not close to numeric gradient" | ||
else: | ||
if isinstance(place, core.GPUPlace): | ||
msg = "CPU kernel gradient is not close to numeric gradient" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CPU kernel
->GPU kernel
Also, maybe we should add DebugString
to Place. Maybe in another PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, but lack design doc.
* jde fairmot export and video python infer
No description provided.