import os, sys, sqlite3, pickle, random from tqdm import tqdm, trange from copy import deepcopy from tinygrad.nn import Linear from tinygrad.tensor import Tensor from tinygrad.nn.optim import Adam from tinygrad.nn.state import get_parameters, get_state_dict, safe_save, safe_load, load_state_dict from tinygrad.engine.search import actions from extra.optimization.helpers import load_worlds, ast_str_to_lin, lin_to_feats, assert_same_lin from tinygrad.codegen.kernel import Kernel from tinygrad.helpers import getenv # stuff needed to unpack a kernel from tinygrad.ops import LazyOp, TernaryOps, BinaryOps, UnaryOps, ReduceOps, BufferOps, MemBuffer, ConstBuffer from tinygrad.dtype import dtypes from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.view import View from tinygrad.ops import Variable inf, nan = float('inf'), float('nan') from tinygrad.codegen.kernel import Opt, OptOps INNER = 256 class PolicyNet: def __init__(self): self.l1 = Linear(1021,INNER) self.l2 = Linear(INNER,INNER) self.l3 = Linear(INNER,1+len(actions)) def __call__(self, x): x = self.l1(x).relu() x = self.l2(x).relu().dropout(0.9) return self.l3(x).log_softmax() def dataset_from_cache(fn): conn = sqlite3.connect(fn) cur = conn.cursor() cur.execute("SELECT * FROM beam_search") X,A = [], [] for f in tqdm(cur.fetchall()): Xs,As = [], [] try: lin = Kernel(eval(f[0])) opts = pickle.loads(f[-1]) for o in opts: Xs.append(lin_to_feats(lin, use_sts=True)) As.append(actions.index(o)) lin.apply_opt(o) Xs.append(lin_to_feats(lin, use_sts=True)) As.append(0) except Exception: pass X += Xs A += As return X,A if __name__ == "__main__": if getenv("REGEN"): X,V = dataset_from_cache(sys.argv[1] if len(sys.argv) > 1 else "/tmp/tinygrad_cache") safe_save({"X": Tensor(X), "V": Tensor(V)}, "/tmp/dataset_policy") else: ld = safe_load("/tmp/dataset_policy") X,V = ld['X'].numpy(), ld['V'].numpy() print(X.shape, V.shape) order = list(range(X.shape[0])) random.shuffle(order) X, V = X[order], V[order] ratio = -256 X_test, V_test = Tensor(X[ratio:]), Tensor(V[ratio:]) X,V = X[:ratio], V[:ratio] print(X.shape, V.shape) net = PolicyNet() #if os.path.isfile("/tmp/policynet.safetensors"): load_state_dict(net, safe_load("/tmp/policynet.safetensors")) optim = Adam(get_parameters(net)) def get_minibatch(X,Y,bs): xs, ys = [], [] for _ in range(bs): sel = random.randint(0, len(X)-1) xs.append(X[sel]) ys.append(Y[sel]) return Tensor(xs), Tensor(ys) Tensor.no_grad, Tensor.training = False, True losses = [] test_losses = [] test_accuracy = 0 test_loss = float('inf') for i in (t:=trange(500)): x,y = get_minibatch(X,V,bs=256) out = net(x) loss = out.sparse_categorical_crossentropy(y) optim.zero_grad() loss.backward() optim.step() cat = out.argmax(axis=-1) accuracy = (cat == y).mean() t.set_description(f"loss {loss.numpy():7.2f} accuracy {accuracy.numpy()*100:7.2f}%, test loss {test_loss:7.2f} test accuracy {test_accuracy*100:7.2f}%") losses.append(loss.numpy().item()) test_losses.append(test_loss) if i % 10: out = net(X_test) test_loss = out.sparse_categorical_crossentropy(V_test).square().mean().numpy().item() cat = out.argmax(axis=-1) test_accuracy = (cat == y).mean().numpy() safe_save(get_state_dict(net), "/tmp/policynet.safetensors") import matplotlib.pyplot as plt plt.plot(losses[10:]) plt.plot(test_losses[10:]) plt.show()