from tinygrad.tensor import Tensor from tinygrad.engine.jit import TinyJit from tinygrad.nn import Linear, Embedding from tinygrad.helpers import fetch import numpy as np from pathlib import Path class RNNT: def __init__(self, input_features=240, vocab_size=29, enc_hidden_size=1024, pred_hidden_size=320, joint_hidden_size=512, pre_enc_layers=2, post_enc_layers=3, pred_layers=2, stack_time_factor=2, dropout=0.32): self.encoder = Encoder(input_features, enc_hidden_size, pre_enc_layers, post_enc_layers, stack_time_factor, dropout) self.prediction = Prediction(vocab_size, pred_hidden_size, pred_layers, dropout) self.joint = Joint(vocab_size, pred_hidden_size, enc_hidden_size, joint_hidden_size, dropout) @TinyJit def __call__(self, x, y, hc=None): f, _ = self.encoder(x, None) g, _ = self.prediction(y, hc, Tensor.ones(1, requires_grad=False)) out = self.joint(f, g) return out.realize() def decode(self, x, x_lens): logits, logit_lens = self.encoder(x, x_lens) outputs = [] for b in range(logits.shape[0]): inseq = logits[b, :, :].unsqueeze(1) logit_len = logit_lens[b] seq = self._greedy_decode(inseq, int(np.ceil(logit_len.numpy()).item())) outputs.append(seq) return outputs def _greedy_decode(self, logits, logit_len): hc = Tensor.zeros(self.prediction.rnn.layers, 2, self.prediction.hidden_size, requires_grad=False) labels = [] label = Tensor.zeros(1, 1, requires_grad=False) mask = Tensor.zeros(1, requires_grad=False) for time_idx in range(logit_len): logit = logits[time_idx, :, :].unsqueeze(0) not_blank = True added = 0 while not_blank and added < 30: if len(labels) > 0: mask = (mask + 1).clip(0, 1) label = Tensor([[labels[-1] if labels[-1] <= 28 else labels[-1] - 1]], requires_grad=False) + 1 - 1 jhc = self._pred_joint(Tensor(logit.numpy()), label, hc, mask) k = jhc[0, 0, :29].argmax(axis=0).numpy() not_blank = k != 28 if not_blank: labels.append(k) hc = jhc[:, :, 29:] + 1 - 1 added += 1 return labels @TinyJit def _pred_joint(self, logit, label, hc, mask): g, hc = self.prediction(label, hc, mask) j = self.joint(logit, g)[0] j = j.pad(((0, 1), (0, 1), (0, 0))) out = j.cat(hc, dim=2) return out.realize() def load_from_pretrained(self): fn = Path(__file__).parents[1] / "weights/rnnt.pt" fetch("https://zenodo.org/record/3662521/files/DistributedDataParallel_1576581068.9962234-epoch-100.pt?download=1", fn) import torch with open(fn, "rb") as f: state_dict = torch.load(f, map_location="cpu")["state_dict"] # encoder for i in range(2): self.encoder.pre_rnn.cells[i].weights_ih.assign(state_dict[f"encoder.pre_rnn.lstm.weight_ih_l{i}"].numpy()) self.encoder.pre_rnn.cells[i].weights_hh.assign(state_dict[f"encoder.pre_rnn.lstm.weight_hh_l{i}"].numpy()) self.encoder.pre_rnn.cells[i].bias_ih.assign(state_dict[f"encoder.pre_rnn.lstm.bias_ih_l{i}"].numpy()) self.encoder.pre_rnn.cells[i].bias_hh.assign(state_dict[f"encoder.pre_rnn.lstm.bias_hh_l{i}"].numpy()) for i in range(3): self.encoder.post_rnn.cells[i].weights_ih.assign(state_dict[f"encoder.post_rnn.lstm.weight_ih_l{i}"].numpy()) self.encoder.post_rnn.cells[i].weights_hh.assign(state_dict[f"encoder.post_rnn.lstm.weight_hh_l{i}"].numpy()) self.encoder.post_rnn.cells[i].bias_ih.assign(state_dict[f"encoder.post_rnn.lstm.bias_ih_l{i}"].numpy()) self.encoder.post_rnn.cells[i].bias_hh.assign(state_dict[f"encoder.post_rnn.lstm.bias_hh_l{i}"].numpy()) # prediction self.prediction.emb.weight.assign(state_dict["prediction.embed.weight"].numpy()) for i in range(2): self.prediction.rnn.cells[i].weights_ih.assign(state_dict[f"prediction.dec_rnn.lstm.weight_ih_l{i}"].numpy()) self.prediction.rnn.cells[i].weights_hh.assign(state_dict[f"prediction.dec_rnn.lstm.weight_hh_l{i}"].numpy()) self.prediction.rnn.cells[i].bias_ih.assign(state_dict[f"prediction.dec_rnn.lstm.bias_ih_l{i}"].numpy()) self.prediction.rnn.cells[i].bias_hh.assign(state_dict[f"prediction.dec_rnn.lstm.bias_hh_l{i}"].numpy()) # joint self.joint.l1.weight.assign(state_dict["joint_net.0.weight"].numpy()) self.joint.l1.bias.assign(state_dict["joint_net.0.bias"].numpy()) self.joint.l2.weight.assign(state_dict["joint_net.3.weight"].numpy()) self.joint.l2.bias.assign(state_dict["joint_net.3.bias"].numpy()) class LSTMCell: def __init__(self, input_size, hidden_size, dropout): self.dropout = dropout self.weights_ih = Tensor.uniform(hidden_size * 4, input_size) self.bias_ih = Tensor.uniform(hidden_size * 4) self.weights_hh = Tensor.uniform(hidden_size * 4, hidden_size) self.bias_hh = Tensor.uniform(hidden_size * 4) def __call__(self, x, hc): gates = x.linear(self.weights_ih.T, self.bias_ih) + hc[:x.shape[0]].linear(self.weights_hh.T, self.bias_hh) i, f, g, o = gates.chunk(4, 1) i, f, g, o = i.sigmoid(), f.sigmoid(), g.tanh(), o.sigmoid() c = (f * hc[x.shape[0]:]) + (i * g) h = (o * c.tanh()).dropout(self.dropout) return Tensor.cat(h, c).realize() class LSTM: def __init__(self, input_size, hidden_size, layers, dropout): self.input_size = input_size self.hidden_size = hidden_size self.layers = layers self.cells = [LSTMCell(input_size, hidden_size, dropout) if i == 0 else LSTMCell(hidden_size, hidden_size, dropout if i != layers - 1 else 0) for i in range(layers)] def __call__(self, x, hc): @TinyJit def _do_step(x_, hc_): return self.do_step(x_, hc_) if hc is None: hc = Tensor.zeros(self.layers, 2 * x.shape[1], self.hidden_size, requires_grad=False) output = None for t in range(x.shape[0]): hc = _do_step(x[t] + 1 - 1, hc) # TODO: why do we need to do this? if output is None: output = hc[-1:, :x.shape[1]] else: output = output.cat(hc[-1:, :x.shape[1]], dim=0).realize() return output, hc def do_step(self, x, hc): new_hc = [x] for i, cell in enumerate(self.cells): new_hc.append(cell(new_hc[i][:x.shape[0]], hc[i])) return Tensor.stack(*new_hc[1:]).realize() class StackTime: def __init__(self, factor): self.factor = factor def __call__(self, x, x_lens): x = x.pad(((0, (-x.shape[0]) % self.factor), (0, 0), (0, 0))) x = x.reshape(x.shape[0] // self.factor, x.shape[1], x.shape[2] * self.factor) return x, x_lens / self.factor if x_lens is not None else None class Encoder: def __init__(self, input_size, hidden_size, pre_layers, post_layers, stack_time_factor, dropout): self.pre_rnn = LSTM(input_size, hidden_size, pre_layers, dropout) self.stack_time = StackTime(stack_time_factor) self.post_rnn = LSTM(stack_time_factor * hidden_size, hidden_size, post_layers, dropout) def __call__(self, x, x_lens): x, _ = self.pre_rnn(x, None) x, x_lens = self.stack_time(x, x_lens) x, _ = self.post_rnn(x, None) return x.transpose(0, 1), x_lens class Prediction: def __init__(self, vocab_size, hidden_size, layers, dropout): self.hidden_size = hidden_size self.emb = Embedding(vocab_size - 1, hidden_size) self.rnn = LSTM(hidden_size, hidden_size, layers, dropout) def __call__(self, x, hc, m): emb = self.emb(x) * m x_, hc = self.rnn(emb.transpose(0, 1), hc) return x_.transpose(0, 1), hc class Joint: def __init__(self, vocab_size, pred_hidden_size, enc_hidden_size, joint_hidden_size, dropout): self.dropout = dropout self.l1 = Linear(pred_hidden_size + enc_hidden_size, joint_hidden_size) self.l2 = Linear(joint_hidden_size, vocab_size) def __call__(self, f, g): (_, T, H), (B, U, H2) = f.shape, g.shape f = f.unsqueeze(2).expand(B, T, U, H) g = g.unsqueeze(1).expand(B, T, U, H2) inp = f.cat(g, dim=3) t = self.l1(inp).relu() t = t.dropout(self.dropout) return self.l2(t)