Skip to content

Commit

Permalink
Fix param@grad type error for amp in run_program
Browse files Browse the repository at this point in the history
  • Loading branch information
0x45f committed Mar 25, 2022
1 parent b79c6a9 commit 47cc9bc
Showing 1 changed file with 7 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,9 @@ def _train_amp_program(self):
"""
Lazy initialized property of train_amp_program.
"""
return self._append_backward_desc(self._infer_amp_program)
train_amp_program = self._append_backward_desc(self._infer_amp_program)
self._set_grad_type(self._params, train_amp_program)
return train_amp_program

@LazyInitialized
@switch_to_static_graph
Expand All @@ -224,7 +226,10 @@ def _train_pure_fp16_program(self):
"""
Lazy initialized property of _train_pure_fp16_program.
"""
return self._append_backward_desc(self._infer_pure_fp16_program)
train_pure_fp16_program = self._append_backward_desc(
self._infer_pure_fp16_program)
self._set_grad_type(self._params, train_pure_fp16_program)
return train_pure_fp16_program

@LazyInitialized
def _infer_program_id(self):
Expand Down

0 comments on commit 47cc9bc

Please sign in to comment.