Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Not for Merge]: Visualize the gradient of each node in the lattice. #251

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

csukuangfj
Copy link
Collaborator

This PR visualizes the gradient of each node in the lattice, which is used to compute the transducer loss.

The following shows some plots for different utterances.

You can see that

  • Most of the nodes have a very small gradient, i.e., most of the nodes have the background color.
  • Positions of nodes with non-zero gradient change somewhat monotonically, from the lower left to the upper right
  • At each time frame, only a small set of nodes have a non-zero gradient, which justifies the pruned RNN-T loss, i.e., putting a limit on the number of symbols per frame.

4160-11550-0025-15342-0_sp0 9
3440-171006-0000-22865-0_sp0 9

4195-186238-0001-16076-0_sp0 9

8425-246962-0023-25419-0_sp0 9
5652-39938-0025-14246-0

@csukuangfj
Copy link
Collaborator Author

This PR is not for merge. It is useful for visualizing the node gradient in the lattice during training.

@pkufool
Copy link
Collaborator

pkufool commented Mar 14, 2022

Are these pictures from the very first beginning steps or the stable training steps(i.e. middle steps at epoch 5 or other larger epochs).

@csukuangfj
Copy link
Collaborator Author

Note: The above plots are from the first batch at the very beginning of the training, i.e., the model weights are randomly initialized and no backward pass has been performed on it yet.

The following plots use the pre-trained model from #248

4160-11550-0025-15342-0_sp0 9

3440-171006-0000-22865-0_sp0 9
4195-186238-0001-16076-0_sp0 9
8425-246962-0023-25419-0_sp0 9

5652-39938-0025-14246-0

@csukuangfj
Copy link
Collaborator Author

csukuangfj commented Mar 14, 2022

For better comparison, the plots between the model with randomly initialized weights and the pre-trained model are listed as follows:

Randomly initialized Pre-trained
4160-11550-0025-15342-0_sp0 9 4160-11550-0025-15342-0_sp0 9
3440-171006-0000-22865-0_sp0 9 3440-171006-0000-22865-0_sp0 9
4195-186238-0001-16076-0_sp0 9 4195-186238-0001-16076-0_sp0 9
8425-246962-0023-25419-0_sp0 9 8425-246962-0023-25419-0_sp0 9
5652-39938-0025-14246-0 5652-39938-0025-14246-0

@desh2608
Copy link
Collaborator

@csukuangfj which quantity are you plotting here exactly? Is it simple_loss.grad?

@csukuangfj
Copy link
Collaborator Author

@csukuangfj which quantity are you plotting here exactly? Is it simple_loss.grad?

It is related to simple_loss, but it is not simple_loss.grad.

We are plotting the occupation probability of each node in the lattice. Please
refer to the following code if you want to learn more.

    # this is a kind of "fake gradient" that we use, in effect to compute
    # occupation probabilities.  The backprop will work regardless of the
    # actual derivative w.r.t. the total probs.
    ans_grad = torch.ones(B, device=px_tot.device, dtype=px_tot.dtype)

    (px_grad,
     py_grad) = _k2.mutual_information_backward(px_tot, py_tot, boundary, p,
                                                ans_grad)
// backward of mutual_information.  Returns (px_grad, py_grad).
// p corresponds to what we computed in the forward pass.
std::vector<torch::Tensor> MutualInformationBackwardCpu(
    torch::Tensor px, torch::Tensor py,
    torch::optional<torch::Tensor> opt_boundary, torch::Tensor p,
    torch::Tensor ans_grad) {

I suggest that you derive the formula of the occupation probability of each node on your own. You can find the code at
https://github.com/k2-fsa/k2/blob/0d7ef1a7867f70354ab5c59f2feb98c45558dcc7/k2/python/csrc/torch/mutual_information_cpu.cu#L189-L215

              // The s,t indexes correspond to
              // The statement we are backpropagating here is:
              // p_a[b][s][t] = LogAdd(
              //    p_a[b][s - 1][t + t_offset] + px_a[b][s - 1][t + t_offset],
              //    p_a[b][s][t - 1] + py_a[b][s][t - 1]);
              // .. which obtains p_a[b][s][t - 1] from a register.
              scalar_t term1 = p_a[b][s - 1][t + t_offset] +
                               px_a[b][s - 1][t + t_offset],
                       // term2 = p_a[b][s][t - 1] + py_a[b][s][t - 1], <-- not
                       // actually needed..
                  total = p_a[b][s][t];
              if (total - total != 0) total = 0;
              scalar_t term1_deriv = exp(term1 - total),
                       term2_deriv = 1.0 - term1_deriv,
                       grad = p_grad_a[b][s][t];
              scalar_t term1_grad, term2_grad;
              if (term1_deriv - term1_deriv == 0.0) {
                term1_grad = term1_deriv * grad;
                term2_grad = term2_deriv * grad;
              } else {
                // could happen if total == -inf
                term1_grad = term2_grad = 0.0;
              }
              px_grad_a[b][s - 1][t + t_offset] = term1_grad;
              p_grad_a[b][s - 1][t + t_offset] = term1_grad;
              py_grad_a[b][s][t - 1] = term2_grad;
              p_grad_a[b][s][t - 1] += term2_grad;

@desh2608
Copy link
Collaborator

Thanks for the detailed explanation!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants