mirror of
https://github.com/HumanAIGC/lite-avatar.git
synced 2026-02-05 18:09:20 +08:00
add files
This commit is contained in:
60
funasr_local/runtime/triton_gpu/client/utils.py
Normal file
60
funasr_local/runtime/triton_gpu/client/utils.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import numpy as np
|
||||
|
||||
|
||||
def _levenshtein_distance(ref, hyp):
|
||||
"""Levenshtein distance is a string metric for measuring the difference
|
||||
between two sequences. Informally, the levenshtein disctance is defined as
|
||||
the minimum number of single-character edits (substitutions, insertions or
|
||||
deletions) required to change one word into the other. We can naturally
|
||||
extend the edits to word level when calculate levenshtein disctance for
|
||||
two sentences.
|
||||
"""
|
||||
m = len(ref)
|
||||
n = len(hyp)
|
||||
|
||||
# special case
|
||||
if ref == hyp:
|
||||
return 0
|
||||
if m == 0:
|
||||
return n
|
||||
if n == 0:
|
||||
return m
|
||||
|
||||
if m < n:
|
||||
ref, hyp = hyp, ref
|
||||
m, n = n, m
|
||||
|
||||
# use O(min(m, n)) space
|
||||
distance = np.zeros((2, n + 1), dtype=np.int32)
|
||||
|
||||
# initialize distance matrix
|
||||
for j in range(n + 1):
|
||||
distance[0][j] = j
|
||||
|
||||
# calculate levenshtein distance
|
||||
for i in range(1, m + 1):
|
||||
prev_row_idx = (i - 1) % 2
|
||||
cur_row_idx = i % 2
|
||||
distance[cur_row_idx][0] = i
|
||||
for j in range(1, n + 1):
|
||||
if ref[i - 1] == hyp[j - 1]:
|
||||
distance[cur_row_idx][j] = distance[prev_row_idx][j - 1]
|
||||
else:
|
||||
s_num = distance[prev_row_idx][j - 1] + 1
|
||||
i_num = distance[cur_row_idx][j - 1] + 1
|
||||
d_num = distance[prev_row_idx][j] + 1
|
||||
distance[cur_row_idx][j] = min(s_num, i_num, d_num)
|
||||
|
||||
return distance[m % 2][n]
|
||||
|
||||
|
||||
def cal_cer(references, predictions):
|
||||
errors = 0
|
||||
lengths = 0
|
||||
for ref, pred in zip(references, predictions):
|
||||
cur_ref = list(ref)
|
||||
cur_hyp = list(pred)
|
||||
cur_error = _levenshtein_distance(cur_ref, cur_hyp)
|
||||
errors += cur_error
|
||||
lengths += len(cur_ref)
|
||||
return float(errors) / lengths
|
||||
Reference in New Issue
Block a user