# this converts a lowerer program into a vectorized program import functools, itertools, operator from tinygrad.helpers import AMX, dedup, flatten, all_same, prod from tinygrad.ops import UOp, Ops, UPat, PatternMatcher, GroupOp, graph_rewrite from tinygrad.codegen.symbolic import sym def _expand_arg_to_idx(args:tuple[tuple[int, int], ...], rpk:dict[int, int]) -> int: idx, mul = 0, 1 for axis,m in args[::-1]: idx += rpk[axis] * mul mul *= m return idx def _choices_from_args(args:tuple[tuple[int, int], ...]) -> list[dict[int, int]]: return [dict(x) for x in itertools.product(*[zip(itertools.repeat(axis), range(m)) for axis,m in args])] @functools.lru_cache(None) def _swizzle_args(cargs:tuple[tuple[int, int], ...], eargs:tuple[tuple[int, int], ...], exclude_args:tuple[int, ...]) -> list[int]: return [_expand_arg_to_idx(eargs, {**rpk, **{x:0 for x in exclude_args}} if exclude_args else rpk) for rpk in _choices_from_args(cargs)] def do_expand(root:UOp): expands = [x for x in root.src if x.op is Ops.UNROLL] if len(expands) == 0: return None # NOTE: we 0 out the reduce axis for WMMA. in theory they should all be the same, but is this always correct? exclude_args = tuple(dedup(root.arg[-1] + tuple(y[0] for y in flatten(root.arg[-2])))) if root.op is Ops.WMMA else () if all_same(expands_args:=[x.arg for x in expands]) and len(exclude_args) == 0: # if there's only one expand arg, it's okay to use it (optimization) expand_args = expands[0].arg else: # otherwise, we sort them and GEP expand_args = tuple(x for x in sorted(dedup(flatten(expands_args))) if x[0] not in exclude_args) expand_sz = prod([x[1] for x in expand_args]) new_srcs = [] for i,src in enumerate(root.src): if src.op is Ops.UNROLL: if root.op is Ops.IF and i == 0: # IF means OR on first arg to IF new_srcs.append(functools.reduce(operator.__or__, [src.src[0].gep(i) for i in range(expand_sz)])) elif expand_args == src.arg: # just remove the expand new_srcs.append(src.src[0]) else: lst = _swizzle_args(expand_args, src.arg, exclude_args) # if the base dtype is > 1, put those at the end if src.dtype.count > 1: lst = flatten([[i*src.dtype.count+j for j in range(src.dtype.count)] for i in lst]) new_srcs.append(src.src[0].gep(tuple(lst))) else: # non-UNROLL input if root.op is Ops.IF: # for the first arg of IF, just pass them through ignoring UNROLLS new_srcs.append(src) elif root.op is Ops.REDUCE and src.op is Ops.RANGE: # for any range args of REDUCE, pass them through new_srcs.append(src) elif src.dtype.count > 1: # put any input dtype > 1 grouped together new_srcs.append(UOp(Ops.CAT, src.dtype.scalar().vec(expand_sz*src.dtype.count), (src,)*expand_sz)) else: # repeat the arg new_srcs.append(src.broadcast(expand_sz)) new_arg = root.arg if root.op is Ops.GEP: assert root.dtype.count == 1 # is this right? new_arg = tuple(range(root.arg[0], new_srcs[0].dtype.count, new_srcs[0].dtype.count // expand_sz)) nsrc = UOp(root.op, root.dtype.scalar().vec(root.dtype.count*expand_sz), tuple(new_srcs), new_arg) return UOp(Ops.UNROLL, root.dtype, (nsrc,), expand_args) def do_contract(con:UOp): ex = con.src[0] # CONTRACT without UNROLL repeats the element VECTORIZED if ex.op is not Ops.UNROLL: return UOp(Ops.VECTORIZE, con.dtype, con.src*con.dtype.count) # CONTRACT may remove several axes from UNROLL assert con.dtype.count == prod([x[1] for x in con.arg]), "dtype is wrong" idxs = [] for rpk in _choices_from_args(new_ex_args:=tuple(x for x in ex.arg if x not in con.arg)): idxs += [_expand_arg_to_idx(ex.arg, {**rpk, **lrpk}) for lrpk in _choices_from_args(con.arg)] return UOp(Ops.UNROLL, con.dtype, (ex.src[0].gep(tuple(idxs)),), new_ex_args) expander = PatternMatcher([ # double expand (UPat(Ops.UNROLL, name="outer", src=(UPat(Ops.UNROLL, name="inner"),)), lambda outer, inner: UOp(Ops.UNROLL, outer.dtype, (inner.src[0],), inner.arg+outer.arg)), # do expansion (UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.GEP, Ops.WMMA, Ops.LOAD, Ops.STORE, Ops.INDEX, Ops.ASSIGN, Ops.VECTORIZE, Ops.IF, Ops.REDUCE), name="root", custom_early_reject=set([Ops.UNROLL])), do_expand), (UPat(Ops.CONTRACT, name="con"), do_contract), # vectorize DEFINE_ACC (UPat(Ops.VECTORIZE, src=UPat(Ops.DEFINE_ACC, name="acc"), name="v"), lambda acc,v: acc.replace(dtype=v.dtype, src=(acc.src[0].broadcast(v.dtype.count),)+acc.src[1:])), # BARRIERs aren't actually expanded (UPat(Ops.BARRIER, src=(UPat(Ops.UNROLL, name="ex"),)), lambda ex: UOp(Ops.UNROLL, src=(UOp(Ops.BARRIER, src=ex.src),)*len(ex.src), arg=ex.arg)), # empty UNROLL is NOOP (UPat(Ops.UNROLL, src=(UPat.var('x'),), arg=()), lambda x: x), # UNROLL GEP (needed for WMMA, generalize this) -> vectorized ALU (UPat(Ops.UNROLL, name="ex", src=tuple(UPat.var('x').gep(i)+UPat.var('y').gep(i) for i in range(256 if AMX else 8))), lambda ex,x,y: UOp(Ops.UNROLL, ex.dtype, tuple((x+y).gep(i) for i in range(256 if AMX else 8)), ex.arg)), ]) def create_gate(root:UOp) -> UOp|None: @functools.lru_cache(None) def _gate_srcs(u:UOp, gate:UOp) -> UOp: if u.op is Ops.BARRIER: return u if u.op is Ops.LOAD and u.src[-1].op is Ops.BARRIER: return UOp(u.op, u.dtype, u.src[:-1]+(UOp(Ops.IF, src=(gate, u.src[-1])),), arg=u.arg) return u if (replace_source:=tuple(_gate_srcs(x, gate) for x in u.src)) == u.src else UOp(u.op, u.dtype, replace_source, u.arg) idx = root.src[0] if idx.op is Ops.CAST: idx = idx.src[0] return None if idx.op is not Ops.INDEX or len(idx.src) == 2 or (ret:=_gate_srcs(root, idx.src[2])) is root else ret migrate_indexing = PatternMatcher([ # create gate MUST BE BEFORE expander (UPat(Ops.STORE, name="root"), create_gate), ]) def expand_rewrite(sink:UOp) -> UOp: # initial symbolic + migrate indexing (remove this) sink = graph_rewrite(sink, sym+migrate_indexing) # expand return graph_rewrite(sink, sym+expander)