Skip to content

Commit

Permalink
Merge pull request #1951 from Zhangjingyu06/develop
Browse files Browse the repository at this point in the history
deepspeech2 modify for kunlun
  • Loading branch information
Jackwaterveg authored May 25, 2022
2 parents 657c424 + acb19cf commit 86b9473
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 1 deletion.
6 changes: 6 additions & 0 deletions paddlespeech/s2t/exps/deepspeech2/bin/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ def main(config, args):
"--export_path", type=str, help="path of the jit model to save")
parser.add_argument(
"--model_type", type=str, default='offline', help="offline/online")
parser.add_argument(
'--nxpu',
type=int,
default=0,
choices=[0, 1],
help="if nxpu == 0 and ngpu == 0, use cpu.")
args = parser.parse_args()
print("model_type:{}".format(args.model_type))
print_arguments(args)
Expand Down
6 changes: 6 additions & 0 deletions paddlespeech/s2t/exps/deepspeech2/bin/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ def main(config, args):
# save asr result to
parser.add_argument(
"--result_file", type=str, help="path of save the asr result")
parser.add_argument(
'--nxpu',
type=int,
default=0,
choices=[0, 1],
help="if nxpu == 0 and ngpu == 0, use cpu.")
args = parser.parse_args()
print_arguments(args, globals())
print("model_type:{}".format(args.model_type))
Expand Down
6 changes: 6 additions & 0 deletions paddlespeech/s2t/exps/deepspeech2/bin/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@ def main(config, args):
"--export_path", type=str, help="path of the jit model to save")
parser.add_argument(
"--model_type", type=str, default='offline', help='offline/online')
parser.add_argument(
'--nxpu',
type=int,
default=0,
choices=[0, 1],
help="if nxpu == 0 and ngpu == 0, use cpu.")
parser.add_argument(
"--enable-auto-log", action="store_true", help="use auto log")
args = parser.parse_args()
Expand Down
6 changes: 6 additions & 0 deletions paddlespeech/s2t/exps/deepspeech2/bin/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@ def main(config, args):
parser = default_argument_parser()
parser.add_argument(
"--model_type", type=str, default='offline', help='offline/online')
parser.add_argument(
'--nxpu',
type=int,
default=0,
choices=[0, 1],
help="if nxpu == 0 and ngpu == 0, use cpu.")
args = parser.parse_args()
print("model_type:{}".format(args.model_type))
print_arguments(args, globals())
Expand Down
11 changes: 10 additions & 1 deletion paddlespeech/s2t/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,16 @@ def __init__(self, config, args):
logger.info(f"Rank: {self.rank}/{self.world_size}")

# set device
paddle.set_device('gpu' if self.args.ngpu > 0 else 'cpu')
if self.args.ngpu == 0:
if self.args.nxpu == 0:
paddle.set_device('cpu')
else:
paddle.set_device('xpu')
elif self.args.ngpu > 0:
paddle.set_device("gpu")
else:
raise Exception("invalid device")

if self.parallel:
self.init_parallel()

Expand Down

0 comments on commit 86b9473

Please sign in to comment.