mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 18:09:24 +08:00
keep only embedding mean as spk embedding
This commit is contained in:
@@ -167,7 +167,7 @@ def parse_embedding(data, normalize, mode='train'):
|
|||||||
"""
|
"""
|
||||||
for sample in data:
|
for sample in data:
|
||||||
sample['utt_embedding'] = torch.tensor(sample['utt_embedding'], dtype=torch.float32)
|
sample['utt_embedding'] = torch.tensor(sample['utt_embedding'], dtype=torch.float32)
|
||||||
sample['spk_embedding'] = torch.stack([torch.tensor(i, dtype=torch.float32) for i in sample['spk_embedding']], dim=0).mean(dim=0)
|
sample['spk_embedding'] = torch.tensor(sample['spk_embedding'], dtype=torch.float32)
|
||||||
if normalize:
|
if normalize:
|
||||||
sample['utt_embedding'] = F.normalize(sample['utt_embedding'], dim=0)
|
sample['utt_embedding'] = F.normalize(sample['utt_embedding'], dim=0)
|
||||||
sample['spk_embedding'] = F.normalize(sample['spk_embedding'], dim=0)
|
sample['spk_embedding'] = F.normalize(sample['spk_embedding'], dim=0)
|
||||||
|
|||||||
@@ -53,6 +53,8 @@ def main(args):
|
|||||||
if spk not in spk2embedding:
|
if spk not in spk2embedding:
|
||||||
spk2embedding[spk] = []
|
spk2embedding[spk] = []
|
||||||
spk2embedding[spk].append(embedding)
|
spk2embedding[spk].append(embedding)
|
||||||
|
for k, v in spk2embedding.items():
|
||||||
|
spk2embedding[k] = torch.tensor(v).mean(dim=0, keepdim=True).tolist()
|
||||||
|
|
||||||
torch.save(utt2embedding, '{}/utt2embedding.pt'.format(args.dir))
|
torch.save(utt2embedding, '{}/utt2embedding.pt'.format(args.dir))
|
||||||
torch.save(spk2embedding, '{}/spk2embedding.pt'.format(args.dir))
|
torch.save(spk2embedding, '{}/spk2embedding.pt'.format(args.dir))
|
||||||
|
|||||||
Reference in New Issue
Block a user