mirror of
https://github.com/HumanAIGC/lite-avatar.git
synced 2026-02-05 18:09:20 +08:00
add files
This commit is contained in:
62
funasr_local/modules/streaming_utils/load_fr_tf.py
Normal file
62
funasr_local/modules/streaming_utils/load_fr_tf.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import numpy as np
|
||||
np.set_printoptions(threshold=np.inf)
|
||||
import logging
|
||||
|
||||
def load_ckpt(checkpoint_path):
|
||||
import tensorflow as tf
|
||||
if tf.__version__.startswith('2'):
|
||||
import tensorflow.compat.v1 as tf
|
||||
tf.disable_v2_behavior()
|
||||
reader = tf.compat.v1.train.NewCheckpointReader(checkpoint_path)
|
||||
else:
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
|
||||
var_to_shape_map = reader.get_variable_to_shape_map()
|
||||
|
||||
var_dict = dict()
|
||||
for var_name in sorted(var_to_shape_map):
|
||||
if "Adam" in var_name:
|
||||
continue
|
||||
tensor = reader.get_tensor(var_name)
|
||||
# print("in ckpt: {}, {}".format(var_name, tensor.shape))
|
||||
# print(tensor)
|
||||
var_dict[var_name] = tensor
|
||||
|
||||
return var_dict
|
||||
|
||||
|
||||
|
||||
def load_tf_pb_dict(pb_model):
|
||||
import tensorflow as tf
|
||||
if tf.__version__.startswith('2'):
|
||||
import tensorflow.compat.v1 as tf
|
||||
tf.disable_v2_behavior()
|
||||
# import tensorflow_addons as tfa
|
||||
# from tensorflow_addons.seq2seq.python.ops import beam_search_ops
|
||||
else:
|
||||
from tensorflow.contrib.seq2seq.python.ops import beam_search_ops
|
||||
from tensorflow.python.ops import lookup_ops as lookup
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.platform import gfile
|
||||
|
||||
sess = tf.Session()
|
||||
with gfile.FastGFile(pb_model, 'rb') as f:
|
||||
graph_def = tf.GraphDef()
|
||||
graph_def.ParseFromString(f.read())
|
||||
sess.graph.as_default()
|
||||
tf.import_graph_def(graph_def, name='')
|
||||
|
||||
var_dict = dict()
|
||||
for node in sess.graph_def.node:
|
||||
if node.op == 'Const':
|
||||
value = tensor_util.MakeNdarray(node.attr['value'].tensor)
|
||||
if len(value.shape) >= 1:
|
||||
var_dict[node.name] = value
|
||||
return var_dict
|
||||
|
||||
def load_tf_dict(pb_model):
|
||||
if "model.ckpt-" in pb_model:
|
||||
var_dict = load_ckpt(pb_model)
|
||||
else:
|
||||
var_dict = load_tf_pb_dict(pb_model)
|
||||
return var_dict
|
||||
Reference in New Issue
Block a user