mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 01:49:25 +08:00
@@ -11,6 +11,7 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import threading
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from matcha.models.components.flow_matching import BASECFM
|
from matcha.models.components.flow_matching import BASECFM
|
||||||
@@ -30,6 +31,7 @@ class ConditionalCFM(BASECFM):
|
|||||||
in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0)
|
in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0)
|
||||||
# Just change the architecture of the estimator here
|
# Just change the architecture of the estimator here
|
||||||
self.estimator = estimator
|
self.estimator = estimator
|
||||||
|
self.lock = threading.Lock()
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, flow_cache=torch.zeros(1, 80, 0, 2)):
|
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, flow_cache=torch.zeros(1, 80, 0, 2)):
|
||||||
@@ -123,20 +125,21 @@ class ConditionalCFM(BASECFM):
|
|||||||
if isinstance(self.estimator, torch.nn.Module):
|
if isinstance(self.estimator, torch.nn.Module):
|
||||||
return self.estimator.forward(x, mask, mu, t, spks, cond)
|
return self.estimator.forward(x, mask, mu, t, spks, cond)
|
||||||
else:
|
else:
|
||||||
self.estimator.set_input_shape('x', (2, 80, x.size(2)))
|
with self.lock:
|
||||||
self.estimator.set_input_shape('mask', (2, 1, x.size(2)))
|
self.estimator.set_input_shape('x', (2, 80, x.size(2)))
|
||||||
self.estimator.set_input_shape('mu', (2, 80, x.size(2)))
|
self.estimator.set_input_shape('mask', (2, 1, x.size(2)))
|
||||||
self.estimator.set_input_shape('t', (2,))
|
self.estimator.set_input_shape('mu', (2, 80, x.size(2)))
|
||||||
self.estimator.set_input_shape('spks', (2, 80))
|
self.estimator.set_input_shape('t', (2,))
|
||||||
self.estimator.set_input_shape('cond', (2, 80, x.size(2)))
|
self.estimator.set_input_shape('spks', (2, 80))
|
||||||
# run trt engine
|
self.estimator.set_input_shape('cond', (2, 80, x.size(2)))
|
||||||
self.estimator.execute_v2([x.contiguous().data_ptr(),
|
# run trt engine
|
||||||
mask.contiguous().data_ptr(),
|
self.estimator.execute_v2([x.contiguous().data_ptr(),
|
||||||
mu.contiguous().data_ptr(),
|
mask.contiguous().data_ptr(),
|
||||||
t.contiguous().data_ptr(),
|
mu.contiguous().data_ptr(),
|
||||||
spks.contiguous().data_ptr(),
|
t.contiguous().data_ptr(),
|
||||||
cond.contiguous().data_ptr(),
|
spks.contiguous().data_ptr(),
|
||||||
x.data_ptr()])
|
cond.contiguous().data_ptr(),
|
||||||
|
x.data_ptr()])
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def compute_loss(self, x1, mask, mu, spks=None, cond=None):
|
def compute_loss(self, x1, mask, mu, spks=None, cond=None):
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ conformer==0.3.2
|
|||||||
deepspeed==0.14.2; sys_platform == 'linux'
|
deepspeed==0.14.2; sys_platform == 'linux'
|
||||||
diffusers==0.27.2
|
diffusers==0.27.2
|
||||||
gdown==5.1.0
|
gdown==5.1.0
|
||||||
gradio==4.32.2
|
gradio==5.4.0
|
||||||
grpcio==1.57.0
|
grpcio==1.57.0
|
||||||
grpcio-tools==1.57.0
|
grpcio-tools==1.57.0
|
||||||
huggingface-hub==0.25.2
|
huggingface-hub==0.25.2
|
||||||
|
|||||||
Reference in New Issue
Block a user