You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
			
				
					122 lines
				
				4.6 KiB
			
		
		
			
		
	
	
					122 lines
				
				4.6 KiB
			| 
											4 days ago
										 | from enum import auto, IntEnum, Enum
 | ||
|  | 
 | ||
|  | # wrapper around IntEnum that preserves Enum.__str__ and makes auto() unique across all FastEnum subclasses
 | ||
|  | class FastEnum(IntEnum):
 | ||
|  |   def __str__(self): return Enum.__str__(self)
 | ||
|  |   @staticmethod
 | ||
|  |   def _generate_next_value_(_, __, ___, last_values): return 1 + max([0, *last_values, *[max(c) for c in FastEnum.__subclasses__()]])
 | ||
|  | 
 | ||
|  | # the order of these Ops controls the order of the toposort
 | ||
|  | class Ops(FastEnum):
 | ||
|  |   # uops that aren't rendered
 | ||
|  |   NOOP = auto(); SINK = auto(); UNIQUE = auto(); DEVICE = auto(); KERNEL = auto(); PRECAST = auto(); REWRITE_ERROR = auto()  # noqa: E702
 | ||
|  | 
 | ||
|  |   # track children
 | ||
|  |   CHILD = auto(); CHILDREN = auto() # noqa: E702
 | ||
|  | 
 | ||
|  |   # buffer ops
 | ||
|  |   COPY = auto(); BUFFER = auto(); BUFFER_VIEW = auto(); MSELECT = auto(); MSTACK = auto() # noqa: E702
 | ||
|  | 
 | ||
|  |   # create buffer
 | ||
|  |   BUFFERIZE = auto()
 | ||
|  | 
 | ||
|  |   # ops that adjust the behavior of the scheduler
 | ||
|  |   CONTIGUOUS = auto(); CONTIGUOUS_BACKWARD = auto(); DETACH = auto(); FUSE = auto() # noqa: E702
 | ||
|  |   REALIZE = auto()
 | ||
|  | 
 | ||
|  |   # blocks in linearizer (only used there)
 | ||
|  |   BLOCK = auto(); BLOCKSTART = auto(); BLOCKEND = auto(); BLOCKFINAL = auto() # noqa: E702
 | ||
|  | 
 | ||
|  |   # movement ops! these only exist in the tensor graph
 | ||
|  |   RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); FLIP = auto() # noqa: E702
 | ||
|  |   MULTI = auto()  # MULTI is really a movement op
 | ||
|  | 
 | ||
|  |   # view is what all movement ops become
 | ||
|  |   VIEW = auto()
 | ||
|  | 
 | ||
|  |   # TODO: remove VALID with the VIEW(CONST(DEVICE)) refactor
 | ||
|  |   VALID = auto()
 | ||
|  | 
 | ||
|  |   # TODO: unify these ops into the levels of the memory hierarchy. depends on ASSIGN is STORE
 | ||
|  |   DEFINE_GLOBAL = auto(); DEFINE_LOCAL = auto(); DEFINE_REG = auto() # noqa: E702
 | ||
|  | 
 | ||
|  |   # this is for symbolic shapes
 | ||
|  |   DEFINE_VAR = auto(); BIND = auto() # noqa: E702
 | ||
|  | 
 | ||
|  |   # this is a RANGE for GPU dimensions, similar to symbolic shapes but not exactly
 | ||
|  |   SPECIAL = auto()
 | ||
|  | 
 | ||
|  |   # reduce
 | ||
|  |   REDUCE_AXIS = auto(); REDUCE = auto(); ALLREDUCE = auto() # noqa: E702
 | ||
|  | 
 | ||
|  |   # optimization helper ops
 | ||
|  |   UNROLL = auto(); CONTRACT = auto(); GEP = auto(); VECTORIZE = auto(); CAT = auto(); PTRCAT = auto() # noqa: E702
 | ||
|  | 
 | ||
|  |   # UnaryOps
 | ||
|  |   CAST = auto(); BITCAST = auto(); EXP2 = auto(); LOG2 = auto(); SIN = auto(); SQRT = auto(); RECIP = auto(); NEG = auto(); TRUNC = auto() # noqa: E702
 | ||
|  | 
 | ||
|  |   # load/store before math
 | ||
|  |   LOAD = auto(); STORE = auto() # noqa: E702
 | ||
|  |   ASSIGN = auto()  # TODO: ASSIGN is STORE, remove ASSIGN
 | ||
|  | 
 | ||
|  |   # tensor core math op, not elementwise
 | ||
|  |   WMMA = auto()
 | ||
|  | 
 | ||
