openpilot is an open source driver assistance system. openpilot performs the functions of Automated Lane Centering and Adaptive Cruise Control for over 200 supported car makes and models.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

83 lines
2.7 KiB

import unittest
import torch
import tqdm
import torchaudio
import pathlib
import jiwer
import os
import numpy as np
from whisper.normalizers import EnglishTextNormalizer
from examples.whisper import init_whisper, transcribe_waveform
class TestWhisperLibriSpeech(unittest.TestCase):
# reference WERs determined by running https://github.com/openai/whisper/blob/main/notebooks/LibriSpeech.ipynb
# the values should be consistent with the paper D.1.1 https://cdn.openai.com/papers/whisper.pdf#page=22
# tinygrad WERs do not perfectly match due to what seem to be precision differences vs torch
def test_en_tiny(self):
run_evaluation("tiny.en", 0.056629001883239174, 0.05655609406528749)
def test_tiny(self):
run_evaluation("tiny", 0.0771121409407306, 0.07558413638335187)
def test_en_base(self):
run_evaluation("base.en", 0.041412520064205455, 0.04271408904897505)
def test_en_small(self):
run_evaluation("small.en", 0.03369011117172363, 0.030531615969223228)
def run_evaluation(model_name, tinygrad_expected_wer, reference_wer):
dataset = LibriSpeech()
batch_size=16
loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size)
model, enc = init_whisper(model_name, batch_size=batch_size)
hypotheses = []
references = []
for audio, texts in tqdm.tqdm(loader):
transcriptions = transcribe_waveform(model, enc, audio.numpy(), truncate=True)
hypotheses.extend(transcriptions)
references.extend(texts)
normalizer = EnglishTextNormalizer()
normalized_hypotheses = [normalizer(text) for text in hypotheses]
normalized_references = [normalizer(text) for text in references]
wer = jiwer.wer(normalized_hypotheses, normalized_references)
np.testing.assert_almost_equal(wer, tinygrad_expected_wer)
print(f'tinygrad WER {wer} vs reference WER {reference_wer}')
del model, enc
class LibriSpeech(torch.utils.data.Dataset):
def __init__(self):
folder = pathlib.Path(__file__).parent.parent.parent / "extra" / "datasets" / "librispeech"
if not os.path.exists(folder):
os.makedirs(folder)
self.dataset = torchaudio.datasets.LIBRISPEECH(
root=folder,
url="test-clean",
download=True,
)
def __len__(self):
return len(self.dataset)
def __getitem__(self, item):
audio, sample_rate, text, _, _, _ = self.dataset[item]
assert sample_rate == 16000
return pad_or_trim_tensor(audio[0]), text
def pad_or_trim_tensor(tensor, target_len=480000):
curr_len = len(tensor)
if curr_len == target_len:
return tensor
elif curr_len < target_len:
return torch.cat((tensor, torch.zeros(target_len - curr_len)))
else:
return tensor[:target_len]
if __name__ == '__main__':
unittest.main()