mirror of
https://github.com/HumanAIGC/lite-avatar.git
synced 2026-02-05 18:09:20 +08:00
add files
This commit is contained in:
132
funasr_local/modules/data2vec/ema_module.py
Normal file
132
funasr_local/modules/data2vec/ema_module.py
Normal file
@@ -0,0 +1,132 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Used for EMA tracking a given pytorch module. The user is responsible for calling step()
|
||||
and setting the appropriate decay
|
||||
"""
|
||||
|
||||
import copy
|
||||
import logging
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class EMAModule:
|
||||
"""Exponential Moving Average of Fairseq Models"""
|
||||
|
||||
def __init__(self, model, ema_decay=0.9999, ema_fp32=False, device=None, skip_keys=None):
|
||||
"""
|
||||
@param model model to initialize the EMA with
|
||||
@param config EMAConfig object with configuration like
|
||||
ema_decay, ema_update_freq, ema_fp32
|
||||
@param device If provided, copy EMA to this device (e.g. gpu).
|
||||
Otherwise EMA is in the same device as the model.
|
||||
"""
|
||||
|
||||
self.decay = ema_decay
|
||||
self.ema_fp32 = ema_fp32
|
||||
self.model = copy.deepcopy(model)
|
||||
self.model.requires_grad_(False)
|
||||
self.skip_keys = skip_keys or set()
|
||||
self.fp32_params = {}
|
||||
|
||||
if device is not None:
|
||||
logging.info(f"Copying EMA model to device {device}")
|
||||
self.model = self.model.to(device=device)
|
||||
|
||||
if self.ema_fp32:
|
||||
self.build_fp32_params()
|
||||
|
||||
self.update_freq_counter = 0
|
||||
|
||||
def build_fp32_params(self, state_dict=None):
|
||||
"""
|
||||
Store a copy of the EMA params in fp32.
|
||||
If state dict is passed, the EMA params is copied from
|
||||
the provided state dict. Otherwise, it is copied from the
|
||||
current EMA model parameters.
|
||||
"""
|
||||
if not self.ema_fp32:
|
||||
raise RuntimeError(
|
||||
"build_fp32_params should not be called if ema_fp32=False. "
|
||||
"Use ema_fp32=True if this is really intended."
|
||||
)
|
||||
|
||||
if state_dict is None:
|
||||
state_dict = self.model.state_dict()
|
||||
|
||||
def _to_float(t):
|
||||
return t.float() if torch.is_floating_point(t) else t
|
||||
|
||||
for param_key in state_dict:
|
||||
if param_key in self.fp32_params:
|
||||
self.fp32_params[param_key].copy_(state_dict[param_key])
|
||||
else:
|
||||
self.fp32_params[param_key] = _to_float(state_dict[param_key])
|
||||
|
||||
def restore(self, state_dict, build_fp32_params=False):
|
||||
"""Load data from a model spec into EMA model"""
|
||||
self.model.load_state_dict(state_dict, strict=False)
|
||||
if build_fp32_params:
|
||||
self.build_fp32_params(state_dict)
|
||||
|
||||
def set_decay(self, decay):
|
||||
self.decay = decay
|
||||
|
||||
def get_decay(self):
|
||||
return self.decay
|
||||
|
||||
def _step_internal(self, new_model):
|
||||
"""One update of the EMA model based on new model weights"""
|
||||
decay = self.decay
|
||||
|
||||
ema_state_dict = {}
|
||||
ema_params = (
|
||||
self.fp32_params if self.ema_fp32 else self.model.state_dict()
|
||||
)
|
||||
for key, param in new_model.state_dict().items():
|
||||
if isinstance(param, dict):
|
||||
continue
|
||||
try:
|
||||
ema_param = ema_params[key]
|
||||
except KeyError:
|
||||
ema_param = (
|
||||
param.float().clone() if param.ndim == 1 else copy.deepcopy(param)
|
||||
)
|
||||
|
||||
if param.shape != ema_param.shape:
|
||||
raise ValueError(
|
||||
"incompatible tensor shapes between model param and ema param"
|
||||
+ "{} vs. {}".format(param.shape, ema_param.shape)
|
||||
)
|
||||
|
||||
if "version" in key:
|
||||
# Do not decay a model.version pytorch param
|
||||
continue
|
||||
|
||||
if key in self.skip_keys or ("num_batches_tracked" in key and ema_param.dtype == torch.int64):
|
||||
ema_param = param.to(dtype=ema_param.dtype).clone()
|
||||
ema_params[key].copy_(ema_param)
|
||||
else:
|
||||
ema_param.mul_(decay)
|
||||
ema_param.add_(param.to(dtype=ema_param.dtype), alpha=1 - decay)
|
||||
ema_state_dict[key] = ema_param
|
||||
self.restore(ema_state_dict, build_fp32_params=False)
|
||||
|
||||
def step(self, new_model):
|
||||
self._step_internal(new_model)
|
||||
|
||||
def reverse(self, model):
|
||||
"""
|
||||
Load the model parameters from EMA model.
|
||||
Useful for inference or fine-tuning from the EMA model.
|
||||
"""
|
||||
d = self.model.state_dict()
|
||||
if "_ema" in d:
|
||||
del d["_ema"]
|
||||
|
||||
model.load_state_dict(d, strict=False)
|
||||
return model
|
||||
Reference in New Issue
Block a user