import numpy as np
import math
import random
np . set_printoptions ( suppress = True )
from copy import deepcopy
from tinygrad . helpers import getenv , colored
from tinygrad . tensor import Tensor
from tinygrad . nn . state import get_parameters , get_state_dict , safe_save , safe_load , load_state_dict
from tinygrad . engine . search import bufs_from_lin , actions , get_kernel_actions
from extra . optimization . helpers import load_worlds , ast_str_to_lin , lin_to_feats , time_linearizer
from extra . optimization . extract_policynet import PolicyNet
from extra . optimization . pretrain_valuenet import ValueNet
VALUE = getenv ( " VALUE " )
if __name__ == " __main__ " :
if VALUE :
net = ValueNet ( )
load_state_dict ( net , safe_load ( " /tmp/valuenet.safetensors " ) )
else :
net = PolicyNet ( )
load_state_dict ( net , safe_load ( " /tmp/policynet.safetensors " ) )
ast_strs = load_worlds ( )
# real randomness
random . seed ( )
random . shuffle ( ast_strs )
wins = 0
for ep_num , ast_str in enumerate ( ast_strs ) :
print ( " \n EPISODE " , ep_num , f " win { wins * 100 / max ( 1 , ep_num ) : .2f } % " )
lin = ast_str_to_lin ( ast_str )
rawbufs = bufs_from_lin ( lin )
linhc = deepcopy ( lin )
linhc . hand_coded_optimizations ( )
tmhc = time_linearizer ( linhc , rawbufs )
print ( f " { tmhc * 1e6 : 10.2f } HC " , linhc . colored_shape ( ) )
pred_time = float ( ' nan ' )
tm = float ( ' inf ' )
while 1 :
if VALUE :
acts , feats = [ ] , [ ]
for k , v in get_kernel_actions ( lin ) . items ( ) :
acts . append ( k )
feats . append ( lin_to_feats ( v ) )
preds = net ( Tensor ( feats ) )
pred_time = math . exp ( preds . numpy ( ) . min ( ) )
act = acts [ preds . numpy ( ) . argmin ( ) ]
else :
probs = net ( Tensor ( [ lin_to_feats ( lin ) ] ) )
dist = probs . exp ( ) . numpy ( )
act = dist . argmax ( )
if act == 0 : break
try :
lin . apply_opt ( actions [ act - 1 ] )
except Exception :
print ( " FAILED " )
break
tm = time_linearizer ( lin , rawbufs )
print ( f " { tm * 1e6 : 10.2f } { pred_time * 1e6 : 10.2f } " , lin . colored_shape ( ) )
print ( f " { colored ( ' BEAT ' , ' green ' ) if tm < tmhc else colored ( ' lost ' , ' red ' ) } hand coded { tmhc / tm : 5.2f } x " )
wins + = int ( tm < tmhc )