You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							176 lines
						
					
					
						
							6.7 KiB
						
					
					
				
			
		
		
	
	
							176 lines
						
					
					
						
							6.7 KiB
						
					
					
				| from __future__ import annotations
 | |
| from typing import List, Optional, Dict, cast
 | |
| import numpy as np
 | |
| np.set_printoptions(suppress=True)
 | |
| import math, functools, time, random, statistics
 | |
| from tinygrad.helpers import DEBUG, getenv, CACHELEVEL, diskcache_get, diskcache_put, colored, Profiling
 | |
| from tinygrad.codegen.opt.kernel import Kernel
 | |
| from tinygrad.device import Buffer, Device, CompileError
 | |
| from tinygrad.codegen.opt.search import _ensure_buffer_alloc, get_kernel_actions, _time_program
 | |
| from tinygrad.engine.realize import get_program
 | |
| 
 | |
| class MCTSNode:
 | |
|   def __init__(self, kernel:Kernel, parent=None):
 | |
|     self.kernel:Kernel = kernel
 | |
|     self.t = math.inf
 | |
|     self.n = 0
 | |
|     self.tm = math.inf
 | |
|     self.i = -1
 | |
|     self.parents: List[MCTSNode] = [parent] if parent is not None else []
 | |
|     self.children: Optional[List[MCTSNode]] = None
 | |
|     self.removed_children: List[MCTSNode] = []
 | |
| 
 | |
| def expand_node(node:MCTSNode):
 | |
|   assert node.children is None
 | |
|   node.children = [MCTSNode(x, node) for x in get_kernel_actions(node.kernel, include_0=False).values()]
 | |
| 
 | |
| def remove_node(node:MCTSNode):
 | |
|   for parent in node.parents:
 | |
|     assert parent.children is not None
 | |
|     parent.children.remove(node)
 | |
|     parent.removed_children.append(node)
 | |
| 
 | |
| C = math.sqrt(2)
 | |
| TEMP = 0.5
 | |
| def _sample_tree(node:MCTSNode, best_tm:float) -> MCTSNode:
 | |
|   if node.children is None or len(node.children) == 0: return node
 | |
|   unexplored_children = []
 | |
|   explored_children = []
 | |
|   ucb_explored_children: List[float] = []
 | |
|   for child in node.children:
 | |
|     if child.n == 0: unexplored_children.append(child)
 | |
|     else:
 | |
|       ucb = -child.t/best_tm + C*math.sqrt(math.log(node.n)/child.n)
 | |
|       if not math.isinf(ucb):
 | |
|         explored_children.append(child)
 | |
|         ucb_explored_children.append(ucb)
 | |
|   if len(unexplored_children): return random.choice(unexplored_children)
 | |
|   if not len(explored_children): return node
 | |
|   # safe softmax
 | |
|   ucb_exp = np.exp((np.array(ucb_explored_children)-max(ucb_explored_children))/TEMP)
 | |
|   return _sample_tree(explored_children[np.random.choice(len(ucb_exp), p=ucb_exp/np.sum(ucb_exp))], best_tm)
 | |
| 
 | |
| # this will expand/remove sometimes
 | |
| def sample_tree(root:MCTSNode, best_tm:float) -> Optional[MCTSNode]:
 | |
|   if root.children is None: expand_node(root)
 | |
|   while root.children:
 | |
|     # tree traversal
 | |
|     node = _sample_tree(root, best_tm)
 | |
| 
 | |
|     if node.children is not None and len(node.children) == 0:
 | |
|       remove_node(node)
 | |
|       continue
 | |
| 
 | |
|     # node expansion
 | |
|     if node.n != 0:
 | |
|       if node.children is None: expand_node(node)
 | |
|       assert node.children is not None
 | |
|       if len(node.children) == 0:
 | |
|         remove_node(node)
 | |
|         continue
 | |
|       node = random.choice(node.children)
 | |
|     return node
 | |
|   return None
 | |
| 
 | |
| def backprop(bnode:MCTSNode, tm, strength=1.0):
 | |
|   if bnode.t > tm: bnode.t = tm
 | |
|   bnode.n += strength
 | |
|   for parent in bnode.parents: backprop(parent, tm, strength/len(bnode.parents))
 | |
| 
 | |
| graph_mcts_cnt = 0
 | |
| def mcts_search(lin:Kernel, rawbufs:List[Buffer], amt:int) -> Kernel:
 | |
|   global graph_mcts_cnt
 | |
|   # TODO: copied from BEAM
 | |
|   key = {"ast": lin.ast.key, "amt": amt, "device": lin.opts.device, "suffix": lin.opts.suffix}
 | |
|   if not getenv("IGNORE_MCTS_CACHE") and CACHELEVEL >= 1 and (val:=diskcache_get("mcts_search", key)) is not None:
 | |
|     ret = lin.copy()
 | |
|     for o in val[len(lin.applied_opts):]: ret.apply_opt(o)
 | |
|     return ret
 | |
| 
 | |
|   rawbufs = _ensure_buffer_alloc(rawbufs)
 | |
