Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[train] u2++-lite training support #2202

Merged
merged 1 commit into from
Dec 8, 2023
Merged

[train] u2++-lite training support #2202

merged 1 commit into from
Dec 8, 2023

Conversation

whiteshirt0429
Copy link
Collaborator

No description provided.

@manbaaaa
Copy link
Contributor

manbaaaa commented Dec 7, 2023

What is the functionality of 'apply_non_blank_embedding'? Are there any reference materials available for learning?

@whiteshirt0429 whiteshirt0429 force-pushed the diwu-u2++-lite branch 2 times, most recently from a2f0674 to 0a5ee16 Compare December 7, 2023 15:46
@whiteshirt0429
Copy link
Collaborator Author

whiteshirt0429 commented Dec 7, 2023

What is the functionality of 'apply_non_blank_embedding'? Are there any reference materials available for learning?

it is a new feature

@whiteshirt0429
Copy link
Collaborator Author

whiteshirt0429 commented Dec 7, 2023

u2++ lite is used for reducing rescoring latency,runtime and latency result will be check in soon

Copy link
Member

@xingchensong xingchensong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

看起来 _forward_ctc 不再需要了
image

@@ -133,6 +143,34 @@ def _forward_ctc(self, encoder_out: torch.Tensor,
loss_ctc = self.ctc(encoder_out, encoder_out_lens, text, text_lengths)
return loss_ctc
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

看起来 _forward_ctc 不再需要了,asr_model.py 中没有其他地方会调用
image

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

k2 和paraformer里面都有重定义的 _forward_ctc, 比如k2里面是要在里面算lfmmi。所以我建议,修改asr_model._forward_ctc,让他返回loss和logits,而不是弃用 asr_model._forward_ctc, 同时修改k2和paraformer对应的重定义函数

image

@@ -63,7 +67,8 @@ def forward(self, hs_pad: torch.Tensor, hlens: torch.Tensor,
loss = self.ctc_loss(ys_hat, ys_pad, hlens, ys_lens)
# Batch-size average
loss = loss / ys_hat.size(1)
return loss
ys_hat = ys_hat.transpose(0, 1)
return loss, ys_hat
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里ctc的返回值已经变成俩了,所以k2和paraformer里的ctc调用的返回值也得改一下,不然会报错

image

Comment on lines 41 to 45
if info_dict["model_conf"]["apply_non_blank_embedding"]:
logging.warn(
'Had better load a well trained model if'
'apply_non_blank_embedding is true !!!'
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个可以挪到 train_utils.py::check_modify_and_save_config函数吗?原因是:

  1. 放在executor::train里,每个epcoh都要打印
  2. check_modify_and_save_config 这个函数就是专门用来检查配置的,符合这里的log含义

for module_name in args.freeze_modules:
if module_name in name:
param.requires_grad = False
logging.debug("{} module is freezed".format(name))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

纯好奇,freeze的结果比不freeze更好吗?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不freeze 多卡训练会有问题,对齐也会发生变化

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不freeze 多卡训练会有问题,对齐也会发生变化

get,多卡训练报啥错

maxlen = encoder_out.size(1)
top1_index = torch.argmax(ctc_probs, dim=2)
indices = []
for j in range(topk_prob.size(0)):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

topk_prob is undefined

@whiteshirt0429 whiteshirt0429 force-pushed the diwu-u2++-lite branch 3 times, most recently from 78915c3 to f82a4c9 Compare December 8, 2023 04:34
[train] add instructions for use
@robin1001
Copy link
Collaborator

设置 pre-commit 了吗?

@kobenaxie
Copy link
Contributor

这是google用在RNNT中的frame reduce / blank skip 用在AED架构中吗?

@whiteshirt0429
Copy link
Collaborator Author

whiteshirt0429 commented Dec 8, 2023

这是google用在RNNT中的frame reduce / blank skip 用在AED架构中吗?

当时做的时候没有了解这些,刚才搜了一下 k2 团队也有类似的工作。我理解思想上都差不多,都是为了降低计算量,减小延迟。这里主要是为了减少推理时的延迟。

@whiteshirt0429 whiteshirt0429 merged commit 2894f7c into main Dec 8, 2023
6 checks passed
@xingchensong xingchensong deleted the diwu-u2++-lite branch December 8, 2023 07:28
@xingchensong xingchensong mentioned this pull request Dec 8, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants