openpilot is an open source driver assistance system. openpilot performs the functions of Automated Lane Centering and Adaptive Cruise Control for over 200 supported car makes and models.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

343 lines
14 KiB

# https://arxiv.org/pdf/2112.10752.pdf
# https://github.com/ekagra-ranjan/huggingface-blog/blob/main/stable_diffusion.md
import tempfile
from pathlib import Path
import argparse, time
from collections import namedtuple
from typing import Dict, Any
from PIL import Image
import numpy as np
from tinygrad import Device, GlobalCounters, dtypes, Tensor, TinyJit
from tinygrad.helpers import Timing, Context, getenv, fetch, colored, tqdm, flatten
from tinygrad.nn import Conv2d, GroupNorm
from tinygrad.nn.state import torch_load, load_state_dict, get_state_dict
from extra.models.clip import Closed, Tokenizer, FrozenOpenClipEmbedder
from extra.models import unet, clip
from extra.models.unet import UNetModel
from examples.mlperf.initializers import AutocastLinear, AutocastConv2d, AutocastGroupNorm, AutocastLayerNorm, zero_module, attn_f32_softmax, gelu_erf
from extra.bench_log import BenchEvent, WallTimeEvent
class AttnBlock:
def __init__(self, in_channels):
self.norm = GroupNorm(32, in_channels)
self.q = Conv2d(in_channels, in_channels, 1)
self.k = Conv2d(in_channels, in_channels, 1)
self.v = Conv2d(in_channels, in_channels, 1)
self.proj_out = Conv2d(in_channels, in_channels, 1)
# copied from AttnBlock in ldm repo
def __call__(self, x):
h_ = self.norm(x)
q,k,v = self.q(h_), self.k(h_), self.v(h_)
# compute attention
b,c,h,w = q.shape
q,k,v = [x.reshape(b,c,h*w).transpose(1,2) for x in (q,k,v)]
h_ = Tensor.scaled_dot_product_attention(q,k,v).transpose(1,2).reshape(b,c,h,w)
return x + self.proj_out(h_)
class ResnetBlock:
def __init__(self, in_channels, out_channels=None):
self.norm1 = GroupNorm(32, in_channels)
self.conv1 = Conv2d(in_channels, out_channels, 3, padding=1)
self.norm2 = GroupNorm(32, out_channels)
self.conv2 = Conv2d(out_channels, out_channels, 3, padding=1)
self.nin_shortcut = Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else lambda x: x
def __call__(self, x):
h = self.conv1(self.norm1(x).swish())
h = self.conv2(self.norm2(h).swish())
return self.nin_shortcut(x) + h
class Mid:
def __init__(self, block_in):
self.block_1 = ResnetBlock(block_in, block_in)
self.attn_1 = AttnBlock(block_in)
self.block_2 = ResnetBlock(block_in, block_in)
def __call__(self, x):
return x.sequential([self.block_1, self.attn_1, self.block_2])
class Decoder:
def __init__(self):
sz = [(128, 256), (256, 512), (512, 512), (512, 512)]
self.conv_in = Conv2d(4,512,3, padding=1)
self.mid = Mid(512)
arr = []
for i,s in enumerate(sz):
arr.append({"block":
[ResnetBlock(s[1], s[0]),
ResnetBlock(s[0], s[0]),
ResnetBlock(s[0], s[0])]})
if i != 0: arr[-1]['upsample'] = {"conv": Conv2d(s[0], s[0], 3, padding=1)}
self.up = arr
self.norm_out = GroupNorm(32, 128)
self.conv_out = Conv2d(128, 3, 3, padding=1)
def __call__(self, x):
x = self.conv_in(x)
x = self.mid(x)
for l in self.up[::-1]:
print("decode", x.shape)
for b in l['block']: x = b(x)
if 'upsample' in l:
# https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html ?
bs,c,py,px = x.shape
x = x.reshape(bs, c, py, 1, px, 1).expand(bs, c, py, 2, px, 2).reshape(bs, c, py*2, px*2)
x = l['upsample']['conv'](x)
x.realize()
return self.conv_out(self.norm_out(x).swish())
class Encoder:
def __init__(self):
sz = [(128, 128), (128, 256), (256, 512), (512, 512)]
self.conv_in = Conv2d(3,128,3, padding=1)
arr = []
for i,s in enumerate(sz):
arr.append({"block":
[ResnetBlock(s[0], s[1]),
ResnetBlock(s[1], s[1])]})
if i != 3: arr[-1]['downsample'] = {"conv": Conv2d(s[1], s[1], 3, stride=2, padding=(0,1,0,1))}
self.down = arr
self.mid = Mid(512)
self.norm_out = GroupNorm(32, 512)
self.conv_out = Conv2d(512, 8, 3, padding=1)
def __call__(self, x):
x = self.conv_in(x)
for l in self.down:
print("encode", x.shape)
for b in l['block']: x = b(x)
if 'downsample' in l: x = l['downsample']['conv'](x)
x = self.mid(x)
return self.conv_out(self.norm_out(x).swish())
class AutoencoderKL:
def __init__(self):
self.encoder = Encoder()
self.decoder = Decoder()
self.quant_conv = Conv2d(8, 8, 1)
self.post_quant_conv = Conv2d(4, 4, 1)
def __call__(self, x):
latent = self.encoder(x)
latent = self.quant_conv(latent)
latent = latent[:, 0:4] # only the means
print("latent", latent.shape)
latent = self.post_quant_conv(latent)
return self.decoder(latent)
def get_alphas_cumprod(beta_start=0.00085, beta_end=0.0120, n_training_steps=1000):
betas = np.linspace(beta_start ** 0.5, beta_end ** 0.5, n_training_steps, dtype=np.float32) ** 2
alphas = 1.0 - betas
alphas_cumprod = np.cumprod(alphas, axis=0)
return Tensor(alphas_cumprod)
unet_params: Dict[str,Any] = {
"adm_in_ch": None,
"in_ch": 4,
"out_ch": 4,
"model_ch": 320,
"attention_resolutions": [4, 2, 1],
"num_res_blocks": 2,
"channel_mult": [1, 2, 4, 4],
"n_heads": 8,
"transformer_depth": [1, 1, 1, 1],
"ctx_dim": 768,
"use_linear": False,
}
mlperf_params: Dict[str,Any] = {"adm_in_ch": None, "in_ch": 4, "out_ch": 4, "model_ch": 320, "attention_resolutions": [4, 2, 1], "num_res_blocks": 2,
"channel_mult": [1, 2, 4, 4], "d_head": 64, "transformer_depth": [1, 1, 1, 1], "ctx_dim": 1024, "use_linear": True,
"num_groups":16, "st_norm_eps":1e-6}
class StableDiffusion:
def __init__(self, version:str|None=None, pretrained:str|None=None):
self.alphas_cumprod = get_alphas_cumprod()
if version != "v2-mlperf-train":
self.first_stage_model = AutoencoderKL() # only needed for decoding generated latents to images; not needed in mlperf training from preprocessed moments
if not version:
self.cond_stage_model = namedtuple("CondStageModel", ["transformer"])(transformer = namedtuple("Transformer", ["text_model"])(text_model = Closed.ClipTextTransformer()))
unet_init_params = unet_params
elif version in {"v2-mlperf-train", "v2-mlperf-eval"}:
unet_init_params = mlperf_params
clip.gelu = gelu_erf
self.cond_stage_model = FrozenOpenClipEmbedder(**{"dims": 1024, "n_heads": 16, "layers": 24, "return_pooled": False, "ln_penultimate": True,
"clip_tokenizer_version": "sd_mlperf_v5_0"})
unet.Linear, unet.Conv2d, unet.GroupNorm, unet.LayerNorm = AutocastLinear, AutocastConv2d, AutocastGroupNorm, AutocastLayerNorm
unet.attention, unet.gelu, unet.mixed_precision_dtype = attn_f32_softmax, gelu_erf, dtypes.bfloat16
if pretrained:
print("loading text encoder")
weights: dict[str,Tensor] = {k.replace("cond_stage_model.", "", 1):v for k,v in torch_load(pretrained)["state_dict"].items() if k.startswith("cond_stage_model.")}
weights["model.attn_mask"] = Tensor.full((77, 77), fill_value=float("-inf")).triu(1)
load_state_dict(self.cond_stage_model, weights)
# only the eval model needs the decoder
if version == "v2-mlperf-eval":
print("loading image latent encoder")
weights = {k.replace("first_stage_model.", "", 1):v for k,v in torch_load(pretrained)["state_dict"].items() if k.startswith("first_stage_model.")}
load_state_dict(self.first_stage_model, weights)
self.model = namedtuple("DiffusionModel", ["diffusion_model"])(diffusion_model = UNetModel(**unet_init_params))
if version == "v2-mlperf-train":
# the mlperf reference inits certain weights as zeroes
for bb in flatten(self.model.diffusion_model.input_blocks) + self.model.diffusion_model.middle_block + flatten(self.model.diffusion_model.output_blocks):
if isinstance(bb, unet.ResBlock):
zero_module(bb.out_layers[3])
elif isinstance(bb, unet.SpatialTransformer):
zero_module(bb.proj_out)
zero_module(self.model.diffusion_model.out[2])
def get_x_prev_and_pred_x0(self, x, e_t, a_t, a_prev):
temperature = 1
sigma_t = 0
sqrt_one_minus_at = (1-a_t).sqrt()
#print(a_t, a_prev, sigma_t, sqrt_one_minus_at)
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
# direction pointing to x_t
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
x_prev = a_prev.sqrt() * pred_x0 + dir_xt
return x_prev, pred_x0
def get_model_output(self, unconditional_context, context, latent, timestep, unconditional_guidance_scale):
# put into diffuser
latents = self.model.diffusion_model(latent.expand(2, *latent.shape[1:]), timestep, unconditional_context.cat(context, dim=0))
unconditional_latent, latent = latents[0:1], latents[1:2]
e_t = unconditional_latent + unconditional_guidance_scale * (latent - unconditional_latent)
return e_t
def decode(self, x):
x = self.first_stage_model.post_quant_conv(1/0.18215 * x)
x = self.first_stage_model.decoder(x)
# make image correct size and scale
x = (x + 1.0) / 2.0
x = x.reshape(3,512,512).permute(1,2,0).clip(0,1)*255
return x.cast(dtypes.uint8)
def __call__(self, unconditional_context, context, latent, timestep, alphas, alphas_prev, guidance):
e_t = self.get_model_output(unconditional_context, context, latent, timestep, guidance)
x_prev, _ = self.get_x_prev_and_pred_x0(latent, e_t, alphas, alphas_prev)
#e_t_next = get_model_output(x_prev)
#e_t_prime = (e_t + e_t_next) / 2
#x_prev, pred_x0 = get_x_prev_and_pred_x0(latent, e_t_prime, index)
return x_prev.realize()
# ** ldm.models.autoencoder.AutoencoderKL (done!)
# 3x512x512 <--> 4x64x64 (16384)
# decode torch.Size([1, 4, 64, 64]) torch.Size([1, 3, 512, 512])
# section 4.3 of paper
# first_stage_model.encoder, first_stage_model.decoder
# ** ldm.modules.diffusionmodules.openaimodel.UNetModel
# this is what runs each time to sample. is this the LDM?
# input: 4x64x64
# output: 4x64x64
# model.diffusion_model
# it has attention?
# ** ldm.modules.encoders.modules.FrozenCLIPEmbedder
# cond_stage_model.transformer.text_model
if __name__ == "__main__":
default_prompt = "a horse sized cat eating a bagel"
parser = argparse.ArgumentParser(description='Run Stable Diffusion', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--steps', type=int, default=6, help="Number of steps in diffusion")
parser.add_argument('--prompt', type=str, default=default_prompt, help="Phrase to render")
parser.add_argument('--out', type=str, default=Path(tempfile.gettempdir()) / "rendered.png", help="Output filename")
parser.add_argument('--noshow', action='store_true', help="Don't show the image")
parser.add_argument('--fp16', action='store_true', help="Cast the weights to float16")
parser.add_argument('--timing', action='store_true', help="Print timing per step")
parser.add_argument('--seed', type=int, help="Set the random latent seed")
parser.add_argument('--guidance', type=float, default=7.5, help="Prompt strength")
parser.add_argument('--fakeweights', action='store_true', help="Skip loading checkpoints and use fake weights")
args = parser.parse_args()
model = StableDiffusion()
# load in weights
with WallTimeEvent(BenchEvent.LOAD_WEIGHTS):
if not args.fakeweights:
model_bin = fetch('https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt', 'sd-v1-4.ckpt')
load_state_dict(model, torch_load(model_bin)['state_dict'], verbose=False, strict=False, realize=False)
if args.fp16:
for k,v in get_state_dict(model).items():
if k.startswith("model"):
v.replace(v.cast(dtypes.float16))
Tensor.realize(*get_state_dict(model).values())
# run through CLIP to get context
tokenizer = Tokenizer.ClipTokenizer()
prompt = Tensor([tokenizer.encode(args.prompt)])
context = model.cond_stage_model.transformer.text_model(prompt).realize()
print("got CLIP context", context.shape)
prompt = Tensor([tokenizer.encode("")])
unconditional_context = model.cond_stage_model.transformer.text_model(prompt).realize()
print("got unconditional CLIP context", unconditional_context.shape)
# done with clip model
del model.cond_stage_model
timesteps = list(range(1, 1000, 1000//args.steps))
print(f"running for {timesteps} timesteps")
alphas = model.alphas_cumprod[Tensor(timesteps)]
alphas_prev = Tensor([1.0]).cat(alphas[:-1])
# start with random noise
if args.seed is not None: Tensor.manual_seed(args.seed)
latent = Tensor.randn(1,4,64,64)
@TinyJit
def run(model, *x): return model(*x).realize()
# this is diffusion
step_times = []
with Context(BEAM=getenv("LATEBEAM")):
for index, timestep in (t:=tqdm(list(enumerate(timesteps))[::-1])):
GlobalCounters.reset()
st = time.perf_counter_ns()
t.set_description("%3d %3d" % (index, timestep))
with Timing("step in ", enabled=args.timing, on_exit=lambda _: f", using {GlobalCounters.mem_used/1e9:.2f} GB"):
with WallTimeEvent(BenchEvent.STEP):
tid = Tensor([index])
latent = run(model, unconditional_context, context, latent, Tensor([timestep]), alphas[tid], alphas_prev[tid], Tensor([args.guidance]))
if args.timing: Device[Device.DEFAULT].synchronize()
step_times.append((time.perf_counter_ns() - st)*1e-6)
del run
if (assert_time:=getenv("ASSERT_MIN_STEP_TIME")):
min_time = min(step_times)
assert min_time < assert_time, f"Speed regression, expected min step time of < {assert_time} ms but took: {min_time} ms"
# upsample latent space to image with autoencoder
x = model.decode(latent)
print(x.shape)
# save image
im = Image.fromarray(x.numpy())
print(f"saving {args.out}")
im.save(args.out)
# Open image.
if not args.noshow: im.show()
# validation!
if args.prompt == default_prompt and args.steps == 6 and args.seed == 0 and args.guidance == 7.5:
ref_image = Tensor(np.array(Image.open(Path(__file__).parent / "stable_diffusion_seed0.png")))
distance = (((x.cast(dtypes.float) - ref_image.cast(dtypes.float)) / ref_image.max())**2).mean().item()
assert distance < 3e-3, colored(f"validation failed with {distance=}", "red") # higher distance with WINO
print(colored(f"output validated with {distance=}", "green"))