import itertools from enum import Enum, auto from typing import List, Tuple from tinygrad.helpers import prod, dedup, all_same, colored from tinygrad.ops import LazyOp, MovementOps, get_lazyop_info, get_buffers, ReduceOps, get_lazyops, map_buffers from tinygrad.shape import ShapeTracker, View, strides_for_shape def get_first_reduce(shapes): for i in range(len(shapes[0])): if not all_same([x[i] for x in shapes]): return i return len(shapes[0]) # off the end # this will be removed soon anyway class Types(Enum): FLOAT = auto(); FLOAT4 = auto() # noqa: E702 class Token: def __init__(self, tok:str, typ:Types, ptr:bool=False): assert isinstance(tok, str) self.tok, self.typ, self.ptr = tok, typ, ptr self.axis : List[Tuple[int, int, bool]] = [] def array(self, length, stride, reduce): self.axis.append((length, stride, reduce)) def size(self): return prod([x[0] for x in self.axis]) def offsets(self): return [sum(t) for t in itertools.product(*[[y*x[1] for y in range(x[0])] for x in self.axis[::-1]])] if len(self.axis) else [0] def can_float4(self): return any(a[0:2] == (4,1) for a in self.axis) # TODO: this is sort of a hack, it gets the accumulator indices def acc_offsets(self): if len(self.axis) == 0: return [0] acc_strides = [x*(1-self.axis[::-1][i][2]) for i,x in enumerate(strides_for_shape(tuple(1 if r else s for s,_,r in self.axis[::-1])))] return [sum(t) for t in itertools.product(*[[y*acc_strides[i] for y in range(x[0])] for i,x in enumerate(self.axis[::-1])])] def decltype(self): return ('float' if self.typ == Types.FLOAT else 'float4') + ('*' if self.ptr else str()) def __repr__(self): return f"<{self.typ}{'*' if self.ptr else str()} {self.tok}{f'[{self.axis}]' if len(self.axis) else str()}>" # ast kernel can contain one ReduceOp with arbitrary Binary/Unary ops class ASTKernel: def __init__(self, ast:LazyOp, output_buffer=None): self.input_ast = ast # if the AST ends with a RESHAPE, we remove it and create the buffer accordingly if ast.op == MovementOps.RESHAPE: output_shape = ast.arg ast = ast.src[0] else: output_shape = None self.info = get_lazyop_info(ast) self.bufs = dedup(get_buffers(ast)) for b in self.bufs: b.st.simplify() self.ast = ast # check if the output buffer is allowed to be used # if it's aliased, don't use it if output_buffer is not None: for a in self.bufs: if a._buf == output_buffer._buf and not a.st.contiguous: output_buffer = None break # create the buffer we are returning (as the same type as the input buffers) and add it as the first buffer self.ret = output_buffer if output_buffer else type(self.bufs[0])(output_shape if output_shape else self.info.shape, force_create=True) self.bufs = ([type(self.ret)(self.info.shape, hostbuf=self.ret)] if output_shape else [self.ret]) + self.bufs # key for lookup in cache (can change, str might not be right) # bufs are needed because kernels like f(x) = x + x and f(x, y) = x + y have the same str(ast), but are different kernels. # mapping the buffers to integers is required because a-b != b-a (and how would you tell a and b apart?) self.key = f"ASTKernelKey ast={str(map_buffers({x:i for i,x in enumerate(self.bufs)}, ast))} bufs={self.bufs}" def process(self) -> None: if hasattr(self, "sts"): return # already processed reduceops = [x for x in get_lazyops(self.ast) if x.op in ReduceOps] assert len(dedup(reduceops)) <= 1, "max one reduce op in an ast" self.reduceop = reduceops[0] if reduceops else None self.earlybufs = dedup(get_buffers(self.reduceop)) if self.reduceop else [] self.buftokens = [Token(f"data{i}", Types.FLOAT, ptr=True) for i in range(len(self.bufs))] self.group_for_reduce : List[int] = [] # check valid AST kernel assert all_same([x.shape for x in self.earlybufs]), "all earlybufs must have the same shape" assert all_same([x.shape for x in self.bufs if x not in self.earlybufs]), "all latebufs must have the same shape" assert all_same([len(x.shape) for x in self.bufs]), "all bufs must have the same shape size" # process self.sts : List[ShapeTracker] = [x.st.copy() for x in self.bufs] # create new shapetrackers inside this kernel self.simplify_ones() self.simplify_merge_adjacent() # get full shape buf index (earlybufs if there are any, otherwise output) self.full_buf_index : int = self.bufs.index(self.earlybufs[0]) if len(self.earlybufs) > 0 else 0 def print(self): buf_count, op_count, cache = -1, -1, {} def print_ast(x, name=None): nonlocal buf_count, op_count if x not in cache: if not isinstance(x, LazyOp): if name is None: buf_count += 1 name = f"buf{buf_count}" print(f"buf{buf_count} = {x}") cache[x] = name else: srcs = [print_ast(y) for y in x.src] if name is None: op_count += 1 name = f"op{op_count}" print(f"{name} = LazyOp({str(x.op)}, ({','.join(srcs)},), {x.arg})") cache[x] = name return cache[x] print_ast(self.input_ast, "ast") def printbufs(self, prefix="", print_shapetrackers=False): print(f"first_reduce: {self.first_reduce} shape_len: {self.shape_len} group_for_reduce: {self.group_for_reduce}") if print_shapetrackers: for st in self.sts: print(st) for i in range(len(self.sts)): print(prefix, self.buftokens[i], f"early:{'T' if i < len(self.bufs) and self.bufs[i] in self.earlybufs else 'F'}", self.sts[i].shape, self.sts[i].views[-1].strides, len(self.sts[i].views), type(self.bufs[i]._buf) if self.bufs[i] is not None else "FAKE") @property def shape_len(self) -> int: return len(self.sts[0].shape) @property def full_shape(self) -> Tuple[int, ...]: return self.sts[self.full_buf_index].shape @property def upcast_in_mid_reduce_axes(self): return [j for j in range(self.first_reduce, self.first_reduce+len(self.group_for_reduce)) if self.full_shape[j] == self.sts[0].shape[j]] def colorshape(self, pad=50) -> str: axis = [(f"{rs:4d}", (("green" if i in self.upcast_in_mid_reduce_axes else "cyan") if i < self.first_reduce + len(self.group_for_reduce) else "red") if i >= self.first_reduce else "blue") for i, rs in enumerate(self.full_shape)] axis += [(f"{s:4d}", 'magenta' if reduce else 'yellow') for s, _, reduce in self.buftokens[self.full_buf_index].axis[::-1]] return ' '.join([colored(*x) for x in axis])+(" "*(pad-len(' '.join([x[0] for x in axis])))) def simplify_ones(self): # remove places where the shape is all ones # TODO: this should be factored in to multi shape stride all_ones = [all(st.shape[i]==1 for st in self.sts) for i in range(self.shape_len)] # keep at least 1 one if all(all_ones): all_ones[-1] = False self.reshape_and_permute(lambda shape: [x for i,x in enumerate(shape) if not all_ones[i]], None) # find first mismatch, don't reduce this self.first_reduce = get_first_reduce([x.shape for x in self.sts]) def simplify_merge_adjacent(self): shapes, strides = [x.shape for x in self.sts], [x.views[-1].strides for x in self.sts] # merge dimensions if we can, multi get_shape_strides # TODO: does this always preserve the reduce dimension, NO # TODO: move this into shapetracker, with tests! rets = [[(shapes[j][0], strides[j][0])] for j in range(len(shapes))] for i in range(1, len(shapes[0])): can_merge = [] for j in range(len(shapes)): # TODO: added the always mergeability of 1s, is this right? if so, add to shapetracker in the 1 case can_merge.append((strides[j][i] != 0 and rets[j][-1][1] == shapes[j][i]*strides[j][i]) or (strides[j][i] == 0 and rets[j][-1][1] == 0)) # more can merge than this mergeable = all(can_merge) and i != self.first_reduce for j in range(len(shapes)): if mergeable: rets[j][-1] = (rets[j][-1][0] * shapes[j][i], strides[j][i]) else: rets[j].append((shapes[j][i], strides[j][i])) for i,x in enumerate(rets): self.sts[i].reshape(tuple(y[0] for y in x)) self.first_reduce = get_first_reduce([x.shape for x in self.sts]) # this should be aware of the three parts to the shape # * the input/output dimensions # * the reduce dimensions # * the size outputted by each kernel def reshape_and_permute(self, new_shape_fxn, axis): for st in self.sts: if new_shape_fxn is not None: st.reshape(tuple(new_shape_fxn(st.shape))) if axis is not None: st.permute(tuple(axis)) # axis : the axis to pull from # amount : the amount to take # top : if you want to pull that amount from the top # insert_before : place to insert the new stuff def shift_to(self, axis, amount, top=False, insert_before=None): if insert_before is None: insert_before = self.shape_len move_axis = axis if top else axis+1 if move_axis < insert_before: insert_before += 1 self.reshape_and_permute( lambda x: list(x[0:axis]) + (([amount, x[axis]//amount] if top else [x[axis]//amount, amount]) if x[axis] > 1 else [1,1]) + list(x[axis+1:]), [i for i in range(insert_before) if i != move_axis] + [move_axis] + [i for i in range(insert_before, self.shape_len+1) if i != move_axis]) # drops the final dimension def upcast(self): upcasted = [x.shape[-1] for x in self.sts if x.shape[-1] != 1] assert len(upcasted) >= 1 and all_same(upcasted), f"can't upcast mismatch {upcasted}" for st,buftoken in zip(self.sts, self.buftokens): # add last axis to the buftoken (if it's not a 1) if st.shape[-1] == upcasted[0]: buftoken.array(st.shape[-1], st.views[-1].strides[-1], len(upcasted) != len(self.sts)) # remove the last axis (unless it's the only dimension, then make it a 1) st.views[-1] = View(st.shape[0:-1], st.views[-1].strides[0:-1], st.views[-1].offset) if len(st.shape) > 1 else View((1,), (0,), st.views[-1].offset)