from typing import Any, cast import functools, operator, itertools from collections import defaultdict from dataclasses import dataclass from tinygrad.dtype import dtypes, ImageDType, PtrDType, DType, AddrSpace from tinygrad.uop.ops import UOp, Ops, UPat, PatternMatcher, graph_rewrite, GroupOp, identity_element from tinygrad.uop.symbolic import split_uop, uop_given_valid, parse_valid, simplify_valid, sym, symbolic_flat from tinygrad.helpers import getenv, flatten, AMX, prod, partition from tinygrad.renderer import Renderer # ***** image load valid simplification ***** def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> UOp|None: if (idx:=uop_given_valid(valid, start_idx)) is None: return buf.const_like(0) if not isinstance(buf.dtype, ImageDType): return None if idx is start_idx else buf.index(idx, valid) # wait for it to be image indexed before running simplification if start_idx.dtype.count != 2: return None # can drop valid if idx is out of bound when valid is False drop_stmt = [] for stmt in split_uop(valid, Ops.AND): try: X, is_upper_bound, c = parse_valid(stmt) except ValueError: return None # for X0 + X1 + ... >= 1, check if it's out of bound when Xi = 0 for all i if not is_upper_bound and c == 1 and all(u.op in GroupOp.Irreducible and u.vmin == 0 for u in split_uop(X, Ops.ADD)): testidx = functools.reduce(lambda nowidx,u: nowidx.substitute({u:u.const_like(0)}), split_uop(X, Ops.ADD), idx) testidx = testidx.simplify() if testidx.gep(0).vmax < 0 or testidx.gep(1).vmax < 0: drop_stmt.append(stmt) continue # if X <= c, check if it's out of bound when X = c+1 # if X >= c, check if it's out of bound when X = c-1 test_value = c + 1 if is_upper_bound else c - 1 for i,b in zip(idx.src, (buf.dtype.shape[1], buf.dtype.shape[0])): if i.is_increasing(): rw = i.substitute({X:X.const_like(test_value)}).simplify() if rw.vmin >= b or rw.vmax < 0: drop_stmt.append(stmt) break if not drop_stmt and idx is start_idx: return None new_valid = functools.reduce(operator.and_, ss) if (ss:=[s for s in split_uop(valid, Ops.AND) if s not in drop_stmt]) else None return buf.index(idx, new_valid) def delete_redundant_gates(store:UOp, buf:UOp, idx:UOp, val:UOp, store_gate:UOp, cast:UOp|None=None) -> UOp|None: if store_gate not in [gate.src[0] for gate in val.toposort() if gate.op is Ops.IF]: return None # remove the gate from the index return UOp.store(buf.index(idx).cast(cast.dtype) if cast is not None else buf.index(idx), val, *store.src[2:]) load_store_indexing = PatternMatcher([ # simplify valid (UPat(Ops.AND, name="valid"), simplify_valid), # image load valid idx simplification (UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("start_idx"), UPat.var("valid"))), simplify_valid_load), # index True is just Index (UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("start_idx"), UPat(Ops.CONST, arg=True))), lambda buf,start_idx: buf.index(start_idx)), # delete_redundant_gates (after expand) (UPat(Ops.STORE, src=(UPat.any(stidx:=UPat.var("buf").index(UPat.var("idx"), UPat.var("store_gate")), stidx.cast().named("cast")), UPat.var("val")), name="store", allow_any_len=True), delete_redundant_gates), ]) # ***** load/store grouping ***** def expand_index(buf:UOp, vec:UOp, mask:UOp|None=None): if getenv("UNSAFE_DISABLE_MASK", 0): mask = None # generate the individual indexes midx = graph_rewrite(UOp.sink(*[buf.index(vec.gep(i), mask.gep(i) if mask is not None else None) for i in range(vec.dtype.count)]), symbolic_flat+load_store_indexing, name=f"index_buf_{buf.arg}") # extract all the relevant offsets offsets_rootsrc: defaultdict[Any, dict[int, list[int]]] = defaultdict(dict) for i in range(vec.dtype.count): idx: Any = midx.src[i].src[1] if idx.op is Ops.ADD and idx.src[1].op is Ops.CONST: root_src, arg = idx.src[0], idx.src[1].arg elif idx.op is Ops.ADD and idx.src[0].op is Ops.CONST: root_src, arg = idx.src[1], idx.src[0].arg elif idx.op is Ops.CONST: root_src, arg = "CONST", idx.arg else: root_src, arg = idx, 0 if len(midx.src[i].src) == 3: root_src = (midx.src[i].src[2], root_src) offsets_rootsrc[root_src].setdefault(arg, []).append(i) # the buf.dtype is always a pointer ptrdtype = cast(PtrDType, buf.dtype) # then rewrite everything we can into groups ret = [] idxs: list[int|None] = [None]*vec.dtype.count global_offset = 0 for offsets in offsets_rootsrc.values(): grouped_offsets = [[x for _,x in group] for _,group in itertools.groupby(enumerate(sorted(offsets.keys())), lambda x: x[1]-x[0])] for grp in grouped_offsets: # get the index offset for this element. using [0] is okay, because they are the same lidx = midx.src[offsets[grp[0]][0]] if len(grp) > 1: lidx = lidx.cast(ptrdtype.base.vec(len(grp)).ptr(size=ptrdtype.size, addrspace=ptrdtype.addrspace)) # set the idxs of the output for i,g in enumerate(grp): for oo in offsets[g]: idxs[oo] = global_offset+i # add this lidx to the CAT ret.append(lidx) global_offset += len(grp) assert None not in idxs, f"some idxs are missing {idxs}" # this base thing is for image, we want the CAT to be a normal pointer post_cat = UOp(Ops.PTRCAT, ptrdtype.base.ptr(size=ptrdtype.size, addrspace=ptrdtype.addrspace).vec(vec.dtype.count), tuple(ret)) return post_cat.gep(tuple(cast(list[int], idxs))) def cat_after_store(cat:UOp, data:UOp, sto:UOp): # TODO: this is written in many places offset = 0 ret: list[UOp] = [] for s in cat.src: ret.append(s.store(data.gep(tuple(range(offset, offset+s.dtype.count))), *sto.src[2:])) offset += s.dtype.count return UOp(Ops.NOOP, src=tuple(ret)) def gep_on_store(gep:UOp, st:UOp, sto:UOp): # NOTE: we need to invert the gep here, but it may be an expanding gep # fake argsort. TODO: handle duplicates a = {} for i,x in enumerate(gep.arg): a[x] = i new_arg = tuple(x[1] for x in sorted(a.items())) return gep.src[0].store(st.gep(new_arg), *sto.src[2:]) load_store_folding = PatternMatcher([ (UPat(Ops.INDEX, src=(UPat(Ops.VECTORIZE, src=UPat(GroupOp.Defines, name="buf")), UPat.var("vec"))), expand_index), (UPat(Ops.INDEX, src=(UPat(Ops.VECTORIZE, src=UPat(GroupOp.Defines, name="buf")), UPat.var("vec"), UPat.var("mask"))), expand_index), # GEP after LOAD (UPat(Ops.LOAD, src=(UPat(Ops.GEP, name="gep"),), name="ld", allow_any_len=True), lambda gep, ld: ld.replace(dtype=ld.dtype.scalar().vec(gep.dtype.count), src=(gep.src[0],)+ld.src[1:]).gep(gep.arg)), # GEP on data of STORE (UPat(Ops.STORE, src=(UPat(Ops.GEP, name="gep"), UPat.var("st")), allow_any_len=True, name="sto"), gep_on_store), # put PTRCAT after LOAD (UPat(Ops.LOAD, src=(UPat(Ops.PTRCAT, name="cat"),), name="ld", allow_any_len=True), lambda cat,ld: UOp(Ops.CAT, ld.dtype, tuple(ld.replace(dtype=x.dtype.base, src=(x,)+ld.src[1:]) for x in cat.src))), # put PTRCAT after STORE (UPat(Ops.STORE, src=(UPat(Ops.PTRCAT, name="cat"), UPat(name="data")), allow_any_len=True, name="sto"), cat_after_store), ]) # *** correct load/store *** def split_load_store(ctx:Renderer|None, ls:UOp, idx:UOp): # this splits loads and stores into multiple chunks # if there's only one element to load/store, no splitting needed if (sz:=ls.src[0].dtype.count) == 1: return None buf = idx.src[0] # determine fold lengths lengths = [] must_divide = True if ctx is not None and ctx.device == "DSP": lengths = [128,64,32,16,8,4] must_divide = False elif buf.dtype.base != dtypes.float and buf.dtype.base != dtypes.half and not isinstance(buf.dtype, ImageDType): pass elif cast(PtrDType, buf.dtype).addrspace == AddrSpace.REG: pass elif isinstance(buf.dtype, ImageDType): lengths = [4] elif ctx is not None and ctx.supports_float4: # TODO: a better way to get this than ctx lengths = [8,4,2] if buf.dtype.base == dtypes.half and getenv("ALLOW_HALF8") else ([16,8,4,2] if AMX else [4,2]) lengths.append(1) # worst case, it's not folded # filter fold lengths that don't divide if must_divide: lengths = [x for x in lengths if idx.src[1].divides(x) is not None] # split based on the fold lengths global_offset = 0 ret = [] ptrdtype = cast(PtrDType, buf.dtype) while global_offset < sz: # with 1 at the end of the lengths list, this will always hit for fold_length in lengths: if global_offset+fold_length > sz: continue lidx = buf.index(idx.src[1] + global_offset, idx.src[2] if len(idx.src) > 2 else None) if fold_length > 1: lidx = lidx.cast(ptrdtype.base.vec(fold_length).ptr(size=ptrdtype.size, addrspace=ptrdtype.addrspace)) if ls.op is Ops.STORE: ret.append(ls.replace(src=(lidx,ls.src[1].gep(tuple(range(global_offset, global_offset+fold_length))))+ls.src[2:])) else: ret.append(ls.replace(src=(lidx,)+ls.src[1:], dtype=ls.dtype.scalar().vec(fold_length))) global_offset += fold_length break # if it wasn't split, we return None. otherwise we CAT them if len(ret) <= 1: return None return UOp(Ops.CAT, ls.dtype, tuple(ret)) if ls.op is Ops.LOAD else UOp(Ops.NOOP, src=tuple(ret)) def image_fixup(ls:UOp): # normal image load or store, with the CAST from expand_index if ls.src[0].op is Ops.CAST and isinstance(image_dtype:=ls.src[0].src[0].dtype, ImageDType): assert ls.src[0].dtype.count == 4, "image must be casted to 4" idx = ls.src[0].src[0] oidx = UOp(Ops.VECTORIZE, dtypes.int.vec(2), ((idx.src[1] // 4) % image_dtype.shape[1], (idx.src[1] // (4*image_dtype.shape[1])))) idx = idx.replace(src=(idx.src[0], oidx)+idx.src[2:]) return ls.replace(src=(idx,)+ls.src[1:]) # this is an unprocessed image without a cast, aka unfoldable image load. this doesn't work for stores if isinstance(image_dtype:=ls.src[0].dtype, ImageDType) and ls.src[0].src[1].dtype != dtypes.int.vec(2): assert ls.op is Ops.LOAD, "if an image store isn't upcasted to 4, we can't store it" idx = ls.src[0] id4 = idx.src[1] % 4 oidx = UOp(Ops.VECTORIZE, dtypes.int.vec(2), ((idx.src[1] // 4) % image_dtype.shape[1], (idx.src[1] // (4*image_dtype.shape[1])))) idx = idx.replace(src=(idx.src[0], oidx)+idx.src[2:]) vec_load = ls.replace(dtype=ls.dtype.vec(4), src=(idx,)+ls.src[1:]) return functools.reduce(lambda ret, i: id4.ne(i).where(ret, vec_load.gep(i)), range(4), ls.const_like(float('nan'))) return None correct_load_store = PatternMatcher([ # split LOAD/STORE (UPat((Ops.LOAD, Ops.STORE), src=(UPat(Ops.INDEX, name="idx").cast(),), name="ls", allow_any_len=True), split_load_store), # image indexing, including unfoldable images (UPat((Ops.LOAD, Ops.STORE), name="ls"), image_fixup), ]) # *** uop expander *** # TODO: there's a lot shared with gep_through_wmma here def no_vectorized_wmma(wmma:UOp): out_sz = prod(x[1] for x in wmma.arg[6][-1]) if wmma.dtype.count == out_sz: return None tsrcs = [] for s,sz in zip(wmma.src, wmma.arg[6]): ssz = prod(x[1] for x in sz) tsrcs.append([s.gep(tuple(range(grp, grp+ssz))) for grp in range(0, s.dtype.count, ssz)]) wmmas = [UOp(Ops.WMMA, wmma.dtype.scalar().vec(out_sz), tsrc, wmma.arg) for tsrc in zip(*tsrcs)] wmma_ex = flatten([[e.gep(i) for i in range(out_sz)] for e in wmmas]) return UOp(Ops.VECTORIZE, wmma.dtype, tuple(wmma_ex)) def no_vectorized_alu(alu:UOp): if alu.dtype.vcount == 1: return None alus = tuple(UOp(alu.op, alu.dtype.scalar(), tuple(s.gep(i) for s in alu.src), alu.arg) for i in range(alu.dtype.vcount)) return UOp(Ops.VECTORIZE, alu.dtype, alus) def no_vectorized_acc(acc:UOp, c:UOp): if acc.dtype.count == 1: return None assert c.arg == 0, "this only supports index 0" new_acc = acc.replace(dtype=acc.dtype.base.scalar().ptr(acc.dtype.count, cast(PtrDType, acc.dtype).addrspace)) return UOp(Ops.PTRCAT, acc.dtype, tuple([new_acc.index(UOp.const(dtypes.int, i)) for i in range(acc.dtype.count)])) devectorize = PatternMatcher([ # no ALU on vectorized dtypes (UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST), name="alu"), no_vectorized_alu), (UPat(Ops.WMMA, name="wmma"), no_vectorized_wmma), (UPat(Ops.DEFINE_REG, name="acc").index(UPat.cvar("c")), no_vectorized_acc), ]) pm_render = PatternMatcher([ # for rendering, we use explicit VECTORIZE (UPat(Ops.CONST, name='c'), lambda c: UOp(Ops.VECTORIZE, c.dtype, (UOp.const(c.dtype.scalar(), c.arg),)*c.dtype.vcount) if c.dtype.vcount > 1 else None), (UPat(Ops.VCONST, name='c'), lambda c: UOp(Ops.VECTORIZE, c.dtype, tuple(UOp.const(c.dtype.scalar(), x) for x in c.arg))), (UPat(Ops.GEP, name='gep'), lambda gep: UOp(Ops.VECTORIZE, gep.dtype, tuple(gep.src[0].gep(x) for x in gep.arg)) if len(gep.arg) > 1 else None), (UPat(Ops.GEP, name='gep'), lambda gep: gep.src[0] if gep.src[0].dtype.vcount == 1 and gep.arg == (0,) else None), (UPat(Ops.VECTORIZE, src=(UPat(name='x'),)), lambda x: x), # give any loads that are masked an alt value (UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat(), UPat(), UPat())).or_casted(),), allow_any_len=True, name="x"), lambda x: x.replace(src=(x.src[0], x.const_like(0))+x.src[1:]) if len(x.src) == 1 or x.src[1].op is Ops.CUSTOM else None), # gate any stores that aren't gated with ifs (UPat(Ops.STORE, src=(UPat(src=(UPat(), UPat(), UPat(dtype=dtypes.bool)), name="idx").or_casted(), UPat()), name="store", allow_any_len=True), lambda store,idx: UOp(Ops.STORE, dtype=store.dtype, src=store.src[:2]+(UOp(Ops.IF, src=(idx.src[2],)),)+store.src[2:]) if \ len(store.src) <= 2 or store.src[2].op != Ops.IF else None), ]) # *** Ops.REDUCE -> Ops.DEFINE_ACC *** @dataclass class ReduceContext: acc_num: int = 0 def horizontal_reduce(inp:UOp, out_dtype:DType) -> list[UOp]: # if this has a horizontal reduction component, do that first if inp.dtype != out_dtype: # NOTE: [0 1 2 3 4 5 6 7] -> [0+4, 1+5, 2+6, 3+7] horizontal_amount = inp.dtype.count//out_dtype.count return [inp.gep(tuple(range(i, inp.dtype.count, horizontal_amount))) for i in range(0, horizontal_amount)] return [inp] def reduce_to_acc(ctx:ReduceContext, red:UOp): inp, reduce_range = red.src[0], red.src[1:] lst = horizontal_reduce(inp, red.dtype) assert all(x.dtype == red.dtype for x in lst), f"horizontal reduction mismatch {lst[0].dtype} != {red.dtype}" # if we have a range if len(reduce_range) != 0: topo = inp.toposort() stored_ranges = flatten([x.src[2:] for x in topo if x.op is Ops.STORE]) input_ranges = tuple([x for x in topo if x.op is Ops.RANGE and x not in reduce_range and x not in stored_ranges]) identity = red.const_like(identity_element(red.arg, red.dtype.scalar())) acc = UOp(Ops.DEFINE_REG, red.dtype.ptr(size=1, addrspace=AddrSpace.REG), arg=(ctx.acc_num,)).index(UOp.const(dtypes.int, 0)) do_store = acc.store(identity, UOp(Ops.NOOP, src=input_ranges)) if len(input_ranges) else acc.store(identity) lst = [acc.load(do_store, *reduce_range)] + lst # put acc as the first element ctx.acc_num += 1 ret = functools.reduce(lambda x,y: x.alu(red.arg, y), lst) return acc.load(acc.store(ret, *reduce_range)) if len(reduce_range) != 0 else ret def no_vectorized_reduce(inp:UOp, red:UOp): if inp.dtype != red.dtype: red = red.replace(src=(functools.reduce(lambda x,y: x.alu(red.arg, y), horizontal_reduce(inp, red.dtype)),)+red.src[1:]) if red.dtype.vcount == 1: return red # no_vectorize_alu ignoring ranges if red.dtype.vcount == 1: return None alus = tuple(UOp(red.op, red.dtype.scalar(), (red.src[0].gep(i),)+red.src[1:], red.arg) for i in range(red.dtype.vcount)) return UOp(Ops.VECTORIZE, red.dtype, alus) def reduce_rangeless(red:UOp): # TODO: share code with reduce_unparented if red.arg not in {Ops.ADD, Ops.MAX}: return None if red.src[0].dtype != red.dtype: return None if any(x.op in {Ops.RANGE} for x in red.src[0].toposort()): return None ret = red.src[0] if red.arg is Ops.ADD: for r in red.src[1:]: ret = ret * r.src[0].cast(ret.dtype.scalar()).broadcast(ret.dtype.count) return ret def no_range(u:UOp) -> bool: return not any(x.op is Ops.RANGE for x in u.sparents) pm_reduce_collapse = PatternMatcher([ # lift x+y out of reduce on lt ((UPat.var("x")+UPat.var("y")) < UPat.var("c"), lambda x,y,c: (x < (c-y)) if no_range(y) and no_range(c) else None), # lift x*y out of reduce ((UPat.var("x")*UPat.var("y")) < UPat.var("c"), lambda x,y,c: (x < ((c+y-1) // y)) if no_range(y) and no_range(c) and y.vmin > 0 else None), # lift x+y out of reduce on ne ((UPat.var("x")+UPat.var("y")) != UPat.var("c"), lambda x,y,c: (x != (c-y)) if no_range(y) and no_range(c) else None), # fold the range ((UPat(Ops.RANGE, name="r") < UPat.var("cut")).where(0, UPat.cvar("val")).reduce(arg=Ops.ADD, allow_any_len=True), lambda r,cut,val: (r.src[0]-cut).maximum(0).minimum(r.src[0]).cast(val.dtype) * val), ((UPat(Ops.RANGE, name="r") < UPat.var("cut")).where(UPat.cvar("val"), 0).reduce(arg=Ops.ADD, allow_any_len=True), lambda r,cut,val: cut.maximum(0).minimum(r.src[0]).cast(val.dtype) * val), # REDUCE on ADD ((UPat.var("x")+UPat.var("y")).reduce(arg=Ops.ADD, allow_any_len=True, name="r"), lambda x,y,r: x.reduce(*r.src[1:], arg=Ops.ADD) + y.reduce(*r.src[1:],arg=Ops.ADD)), # MUL casted bool ((UPat.var("x") * UPat.var("gate", dtype=dtypes.bool).cast().or_broadcasted(name="b")), lambda x,gate,b=None: gate.broadcast(x.dtype.count).where(x, 0) if b is not None else gate.where(x, 0)), # WHERE on LOAD (works on max too) (UPat.var("gate").where(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"))).load(), 0).reduce(arg=Ops.ADD, allow_any_len=True), lambda buf,idx,gate: buf.index(idx, gate).load()), (UPat.var("gate").where(0, UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"))).load()).reduce(arg=Ops.ADD, allow_any_len=True), lambda buf,idx,gate: buf.index(idx, gate.logical_not()).load()), # INDEX on RANGE / gated RANGE (UPat.var("buf").index(UPat.var("expr"), UPat.var("idx").eq(UPat(Ops.RANGE, name="r").or_casted())), lambda buf,r,idx,expr: buf.index(expr.substitute({r:idx.cast(r.dtype)}), (idx.cast(r.dtype) >= 0) & (idx.cast(r.dtype) < r.src[0]))), # AND on WHERE ((UPat.any(UPat(Ops.DEFINE_VAR, name="x"), UPat(Ops.DEFINE_VAR).gep(name="x")) & UPat.var("y")) \ .where(UPat.cvar("c"), 0).reduce(arg=Ops.ADD, allow_any_len=True, name="r"), lambda x,y,c,r: y.where(c, 0).reduce(*r.src[1:], arg=Ops.ADD)*x.cast(c.dtype)), # remove REDUCEs that no longer have a RANGE in the src (UPat(Ops.REDUCE, name="red"), reduce_rangeless), # devectorize REDUCE (UPat(Ops.VECTORIZE, name="inp").reduce(name="red", allow_any_len=True), no_vectorized_reduce), # index/load/where. TODO: this is more aggressive than needed (UPat((Ops.INDEX, Ops.LOAD, Ops.WHERE), name="alu"), no_vectorized_alu), ])+sym def reduce_collapse(red:UOp): included, not_included = partition(red.parents, lambda x: any(y in x.sparents for y in red.src[1:])) if any(x.op in {Ops.STORE, Ops.REDUCE} for x in included): return None replaces: dict[UOp, UOp] = {} for u in included: for s in u.src: if s in not_included and s not in replaces and s.op not in {Ops.CONST, Ops.VCONST, Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_VAR}: replaces[s] = UOp(Ops.DEFINE_VAR, dtype=s.dtype, arg=(f'in{len(replaces)}', s.vmin, s.vmax)) collapse_fxn = red.substitute(replaces) sink = graph_rewrite(collapse_fxn, pm_reduce_collapse, name="reduce_collapse") # TODO: why is REDUCE needed here and just RANGE isn't enough? if any(x.op in {Ops.REDUCE, Ops.RANGE} for x in sink.toposort()): return None return sink.substitute({v:k for k,v in replaces.items()}) def reduce_unparented(red:UOp): if red.arg not in {Ops.ADD, Ops.MAX}: return None reduce_parented, reduce_unparented = partition(red.src[1:], lambda x: x in red.src[0].sparents) if len(reduce_unparented) == 0: return None ret = red.replace(src=(red.src[0],)+tuple(reduce_parented)) if len(reduce_parented) or red.dtype != red.src[0].dtype else red.src[0] if red.arg is Ops.ADD: for r in reduce_unparented: ret = ret * r.src[0].cast(ret.dtype.scalar()).broadcast(ret.dtype.count) return ret pm_reduce = PatternMatcher([ # remove any ranges from a REDUCE that aren't referenced in the reduce source (UPat(Ops.REDUCE, name="red"), reduce_unparented), # remove REDUCE without loads (generic arange opt / indexing). TODO: support multi range (UPat(Ops.REDUCE, src=(UPat(), UPat()), name="red"), reduce_collapse), # REDUCE -> DEFINE_ACC+ASSIGN (UPat(Ops.REDUCE, name="red"), reduce_to_acc), # tensor core built in accumulate (UPat(Ops.WMMA, name="wmma") + UPat.var("add"), lambda add, wmma: UOp(wmma.op, wmma.dtype, (wmma.src[0], wmma.src[1], wmma.src[2]+add), wmma.arg)), ])+sym