You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
69 lines
3.9 KiB
69 lines
3.9 KiB
1 month ago
|
from typing import cast
|
||
|
import math, functools
|
||
|
from tinygrad.dtype import dtypes, sum_acc_dtype
|
||
|
from tinygrad.ops import UOp, PatternMatcher, UPat, Ops
|
||
|
from tinygrad.helpers import argsort
|
||
|
|
||
|
def reduce_gradient(ctx:UOp, ret:UOp):
|
||
|
if ret.arg[0] == Ops.ADD: return (ctx.expand(ret.src[0].shape),)
|
||
|
if ret.arg[0] == Ops.MAX:
|
||
|
max_is_1s = ret.src[0].ne(ret.expand(ret.src[0].shape)).ne(ret.src[0].const_like(1).cast(dtypes.bool)).cast(ctx.dtype)
|
||
|
div = max_is_1s.r(Ops.ADD, ret.arg[1]).expand(ret.src[0].shape)
|
||
|
return ((max_is_1s/div) * ctx.expand(ret.src[0].shape),)
|
||
|
if ret.arg[0] == Ops.MUL: return ((ctx * ret).expand(ret.src[0].shape) / ret.src[0],)
|
||
|
|
||
|
# ctx is grad_output
|
||
|
pm_gradient = PatternMatcher([
|
||
|
(UPat(Ops.CAST, name="ret"), lambda ctx, ret: (ctx.cast(ret.src[0].dtype),)),
|
||
|
(UPat(Ops.RECIP, name="ret"), lambda ctx, ret: (-ctx * ret * ret,)),
|
||
|
(UPat(Ops.SIN, name="ret"), lambda ctx, ret: ((math.pi/2 - ret.src[0]).sin() * ctx,)),
|
||
|
(UPat(Ops.LOG2, name="ret"), lambda ctx, ret: (ctx / (ret.src[0] * math.log(2)),)),
|
||
|
(UPat(Ops.EXP2, name="ret"), lambda ctx, ret: (ret * ctx * math.log(2),)),
|
||
|
(UPat(Ops.SQRT, name="ret"), lambda ctx, ret: (ctx / (ret*2),)),
|
||
|
(UPat((Ops.CMPLT, Ops.CMPNE)), lambda: (None, None)),
|
||
|
(UPat(Ops.ADD), lambda ctx: (ctx, ctx)),
|
||
|
(UPat(Ops.MAX, name="ret"), lambda ctx, ret: ((ret.src[0]>ret.src[1]).where(ctx, (ret.src[0]!=ret.src[1]).where(ctx.const_like(0), ctx * 0.5)),
|
||
|
(ret.src[0]<ret.src[1]).where(ctx, (ret.src[0]!=ret.src[1]).where(ctx.const_like(0), ctx * 0.5)))),
|
||
|
(UPat(Ops.MUL, name="ret"), lambda ctx, ret: (ret.src[1]*ctx, ret.src[0]*ctx)),
|
||
|
(UPat(Ops.WHERE, name="ret"), lambda ctx, ret: (None, ret.src[0].where(ctx, ctx.const_like(0)), ret.src[0].where(ctx.const_like(0), ctx))),
|
||
|
(UPat(Ops.REDUCE_AXIS, name="ret"), reduce_gradient),
|
||
|
(UPat(Ops.CONTIGUOUS), lambda ctx: (ctx,)),
|
||
|
(UPat(Ops.RESHAPE, name="ret"), lambda ctx, ret: (ctx.reshape(ret.src[0].shape),)),
|
||
|
(UPat(Ops.PERMUTE, name="ret"), lambda ctx, ret: (ctx.permute(argsort(ret.arg)),)),
|
||
|
(UPat(Ops.PAD, name="ret"), lambda ctx, ret: (ctx.shrink(tuple([(p[0], s+p[0]) for s,p in zip(ret.src[0].shape, ret.arg)])),)),
|
||
|
(UPat(Ops.SHRINK, name="ret"), lambda ctx, ret: (ctx.pad(tuple([(p[0], s-p[1]) for s,p in zip(ret.src[0].shape, ret.arg)])),)),
|
||
|
(UPat(Ops.STRIDE, name="ret"), lambda ctx, ret: (ctx.stride(ret.arg) if all(x in {-1,1} for x in ret.arg) else None,)),
|
||
|
# TODO: this cast can be removed by putting the casts around the EXPAND
|
||
|
(UPat(Ops.EXPAND, name="ret"), lambda ctx, ret:
|
||
|
(ctx.cast(sum_acc_dtype(ctx.dtype)).r(Ops.ADD, tuple(i for i,(si,so) in enumerate(zip(ret.src[0].shape, ret.arg)) if si!=so)).cast(ctx.dtype),)),
|
||
|
|
||
|
# there's no gradient for...is this ASSIGN?
|
||
|
(UPat(Ops.VIEW, src=(UPat(Ops.BUFFER), UPat(Ops.BUFFER_VIEW))), lambda: (None, None)),
|
||
|
])
|
||
|
|
||
|
# copied from tensor.py, get relevant toposort of gradients
|
||
|
def _deepwalk(root:UOp, targets:list[UOp]):
|
||
|
@functools.lru_cache(None)
|
||
|
def is_in_target_path(x:UOp) -> bool: return any(u in targets or is_in_target_path(u) for u in x.src)
|
||
|
def _walk(node:UOp, visited:set[UOp]):
|
||
|
visited.add(node)
|
||
|
if node.op is Ops.DETACH: return
|
||
|
if is_in_target_path(node):
|
||
|
for i in node.src:
|
||
|
if i not in visited: yield from _walk(i, visited)
|
||
|
yield node
|
||
|
return list(_walk(root, set()))
|
||
|
|
||
|
def compute_gradient(root:UOp, root_grad:UOp, targets:list[UOp]) -> dict[UOp, UOp]:
|
||
|
grads = {root: root_grad}
|
||
|
for t0 in reversed(_deepwalk(root, targets)):
|
||
|
if t0 not in grads: continue
|
||
|
lgrads: tuple[UOp|None, ...]|None = cast(tuple[UOp, ...]|None, pm_gradient.rewrite(t0, ctx=grads[t0]))
|
||
|
if lgrads is None: raise RuntimeError(f"failed to compute gradient for {t0.op}\n\nin {str(t0)[0:1000]}...")
|
||
|
assert len(lgrads) == len(t0.src), f"got {len(lgrads)} gradient, expected {len(t0.src)}"
|
||
|
for k,v in zip(t0.src, lgrads):
|
||
|
if v is None: continue
|
||
|
if k in grads: grads[k] = grads[k] + v
|
||
|
else: grads[k] = v
|
||
|
return grads
|