import os import numpy as np import math, random from tinygrad.tensor import Tensor from tinygrad.nn.state import get_parameters, get_state_dict, safe_save, safe_load, load_state_dict from tinygrad.engine.search import actions, bufs_from_lin, time_linearizer, get_kernel_actions from tinygrad.nn.optim import Adam from extra.optimization.extract_policynet import PolicyNet from extra.optimization.helpers import load_worlds, ast_str_to_lin, lin_to_feats if __name__ == "__main__": net = PolicyNet() if os.path.isfile("/tmp/policynet.safetensors"): load_state_dict(net, safe_load("/tmp/policynet.safetensors")) optim = Adam(get_parameters(net)) ast_strs = load_worlds() # select a world all_feats, all_acts, all_rews = [], [], [] while 1: Tensor.no_grad, Tensor.training = True, False lin = ast_str_to_lin(random.choice(ast_strs)) rawbufs = bufs_from_lin(lin) tm = last_tm = base_tm = time_linearizer(lin, rawbufs) # take actions feats, acts, rews = [], [], [] while 1: feat = lin_to_feats(lin) feats.append(feat) probs = net(Tensor([feat])).exp()[0].numpy() # mask valid actions valid_action_mask = np.zeros((len(actions)+1), dtype=np.float32) for x in get_kernel_actions(lin): valid_action_mask[x] = 1 probs *= valid_action_mask probs /= sum(probs) act = np.random.choice(len(probs), p=probs) acts.append(act) if act == 0: rews.append(0) break try: lin.apply_opt(actions[act-1]) tm = time_linearizer(lin, rawbufs) if math.isinf(tm): raise Exception("failed") rews.append(((last_tm-tm)/base_tm)) last_tm = tm except Exception: rews.append(-0.5) break #print(f"{tm*1e6:10.2f}", lin.colored_shape()) assert len(feats) == len(acts) and len(acts) == len(rews) #print(rews) print(f"***** EPISODE {len(rews)} steps, {sum(rews):5.2f} reward, {base_tm*1e6:12.2f} -> {tm*1e6:12.2f} : {lin.colored_shape()}") all_feats += feats all_acts += acts # rewards to go for i in range(len(rews)-2, -1, -1): rews[i] += rews[i+1] all_rews += rews BS = 32 if len(all_feats) >= BS: Tensor.no_grad, Tensor.training = False, True x = Tensor(all_feats[:BS]) mask = np.zeros((BS, len(actions)+1), dtype=np.float32) mask[range(BS), all_acts[:BS]] = all_rews[:BS] loss = -(net(x) * Tensor(mask)).mean() optim.zero_grad() loss.backward() optim.step() all_feats = all_feats[BS:] all_acts = all_acts[BS:] all_rews = all_rews[BS:]