Skip to content

Commit

Permalink
old grad clip has 0d tensor problem, fix it (PaddlePaddle#3334)
Browse files Browse the repository at this point in the history
  • Loading branch information
zh794390558 authored and luotao1 committed Jun 11, 2024
1 parent 7acf073 commit 9fbaebd
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 90 deletions.
3 changes: 1 addition & 2 deletions paddlespeech/s2t/exps/deepspeech2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from paddlespeech.s2t.io.dataloader import BatchDataLoader
from paddlespeech.s2t.models.ds2 import DeepSpeech2InferModel
from paddlespeech.s2t.models.ds2 import DeepSpeech2Model
from paddlespeech.s2t.training.gradclip import ClipGradByGlobalNormWithLog
from paddlespeech.s2t.training.reporter import report
from paddlespeech.s2t.training.timer import Timer
from paddlespeech.s2t.training.trainer import Trainer
Expand Down Expand Up @@ -148,7 +147,7 @@ def setup_model(self):
if not self.train:
return

grad_clip = ClipGradByGlobalNormWithLog(config.global_grad_clip)
grad_clip = paddle.nn.ClipGradByGlobalNorm(config.global_grad_clip)
lr_scheduler = paddle.optimizer.lr.ExponentialDecay(
learning_rate=config.lr, gamma=config.lr_decay, verbose=True)
optimizer = paddle.optimizer.Adam(
Expand Down
86 changes: 0 additions & 86 deletions paddlespeech/s2t/training/gradclip.py

This file was deleted.

4 changes: 2 additions & 2 deletions paddlespeech/s2t/training/optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import paddle
from paddle.optimizer import Optimizer
from paddle.regularizer import L2Decay
from paddlespeech.s2t.training.gradclip import ClipGradByGlobalNormWithLog

from paddlespeech.s2t.utils.dynamic_import import dynamic_import
from paddlespeech.s2t.utils.dynamic_import import instance_class
from paddlespeech.s2t.utils.log import Log
Expand Down Expand Up @@ -100,7 +100,7 @@ def from_args(cls, name: str, args: Dict[Text, Any]):
assert "parameters" in args, "parameters not in args."
assert "learning_rate" in args, "learning_rate not in args."

grad_clip = ClipGradByGlobalNormWithLog(
grad_clip = paddle.nn.ClipGradByGlobalNorm(
args['grad_clip']) if "grad_clip" in args else None
weight_decay = L2Decay(
args['weight_decay']) if "weight_decay" in args else None
Expand Down
9 changes: 9 additions & 0 deletions tests/unit/tts/test_ssml.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,12 @@
for i, sub in enumerate(outs):
print(i, sub)
print()

import json
import xmltodict
text = "<speak>我们的声学模型使用了 Fast Speech Two。前浪<say-as pinyin='dao3'>倒</say-as>在沙滩上,沙滩上倒了一堆<say-as pinyin='tu3'>土</say-as>。 想象<say-as pinyin='gan1 gan1'>干干</say-as>的树干<say-as pinyin='dao3'>倒</say-as>了, 里面有个干尸,不知是被谁<say-as pinyin='gan4'>干</say-as>死的。</speak>"
ssml = xmltodict.parse(text)
print(json.dumps(ssml))
print(ssml['speak'].keys())
print(ssml['speak']['#text'])
print(ssml['speak']['say-as'])

0 comments on commit 9fbaebd

Please sign in to comment.