You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
134 lines
7.9 KiB
134 lines
7.9 KiB
from tinygrad.uop.ops import UOp, Ops, GroupOp, PatternMatcher, UPat, graph_rewrite, resolve, sint
|
|
from tinygrad.helpers import all_same, prod, unwrap, colored
|
|
from tinygrad.shape.shapetracker import ShapeTracker
|
|
from tinygrad.shape.view import View, strides_for_shape, get_contraction_with_reduce
|
|
from tinygrad.schedule.grouper import ALWAYS_CONTIGUOUS
|
|
from tinygrad.dtype import ImageDType, dtypes
|
|
|
|
merge_views = PatternMatcher([
|
|
# merge adjacent views
|
|
(UPat(Ops.VIEW, src=(UPat(Ops.VIEW, name="v1"),), name="v2"), lambda v1,v2: v1.replace(arg=v1.arg+v2.arg)),
|
|
# replace MovementOps with VIEW
|
|
(UPat(GroupOp.Movement, src=(UPat.var("x"),), name="mop"), lambda mop,x: x.base.view(mop.st)),
|
|
# remove NOOP views
|
|
(UPat.var("x").view(name="view"),
|
|
lambda x,view: x if x.st is not None and x.op not in GroupOp.Defines and view.st.contiguous and view.shape == x.shape else None),
|
|
(UPat(GroupOp.All-{Ops.DEFINE_GLOBAL}).view(name="view"),
|
|
lambda view: view.const_like(0) if (mask:=view.st.views[-1].mask) is not None and any((x[1]-x[0]) == 0 for x in mask) else None),
|
|
# only unmaksed VIEW on CONST replaces the ShapeTracker
|
|
(UPat(Ops.VIEW, src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="x"),), name="view"),
|
|
lambda x,view: x.replace(src=(x.src[0].replace(arg=x.st+view.st),)) if all(v.mask is None for v in (x.st+view.st).views) else None),
|
|
])
|
|
|
|
def reduce_push_add_ones(src:UOp, r:UOp, view:UOp):
|
|
# contiguous, expand, and the same with ones removed
|
|
if unwrap(view.st).contiguous and len(r.shape) < len(view.shape) and \
|
|
tuple(x for x in r.shape if resolve(x != 1)) == tuple(x for x in view.shape if resolve(x != 1)):
|
|
new_shape: list[sint] = []
|
|
new_reduce_axis = []
|
|
if (contraction:=get_contraction_with_reduce(view.shape, r.shape, r.arg[1])) is None: return None
|
|
for i,pairs in enumerate(contraction):
|
|
new_shape_chunk = [view.shape[p] for p in pairs]
|
|
if i in r.arg[1]:
|
|
# if this is a reduce axis, we need a 1 in the view here to put it
|
|
assert len(new_shape_chunk) > 0
|
|
new_shape += [1]*(len(pairs)-1) + [src.shape[i]]
|
|
new_reduce_axis.append(len(new_shape)-1)
|
|
else:
|
|
# otherwise, pass through the new_shape_chunk
|
|
new_shape += new_shape_chunk
|
|
ret = r.replace(src=(src.reshape(tuple(new_shape)),), arg=(r.arg[0], tuple(new_reduce_axis))+r.arg[2:])
|
|
assert ret.shape == view.shape, f"shape mismatch on reduce_push_add_ones, {ret.shape} != {view.shape}"
|
|
return ret
|
|
return None
|
|
|
|
view_left = merge_views+PatternMatcher([
|
|
# view before elementwise and buffer ops
|
|
(UPat(Ops.VIEW, src=(UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.BIND, Ops.STORE, Ops.VALID, Ops.SINK}, name="e"),), name="view"),
|
|
lambda e,view: e.replace(src=tuple(s.view(view.st) for s in e.src))),
|
|
# if there's ones added after reduce, put this before the reduce
|
|
(UPat(Ops.VIEW, src=(UPat(Ops.REDUCE_AXIS, src=(UPat.var("src"),), name="r"),), name="view"), reduce_push_add_ones),
|
|
])
|
|
|
|
view_left_through_load = PatternMatcher([
|
|
# view before load
|
|
(UPat(Ops.VIEW, src=(UPat(Ops.LOAD, name="e"),), name="view"),
|
|
lambda e,view: e.replace(src=tuple(s.view(view.st) for s in e.src))),
|
|
])
|
|
|
|
def apply_swizzle(u:UOp) -> UOp: return graph_rewrite(u, view_left, name="Sub View Left")
|
|
|
|
# change reduceop axes and input ShapeTrackers, view gets replaced with a reshape.
|
|
def swizzle_reduceop(r:UOp, src:UOp, view:UOp, fuse=False):
|
|
# contiguous and same size can push to children
|
|
# if there's a reduce child, shapes match with ones removed
|
|
if unwrap(view.st).contiguous and view.size == r.size and \
|
|
(not (len(r.arg) == 3 and r.arg[2]) or # arg[2] = True is fuse marker
|
|
tuple((i,x) for i,x in enumerate(r.shape) if resolve(x != 1)) == tuple((i,x) for i,x in enumerate(view.shape) if resolve(x != 1))):
|
|
return None
|
|
# swizzle the input
|
|
input_st = ShapeTracker.from_shape(src.shape)
|
|
tmp = input_st.permute(tuple(i for i in range(len(input_st.shape)) if i not in r.axis_arg)+r.axis_arg)
|
|
prshape = prod(rshape:=tmp.shape[-len(r.axis_arg):])
|
|
strides = strides_for_shape(rshape)
|
|
nv = [View.create(v.shape+rshape, tuple(x*prshape for x in v.strides)+strides,
|
|
v.offset*prshape, v.mask+tuple((0,s) for s in rshape) if v.mask is not None else None) for v in unwrap(view.st).views]
|
|
new_view = tmp + ShapeTracker(tuple(nv))
|
|
swizzled_input = apply_swizzle(src.view(new_view))
|
|
# create a new reduceop
|
|
new_axis = tuple(range(len(view.shape), len(view.shape) + len(r.axis_arg)))
|
|
if fuse: red = UOp(Ops.REDUCE_AXIS, r.dtype, (swizzled_input.fuse(),), (r.arg[0], new_axis, True))
|
|
else: red = UOp(Ops.REDUCE_AXIS, r.dtype, (swizzled_input,), (r.arg[0], new_axis))
|
|
return red.reshape(view.shape)
|
|
|
|
def reduceop_view_right(src:UOp, v:UOp, r:UOp):
|
|
assert unwrap(v.st).contiguous and v.size == src.size, f"can't compute new axis for {src.shape} -> {r.shape}"
|
|
new_axis = [i for i,(s,u) in enumerate(zip(src.shape, r.shape)) if s != u]
|
|
return src.r(r.arg[0], tuple(new_axis)).reshape(r.shape)
|
|
|
|
def elementwise_view_right(root:UOp):
|
|
if not (swizzles:=[x for x in root.src if x.op is Ops.VIEW and x.base.op not in ALWAYS_CONTIGUOUS]): return None
|
|
assert all_same([x.base.size for x in swizzles]), f"swizzle inputs must have the same size {swizzles}"
|
|
# place view after applying the elementwise op
|
|
new_st = ShapeTracker.from_shape(swizzles[0].base.shape)
|
|
new_src = [x.base if x.base.shape==new_st.shape else apply_swizzle(x.view(new_st)) for x in root.src]
|
|
# reshape to match downstream shapes
|
|
return root.replace(src=tuple(new_src)).reshape(root.shape)
|
|
|
|
# push VIEW to children
|
|
view_right = merge_views+PatternMatcher([
|
|
# push a non contiguous ShapeTracker through reduceop
|
|
(UPat(Ops.VIEW, src=(UPat(Ops.REDUCE_AXIS, src=(UPat.var("src"),), name="r"),), name="view"), swizzle_reduceop),
|
|
# apply view after reduceops
|
|
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.VIEW, src=(UPat(GroupOp.All-ALWAYS_CONTIGUOUS, name="src"),), name="v"),), name="r"), reduceop_view_right),
|
|
# apply view after elementwise ops
|
|
(UPat(GroupOp.All-{Ops.SINK, Ops.REDUCE_AXIS}, name="root"), elementwise_view_right),
|
|
# merge axes for double reduce (invert of SPLIT_REDUCEOP=1)
|
|
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="r1"),), name="r2"),
|
|
lambda r1,r2: r1.replace(arg=(r1.arg[0], r2.arg[1]+r1.arg[1])) if r1.arg[0] is r2.arg[0] else None),
|
|
# remove view from sink
|
|
(UPat(Ops.VIEW, name="v").sink(name="sink"), lambda v,sink: v.src[0].sink(arg=sink.arg)),
|
|
])
|
|
|
|
def check_load_st(glbl:UOp, view:UOp):
|
|
if glbl.arg != 0 or (st:=unwrap(view.st)).contiguous: return
|
|
# if it has a single view and it becomes contiguous when you shrink expanded axes, it's fine
|
|
if len(st.views) == 1 and st.shrink(tuple((0,1) if st == 0 else (0,s) for s,st in zip(st.shape, st.views[0].strides))).contiguous: return
|
|
# if it has a single view and it's equal when you shrink a contig, it's fine
|
|
if len(st.views) == 1 and (mask:=st.views[0].mask) is not None and ShapeTracker.from_shape(st.shape).shrink(mask) == st.shrink(mask): return
|
|
# otherwise, it's not fine
|
|
raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
|
|
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
|
|
|
|
fix_kernel_ops = view_left_through_load+PatternMatcher([
|
|
# add view to LOAD and STORE
|
|
(UPat(Ops.DEFINE_GLOBAL, name="g").load(), lambda g: g.view(g.st).load()),
|
|
(UPat(Ops.DEFINE_GLOBAL, name="g").store(UPat.var('x')), lambda g,x: g.view(g.st).store(x)),
|
|
# VALID
|
|
(UPat(Ops.VIEW, src=(UPat.cvar(),), name="self"),
|
|
lambda self: UOp.where(UOp(Ops.VALID, dtypes.bool, (UOp(Ops.VIEW, arg=self.st),)), self.const_like(self.base.arg), 0)),
|
|
# no ImageDType after index
|
|
(UPat(GroupOp.All-{Ops.DEFINE_GLOBAL, Ops.VIEW}, name="x"), lambda x: x.replace(dtype=x.dtype.base) if isinstance(x.dtype, ImageDType) else None),
|
|
# if this kernel also assigns to the loaded buffer, ensure we can index it correctly
|
|
(UPat(Ops.LOAD, src=(UPat.var("glbl").view(name="view"),)), check_load_st),
|
|
])
|
|
|