# thanks to https://github.com/openai/whisper for a good chunk of MIT licensed code
import sys , base64 , multiprocessing , itertools , collections
from typing import Optional , Union , Literal , List
from tinygrad import Tensor , TinyJit , Variable , nn
from tinygrad . nn . state import torch_load , load_state_dict
from tinygrad . helpers import getenv , DEBUG , fetch
import numpy as np
import librosa
class MultiHeadAttention :
def __init__ ( self , n_state , n_head , kv_caching : Literal [ ' cross ' , ' self ' ] = None , max_self_attn_cache_len = None ) :
self . n_head = n_head
self . query = nn . Linear ( n_state , n_state )
self . key = nn . Linear ( n_state , n_state , bias = False )
self . value = nn . Linear ( n_state , n_state )
self . out = nn . Linear ( n_state , n_state )
self . kv_caching = kv_caching
self . max_self_attn_cache_len = max_self_attn_cache_len
def __call__ ( self , x : Tensor , xa : Optional [ Tensor ] = None , mask : Optional [ Tensor ] = None , len : Union [ Variable , int ] = None ) :
if self . kv_caching == ' cross ' :
if xa is not None :
k , v = self . key ( xa ) , self . value ( xa )
if not hasattr ( self , ' cache_k ' ) :
self . cache_k , self . cache_v = k , v
else :
self . cache_k . assign ( k ) . realize ( )
self . cache_v . assign ( v ) . realize ( )
else :
k , v = self . cache_k , self . cache_v
else :
k , v = self . key ( x ) , self . value ( x )
if self . kv_caching == ' self ' :
if not hasattr ( self , ' cache_k ' ) :
self . cache_k = Tensor . zeros ( x . shape [ 0 ] , self . max_self_attn_cache_len , x . shape [ 2 ] )
self . cache_v = Tensor . zeros ( x . shape [ 0 ] , self . max_self_attn_cache_len , x . shape [ 2 ] )
k = self . cache_k . shrink ( ( None , ( 0 , len ) , None ) ) . cat ( k , dim = 1 )
v = self . cache_v . shrink ( ( None , ( 0 , len ) , None ) ) . cat ( v , dim = 1 )
padding = self . max_self_attn_cache_len - len - x . shape [ 1 ]
self . cache_k . assign ( k . pad ( ( None , ( 0 , padding ) , None ) ) . contiguous ( ) ) . realize ( )
self . cache_v . assign ( v . pad ( ( None , ( 0 , padding ) , None ) ) . contiguous ( ) ) . realize ( )
q = self . query ( x )
n_ctx = q . shape [ 1 ]
assert ( q . shape [ - 1 ] == k . shape [ - 1 ] == v . shape [ - 1 ] )
head_dim = q . shape [ - 1 ] / / self . n_head
q = q . reshape ( * q . shape [ : 2 ] , self . n_head , head_dim ) . permute ( 0 , 2 , 1 , 3 )
k = k . reshape ( * k . shape [ : 2 ] , self . n_head , head_dim ) . permute ( 0 , 2 , 1 , 3 )
v = v . reshape ( * v . shape [ : 2 ] , self . n_head , head_dim ) . permute ( 0 , 2 , 1 , 3 )
attn = Tensor . scaled_dot_product_attention ( q , k , v , mask [ : n_ctx , : n_ctx ] if mask is not None else None )
wv = attn . permute ( 0 , 2 , 1 , 3 ) . flatten ( start_dim = 2 )
return self . out ( wv )
class ResidualAttentionBlock :
def __init__ ( self , n_state , n_head , is_decoder_block = False , max_self_attn_cache_len = None ) :
self . attn = MultiHeadAttention ( n_state , n_head , kv_caching = ' self ' if is_decoder_block else None , max_self_attn_cache_len = max_self_attn_cache_len )
self . attn_ln = nn . LayerNorm ( n_state )
self . cross_attn = MultiHeadAttention ( n_state , n_head , kv_caching = ' cross ' ) if is_decoder_block else None
self . cross_attn_ln = nn . LayerNorm ( n_state ) if is_decoder_block else None
self . mlp = [ nn . Linear ( n_state , n_state * 4 ) , Tensor . gelu , nn . Linear ( n_state * 4 , n_state ) ]
self . mlp_ln = nn . LayerNorm ( n_state )
def __call__ ( self , x , xa = None , mask = None , len : Union [ Variable , int ] = None ) :
x = x + self . attn ( self . attn_ln ( x ) , mask = mask , len = len )
if self . cross_attn : x = x + self . cross_attn ( self . cross_attn_ln ( x ) , xa )
x = x + self . mlp_ln ( x ) . sequential ( self . mlp )
return x . realize ( )
class AudioEncoder :
def __init__ ( self , n_mels , n_audio_ctx , n_audio_state , n_audio_head , n_audio_layer , * * _ ) :
self . conv1 = nn . Conv1d ( n_mels , n_audio_state , kernel_size = 3 , padding = 1 )
self . conv2 = nn . Conv1d ( n_audio_state , n_audio_state , kernel_size = 3 , stride = 2 , padding = 1 )
self . blocks = [ ResidualAttentionBlock ( n_audio_state , n_audio_head ) for _ in range ( n_audio_layer ) ]
self . ln_post = nn . LayerNorm ( n_audio_state )
self . positional_embedding = Tensor . empty ( n_audio_ctx , n_audio_state )
self . encode = TinyJit ( self . __call__ )
def __call__ ( self , x ) :
x = self . conv1 ( x ) . gelu ( )
x = self . conv2 ( x ) . gelu ( )
x = x . permute ( 0 , 2 , 1 )
x = x + self . positional_embedding [ : x . shape [ 1 ] ]
x = x . sequential ( self . blocks )
x = self . ln_post ( x )
return x . realize ( )
class TextDecoder :
def __init__ ( self , n_vocab , n_text_ctx , n_text_state , n_text_head , n_text_layer , * * _ ) :
self . max_tokens_to_sample = n_text_ctx / / 2
self . max_self_attn_cache_len = n_text_ctx
self . token_embedding = nn . Embedding ( n_vocab , n_text_state )
self . positional_embedding = Tensor . empty ( n_text_ctx , n_text_state )
self . blocks = [ ResidualAttentionBlock ( n_text_state , n_text_head , is_decoder_block = True , max_self_attn_cache_len = self . max_self_attn_cache_len ) for _ in range ( n_text_layer ) ]
self . ln = nn . LayerNorm ( n_text_state )
self . mask = Tensor . full ( ( n_text_ctx , n_text_ctx ) , - np . inf ) . triu ( 1 ) . realize ( )
self . getjitted = collections . defaultdict ( lambda : TinyJit ( self . forward ) )
def __call__ ( self , x : Tensor , pos : int , encoded_audio : Tensor ) :
pos = Variable ( " self_attn_cache_len " , 1 , self . max_self_attn_cache_len - 1 ) . bind ( pos ) if pos else 0
return self . getjitted [ x . shape ] ( x , pos , encoded_audio )
def forward ( self , x : Tensor , pos : Union [ Variable , Literal [ 0 ] ] , encoded_audio : Tensor ) :
seqlen = x . shape [ - 1 ]
x = self . token_embedding ( x ) + self . positional_embedding . shrink ( ( ( pos , pos + seqlen ) , None , None ) )
for block in self . blocks : x = block ( x , xa = encoded_audio , mask = self . mask , len = pos )
return self . output_tok ( x )
def output_tok ( self , x ) :
return ( self . ln ( x ) @ self . token_embedding . weight . T ) . realize ( )
class Whisper :
def __init__ ( self , dims , batch_size = 1 ) :
self . encoder = AudioEncoder ( * * dims )
self . decoder = TextDecoder ( * * dims )
self . is_multilingual = dims [ " n_vocab " ] == 51865
self . batch_size = batch_size
RATE = 16000
SEGMENT_SECONDS = 30
SAMPLES_PER_SEGMENT = RATE * SEGMENT_SECONDS # 480000
N_FFT = 400
HOP_LENGTH = 160
N_MELS = 80
FRAMES_PER_SEGMENT = SAMPLES_PER_SEGMENT / / HOP_LENGTH # 3000
def prep_audio ( waveforms : List [ np . ndarray ] , batch_size : int , truncate = False ) - > np . ndarray :
"""
: param waveforms : A list of possibly variable length 16000 Hz audio samples
: param batch_size : The batch_size associated with the Whisper model being used to transcribe the audio .
Used to prevent JIT mismatch errors since the encoder does not accept symbolic shapes
: param truncate : If true , truncates ( or pads ) audio to exactly 30 s for a single encoder pass
: return : mel spectrogram of the given waveforms
"""
def pad_or_trim ( arr , target_len ) :
curr_len = len ( arr )
if curr_len == target_len :
return arr
elif curr_len < target_len :
return np . pad ( arr , ( 0 , target_len - curr_len ) , ' constant ' )
else :
return arr [ : target_len ]
max_len = SAMPLES_PER_SEGMENT if truncate else max ( len ( wav ) for wav in waveforms )
if ( r := max_len % SAMPLES_PER_SEGMENT ) > 0 : max_len + = SAMPLES_PER_SEGMENT - r
waveforms = np . array ( list ( map ( lambda w : pad_or_trim ( w , max_len ) , waveforms ) ) )
assert waveforms . shape [ 0 ] < = batch_size
if waveforms . shape [ 0 ] < batch_size :
# we could have a symbolic batch_size dim instead of manually padding here if conv/layernorm supported symbolic shapes
waveforms = np . pad ( waveforms , pad_width = ( ( 0 , batch_size - waveforms . shape [ 0 ] ) , ( 0 , 0 ) ) )
stft = librosa . stft ( waveforms , n_fft = N_FFT , hop_length = HOP_LENGTH , window = ' hann ' , dtype = np . csingle )
magnitudes = np . absolute ( stft [ . . . , : - 1 ] ) * * 2
mel_spec = librosa . filters . mel ( sr = RATE , n_fft = N_FFT , n_mels = N_MELS ) @ magnitudes
log_spec = np . log10 ( np . clip ( mel_spec , 1e-10 , None ) )
log_spec = np . maximum ( log_spec , log_spec . max ( ( 1 , 2 ) , keepdims = True ) - 8.0 )
log_spec = ( log_spec + 4.0 ) / 4.0
return log_spec
LANGUAGES = {
" en " : " english " , " zh " : " chinese " , " de " : " german " , " es " : " spanish " , " ru " : " russian " , " ko " : " korean " , " fr " : " french " , " ja " : " japanese " , " pt " : " portuguese " , " tr " : " turkish " ,
" pl " : " polish " , " ca " : " catalan " , " nl " : " dutch " , " ar " : " arabic " , " sv " : " swedish " , " it " : " italian " , " id " : " indonesian " , " hi " : " hindi " , " fi " : " finnish " , " vi " : " vietnamese " ,
" he " : " hebrew " , " uk " : " ukrainian " , " el " : " greek " , " ms " : " malay " , " cs " : " czech " , " ro " : " romanian " , " da " : " danish " , " hu " : " hungarian " , " ta " : " tamil " , " no " : " norwegian " ,
" th " : " thai " , " ur " : " urdu " , " hr " : " croatian " , " bg " : " bulgarian " , " lt " : " lithuanian " , " la " : " latin " , " mi " : " maori " , " ml " : " malayalam " , " cy " : " welsh " , " sk " : " slovak " , " te " : " telugu " ,
" fa " : " persian " , " lv " : " latvian " , " bn " : " bengali " , " sr " : " serbian " , " az " : " azerbaijani " , " sl " : " slovenian " , " kn " : " kannada " , " et " : " estonian " , " mk " : " macedonian " ,
" br " : " breton " , " eu " : " basque " , " is " : " icelandic " , " hy " : " armenian " , " ne " : " nepali " , " mn " : " mongolian " , " bs " : " bosnian " , " kk " : " kazakh " , " sq " : " albanian " , " sw " : " swahili " ,
" gl " : " galician " , " mr " : " marathi " , " pa " : " punjabi " , " si " : " sinhala " , " km " : " khmer " , " sn " : " shona " , " yo " : " yoruba " , " so " : " somali " , " af " : " afrikaans " , " oc " : " occitan " , " ka " : " georgian " ,
" be " : " belarusian " , " tg " : " tajik " , " sd " : " sindhi " , " gu " : " gujarati " , " am " : " amharic " , " yi " : " yiddish " , " lo " : " lao " , " uz " : " uzbek " , " fo " : " faroese " , " ht " : " haitian creole " ,
" ps " : " pashto " , " tk " : " turkmen " , " nn " : " nynorsk " , " mt " : " maltese " , " sa " : " sanskrit " , " lb " : " luxembourgish " , " my " : " myanmar " , " bo " : " tibetan " , " tl " : " tagalog " , " mg " : " malagasy " ,
" as " : " assamese " , " tt " : " tatar " , " haw " : " hawaiian " , " ln " : " lingala " , " ha " : " hausa " , " ba " : " bashkir " , " jw " : " javanese " , " su " : " sundanese " ,
}
def get_encoding ( encoding_name ) :
with fetch ( f " https://raw.githubusercontent.com/openai/whisper/main/whisper/assets/ { encoding_name } .tiktoken " ) . open ( ) as f :
ranks = { base64 . b64decode ( token ) : int ( rank ) for token , rank in ( line . split ( ) for line in f if line ) }
n_vocab = len ( ranks )
specials = [
" <|endoftext|> " ,
" <|startoftranscript|> " ,
* [ f " <| { lang } |> " for lang in LANGUAGES . keys ( ) ] ,
" <|translate|> " ,
" <|transcribe|> " ,
" <|startoflm|> " ,
" <|startofprev|> " ,
" <|nospeech|> " ,
" <|notimestamps|> " ,
* [ f " <| { i * 0.02 : .2f } |> " for i in range ( 1501 ) ] ,
]
special_tokens = dict ( zip ( specials , itertools . count ( n_vocab ) ) )
n_vocab + = len ( specials )
import tiktoken
return tiktoken . Encoding (
name = encoding_name ,
explicit_n_vocab = n_vocab ,
pat_str = r """ ' s| ' t| ' re| ' ve| ' m| ' ll| ' d| ? \ p {L} +| ? \ p {N} +| ?[^ \ s \ p {L} \ p {N} ]+| \ s+(?! \ S)| \ s+ """ ,
mergeable_ranks = ranks ,
special_tokens = special_tokens )
MODEL_URLS = {
" tiny.en " : " https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt " ,
" tiny " : " https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt " ,
" base.en " : " https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt " ,
" base " : " https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt " ,
" small.en " : " https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt " ,
" small " : " https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt " ,
" medium.en " : " https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt " ,
" medium " : " https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt " ,
" large-v1 " : " https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt " ,
" large-v2 " : " https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt " ,
" large " : " https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt " ,
}
def init_whisper ( model_name = " tiny.en " , batch_size = 1 ) :
assert MODEL_URLS [ model_name ] is not None
filename = fetch ( MODEL_URLS [ model_name ] )
state = torch_load ( filename )
model = Whisper ( state [ ' dims ' ] , batch_size )
load_state_dict ( model , state [ ' model_state_dict ' ] , strict = False )
enc = get_encoding ( " multilingual " if model . is_multilingual else " gpt2 " )
return model , enc
def load_file_waveform ( filename ) :
waveform , _ = librosa . load ( filename , sr = RATE )
return waveform
def transcribe_file ( model , enc , filename ) :
return transcribe_waveform ( model , enc , [ load_file_waveform ( filename ) ] )
def transcribe_waveform ( model : Whisper , enc , waveforms , truncate = False ) :
"""
Expects an array of shape ( N , S ) where N is the number waveforms to transcribe in parallel and S is number of 16000 Hz samples
Returns the transcribed text if a single waveform is provided , or an array of transcriptions if multiple are provided
"""
log_spec = prep_audio ( waveforms , model . batch_size , truncate )
nsample = model . decoder . max_tokens_to_sample
def inferloop ( ctx : Union [ np . ndarray , List [ np . ndarray ] ] , encoded_audio ) :
pos , next_tokens = 0 , ctx
for i in range ( ( nsample - len ( start_tokens ) ) * 2 ) :
next_tokens = model . decoder ( Tensor ( next_tokens ) , pos , encoded_audio ) [ : , - 1 ] . argmax ( axis = - 1 ) . numpy ( ) . astype ( np . int32 ) . reshape ( - 1 , 1 )
next_tokens [ ctx [ : , - 1 ] == eot ] = eot
ctx = np . concatenate ( ( ctx , next_tokens ) , axis = 1 )
pos = ctx . shape [ - 1 ] - 1
if ( next_tokens == eot ) . all ( ) : break
return ctx
def gettexttoks ( line ) : return [ tok for tok in line if tok < eot or tok > enc . _special_tokens [ " <|notimestamps|> " ] ] [ - nsample + len ( start_tokens ) : ]
start_tokens = [ enc . _special_tokens [ " <|startoftranscript|> " ] ]
if model . is_multilingual :
# TODO detect language
language_token = enc . _special_tokens [ " <|startoftranscript|> " ] + 1 + tuple ( LANGUAGES . keys ( ) ) . index ( " en " )
start_tokens . append ( language_token )
start_tokens . append ( enc . _special_tokens [ " <|transcribe|> " ] )
start_tokens . append ( enc . _special_tokens [ " <|notimestamps|> " ] )
eot = enc . _special_tokens [ " <|endoftext|> " ]
ctx = np . tile ( start_tokens , ( model . batch_size , 1 ) )
transcriptions = [ [ ] for _ in waveforms ]
for curr_frame in range ( 0 , log_spec . shape [ - 1 ] , FRAMES_PER_SEGMENT ) :
encoded_audio = model . encoder . encode ( Tensor ( log_spec [ : , : , curr_frame : curr_frame + FRAMES_PER_SEGMENT ] ) )
if all ( len ( c ) == len ( ctx [ 0 ] ) for c in ctx ) : ctx = inferloop ( np . array ( ctx ) , encoded_audio )
else : ctx = [ inferloop ( ( np . array ( [ c ] * model . batch_size ) ) , encoded_audio ) [ i ] for i , c in enumerate ( ctx ) ]
for i , ( res , arr ) in enumerate ( zip ( transcriptions , ctx ) ) :
if curr_frame * HOP_LENGTH < = len ( waveforms [ i ] ) : res . extend ( arr [ np . where ( arr == start_tokens [ - 1 ] ) [ 0 ] [ 0 ] + 1 : eoti [ 0 ] if len ( eoti := np . where ( arr == eot ) [ 0 ] ) else None ] )
ctx = [ [ enc . _special_tokens [ ' <|startofprev|> ' ] ] + gettexttoks ( cs ) + start_tokens for cs in ctx ]
transcriptions = list ( map ( lambda tokens : enc . decode ( tokens ) . strip ( ) , transcriptions ) )
return transcriptions if len ( transcriptions ) > 1 else transcriptions [ 0 ]
CHUNK = 1600
RECORD_SECONDS = 10
def listener ( q ) :
import pyaudio
p = pyaudio . PyAudio ( )
stream = p . open ( format = pyaudio . paInt16 , channels = 1 , rate = RATE , input = True , frames_per_buffer = CHUNK )
print ( " listening " )
for _ in range ( 0 , int ( RATE / CHUNK * RECORD_SECONDS ) ) :
data = stream . read ( CHUNK )
waveform = ( ( np . frombuffer ( data , np . int16 ) / 32768 ) . astype ( np . float32 ) * 3 )
q . put ( waveform )
print ( " done listening " )
if __name__ == " __main__ " :
model , enc = init_whisper ( " small.en " if getenv ( " SMALL " ) else " tiny.en " , batch_size = 1 )
if len ( sys . argv ) > 1 :
print ( transcribe_file ( model , enc , sys . argv [ 1 ] ) )
else :
# online
q = multiprocessing . Queue ( )
p = multiprocessing . Process ( target = listener , args = ( q , ) )
p . daemon = True
p . start ( )
lst = [ enc . _special_tokens [ " <|startoftranscript|> " ] , enc . _special_tokens [ " <|notimestamps|> " ] ]
total = None
did_read = False
for i in range ( 0 , int ( RATE / CHUNK * RECORD_SECONDS ) ) :
while not q . empty ( ) or total is None :
waveform = q . get ( )
if total is None : total = waveform
else : total = np . concatenate ( [ total , waveform ] )
did_read = True
if did_read :
log_spec = prep_audio ( total . reshape ( 1 , - 1 ) , model . batch_size , truncate = True )
encoded_audio = model . encoder . encode ( Tensor ( log_spec ) )
# pass the previously inferred tokens as 'prefix' - https://github.com/openai/whisper/discussions/117#discussioncomment-3727051
out = model . decoder ( Tensor ( [ lst ] ) , 0 , encoded_audio , streaming = True ) . realize ( )
idx = int ( out [ 0 , - 1 ] . argmax ( ) . numpy ( ) . item ( ) )
lst . append ( idx )
dec = enc . decode ( lst )
print ( dec ) # DO NOT REMOVE PRINT. IT'S VERY IMPORTANT
if dec . endswith ( " <|endoftext|> " ) :
lst . pop ( )