from tinygrad.uop.ops import Ops, UOp, resolve, can_pad, GroupOp, UPat, PatternMatcher, graph_rewrite
from tinygrad.helpers import all_int, prod, unwrap, dedup, DONT_REALIZE_EXPAND, DONT_GROUP_REDUCES, FUSE_CONV_BW
from tinygrad.shape.shapetracker import ShapeTracker

ALWAYS_CONTIGUOUS = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW,
                     Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK, Ops.GBARRIER}

# **** Grouper decides which of the UOps realize

def realize(ctx:dict[UOp, None], tr:UOp) -> None: ctx[tr] = None

def realize_parents(ctx:dict[UOp, None], rb:UOp) -> None:
  for s in rb.src:
    if s.op not in ALWAYS_CONTIGUOUS: ctx[s] = None

def realize_before_view(ctx:dict[UOp, None], view:UOp, tr:UOp) -> None:
  st = unwrap(view.st)
  # always realize unsafe pad ops before masked view
  if any(v.mask is not None for v in st.views) and not can_pad(tr, ctx): return realize(ctx, tr)
  # fold simple pads
  if len(st.views) == 1 and (m:=st.views[-1].mask) is not None and all_int(tr.shape) and resolve(prod(tr.shape) >= prod([y-x for x,y in m])): return
  # realize before expand
  if resolve(prod(tr.shape) < prod(st.shape)) and not DONT_REALIZE_EXPAND: return realize(ctx, tr)

do_realize = PatternMatcher([
  # always realize SINK parents
  (UPat(Ops.SINK, name="s"), lambda ctx,s: ctx.update((x.base, None) for x in s.src if x.base.op not in ALWAYS_CONTIGUOUS)),
  # always realize ASSIGN/CONTIGUOUS/GroupOp.Meta
  (UPat({Ops.ASSIGN, Ops.CONTIGUOUS, *GroupOp.Meta}, name="tr"), realize),
  # realize before expand or unsafe pad ops
  (UPat(Ops.VIEW, src=(UPat(GroupOp.All-ALWAYS_CONTIGUOUS, name="tr"),), name="view"), realize_before_view),
  # realize parents of COPY, MSELECT, MSTACK
  (UPat((Ops.COPY, Ops.MSELECT, Ops.MSTACK), name="rb"), realize_parents),
])

def recursive_group(tr:UOp, st:ShapeTracker, r:UOp, children:dict[UOp, dict[UOp, None]], realizes:dict[UOp, None],
                    reduce_for_op:dict[UOp, UOp], group:dict[UOp, None], cache:dict[tuple[UOp, ShapeTracker], None]) -> None:
  if (tr, st) in cache: return
  cache.setdefault((tr, st))
  rsize = unwrap(r.st).size
  if tr in realizes and tr is not r:
    # can only fuse contiguous
    # max one reduceop per kernel
    if not st.contiguous or st.size != rsize or tr in reduce_for_op: group.setdefault(r)
    return group.setdefault(tr)
  for tr_next in children.get(tr, {}):
    # max one reduceop per kernel
    if tr_next.op is Ops.REDUCE_AXIS: return group.setdefault(r)
    # can only fuse contiguous
    if len(st_childs:=dedup(unwrap(x.st) for x in tr_next.src if x.base == tr)) > 1: return group.setdefault(r)
    recursive_group(tr_next, st+st_childs[0], r, children, realizes, reduce_for_op, group, cache)

def group_realizes(sink:UOp) -> dict[UOp, None]:
  # start by adding uops that always realize
  realizes: dict[UOp, None] = {}
  sink = graph_rewrite(sink, do_realize, ctx=realizes, name="do_realize")
  if DONT_GROUP_REDUCES: return realizes

  # construct children graph (only for bases)
  children: dict[UOp, dict[UOp, None]] = {}
  assigns: dict[UOp, None] = {}
  for u in (toposort:=sink.toposort()):
    if u.op in {Ops.VIEW, Ops.SINK}: continue
    if u.op is Ops.ASSIGN: assigns[u.buf_uop] = None
    for s in u.src: children.setdefault(s.base, {})[u] = None

  # find all reduces, and pair them to a elementwise op. if they can't be cleanly paired, force realize the reduce (or a contig child)
  reduce_for_op: dict[UOp, UOp] = {}
  double_reduces: list[UOp] = []
  for r in toposort:
    if r.op is not Ops.REDUCE_AXIS: continue
    if len(r.arg) == 3 and r.arg[2] is True: continue
    if FUSE_CONV_BW and r.src[0].base.op is Ops.REDUCE_AXIS and r.src[0] is not r.src[0].base: double_reduces.append(r)
    if r in realizes: continue
    group: dict[UOp, None] = {}
    recursive_group(r, unwrap(r.st), r, children, realizes, reduce_for_op, group, cache={})
    # max one reduceop per kernel
    can_chase = all(tr not in reduce_for_op for tr in group)
    for u in r.toposort(gate=lambda u: u not in realizes):
      if u.op is Ops.REDUCE_AXIS and u.src[0].base.op is Ops.CONST:
        can_chase = False
        break
    # TODO: forced_realize exists because the scheduler is incapable of checking for self-contained DAGs
    forced_realize = r in group
    # can only have one output
    if not forced_realize and len(group) > 1: forced_realize = True
    # can only fuse assign if no other assign_target is used in the kernel
    if not forced_realize and (assign_targets:={x.buf_uop for x in group if x.op is Ops.ASSIGN}):
      parents = [r, *group]
      while parents and not forced_realize:
        p = parents.pop().base
        if p.op is Ops.BUFFER and p in assigns and p not in assign_targets: forced_realize, can_chase = True, False
        if p in realizes: continue
        parents.extend(p.src)
    if forced_realize or not group:
      tr = r
      if can_chase:
        # can chase this down to contiguous children
        st = unwrap(tr.st)
        while len(lst:=children.get(tr, {})) == 1:
          tr_next = next(iter(lst))
          st_childs = dedup(unwrap(s.st) for s in tr_next.src if s.base is tr)
          if len(st_childs) > 1: break
          if st.size != st_childs[0].size: break
          st = st + st_childs[0]
          if not st.contiguous or tr_next.op is Ops.REDUCE_AXIS: break
          tr = tr_next
        # don't cast to higher size before store (tr cannot be realized if forced_realize)
        if tr.op is Ops.CAST and tr.dtype.itemsize > tr.src[0].dtype.itemsize:
          tr = tr.src[0].base
      group = {tr: None}
      realizes[tr] = None
    reduce_for_op.update((tr, r) for tr in group)
  # fuse double reduces with no other child
  for reduceop in double_reduces:
    top_reduce = reduceop.src[0].base
    if len(children.get(top_reduce, {})) == 1: del realizes[top_reduce]
  return realizes