#!/usr/bin/env python3
import os, math, time
import numpy as np
from tinygrad import Tensor, nn, fetch, Device, TinyJit, GlobalCounters
from dataclasses import dataclass

@dataclass
class GPTConfig:
  block_size: int = 1024
  vocab_size: int = 50257
  padded_vocab_size: int = 50304
  n_layer: int = 12
  n_head: int = 12
  n_embd: int = 768

class CausalSelfAttention:
  def __init__(self, config:GPTConfig):
    assert config.n_embd % config.n_head == 0
    # key, query, value projections for all heads, but in a batch
    self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
    # output projection
    self.c_proj = nn.Linear(config.n_embd, config.n_embd)
    # regularization
    self.n_head = config.n_head
    self.n_embd = config.n_embd
    # not really a 'bias', more of a mask, but following the OpenAI/HF naming though
    self.bias = Tensor.ones(1, 1, config.block_size, config.block_size).tril()
    self.bias.requires_grad = False

  def __call__(self, x:Tensor):
    B, T, C = x.shape
    qkv = self.c_attn(x)
    q, k, v = qkv.split(self.n_embd, dim=2)
    k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
    q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
    v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)

    # manual implementation of attention
    att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
    att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
    att = att.softmax()
    y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
    y = y.transpose(1, 2).view(B, T, C) # re-assemble all head outputs side by side
    # output projection
    y = self.c_proj(y)
    return y

class MLP:
  def __init__(self, config:GPTConfig):
    self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd)
    self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)

  def __call__(self, x:Tensor) -> Tensor:
    return self.c_proj(self.c_fc(x).gelu())

class Block:
  def __init__(self, config:GPTConfig):
    self.ln_1 = nn.LayerNorm(config.n_embd)
    self.attn = CausalSelfAttention(config)
    self.ln_2 = nn.LayerNorm(config.n_embd)
    self.mlp = MLP(config)

  def __call__(self, x:Tensor):
    x = x + self.attn(self.ln_1(x))
    x = x + self.mlp(self.ln_2(x))
    return x

class GPT:
  def __init__(self, config:GPTConfig):
    self.config = config

    self.wte = nn.Embedding(config.padded_vocab_size, config.n_embd)
    self.wpe = nn.Embedding(config.block_size, config.n_embd)
    self.h = [Block(config) for _ in range(config.n_layer)]
    self.ln_f = nn.LayerNorm(config.n_embd)
    self.lm_head = nn.Linear(config.n_embd, config.padded_vocab_size, bias=False)
    self.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying

  def load_pretrained(self):
    weights = nn.state.torch_load(fetch(f'https://huggingface.co/gpt2/resolve/main/pytorch_model.bin'))
    transposed = ('attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight')
    for k in weights:
      if k == "wte.weight":
        weights[k] = weights[k].pad(((0, self.config.padded_vocab_size-self.config.vocab_size), (0,0))).to(None).contiguous()
      if k.endswith(transposed):
        weights[k] = weights[k].to(None).T.contiguous()
    # lm head and wte are tied
    weights['lm_head.weight'] = weights['wte.weight']
    nn.state.load_state_dict(self, weights)

  def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
    for _ in range(max_new_tokens):
      idx_cond = idx if idx.shape[1] <= self.config.block_size else idx[:, -self.config.block_size:]
      logits, _ = self(idx_cond)
      logits = logits[:, -1, :] / temperature
      idx_next = logits.softmax().multinomial()
      idx = Tensor.cat(idx, idx_next, dim=1)
    return idx

  def __call__(self, idx:Tensor, targets=None):
    b, t = idx.shape
    pos = Tensor.arange(0, t)

    tok_emb = self.wte(idx) # token embeddings of shape (b, t, n_embd)
    pos_emb = self.wpe(pos) # position embeddings of shape (t, n_embd)
    x = tok_emb + pos_emb

    x = self.ln_f(x.sequential(self.h))

    if targets is not None:
      logits = self.lm_head(x)[:, :, :self.config.vocab_size]
      loss = logits.sparse_categorical_crossentropy(targets)
    else:
      logits = self.lm_head(x[:, [-1], :])[:, :, :self.config.vocab_size]
      loss = None

    return logits, loss

if __name__ == "__main__":
  import tiktoken, argparse

  parser = argparse.ArgumentParser()
  parser.add_argument("--num_iterations", type=int, default=10, help="number of iterations to run")
  parser.add_argument("--batch_size", type=int, default=4, help="batch size")
  parser.add_argument("--sequence_length", type=int, default=64, help="sequence length")
  parser.add_argument("--skip_test", action="store_true", help="skip test")
  args = parser.parse_args()
  B, T = args.batch_size, args.sequence_length
  assert 1 <= T <= 1024

  model = GPT(GPTConfig(n_layer=12, n_head=12, n_embd=768))
  model.load_pretrained()

  # init the tokenizer
  enc = tiktoken.get_encoding("gpt2")
  encode = lambda s: enc.encode(s, allowed_special={"<|endoftext|>"})
  decode = lambda l: enc.decode(l)

  # load the tokens
  # prefer to use tiny_shakespeare if it's available, otherwise use tiny_stories
  # we're using val instead of train split just because it is smaller/faster
  tokens_bin = fetch("https://huggingface.co/datasets/karpathy/llmc-starter-pack/resolve/main/tiny_shakespeare_val.bin")
  assert os.path.isfile(tokens_bin)
  print(f"loading cached tokens in {tokens_bin}")
  with open(tokens_bin, "rb") as f:
    f.seek(0x400)
    tokens = np.frombuffer(f.read(), dtype=np.uint16).astype(np.int32)
  tokens = Tensor(tokens)

  # lightweight dataloader
  def get_batch():
    assert B*T+1 <= len(tokens), "not enough tokens"
    # for 338,025 tokens. E.g. with B=8 T=1024, this will yield 41 batches before looping
    i = 0
    while True:
      x = tokens[i:i+B*T].view(B, T)
      y = tokens[i+1:i+B*T+1].view(B, T)
      yield x, y
      i += B*T
      if i + B*T + 1 >= len(tokens):
        i = 0 # in prod we'd want to randomize the start point a bit

  # forward backward for a few iterations
  data_iter = iter(get_batch())
  x, y = next(data_iter) # we'll overfit this batch below
  optimizer = nn.optim.AdamW(nn.state.get_parameters(model), lr=1e-4, weight_decay=0)

  @TinyJit
  def step(x, y):
    _, loss = model(x, y)
    optimizer.zero_grad()
    loss.backward()
    return loss.realize(*optimizer.schedule_step())

  with Tensor.train():
    for i in range(args.num_iterations):
      GlobalCounters.reset()
      t0 = time.time()
      loss = step(x.contiguous(), y.contiguous())
      Device[Device.DEFAULT].synchronize()
      t1 = time.time()
      print(f"iteration {i}, loss: {loss.item():.6f}, time: {(t1-t0)*1000:.3f}ms, {int(B*T/(t1-t0))} tok/s")

  if not args.skip_test:
    start = "<|endoftext|>"
    start_ids = encode(start)
    x = (Tensor(start_ids)[None, ...])
    max_new_tokens = 16
    temperature = 1.0
    top_k = 40
    y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k)
    print(decode(y[0].tolist()))