You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
68 lines
3.4 KiB
68 lines
3.4 KiB
# sorted in order of increasing complexity
|
|
from typing import List
|
|
from tinygrad.helpers import dedup
|
|
from tinygrad.tensor import Tensor
|
|
|
|
class Optimizer:
|
|
def __init__(self, params: List[Tensor], lr: float):
|
|
# if it's None, but being put into an optimizer, set it to True
|
|
for x in params:
|
|
if x.requires_grad is None: x.requires_grad = True
|
|
|
|
self.params: List[Tensor] = dedup([x for x in params if x.requires_grad])
|
|
self.buffers: List[Tensor] = dedup([x for x in params if not x.requires_grad]) # buffers are still realized
|
|
self.lr = Tensor([lr], requires_grad=False).contiguous()
|
|
|
|
def zero_grad(self):
|
|
for param in self.params: param.grad = None
|
|
|
|
def realize(self, extra=None):
|
|
# NOTE: in extra is too late for most of the params due to issues with assign
|
|
Tensor.corealize(extra + self.params + self.buffers if extra is not None else self.params + self.buffers)
|
|
|
|
class SGD(Optimizer):
|
|
def __init__(self, params: List[Tensor], lr=0.001, momentum=0, weight_decay=0.0, nesterov=False):
|
|
super().__init__(params, lr)
|
|
self.momentum, self.wd, self.nesterov = momentum, weight_decay, nesterov
|
|
self.b = [Tensor.zeros(*t.shape, device=t.device, requires_grad=False) for t in self.params] if self.momentum else []
|
|
|
|
# https://pytorch.org/docs/stable/generated/torch.optim.SGD.html
|
|
def step(self) -> None:
|
|
for i, t in enumerate(self.params):
|
|
assert t.grad is not None
|
|
g = t.grad.realize() + self.wd * t.detach()
|
|
if self.momentum:
|
|
self.b[i].assign(self.momentum * self.b[i] + g).realize() # NOTE: self.b[i] is zero on the first run, no if required
|
|
g = (g + self.momentum * self.b[i]) if self.nesterov else self.b[i]
|
|
t.assign(t.detach() - g * self.lr)
|
|
self.realize(self.b)
|
|
|
|
# LAMB is essentially just the trust ratio part of LARS applied to Adam/W so if we just set the trust ratio to 1.0 its just Adam/W.
|
|
def AdamW(params: List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8, wd=0.01): return LAMB(params, lr, b1, b2, eps, wd, adam=True)
|
|
def Adam(params: List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8): return LAMB(params, lr, b1, b2, eps, 0.0, adam=True)
|
|
|
|
class LAMB(Optimizer):
|
|
def __init__(self, params: List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-6, wd=0.0, adam=False):
|
|
super().__init__(params, lr)
|
|
self.b1, self.b2, self.eps, self.wd, self.adam, self.t = b1, b2, eps, wd, adam, Tensor([0], requires_grad=False).realize()
|
|
self.m = [Tensor.zeros(*t.shape, device=t.device, requires_grad=False) for t in self.params]
|
|
self.v = [Tensor.zeros(*t.shape, device=t.device, requires_grad=False) for t in self.params]
|
|
|
|
def step(self) -> None:
|
|
self.t.assign(self.t + 1).realize()
|
|
for i, t in enumerate(self.params):
|
|
assert t.grad is not None
|
|
g = t.grad.realize()
|
|
self.m[i].assign(self.b1 * self.m[i] + (1.0 - self.b1) * g).realize()
|
|
self.v[i].assign(self.b2 * self.v[i] + (1.0 - self.b2) * (g * g)).realize()
|
|
m_hat = self.m[i] / (1.0 - self.b1**self.t)
|
|
v_hat = self.v[i] / (1.0 - self.b2**self.t)
|
|
up = (m_hat / (v_hat.sqrt() + self.eps)) + self.wd * t.detach()
|
|
if not self.adam:
|
|
r1 = t.detach().square().sum().sqrt()
|
|
r2 = up.square().sum().sqrt()
|
|
r = Tensor.where(r1 > 0, Tensor.where(r2 > 0, r1 / r2, 1.0), 1.0)
|
|
else:
|
|
r = 1.0
|
|
t.assign(t.detach() - self.lr * r * up)
|
|
self.realize([self.t] + self.m + self.v)
|
|
|