You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
270 lines
13 KiB
270 lines
13 KiB
|
3 days ago
|
from typing import Iterator
|
||
|
|
import functools, operator, itertools
|
||
|
|
from dataclasses import dataclass, field
|
||
|
|
from tinygrad.dtype import dtypes, AddrSpace
|
||
|
|
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, graph_rewrite, sint, AxisType, profile_matches
|
||
|
|
from tinygrad.uop.symbolic import symbolic, pm_simplify_valid, pm_drop_and_clauses
|
||
|
|
from tinygrad.helpers import argsort, all_same, cpu_profile, PCONTIG, colored
|
||
|
|
|
||
|
|
ALWAYS_CONTIGUOUS: set[Ops] = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW,
|
||
|
|
Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK, Ops.DEFINE_GLOBAL,
|
||
|
|
Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.LOAD, Ops.KERNEL}
|
||
|
|
|
||
|
|
def realize(ctx:dict[UOp, None], tr:UOp) -> None: ctx[tr] = None
|
||
|
|
|
||
|
|
def realize_srcs(ctx:dict[UOp, None], rb:UOp) -> None:
|
||
|
|
for s in rb.src:
|
||
|
|
if s.base.op not in ALWAYS_CONTIGUOUS: ctx[s] = None
|
||
|
|
|
||
|
|
def realize_assign(ctx:dict[UOp, None], a:UOp) -> None:
|
||
|
|
if a.src[1].op not in ALWAYS_CONTIGUOUS: ctx[a.src[1]] = None
|
||
|
|
# if it's a kernel, we don't realize it
|
||
|
|
if a.src[1].op is not Ops.KERNEL: ctx[a] = None
|
||
|
|
|
||
|
|
pm_generate_realize_map = PatternMatcher([
|
||
|
|
# always realize SINK src
|
||
|
|
(UPat(Ops.SINK, name="s"), lambda ctx,s: ctx.update((x.base, None) for x in s.src if x.base.op not in ALWAYS_CONTIGUOUS)),
|
||
|
|
# always realize COPY/BUFFER_VIEW/CONTIGUOUS
|
||
|
|
(UPat({Ops.COPY, Ops.BUFFER_VIEW, Ops.CONTIGUOUS}, name="tr"), realize),
|
||
|
|
# realize srcs of COPY, MSELECT, MSTACK
|
||
|
|
(UPat((Ops.COPY, Ops.MSELECT, Ops.MSTACK), name="rb"), realize_srcs),
|
||
|
|
# realize ASSIGN and input to assign (might be optimized out)
|
||
|
|
(UPat(Ops.ASSIGN, name="a"), realize_assign),
|
||
|
|
])
|
||
|
|
|
||
|
|
@dataclass(frozen=True)
|
||
|
|
class BufferizeOpts:
|
||
|
|
# on AddrSpace.LOCAL, device is the id
|
||
|
|
device: str|tuple[str, ...]|int|None
|
||
|
|
addrspace: AddrSpace = AddrSpace.GLOBAL
|
||
|
|
|
||
|
|
@dataclass
|
||
|
|
class IndexingContext:
|
||
|
|
realize_map: dict[UOp, None|list[int]] = field(default_factory=dict)
|
||
|
|
range_map: dict[UOp, tuple[tuple[UOp, ...], tuple[UOp, ...]]] = field(default_factory=dict)
|
||
|
|
|
||
|
|
# create ranges
|
||
|
|
range_idx: Iterator[int] = field(default_factory=itertools.count)
|
||
|
|
def new_range(self, s:sint, axistype:AxisType=AxisType.LOOP) -> UOp:
|
||
|
|
if isinstance(s, UOp) and s.op is Ops.RANGE: return s
|
||
|
|
# if a range has a 1 src, it's the same as UOp.const(dtypes.index, 0)
|
||
|
|
return UOp.range(s, next(self.range_idx), axistype) if resolve(s!=1) else UOp.const(dtypes.index, 0)
|
||
|
|
|
||
|
|
def create_bufferize_and_index_based_on_ranges(ctx:IndexingContext, x:UOp):
|
||
|
|
if x.op in {Ops.BUFFERIZE, Ops.INDEX}: return None
|
||
|
|
if x.op is Ops.AFTER and x.src[1].op is Ops.KERNEL: return None
|
||
|
|
new_srcs = []
|
||
|
|
for s in x.src:
|
||
|
|
new_src = s
|
||
|
|
if s.op in {Ops.BUFFER, Ops.BUFFER_VIEW, Ops.MSTACK, Ops.MSELECT} or (s.op is Ops.AFTER and s.src[1].op is Ops.KERNEL):
|
||
|
|
if x in ctx.range_map: new_src = new_src.index(*ctx.range_map[x][0])
|
||
|
|
elif s in ctx.realize_map:
|
||
|
|
realized_ranges = ctx.realize_map[s]
|
||
|
|
assert isinstance(realized_ranges, list), "realize map must contain range list"
|
||
|
|
closed_ranges = tuple([r for i,r in enumerate(ctx.range_map[s][1]) if i in realized_ranges])
|
||
|
|
# None in the device assigns it a number later
|
||
|
|
opts = BufferizeOpts(device=s.device) if len(ctx.range_map[s][1]) == len(realized_ranges) else BufferizeOpts(None, AddrSpace.LOCAL)
|
||
|
|
new_src = UOp(Ops.BUFFERIZE, s.dtype, src=(new_src,)+closed_ranges, arg=opts, tag=s.tag if opts.addrspace == AddrSpace.GLOBAL else None)
|
||
|
|
if x in ctx.range_map: new_src = new_src.index(*[r for i,r in enumerate(ctx.range_map[x][0]) if i in realized_ranges])
|
||
|
|
new_srcs.append(new_src)
|
||
|
|
# NOTE: do we need this?
|
||
|
|
return x.replace(src=tns) if x.src != (tns:=tuple(new_srcs)) else None
|
||
|
|
|
||
|
|
def convert_pad_to_where_to_keep_behavior_local(ctx:IndexingContext, x:UOp):
|
||
|
|
if x not in ctx.range_map: return None
|
||
|
|
valid: UOp = functools.reduce(operator.and_, [r.get_valid() for r in ctx.range_map[x][0]], UOp.const(dtypes.bool, True))
|
||
|
|
ret = valid.where(x.src[0], UOp.const(x.dtype, 0))
|
||
|
|
ctx.range_map[ret] = ctx.range_map[x]
|
||
|
|
return ret
|
||
|
|
|
||
|
|
def convert_reduce_axis_to_reduce_with_ranges(ctx:IndexingContext, x:UOp):
|
||
|
|
# input ranges
|
||
|
|
new_ranges = [r for i,r in enumerate(ctx.range_map[x][0]) if i in x.arg[1]]
|
||
|
|
ret = UOp(Ops.REDUCE, x.dtype, src=(x.src[0],)+tuple(new_ranges), arg=x.arg[0], tag=x.tag)
|
||
|
|
ctx.range_map[ret] = ctx.range_map[x]
|
||
|
|
return ret
|
||
|
|
|
||
|
|
def remove_movement_op_after_rangeify(ctx:IndexingContext, x:UOp):
|
||
|
|
if x in ctx.range_map or x.src[0].op is Ops.INDEX: return x.src[0]
|
||
|
|
|
||
|
|
def add_third_op_to_assign_to_track_shape(ctx:IndexingContext, assign:UOp):
|
||
|
|
if assign.src[1].op is Ops.KERNEL: return None
|
||
|
|
to_mop = graph_rewrite(assign.src[0], PatternMatcher([(UPat(GroupOp.Movement, name="x"), lambda x: x.replace(tag=()))]))
|
||
|
|
ret = assign.replace(src=assign.src+(to_mop,))
|
||
|
|
ctx.range_map[ret] = ctx.range_map[assign]
|
||
|
|
return ret
|
||
|
|
|
||
|
|
pm_apply_rangeify = PatternMatcher([
|
||
|
|
# REDUCE_AXIS -> REDUCE
|
||
|
|
(UPat(Ops.REDUCE_AXIS, name="x"), convert_reduce_axis_to_reduce_with_ranges),
|
||
|
|
# PAD -> WHERE
|
||
|
|
(UPat(Ops.PAD, name="x"), convert_pad_to_where_to_keep_behavior_local),
|
||
|
|
# add third op to assign
|
||
|
|
(UPat(Ops.ASSIGN, src=(UPat(), UPat()), name="assign"), add_third_op_to_assign_to_track_shape),
|
||
|
|
# finally, apply_rangeify
|
||
|
|
(UPat(GroupOp.All, name="x"), create_bufferize_and_index_based_on_ranges),
|
||
|
|
# remove movement op
|
||
|
|
(UPat(GroupOp.Movement, name="x"), remove_movement_op_after_rangeify),
|
||
|
|
# const/define_var shouldn't have src
|
||
|
|
(UPat((Ops.CONST, Ops.DEFINE_VAR), name="c"), lambda ctx,c: c.replace(src=()) if c in ctx.range_map else None),
|
||
|
|
])
|
||
|
|
|
||
|
|
# this is the definition of the movement ops
|
||
|
|
@functools.cache
|
||
|
|
def apply_movement_op(op:Ops, in_shape:tuple[sint,...], arg:tuple, rngs:tuple[UOp, ...]) -> tuple[UOp, ...]:
|
||
|
|
match op:
|
||
|
|
case Ops.SHRINK: rngs = tuple(a if ss == 0 else a+ss for a,(ss,_) in zip(rngs, arg))
|
||
|
|
case Ops.PERMUTE: rngs = tuple(rngs[p] for p in argsort(arg))
|
||
|
|
case Ops.FLIP: rngs = tuple(((s-1)-a) if f else a for a,s,f in zip(rngs, in_shape, arg))
|
||
|
|
case Ops.EXPAND: rngs = tuple(a if in_sh == out_sh else a.const_like(0) for a,in_sh,out_sh in zip(rngs, in_shape, arg))
|
||
|
|
case Ops.PAD:
|
||
|
|
# TODO: why is multiple graph_rewrites faster than one here?
|
||
|
|
# TODO: the .where(r-s, i) is not inside the graph_rewrite so that `convert_pad_to_where_to_keep_behavior_local`
|
||
|
|
# wraps the pad with only the newly added valid
|
||
|
|
rngs = tuple(r if (s == 0 and e == 0) else graph_rewrite(((r >= s) & (r < (sh+s))),
|
||
|
|
symbolic+pm_simplify_valid, name="pad").where(r-s, UOp.invalid()) for r,sh,(s,e) in zip(rngs, in_shape, arg))
|
||
|
|
case Ops.RESHAPE:
|
||
|
|
acc = 1
|
||
|
|
axes_in:list[UOp] = []
|
||
|
|
for s,src in list(zip(arg, rngs))[::-1]:
|
||
|
|
axes_in.append(acc*src)
|
||
|
|
acc *= s
|
||
|
|
combined_axes = sum(axes_in, start=UOp.const(dtypes.index, 0))
|
||
|
|
axes_out:list[UOp] = []
|
||
|
|
for s in in_shape[::-1]:
|
||
|
|
axes_out.append(combined_axes % s)
|
||
|
|
combined_axes //= s
|
||
|
|
# this simplify is doing a lot of heavy lifting. this is the replacement for the reshape view merging code
|
||
|
|
rngs = graph_rewrite(UOp.sink(*axes_out[::-1]), symbolic+pm_simplify_valid+pm_drop_and_clauses, name="reshape").src
|
||
|
|
case _: raise RuntimeError(f"{op} is not a MovementOp")
|
||
|
|
return rngs
|
||
|
|
|
||
|
|
@profile_matches
|
||
|
|
def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]:
|
||
|
|
if debug: print("**************************")
|
||
|
|
rctx = IndexingContext()
|
||
|
|
|
||
|
|
# get ops to realize
|
||
|
|
graph_rewrite(tsink, pm_generate_realize_map, ctx=rctx.realize_map, name="get realize")
|
||
|
|
|
||
|
|
# get the traversal order
|
||
|
|
with cpu_profile("reverse toposort", "TINY"):
|
||
|
|
tsink_reverse_toposort = tsink.reverse_toposort(consumer_map:=tsink.get_consumer_map())
|
||
|
|
|
||
|
|
# explicit rangeify
|
||
|
|
ending_ranges: dict[UOp, list[UOp]] = {}
|
||
|
|
for x in tsink_reverse_toposort:
|
||
|
|
if x.op in {Ops.DEVICE, Ops.UNIQUE}: continue
|
||
|
|
|
||
|
|
# no ranges on kernels, they are internal
|
||
|
|
if x.op is Ops.KERNEL: continue
|
||
|
|
|
||
|
|
if x.dtype.scalar() == dtypes.index: continue # TODO: why do I need this?
|
||
|
|
ending_ranges[x] = sum([ending_ranges.get(u, []) for u in consumer_map[x]], [])
|
||
|
|
|
||
|
|
# *** the ranges on the output are
|
||
|
|
# 1. new if this op is realized
|
||
|
|
# 2. from the single consumer if this op only has one consumer
|
||
|
|
# 3. potentially new if this op has 2+ consumers
|
||
|
|
|
||
|
|
consumer_rngs = [rctx.range_map[c][0] for c in consumer_map[x] if c in rctx.range_map]
|
||
|
|
if x in rctx.realize_map:
|
||
|
|
# if this is in the realize_map, we create new ranges (at the output)
|
||
|
|
out_rngs = tuple(rctx.new_range(s) for s in x.shape)
|
||
|
|
# all ranges are ended now
|
||
|
|
ending_ranges[x] = []
|
||
|
|
# mark all ranges as ended
|
||
|
|
assert rctx.realize_map[x] is None
|
||
|
|
rctx.realize_map[x] = list(range(len(x.shape)))
|
||
|
|
elif x.op in {Ops.MSTACK, Ops.MSELECT}:
|
||
|
|
# treat MSTACK/MSELECT like SINK
|
||
|
|
continue
|
||
|
|
elif len(consumer_rngs) == 0:
|
||
|
|
# if no consumers have ranges and this isn't realized, this doesn't have ranges either.
|
||
|
|
continue
|
||
|
|
elif len(consumer_rngs) == 1:
|
||
|
|
# if this has one consumer, it inherits the ranges from it
|
||
|
|
out_rngs = consumer_rngs[0]
|
||
|
|
elif len(consumer_rngs) > 1:
|
||
|
|
# if this has two consumers, we have to merge the ranges and might create new ones
|
||
|
|
all_rngs: list[tuple[UOp, ...]] = list(zip(*consumer_rngs))
|
||
|
|
rngs_valids = []
|
||
|
|
for valid_rngs in all_rngs:
|
||
|
|
local_rngs, valids = zip(*[(r.get_idx(), r.get_valid()) for r in valid_rngs])
|
||
|
|
rngs_valids.append((local_rngs, valids))
|
||
|
|
|
||
|
|
# TODO: in RANGEIFY > 1 all_all_same isn't required
|
||
|
|
all_all_same = all(all_same(local_rngs) for local_rngs,_ in rngs_valids)
|
||
|
|
_out_rngs = []
|
||
|
|
_realize_axis = []
|
||
|
|
for i,(local_rngs,valids) in enumerate(rngs_valids):
|
||
|
|
# we compare the ranges without their valids
|
||
|
|
if all_all_same or (PCONTIG and all_same(local_rngs)):
|
||
|
|
# the new valid is the OR of all the children valids
|
||
|
|
minimum_valid = functools.reduce(operator.or_, valids, UOp.const(dtypes.bool, False))
|
||
|
|
_out_rngs.append(graph_rewrite(minimum_valid.where(local_rngs[0], UOp.invalid()), symbolic, name="minimum_valid"))
|
||
|
|
else:
|
||
|
|
_out_rngs.append(rctx.new_range(x.shape[i]))
|
||
|
|
_realize_axis.append(i)
|
||
|
|
out_rngs = tuple(_out_rngs)
|
||
|
|
|
||
|
|
# we have to (partially) realize here if there's new ranges
|
||
|
|
if len(_realize_axis): rctx.realize_map[x] = _realize_axis
|
||
|
|
|
||
|
|
# if this element is a reduce and there's ended ranges, we might have to end some other ranges
|
||
|
|
if len(ending_ranges[x]) and x.op in GroupOp.Elementwise.union({Ops.REDUCE_AXIS}):
|
||
|
|
_realize_axis = rctx.realize_map.get(x, []) or []
|
||
|
|
for i,r in enumerate(out_rngs):
|
||
|
|
if i in _realize_axis: continue
|
||
|
|
if not (PCONTIG > 1) or any(any(rr.arg > e.arg for e in ending_ranges[x]) for rr in r.ranges):
|
||
|
|
_realize_axis.append(i)
|
||
|
|
ending_ranges[x] = []
|
||
|
|
if len(_realize_axis):
|
||
|
|
rctx.realize_map[x] = _realize_axis
|
||
|
|
out_rngs = tuple([(rctx.new_range(x.shape[i]) if i in _realize_axis else r) for i,r in enumerate(out_rngs)])
|
||
|
|
|
||
|
|
# TODO: some ops don't have shape, enable this after the `.st` property is removed
|
||
|
|
#assert len(out_rngs) == len(x.shape), \
|
||
|
|
# f"shape len mismatch {len(out_rngs)} != {len(x.shape)} on {x.op} with {len(consumer_map[x])} consumers and realize {x in realize_map}"
|
||
|
|
|
||
|
|
# *** the ranges on the inputs are
|
||
|
|
# 1. swizzled for MovementOps
|
||
|
|
# 2. newly created for REDUCE_AXIS
|
||
|
|
# 3. passed through for everything else
|
||
|
|
|
||
|
|
rngs = out_rngs # rngs is the input ranges # pylint: disable=possibly-used-before-assignment
|
||
|
|
|
||
|
|
# apply movement ops
|
||
|
|
if x.op in GroupOp.Movement: rngs = apply_movement_op(x.op, x.src[0].shape, x.marg, rngs)
|
||
|
|
# if the EXPAND is used to inject a range, we don't mark it as ending_ranges. otherwise we do.
|
||
|
|
# NOTE: this doesn't actually always end a range, but this is why convs are realized, so for now we need it
|
||
|
|
if x.op is Ops.EXPAND and all(isinstance(y, int) or y.op is not Ops.RANGE for y in x.shape):
|
||
|
|
ending_ranges[x] += list(UOp.sink(*[ro for ri, ro in zip(rngs, out_rngs) if ri is not ro]).ranges.keys())
|
||
|
|
|
||
|
|
# REDUCE_AXIS creates ranges for the axes it is reducing
|
||
|
|
if x.op is Ops.REDUCE_AXIS:
|
||
|
|
rngs = tuple(rctx.new_range(s, axistype=AxisType.REDUCE) if i in x.arg[1] else r for i,(r,s) in enumerate(zip(rngs, x.src[0].shape)))
|
||
|
|
|
||
|
|
if debug:
|
||
|
|
realized_ranges = rctx.realize_map.get(x, None)
|
||
|
|
if x.op is Ops.RESHAPE or len(rngs) != len(out_rngs):
|
||
|
|
disp = render_ranges(rngs, realized=realized_ranges) + " -> " + render_ranges(out_rngs, realized=realized_ranges)
|
||
|
|
else:
|
||
|
|
disp = render_ranges(rngs, out_rngs, realized=realized_ranges)
|
||
|
|
print("***" if x in rctx.realize_map else " ",
|
||
|
|
f"{len(consumer_map[x]):2d} {str(x.op):20s} {str(x._shape):35s} {len(ending_ranges[x]):2d}", disp)
|
||
|
|
|
||
|
|
# assign to the range map. rngs are the input ranges, out_rngs are the output ranges, from the x op.
|
||
|
|
rctx.range_map[x] = (rngs, out_rngs)
|
||
|
|
|
||
|
|
tsink = graph_rewrite(tsink, pm_apply_rangeify, ctx=rctx, bottom_up=True, name="apply rangeify")
|
||
|
|
return tsink, rctx
|
||
|
|
|
||
|
|
def render_ranges(*rngs_list, realized) -> str:
|
||
|
|
disp = []
|
||
|
|
for i, rs in enumerate(zip(*[[r.render() for r in rngs] for rngs in rngs_list])):
|
||
|
|
rng = rs[0] if all_same(rs) else " -> ".join(rs)
|
||
|
|
if realized is not None and i in realized: rng = colored(rng, "yellow")
|
||
|
|
disp.append("["+rng+"]")
|
||
|
|
return ''.join(disp)
|