You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
124 lines
6.0 KiB
124 lines
6.0 KiB
# this converts a lowerer program into a vectorized program
|
|
|
|
import functools, itertools, operator
|
|
from tinygrad.helpers import AMX, dedup, flatten, all_same, prod
|
|
from tinygrad.ops import UOp, Ops, UPat, PatternMatcher, GroupOp, graph_rewrite
|
|
from tinygrad.codegen.symbolic import sym
|
|
|
|
def _expand_arg_to_idx(args:tuple[tuple[int, int], ...], rpk:dict[int, int]) -> int:
|
|
idx, mul = 0, 1
|
|
for axis,m in args[::-1]:
|
|
idx += rpk[axis] * mul
|
|
mul *= m
|
|
return idx
|
|
|
|
def _choices_from_args(args:tuple[tuple[int, int], ...]) -> list[dict[int, int]]:
|
|
return [dict(x) for x in itertools.product(*[zip(itertools.repeat(axis), range(m)) for axis,m in args])]
|
|
|
|
@functools.lru_cache(None)
|
|
def _swizzle_args(cargs:tuple[tuple[int, int], ...], eargs:tuple[tuple[int, int], ...], exclude_args:tuple[int, ...]) -> list[int]:
|
|
return [_expand_arg_to_idx(eargs, {**rpk, **{x:0 for x in exclude_args}} if exclude_args else rpk) for rpk in _choices_from_args(cargs)]
|
|
|
|
def do_expand(root:UOp):
|
|
expands = [x for x in root.src if x.op is Ops.UNROLL]
|
|
if len(expands) == 0: return None
|
|
# NOTE: we 0 out the reduce axis for WMMA. in theory they should all be the same, but is this always correct?
|
|
exclude_args = tuple(dedup(root.arg[-1] + tuple(y[0] for y in flatten(root.arg[-2])))) if root.op is Ops.WMMA else ()
|
|
if all_same(expands_args:=[x.arg for x in expands]) and len(exclude_args) == 0:
|
|
# if there's only one expand arg, it's okay to use it (optimization)
|
|
expand_args = expands[0].arg
|
|
else:
|
|
# otherwise, we sort them and GEP
|
|
expand_args = tuple(x for x in sorted(dedup(flatten(expands_args))) if x[0] not in exclude_args)
|
|
expand_sz = prod([x[1] for x in expand_args])
|
|
new_srcs = []
|
|
for i,src in enumerate(root.src):
|
|
if src.op is Ops.UNROLL:
|
|
if root.op is Ops.IF and i == 0:
|
|
# IF means OR on first arg to IF
|
|
new_srcs.append(functools.reduce(operator.__or__, [src.src[0].gep(i) for i in range(expand_sz)]))
|
|
elif expand_args == src.arg:
|
|
# just remove the expand
|
|
new_srcs.append(src.src[0])
|
|
else:
|
|
lst = _swizzle_args(expand_args, src.arg, exclude_args)
|
|
# if the base dtype is > 1, put those at the end
|
|
if src.dtype.count > 1: lst = flatten([[i*src.dtype.count+j for j in range(src.dtype.count)] for i in lst])
|
|
new_srcs.append(src.src[0].gep(tuple(lst)))
|
|
else:
|
|
# non-UNROLL input
|
|
if root.op is Ops.IF:
|
|
# for the first arg of IF, just pass them through ignoring UNROLLS
|
|
new_srcs.append(src)
|
|
elif root.op is Ops.REDUCE and src.op is Ops.RANGE:
|
|
# for any range args of REDUCE, pass them through
|
|
new_srcs.append(src)
|
|
elif src.dtype.count > 1:
|
|
# put any input dtype > 1 grouped together
|
|
new_srcs.append(UOp(Ops.CAT, src.dtype.scalar().vec(expand_sz*src.dtype.count), (src,)*expand_sz))
|
|
else:
|
|
# repeat the arg
|
|
new_srcs.append(src.broadcast(expand_sz))
|
|
|
|
new_arg = root.arg
|
|
if root.op is Ops.GEP:
|
|
assert root.dtype.count == 1
|
|
# is this right?
|
|
new_arg = tuple(range(root.arg[0], new_srcs[0].dtype.count, new_srcs[0].dtype.count // expand_sz))
|
|
nsrc = UOp(root.op, root.dtype.scalar().vec(root.dtype.count*expand_sz), tuple(new_srcs), new_arg)
|
|
return UOp(Ops.UNROLL, root.dtype, (nsrc,), expand_args)
|
|
|
|
def do_contract(con:UOp):
|
|
ex = con.src[0]
|
|
# CONTRACT without UNROLL repeats the element VECTORIZED
|
|
if ex.op is not Ops.UNROLL: return UOp(Ops.VECTORIZE, con.dtype, con.src*con.dtype.count)
|
|
# CONTRACT may remove several axes from UNROLL
|
|
assert con.dtype.count == prod([x[1] for x in con.arg]), "dtype is wrong"
|
|
idxs = []
|
|
for rpk in _choices_from_args(new_ex_args:=tuple(x for x in ex.arg if x not in con.arg)):
|
|
idxs += [_expand_arg_to_idx(ex.arg, {**rpk, **lrpk}) for lrpk in _choices_from_args(con.arg)]
|
|
return UOp(Ops.UNROLL, con.dtype, (ex.src[0].gep(tuple(idxs)),), new_ex_args)
|
|
|
|
expander = PatternMatcher([
|
|
# double expand
|
|
(UPat(Ops.UNROLL, name="outer", src=(UPat(Ops.UNROLL, name="inner"),)),
|
|
lambda outer, inner: UOp(Ops.UNROLL, outer.dtype, (inner.src[0],), inner.arg+outer.arg)),
|
|
# do expansion
|
|
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.GEP, Ops.WMMA, Ops.LOAD, Ops.STORE, Ops.INDEX, Ops.ASSIGN,
|
|
Ops.VECTORIZE, Ops.IF, Ops.REDUCE), name="root", custom_early_reject=set([Ops.UNROLL])), do_expand),
|
|
(UPat(Ops.CONTRACT, name="con"), do_contract),
|
|
# vectorize DEFINE_ACC
|
|
(UPat(Ops.VECTORIZE, src=UPat(Ops.DEFINE_ACC, name="acc"), name="v"),
|
|
lambda acc,v: acc.replace(dtype=v.dtype, src=(acc.src[0].broadcast(v.dtype.count),)+acc.src[1:])),
|
|
# BARRIERs aren't actually expanded
|
|
(UPat(Ops.BARRIER, src=(UPat(Ops.UNROLL, name="ex"),)),
|
|
lambda ex: UOp(Ops.UNROLL, src=(UOp(Ops.BARRIER, src=ex.src),)*len(ex.src), arg=ex.arg)),
|
|
# empty UNROLL is NOOP
|
|
(UPat(Ops.UNROLL, src=(UPat.var('x'),), arg=()), lambda x: x),
|
|
# UNROLL GEP (needed for WMMA, generalize this) -> vectorized ALU
|
|
(UPat(Ops.UNROLL, name="ex", src=tuple(UPat.var('x').gep(i)+UPat.var('y').gep(i) for i in range(256 if AMX else 8))),
|
|
lambda ex,x,y: UOp(Ops.UNROLL, ex.dtype, tuple((x+y).gep(i) for i in range(256 if AMX else 8)), ex.arg)),
|
|
])
|
|
|
|
def create_gate(root:UOp) -> UOp|None:
|
|
@functools.lru_cache(None)
|
|
def _gate_srcs(u:UOp, gate:UOp) -> UOp:
|
|
if u.op is Ops.BARRIER: return u
|
|
if u.op is Ops.LOAD and u.src[-1].op is Ops.BARRIER:
|
|
return UOp(u.op, u.dtype, u.src[:-1]+(UOp(Ops.IF, src=(gate, u.src[-1])),), arg=u.arg)
|
|
return u if (replace_source:=tuple(_gate_srcs(x, gate) for x in u.src)) == u.src else UOp(u.op, u.dtype, replace_source, u.arg)
|
|
idx = root.src[0]
|
|
if idx.op is Ops.CAST: idx = idx.src[0]
|
|
return None if idx.op is not Ops.INDEX or len(idx.src) == 2 or (ret:=_gate_srcs(root, idx.src[2])) is root else ret
|
|
|
|
migrate_indexing = PatternMatcher([
|
|
# create gate MUST BE BEFORE expander
|
|
(UPat(Ops.STORE, name="root"), create_gate),
|
|
])
|
|
|
|
def expand_rewrite(sink:UOp) -> UOp:
|
|
# initial symbolic + migrate indexing (remove this)
|
|
sink = graph_rewrite(sink, sym+migrate_indexing)
|
|
|
|
# expand
|
|
return graph_rewrite(sink, sym+expander)
|
|
|