You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
			
				
					150 lines
				
				5.7 KiB
			
		
		
			
		
	
	
					150 lines
				
				5.7 KiB
			| 
											4 days ago
										 | from tinygrad import Tensor, dtypes, TinyJit
 | ||
|  | from tinygrad.helpers import fetch
 | ||
|  | from tinygrad.nn.state import safe_load, load_state_dict, get_state_dict
 | ||
|  | from examples.stable_diffusion import AutoencoderKL, get_alphas_cumprod
 | ||
|  | from examples.sdxl import DPMPP2MSampler, append_dims, LegacyDDPMDiscretization
 | ||
|  | from extra.models.unet import UNetModel
 | ||
|  | from extra.models.clip import FrozenOpenClipEmbedder
 | ||
|  | from extra.bench_log import BenchEvent, WallTimeEvent
 | ||
|  | 
 | ||
|  | from typing import Dict
 | ||
|  | import argparse, tempfile, os
 | ||
|  | from pathlib import Path
 | ||
|  | from PIL import Image
 | ||
|  | 
 | ||
|  | class DiffusionModel:
 | ||
|  |   def __init__(self, model:UNetModel):
 | ||
|  |     self.diffusion_model = model
 | ||
|  | 
 | ||
|  | @TinyJit
 | ||
|  | def run(model, x, tms, ctx, c_out, add):
 | ||
|  |   return (model(x, tms, ctx)*c_out + add).realize()
 | ||
|  | 
 | ||
|  | # https://github.com/Stability-AI/stablediffusion/blob/cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf/ldm/models/diffusion/ddpm.py#L521
 | ||
|  | class StableDiffusionV2:
 | ||
|  |   def __init__(self, unet_config:Dict, cond_stage_config:Dict, parameterization:str="v"):
 | ||
|  |     self.model             = DiffusionModel(UNetModel(**unet_config))
 | ||
|  |     self.first_stage_model = AutoencoderKL()
 | ||
|  |     self.cond_stage_model  = FrozenOpenClipEmbedder(**cond_stage_config)
 | ||
|  |     self.alphas_cumprod    = get_alphas_cumprod()
 | ||
|  |     self.parameterization  = parameterization
 | ||
|  | 
 | ||
|  |     self.discretization = LegacyDDPMDiscretization()
 | ||
|  |     self.sigmas = self.discretization(1000, flip=True)
 | ||
|  | 
 | ||
|  |   def denoise(self, x:Tensor, sigma:Tensor, cond:Dict) -> Tensor:
 | ||
|  | 
 | ||
|  |     def sigma_to_idx(s:Tensor) -> Tensor:
 | ||
|  |       dists = s - self.sigmas.unsqueeze(1)
 | ||
|  |       return dists.abs().argmin(axis=0).view(*s.shape)
 | ||
|  | 
 | ||
|  |     sigma = self.sigmas[sigma_to_idx(sigma)]
 | ||
|  |     sigma_shape = sigma.shape
 | ||
|  |     sigma = append_dims(sigma, x)
 | ||
|  | 
 | ||
|  |     c_skip = 1.0 / (sigma**2 + 1.0)
 | ||
|  |     c_out = -sigma / (sigma**2 + 1.0) ** 0.5
 | ||
|  |     c_in = 1.0 / (sigma**2 + 1.0) ** 0.5
 | ||
|  |     c_noise = sigma_to_idx(sigma.reshape(sigma_shape))
 | ||
|  | 
 | ||
|  |     def prep(*tensors:Tensor):
 | ||
|  |       return tuple(t.cast(dtypes.float16).realize() for t in tensors)
 | ||
|  | 
 | ||
|  |     return run(self.model.diffusion_model, *prep(x*c_in, c_noise, cond["crossattn"], c_out, x*c_skip))
 | ||
|  | 
 | ||
|  |   def decode(self, x:Tensor, height:int, width:int) -> Tensor:
 | ||
|  |     x = self.first_stage_model.post_quant_conv(1/0.18215 * x)
 | ||
|  |     x = self.first_stage_model.decoder(x)
 | ||
|  | 
 | ||
|  |     # make image correct size and scale
 | ||
|  |     x = (x + 1.0) / 2.0
 | ||
|  |     x = x.reshape(3,height,width).permute(1,2,0).clip(0,1).mul(255).cast(dtypes.uint8)
 | ||
|  |     return x
 | ||
|  | 
 | ||
