from __future__ import annotations
import itertools, functools, math
from dataclasses import dataclass
from collections import defaultdict
from typing import Optional, cast, Final, Callable, Sequence

from tinygrad.uop.ops import GroupOp, KernelInfo, UOp, Ops, can_pad, resolve, Variable, sint, graph_rewrite, track_rewrites, print_uops
from tinygrad.uop.ops import PatternMatcher, smax
from tinygrad.uop.spec import type_verify, ast_spec
from tinygrad.device import Device
from tinygrad.renderer import Renderer, TensorCore, ProgramSpec, Opt, OptOps
from tinygrad.dtype import ImageDType
from tinygrad.helpers import all_same, colored, ansilen, dedup, getenv, prod, round_up, all_int, to_function_name, unwrap
from tinygrad.helpers import DEBUG, TC_SELECT, TC_OPT, AMX
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import strides_for_shape
from tinygrad.codegen.lowerer import get_contraction
from tinygrad.engine.grouper import view_left
from tinygrad.codegen import full_rewrite

class KernelOptError(Exception): pass

def check(cond:bool, msg:str=""):
  if not cond: raise KernelOptError(msg)

@dataclass
class TensorCoreOptions:
  axes: tuple[int, ...] # the location of the original N and M axes if still in the shape
  axes_exist: tuple[bool, ...] # true if the original N and M axes are still in the shape
  axis_pads: tuple[tuple[int, int], ...]
  def fix_axes(self, removed_axis:int): # adjust the TC axes if necessary when a dimension is removed
    axes, axes_exist = list(self.axes), list(self.axes_exist)
    for tc_dim in [i for i in range(2) if axes_exist[i]]:
      if removed_axis < axes[tc_dim]: axes[tc_dim] -= 1
      elif removed_axis == axes[tc_dim]: axes_exist[tc_dim] = False
    self.axes, self.axes_exist = tuple(axes), tuple(axes_exist)

