mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-04 17:39:25 +08:00
Merge pull request #924 from FunAudioLLM/dev/lyuxiang.lx
add llm bistream
This commit is contained in:
10
README.md
10
README.md
@@ -143,6 +143,16 @@ for i, j in enumerate(cosyvoice.inference_cross_lingual('在他讲述那个荒
|
|||||||
# instruct usage
|
# instruct usage
|
||||||
for i, j in enumerate(cosyvoice.inference_instruct2('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '用四川话说这句话', prompt_speech_16k, stream=False)):
|
for i, j in enumerate(cosyvoice.inference_instruct2('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '用四川话说这句话', prompt_speech_16k, stream=False)):
|
||||||
torchaudio.save('instruct_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
torchaudio.save('instruct_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
||||||
|
|
||||||
|
# bistream usage, you can use generator as input, this is useful when using text llm model as input
|
||||||
|
# NOTE you should still have some basic sentence split logic because llm can not handle arbitrary sentence length
|
||||||
|
def text_generator():
|
||||||
|
yield '收到好友从远方寄来的生日礼物,'
|
||||||
|
yield '那份意外的惊喜与深深的祝福'
|
||||||
|
yield '让我心中充满了甜蜜的快乐,'
|
||||||
|
yield '笑容如花儿般绽放。'
|
||||||
|
for i, j in enumerate(cosyvoice.inference_zero_shot(text_generator, '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=False)):
|
||||||
|
torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
||||||
```
|
```
|
||||||
|
|
||||||
**CosyVoice Usage**
|
**CosyVoice Usage**
|
||||||
|
|||||||
@@ -13,6 +13,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
from typing import Generator
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from hyperpyyaml import load_hyperpyyaml
|
from hyperpyyaml import load_hyperpyyaml
|
||||||
from modelscope import snapshot_download
|
from modelscope import snapshot_download
|
||||||
@@ -76,7 +77,7 @@ class CosyVoice:
|
|||||||
def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, stream=False, speed=1.0, text_frontend=True):
|
def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, stream=False, speed=1.0, text_frontend=True):
|
||||||
prompt_text = self.frontend.text_normalize(prompt_text, split=False, text_frontend=text_frontend)
|
prompt_text = self.frontend.text_normalize(prompt_text, split=False, text_frontend=text_frontend)
|
||||||
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
|
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
|
||||||
if len(i) < 0.5 * len(prompt_text):
|
if (not isinstance(i, Generator)) and len(i) < 0.5 * len(prompt_text):
|
||||||
logging.warning('synthesis text {} too short than prompt text {}, this may lead to bad performance'.format(i, prompt_text))
|
logging.warning('synthesis text {} too short than prompt text {}, this may lead to bad performance'.format(i, prompt_text))
|
||||||
model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k, self.sample_rate)
|
model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k, self.sample_rate)
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|||||||
@@ -12,6 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
from typing import Generator
|
||||||
import json
|
import json
|
||||||
import onnxruntime
|
import onnxruntime
|
||||||
import torch
|
import torch
|
||||||
@@ -31,6 +32,7 @@ except ImportError:
|
|||||||
from tn.chinese.normalizer import Normalizer as ZhNormalizer
|
from tn.chinese.normalizer import Normalizer as ZhNormalizer
|
||||||
from tn.english.normalizer import Normalizer as EnNormalizer
|
from tn.english.normalizer import Normalizer as EnNormalizer
|
||||||
use_ttsfrd = False
|
use_ttsfrd = False
|
||||||
|
from cosyvoice.utils.file_utils import logging
|
||||||
from cosyvoice.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph, is_only_punctuation
|
from cosyvoice.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph, is_only_punctuation
|
||||||
|
|
||||||
|
|
||||||
@@ -71,10 +73,21 @@ class CosyVoiceFrontEnd:
|
|||||||
self.inflect_parser = inflect.engine()
|
self.inflect_parser = inflect.engine()
|
||||||
|
|
||||||
def _extract_text_token(self, text):
|
def _extract_text_token(self, text):
|
||||||
text_token = self.tokenizer.encode(text, allowed_special=self.allowed_special)
|
if isinstance(text, Generator):
|
||||||
text_token = torch.tensor([text_token], dtype=torch.int32).to(self.device)
|
logging.info('get tts_text generator, will return _extract_text_token_generator!')
|
||||||
text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.int32).to(self.device)
|
# NOTE add a dummy text_token_len for compatibility
|
||||||
return text_token, text_token_len
|
return self._extract_text_token_generator(text), torch.tensor([0], dtype=torch.int32).to(self.device)
|
||||||
|
else:
|
||||||
|
text_token = self.tokenizer.encode(text, allowed_special=self.allowed_special)
|
||||||
|
text_token = torch.tensor([text_token], dtype=torch.int32).to(self.device)
|
||||||
|
text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.int32).to(self.device)
|
||||||
|
return text_token, text_token_len
|
||||||
|
|
||||||
|
def _extract_text_token_generator(self, text_generator):
|
||||||
|
for text in text_generator:
|
||||||
|
text_token, _ = self._extract_text_token(text)
|
||||||
|
for i in range(text_token.shape[1]):
|
||||||
|
yield text_token[:, i: i + 1]
|
||||||
|
|
||||||
def _extract_speech_token(self, speech):
|
def _extract_speech_token(self, speech):
|
||||||
assert speech.shape[1] / 16000 <= 30, 'do not support extract speech token for audio longer than 30s'
|
assert speech.shape[1] / 16000 <= 30, 'do not support extract speech token for audio longer than 30s'
|
||||||
@@ -106,6 +119,9 @@ class CosyVoiceFrontEnd:
|
|||||||
return speech_feat, speech_feat_len
|
return speech_feat, speech_feat_len
|
||||||
|
|
||||||
def text_normalize(self, text, split=True, text_frontend=True):
|
def text_normalize(self, text, split=True, text_frontend=True):
|
||||||
|
if isinstance(text, Generator):
|
||||||
|
logging.info('get tts_text generator, will skip text_normalize!')
|
||||||
|
return [text]
|
||||||
if text_frontend is False:
|
if text_frontend is False:
|
||||||
return [text] if split is True else text
|
return [text] if split is True else text
|
||||||
text = text.strip()
|
text = text.strip()
|
||||||
|
|||||||
@@ -12,6 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import os
|
import os
|
||||||
|
from typing import Generator
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import threading
|
import threading
|
||||||
@@ -99,14 +100,24 @@ class CosyVoiceModel:
|
|||||||
|
|
||||||
def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
|
def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
|
||||||
with self.llm_context:
|
with self.llm_context:
|
||||||
for i in self.llm.inference(text=text.to(self.device),
|
if isinstance(text, Generator):
|
||||||
text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
|
assert isinstance(self, CosyVoice2Model), 'streaming input text is only implemented for CosyVoice2!'
|
||||||
prompt_text=prompt_text.to(self.device),
|
for i in self.llm.inference_bistream(text=text,
|
||||||
prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
|
prompt_text=prompt_text.to(self.device),
|
||||||
prompt_speech_token=llm_prompt_speech_token.to(self.device),
|
prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
|
||||||
prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
|
prompt_speech_token=llm_prompt_speech_token.to(self.device),
|
||||||
embedding=llm_embedding.to(self.device)):
|
prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
|
||||||
self.tts_speech_token_dict[uuid].append(i)
|
embedding=llm_embedding.to(self.device)):
|
||||||
|
self.tts_speech_token_dict[uuid].append(i)
|
||||||
|
else:
|
||||||
|
for i in self.llm.inference(text=text.to(self.device),
|
||||||
|
text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
|
||||||
|
prompt_text=prompt_text.to(self.device),
|
||||||
|
prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
|
||||||
|
prompt_speech_token=llm_prompt_speech_token.to(self.device),
|
||||||
|
prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
|
||||||
|
embedding=llm_embedding.to(self.device)):
|
||||||
|
self.tts_speech_token_dict[uuid].append(i)
|
||||||
self.llm_end_dict[uuid] = True
|
self.llm_end_dict[uuid] = True
|
||||||
|
|
||||||
def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0):
|
def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0):
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ from torch.nn.utils.rnn import pad_sequence, unpad_sequence
|
|||||||
from cosyvoice.utils.common import IGNORE_ID
|
from cosyvoice.utils.common import IGNORE_ID
|
||||||
from cosyvoice.transformer.label_smoothing_loss import LabelSmoothingLoss
|
from cosyvoice.transformer.label_smoothing_loss import LabelSmoothingLoss
|
||||||
from cosyvoice.utils.common import th_accuracy
|
from cosyvoice.utils.common import th_accuracy
|
||||||
|
from cosyvoice.utils.file_utils import logging
|
||||||
|
|
||||||
|
|
||||||
class TransformerLM(torch.nn.Module):
|
class TransformerLM(torch.nn.Module):
|
||||||
@@ -144,10 +145,14 @@ class TransformerLM(torch.nn.Module):
|
|||||||
sampling: int,
|
sampling: int,
|
||||||
ignore_eos: bool = True,
|
ignore_eos: bool = True,
|
||||||
):
|
):
|
||||||
|
num_trials, max_trials = 0, 100
|
||||||
while True:
|
while True:
|
||||||
top_ids = self.sampling(weighted_scores, decoded_tokens, sampling)
|
top_ids = self.sampling(weighted_scores, decoded_tokens, sampling)
|
||||||
if (not ignore_eos) or (self.speech_token_size not in top_ids):
|
if (not ignore_eos) or (self.speech_token_size not in top_ids):
|
||||||
break
|
break
|
||||||
|
num_trials += 1
|
||||||
|
if num_trials > max_trials:
|
||||||
|
raise RuntimeError('sampling reaches max_trials {} and still get eos when ignore_eos is True, check your input!'.format(max_trials))
|
||||||
return top_ids
|
return top_ids
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
@@ -239,7 +244,7 @@ class Qwen2Encoder(torch.nn.Module):
|
|||||||
return xs, new_cache
|
return xs, new_cache
|
||||||
|
|
||||||
|
|
||||||
class Qwen2LM(torch.nn.Module):
|
class Qwen2LM(TransformerLM):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
llm_input_size: int,
|
llm_input_size: int,
|
||||||
@@ -249,8 +254,9 @@ class Qwen2LM(torch.nn.Module):
|
|||||||
sampling: Callable,
|
sampling: Callable,
|
||||||
length_normalized_loss: bool = True,
|
length_normalized_loss: bool = True,
|
||||||
lsm_weight: float = 0.0,
|
lsm_weight: float = 0.0,
|
||||||
|
mix_ratio: List[int] = [5, 15],
|
||||||
):
|
):
|
||||||
super().__init__()
|
torch.nn.Module.__init__(self)
|
||||||
self.llm_input_size = llm_input_size
|
self.llm_input_size = llm_input_size
|
||||||
self.llm_output_size = llm_output_size
|
self.llm_output_size = llm_output_size
|
||||||
self.speech_token_size = speech_token_size
|
self.speech_token_size = speech_token_size
|
||||||
@@ -275,23 +281,7 @@ class Qwen2LM(torch.nn.Module):
|
|||||||
|
|
||||||
# 4. sampling method
|
# 4. sampling method
|
||||||
self.sampling = sampling
|
self.sampling = sampling
|
||||||
|
self.mix_ratio = mix_ratio
|
||||||
def sampling_ids(
|
|
||||||
self,
|
|
||||||
weighted_scores: torch.Tensor,
|
|
||||||
decoded_tokens: List,
|
|
||||||
sampling: int,
|
|
||||||
ignore_eos: bool = True,
|
|
||||||
):
|
|
||||||
num_trials, max_trials = 0, 100
|
|
||||||
while True:
|
|
||||||
top_ids = self.sampling(weighted_scores, decoded_tokens, sampling)
|
|
||||||
if (not ignore_eos) or (self.speech_token_size not in top_ids):
|
|
||||||
break
|
|
||||||
num_trials += 1
|
|
||||||
if num_trials > max_trials:
|
|
||||||
raise RuntimeError('sampling reaches max_trials {} and still get eos when ignore_eos is True, check your input!'.format(max_trials))
|
|
||||||
return top_ids
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def inference(
|
def inference(
|
||||||
@@ -312,9 +302,6 @@ class Qwen2LM(torch.nn.Module):
|
|||||||
text_len += prompt_text_len
|
text_len += prompt_text_len
|
||||||
text = self.llm.model.model.embed_tokens(text)
|
text = self.llm.model.model.embed_tokens(text)
|
||||||
|
|
||||||
# 2. encode embedding
|
|
||||||
embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device).to(text.dtype)
|
|
||||||
|
|
||||||
# 3. concat llm_input
|
# 3. concat llm_input
|
||||||
sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
|
sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
|
||||||
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
|
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
|
||||||
@@ -322,7 +309,7 @@ class Qwen2LM(torch.nn.Module):
|
|||||||
prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
|
prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
|
||||||
else:
|
else:
|
||||||
prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
|
prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
|
||||||
lm_input = torch.concat([sos_eos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1)
|
lm_input = torch.concat([sos_eos_emb, text, task_id_emb, prompt_speech_token_emb], dim=1)
|
||||||
|
|
||||||
# 4. cal min/max_length
|
# 4. cal min/max_length
|
||||||
min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
|
min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
|
||||||
@@ -345,3 +332,100 @@ class Qwen2LM(torch.nn.Module):
|
|||||||
yield top_ids
|
yield top_ids
|
||||||
out_tokens.append(top_ids)
|
out_tokens.append(top_ids)
|
||||||
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
|
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def inference_bistream(
|
||||||
|
self,
|
||||||
|
text: Generator,
|
||||||
|
prompt_text: torch.Tensor,
|
||||||
|
prompt_text_len: torch.Tensor,
|
||||||
|
prompt_speech_token: torch.Tensor,
|
||||||
|
prompt_speech_token_len: torch.Tensor,
|
||||||
|
embedding: torch.Tensor,
|
||||||
|
sampling: int = 25,
|
||||||
|
max_token_text_ratio: float = 20,
|
||||||
|
min_token_text_ratio: float = 2,
|
||||||
|
) -> Generator[torch.Tensor, None, None]:
|
||||||
|
|
||||||
|
device = prompt_text.device
|
||||||
|
# 1. prepare input
|
||||||
|
sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
|
||||||
|
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
|
||||||
|
if prompt_speech_token_len != 0:
|
||||||
|
prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
|
||||||
|
else:
|
||||||
|
prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=prompt_text.dtype).to(device)
|
||||||
|
lm_input = torch.concat([sos_eos_emb], dim=1)
|
||||||
|
|
||||||
|
# 2. iterate text
|
||||||
|
out_tokens = []
|
||||||
|
cache = None
|
||||||
|
# NOTE init prompt_text as text_cache as it is basically impossible prompt_speech_token/prompt_text < 15/5
|
||||||
|
text_cache = self.llm.model.model.embed_tokens(prompt_text)
|
||||||
|
next_fill_index = -1
|
||||||
|
for this_text in text:
|
||||||
|
text_cache = torch.concat([text_cache, self.llm.model.model.embed_tokens(this_text)], dim=1)
|
||||||
|
# prompt_speech_token_emb not empty, try append to lm_input
|
||||||
|
while prompt_speech_token_emb.size(1) != 0:
|
||||||
|
if text_cache.size(1) >= self.mix_ratio[0]:
|
||||||
|
lm_input_text, lm_input_speech = text_cache[:, :self.mix_ratio[0]], prompt_speech_token_emb[:, :self.mix_ratio[1]]
|
||||||
|
logging.info('append {} text token {} speech token'.format(lm_input_text.size(1), lm_input_speech.size(1)))
|
||||||
|
lm_input = torch.concat([lm_input, lm_input_text, lm_input_speech], dim=1)
|
||||||
|
text_cache, prompt_speech_token_emb = text_cache[:, self.mix_ratio[0]:], prompt_speech_token_emb[:, self.mix_ratio[1]:]
|
||||||
|
else:
|
||||||
|
logging.info('not enough text token to decode, wait for more')
|
||||||
|
break
|
||||||
|
# no prompt_speech_token_emb remain, can decode some speech token
|
||||||
|
if prompt_speech_token_emb.size(1) == 0:
|
||||||
|
if (len(out_tokens) != 0 and out_tokens[-1] == self.speech_token_size + 2) or (len(out_tokens) == 0 and lm_input.size(1) == 1):
|
||||||
|
logging.info('get fill token, need to append more text token')
|
||||||
|
if text_cache.size(1) >= self.mix_ratio[0]:
|
||||||
|
lm_input_text = text_cache[:, :self.mix_ratio[0]]
|
||||||
|
logging.info('append {} text token'.format(lm_input_text.size(1)))
|
||||||
|
lm_input = torch.concat([lm_input, lm_input_text], dim=1)
|
||||||
|
text_cache = text_cache[:, self.mix_ratio[0]:]
|
||||||
|
else:
|
||||||
|
logging.info('not enough text token to decode, wait for more')
|
||||||
|
continue
|
||||||
|
while True:
|
||||||
|
seq_len = lm_input.shape[1] if cache is None else lm_input.shape[1] + cache[0][0].size(2)
|
||||||
|
y_pred, cache = self.llm.forward_one_step(lm_input,
|
||||||
|
masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool),
|
||||||
|
cache=cache)
|
||||||
|
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
|
||||||
|
if next_fill_index != -1 and len(out_tokens) == next_fill_index:
|
||||||
|
top_ids = self.speech_token_size + 2
|
||||||
|
next_fill_index += (self.mix_ratio[1] + 1)
|
||||||
|
else:
|
||||||
|
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True).item()
|
||||||
|
if top_ids == self.speech_token_size + 2:
|
||||||
|
next_fill_index = len(out_tokens) + self.mix_ratio[1] + 1
|
||||||
|
logging.info('fill_token index {} next fill_token index {}'.format(len(out_tokens), next_fill_index))
|
||||||
|
out_tokens.append(top_ids)
|
||||||
|
if top_ids >= self.speech_token_size:
|
||||||
|
if top_ids == self.speech_token_size + 2:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
raise ValueError('should not get token {}'.format(top_ids))
|
||||||
|
yield top_ids
|
||||||
|
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
|
||||||
|
|
||||||
|
# 3. final decode
|
||||||
|
lm_input = torch.concat([lm_input, text_cache, task_id_emb], dim=1)
|
||||||
|
logging.info('no more text token, decode until met eos')
|
||||||
|
while True:
|
||||||
|
seq_len = lm_input.shape[1] if cache is None else lm_input.shape[1] + cache[0][0].size(2)
|
||||||
|
y_pred, cache = self.llm.forward_one_step(lm_input,
|
||||||
|
masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool),
|
||||||
|
cache=cache)
|
||||||
|
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
|
||||||
|
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=False).item()
|
||||||
|
out_tokens.append(top_ids)
|
||||||
|
if top_ids >= self.speech_token_size:
|
||||||
|
if top_ids == self.speech_token_size:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
raise ValueError('should not get token {}'.format(top_ids))
|
||||||
|
# in stream mode, yield token one by one
|
||||||
|
yield top_ids
|
||||||
|
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
|
||||||
|
|||||||
@@ -162,5 +162,5 @@ def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
|
|||||||
# attention mask bias
|
# attention mask bias
|
||||||
# NOTE(Mddct): torch.finfo jit issues
|
# NOTE(Mddct): torch.finfo jit issues
|
||||||
# chunk_masks = (1.0 - chunk_masks) * torch.finfo(dtype).min
|
# chunk_masks = (1.0 - chunk_masks) * torch.finfo(dtype).min
|
||||||
mask = (1.0 - mask) * torch.finfo(dtype).min
|
mask = (1.0 - mask) * -1.0e+10
|
||||||
return mask
|
return mask
|
||||||
|
|||||||
1
examples/libritts/cosyvoice2/cosyvoice
Symbolic link
1
examples/libritts/cosyvoice2/cosyvoice
Symbolic link
@@ -0,0 +1 @@
|
|||||||
|
../../../cosyvoice
|
||||||
1
examples/libritts/cosyvoice2/tools
Symbolic link
1
examples/libritts/cosyvoice2/tools
Symbolic link
@@ -0,0 +1 @@
|
|||||||
|
../../../tools
|
||||||
@@ -24,7 +24,7 @@ import whisper
|
|||||||
|
|
||||||
|
|
||||||
def single_job(utt):
|
def single_job(utt):
|
||||||
audio, sample_rate = torchaudio.load(utt2wav[utt])
|
audio, sample_rate = torchaudio.load(utt2wav[utt], backend='soundfile')
|
||||||
if sample_rate != 16000:
|
if sample_rate != 16000:
|
||||||
audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(audio)
|
audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(audio)
|
||||||
if audio.shape[1] / 16000 > 30:
|
if audio.shape[1] / 16000 > 30:
|
||||||
|
|||||||
Reference in New Issue
Block a user