openpilot is an open source driver assistance system. openpilot performs the functions of Automated Lane Centering and Adaptive Cruise Control for over 200 supported car makes and models.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

200 lines
6.9 KiB

"""This is where the forwards and backwards passes live."""
import math
from tinygrad.helpers import argsort
from tinygrad.dtype import dtypes, DType, sum_acc_dtype
from tinygrad.ops import Ops, resolve, sint, UOp
from tinygrad.tensor import Function
class Contiguous(Function):
def forward(self, x:UOp) -> UOp: return x.contiguous()
def backward(self, grad_output:UOp) -> UOp: return grad_output
class ContiguousBackward(Function):
def forward(self, x:UOp) -> UOp: return x
def backward(self, grad_output:UOp) -> UOp: return grad_output.contiguous()
class Cast(Function):
def forward(self, x:UOp, dtype:DType, bitcast:bool=False) -> UOp:
self.input_dtype, self.bitcast = x.dtype, bitcast
return x.bitcast(dtype) if self.bitcast else x.cast(dtype)
def backward(self, grad_output:UOp) -> UOp:
if self.bitcast: raise RuntimeError("bitcast cannot backward")
return grad_output.cast(self.input_dtype)
# ************* unary ops *************
class Reciprocal(Function):
def forward(self, x:UOp) -> UOp:
self.ret = x.reciprocal()
return self.ret
def backward(self, grad_output:UOp) -> UOp: return -grad_output * self.ret * self.ret
class Sin(Function):
def forward(self, x:UOp) -> UOp:
self.x = x
return x.sin()
def backward(self, grad_output:UOp) -> UOp: return (math.pi/2 - self.x).sin() * grad_output
class Relu(Function):
def forward(self, x:UOp) -> UOp:
self.ret = (x>0).where(x, 0)
return self.ret
def backward(self, grad_output:UOp) -> UOp: return (self.ret>0).cast(grad_output.dtype) * grad_output
class Log(Function):
def forward(self, x:UOp) -> UOp:
self.x = x
return x.log2() * math.log(2)
def backward(self, grad_output:UOp) -> UOp: return grad_output / self.x
class Exp(Function):
def forward(self, x:UOp) -> UOp:
self.ret = (x * (1/math.log(2))).exp2()
return self.ret
def backward(self, grad_output:UOp) -> UOp: return self.ret * grad_output
class Sqrt(Function):
def forward(self, x:UOp) -> UOp:
self.ret = x.sqrt()
return self.ret
def backward(self, grad_output:UOp) -> UOp: return grad_output / (self.ret*2)
class Sign(Function):
# NOTE: the x*0 is to match torch behavior without function.py
def forward(self, x:UOp) -> UOp: return x.ne(0).where((x<0).where(x.const_like(-1), x.const_like(1)), x.const_like(0)) + x*0
# backward always return 0 to match torch
def backward(self, grad_output:UOp) -> UOp: return grad_output.const_like(0)
# ************* binary ops *************
class Less(Function):
def forward(self, x:UOp, y:UOp) -> UOp: return x<y
def backward(self, grad_output:UOp) -> tuple[UOp|None, UOp|None]: return None, None
class Neq(Function):
def forward(self, x:UOp, y:UOp) -> UOp: return x.ne(y)
def backward(self, grad_output:UOp) -> tuple[UOp|None, UOp|None]: return None, None
class Xor(Function):
def forward(self, x:UOp, y:UOp) -> UOp: return x^y
class BitwiseAnd(Function):
def forward(self, x:UOp, y:UOp) -> UOp: return x&y
class BitwiseOr(Function):
def forward(self, x:UOp, y:UOp) -> UOp: return x|y
class Threefry(Function):
def forward(self, x:UOp, seed:UOp) -> UOp: return x.threefry(seed)
class Add(Function):
def forward(self, x:UOp, y:UOp) -> UOp: return x+y
def backward(self, grad_output:UOp) -> tuple[UOp|None, UOp|None]:
return grad_output if self.needs_input_grad[0] else None, \
grad_output if self.needs_input_grad[1] else None
class Mul(Function):
def forward(self, x:UOp, y:UOp) -> UOp:
self.x, self.y = x, y
return x * y
def backward(self, grad_output:UOp) -> tuple[UOp|None, UOp|None]:
return (self.y * grad_output) if self.needs_input_grad[0] else None, \
(self.x * grad_output) if self.needs_input_grad[1] else None
class IDiv(Function):
def forward(self, x:UOp, y:UOp) -> UOp: return x // y
# ************* ternary ops *************
class Where(Function):
def forward(self, x:UOp, y:UOp, z:UOp) -> UOp:
self.x = x
return self.x.where(y, z)
def backward(self, grad_output:UOp) -> tuple[None, UOp|None, UOp|None]:
return None, \
self.x.where(grad_output, grad_output.const_like(0)) if self.needs_input_grad[1] else None, \
self.x.where(grad_output.const_like(0), grad_output) if self.needs_input_grad[2] else None
# ************* reduce ops *************
class Sum(Function):
def forward(self, x:UOp, axis:tuple[int, ...]) -> UOp:
self.input_shape = x.shape
return x.r(Ops.ADD, axis)
def backward(self, grad_output:UOp) -> UOp: return grad_output.expand(self.input_shape)
class Prod(Function):
def forward(self, x:UOp, axis:tuple[int, ...]) -> UOp:
self.x, self.ret = x, x.r(Ops.MUL, axis)
return self.ret
def backward(self, grad_output:UOp) -> UOp:
return (grad_output * self.ret).expand(self.x.shape) / self.x
class Max(Function):
def forward(self, x:UOp, axis:tuple[int, ...]) -> UOp:
self.x, self.ret, self.axis = x, x.r(Ops.MAX, axis), axis
return self.ret
def backward(self, grad_output:UOp) -> UOp:
# 1s in locations where the max was chosen (can be two locations)
max_is_1s = self.x.ne(self.ret.expand(self.x.shape)).ne(self.x.const_like(1).cast(dtypes.bool)).cast(grad_output.dtype)
div = max_is_1s.r(Ops.ADD, self.axis).expand(self.x.shape)
return (max_is_1s/div) * grad_output.expand(self.x.shape)
# ************* movement ops *************
# NOTE: this is sum in reverse
class Expand(Function):
def forward(self, x:UOp, shape:tuple[int, ...]) -> UOp:
self.expanded_axis = tuple(i for i, (si, so) in enumerate(zip(x.shape, shape)) if resolve(si != so))
return x.expand(shape)
def backward(self, grad_output:UOp) -> UOp:
return grad_output.cast(sum_acc_dtype(grad_output.dtype)).r(Ops.ADD, self.expanded_axis).cast(grad_output.dtype)
class Reshape(Function):
def forward(self, x:UOp, shape:tuple[int, ...]) -> UOp:
self.input_shape = x.shape
return x.reshape(shape)
def backward(self, grad_output:UOp) -> UOp: return grad_output.reshape(self.input_shape)
class Permute(Function):
def forward(self, x:UOp, order:tuple[int, ...]) -> UOp:
self.input_order = order
return x.permute(order)
def backward(self, grad_output:UOp) -> UOp: return grad_output.permute(argsort(self.input_order))
class Pad(Function):
def forward(self, x:UOp, arg:tuple[tuple[int, int], ...]) -> UOp:
self.narg = tuple([(p[0], s+p[0]) for s,p in zip(x.shape, arg)])
return x.pad(arg)
def backward(self, grad_output:UOp) -> UOp: return grad_output.shrink(self.narg)
class Shrink(Function):
def forward(self, x:UOp, arg:tuple[tuple[sint, sint], ...]) -> UOp:
self.narg = tuple([(p[0], s-p[1]) for s,p in zip(x.shape, arg)])
return x.shrink(arg)
def backward(self, grad_output:UOp) -> UOp: return grad_output.pad(self.narg)
class Flip(Function):
def forward(self, x:UOp, axis:tuple[int, ...]) -> UOp:
self.arg = tuple([-1 if i in axis else 1 for i in range(len(x.shape))])
return x.stride(self.arg)
def backward(self, grad_output:UOp) -> UOp: return grad_output.stride(self.arg)