You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
62 lines
2.0 KiB
62 lines
2.0 KiB
1 month ago
|
import re
|
||
|
import string
|
||
|
from collections import Counter
|
||
|
|
||
|
def levenshtein(a, b):
|
||
|
n, m = len(a), len(b)
|
||
|
if n > m:
|
||
|
a, b, n, m = b, a, m, n
|
||
|
|
||
|
current = list(range(n + 1))
|
||
|
for i in range(1, m + 1):
|
||
|
previous, current = current, [i] + [0] * n
|
||
|
for j in range(1, n + 1):
|
||
|
add, delete = previous[j] + 1, current[j - 1] + 1
|
||
|
change = previous[j - 1]
|
||
|
if a[j - 1] != b[i - 1]:
|
||
|
change = change + 1
|
||
|
current[j] = min(add, delete, change)
|
||
|
|
||
|
return current[n]
|
||
|
|
||
|
def word_error_rate(x, y):
|
||
|
scores = words = 0
|
||
|
for h, r in zip(x, y):
|
||
|
h_list = h.split()
|
||
|
r_list = r.split()
|
||
|
words += len(r_list)
|
||
|
scores += levenshtein(h_list, r_list)
|
||
|
return float(scores) / words, float(scores), words
|
||
|
|
||
|
def one_hot(x):
|
||
|
return x.one_hot(3).squeeze(1).permute(0, 4, 1, 2, 3)
|
||
|
|
||
|
def dice_score(prediction, target, channel_axis=1, smooth_nr=1e-6, smooth_dr=1e-6, argmax=True, to_one_hot_x=True):
|
||
|
channel_axis, reduce_axis = 1, tuple(range(2, len(prediction.shape)))
|
||
|
if argmax: prediction = prediction.argmax(axis=channel_axis)
|
||
|
else: prediction = prediction.softmax(axis=channel_axis)
|
||
|
if to_one_hot_x: prediction = one_hot(prediction)
|
||
|
target = one_hot(target)
|
||
|
prediction, target = prediction[:, 1:], target[:, 1:]
|
||
|
assert prediction.shape == target.shape, f"prediction ({prediction.shape}) and target ({target.shape}) shapes do not match"
|
||
|
intersection = (prediction * target).sum(axis=reduce_axis)
|
||
|
target_sum = target.sum(axis=reduce_axis)
|
||
|
prediction_sum = prediction.sum(axis=reduce_axis)
|
||
|
result = (2.0 * intersection + smooth_nr) / (target_sum + prediction_sum + smooth_dr)
|
||
|
return result
|
||
|
|
||
|
def normalize_string(s):
|
||
|
s = "".join(c for c in s.lower() if c not in string.punctuation)
|
||
|
s = re.sub(r'\b(a|an|the)\b', ' ', s)
|
||
|
return " ".join(s.split())
|
||
|
|
||
|
def f1_score(x, y):
|
||
|
xt = normalize_string(x).split()
|
||
|
yt = normalize_string(y).split()
|
||
|
ct = Counter(xt) & Counter(yt)
|
||
|
if (ns := sum(ct.values())) == 0:
|
||
|
return 0.0
|
||
|
p = ns / len(xt)
|
||
|
r = ns / len(yt)
|
||
|
return 2 * p * r / (p + r)
|