-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Able to print gradients in event_handler #3085
Conversation
@@ -161,14 +161,14 @@ def train(self, reader, num_passes=1, event_handler=None, feeding=None): | |||
self.__parameter_updater__.update(each_param) | |||
cost_sum = out_args.sum() | |||
cost = cost_sum / len(data_batch) | |||
self.__parameter_updater__.finishBatch(cost) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a little confusing, why we move the batch_evaluator
to the event_handler
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to call event_handler
before finishBatch operations so we can get inner status before clearing them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it, thanks!
assert isinstance(val, api.Vector) | ||
val = val.copyToNumpyArray() | ||
return val | ||
# else continue | ||
|
||
raise RuntimeError("Unexpected branch") | ||
|
||
def __getitem__(self, key): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rename to get_param?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
__getitem__
will called when doing param[k]
, it's an operator reload. Need to keep this the same as before.
even though it copies the parameter from the c++ side, very helpful for debugging the training process and tuning the model, thank you very much. |
Able to print gradients in event handle using v2 API to train.
Related: #3040
Fixes: #3211