# sorted in order of increasing complexity from typing import List from tinygrad.helpers import dedup from tinygrad.tensor import Tensor class Optimizer: def __init__(self, params: List[Tensor], lr: float): # if it's None, but being put into an optimizer, set it to True for x in params: if x.requires_grad is None: x.requires_grad = True self.params: List[Tensor] = dedup([x for x in params if x.requires_grad]) self.buffers: List[Tensor] = dedup([x for x in params if not x.requires_grad]) # buffers are still realized self.lr = Tensor([lr], requires_grad=False).contiguous() def zero_grad(self): for param in self.params: param.grad = None def realize(self, extra=None): # NOTE: in extra is too late for most of the params due to issues with assign Tensor.corealize(extra + self.params + self.buffers if extra is not None else self.params + self.buffers) class SGD(Optimizer): def __init__(self, params: List[Tensor], lr=0.001, momentum=0, weight_decay=0.0, nesterov=False): super().__init__(params, lr) self.momentum, self.wd, self.nesterov = momentum, weight_decay, nesterov self.b = [Tensor.zeros(*t.shape, device=t.device, requires_grad=False) for t in self.params] if self.momentum else [] # https://pytorch.org/docs/stable/generated/torch.optim.SGD.html def step(self) -> None: for i, t in enumerate(self.params): assert t.grad is not None g = t.grad.realize() + self.wd * t.detach() if self.momentum: self.b[i].assign(self.momentum * self.b[i] + g).realize() # NOTE: self.b[i] is zero on the first run, no if required g = (g + self.momentum * self.b[i]) if self.nesterov else self.b[i] t.assign(t.detach() - g * self.lr) self.realize(self.b) # LAMB is essentially just the trust ratio part of LARS applied to Adam/W so if we just set the trust ratio to 1.0 its just Adam/W. def AdamW(params: List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8, wd=0.01): return LAMB(params, lr, b1, b2, eps, wd, adam=True) def Adam(params: List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8): return LAMB(params, lr, b1, b2, eps, 0.0, adam=True) class LAMB(Optimizer): def __init__(self, params: List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-6, wd=0.0, adam=False): super().__init__(params, lr) self.b1, self.b2, self.eps, self.wd, self.adam, self.t = b1, b2, eps, wd, adam, Tensor([0], requires_grad=False).realize() self.m = [Tensor.zeros(*t.shape, device=t.device, requires_grad=False) for t in self.params] self.v = [Tensor.zeros(*t.shape, device=t.device, requires_grad=False) for t in self.params] def step(self) -> None: self.t.assign(self.t + 1).realize() for i, t in enumerate(self.params): assert t.grad is not None g = t.grad.realize() self.m[i].assign(self.b1 * self.m[i] + (1.0 - self.b1) * g).realize() self.v[i].assign(self.b2 * self.v[i] + (1.0 - self.b2) * (g * g)).realize() m_hat = self.m[i] / (1.0 - self.b1**self.t) v_hat = self.v[i] / (1.0 - self.b2**self.t) up = (m_hat / (v_hat.sqrt() + self.eps)) + self.wd * t.detach() if not self.adam: r1 = t.detach().square().sum().sqrt() r2 = up.square().sum().sqrt() r = Tensor.where(r1 > 0, Tensor.where(r2 > 0, r1 / r2, 1.0), 1.0) else: r = 1.0 t.assign(t.detach() - self.lr * r * up) self.realize([self.t] + self.m + self.v)