You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
194 lines
7.0 KiB
194 lines
7.0 KiB
1 month ago
|
#!/usr/bin/env python3
|
||
|
import os, math, time
|
||
|
import numpy as np
|
||
|
from tinygrad import Tensor, nn, fetch, Device, TinyJit, GlobalCounters
|
||
|
from dataclasses import dataclass
|
||
|
|
||
|
@dataclass
|
||
|
class GPTConfig:
|
||
|
block_size: int = 1024
|
||
|
vocab_size: int = 50257
|
||
|
padded_vocab_size: int = 50304
|
||
|
n_layer: int = 12
|
||
|
n_head: int = 12
|
||
|
n_embd: int = 768
|
||
|
|
||
|
class CausalSelfAttention:
|
||
|
def __init__(self, config:GPTConfig):
|
||
|
assert config.n_embd % config.n_head == 0
|
||
|
# key, query, value projections for all heads, but in a batch
|
||
|
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
|
||
|
# output projection
|
||
|
self.c_proj = nn.Linear(config.n_embd, config.n_embd)
|
||
|
# regularization
|
||
|
self.n_head = config.n_head
|
||
|
self.n_embd = config.n_embd
|
||
|
# not really a 'bias', more of a mask, but following the OpenAI/HF naming though
|
||
|
self.bias = Tensor.ones(1, 1, config.block_size, config.block_size).tril()
|
||
|
self.bias.requires_grad = False
|
||
|
|
||
|
def __call__(self, x:Tensor):
|
||
|
B, T, C = x.shape
|
||
|
qkv = self.c_attn(x)
|
||
|
q, k, v = qkv.split(self.n_embd, dim=2)
|
||
|
k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
||
|
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
||
|
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
||
|
|
||
|
# manual implementation of attention
|
||
|
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
|
||
|
att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
|
||
|
att = att.softmax()
|
||
|
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
|
||
|
y = y.transpose(1, 2).view(B, T, C) # re-assemble all head outputs side by side
|
||
|
# output projection
|
||
|
y = self.c_proj(y)
|
||
|
return y
|
||
|
|
||
|
class MLP:
|
||
|
def __init__(self, config:GPTConfig):
|
||
|
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd)
|
||
|
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)
|
||
|
|
||
|
def __call__(self, x:Tensor) -> Tensor:
|
||
|
return self.c_proj(self.c_fc(x).gelu())
|
||
|
|
||
|
class Block:
|
||
|
def __init__(self, config:GPTConfig):
|
||
|
self.ln_1 = nn.LayerNorm(config.n_embd)
|
||
|
self.attn = CausalSelfAttention(config)
|
||
|
self.ln_2 = nn.LayerNorm(config.n_embd)
|
||
|
self.mlp = MLP(config)
|
||
|
|
||
|
def __call__(self, x:Tensor):
|
||
|
x = x + self.attn(self.ln_1(x))
|
||
|
x = x + self.mlp(self.ln_2(x))
|
||
|
return x
|
||
|
|
||
|
class GPT:
|
||
|
def __init__(self, config:GPTConfig):
|
||
|
self.config = config
|
||
|
|
||
|
self.wte = nn.Embedding(config.padded_vocab_size, config.n_embd)
|
||
|
self.wpe = nn.Embedding(config.block_size, config.n_embd)
|
||
|
self.h = [Block(config) for _ in range(config.n_layer)]
|
||
|
self.ln_f = nn.LayerNorm(config.n_embd)
|
||
|
self.lm_head = nn.Linear(config.n_embd, config.padded_vocab_size, bias=False)
|
||
|
self.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying
|
||
|
|
||
|
def load_pretrained(self):
|
||
|
weights = nn.state.torch_load(fetch(f'https://huggingface.co/gpt2/resolve/main/pytorch_model.bin'))
|
||
|
transposed = ('attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight')
|
||
|
for k in weights:
|
||
|
if k == "wte.weight":
|
||
|
weights[k] = weights[k].pad(((0, self.config.padded_vocab_size-self.config.vocab_size), (0,0))).to(None).contiguous()
|
||
|
if k.endswith(transposed):
|
||
|
weights[k] = weights[k].to(None).T.contiguous()
|
||
|
# lm head and wte are tied
|
||
|
weights['lm_head.weight'] = weights['wte.weight']
|
||
|
nn.state.load_state_dict(self, weights)
|
||
|
|
||
|
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
|
||
|
for _ in range(max_new_tokens):
|
||
|
idx_cond = idx if idx.shape[1] <= self.config.block_size else idx[:, -self.config.block_size:]
|
||
|
logits, _ = self(idx_cond)
|
||
|
logits = logits[:, -1, :] / temperature
|
||
|
idx_next = logits.softmax().multinomial()
|
||
|
idx = Tensor.cat(idx, idx_next, dim=1)
|
||
|
return idx
|
||
|
|
||
|
def __call__(self, idx:Tensor, targets=None):
|
||
|
b, t = idx.shape
|
||
|
pos = Tensor.arange(0, t)
|
||
|
|
||
|
tok_emb = self.wte(idx) # token embeddings of shape (b, t, n_embd)
|
||
|
pos_emb = self.wpe(pos) # position embeddings of shape (t, n_embd)
|
||
|
x = tok_emb + pos_emb
|
||
|
|
||
|
x = self.ln_f(x.sequential(self.h))
|
||
|
|
||
|
if targets is not None:
|
||
|
logits = self.lm_head(x)[:, :, :self.config.vocab_size]
|
||
|
loss = logits.sparse_categorical_crossentropy(targets)
|
||
|
else:
|
||
|
logits = self.lm_head(x[:, [-1], :])[:, :, :self.config.vocab_size]
|
||
|
loss = None
|
||
|
|
||
|
return logits, loss
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
import tiktoken, argparse
|
||
|
|
||
|
parser = argparse.ArgumentParser()
|
||
|
parser.add_argument("--num_iterations", type=int, default=10, help="number of iterations to run")
|
||
|
parser.add_argument("--batch_size", type=int, default=4, help="batch size")
|
||
|
parser.add_argument("--sequence_length", type=int, default=64, help="sequence length")
|
||
|
parser.add_argument("--skip_test", action="store_true", help="skip test")
|
||
|
args = parser.parse_args()
|
||
|
B, T = args.batch_size, args.sequence_length
|
||
|
assert 1 <= T <= 1024
|
||
|
|
||
|
model = GPT(GPTConfig(n_layer=12, n_head=12, n_embd=768))
|
||
|
model.load_pretrained()
|
||
|
|
||
|
# init the tokenizer
|
||
|
enc = tiktoken.get_encoding("gpt2")
|
||
|
encode = lambda s: enc.encode(s, allowed_special={"<|endoftext|>"})
|
||
|
decode = lambda l: enc.decode(l)
|
||
|
|
||
|
# load the tokens
|
||
|
# prefer to use tiny_shakespeare if it's available, otherwise use tiny_stories
|
||
|
# we're using val instead of train split just because it is smaller/faster
|
||
|
tokens_bin = fetch("https://huggingface.co/datasets/karpathy/llmc-starter-pack/resolve/main/tiny_shakespeare_val.bin")
|
||
|
assert os.path.isfile(tokens_bin)
|
||
|
print(f"loading cached tokens in {tokens_bin}")
|
||
|
with open(tokens_bin, "rb") as f:
|
||
|
f.seek(0x400)
|
||
|
tokens = np.frombuffer(f.read(), dtype=np.uint16).astype(np.int32)
|
||
|
tokens = Tensor(tokens)
|
||
|
|
||
|
# lightweight dataloader
|
||
|
def get_batch():
|
||
|
assert B*T+1 <= len(tokens), "not enough tokens"
|
||
|
# for 338,025 tokens. E.g. with B=8 T=1024, this will yield 41 batches before looping
|
||
|
i = 0
|
||
|
while True:
|
||
|
x = tokens[i:i+B*T].view(B, T)
|
||
|
y = tokens[i+1:i+B*T+1].view(B, T)
|
||
|
yield x, y
|
||
|
i += B*T
|
||
|
if i + B*T + 1 >= len(tokens):
|
||
|
i = 0 # in prod we'd want to randomize the start point a bit
|
||
|
|
||
|
# forward backward for a few iterations
|
||
|
data_iter = iter(get_batch())
|
||
|
x, y = next(data_iter) # we'll overfit this batch below
|
||
|
optimizer = nn.optim.AdamW(nn.state.get_parameters(model), lr=1e-4, weight_decay=0)
|
||
|
|
||
|
@TinyJit
|
||
|
def step(x, y):
|
||
|
_, loss = model(x, y)
|
||
|
optimizer.zero_grad()
|
||
|
loss.backward()
|
||
|
return loss.realize(*optimizer.schedule_step())
|
||
|
|
||
|
with Tensor.train():
|
||
|
for i in range(args.num_iterations):
|
||
|
GlobalCounters.reset()
|
||
|
t0 = time.time()
|
||
|
loss = step(x.contiguous(), y.contiguous())
|
||
|
Device[Device.DEFAULT].synchronize()
|
||
|
t1 = time.time()
|
||
|
print(f"iteration {i}, loss: {loss.item():.6f}, time: {(t1-t0)*1000:.3f}ms, {int(B*T/(t1-t0))} tok/s")
|
||
|
|
||
|
if not args.skip_test:
|
||
|
start = "<|endoftext|>"
|
||
|
start_ids = encode(start)
|
||
|
x = (Tensor(start_ids)[None, ...])
|
||
|
max_new_tokens = 16
|
||
|
temperature = 1.0
|
||
|
top_k = 40
|
||
|
y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k)
|
||
|
print(decode(y[0].tolist()))
|
||
|
|