import itertools from enum import Enum, auto from collections import defaultdict from typing import List, Tuple, DefaultDict from tinygrad.helpers import prod, tqdm from tinygrad.uop.ops import UOp, Ops from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.uop.ops import sym_infer from tinygrad.tensor import Tensor class MovementOps(Enum): RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); STRIDE = auto(); AS_STRIDED = auto() # noqa: E702 def apply_mop(st: Tensor|ShapeTracker, mop_arg: Tuple[MovementOps, Tuple]) -> ShapeTracker: mop, arg = mop_arg if mop == MovementOps.RESHAPE: # shapetracker doesn't allow flattening with -1 but required for MovementOps.RESHAPE if arg == (-1,): return st.reshape((prod(st.shape),)) return st.reshape(arg) if mop == MovementOps.PERMUTE: return st.permute(arg) if mop == MovementOps.EXPAND: if len(arg) != len(st.shape): st = st.reshape((1,*st.shape)) return st.expand(arg) if mop == MovementOps.PAD: return st.pad(arg) if mop == MovementOps.SHRINK: return st.shrink(arg) if mop == MovementOps.STRIDE: assert all(x in [-1, 1] for x in arg) return st.flip(tuple(i for i,x in enumerate(arg) if x == -1)) raise ValueError("invalid mop") def make_scratch_st(st: ShapeTracker) -> ShapeTracker: return ShapeTracker.from_shape((get_buffer_size(st.views[0].shape, st.views[0].strides, st.views[0].offset, st.views[0].mask),)) # ShapeTracker to an equivalent series of MovementOps (https://github.com/tinygrad/tinygrad/pull/2216) def to_movement_ops(st: ShapeTracker) -> List[Tuple[MovementOps, Tuple]]: to_apply:List[Tuple[MovementOps, Tuple]] = [] for i, v in enumerate(st.views): real_shape = tuple(y-x for x,y in v.mask) if v.mask else v.shape offset = (v.offset or 0) + sum(st*(s-1) for s,st in zip(real_shape, v.strides) if st<0) real_offset = offset + (sum(x*st for (x,_),st in zip(v.mask, v.strides)) if v.mask else 0) real_real_shape = [s for s,st in zip(real_shape, v.strides) if st] strides: List[int] = [abs(st) if isinstance(st,int) else st for st in v.strides if st] buffer_size = sum((s-1)*st for s,st in zip(real_real_shape,strides)) + 1 if i: buffer_size = prod(st.views[i-1].shape) - real_offset if real_shape else 1 def sort_by_strides(shape, strides): return sorted(zip(shape, strides), key=lambda k: (k[1],-k[0]), reverse=True), sorted(range(len(strides)), key=lambda k: (strides[k],-real_real_shape[k]), reverse=True) ordered_shape_strides, order = sort_by_strides(real_real_shape, strides) to_apply.extend([(MovementOps.RESHAPE, (-1,)), (MovementOps.SHRINK, ((real_offset, real_offset+buffer_size),))]) if strides: if (ordered_shape_strides[0][0]*ordered_shape_strides[0][1])-buffer_size>0: to_apply.append((MovementOps.PAD, ((0, (ordered_shape_strides[0][0] * ordered_shape_strides[0][1]) - buffer_size),))) for i, shape_stride in enumerate(ordered_shape_strides): if i0 else buffer_size to_apply.append((MovementOps.EXPAND, (shape_stride[0], *(s[0] for s in ordered_shape_strides[:i]), remaining_buffer))) to_apply.append((MovementOps.PERMUTE, (*range(1,i+1), 0, i+1))) to_apply.append((MovementOps.RESHAPE, (*(s[0] for s in ordered_shape_strides[:i]), shape_stride[0]*remaining_buffer))) to_apply.append((MovementOps.PAD, (*((0,0) for _ in range(i)), (0, shape_stride[0]*shape_stride[1])))) to_apply.append((MovementOps.RESHAPE, (*(s[0] for s in ordered_shape_strides[:i+1]), remaining_buffer+shape_stride[1]))) ordered_shape_strides[i] = (ordered_shape_strides[i][0], remaining_buffer+shape_stride[1]) else: to_apply.append((MovementOps.SHRINK, (*((0, s[0]) for s in ordered_shape_strides[:i]), (0, shape_stride[0]*shape_stride[1])))) to_apply.append((MovementOps.RESHAPE, (*[s[0] for s in ordered_shape_strides[:i+1]], shape_stride[1]))) to_apply.extend([(MovementOps.SHRINK, (*[(0, s[0]) for s in ordered_shape_strides], (0,1))), (MovementOps.RESHAPE, tuple(s[0] for s in ordered_shape_strides))]) if order != list(range(len(order))): to_apply.append((MovementOps.PERMUTE, tuple(order.index(i) for i in range(len(strides))))) to_apply.append((MovementOps.RESHAPE, tuple(s if st else 1 for s,st in zip(real_shape, v.strides)))) if any(i<0 for i in v.strides): to_apply.append((MovementOps.STRIDE, tuple(-1 if st<0 else 1 for st in v.strides))) # then, we apply pre expand pads if v.mask is not None: pre_expand_pads = tuple((x,s-y) if st != 0 else (0,0) for (x,y),s,st in zip(v.mask, v.shape, v.strides)) post_expand_pads = tuple((x,s-y) if st == 0 else (0,0) for (x,y),s,st in zip(v.mask, v.shape, v.strides)) if any(x != (0,0) for x in pre_expand_pads): to_apply.append((MovementOps.PAD, pre_expand_pads)) real_shape = tuple(x+s[0]+s[1] for x,s in zip(real_shape, pre_expand_pads)) # then, we do any expands if any(s != 1 and st == 0 for s,st in zip(real_shape, v.strides)): to_apply.append((MovementOps.EXPAND, real_shape)) # lastly, we apply post expand pads if v.mask is not None and any(x != (0,0) for x in post_expand_pads): to_apply.append((MovementOps.PAD, post_expand_pads)) scratch_st = make_scratch_st(st) ret = [] seen = {} # {shapetracker: list of mops to generate that shapetracker} for mop_arg in to_apply: scratch_st = apply_mop(scratch_st, mop_arg) if scratch_st in seen: ret = seen[scratch_st][:] else: if len(ret) and ret[-1][0] == MovementOps.RESHAPE and mop_arg[0] == MovementOps.RESHAPE: ret[-1] = mop_arg else: if mop_arg == (MovementOps.RESHAPE, -1): mop_arg = (MovementOps.RESHAPE, (prod(st.shape),)) ret.append(mop_arg) seen[scratch_st] = ret[:] return ret def get_real_view(shape, strides, offset, mask): real_shape = tuple(y-x for x,y in mask) if mask else shape offset = offset + sum(st * (s-1) for s,st in zip(real_shape, strides) if st<0) real_offset = offset + (sum(x*st for (x,_),st in zip(mask, strides)) if mask else 0) real_real_shape = [s for s,st in zip(real_shape, strides) if st] strides = [abs(st) if isinstance(st,int) else st for st in strides if st] return real_real_shape, strides, real_offset def get_buffer_size(shape, strides, offset, mask): real_real_shape, strides, real_offset = get_real_view(shape, strides, offset, mask) return real_offset + sum((s-1)*st for s, st in zip(real_real_shape,strides)) + 1 def st_equivalent(st1: ShapeTracker, st2: ShapeTracker): if (idxs1:=st1.expr_idxs()) == (idxs2:=st2.expr_idxs()): return True idx1, valid1 = idxs1 idx2, valid2 = idxs2 # always invalid if valid1 == 0 and valid2 == 0: return True var1 = idx1.vars() | valid1.vars() var2 = idx2.vars() | valid2.vars() # Maybe there are cases that vars are different yet the sts are the same? if var1 != var2: return False # brute force over the vars range vs = list(var1) for i, ranges in enumerate(itertools.product(*[range(v.min, v.max+1) for v in vs])): if i > 1000: print("WARNING: did not search all possible combinations") break var_vals = {k.expr:v for k,v in zip(vs, ranges)} r1 = sym_infer(idx1, var_vals) if sym_infer(valid1, var_vals) else 0 r2 = sym_infer(idx2, var_vals) if sym_infer(valid2, var_vals) else 0 if r1 != r2: return False return True c: DefaultDict[int,int] = defaultdict(int) def test_rebuild(st: ShapeTracker): rebuilt_st = make_scratch_st(st) mops = to_movement_ops(st) c[len(mops)] += 1 for mop_arg in mops: rebuilt_st = apply_mop(rebuilt_st, mop_arg) rebuilt_st = rebuilt_st.simplify() # why is the "all(x == 0 for x in rebuilt_st.views[-1].strides)" hack needed? assert st_equivalent(st, rebuilt_st) or all(x == 0 for x in rebuilt_st.views[-1].strides), f"mismatch {st} {rebuilt_st}" last_v1 = st.views[-1] last_v2 = rebuilt_st.views[-1] assert last_v1.shape == last_v2.shape, f"{last_v1.shape} != {last_v2.shape}" def test_rebuild_bufferop_st(ast:UOp): if ast.op is Ops.SHAPETRACKER: test_rebuild(ast.arg) for src in ast.src: test_rebuild_bufferop_st(src) if __name__ == "__main__": from extra.optimization.helpers import load_worlds, ast_str_to_ast ast_strs = load_worlds(False, False, True)[:2000] for ast_str in tqdm(ast_strs): test_rebuild_bufferop_st(ast_str_to_ast(ast_str)) print(f"avg length of mop = {sum(k*v for k,v in c.items()) / sum(c.values()):.2f}")