|  |   # INDEX is a BinaryOp similar to ADD, but it operates on pointers
 | ||
|  |   INDEX = auto()
 | ||
|  | 
 | ||
|  |   # BinaryOps
 | ||
|  |   ADD = auto(); MUL = auto(); SHL = auto(); SHR = auto(); IDIV = auto(); MAX = auto(); MOD = auto() # noqa: E702
 | ||
|  |   CMPLT = auto(); CMPNE = auto(); CMPEQ = auto() # noqa: E702
 | ||
|  |   XOR = auto(); OR = auto(); AND = auto() # noqa: E702
 | ||
|  |   THREEFRY = auto(); SUB = auto(); FDIV = auto(); POW = auto() # noqa: E702
 | ||
|  | 
 | ||
|  |   # TernaryOps
 | ||
|  |   WHERE = auto(); MULACC = auto() # noqa: E702
 | ||
|  | 
 | ||
|  |   # control flow ops
 | ||
|  |   BARRIER = auto(); RANGE = auto(); IF = auto(); ENDRANGE = auto(); ENDIF = auto() # noqa: E702
 | ||
|  | 
 | ||
|  |   # consts. VCONST is a vectorized const
 | ||
|  |   VCONST = auto(); CONST = auto() # noqa: E702
 | ||
|  | 
 | ||
|  |   # CUSTOM/CUSTOMI are used to output strings into codegen. the I makes the string inline
 | ||
|  |   CUSTOM = auto(); CUSTOMI = auto() # noqa: E702
 | ||
|  | 
 | ||
|  | class GroupOp:
 | ||
|  |   Unary = {Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.SQRT, Ops.RECIP, Ops.NEG, Ops.TRUNC}
 | ||
|  |   Binary = {Ops.ADD, Ops.MUL, Ops.IDIV, Ops.MAX, Ops.MOD, Ops.CMPLT, Ops.CMPNE, Ops.CMPEQ,
 | ||
|  |             Ops.XOR, Ops.SHL, Ops.SHR, Ops.OR, Ops.AND, Ops.THREEFRY, Ops.SUB, Ops.FDIV, Ops.POW}
 | ||
|  |   Ternary = {Ops.WHERE, Ops.MULACC}
 | ||
|  |   ALU = set.union(Unary, Binary, Ternary)
 | ||
|  | 
 | ||
|  |   # TODO: is BITCAST always Elementwise if it's shape changing?
 | ||
|  |   Elementwise = set.union(ALU, {Ops.CAST, Ops.BITCAST})
 | ||
|  | 
 | ||
|  |   Defines = {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_REG}
 | ||
|  | 
 | ||
|  |   Irreducible = {Ops.CONST, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.RANGE}
 | ||
|  |   Movement = {Ops.RESHAPE, Ops.EXPAND, Ops.PERMUTE, Ops.PAD, Ops.SHRINK, Ops.FLIP}
 | ||
|  | 
 | ||
|  |   Buffer = {Ops.LOAD, Ops.STORE, Ops.VALID, Ops.CONST, Ops.DEFINE_VAR}
 | ||
|  |   Block = {Ops.BLOCK, Ops.BLOCKEND, Ops.BLOCKSTART}
 | ||
|  | 
 | ||
|  |   # BinaryOps that can be flipped
 | ||
|  |   Commutative = {Ops.ADD, Ops.MUL, Ops.MAX, Ops.CMPNE, Ops.CMPEQ, Ops.XOR, Ops.AND, Ops.OR}
 | ||
|  | 
 | ||
|  |   # BinaryOps where f(f(a,b),c) = f(a,f(b,c))
 | ||
|  |   Associative = {Ops.ADD, Ops.MUL, Ops.AND, Ops.OR, Ops.MAX}
 | ||
|  | 
 | ||
|  |   # BinaryOps that satisfy f(x,x)=x see https://en.wikipedia.org/wiki/Idempotence
 | ||
|  |   Idempotent = {Ops.OR, Ops.AND, Ops.MAX}
 | ||
|  | 
 | ||
|  |   # These can change the dtype to bool
 | ||
|  |   Comparison = {Ops.CMPLT, Ops.CMPNE, Ops.CMPEQ}
 | ||
|  | 
 | ||
|  |   # do not preserve f(0) = 0
 | ||
|  |   UnsafePad = {Ops.RECIP, Ops.LOG2, Ops.EXP2, Ops.IDIV, Ops.POW}
 | ||
|  | 
 | ||
|  |   Meta = {Ops.COPY, Ops.BUFFER_VIEW}
 | ||
|  | 
 | ||
|  |   All = set(Ops)
 |