|   var_vals = {k.expr:(k.vmax+k.vmin)//2 for k in lin.ast.variables()}
 | |
|   dev = Device[lin.opts.device]
 | |
|   root = MCTSNode(lin)
 | |
| 
 | |
|   st = time.perf_counter()
 | |
|   best, best_idx, best_tm = lin, 0, math.inf
 | |
|   seen_libs: Dict[bytes, MCTSNode] = {}
 | |
|   seen_asts: Dict[bytes, MCTSNode] = {}
 | |
|   compile_time, runtime_time = 0.0, 0.0
 | |
|   for i in range(amt):
 | |
|     node = sample_tree(root, best_tm)  # sample and expand
 | |
|     if node is None: break  # finished the whole tree
 | |
|     node.i = i  # when was node explored
 | |
| 
 | |
|     opt_ast = node.kernel.get_optimized_ast()
 | |
|     if (sibling_node:=seen_asts.get(opt_ast.key, None)) is not None:
 | |
|       # early check for same optimized AST hit
 | |
|       remove_node(node)
 | |
|       tm = sibling_node.t
 | |
|     else:
 | |
|       seen_asts[opt_ast.key] = node
 | |
| 
 | |
|       # lowering (50% of the time)
 | |
|       p = get_program(node.kernel.get_optimized_ast(name_override="test"), node.kernel.opts)
 | |
| 
 | |
|       # rollout
 | |
|       tm1 = time.perf_counter()
 | |
|       try:
 | |
|         lib = dev.compiler.compile(p.src)
 | |
|       except CompileError:
 | |
|         # NOTE: many of these "compiler errors" are caused by bad code output from the lowerer
 | |
|         lib = None
 | |
|       tm2 = time.perf_counter()
 | |
|       if lib is None:
 | |
|         tm = math.inf
 | |
|       else:
 | |
|         if (sibling_node:=seen_libs.get(lib, None)) is not None:
 | |
|           # NOTE: these should all be caught by the AST check, need to canonicalize
 | |
|           # remove this node, it's a duplicate
 | |
|           remove_node(node)
 | |
|           tm = sibling_node.t
 | |
|         else:
 | |
|           seen_libs[lib] = node
 | |
|           try: tm = statistics.median(_time_program(p, lib, var_vals, rawbufs, cnt=3, early_stop=best_tm*5/1e6))*1e6
 | |
|           except RuntimeError: tm = math.inf
 | |
|           node.tm = tm
 | |
|       tm3 = time.perf_counter()
 | |
|       compile_time += tm2-tm1
 | |
|       runtime_time += tm3-tm2
 | |
| 
 | |
|       # mock rollout
 | |
|       #node.tm = tm = random.random() + 0.1
 | |
| 
 | |
|     if tm < best_tm: best, best_idx, best_tm = node.kernel, i, tm
 | |
|     et = time.perf_counter() - st
 | |
|     if DEBUG>=2: print(f"\r{et:7.2f}s {colored(f'{compile_time*100/et:3.0f}%', 'cyan')} {colored(f'{runtime_time*100/et:3.0f}%', 'red')}: {tm:12.2f} us     best: {best_tm:12.2f} us @ {best_idx+1:4d}      {i+1:4d}/{amt:4d}  {int(round((i+1)/et)):4d}/s     {node.kernel.colored_shape()}\033[K", end="")  # noqa: E501
 | |
| 
 | |
|     # backprop
 | |
|     backprop(node, tm)
 | |
|   if DEBUG>=2: print()
 | |
| 
 | |
|   if getenv("MCTSGRAPH"):
 | |
|     import networkx as nx
 | |
|     import os
 | |
|     GRAPHPATH = "/tmp/net"
 | |
|     def save_graph(G, fn, opt=""):
 | |
|       print("saving", G, f"to {fn}.svg")
 | |
|       nx.drawing.nx_pydot.write_dot(G, f'{fn}.dot')
 | |
|       os.system(f'dot {opt} -Tsvg {fn}.dot -o {fn}.svg')
 | |
| 
 | |
|     G = nx.DiGraph()
 | |
|     def add_node(node:MCTSNode):
 | |
|       if node.n == 0: return
 | |
|       for parent in node.parents: G.add_edge(parent, node)
 | |
|       gopts = node.kernel.applied_opts
 | |
|       edge_lbl = f"{str(gopts[-1].op)[7:]} {gopts[-1].axis} {gopts[-1].arg}" if len(gopts) else "ROOT"
 | |
|       G.add_node(node, label=f"{node.i+1}\n{node.tm:.2f} us\n{edge_lbl}\nt {node.t:.2f}\nn {node.n}",
 | |
|                  fillcolor="#80ff8080" if node.tm == best_tm else "#ffff8080", style='filled' if node.t == best_tm else '')
 | |
|       if node.children is not None:
 | |
|         for child in node.children+node.removed_children: add_node(child)
 | |
|     add_node(root)
 | |
|     save_graph(G, f"{GRAPHPATH}.{graph_mcts_cnt}.mcts", '-Grankdir=LR')
 | |
|     graph_mcts_cnt += 1
 | |
| 
 | |
|   if CACHELEVEL >= 1: diskcache_put("mcts_search", key, best.applied_opts)
 | |
|   return best
 | |
| 
 |