mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 18:09:24 +08:00
fix lint
This commit is contained in:
@@ -180,7 +180,7 @@ def data_collator(batch, tokenizer, s3_tokenizer):
|
|||||||
input_ids_list, prompt_audio_list, prompt_text_list = [], [], []
|
input_ids_list, prompt_audio_list, prompt_text_list = [], [], []
|
||||||
prompt_text_after_apply_template_list = []
|
prompt_text_after_apply_template_list = []
|
||||||
mels, prompt_audio_cosy2tokens_list, full_text_list = [], [], []
|
mels, prompt_audio_cosy2tokens_list, full_text_list = [], [], []
|
||||||
for i, item in enumerate(batch):
|
for _, item in enumerate(batch):
|
||||||
audio_processing_start_time = time.time()
|
audio_processing_start_time = time.time()
|
||||||
prompt_text, target_text = (
|
prompt_text, target_text = (
|
||||||
item["prompt_text"],
|
item["prompt_text"],
|
||||||
@@ -402,7 +402,7 @@ def main(args):
|
|||||||
)
|
)
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
elif args.backend == "trtllm":
|
elif args.backend == "trtllm":
|
||||||
batch_input_ids = [ids for ids in batch["input_ids"]]
|
batch_input_ids = list(batch["input_ids"])
|
||||||
input_lengths = [x.size(0) for x in batch_input_ids]
|
input_lengths = [x.size(0) for x in batch_input_ids]
|
||||||
|
|
||||||
end_id = tokenizer.convert_tokens_to_ids("<|eos1|>") if "<|eos1|>" in tokenizer.get_vocab() else tokenizer.eos_token_id
|
end_id = tokenizer.convert_tokens_to_ids("<|eos1|>") if "<|eos1|>" in tokenizer.get_vocab() else tokenizer.eos_token_id
|
||||||
|
|||||||
@@ -286,7 +286,7 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
|
|||||||
|
|
||||||
def collate_fn(batch):
|
def collate_fn(batch):
|
||||||
ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = [], [], [], []
|
ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = [], [], [], []
|
||||||
for i, item in enumerate(batch):
|
for _, item in enumerate(batch):
|
||||||
generated_speech_tokens_list.append(item['target_audio_cosy2_tokens'])
|
generated_speech_tokens_list.append(item['target_audio_cosy2_tokens'])
|
||||||
audio = torch.from_numpy(item['prompt_audio']['array']).float()
|
audio = torch.from_numpy(item['prompt_audio']['array']).float()
|
||||||
prompt_audios_list.append(audio)
|
prompt_audios_list.append(audio)
|
||||||
@@ -319,7 +319,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn, num_workers=0)
|
data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn, num_workers=0)
|
||||||
|
|
||||||
for epoch in range(args.warmup):
|
for _ in range(args.warmup):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
for batch in data_loader:
|
for batch in data_loader:
|
||||||
|
|||||||
Reference in New Issue
Block a user