mirror of
https://github.com/HumanAIGC/lite-avatar.git
synced 2026-02-05 18:09:20 +08:00
add files
This commit is contained in:
43
funasr_local/bin/punc_train.py
Normal file
43
funasr_local/bin/punc_train.py
Normal file
@@ -0,0 +1,43 @@
|
||||
#!/usr/bin/env python3
|
||||
import os
|
||||
from funasr_local.tasks.punctuation import PunctuationTask
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = PunctuationTask.get_parser()
|
||||
parser.add_argument(
|
||||
"--gpu_id",
|
||||
type=int,
|
||||
default=0,
|
||||
help="local gpu id.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--punc_list",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Punctuation list",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main(args=None, cmd=None):
|
||||
"""
|
||||
punc training.
|
||||
"""
|
||||
PunctuationTask.main(args=args, cmd=cmd)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
|
||||
# setup local gpu_id
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
|
||||
|
||||
# DDP settings
|
||||
if args.ngpu > 1:
|
||||
args.distributed = True
|
||||
else:
|
||||
args.distributed = False
|
||||
|
||||
main(args=args)
|
||||
Reference in New Issue
Block a user