You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							100 lines
						
					
					
						
							4.3 KiB
						
					
					
				
			
		
		
	
	
							100 lines
						
					
					
						
							4.3 KiB
						
					
					
				import itertools
 | 
						|
import numpy as np
 | 
						|
from typing import DefaultDict, Dict, List, Set, Tuple, TypeVar, Union
 | 
						|
from tinygrad.device import Buffer
 | 
						|
from tinygrad.engine.realize import capturing, lower_schedule_item
 | 
						|
from tinygrad.helpers import DEBUG, MULTIOUTPUT, colored, getenv
 | 
						|
from tinygrad.engine.schedule import LBScheduleItem, _graph_schedule, ScheduleItem
 | 
						|
from tinygrad.uop.ops import Ops, UOp
 | 
						|
from tinygrad.tensor import Tensor, _to_np_dtype
 | 
						|
 | 
						|
ctx_vars = { MULTIOUTPUT: (0, 1) }
 | 
						|
FUZZ_SCHEDULE_MAX_PATHS = getenv("FUZZ_SCHEDULE_MAX_PATHS", 10)
 | 
						|
 | 
						|
def fuzz_schedule(outs:List[UOp]):
 | 
						|
  # find toposorts across all tunable params
 | 
						|
  unique_ts: Dict[Tuple[LBScheduleItem, ...], Dict[str, int]] = {}
 | 
						|
  for combination in itertools.product(*ctx_vars.values()):
 | 
						|
    for var, val in zip(ctx_vars, combination): var.value = val
 | 
						|
    ctx_var_values = dict(zip([v.key for v in ctx_vars], combination))
 | 
						|
    graph, in_degree, _ = _graph_schedule(outs)
 | 
						|
    for ts in find_all_toposorts(graph, in_degree): unique_ts[ts] = ctx_var_values
 | 
						|
  toposorts = list(unique_ts.items())
 | 
						|
  if DEBUG >= 1: print(colored(f"fuzzing {len(toposorts)} schedule permutations", "yellow"))
 | 
						|
 | 
						|
  # setup ground truth
 | 
						|
  ground_truth: Dict[UOp, memoryview] = {}
 | 
						|
  assign_targets: Dict[UOp, UOp] = {}
 | 
						|
  # IMPORTANT: freeze prerealized bufs before ScheduleItem exec
 | 
						|
  prerealized: Dict[UOp, memoryview] = {}
 | 
						|
  seed = Tensor._seed
 | 
						|
  ts,_ = toposorts[0]
 | 
						|
  for lsi in ts:
 | 
						|
    for out in lsi.outputs:
 | 
						|
      # freeze assign state before exec
 | 
						|
      if out.op is Ops.ASSIGN:
 | 
						|
        prerealized[out] = out.buffer.as_buffer()
 | 
						|
        assign_targets[out.srcs[1]] = out
 | 
						|
    for x in lsi.inputs:
 | 
						|
      if x not in ground_truth and x.device != "NPY": prerealized[x] = x.buffer.as_buffer()
 | 
						|
    si = ScheduleItem(lsi.ast, tuple(x.buffer for x in lsi.outputs+lsi.inputs if x.size != 0), lsi.metadata)
 | 
						|
    _exec_si(si, seed)
 | 
						|
    for out in lsi.outputs:
 | 
						|
      ground_truth[out] = out.buffer.as_buffer()
 | 
						|
      del out.srcs # only schedule the LazyBuffer in this fuzz run
 | 
						|
 | 
						|
  # exec and validate each permutation with new Buffers
 | 
						|
  for i, (ts, ctx) in enumerate(toposorts[1:]):
 | 
						|
    if DEBUG >= 1: print(colored(f"testing permutation {i} {ctx}", "yellow"))
 | 
						|
    rawbufs: Dict[UOp, Buffer] = {}
 | 
						|
    for lsi in ts:
 | 
						|
      for out in lsi.outputs:
 | 
						|
        base = rawbufs[lsi.inputs[0]].base if out.op is Ops.BUFFER_VIEW else None
 | 
						|
        rawbufs[out] = Buffer(out.buffer.device, out.buffer.size, out.buffer.dtype, base=base)
 | 
						|
        if out.op is Ops.ASSIGN: rawbufs[out].ensure_allocated().copyin(prerealized[out])
 | 
						|
      for x in lsi.inputs:
 | 
						|
        if x not in rawbufs:
 | 
						|
          # override the assign_target after ASSIGN
 | 
						|
          if x in assign_targets and assign_targets[x] in rawbufs: rawbufs[x] = rawbufs[assign_targets[x]]
 | 
						|
          elif x.device == "NPY": rawbufs[x] = x.buffer
 | 
						|
          # copy the pre realized input
 | 
						|
          else: rawbufs[x] = Buffer(x.buffer.device, x.buffer.size, x.buffer.dtype, initial_value=bytes(prerealized[x]))
 | 
						|
      si = ScheduleItem(lsi.ast, tuple(rawbufs[x] for x in lsi.bufs if x.size != 0), lsi.metadata)
 | 
						|
      _exec_si(si, seed)
 | 
						|
      for out in lsi.outputs:
 | 
						|
        outbuf = np.frombuffer(rawbufs[out].as_buffer(), _to_np_dtype(out.dtype))
 | 
						|
        try: np.testing.assert_allclose(outbuf, np.frombuffer(ground_truth[out], _to_np_dtype(out.dtype)), atol=1e-2, rtol=1e-2)
 | 
						|
        except Exception as e:
 | 
						|
          print(f"FAILED FOR {out}")
 | 
						|
          raise e
 | 
						|
 | 
						|
def _exec_si(si:ScheduleItem, seed:int):
 | 
						|
  ei = lower_schedule_item(si)
 | 
						|
  if len(capturing): capturing[0].add(ei)
 | 
						|
  ei.run()
 | 
						|
 | 
						|
T = TypeVar("T")
 | 
						|
def find_all_toposorts(graph:DefaultDict[T, List[T]], in_degree:Union[DefaultDict[T, int], Dict[T, int]]) -> List[Tuple[T, ...]]:
 | 
						|
  visited: Set[T] = set()
 | 
						|
  ret: List[Tuple[T, ...]] = []
 | 
						|
  path: List[T] = []
 | 
						|
 | 
						|
  def recurse_paths(path:List[T]):
 | 
						|
    for v, d in in_degree.items():
 | 
						|
      if d != 0 or v in visited: continue
 | 
						|
      for u in graph[v]: in_degree[u] -= 1
 | 
						|
      path.append(v)
 | 
						|
      visited.add(v)
 | 
						|
      recurse_paths(path)
 | 
						|
      if len(ret) >= FUZZ_SCHEDULE_MAX_PATHS: return
 | 
						|
      # backtrack
 | 
						|
      for u in graph[v]: in_degree[u] += 1
 | 
						|
      path.pop()
 | 
						|
      visited.remove(v)
 | 
						|
    if len(path) == len(in_degree): ret.append(tuple(path))
 | 
						|
  recurse_paths(path)
 | 
						|
 | 
						|
  if len(ret) == 0: raise RuntimeError("detected cycle in the graph")
 | 
						|
  # verify all paths are unique
 | 
						|
  assert len(ret) == len(set(ret))
 | 
						|
  return ret
 | 
						|
 |