|  | params: Dict = {
 | ||
|  |   "unet_config": {
 | ||
|  |     "adm_in_ch": None,
 | ||
|  |     "in_ch": 4,
 | ||
|  |     "out_ch": 4,
 | ||
|  |     "model_ch": 320,
 | ||
|  |     "attention_resolutions": [4, 2, 1],
 | ||
|  |     "num_res_blocks": 2,
 | ||
|  |     "channel_mult": [1, 2, 4, 4],
 | ||
|  |     "d_head": 64,
 | ||
|  |     "transformer_depth": [1, 1, 1, 1],
 | ||
|  |     "ctx_dim": 1024,
 | ||
|  |     "use_linear": True,
 | ||
|  |   },
 | ||
|  |   "cond_stage_config": {
 | ||
|  |     "dims": 1024,
 | ||
|  |     "n_heads": 16,
 | ||
|  |     "layers": 24,
 | ||
|  |     "return_pooled": False,
 | ||
|  |     "ln_penultimate": True,
 | ||
|  |   }
 | ||
|  | }
 | ||
|  | 
 | ||
|  | if __name__ == "__main__":
 | ||
|  |   default_prompt = "a horse sized cat eating a bagel"
 | ||
|  |   parser = argparse.ArgumentParser(description='Run Stable Diffusion v2.X', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
 | ||
|  |   parser.add_argument('--steps',       type=int,   default=10, help="The number of diffusion steps")
 | ||
|  |   parser.add_argument('--prompt',      type=str,   default=default_prompt, help="Description of image to generate")
 | ||
|  |   parser.add_argument('--out',         type=str,   default=Path(tempfile.gettempdir()) / "rendered.png", help="Output filename")
 | ||
|  |   parser.add_argument('--seed',        type=int,   help="Set the random latent seed")
 | ||
|  |   parser.add_argument('--guidance',    type=float, default=7.5, help="Prompt strength")
 | ||
|  |   parser.add_argument('--width',       type=int,   default=768, help="The output image width")
 | ||
|  |   parser.add_argument('--height',      type=int,   default=768, help="The output image height")
 | ||
|  |   parser.add_argument('--weights-fn',  type=str,   help="Filename of weights to use")
 | ||
|  |   parser.add_argument('--weights-url', type=str,   help="Custom URL to download weights from")
 | ||
|  |   parser.add_argument('--timing',      action='store_true', help="Print timing per step")
 | ||
|  |   parser.add_argument('--noshow',      action='store_true', help="Don't show the image")
 | ||
|  |   parser.add_argument('--fp16',        action='store_true', help="Cast the weights to float16")
 | ||
|  |   args = parser.parse_args()
 | ||
|  | 
 | ||
|  |   N = 1
 | ||
|  |   C = 4
 | ||
|  |   F = 8
 | ||
|  |   assert args.width  % F == 0, f"img_width must be multiple of {F}, got {args.width}"
 | ||
|  |   assert args.height % F == 0, f"img_height must be multiple of {F}, got {args.height}"
 | ||
|  | 
 | ||
|  |   if args.seed is not None:
 | ||
|  |     Tensor.manual_seed(args.seed)
 | ||
|  | 
 | ||
|  |   model = StableDiffusionV2(**params)
 | ||
|  | 
 | ||
|  |   default_weights_url = 'https://huggingface.co/stabilityai/stable-diffusion-2-1/resolve/main/v2-1_768-ema-pruned.safetensors'
 | ||
|  |   weights_fn = args.weights_fn
 | ||
|  |   if not weights_fn:
 | ||
|  |     weights_url = args.weights_url if args.weights_url else default_weights_url
 | ||
|  |     weights_fn  = fetch(weights_url, os.path.basename(str(weights_url)))
 | ||
|  | 
 | ||
|  |   with WallTimeEvent(BenchEvent.LOAD_WEIGHTS):
 | ||
|  |     load_state_dict(model, safe_load(weights_fn), strict=False)
 | ||
|  | 
 | ||
|  |     if args.fp16:
 | ||
|  |       for k,v in get_state_dict(model).items():
 | ||
|  |         if k.startswith("model"):
 | ||
|  |           v.replace(v.cast(dtypes.float16).realize())
 | ||
|  | 
 | ||
|  |   c  = { "crossattn": model.cond_stage_model(args.prompt) }
 | ||
|  |   uc = { "crossattn": model.cond_stage_model("") }
 | ||
|  |   del model.cond_stage_model
 | ||
|  |   print("created conditioning")
 | ||
|  | 
 | ||
|  |   shape = (N, C, args.height // F, args.width // F)
 | ||
|  |   randn = Tensor.randn(shape)
 | ||
|  | 
 | ||
|  |   sampler = DPMPP2MSampler(args.guidance)
 | ||
|  |   z = sampler(model.denoise, randn, c, uc, args.steps, timing=args.timing)
 | ||
|  |   print("created samples")
 | ||
|  |   x = model.decode(z, args.height, args.width).realize()
 | ||
|  |   print("decoded samples")
 | ||
|  |   print(x.shape)
 | ||
|  | 
 | ||
|  |   im = Image.fromarray(x.numpy())
 | ||
|  |   print(f"saving {args.out}")
 | ||
|  |   im.save(args.out)
 | ||
|  | 
 | ||
|  |   if not args.noshow:
 | ||
|  |     im.show()
 |