mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 18:09:24 +08:00
Enhance CosyVoice with CUDA stream management and estimator handling
- Introduced a queue-based system for managing CUDA streams to improve inference performance. - Updated inference methods to utilize CUDA streams for asynchronous processing. - Added an EstimatorWrapper class to manage TensorRT estimators, allowing for efficient execution context handling. - Modified model loading functions to support estimator count configuration. - Improved logging and performance tracking during inference operations.
This commit is contained in:
@@ -22,7 +22,7 @@ from cosyvoice.cli.frontend import CosyVoiceFrontEnd
|
|||||||
from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model, VllmCosyVoice2Model
|
from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model, VllmCosyVoice2Model
|
||||||
from cosyvoice.utils.file_utils import logging
|
from cosyvoice.utils.file_utils import logging
|
||||||
from cosyvoice.utils.class_utils import get_model_type
|
from cosyvoice.utils.class_utils import get_model_type
|
||||||
|
import queue
|
||||||
|
|
||||||
class CosyVoice:
|
class CosyVoice:
|
||||||
|
|
||||||
@@ -54,11 +54,18 @@ class CosyVoice:
|
|||||||
'{}/llm.llm.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
|
'{}/llm.llm.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
|
||||||
'{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
|
'{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
|
||||||
if load_trt:
|
if load_trt:
|
||||||
|
self.estimator_count = configs['flow']['decoder']['estimator'].get('estimator_count', 1)
|
||||||
self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
|
self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
|
||||||
'{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
|
'{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
|
||||||
self.fp16)
|
self.fp16, self.estimator_count)
|
||||||
del configs
|
del configs
|
||||||
|
|
||||||
|
thread_count = 10
|
||||||
|
self.stream_pool = queue.Queue(maxsize=thread_count)
|
||||||
|
for _ in range(thread_count):
|
||||||
|
self.stream_pool.put(torch.cuda.Stream(self.device))
|
||||||
|
|
||||||
|
|
||||||
def list_available_spks(self):
|
def list_available_spks(self):
|
||||||
spks = list(self.frontend.spk2info.keys())
|
spks = list(self.frontend.spk2info.keys())
|
||||||
return spks
|
return spks
|
||||||
@@ -67,6 +74,8 @@ class CosyVoice:
|
|||||||
self.frontend.add_spk_info(spk_id, spk_info)
|
self.frontend.add_spk_info(spk_id, spk_info)
|
||||||
|
|
||||||
def inference_sft(self, tts_text, spk_id, stream=False, speed=1.0, text_frontend=True):
|
def inference_sft(self, tts_text, spk_id, stream=False, speed=1.0, text_frontend=True):
|
||||||
|
cuda_stream = self.stream_pool.get()
|
||||||
|
with torch.cuda.stream(cuda_stream):
|
||||||
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
|
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
|
||||||
model_input = self.frontend.frontend_sft(i, spk_id)
|
model_input = self.frontend.frontend_sft(i, spk_id)
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
@@ -76,8 +85,12 @@ class CosyVoice:
|
|||||||
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
||||||
yield model_output
|
yield model_output
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
cuda_stream.synchronize()
|
||||||
|
self.stream_pool.put(cuda_stream)
|
||||||
|
|
||||||
def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, stream=False, speed=1.0, text_frontend=True):
|
def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, stream=False, speed=1.0, text_frontend=True):
|
||||||
|
cuda_stream = self.stream_pool.get()
|
||||||
|
with torch.cuda.stream(cuda_stream):
|
||||||
prompt_text = self.frontend.text_normalize(prompt_text, split=False, text_frontend=text_frontend)
|
prompt_text = self.frontend.text_normalize(prompt_text, split=False, text_frontend=text_frontend)
|
||||||
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
|
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
|
||||||
if (not isinstance(i, Generator)) and len(i) < 0.5 * len(prompt_text):
|
if (not isinstance(i, Generator)) and len(i) < 0.5 * len(prompt_text):
|
||||||
@@ -90,9 +103,13 @@ class CosyVoice:
|
|||||||
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
||||||
yield model_output
|
yield model_output
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
cuda_stream.synchronize()
|
||||||
|
self.stream_pool.put(cuda_stream)
|
||||||
|
|
||||||
def inference_zero_shot_by_spk_id(self, tts_text, spk_id, stream=False, speed=1.0, text_frontend=True):
|
def inference_zero_shot_by_spk_id(self, tts_text, spk_id, stream=False, speed=1.0, text_frontend=True):
|
||||||
"""使用预定义的说话人执行 zero_shot 推理"""
|
"""使用预定义的说话人执行 zero_shot 推理"""
|
||||||
|
cuda_stream = self.stream_pool.get()
|
||||||
|
with torch.cuda.stream(cuda_stream):
|
||||||
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
|
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
|
||||||
model_input = self.frontend.frontend_zero_shot_by_spk_id(i, spk_id)
|
model_input = self.frontend.frontend_zero_shot_by_spk_id(i, spk_id)
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
@@ -106,8 +123,12 @@ class CosyVoice:
|
|||||||
yield model_output
|
yield model_output
|
||||||
last_time = time.time()
|
last_time = time.time()
|
||||||
chunk_index += 1
|
chunk_index += 1
|
||||||
|
cuda_stream.synchronize()
|
||||||
|
self.stream_pool.put(cuda_stream)
|
||||||
|
|
||||||
def inference_cross_lingual(self, tts_text, prompt_speech_16k, stream=False, speed=1.0, text_frontend=True):
|
def inference_cross_lingual(self, tts_text, prompt_speech_16k, stream=False, speed=1.0, text_frontend=True):
|
||||||
|
cuda_stream = self.stream_pool.get()
|
||||||
|
with torch.cuda.stream(cuda_stream):
|
||||||
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
|
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
|
||||||
model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k, self.sample_rate)
|
model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k, self.sample_rate)
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
@@ -117,8 +138,12 @@ class CosyVoice:
|
|||||||
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
||||||
yield model_output
|
yield model_output
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
cuda_stream.synchronize()
|
||||||
|
self.stream_pool.put(cuda_stream)
|
||||||
|
|
||||||
def inference_instruct(self, tts_text, spk_id, instruct_text, stream=False, speed=1.0, text_frontend=True):
|
def inference_instruct(self, tts_text, spk_id, instruct_text, stream=False, speed=1.0, text_frontend=True):
|
||||||
|
cuda_stream = self.stream_pool.get()
|
||||||
|
with torch.cuda.stream(cuda_stream):
|
||||||
assert isinstance(self.model, CosyVoiceModel), 'inference_instruct is only implemented for CosyVoice!'
|
assert isinstance(self.model, CosyVoiceModel), 'inference_instruct is only implemented for CosyVoice!'
|
||||||
if self.instruct is False:
|
if self.instruct is False:
|
||||||
raise ValueError('{} do not support instruct inference'.format(self.model_dir))
|
raise ValueError('{} do not support instruct inference'.format(self.model_dir))
|
||||||
@@ -132,8 +157,12 @@ class CosyVoice:
|
|||||||
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
||||||
yield model_output
|
yield model_output
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
cuda_stream.synchronize()
|
||||||
|
self.stream_pool.put(cuda_stream)
|
||||||
|
|
||||||
def inference_vc(self, source_speech_16k, prompt_speech_16k, stream=False, speed=1.0):
|
def inference_vc(self, source_speech_16k, prompt_speech_16k, stream=False, speed=1.0):
|
||||||
|
cuda_stream = self.stream_pool.get()
|
||||||
|
with torch.cuda.stream(cuda_stream):
|
||||||
model_input = self.frontend.frontend_vc(source_speech_16k, prompt_speech_16k, self.sample_rate)
|
model_input = self.frontend.frontend_vc(source_speech_16k, prompt_speech_16k, self.sample_rate)
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
for model_output in self.model.vc(**model_input, stream=stream, speed=speed):
|
for model_output in self.model.vc(**model_input, stream=stream, speed=speed):
|
||||||
@@ -141,6 +170,8 @@ class CosyVoice:
|
|||||||
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
||||||
yield model_output
|
yield model_output
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
cuda_stream.synchronize()
|
||||||
|
self.stream_pool.put(cuda_stream)
|
||||||
|
|
||||||
|
|
||||||
class CosyVoice2(CosyVoice):
|
class CosyVoice2(CosyVoice):
|
||||||
@@ -178,15 +209,23 @@ class CosyVoice2(CosyVoice):
|
|||||||
if load_jit:
|
if load_jit:
|
||||||
self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
|
self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
|
||||||
if load_trt:
|
if load_trt:
|
||||||
|
self.estimator_count = configs['flow']['decoder']['estimator'].get('estimator_count', 1)
|
||||||
self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
|
self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
|
||||||
'{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
|
'{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
|
||||||
self.fp16)
|
self.fp16, self.estimator_count)
|
||||||
del configs
|
del configs
|
||||||
|
|
||||||
|
thread_count = 10
|
||||||
|
self.stream_pool = queue.Queue(maxsize=thread_count)
|
||||||
|
for _ in range(thread_count):
|
||||||
|
self.stream_pool.put(torch.cuda.Stream(self.device))
|
||||||
|
|
||||||
def inference_instruct(self, *args, **kwargs):
|
def inference_instruct(self, *args, **kwargs):
|
||||||
raise NotImplementedError('inference_instruct is not implemented for CosyVoice2!')
|
raise NotImplementedError('inference_instruct is not implemented for CosyVoice2!')
|
||||||
|
|
||||||
def inference_instruct2(self, tts_text, instruct_text, prompt_speech_16k, stream=False, speed=1.0, text_frontend=True):
|
def inference_instruct2(self, tts_text, instruct_text, prompt_speech_16k, stream=False, speed=1.0, text_frontend=True):
|
||||||
|
cuda_stream = self.stream_pool.get()
|
||||||
|
with torch.cuda.stream(cuda_stream):
|
||||||
assert isinstance(self.model, CosyVoice2Model), 'inference_instruct2 is only implemented for CosyVoice2!'
|
assert isinstance(self.model, CosyVoice2Model), 'inference_instruct2 is only implemented for CosyVoice2!'
|
||||||
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
|
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
|
||||||
model_input = self.frontend.frontend_instruct2(i, instruct_text, prompt_speech_16k, self.sample_rate)
|
model_input = self.frontend.frontend_instruct2(i, instruct_text, prompt_speech_16k, self.sample_rate)
|
||||||
@@ -197,8 +236,13 @@ class CosyVoice2(CosyVoice):
|
|||||||
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
||||||
yield model_output
|
yield model_output
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
cuda_stream.synchronize()
|
||||||
|
self.stream_pool.put(cuda_stream)
|
||||||
|
|
||||||
def inference_instruct2_by_spk_id(self, tts_text, instruct_text, spk_id, stream=False, speed=1.0, text_frontend=True):
|
def inference_instruct2_by_spk_id(self, tts_text, instruct_text, spk_id, stream=False, speed=1.0, text_frontend=True):
|
||||||
|
cuda_stream = self.stream_pool.get()
|
||||||
|
with torch.cuda.stream(cuda_stream):
|
||||||
|
assert isinstance(self.model, CosyVoice2Model), 'inference_instruct2 is only implemented for CosyVoice2!'
|
||||||
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
|
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
|
||||||
model_input = self.frontend.frontend_instruct2_by_spk_id(i, instruct_text, spk_id)
|
model_input = self.frontend.frontend_instruct2_by_spk_id(i, instruct_text, spk_id)
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
@@ -208,3 +252,5 @@ class CosyVoice2(CosyVoice):
|
|||||||
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
||||||
yield model_output
|
yield model_output
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
cuda_stream.synchronize()
|
||||||
|
self.stream_pool.put(cuda_stream)
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ from contextlib import nullcontext
|
|||||||
import uuid
|
import uuid
|
||||||
from cosyvoice.utils.common import fade_in_out
|
from cosyvoice.utils.common import fade_in_out
|
||||||
from cosyvoice.utils.file_utils import convert_onnx_to_trt
|
from cosyvoice.utils.file_utils import convert_onnx_to_trt
|
||||||
|
from cosyvoice.flow.flow_matching import EstimatorWrapper
|
||||||
|
|
||||||
class CosyVoiceModel:
|
class CosyVoiceModel:
|
||||||
|
|
||||||
@@ -84,7 +84,7 @@ class CosyVoiceModel:
|
|||||||
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
|
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
|
||||||
self.flow.encoder = flow_encoder
|
self.flow.encoder = flow_encoder
|
||||||
|
|
||||||
def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, fp16):
|
def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, fp16, estimator_count=1):
|
||||||
assert torch.cuda.is_available(), 'tensorrt only supports gpu!'
|
assert torch.cuda.is_available(), 'tensorrt only supports gpu!'
|
||||||
if not os.path.exists(flow_decoder_estimator_model):
|
if not os.path.exists(flow_decoder_estimator_model):
|
||||||
convert_onnx_to_trt(flow_decoder_estimator_model, flow_decoder_onnx_model, fp16)
|
convert_onnx_to_trt(flow_decoder_estimator_model, flow_decoder_onnx_model, fp16)
|
||||||
@@ -96,7 +96,7 @@ class CosyVoiceModel:
|
|||||||
self.flow.decoder.estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
|
self.flow.decoder.estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
|
||||||
if self.flow.decoder.estimator_engine is None:
|
if self.flow.decoder.estimator_engine is None:
|
||||||
raise ValueError('failed to load trt {}'.format(flow_decoder_estimator_model))
|
raise ValueError('failed to load trt {}'.format(flow_decoder_estimator_model))
|
||||||
self.flow.decoder.estimator = self.flow.decoder.estimator_engine.create_execution_context()
|
self.flow.decoder.estimator = EstimatorWrapper(self.flow.decoder.estimator_engine, estimator_count=estimator_count)
|
||||||
|
|
||||||
def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
|
def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
|
||||||
with self.llm_context:
|
with self.llm_context:
|
||||||
@@ -319,6 +319,7 @@ class CosyVoice2Model(CosyVoiceModel):
|
|||||||
self.flow.encoder = flow_encoder
|
self.flow.encoder = flow_encoder
|
||||||
|
|
||||||
def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, token_offset, finalize=False, speed=1.0):
|
def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, token_offset, finalize=False, speed=1.0):
|
||||||
|
|
||||||
tts_mel, _ = self.flow.inference(token=token.to(self.device),
|
tts_mel, _ = self.flow.inference(token=token.to(self.device),
|
||||||
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
|
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
|
||||||
prompt_token=prompt_token.to(self.device),
|
prompt_token=prompt_token.to(self.device),
|
||||||
|
|||||||
@@ -15,7 +15,26 @@ import threading
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from matcha.models.components.flow_matching import BASECFM
|
from matcha.models.components.flow_matching import BASECFM
|
||||||
|
import queue
|
||||||
|
|
||||||
|
class EstimatorWrapper:
|
||||||
|
def __init__(self, estimator_engine, estimator_count=2,):
|
||||||
|
self.estimators = queue.Queue()
|
||||||
|
self.estimator_engine = estimator_engine
|
||||||
|
for _ in range(estimator_count):
|
||||||
|
estimator = estimator_engine.create_execution_context()
|
||||||
|
if estimator is not None:
|
||||||
|
self.estimators.put(estimator)
|
||||||
|
|
||||||
|
if self.estimators.empty():
|
||||||
|
raise Exception("No available estimator")
|
||||||
|
|
||||||
|
def acquire_estimator(self):
|
||||||
|
return self.estimators.get(), self.estimator_engine
|
||||||
|
|
||||||
|
def release_estimator(self, estimator):
|
||||||
|
self.estimators.put(estimator)
|
||||||
|
return
|
||||||
|
|
||||||
class ConditionalCFM(BASECFM):
|
class ConditionalCFM(BASECFM):
|
||||||
def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
|
def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
|
||||||
@@ -124,6 +143,34 @@ class ConditionalCFM(BASECFM):
|
|||||||
def forward_estimator(self, x, mask, mu, t, spks, cond):
|
def forward_estimator(self, x, mask, mu, t, spks, cond):
|
||||||
if isinstance(self.estimator, torch.nn.Module):
|
if isinstance(self.estimator, torch.nn.Module):
|
||||||
return self.estimator.forward(x, mask, mu, t, spks, cond)
|
return self.estimator.forward(x, mask, mu, t, spks, cond)
|
||||||
|
else:
|
||||||
|
if isinstance(self.estimator, EstimatorWrapper):
|
||||||
|
estimator, engine = self.estimator.acquire_estimator()
|
||||||
|
|
||||||
|
estimator.set_input_shape('x', (2, 80, x.size(2)))
|
||||||
|
estimator.set_input_shape('mask', (2, 1, x.size(2)))
|
||||||
|
estimator.set_input_shape('mu', (2, 80, x.size(2)))
|
||||||
|
estimator.set_input_shape('t', (2,))
|
||||||
|
estimator.set_input_shape('spks', (2, 80))
|
||||||
|
estimator.set_input_shape('cond', (2, 80, x.size(2)))
|
||||||
|
|
||||||
|
data_ptrs = [x.contiguous().data_ptr(),
|
||||||
|
mask.contiguous().data_ptr(),
|
||||||
|
mu.contiguous().data_ptr(),
|
||||||
|
t.contiguous().data_ptr(),
|
||||||
|
spks.contiguous().data_ptr(),
|
||||||
|
cond.contiguous().data_ptr(),
|
||||||
|
x.data_ptr()]
|
||||||
|
|
||||||
|
for idx, data_ptr in enumerate(data_ptrs):
|
||||||
|
estimator.set_tensor_address(engine.get_tensor_name(idx), data_ptr)
|
||||||
|
|
||||||
|
# run trt engine
|
||||||
|
estimator.execute_async_v3(torch.cuda.current_stream().cuda_stream)
|
||||||
|
|
||||||
|
torch.cuda.current_stream().synchronize()
|
||||||
|
self.estimator.release_estimator(estimator)
|
||||||
|
return x
|
||||||
else:
|
else:
|
||||||
with self.lock:
|
with self.lock:
|
||||||
self.estimator.set_input_shape('x', (2, 80, x.size(2)))
|
self.estimator.set_input_shape('x', (2, 80, x.size(2)))
|
||||||
|
|||||||
Reference in New Issue
Block a user