#!/usr/bin/env python3
import multiprocessing , pickle , functools , difflib , os , threading , json , time , sys , webbrowser , socket , argparse , decimal , socketserver
from http . server import BaseHTTPRequestHandler
from urllib . parse import parse_qs , urlparse
from typing import Any , Callable , TypedDict , Generator
from tinygrad . helpers import colored , getenv , tqdm , unwrap , word_wrap
from tinygrad . ops import TrackedGraphRewrite , UOp , Ops , lines , GroupOp
from tinygrad . codegen . kernel import Kernel
from tinygrad . device import ProfileEvent , ProfileDeviceEvent , ProfileRangeEvent , ProfileGraphEvent
from tinygrad . dtype import dtypes
uops_colors = { Ops . LOAD : " #ffc0c0 " , Ops . STORE : " #87CEEB " , Ops . CONST : " #e0e0e0 " , Ops . VCONST : " #e0e0e0 " , Ops . REDUCE : " #FF5B5B " ,
Ops . DEFINE_GLOBAL : " #ffe0b0 " , Ops . DEFINE_LOCAL : " #ffe0d0 " , Ops . DEFINE_ACC : " #f0ffe0 " , Ops . REDUCE_AXIS : " #FF6B6B " ,
Ops . RANGE : " #c8a0e0 " , Ops . ASSIGN : " #e0ffc0 " , Ops . BARRIER : " #ff8080 " , Ops . IF : " #c8b0c0 " , Ops . SPECIAL : " #c0c0ff " ,
Ops . INDEX : " #e8ffa0 " , Ops . WMMA : " #efefc0 " , Ops . VIEW : " #C8F9D4 " , Ops . MULTI : " #f6ccff " , Ops . KERNEL : " #3e7f55 " , Ops . IGNORE : " #00C000 " ,
* * { x : " #D8F9E4 " for x in GroupOp . Movement } , * * { x : " #ffffc0 " for x in GroupOp . ALU } , Ops . THREEFRY : " #ffff80 " , Ops . BUFFER_VIEW : " #E5EAFF " ,
Ops . BLOCK : " #C4A484 " , Ops . BLOCKEND : " #C4A4A4 " , Ops . BUFFER : " #B0BDFF " , Ops . COPY : " #a040a0 " , Ops . NAME : " #808080 " }
# VIZ API
# NOTE: if any extra rendering in VIZ fails, we don't crash
def pcall ( fxn : Callable [ . . . , str ] , * args , * * kwargs ) - > str :
try : return fxn ( * args , * * kwargs )
except Exception as e : return f " ERROR in { fxn . __name__ } : { e } "
# ** Metadata for a track_rewrites scope
class GraphRewriteMetadata ( TypedDict ) :
loc : tuple [ str , int ] # [path, lineno] calling graph_rewrite
match_count : int # total match count in this context
code_line : str # source code calling graph_rewrite
kernel_code : str | None # optionally render the final kernel code
name : str | None # optional name of the rewrite
@functools . lru_cache ( None )
def render_program ( k : Kernel ) : return k . opts . render ( k . uops )
def to_metadata ( k : Any , v : TrackedGraphRewrite ) - > GraphRewriteMetadata :
return { " loc " : v . loc , " match_count " : len ( v . matches ) , " code_line " : lines ( v . loc [ 0 ] ) [ v . loc [ 1 ] - 1 ] . strip ( ) ,
" kernel_code " : pcall ( render_program , k ) if isinstance ( k , Kernel ) else None , " name " : v . name }
def get_metadata ( keys : list [ Any ] , contexts : list [ list [ TrackedGraphRewrite ] ] ) - > list [ tuple [ str , list [ GraphRewriteMetadata ] ] ] :
return [ ( k . name if isinstance ( k , Kernel ) else str ( k ) , [ to_metadata ( k , v ) for v in vals ] ) for k , vals in zip ( keys , contexts ) ]
# ** Complete rewrite details for a graph_rewrite call
class GraphRewriteDetails ( TypedDict ) :
graph : dict # JSON serialized UOp for this rewrite step
uop : str # strigified UOp for this rewrite step
diff : list [ str ] | None # string diff of the single UOp that changed
changed_nodes : list [ int ] | None # the changed UOp id + all its parents ids
upat : tuple [ tuple [ str , int ] , str ] | None # [loc, source_code] of the matched UPat
def uop_to_json ( x : UOp ) - > dict [ int , dict ] :
assert isinstance ( x , UOp )
graph : dict [ int , dict ] = { }
excluded : set [ UOp ] = set ( )
for u in ( toposort := x . toposort ) :
# always exclude DEVICE/CONST/UNIQUE
if u . op in { Ops . DEVICE , Ops . CONST , Ops . UNIQUE } : excluded . add ( u )
# only exclude CONST VIEW source if it has no other children in the graph
if u . op is Ops . CONST and len ( u . src ) != 0 and all ( cr . op is Ops . CONST for c in u . src [ 0 ] . children if ( cr := c ( ) ) is not None and cr in toposort ) :
excluded . update ( u . src )
for u in toposort :
if u in excluded : continue
argst = str ( u . arg )
if u . op is Ops . VIEW :
argst = ( " \n " . join ( [ f " { v . shape } / { v . strides } " + ( f " \n MASK { v . mask } " if v . mask is not None else " " ) +
( " " if v . offset == 0 else f " / { v . offset } " ) for v in unwrap ( u . st ) . views ] ) )
label = f " { str ( u . op ) . split ( ' . ' ) [ 1 ] } { ( chr ( 10 ) + word_wrap ( argst . replace ( ' : ' , ' ' ) ) ) if u . arg is not None else ' ' } "
if u . dtype != dtypes . void : label + = f " \n { u . dtype } "
for idx , x in enumerate ( u . src ) :
if x in excluded :
if x . op is Ops . CONST and dtypes . is_float ( u . dtype ) : label + = f " \n CONST { idx } { x . arg : g } "
else : label + = f " \n { x . op . name } { idx } { x . arg } "
graph [ id ( u ) ] = { " label " : label , " src " : [ id ( x ) for x in u . src if x not in excluded ] , " color " : uops_colors . get ( u . op , " #ffffff " ) }
return graph
def get_details ( k : Any , ctx : TrackedGraphRewrite ) - > Generator [ GraphRewriteDetails , None , None ] :
yield { " graph " : uop_to_json ( next_sink := ctx . sink ) , " uop " : str ( ctx . sink ) , " changed_nodes " : None , " diff " : None , " upat " : None }
replaces : dict [ UOp , UOp ] = { }
for u0 , u1 , upat in tqdm ( ctx . matches ) :
replaces [ u0 ] = u1
new_sink = next_sink . substitute ( replaces )
yield { " graph " : ( sink_json := uop_to_json ( new_sink ) ) , " uop " : str ( new_sink ) , " changed_nodes " : [ id ( x ) for x in u1 . toposort if id ( x ) in sink_json ] ,
" diff " : list ( difflib . unified_diff ( pcall ( str , u0 ) . splitlines ( ) , pcall ( str , u1 ) . splitlines ( ) ) ) , " upat " : ( upat . location , upat . printable ( ) ) }
if not ctx . bottom_up : next_sink = new_sink
# Profiler API
devices : dict [ str , tuple [ decimal . Decimal , decimal . Decimal , int ] ] = { }
def prep_ts ( device : str , ts : decimal . Decimal , is_copy ) : return int ( decimal . Decimal ( ts ) + devices [ device ] [ is_copy ] )
def dev_to_pid ( device : str , is_copy = False ) : return { " pid " : devices [ device ] [ 2 ] , " tid " : int ( is_copy ) }
def dev_ev_to_perfetto_json ( ev : ProfileDeviceEvent ) :
devices [ ev . device ] = ( ev . comp_tdiff , ev . copy_tdiff if ev . copy_tdiff is not None else ev . comp_tdiff , len ( devices ) )
return [ { " name " : " process_name " , " ph " : " M " , " pid " : dev_to_pid ( ev . device ) [ ' pid ' ] , " args " : { " name " : ev . device } } ,
{ " name " : " thread_name " , " ph " : " M " , " pid " : dev_to_pid ( ev . device ) [ ' pid ' ] , " tid " : 0 , " args " : { " name " : " COMPUTE " } } ,
{ " name " : " thread_name " , " ph " : " M " , " pid " : dev_to_pid ( ev . device ) [ ' pid ' ] , " tid " : 1 , " args " : { " name " : " COPY " } } ]
def range_ev_to_perfetto_json ( ev : ProfileRangeEvent ) :
return [ { " name " : ev . name , " ph " : " X " , " ts " : prep_ts ( ev . device , ev . st , ev . is_copy ) , " dur " : float ( ev . en - ev . st ) , * * dev_to_pid ( ev . device , ev . is_copy ) } ]
def graph_ev_to_perfetto_json ( ev : ProfileGraphEvent , reccnt ) :
ret = [ ]
for i , e in enumerate ( ev . ents ) :
st , en = ev . sigs [ e . st_id ] , ev . sigs [ e . en_id ]
ret + = [ { " name " : e . name , " ph " : " X " , " ts " : prep_ts ( e . device , st , e . is_copy ) , " dur " : float ( en - st ) , * * dev_to_pid ( e . device , e . is_copy ) } ]
for dep in ev . deps [ i ] :
d = ev . ents [ dep ]
ret + = [ { " ph " : " s " , * * dev_to_pid ( d . device , d . is_copy ) , " id " : reccnt + len ( ret ) , " ts " : prep_ts ( d . device , ev . sigs [ d . en_id ] , d . is_copy ) , " bp " : " e " } ]
ret + = [ { " ph " : " f " , * * dev_to_pid ( e . device , e . is_copy ) , " id " : reccnt + len ( ret ) - 1 , " ts " : prep_ts ( e . device , st , e . is_copy ) , " bp " : " e " } ]
return ret
def to_perfetto ( profile : list [ ProfileEvent ] ) :
# Start json with devices.
prof_json = [ x for ev in profile if isinstance ( ev , ProfileDeviceEvent ) for x in dev_ev_to_perfetto_json ( ev ) ]
for ev in tqdm ( profile , desc = " preparing profile " ) :
if isinstance ( ev , ProfileRangeEvent ) : prof_json + = range_ev_to_perfetto_json ( ev )
elif isinstance ( ev , ProfileGraphEvent ) : prof_json + = graph_ev_to_perfetto_json ( ev , reccnt = len ( prof_json ) )
return json . dumps ( { " traceEvents " : prof_json } ) . encode ( ) if len ( prof_json ) > 0 else None
# ** HTTP server
class Handler ( BaseHTTPRequestHandler ) :
def do_GET ( self ) :
ret , status_code , content_type = b " " , 200 , " text/html "
if ( url := urlparse ( self . path ) ) . path == " / " :
with open ( os . path . join ( os . path . dirname ( __file__ ) , " index.html " ) , " rb " ) as f : ret = f . read ( )
elif ( url := urlparse ( self . path ) ) . path == " /profiler " :
with open ( os . path . join ( os . path . dirname ( __file__ ) , " perfetto.html " ) , " rb " ) as f : ret = f . read ( )
elif self . path . startswith ( ( " /assets/ " , " /js/ " ) ) and ' /.. ' not in self . path :
try :
with open ( os . path . join ( os . path . dirname ( __file__ ) , self . path . strip ( ' / ' ) ) , " rb " ) as f : ret = f . read ( )
if url . path . endswith ( " .js " ) : content_type = " application/javascript "
if url . path . endswith ( " .css " ) : content_type = " text/css "
except FileNotFoundError : status_code = 404
elif url . path == " /kernels " :
if " kernel " in ( query := parse_qs ( url . query ) ) :
def getarg ( k : str , default = 0 ) : return int ( query [ k ] [ 0 ] ) if k in query else default
kidx , ridx = getarg ( " kernel " ) , getarg ( " idx " )
try :
# stream details
self . send_response ( 200 )
self . send_header ( " Content-Type " , " text/event-stream " )
self . send_header ( " Cache-Control " , " no-cache " )
self . end_headers ( )
for r in get_details ( contexts [ 0 ] [ kidx ] , contexts [ 1 ] [ kidx ] [ ridx ] ) :
self . wfile . write ( f " data: { json . dumps ( r ) } \n \n " . encode ( " utf-8 " ) )
self . wfile . flush ( )
self . wfile . write ( " data: END \n \n " . encode ( " utf-8 " ) )
return self . wfile . flush ( )
# pass if client closed connection
except ( BrokenPipeError , ConnectionResetError ) : return
ret , content_type = json . dumps ( kernels ) . encode ( ) , " application/json "
elif url . path == " /get_profile " and perfetto_profile is not None : ret , content_type = perfetto_profile , " application/json "
else : status_code = 404
# send response
self . send_response ( status_code )
self . send_header ( ' Content-Type ' , content_type )
self . send_header ( ' Content-Length ' , str ( len ( ret ) ) )
self . end_headers ( )
return self . wfile . write ( ret )
# ** main loop
def reloader ( ) :
mtime = os . stat ( __file__ ) . st_mtime
while not stop_reloader . is_set ( ) :
if mtime != os . stat ( __file__ ) . st_mtime :
print ( " reloading server... " )
os . execv ( sys . executable , [ sys . executable ] + sys . argv )
time . sleep ( 0.1 )
def load_pickle ( path : str ) :
if path is None or not os . path . exists ( path ) : return None
with open ( path , " rb " ) as f : return pickle . load ( f )
# NOTE: using HTTPServer forces a potentially slow socket.getfqdn
class TCPServerWithReuse ( socketserver . TCPServer ) : allow_reuse_address = True
if __name__ == " __main__ " :
parser = argparse . ArgumentParser ( )
parser . add_argument ( ' --kernels ' , type = str , help = ' Path to kernels ' , default = None )
parser . add_argument ( ' --profile ' , type = str , help = ' Path profile ' , default = None )
args = parser . parse_args ( )
with socket . socket ( socket . AF_INET , socket . SOCK_STREAM ) as s :
if s . connect_ex ( ( ( HOST := " http://127.0.0.1 " ) . replace ( " http:// " , " " ) , PORT := getenv ( " PORT " , 8000 ) ) ) == 0 :
raise RuntimeError ( f " { HOST } : { PORT } is occupied! use PORT= to change. " )
stop_reloader = threading . Event ( )
multiprocessing . current_process ( ) . name = " VizProcess " # disallow opening of devices
st = time . perf_counter ( )
print ( " *** viz is starting " )
contexts , profile = load_pickle ( args . kernels ) , load_pickle ( args . profile )
# NOTE: this context is a tuple of list[keys] and list[values]
kernels = get_metadata ( * contexts ) if contexts is not None else [ ]
perfetto_profile = to_perfetto ( profile ) if profile is not None else None
server = TCPServerWithReuse ( ( ' ' , PORT ) , Handler )
reloader_thread = threading . Thread ( target = reloader )
reloader_thread . start ( )
print ( f " *** started viz on { HOST } : { PORT } " )
print ( colored ( f " *** ready in { ( time . perf_counter ( ) - st ) * 1e3 : 4.2f } ms " , " green " ) )
if len ( getenv ( " BROWSER " , " " ) ) > 0 : webbrowser . open ( f " { HOST } : { PORT } { ' /profiler ' if contexts is None else ' ' } " )
try : server . serve_forever ( )
except KeyboardInterrupt :
print ( " *** viz is shutting down... " )
stop_reloader . set ( )