This commit is contained in:
lyuxiang.lx
2026-01-29 06:13:36 +00:00
parent 66b80dbccb
commit f26cde56df
7 changed files with 90 additions and 73 deletions

View File

@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import queue
import os, queue
import random
import time
import threading
@@ -28,7 +28,7 @@ from cosyvoice.transformer.label_smoothing_loss import LabelSmoothingLoss
from cosyvoice.utils.common import th_accuracy
from cosyvoice.utils.file_utils import logging
from cosyvoice.utils.mask import make_pad_mask
from cosyvoice.utils.onnx import SpeechTokenExtractor
from cosyvoice.utils.onnx import SpeechTokenExtractor, online_feature, onnx_path
class TransformerLM(torch.nn.Module):
@@ -301,6 +301,8 @@ class Qwen2LM(TransformerLM):
# 5. vllm related
self.stop_token_ids = [speech_token_size + i for i in range(3)]
self.vllm_output_queue = {}
if online_feature is True:
self.speech_token_extractor = SpeechTokenExtractor(model_path=os.path.join(onnx_path, 'speech_tokenizer_v2.batch.onnx'))
def prepare_lm_input_target(self, sos_emb, text_token, text_token_emb, text_token_len, task_id_emb, speech_token, speech_token_emb, speech_token_len, instruct_token=None, instruct_token_emb=None, instruct_token_len=None):
lm_target, lm_input = [], []
@@ -667,6 +669,8 @@ class CosyVoice3LM(Qwen2LM):
# 5. vllm related
self.stop_token_ids = [speech_token_size + i for i in range(200)]
self.vllm_output_queue = {}
if online_feature is True:
self.speech_token_extractor = SpeechTokenExtractor(model_path=os.path.join(onnx_path, 'speech_tokenizer_v3.batch.onnx'))
def forward(
self,