class Kernel:
  def __init__(self, ast:UOp, opts:Optional[Renderer]=None):
    assert ast.op is Ops.SINK, ast.op
    self.ast = ast

    self.opts = opts if opts is not None else Device[Device.DEFAULT].renderer
    # verify AST matches the spec
    if __debug__: type_verify(list(self.ast.toposort()), ast_spec)

    self.reduceops = [x for x in self.ast.toposort() if x.op is Ops.REDUCE_AXIS]

    self.vars: list[Variable] = self.ast.variables()
    # NOTE: this requires a specific order with the [::-1], this is likely a bug
    self.bufs: list[UOp] = [x for x in self.ast.toposort() if x.op in GroupOp.Buffer][::-1]

    # create new shapetrackers inside this kernel, we will permute them
    self.sts: list[ShapeTracker] = [x.st_arg for x in self.bufs]

    # add the shapetrackers for each reduce
    # we use this to track which axes are reduced in each reduce
    for x in self.reduceops:
      self.sts.append(unwrap(x.st))
      self.sts.append(unwrap(x.src[0].st))

    # add a shapetracker to the end to track the full shape, with 0 strides so it can merge
    self.sts.append(ShapeTracker.from_shape(tuple([smax(*s) for s in zip(*[x.shape for x in self.sts])]), (0,)*self.shape_len))

    # move all reduce axes to the end
    reduce = list(enumerate(zip(self.full_shape, self.output_shape)))
    permute = tuple([i for i,(s,n) in reduce if not resolve(s != n)] + [i for i,(s,n) in reduce if resolve(s != n)])
    self.reshape_and_permute(None, permute)

    # parameters for optimization
    self.applied_opts: list[Opt] = []
    self.group_for_reduces: int = 0
    self.upcasted: int = 0
    self.local_dims: int = 0
    self.tensor_core: Optional[TensorCore] = None
    self.tensor_core_opts: Optional[TensorCoreOptions] = None
    self.use_tensor_cores: int = 0
    self.dont_use_locals: bool = False

    # group simplifies
    self.simplify_ones()
    self.simplify_merge_adjacent()

  def copy(self):
    ret = type(self).__new__(type(self))

    # base linearizer params
    ret.opts, ret.ast = self.opts, self.ast

    # things downstream of the AST
    ret.reduceops, ret.vars, ret.bufs = self.reduceops, self.vars, self.bufs
    ret.sts = self.sts[:]

    # parameters for optimizations
    ret.applied_opts, ret.group_for_reduces, ret.upcasted, ret.local_dims, ret.dont_use_locals = \
      self.applied_opts[:], self.group_for_reduces, self.upcasted, self.local_dims, self.dont_use_locals
    ret.tensor_core, ret.tensor_core_opts, ret.use_tensor_cores = self.tensor_core, self.tensor_core_opts, self.use_tensor_cores

    return ret

  @property
  def membufs(self) -> list[UOp]: return dedup([x.src[0].base for x in self.bufs if x.op in {Ops.LOAD, Ops.STORE}])

  def upcasted_axis(self, i:int) -> list[tuple[int, Optional[sint], bool]]:
    upcasted_shape, upcasted_stride = self.sts[i].shape[self.first_upcast:], self.sts[i].real_strides()[self.first_upcast:]
    assert all_int(upcasted_shape), f"cannot upcast a symbolic amount {upcasted_shape=}"
    return list(zip(upcasted_shape, upcasted_stride,
                    [x!=y for x,y in zip(self.sts[0].shape[self.first_upcast:], self.full_shape[self.first_upcast:])]))

  @property
  def first_reduce(self) -> int:
    return [resolve(x!=y) for x,y in zip(self.sts[0].shape[:self.first_upcast]+(0,), self.full_shape[:self.first_upcast]+(1,))].index(True)

  @property
  def first_upcast(self) -> int: return self.shape_len-self.upcasted

  @property
  def reduceop(self) -> UOp|None: return self.reduceops[0] if len(self.reduceops) > 0 else None

  @property
  def output_shape(self) -> tuple[sint, ...]: return self.sts[0].shape

  @property
  def full_shape(self) -> tuple[sint, ...]: return self.sts[-1].shape

  @property
  def full_unupcasted_shape(self) -> tuple[sint, ...]: return self.full_shape[:self.first_upcast]

  @property
  def shape_len(self) -> int: return len(self.sts[0].shape)

  @property
  def global_dims(self) -> int: return self.first_reduce-self.local_dims

  # there's eight chunks of the shape
  # blue   -- global dims
  # cyan   -- local dims (warp ones first)
  #  *** self.first_reduce
  # green  -- reduce-local dims
  # red    -- reduce loops
  #  *** self.upcasted
  # purple -- reduce upcasted
  # yellow -- normal upcasted dimensions
  def colors(self) -> list[str]:
    # first non local non reduce dims are global (blue)
    colors = ["blue"] * self.global_dims if not self.dont_use_locals else ["BLUE"] * self.global_dims
    # after global are local_dims; warp ones used in tensor cores must be closest to first_reduce (cyan)
    colors += ["cyan"] * self.local_dims
    # between first_reduce and first_reduce + group_for_reduces, they are late upcasted (green)
    colors += ["green"] * self.group_for_reduces
    # between first_reduce + group_for_reduces and upcasted, they are reduce (red)
    colors += ["red"] * (self.first_upcast - (self.first_reduce + self.group_for_reduces))
    # upcasted dimensions are reduce (magenta) or normal (yellow)
    colors += ["magenta" if self.full_shape[i] != self.sts[0].shape[i] else "yellow" for i in range(self.first_upcast, self.shape_len)]
    assert len(colors) == self.shape_len, "colors size mismatch"
    return colors

  def colored_shape(self, pad:Optional[int]=None, dense=False) -> str:
    shape_strs = [(s if dense else f"{s:4d}") if isinstance(s, int) else s.render() for s in self.full_shape]
    ret = ' '.join(colored(s, color) for s,color in zip(shape_strs, self.colors()))
    if pad: ret += ' '*(pad-ansilen(ret))
    return ret

  # ******************** base simplifiers ********************

  # apply reshape and permute to all shapetrackers
  def reshape_and_permute(self, new_shape_fxn:Optional[Callable[[tuple[sint, ...]], Sequence[sint]]], axis:Optional[Sequence[int]]):
    def reshape(st:ShapeTracker): return st.reshape(tuple(new_shape_fxn(st.shape))) if new_shape_fxn is not None else st
    def permute(st:ShapeTracker): return st.permute(tuple(axis)) if axis is not None else st
    self.sts = [permute(reshape(st)) for st in self.sts]

  # drops the final dimension
  def upcast(self):
    check(self.full_shape[-1] != 1, "can't upcast a dimension with size 1")
    self.upcasted += 1

  # axis : the axis to pull from
  # amount : the amount to take
  # top : if you want to pull that amount from the top
  # insert_before : place to insert the new stuff
  def shift_to(self, axis, amount, top=False, insert_before=None):
    if insert_before is None: insert_before = self.shape_len
    move_axis = axis if top else axis+1
    if move_axis < insert_before: insert_before += 1
    self.reshape_and_permute(
      lambda x: x[0:axis] + (((amount, x[axis]//amount) if top else (x[axis]//amount, amount)) if x[axis] > 1 else (1,1)) + x[axis+1:],
      [i for i in range(insert_before) if i != move_axis] + [move_axis] + [i for i in range(insert_before, self.shape_len+1) if i != move_axis])

  # ******************** complex simplifiers ********************

  def simplify_ones(self) -> bool:
    # remove places where the shape is all ones
    # TODO: this should be factored in to multi shape stride
    if self.shape_len == 0: return False
    all_ones = [s==1 for s in self.full_shape]
    self.local_dims -= sum(all_ones[self.first_reduce-self.local_dims:self.first_reduce])
    self.upcasted -= sum(all_ones[self.first_upcast:]) # TODO: no necessary since upcasted axis can't be un-upcasted
    self.reshape_and_permute(lambda shape: [x for i,x in enumerate(shape) if not all_ones[i]], None)
    return any(all_ones)

  def simplify_merge_adjacent(self):
    if self.shape_len == 0: return
    shapes, strides = [x.shape for x in self.sts], [x.real_strides() for x in self.sts]

    # if it's an image, insert fake strides such that this fusion doesn't happen across image axes
    if isinstance(self.membufs[0].dtype, ImageDType):
      base_shape = self.membufs[0].dtype.shape
      if shape_idx_groups := get_contraction(self.output_shape, base_shape):
        special_strides: tuple[sint, ...] = tuple()
        for i,g in enumerate(shape_idx_groups):
          shape_piece = tuple(self.output_shape[x] for x in g)
          assert prod(shape_piece) == base_shape[i], f"get_contraction was wrong? {shape_piece} != {base_shape[i]}"
          special_strides += strides_for_shape(shape_piece)
        # adding the fake image shape
        shapes.append(self.output_shape)
        strides.append(special_strides)

    # merge dimensions if we can, multi _merge_dims
    # NOTE: this does not always preserve the reduce dimension
    # TODO: move this into shapetracker, with tests!
    # TODO: how does this work with multi-reduce?
    rets = [[(s[0], st[0])] for s,st in zip(shapes, strides)]
    for i in range(1, len(shapes[0])):
      can_merge = []
      for s,st,ret in zip(shapes, strides, rets):
        # TODO: added the always mergeability of 1s, is this right? if so, add to shapetracker in the 1 case
        si, sti, last_st = s[i], st[i], ret[-1][1]
        can_merge.append((sti is not None) and ((sti != 0 and last_st == si*sti) or (sti == 0 and last_st == 0)))
      # more can merge than this
      mergeable = all(can_merge) and i != self.first_reduce
      for j,(s,st) in enumerate(zip(shapes, strides)):
        if mergeable: rets[j][-1] = (rets[j][-1][0] * s[i], st[i])
        else: rets[j].append((s[i], st[i]))

    # do the reshapes
    for i,x in enumerate(rets[:len(self.sts)]): self.sts[i] = self.sts[i].reshape(tuple([y[0] for y in x]))

  # ******************** high level optimizers ********************

  def _create_tc_opts(self, reduceop:UOp, tc:TensorCore, axis:int, opt_level:int) -> Optional[TensorCoreOptions]:
    has_cast = tc.dtype_in != tc.dtype_out
    if has_cast and not (reduceop.src[0].op is Ops.CAST and reduceop.src[0].dtype == tc.dtype_out): return None

    mul_op = reduceop.src[0].src[0] if has_cast else reduceop.src[0]
    if mul_op.op is not Ops.MUL: return None

    def buf_index(src:UOp) -> Optional[int]:
      # TODO: apply tc even if the sources are not from LOAD
      if src.op is Ops.LOAD and src.dtype == tc.dtype_in: return self.bufs.index(src)
      try:
        if opt_level >= 1 and src.op is Ops.CAST and src.dtype == tc.dtype_in: return self.bufs.index(src.src[0])
      except ValueError: return None
      return None
    if (buf0:=buf_index(mul_op.src[0])) is None or (buf1:=buf_index(mul_op.src[1])) is None: return None

    buf0_strides, buf1_strides = self.sts[buf0].real_strides(), self.sts[buf1].real_strides()
    axis_buf0 = [(i,self.full_shape[i],buf1_strides[i]) for i,s in enumerate(buf0_strides[:self.first_reduce]) if s == 0]
    axis_buf1 = [(i,self.full_shape[i],buf0_strides[i]) for i,s in enumerate(buf1_strides[:self.first_reduce]) if s == 0]
    if not (axis_buf0 and axis_buf1 and ((self.shape_len-self.first_reduce) == 1 or (opt_level >= 1))): return None

    axis_choices = list(itertools.product(axis_buf0, axis_buf1, range(self.first_reduce, self.shape_len)))
    if not (axis < len(axis_choices)): return None

    s0, s1, s2 = axis_choices[-(axis+1)][0][0], axis_choices[-(axis+1)][1][0], axis_choices[-(axis+1)][2]  # s0 is n, s1 is m, s2 is k
    axis_pads = tuple((x, tc.dims[i]) for i, x in enumerate([s0, s1, s2]) if resolve(self.full_shape[x]%tc.dims[i] != 0))
    if axis_pads and (opt_level < 2): return None
    if DEBUG >= 3: print("TENSOR CORES", axis_buf0, axis_buf1, tc)
    return TensorCoreOptions(axes=(s0, s1, s2), axes_exist=(True, True), axis_pads=axis_pads)

  def _apply_tc_opt(self, use_tensor_cores:int, axis:int, tc_select:int, opt_level:int) -> bool:
    if use_tensor_cores and self.reduceop is not None and self.reduceop.arg[0] is Ops.ADD:
      tensor_cores = self.opts.tensor_cores if tc_select == -1 else [self.opts.tensor_cores[tc_select]]
      for tc in tensor_cores:
        tensor_core_opts = [self._create_tc_opts(reduceop, tc, axis, opt_level) for reduceop in self.reduceops]
        # can only fuse reduces with the same tc options
        assert all_same(tensor_core_opts)
        if tensor_core_opts[0] is None: continue
        self.tensor_core_opts = tc_opts = tensor_core_opts[0]

        # attempt to pad the tensor axes that require it
        try:
          for axis, dim in tc_opts.axis_pads: self.apply_opt(Opt(OptOps.PADTO, axis, dim), append_opt=False) # PADTO might fail
        except KernelOptError: continue
        # tensor core -- unroll the reduce dim (K), upcast and local the inner and outer dims (N, M)
        for dim, amt in tc.get_reduce_axes(): self.apply_opt(Opt(OptOps.UNROLL, tc_opts.axes[2]-self.first_reduce, amt), append_opt=False)
        for opt in tc.opts: self.apply_opt(Opt({"u":OptOps.UPCAST, "l":OptOps.LOCAL}[opt[0]], tc_opts.axes[int(opt[1])], 2), append_opt=False)
        self.tensor_core = tc
        self.use_tensor_cores = use_tensor_cores  # TC=2 will do the shape ops without the WMMA
        return True
    return False

  def apply_tensor_cores(self, use_tensor_cores=1, extra_opts:Optional[list[Opt]]=None, axis:int=0, tc_select:Optional[int]=None,
                         tc_opt:Optional[int]=None) -> bool:
    """ Attempts to apply a tensor core optimization to the kernel. If one exists and applies properly, return true, otherwise return false.
    Tensor cores are optimized instructions that matrix multiply-accumulate across a wave of threads: D(M, N) = A(M, K) * B(K, N) + C(M, N).

    Keyword arguments:
    use_tensor_cores -- controls how tensor cores are applied (default 1)
      0: will disable any tensor core matching
      1: enable tensor cores
      2: apply tensor core shape but don't use UOp.WMMA
      3: emulate tensor cores with local memory
    extra_opts -- additional Opt's to apply after the tensor core instead of the hand-coded additional Opt's (default None)
    tc_select -- specifies which tensor core(s) to use for optimization (default -1)
      -1: iterates through all available tensor cores in order and uses the first one that matches the requirements (dims and dtypes)
      [0-N]: uses only the n'th tensor core available; useful for search
    tc_opt -- controls which kinds of kernels may be eligible for tensor cores application (default 2 during BEAM, 0 otherwise)
      0: applies to only kernels with a single reduce axis and direct Ops.LOAD into Ops.MUL
      1: allows kernels with multiple reduce axes and also multiplication of Ops.CAST'd buffers
      2: allows kernels with M, N, K axes that are not multiples of the tensor core dimensions by applying padding those axes as needed
    """
    if tc_select is None: tc_select = TC_SELECT.value
    if tc_opt is None: tc_opt = TC_OPT.value
    if not self.opts.tensor_cores: return False
    try: # check TC first and apply hand-coded opts if successful
      self.apply_opt(Opt(OptOps.TC, axis, (tc_select, tc_opt, use_tensor_cores)))

      if (tc_opts:=self.tensor_core_opts) is not None:
        if extra_opts is not None: self.apply_opts(extra_opts)
        else:
          if AMX: return True # skip hand-coded TC opts if AMX, upcasting will make kernel slower
          # hand-coded TC opts
          for tc_dim in [tc_dim for tc_dim in [1,0] if tc_opts.axes_exist[tc_dim]]: # attempt to upcast M and N
            szs = [sz for sz in [5,4,3,2] if self.full_shape[tc_opts.axes[tc_dim]] % sz == 0]
            if szs: self.apply_opt(Opt(OptOps.UPCAST, tc_opts.axes[tc_dim], szs[0]))

          if tc_opts.axes_exist[0] and (szs := [sz for sz in [4,2] if self.full_shape[tc_opts.axes[0]] % sz == 0]): # attempt to local N
            self.apply_opt(Opt(OptOps.LOCAL, tc_opts.axes[0], szs[0]))
      return True
    except KernelOptError:
      return False

  def real_axis(self, opt:Opt):
    if opt.axis is None: return -1
    if opt.op is OptOps.UNROLL: return self.first_reduce+opt.axis
    if opt.op in {OptOps.GROUP, OptOps.GROUPTOP}: return self.first_reduce+self.group_for_reduces+opt.axis
    return opt.axis

  def apply_opt(self, opt:Opt, append_opt:bool=True):
    if self.dont_use_locals: check(opt.op not in {OptOps.LOCAL, OptOps.GROUP, OptOps.GROUPTOP}, "not using locals")

    if opt.op is OptOps.TC:
      check(len(self.applied_opts) == 0, "tensor core opts must be first") # TODO: things like PADTO might be fine
      check(len(self.opts.tensor_cores) > 0, "must have tensor cores")
      check(opt.axis is not None, "tensor core opts must have an axis")
      check(opt.arg is not None and isinstance(opt.arg, tuple) and len(opt.arg) == 3, "tensor core opts must have valid arg")
      check(-1 <= (tc_select:=cast(tuple, opt.arg)[0]) < len(self.opts.tensor_cores), "tensor core opts must have valid tc_select")
      check(0 <= (tc_opt:=cast(tuple, opt.arg)[1]) <= 2, "tensor core opts must have valid tc_opt")
      check(0 < (use_tensor_cores:=cast(tuple, opt.arg)[2]) <= 3, "use_tensor_cores value is not valid")
      check(self._apply_tc_opt(use_tensor_cores, cast(int, opt.axis), tc_select, tc_opt), "no tensor core available")
      self.applied_opts.append(opt)
      return

    axis = self.real_axis(opt)
    check(axis < len(self.full_shape), "invalid axis")

    if opt.op is OptOps.SWAP: amt = cast(int, opt.arg)  # arg is an axis in the SWAPs
    elif opt.arg is not None:
      check(isinstance(opt.arg, int), "arg should be int")
      amt = arg if (arg:=cast(int, opt.arg)) != 0 else self.full_shape[axis]
      check(isinstance(amt, int) and amt != 1, f"shift/padto of {amt=}, 1 or symbolic amount is meaningless")
      if opt.op is not OptOps.PADTO: check(self.full_shape[axis] % amt == 0, f"no longer valid shift {self.full_shape[axis]=}, {amt=}")
    else: amt = -1

    if self.reduceop is not None and (opt.op in {OptOps.GROUP, OptOps.GROUPTOP} or \
                                      (self.group_for_reduces and opt.op not in {OptOps.NOLOCALS, OptOps.PADTO})):
      acc_sz = self.reduceop.dtype.itemsize
      upcast_sz = prod([a for a,b in zip(self.full_shape[self.first_upcast:], self.sts[0].shape[self.first_upcast:]) if a == b])
      local_sz = prod(self.full_shape[self.first_reduce-self.local_dims:self.first_reduce+self.group_for_reduces])
      smem_sz = amt*acc_sz*upcast_sz*local_sz
      check(smem_sz <= self.opts.shared_max, f"exceeds maximum shared memory size: needs {smem_sz}, max {self.opts.shared_max}")

    if opt.op is OptOps.LOCAL:    # cyan
      # NOTE: LLVM/CPU can use locals too, but they are treated the same as globals (still helpful for L1 cache)
      # it's disabled for now since it makes BEAM slow for little gain
      check(self.opts.has_local, "target does not support local")
      check(axis < self.global_dims, "local is for globals")
      self.shift_to(axis, amt, insert_before=self.first_reduce)
      self.local_dims += 1
    elif opt.op in {OptOps.GROUP, OptOps.GROUPTOP}:   # green
      check(self.opts.has_local and self.opts.has_shared, "target does not support local or shared mem")
      check(self.first_reduce + self.group_for_reduces <= axis < self.first_upcast, "must be reduce axis to group")
      check(not self.tensor_core, "can't group with tensor cores")
      check(len(reduce_axes:=[i for r in self.reduceops for i in r.axis_arg]) == len(set(reduce_axes)), "can't group with parallel reduces")
      self.shift_to(axis, amt, top=(opt.op is OptOps.GROUPTOP), insert_before=self.first_reduce + self.group_for_reduces)
      self.group_for_reduces += 1
    elif opt.op is OptOps.UNROLL:                     # purple
      check(axis < self.first_upcast, "can't upcasted already upcasted")
      check(amt <= 32, "don't unroll more than 32")
      # TODO: fix upcast_count to put purples before yellows. broken because of METAL tensor cores
      #upcast_count = sum(x == y for x,y in zip(self.full_shape[-self.upcasted:], self.output_shape[-self.upcasted:])) if self.upcasted else 0
      #self.shift_to(axis, amt, insert_before=None if upcast_count == 0 else self.shape_len-upcast_count)
      if self.full_shape[axis] == amt and axis == self.first_reduce: self.local_dims += 1 # first_reduce will ++, so offset loss in simplify_ones
      if self.full_shape[axis] == amt and axis < self.first_reduce+self.group_for_reduces: self.group_for_reduces -= 1 # fully unrolling a GROUP
      self.shift_to(axis, amt, insert_before=None)
      self.upcast()
    elif opt.op is OptOps.UPCAST:                     # yellow
      check(axis < self.first_reduce, "upcast is for non-reduce")
      check(not (self.tensor_core and self.global_dims <= axis < self.global_dims+len(self.tensor_core.get_local_axes())), "can't upcast TC locals")
      check((self.opts is not None and self.opts.device == "DSP") or amt <= 16, "don't upcast more than 16")
      self.shift_to(axis, amt, insert_before=None)
      self.upcast()
    elif opt.op is OptOps.NOLOCALS:
      check(self.opts.has_local and not self.dont_use_locals, "NOLOCALS is meaningless if target does not support local or already not using locals")
      check(self.local_dims == 0 and self.group_for_reduces == 0, "can't have no locals with locals")
      self.dont_use_locals = True
    elif opt.op is OptOps.SWAP:
      check(axis < amt < self.global_dims, f"swap is only for globals with axis < amt, getting {amt=}, {axis=}, {self.global_dims=}")
      permute = list(range(self.shape_len))
      permute[axis], permute[amt] = permute[amt], permute[axis]
      self.reshape_and_permute(None, tuple(permute))
    elif opt.op is OptOps.PADTO:
      check(not self.vars, "does not work with symbolic shape")
      check(axis < self.first_upcast, "cannot pad upcasted")
      # ok to pad SUM if all parent ALU ops have f(0) = 0
      if (r:=self.reduceop) is not None and self.first_reduce <= axis: check(r.arg[0] is Ops.ADD and can_pad(r, {}), f"cannot pad {r}")
      padded = False
      for i,st in enumerate(self.sts):
        if (s:=st.shape[axis]) == 1: continue  # reduced
        check(s > amt//4, f"pad adds more than quadruple the work {st.shape[axis]=} > {amt//4=}")
        if (ru := round_up(cast(int, s), amt) - s):
          # pad right seems to be faster
          self.sts[i] = st.pad(((0,0),) * axis + ((0,ru),) + ((0,0),) * (len(st.shape)-axis-1))
          padded = True
      check(padded, "nothing was padded")

    if append_opt: self.applied_opts.append(opt)
    if self.simplify_ones() and self.tensor_core_opts:
      self.tensor_core_opts.fix_axes(axis) # fix up axes in TC opts if required after simplify_ones()

  def apply_opts(self, opts:Sequence[Opt]) -> Kernel:
    for opt in opts: self.apply_opt(opt)
    return self

  # **** kernel outputs ****

  kernel_cnt: Final[defaultdict[str, int]] = defaultdict(int)
  @functools.cached_property
  def name(self) -> str:
    # kernel name (before late upcast)
    kernel_type = "r" if self.reduceop is not None else ("C" if all(x.op is Ops.SINK or x.op in GroupOp.Buffer for x in self.ast.toposort()) else "E")
    suffix = colored('_', 'BLACK').join([colored(x.render() if isinstance(x, UOp) else str(x), c) for x,c in zip(self.full_shape, self.colors())])
    name = kernel_type + (f"{len(self.ast.src)}" if len(self.ast.src) > 1 else "") + "_" + suffix

    # name the function something unique
    Kernel.kernel_cnt[(function_name := to_function_name(name))] += 1
    num = f"n{Kernel.kernel_cnt[function_name]-1}" if Kernel.kernel_cnt[function_name] > 1 else ""
    return name + colored(num, 'BLACK')

  def get_optimized_ast(self, name_override:Optional[str]=None) -> UOp:
    @functools.cache
    def fixup_ast(op:UOp) -> UOp:
      ret = op.replace(src=tuple(fixup_ast(x) for x in op.src)) # noqa: F821
      if op.op in GroupOp.Buffer and op in self.bufs:
        st = self.sts[self.bufs.index(op)]
        # NOTE: if CONST got masked after applying opts, we create a new VALID
        if op.op is Ops.CONST and any(v.mask is not None for v in st.views): return op.view(st).valid()
        # otherwise we just replace the VIEW source
        return ret.replace(src=(ret.src[0].replace(arg=st),)+ret.src[1:])
      if op.op is Ops.SINK:
        return ret.replace(arg = KernelInfo(to_function_name(self.name) if name_override is None else name_override,
                                            self.local_dims, self.upcasted, self.dont_use_locals))
      if op.op is Ops.REDUCE_AXIS:
        reduce_idx = len(self.bufs) + self.reduceops.index(op) * 2

        def reduced_axes(start, stop):
          return tuple(i for i in range(start, stop) if resolve(self.sts[reduce_idx].shape[i] != self.sts[reduce_idx + 1].shape[i]))
        axes = reduced_axes(self.first_reduce + self.group_for_reduces, self.shape_len)
        grouped_axes = reduced_axes(self.first_reduce, self.first_reduce + self.group_for_reduces)

        if (tc := self.tensor_core) and (self.use_tensor_cores == 1 or self.use_tensor_cores == 3):
          wd, tcd = self.global_dims, self.first_upcast
          def get_upcast_axes(buf): # upcast along non-zero dimensions of (tc_reduce + tc_upcast)
            upcast_axes = int(math.log2(tc.elements_per_thread[buf]))
            return tuple((tcd + len(tc.get_reduce_axes()) + len(tc.get_upcast_axes()) - (i+1), 2) for i in range(upcast_axes))
          def get_tc_swizzle_st(shape, local_perm, upcast_perm):
            offset = (tcd - (wd + len(local_perm)))
            permaxis = list(range(wd)) \
              + [wd + x + (offset if x >= len(local_perm) else 0) for x in local_perm]  + list(range(wd + len(local_perm), tcd)) \
              + [wd + x + (offset if x >= len(local_perm) else 0) for x in upcast_perm] + list(range(tcd + len(upcast_perm), len(shape)))
            return ShapeTracker.from_shape(shape).permute(tuple(permaxis))

          srcs = list((ret.src[0] if ret.src[0].op is not Ops.CAST else ret.src[0].src[0]).src)
          for i, (src, swizzle) in enumerate(zip(srcs, tc.swizzle)):
            src_st = (src if src.op is Ops.LOAD else src.src[0]).st_arg
            if swizzle: srcs[i] = src.view(get_tc_swizzle_st(src_st.shape, *swizzle))

            if self.use_tensor_cores == 3:  # for TC=3, emulate the warp addressing with locals
              local_shape = tuple(1 if st == 0 or i < wd or (i >= self.first_reduce and i < tcd) else src_st.shape[i] \
                                  for i,st in enumerate(src_st.real_strides()))
              st = store_st = ShapeTracker.from_shape(local_shape)
              local_buffer = UOp(Ops.DEFINE_LOCAL, tc.dtype_in.ptr(size=st.real_size(), local=True), (), f"temp{i}")
              if swizzle: store_st = get_tc_swizzle_st(store_st.shape, *swizzle)
              local_store = UOp.store(local_buffer.view(store_st), srcs[i])
              srcs[i] = UOp(Ops.LOAD, tc.dtype_in, (local_buffer.view(st), local_store))

          tc_reduce_axes = tuple(tcd + ax for ax, _ in tc.get_reduce_axes())
          if self.use_tensor_cores == 1: # real WMMA, use CONTRACT/UNROLL to get the vectorization right
            tc_upcast_axes = (get_upcast_axes(0), get_upcast_axes(1), get_upcast_axes(2))
            wmma_arg = (str(tc), tc.dims, tc.dtype_in, tc.dtype_out, self.opts.device, tc.threads, tc_upcast_axes, tc_reduce_axes)
            wmma = UOp(Ops.WMMA, dtype=tc.dtype_out.vec(tc.elements_per_thread[2]), src=(
              UOp(Ops.CONTRACT, dtype=srcs[0].dtype.vec(tc.elements_per_thread[0]), src=(srcs[0],), arg=tc_upcast_axes[0]),
              UOp(Ops.CONTRACT, dtype=srcs[1].dtype.vec(tc.elements_per_thread[1]), src=(srcs[1],), arg=tc_upcast_axes[1]),
              UOp.const(tc.dtype_out.vec(tc.elements_per_thread[2]), 0.0)), arg=wmma_arg)
            tc_uop = UOp(Ops.UNROLL, tc.dtype_out, (wmma,), arg=tc_upcast_axes[2])

          else: # for TC=3 MUL/SUM instead of WMMA
            tc_uop = UOp(Ops.REDUCE_AXIS, tc.dtype_out, ((srcs[0] * srcs[1]).cast(tc.dtype_out),), (Ops.ADD, tc_reduce_axes))

          return ret.replace(src=(tc_uop,), arg=(Ops.ADD, new_axes)) if (new_axes := tuple(i for i in axes if i not in tc_reduce_axes)) else tc_uop

        ret = ret.replace(arg = (op.arg[0], axes))
        if self.group_for_reduces and grouped_axes:
          local_shape = (1,) * self.global_dims + self.full_shape[self.global_dims:self.global_dims+self.local_dims] + \
            tuple([self.full_shape[i] if self.sts[reduce_idx].shape[i] != self.sts[reduce_idx+1].shape[i] else 1 \
              for i in range(self.first_reduce, self.first_reduce+self.group_for_reduces)]) + \
            (1,) * (self.shape_len - self.upcasted - self.group_for_reduces - self.first_reduce) + tuple([x[0] for x in self.upcasted_axis(0)])
          st = ShapeTracker.from_shape(local_shape)
          local_size = st.real_size()
          local_buffer = UOp(Ops.DEFINE_LOCAL, op.dtype.ptr(local_size, local=True), (), f"temp{self.reduceops.index(op)}")
          local_load = UOp(Ops.LOAD, op.dtype, (local_buffer.view(st), UOp.store(local_buffer.view(st), ret)))
          grouped_reduce = UOp(Ops.REDUCE_AXIS, op.dtype, (local_load,), arg=(op.arg[0], grouped_axes))
          if op is self.reduceops[-1]: return grouped_reduce
          st = ShapeTracker.from_shape(tuple([1 if i in grouped_axes else a for i,a in enumerate(local_shape)]))
          return UOp(Ops.LOAD, op.dtype, (local_buffer.view(st), UOp.store(local_buffer.view(st), grouped_reduce)))

      return ret
    fixed_ast = fixup_ast(self.ast)
    del fixup_ast
    return graph_rewrite(fixed_ast, view_left, name="fixup optimized AST")

  # **** this is the lowerer ****

  @track_rewrites()
  def linearize(self, name_override:Optional[str]=None, ast_transform:Optional[Callable]=None) -> Kernel:
    # display the AST
    if getenv("VIZ"): graph_rewrite(self.ast, PatternMatcher([]), name="View Base AST")

    modified_ast = self.get_optimized_ast(name_override)
    if ast_transform is not None: modified_ast = ast_transform(self, modified_ast)

    if DEBUG >= 3:
      print(self.name)
      if DEBUG >= 5: print(self.ast)
      for i,(buf,st) in enumerate([(buf,st) for buf,st in zip(self.bufs, self.sts) if buf.op not in {Ops.CONST, Ops.VALID}]):
        print(f"{i:2d}: {str(st.shape):25s} {str(buf.src[0].dtype).replace('dtypes.',''):20s} {str(st.real_strides()):30s}",
              str(st) if DEBUG >= 4 else "")
      print(self.applied_opts)
      if DEBUG >= 5: print(modified_ast)
    # verify AST matches the spec after applying opts
    if __debug__: type_verify(list(modified_ast.toposort()))
    # TODO: sadly modified_ast doesn't pass the shape spec because of how group_for_reduces constructs UOps, there's probably a way to fix this
    #if __debug__: type_verify(list(modified_ast.toposort()), ast_spec)

    try:
      self.uops:list[UOp] = full_rewrite(modified_ast, self.opts)
    except RuntimeError:
      print("***** LINEARIZE FAILURE *****")
      print(f"ast = {self.ast}")
      print(f"opts = {self.applied_opts}")
      raise
    if DEBUG >= 6: print_uops(self.uops)
    return self

  def to_program(self, name_override:Optional[str]=None, ast_transform:Optional[Callable]=None) -> ProgramSpec:
    self.linearize(name_override, ast_transform)
    assert self.uops[-1].op is Ops.SINK, "last uop must be sink"
    src = self.opts.render(self.uops)
    # group non-local bufs by the op type (LOAD or STORE) and the buffer arg. take the max access of that buffer in bytes
    # TODO: these max and min don't work on symbolic, and results are very wrong.
    mem_bytes = sum(max(x.src[0].dtype.nbytes() for x in group)
      for _, group in itertools.groupby([x for x in self.ast.toposort() if x.op in GroupOp.Buffer and x.src[0].base.op is Ops.DEFINE_GLOBAL],
                        key=lambda x: (x.op, x.src[0].base.arg)))
    return ProgramSpec(self.name if not name_override else name_override, src, self.opts.device, self.ast, self.uops, self.applied_opts, mem_bytes,
                       global_size=[1,1,1] if self.opts.has_local else None, local_size=[1,1,1] if self.opts.has_local else None)