from typing import cast, Iterator import math, functools, dataclasses from tinygrad.dtype import dtypes, sum_acc_dtype from tinygrad.ops import UOp, PatternMatcher, UPat, Ops, all_metadata from tinygrad.helpers import argsort def reduce_gradient(ctx:UOp, ret:UOp): if ret.arg[0] == Ops.ADD: return (ctx.expand(ret.src[0].shape),) if ret.arg[0] == Ops.MAX: max_is_1s = ret.src[0].ne(ret.expand(ret.src[0].shape)).ne(ret.src[0].const_like(1).cast(dtypes.bool)).cast(ctx.dtype) div = max_is_1s.r(Ops.ADD, ret.arg[1]).expand(ret.src[0].shape) return ((max_is_1s/div) * ctx.expand(ret.src[0].shape),) if ret.arg[0] == Ops.MUL: return ((ctx * ret).expand(ret.src[0].shape) / ret.src[0],) # ctx is grad_output pm_gradient = PatternMatcher([ (UPat(Ops.CAST, name="ret"), lambda ctx, ret: (ctx.cast(ret.src[0].dtype),)), (UPat(Ops.RECIP, name="ret"), lambda ctx, ret: (-ctx * ret * ret,)), (UPat(Ops.SIN, name="ret"), lambda ctx, ret: ((math.pi/2 - ret.src[0]).sin() * ctx,)), (UPat(Ops.LOG2, name="ret"), lambda ctx, ret: (ctx / (ret.src[0] * math.log(2)),)), (UPat(Ops.EXP2, name="ret"), lambda ctx, ret: (ret * ctx * math.log(2),)), (UPat(Ops.SQRT, name="ret"), lambda ctx, ret: (ctx / (ret*2),)), (UPat((Ops.CMPLT, Ops.CMPNE)), lambda: (None, None)), (UPat(Ops.ADD), lambda ctx: (ctx, ctx)), (UPat(Ops.POW, name="ret"), lambda ctx, ret: (ctx*(ret.src[0].eq(0) & ret.src[1].eq(0)).where(ret.src[1], ret.src[1]*ret.src[0].pow(ret.src[1]-1)), ctx*ret.src[0].eq(0).where((ret.src[1]<0).where(ret.const_like(-math.inf), ret.const_like(0)), ret*ret.src[0].log2()*math.log(2.0)))), (UPat(Ops.MAX, name="ret"), lambda ctx, ret: ((ret.src[0]>ret.src[1]).where(ctx, (ret.src[0]!=ret.src[1]).where(ctx.const_like(0), ctx * 0.5)), (ret.src[0] list[UOp]: @functools.lru_cache(None) def is_in_target_path(x:UOp) -> bool: return any(u in targets or is_in_target_path(u) for u in x.src) def _walk(node:UOp, visited:set[UOp]) -> Iterator[UOp]: visited.add(node) if node.op is Ops.DETACH: return if is_in_target_path(node): for i in node.src: if i not in visited: yield from _walk(i, visited) yield node return list(_walk(root, set())) def compute_gradient(root:UOp, root_grad:UOp, targets:set[UOp]) -> dict[UOp, UOp]: grads = {root: root_grad} for t0 in reversed(_deepwalk(root, targets)): if t0 not in grads: continue lgrads: tuple[UOp|None, ...]|None = cast(tuple[UOp, ...]|None, pm_gradient.rewrite(t0, ctx=grads[t0])) if lgrads is None: raise RuntimeError(f"failed to compute gradient for {t0.op}\n\nin {str(t0)[0:1000]}...") assert len(lgrads) == len(t0.src), f"got {len(lgrads)} gradient, expected {len(t0.src)}" for k,v in zip(t0.src, lgrads): if v is None: continue if k in grads: grads[k] = grads[k] + v else: grads[k] = v if (forward_metadata:=all_metadata.get(t0)) is not None: all_metadata[v] = dataclasses.replace(forward_metadata, backward=True) return grads