import  os ,  atexit ,  functools 
 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								try : 
 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  import  networkx  as  nx   # type: ignore 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								except  ImportError : 
 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  nx  =  None  # graph won't work 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								from  collections  import  defaultdict 
 
						 
					
						
							
								
							 
							
								
									
										 
								
							 
							
								 
							
							
								from  typing  import  Dict ,  List 
 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								from  tinygrad . ops  import  ScheduleItem ,  UnaryOps ,  BinaryOps ,  ReduceOps ,  MovementOps ,  LoadOps ,  BufferOps ,  TernaryOps ,  Op ,  OpType ,  LazyOp 
 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								from  tinygrad . helpers  import  GRAPH ,  GRAPHPATH ,  DEBUG ,  GlobalCounters ,  getenv ,  dedup 
 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								from  tinygrad . codegen . linearizer  import  UOps 
 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								# **** debugging and graphing **** 
 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								G  =  nx . DiGraph ( )  if  nx  is  not  None  else  None 
 
						 
					
						
							
								
							 
							
								
									
										 
								
							 
							
								 
							
							
								cnts :  Dict [ OpType ,  int ]  =  defaultdict ( int ) 
 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								if  DEBUG  > =  2 : 
 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  def  print_globalcounters ( ) : 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    if  GlobalCounters . time_sum_s  ==  0 :  return 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    print ( f " avg:  { GlobalCounters . global_ops * 1e-9 / GlobalCounters . time_sum_s : 8.2f }  GFLOPS  { GlobalCounters . global_mem * 1e-9 / GlobalCounters . time_sum_s : 8.2f }  GB/s " , 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								          f " { '   ' * 10 } total:  { GlobalCounters . kernel_count : 5d }  kernels  { GlobalCounters . global_ops * 1e-9 : 8.2f }  GOPS  { GlobalCounters . global_mem * 1e-9 : 8.2f }  GB  { GlobalCounters . time_sum_s * 1e3 : 8.2f }  ms " ) 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  atexit . register ( print_globalcounters ) 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								if  GRAPH : 
 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  def  save_graph_exit ( ) : 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    for  k , v  in  cnts . items ( ) :  print ( k ,  v ) 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    print ( " saving " ,  G ) 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    nx . drawing . nx_pydot . write_dot ( G ,  f ' { GRAPHPATH } .dot ' ) 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    # -Gnslimit=100 can make it finish, but you won't like results 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    os . system ( f ' dot -Tsvg  { GRAPHPATH } .dot -o  { GRAPHPATH } .svg ' ) 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  atexit . register ( save_graph_exit ) 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								node_count  =  0 
 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								def  nm ( x ) : 
 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  global  node_count 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  if  not  hasattr ( x ,  ' node_id ' ) : 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    setattr ( x ,  ' node_id ' ,  node_count ) 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    node_count  + =  1 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  return  x . node_id 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
									
										 
								
							 
							
								 
							
							
								def  get_sop ( op :  List [ Op ] ) : 
 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  op  =  [ x  for  x  in  op  if  x  not  in  BufferOps ] 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  if  len ( op )  < =  2 :  return  ' . ' . join ( [ str ( y ) . split ( " . " ) [ 1 ]  for  y  in  op ] [ : : - 1 ] ) 
  
						 
					
						
							
								
							 
							
								
									
										 
								
							 
							
								 
							
							
								  if  len ( op )  < =  6 :  return  ' . ' . join ( [ str ( y ) . split ( " . " ) [ 1 ] [ 0 : 3 ]  for  y  in  op ] [ : : - 1 ] ) 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  return  str ( len ( op ) ) 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
									
										 
								
							 
							
								 
							
							
								def  str_dtype ( dtyp ) : 
 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  ret  =  str ( dtyp ) [ 7 : ] 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  return  " "  if  ret  ==  ' float '  else  f " \n { ret } " 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								@functools . lru_cache ( None ) 
 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								def  add_st_node ( nmx ,  nmo ,  label ,  st ) : 
 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  global  node_count 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  inter_node  =  node_count 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  node_count  + =  1 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  G . add_node ( inter_node ,  style = ' filled ' ,  fillcolor = " #80ff8080 " ,  color = " black " ,  label = f " { st . shape } \n { st . real_strides ( ) } "  +  ( f " \n { st . real_offset ( ) } "  if  st . real_offset ( )  !=  0  else  " " ) ) 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  G . add_edge ( nmx ,  inter_node ,  color = ' #00000060 ' ) 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  G . add_edge ( inter_node ,  nmo ,  label = label ,  color = ' #00000060 ' ) 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								logops  =  open ( getenv ( " LOGOPS " ,  " " ) , " a " )  if  getenv ( " LOGOPS " ,  " " )  else  None 
 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								def  log_schedule_item ( si :  ScheduleItem ) : 
 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  if  logops  and  si . ast . op  not  in  LoadOps :  logops . write ( str ( si . ast ) + " \n " ) 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  show_graph  =  bool ( GRAPH ) 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  if  not  DEBUG  and  not  show_graph :  return 
  
						 
					
						
							
								
							 
							
								
									
										 
								
							 
							
								 
							
							
								  if  si . ast . op  ==  LoadOps . CONTIGUOUS :  setattr ( si . out ,  ' node_id ' ,  nm ( si . inputs [ 0 ] . base ) ) 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  if  si . ast . op  in  { LoadOps . CONST ,  LoadOps . CONTIGUOUS } :  return 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  op :  List [ Op ]  =  [ x . op  for  x  in  si . ast . get_lazyops ( ) ] 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  oporder  =  [ LoadOps ,  TernaryOps ,  ReduceOps ,  BinaryOps ,  UnaryOps ,  MovementOps ,  BufferOps ] 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  optype  =  type ( sorted ( op ,  key = lambda  x :  oporder . index ( type ( x ) ) ) [ 0 ] ) 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  cnts [ optype ]  + =  1 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  if  show_graph : 
  
						 
					
						
							
								
							 
							
								
									
										 
								
							 
							
								 
							
							
								    assert  si . out . base  ==  si . out ,  " all outputs based " 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    top_colors  =  { LoadOps :  ' #FFFFa0 ' ,  UnaryOps :  " #c0c0c0 " ,  ReduceOps :  " #8080ff " ,  BinaryOps :  " #c0c0c0 " ,  MovementOps :  " #80ff80 " ,  TernaryOps :  " #c0c0c0 " ,  BufferOps :  ' #FF8080 ' } 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
									
										 
								
							 
							
								 
							
							
								    # get inputs for shapetrackers 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    input_to_st  =  defaultdict ( list ) 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    for  lo  in  si . ast . get_lazyops ( ) : 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      if  lo . op  !=  BufferOps . MEM :  continue 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      input_to_st [ si . inputs [ lo . arg . idx - 1 ] ] . append ( lo . arg . st ) 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    # add them to the graph, potentially with a movement op seperating them 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    for  x  in  input_to_st : 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      for  st  in  dedup ( input_to_st [ x ] ) : 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        if  st . contiguous : 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								          G . add_edge ( nm ( x ) ,  nm ( si . out ) ,  label = get_sop ( op ) ,  color = ' #00000060 ' ) 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        else : 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								          add_st_node ( nm ( x ) ,  nm ( si . out ) ,  get_sop ( op ) ,  st ) 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      if  ' label '  not  in  G . nodes [ nm ( x ) ] : 
  
						 
					
						
							
								
							 
							
								
									
										 
								
							 
							
								 
							
							
								        G . nodes [ nm ( x ) ] [ ' label ' ]  =  str ( x . shape ) + str_dtype ( si . out . dtype ) 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    if  nm ( si . out )  not  in  G . nodes :  G . add_node ( nm ( si . out ) ) 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    G . nodes [ nm ( si . out ) ] [ ' label ' ]  =  ( str ( set ( x . shape  for  x  in  si . inputs ) ) + " \n " + str ( si . out . shape )  if  optype  ==  ReduceOps  else  str ( si . out . shape ) ) + str_dtype ( si . out . dtype ) + ( f " \n { si . ast . op } "  if  si . ast . op  in  LoadOps  else  " " ) 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    G . nodes [ nm ( si . out ) ] [ ' fillcolor ' ]  =  top_colors [ optype ] 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    G . nodes [ nm ( si . out ) ] [ ' color ' ]  =  ' black ' 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    G . nodes [ nm ( si . out ) ] [ ' style ' ]  =  ' filled ' 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								def  _tree ( lazydata ,  prefix = " " ) : 
 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  if  type ( lazydata ) . __name__  ==  " LazyBuffer " :  return  [ f " ━━ realized  { lazydata . dtype . name }   { lazydata . shape } " ]  if  ( lazydata . realized )  else  _tree ( lazydata . op ,  " LB  " ) 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  if  len ( lazydata . src )  ==  0 :  return  [ f " ━━  { prefix } { lazydata . op . name }   { lazydata . arg  if  lazydata . arg  else  ' ' } " ] 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  lines  =  [ f " ━┳  { prefix } { lazydata . op . name }   { lazydata . arg  if  lazydata . arg  else  ' ' } " ] 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  childs  =  [ _tree ( c )  for  c  in  lazydata . src [ : ] ] 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  for  c  in  childs [ : - 1 ] :  lines  + =  [ f "  ┣ { c [ 0 ] } " ]  +  [ f "  ┃ { l } "  for  l  in  c [ 1 : ] ] 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  return  lines  +  [ "  ┗ " + childs [ - 1 ] [ 0 ] ]  +  [ "    " + l  for  l  in  childs [ - 1 ] [ 1 : ] ] 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
									
										 
								
							 
							
								 
							
							
								def  print_tree ( lazydata : LazyOp ) :  print ( " \n " . join ( [ f " { str ( i ) . rjust ( 3 ) }   { s } "  for  i , s  in  enumerate ( _tree ( lazydata ) ) ] ) ) 
 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
									
										 
								
							 
							
								 
							
							
								def  graph_uops ( uops ) : 
 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  colors  =  { UOps . ALU :  " #ffffc0 " ,  UOps . LOAD :  " #ffc0c0 " ,  UOps . STORE :  " #c0ffc0 " ,  UOps . SPECIAL :  " #c0c0ff " ,  UOps . CONST :  " #e0e0e0 " , 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            UOps . DEFINE_GLOBAL :  " #ffe0b0 " ,  UOps . DEFINE_LOCAL :  " #ffe0d0 " ,  UOps . DEFINE_ACC :  " #f0ffe0 " , 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            UOps . LOOP :  " #c8a0e0 " ,  UOps . PHI :  " #e0ffc0 " } 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  G  =  nx . DiGraph ( ) 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  for  u  in  uops : 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    G . add_node ( u . num ,  label = f " { str ( u . uop ) [ 5 : ] } { ( '   ' + str ( u . arg ) )  if  u . arg  is  not  None  else  ' ' } \n { str ( u . dtype ) } " ,  style = " filled " ,  fillcolor = colors . get ( u . uop ,  " #ffffff " ) ) 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    for  v  in  u . vin :  G . add_edge ( v . num ,  u . num ) 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  GRAPHPATH  =  " /tmp/uops " 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  nx . drawing . nx_pydot . write_dot ( G ,  f ' { GRAPHPATH } .dot ' ) 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  os . system ( f ' dot -Grankdir=LR -Tsvg  { GRAPHPATH } .dot -o  { GRAPHPATH } .svg ' )