#!/usr/bin/env python3 import os if "NOOPT" not in os.environ: os.environ["NOOPT"] = "1" from tinygrad import Device, nn, Tensor, dtypes, Variable Device.DEFAULT = "CPU" from train_gpt2 import GPT, GPTConfig from tinygrad.helpers import dedup, to_function_name, flatten, getenv, GlobalCounters, ansilen, to_function_name from tinygrad.engine.realize import get_kernel, run_schedule from tinygrad.engine.memory import memory_planner from tinygrad.ops import Ops TIMING = getenv("TIMING") if __name__ == "__main__": model = GPT(GPTConfig(n_layer=getenv("NLAYER", 12), n_head=12, n_embd=768)) #model.load_pretrained() for p in nn.state.get_parameters(model): p.replace(Tensor.empty(p.shape, dtype=p.dtype)) # fake load pretrained #early_sched = create_schedule([x.lazydata for x in nn.state.get_parameters(model)]) #print(f"built model {len(early_sched)}") #B, T = Variable("B", 1, 128).bind(4), 64 #Variable("T", 1, 1024).bind(64) B, T = 4, 64 Tensor.training = True optimizer = nn.optim.Adam(nn.state.get_parameters(model), lr=1e-4) warmup_count = getenv("WARMUP", 3) for i in range(warmup_count): # TODO: why does it take three and not two to stablize GlobalCounters.reset() X = Tensor.empty(4, 64, dtype=dtypes.int).reshape(B, T) Y = Tensor.empty(4, 64, dtype=dtypes.int).reshape(B, T) _, loss = model(X, Y) optimizer.zero_grad() if getenv("BACKWARD", 1): loss.backward() tensors = optimizer.schedule_step() else: tensors = [] sched = loss.schedule(*tensors) print(f"calls {i}:", len(sched)) #run_schedule(sched[:]) sched = memory_planner(sched) ast_dedup = dedup([si.ast for si in sched if si.ast.op is Ops.SINK]) srcs = {} for ast in ast_dedup: k = get_kernel(Device["CPU"].renderer, ast) k.linearize() src = Device["CPU"].renderer.render(to_function_name(k.name), k.uops) srcs[ast] = (k.name, src) print("functions:", len(srcs)) used_buffers = dedup(flatten([si.bufs for si in sched])) numbered_bufs = {x:i for i,x in enumerate(used_buffers)} print("buffers:", len(numbered_bufs)) state_dict = nn.state.get_state_dict(model) state_dict.update({'X': X, 'Y': Y, 'loss': loss}) grad_state_dict = {} for k,v in state_dict.items(): if v.lazydata.base.buffer not in used_buffers: print(f"UNUSED: {k}") if v.grad is not None: grad_state_dict['grad_'+k] = v.grad state_dict.update(grad_state_dict) state_dict.update({'adam_b1_t': optimizer.b1_t, 'adam_b2_t': optimizer.b2_t, 'adam_lr': optimizer.lr}) inverse_state_dict = {v:k for k,v in state_dict.items()} for p,m,v in zip(optimizer.params, optimizer.m, optimizer.v): nm = inverse_state_dict[p] state_dict["adam_m_"+nm] = m state_dict["adam_v_"+nm] = v named_buffers = {v.lazydata.base.buffer:k.replace(".", "_") for k,v in state_dict.items()} c_code = ["#include ", "#include ", "#include "] if TIMING: c_code += ["#include ", "#include "] c_code += [x[1].replace(" restrict ", " ")+"\n" for x in srcs.values()] premain = ["int main() {"] if TIMING: premain += [" struct timespec tm0; clock_gettime(CLOCK_MONOTONIC, &tm0);"] lst = 0 main = [] all_bufs = [] for i,si in enumerate(sched): bufs = [(named_buffers.get(b, f"b{numbered_bufs[b]}"), b) for b in si.bufs] all_bufs += bufs if si.ast.op is not Ops.SINK: print(f"// {si.ast.op}", bufs) else: print(f"{srcs[si.ast][0]}({', '.join([x[0] for x in bufs])})") main.append(f" {to_function_name(srcs[si.ast][0])}({', '.join([x[0] for x in bufs])});") if TIMING: main.append(f" struct timespec tm{i+1}; clock_gettime(CLOCK_MONOTONIC, &tm{i+1});") main.append(f" printf(\"%10.2f ms + %7.2f ms @ {to_function_name(srcs[si.ast][0])}\\n\"," +\ f"((tm{i+1}.tv_sec-tm{0}.tv_sec) + (tm{i+1}.tv_nsec-tm{0}.tv_nsec) / 1e9) * 1e3," +\ f"((tm{i+1}.tv_sec-tm{lst}.tv_sec) + (tm{i+1}.tv_nsec-tm{lst}.tv_nsec) / 1e9) * 1e3);") lst = i+1 #call = f"{srcs[si.ast][0]}({', '.join(bufs)})" #call += " "*(80-ansilen(call)) #print(f"{call} // {i+1}") #print(srcs[si.ast][1]) main.append("}") mallocs = [f" {b.dtype.name}* {n} = ({b.dtype.name}*)malloc({b.nbytes});" for n,b in dedup(all_bufs)] with open("out.c", "w") as f: f.write('\n'.join(c_code+premain+mallocs+main))