from typing import cast import math, functools from tinygrad.dtype import dtypes, sum_acc_dtype from tinygrad.ops import UOp, PatternMatcher, UPat, Ops from tinygrad.helpers import argsort def reduce_gradient(ctx:UOp, ret:UOp): if ret.arg[0] == Ops.ADD: return (ctx.expand(ret.src[0].shape),) if ret.arg[0] == Ops.MAX: max_is_1s = ret.src[0].ne(ret.expand(ret.src[0].shape)).ne(ret.src[0].const_like(1).cast(dtypes.bool)).cast(ctx.dtype) div = max_is_1s.r(Ops.ADD, ret.arg[1]).expand(ret.src[0].shape) return ((max_is_1s/div) * ctx.expand(ret.src[0].shape),) if ret.arg[0] == Ops.MUL: return ((ctx * ret).expand(ret.src[0].shape) / ret.src[0],) # ctx is grad_output pm_gradient = PatternMatcher([ (UPat(Ops.CAST, name="ret"), lambda ctx, ret: (ctx.cast(ret.src[0].dtype),)), (UPat(Ops.RECIP, name="ret"), lambda ctx, ret: (-ctx * ret * ret,)), (UPat(Ops.SIN, name="ret"), lambda ctx, ret: ((math.pi/2 - ret.src[0]).sin() * ctx,)), (UPat(Ops.LOG2, name="ret"), lambda ctx, ret: (ctx / (ret.src[0] * math.log(2)),)), (UPat(Ops.EXP2, name="ret"), lambda ctx, ret: (ret * ctx * math.log(2),)), (UPat(Ops.SQRT, name="ret"), lambda ctx, ret: (ctx / (ret*2),)), (UPat((Ops.CMPLT, Ops.CMPNE)), lambda: (None, None)), (UPat(Ops.ADD), lambda ctx: (ctx, ctx)), (UPat(Ops.MAX, name="ret"), lambda ctx, ret: ((ret.src[0]>ret.src[1]).where(ctx, (ret.src[0]!=ret.src[1]).where(ctx.const_like(0), ctx * 0.5)), (ret.src[0] bool: return any(u in targets or is_in_target_path(u) for u in x.src) def _walk(node:UOp, visited:set[UOp]): visited.add(node) if node.op is Ops.DETACH: return if is_in_target_path(node): for i in node.src: if i not in visited: yield from _walk(i, visited) yield node return list(_walk(root, set())) def compute_gradient(root:UOp, root_grad:UOp, targets:list[UOp]) -> dict[UOp, UOp]: grads = {root: root_grad} for t0 in reversed(_deepwalk(root, targets)): if t0 not in grads: continue lgrads: tuple[UOp|None, ...]|None = cast(tuple[UOp, ...]|None, pm_gradient.rewrite(t0, ctx=grads[t0])) if lgrads is None: raise RuntimeError(f"failed to compute gradient for {t0.op}\n\nin {str(t0)[0:1000]}...") assert len(lgrads) == len(t0.src), f"got {len(lgrads)} gradient, expected {len(t0.src)}" for k,v in zip(t0.src, lgrads): if v is None: continue if k in grads: grads[k] = grads[k] + v else: grads[k] = v return grads