You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
158 lines
6.4 KiB
158 lines
6.4 KiB
# much taken from https://github.com/cloneofsimo/minRF
|
|
import math
|
|
from tinygrad import Tensor, nn, GlobalCounters, TinyJit
|
|
from tinygrad.helpers import getenv, trange
|
|
from extra.models.llama import Attention, FeedForward, precompute_freqs_cis
|
|
|
|
def modulate(x:Tensor, shift:Tensor, scale:Tensor) -> Tensor: return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
|
|
|
# TODO: why doesn't the TimestepEmbedder from minRF work?
|
|
class TimestepEmbedder:
|
|
def __init__(self, hidden_size): self.mlp = [nn.Linear(1, hidden_size), Tensor.silu, nn.Linear(hidden_size, hidden_size)]
|
|
def __call__(self, t:Tensor): return t.reshape(-1, 1).sequential(self.mlp)
|
|
|
|
class TransformerBlock:
|
|
def __init__(self, dim, n_heads, norm_eps=1e-5):
|
|
self.attention = Attention(dim, n_heads)
|
|
self.feed_forward = FeedForward(dim, 4*dim)
|
|
self.attention_norm = nn.LayerNorm(dim, eps=norm_eps)
|
|
self.ffn_norm = nn.LayerNorm(dim, eps=norm_eps)
|
|
self.adaLN_modulation = nn.Linear(dim, 6 * dim, bias=True)
|
|
|
|
def __call__(self, x:Tensor, freqs_cis:Tensor, adaln_input:Tensor):
|
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input.silu()).chunk(6, dim=1)
|
|
x = x + gate_msa.unsqueeze(1) * self.attention(modulate(self.attention_norm(x), shift_msa, scale_msa), 0, freqs_cis)
|
|
x = x + gate_mlp.unsqueeze(1) * self.feed_forward(modulate(self.ffn_norm(x), shift_mlp, scale_mlp))
|
|
return x.contiguous().contiguous_backward()
|
|
|
|
class FinalLayer:
|
|
def __init__(self, dim, patch_size, out_channels):
|
|
self.norm_final = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
|
self.linear = nn.Linear(dim, patch_size*patch_size*out_channels, bias=True)
|
|
self.adaLN_modulation = nn.Linear(dim, 2 * dim, bias=True)
|
|
|
|
# init weights/bias to 0
|
|
self.linear.weight.replace(self.linear.weight.zeros_like().contiguous())
|
|
self.linear.bias.replace(self.linear.bias.zeros_like().contiguous())
|
|
|
|
def __call__(self, x:Tensor, c:Tensor):
|
|
shift, scale = self.adaLN_modulation(c.silu()).chunk(2, dim=1)
|
|
x = modulate(self.norm_final(x), shift, scale)
|
|
return self.linear(x)
|
|
|
|
# channels=1, input_size=32, dim=64, n_layers=6, n_heads=4, num_classes=10
|
|
class DiT_Llama:
|
|
def __init__(self, in_channels=1, dim=64, n_layers=6, n_heads=4, num_classes=10, patch_size=2):
|
|
self.patch_size = patch_size
|
|
self.out_channels = in_channels
|
|
self.num_classes = num_classes
|
|
|
|
self.init_conv_seq = [
|
|
nn.Conv2d(in_channels, dim // 2, kernel_size=5, padding=2, stride=1), Tensor.silu, nn.GroupNorm(32, dim//2),
|
|
nn.Conv2d(dim //2, dim // 2, kernel_size=5, padding=2, stride=1), Tensor.silu, nn.GroupNorm(32, dim//2),
|
|
]
|
|
|
|
self.x_embedder = nn.Linear(self.patch_size * self.patch_size * dim // 2, dim, bias=True)
|
|
self.t_embedder = TimestepEmbedder(dim)
|
|
self.y_embedder = nn.Embedding(num_classes+1, dim)
|
|
self.final_layer = FinalLayer(dim, self.patch_size, self.out_channels)
|
|
|
|
self.freqs_cis = precompute_freqs_cis(dim // n_heads, 4096)
|
|
self.layers = [TransformerBlock(dim, n_heads) for _ in range(n_layers)]
|
|
|
|
def unpatchify(self, x:Tensor):
|
|
c, p = self.out_channels, self.patch_size
|
|
h = w = int(x.shape[1] ** 0.5)
|
|
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
|
|
x = x.rearrange("n h w p q c -> n c h p w q")
|
|
return x.reshape(shape=(x.shape[0], c, h * p, h * p))
|
|
|
|
def patchify(self, x:Tensor):
|
|
B, C, H, W = x.shape
|
|
x = x.reshape(B, C, H // self.patch_size, self.patch_size, W // self.patch_size, self.patch_size)
|
|
x = x.permute(0, 2, 4, 1, 3, 5).flatten(-3).flatten(1, 2)
|
|
return x # B <H*W ish> <C*patch_size*patch_size>
|
|
|
|
def __call__(self, x:Tensor, t:Tensor, y:Tensor) -> Tensor:
|
|
x = x.sequential(self.init_conv_seq)
|
|
x = self.patchify(x)
|
|
x = self.x_embedder(x)
|
|
adaln_input = self.t_embedder(t) + self.y_embedder(y)
|
|
adaln_input = adaln_input.contiguous()
|
|
for layer in self.layers:
|
|
x = layer(x, self.freqs_cis[:, :x.size(1)], adaln_input=adaln_input)
|
|
x = self.final_layer(x, adaln_input)
|
|
return self.unpatchify(x)
|
|
|
|
def rf(self, x:Tensor, cond:Tensor):
|
|
b = x.shape[0]
|
|
# self.ln is True
|
|
t = Tensor.randn((b,)).sigmoid()
|
|
texp = t.view([b, *([1] * len(x.shape[1:]))])
|
|
|
|
# conditional dropout
|
|
dropout_prob = 0.1
|
|
cond = (Tensor.rand(cond.shape[0]) < dropout_prob).where(cond.full_like(self.num_classes), cond)
|
|
|
|
# this is rectified flow
|
|
z1 = Tensor.randn(x.shape) # TODO: add Tensor.randn_like (and friends) to tinygrad
|
|
zt = (1 - texp) * x + texp * z1
|
|
vtheta = self(zt, t, cond)
|
|
|
|
# MSE loss
|
|
return ((z1 - x) - vtheta).square().mean()
|
|
|
|
def sample(self, z, cond, null_cond, sample_steps=50, cfg=2.0):
|
|
b = z.size(0)
|
|
dt = Tensor.full((b,)+(1,)*len(z.shape[1:]), fill_value=1.0/sample_steps).contiguous()
|
|
images = [z]
|
|
for i in range(sample_steps, 0, -1):
|
|
t = Tensor.full((b,), fill_value=i/sample_steps).contiguous()
|
|
vc = self(z, t, cond)
|
|
vu = self(z, t, null_cond)
|
|
vc = vu + cfg * (vc - vu)
|
|
z = z - dt * vc
|
|
z = z.contiguous()
|
|
images.append(z)
|
|
return images
|
|
|
|
def mviz(t:Tensor):
|
|
assert len(t.shape) == 4 and t.shape[1] == 1
|
|
ft = t.permute(1,2,0,3).reshape(32, -1)
|
|
assert ft.shape[-1]%32 == 0
|
|
print("")
|
|
for y in ((ft+1)/2).clamp(0,1).tolist():
|
|
ln = [f"\033[38;5;{232+int(x*23)}m██" for x in y]
|
|
print(''.join(ln) + "\033[0m")
|
|
|
|
if __name__ == "__main__":
|
|
X_train, Y_train, X_test, Y_test = nn.datasets.mnist()
|
|
X_train = X_train.pad((2,2,2,2))
|
|
X_train = ((X_train.float()/255)-0.5)/0.5
|
|
Y_train = Y_train.int()
|
|
|
|
model = DiT_Llama(patch_size=getenv("PATCH_SIZE", 2))
|
|
for r in nn.state.get_parameters(model): r.realize()
|
|
optimizer = nn.optim.Adam(nn.state.get_parameters(model), lr=5e-4)
|
|
|
|
@TinyJit
|
|
@Tensor.train()
|
|
def train_step():
|
|
if getenv("OVERFIT"): samples = Tensor.zeros(getenv("BS", 256), dtype='int')
|
|
else: samples = Tensor.randint(getenv("BS", 256), high=X_train.shape[0])
|
|
optimizer.zero_grad()
|
|
loss = model.rf(X_train[samples], Y_train[samples])
|
|
loss.backward()
|
|
optimizer.step()
|
|
return loss
|
|
|
|
@TinyJit
|
|
@Tensor.test()
|
|
def sample(z:Tensor, cond:Tensor) -> Tensor:
|
|
return model.sample(z, cond, Tensor.full_like(cond, 10), sample_steps=getenv("SAMPLE_STEPS", 20))[-1]
|
|
|
|
for steps in (t:=trange(getenv("STEPS", 5000))):
|
|
if steps%10 == 0: mviz(sample(Tensor.randn(3, 1, 32, 32), Tensor([5,0,4], dtype='int')))
|
|
GlobalCounters.reset()
|
|
loss = train_step()
|
|
t.set_description(f"loss: {loss.item():9.2f}")
|
|
|