From 6a3e44242ad24f01ba64430d8f6ac5718442b0da Mon Sep 17 00:00:00 2001 From: "lyuxiang.lx" Date: Wed, 10 Jul 2024 00:21:56 +0800 Subject: [PATCH] keep only embedding mean as spk embedding --- cosyvoice/dataset/processor.py | 2 +- tools/extract_embedding.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/cosyvoice/dataset/processor.py b/cosyvoice/dataset/processor.py index 9477d02..cb34a0c 100644 --- a/cosyvoice/dataset/processor.py +++ b/cosyvoice/dataset/processor.py @@ -167,7 +167,7 @@ def parse_embedding(data, normalize, mode='train'): """ for sample in data: sample['utt_embedding'] = torch.tensor(sample['utt_embedding'], dtype=torch.float32) - sample['spk_embedding'] = torch.stack([torch.tensor(i, dtype=torch.float32) for i in sample['spk_embedding']], dim=0).mean(dim=0) + sample['spk_embedding'] = torch.tensor(sample['spk_embedding'], dtype=torch.float32) if normalize: sample['utt_embedding'] = F.normalize(sample['utt_embedding'], dim=0) sample['spk_embedding'] = F.normalize(sample['spk_embedding'], dim=0) diff --git a/tools/extract_embedding.py b/tools/extract_embedding.py index 02fa2f6..9c6f568 100755 --- a/tools/extract_embedding.py +++ b/tools/extract_embedding.py @@ -53,6 +53,8 @@ def main(args): if spk not in spk2embedding: spk2embedding[spk] = [] spk2embedding[spk].append(embedding) + for k, v in spk2embedding.items(): + spk2embedding[k] = torch.tensor(v).mean(dim=0, keepdim=True).tolist() torch.save(utt2embedding, '{}/utt2embedding.pt'.format(args.dir)) torch.save(spk2embedding, '{}/spk2embedding.pt'.format(args.dir))