from __future__ import annotations
from typing import List , Optional , Dict , cast
import numpy as np
np . set_printoptions ( suppress = True )
import math , functools , time , random , statistics
from tinygrad . helpers import DEBUG , getenv , CACHELEVEL , diskcache_get , diskcache_put , colored , Profiling
from tinygrad . codegen . kernel import Kernel
from tinygrad . device import Buffer , Device , CompileError
from tinygrad . engine . search import _ensure_buffer_alloc , get_kernel_actions , _time_program
class MCTSNode :
def __init__ ( self , kernel : Kernel , parent = None ) :
self . kernel : Kernel = kernel
self . t = math . inf
self . n = 0
self . tm = math . inf
self . i = - 1
self . parents : List [ MCTSNode ] = [ parent ] if parent is not None else [ ]
self . children : Optional [ List [ MCTSNode ] ] = None
self . removed_children : List [ MCTSNode ] = [ ]
def expand_node ( node : MCTSNode ) :
assert node . children is None
node . children = [ MCTSNode ( x , node ) for x in get_kernel_actions ( node . kernel , include_0 = False ) . values ( ) ]
def remove_node ( node : MCTSNode ) :
for parent in node . parents :
assert parent . children is not None
parent . children . remove ( node )
parent . removed_children . append ( node )
C = math . sqrt ( 2 )
TEMP = 0.5
def _sample_tree ( node : MCTSNode , best_tm : float ) - > MCTSNode :
if node . children is None or len ( node . children ) == 0 : return node
unexplored_children = [ ]
explored_children = [ ]
ucb_explored_children : List [ float ] = [ ]
for child in node . children :
if child . n == 0 : unexplored_children . append ( child )
else :
ucb = - child . t / best_tm + C * math . sqrt ( math . log ( node . n ) / child . n )
if not math . isinf ( ucb ) :
explored_children . append ( child )
ucb_explored_children . append ( ucb )
if len ( unexplored_children ) : return random . choice ( unexplored_children )
if not len ( explored_children ) : return node
# safe softmax
ucb_exp = np . exp ( ( np . array ( ucb_explored_children ) - max ( ucb_explored_children ) ) / TEMP )
return _sample_tree ( explored_children [ np . random . choice ( len ( ucb_exp ) , p = ucb_exp / np . sum ( ucb_exp ) ) ] , best_tm )
# this will expand/remove sometimes
def sample_tree ( root : MCTSNode , best_tm : float ) - > Optional [ MCTSNode ] :
if root . children is None : expand_node ( root )
while root . children :
# tree traversal
node = _sample_tree ( root , best_tm )
if node . children is not None and len ( node . children ) == 0 :
remove_node ( node )
continue
# node expansion
if node . n != 0 :
if node . children is None : expand_node ( node )
assert node . children is not None
if len ( node . children ) == 0 :
remove_node ( node )
continue
node = random . choice ( node . children )
return node
return None
def backprop ( bnode : MCTSNode , tm , strength = 1.0 ) :
if bnode . t > tm : bnode . t = tm
bnode . n + = strength
for parent in bnode . parents : backprop ( parent , tm , strength / len ( bnode . parents ) )
graph_mcts_cnt = 0
def mcts_search ( lin : Kernel , rawbufs : List [ Buffer ] , amt : int ) - > Kernel :
global graph_mcts_cnt
# TODO: copied from BEAM
key = { " ast " : lin . ast . key , " amt " : amt , " device " : lin . opts . device , " suffix " : lin . opts . suffix }
if not getenv ( " IGNORE_MCTS_CACHE " ) and CACHELEVEL > = 1 and ( val := diskcache_get ( " mcts_search " , key ) ) is not None :
ret = lin . copy ( )
for o in val [ len ( lin . applied_opts ) : ] : ret . apply_opt ( o )
return ret
rawbufs = _ensure_buffer_alloc ( rawbufs )
var_vals = { k : ( k . vmax + k . vmin ) / / 2 for k in lin . ast . variables ( ) }
dev = Device [ lin . opts . device ]
root = MCTSNode ( lin )
st = time . perf_counter ( )
best , best_idx , best_tm = lin , 0 , math . inf
seen_libs : Dict [ bytes , MCTSNode ] = { }
seen_asts : Dict [ bytes , MCTSNode ] = { }
compile_time , runtime_time = 0.0 , 0.0
for i in range ( amt ) :
node = sample_tree ( root , best_tm ) # sample and expand
if node is None : break # finished the whole tree
node . i = i # when was node explored
opt_ast = node . kernel . get_optimized_ast ( )
if ( sibling_node := seen_asts . get ( opt_ast . key , None ) ) is not None :
# early check for same optimized AST hit
remove_node ( node )
tm = sibling_node . t
else :
seen_asts [ opt_ast . key ] = node
# lowering (50% of the time)
p = node . kernel . to_program ( name_override = " test " )
# rollout
tm1 = time . perf_counter ( )
try :
lib = dev . compiler . compile ( p . src )
except CompileError :
# NOTE: many of these "compiler errors" are caused by bad code output from the lowerer
lib = None
tm2 = time . perf_counter ( )
if lib is None :
tm = math . inf
else :
if ( sibling_node := seen_libs . get ( lib , None ) ) is not None :
# NOTE: these should all be caught by the AST check, need to canonicalize
# remove this node, it's a duplicate
remove_node ( node )
tm = sibling_node . t
else :
seen_libs [ lib ] = node
try : tm = statistics . median ( _time_program ( p , lib , var_vals , rawbufs , cnt = 3 , early_stop = best_tm * 5 / 1e6 ) ) * 1e6
except RuntimeError : tm = math . inf
node . tm = tm
tm3 = time . perf_counter ( )
compile_time + = tm2 - tm1
runtime_time + = tm3 - tm2
# mock rollout
#node.tm = tm = random.random() + 0.1
if tm < best_tm : best , best_idx , best_tm = node . kernel , i , tm
et = time . perf_counter ( ) - st
if DEBUG > = 2 : print ( f " \r { et : 7.2f } s { colored ( f ' { compile_time * 100 / et : 3.0f } % ' , ' cyan ' ) } { colored ( f ' { runtime_time * 100 / et : 3.0f } % ' , ' red ' ) } : { tm : 12.2f } us best: { best_tm : 12.2f } us @ { best_idx + 1 : 4d } { i + 1 : 4d } / { amt : 4d } { int ( round ( ( i + 1 ) / et ) ) : 4d } /s { node . kernel . colored_shape ( ) } \033 [K " , end = " " ) # noqa: E501
# backprop
backprop ( node , tm )
if DEBUG > = 2 : print ( )
if getenv ( " MCTSGRAPH " ) :
import networkx as nx
import os
GRAPHPATH = " /tmp/net "
def save_graph ( G , fn , opt = " " ) :
print ( " saving " , G , f " to { fn } .svg " )
nx . drawing . nx_pydot . write_dot ( G , f ' { fn } .dot ' )
os . system ( f ' dot { opt } -Tsvg { fn } .dot -o { fn } .svg ' )
G = nx . DiGraph ( )
def add_node ( node : MCTSNode ) :
if node . n == 0 : return
for parent in node . parents : G . add_edge ( parent , node )
gopts = node . kernel . applied_opts
edge_lbl = f " { str ( gopts [ - 1 ] . op ) [ 7 : ] } { gopts [ - 1 ] . axis } { gopts [ - 1 ] . arg } " if len ( gopts ) else " ROOT "
G . add_node ( node , label = f " { node . i + 1 } \n { node . tm : .2f } us \n { edge_lbl } \n t { node . t : .2f } \n n { node . n } " ,
fillcolor = " #80ff8080 " if node . tm == best_tm else " #ffff8080 " , style = ' filled ' if node . t == best_tm else ' ' )
if node . children is not None :
for child in node . children + node . removed_children : add_node ( child )
add_node ( root )
save_graph ( G , f " { GRAPHPATH } . { graph_mcts_cnt } .mcts " , ' -Grankdir=LR ' )
graph_mcts_cnt + = 1
if CACHELEVEL > = 1 : diskcache_put ( " mcts_search " , key , best . applied_opts )
return best