diff --git a/cosyvoice/flow/flow_matching.py b/cosyvoice/flow/flow_matching.py index 40062b3..f4e0ace 100644 --- a/cosyvoice/flow/flow_matching.py +++ b/cosyvoice/flow/flow_matching.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import threading import torch import torch.nn.functional as F from matcha.models.components.flow_matching import BASECFM @@ -30,6 +31,7 @@ class ConditionalCFM(BASECFM): in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0) # Just change the architecture of the estimator here self.estimator = estimator + self.lock = threading.Lock() @torch.inference_mode() def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, flow_cache=torch.zeros(1, 80, 0, 2)): @@ -123,20 +125,21 @@ class ConditionalCFM(BASECFM): if isinstance(self.estimator, torch.nn.Module): return self.estimator.forward(x, mask, mu, t, spks, cond) else: - self.estimator.set_input_shape('x', (2, 80, x.size(2))) - self.estimator.set_input_shape('mask', (2, 1, x.size(2))) - self.estimator.set_input_shape('mu', (2, 80, x.size(2))) - self.estimator.set_input_shape('t', (2,)) - self.estimator.set_input_shape('spks', (2, 80)) - self.estimator.set_input_shape('cond', (2, 80, x.size(2))) - # run trt engine - self.estimator.execute_v2([x.contiguous().data_ptr(), - mask.contiguous().data_ptr(), - mu.contiguous().data_ptr(), - t.contiguous().data_ptr(), - spks.contiguous().data_ptr(), - cond.contiguous().data_ptr(), - x.data_ptr()]) + with self.lock: + self.estimator.set_input_shape('x', (2, 80, x.size(2))) + self.estimator.set_input_shape('mask', (2, 1, x.size(2))) + self.estimator.set_input_shape('mu', (2, 80, x.size(2))) + self.estimator.set_input_shape('t', (2,)) + self.estimator.set_input_shape('spks', (2, 80)) + self.estimator.set_input_shape('cond', (2, 80, x.size(2))) + # run trt engine + self.estimator.execute_v2([x.contiguous().data_ptr(), + mask.contiguous().data_ptr(), + mu.contiguous().data_ptr(), + t.contiguous().data_ptr(), + spks.contiguous().data_ptr(), + cond.contiguous().data_ptr(), + x.data_ptr()]) return x def compute_loss(self, x1, mask, mu, spks=None, cond=None): diff --git a/requirements.txt b/requirements.txt index e02452b..70f5e5c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,7 +4,7 @@ conformer==0.3.2 deepspeed==0.14.2; sys_platform == 'linux' diffusers==0.27.2 gdown==5.1.0 -gradio==4.32.2 +gradio==5.4.0 grpcio==1.57.0 grpcio-tools==1.57.0 huggingface-hub==0.25.2