You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
549 lines
26 KiB
549 lines
26 KiB
import time, math, os
|
|
start = time.perf_counter()
|
|
from pathlib import Path
|
|
import numpy as np
|
|
from tinygrad import Tensor, Device, dtypes, GlobalCounters, TinyJit
|
|
from tinygrad.nn.state import get_parameters, load_state_dict, safe_load
|
|
from tinygrad.helpers import getenv, Context, prod
|
|
from extra.bench_log import BenchEvent, WallTimeEvent
|
|
def tlog(x): print(f"{x:25s} @ {time.perf_counter()-start:5.2f}s")
|
|
|
|
def eval_resnet():
|
|
with WallTimeEvent(BenchEvent.FULL):
|
|
# Resnet50-v1.5
|
|
from extra.models.resnet import ResNet50
|
|
tlog("imports")
|
|
GPUS = [f'{Device.DEFAULT}:{i}' for i in range(getenv("GPUS", 6))]
|
|
for x in GPUS: Device[x]
|
|
tlog("got devices") # NOTE: this is faster with rocm-smi running
|
|
|
|
class ResnetRunner:
|
|
def __init__(self, device=None):
|
|
self.mdl = ResNet50()
|
|
for x in get_parameters(self.mdl) if device else []: x.to_(device)
|
|
if (fn:=getenv("RESNET_MODEL", "")): load_state_dict(self.mdl, safe_load(fn))
|
|
else: self.mdl.load_from_pretrained()
|
|
self.input_mean = Tensor([0.485, 0.456, 0.406], device=device).reshape(1, -1, 1, 1)
|
|
self.input_std = Tensor([0.229, 0.224, 0.225], device=device).reshape(1, -1, 1, 1)
|
|
def __call__(self, x:Tensor) -> Tensor:
|
|
x = x.permute([0,3,1,2]).cast(dtypes.float32) / 255.0
|
|
x -= self.input_mean
|
|
x /= self.input_std
|
|
return self.mdl(x).log_softmax().argmax(axis=1).realize()
|
|
|
|
mdl = TinyJit(ResnetRunner(GPUS))
|
|
tlog("loaded models")
|
|
|
|
# evaluation on the mlperf classes of the validation set from imagenet
|
|
from examples.mlperf.dataloader import batch_load_resnet
|
|
iterator = batch_load_resnet(getenv("BS", 128*6), val=getenv("VAL", 1), shuffle=False, pad_first_batch=True)
|
|
def data_get():
|
|
x,y,cookie = next(iterator)
|
|
return x.shard(GPUS, axis=0).realize(), y, cookie
|
|
n,d = 0,0
|
|
proc = data_get()
|
|
tlog("loaded initial data")
|
|
st = time.perf_counter()
|
|
while proc is not None:
|
|
GlobalCounters.reset()
|
|
proc = (mdl(proc[0]), proc[1], proc[2]) # this frees the images
|
|
run = time.perf_counter()
|
|
# load the next data here
|
|
try: next_proc = data_get()
|
|
except StopIteration: next_proc = None
|
|
nd = time.perf_counter()
|
|
y = np.array(proc[1])
|
|
proc = (proc[0].numpy() == y) & (y != -1) # this realizes the models and frees the cookies
|
|
n += proc.sum()
|
|
d += (y != -1).sum()
|
|
et = time.perf_counter()
|
|
tlog(f"****** {n:5d}/{d:5d} {n*100.0/d:.2f}% -- {(run-st)*1000:7.2f} ms to enqueue, {(et-run)*1000:7.2f} ms to realize ({(nd-run)*1000:7.2f} ms fetching). {(len(proc))/(et-st):8.2f} examples/sec. {GlobalCounters.global_ops*1e-12/(et-st):5.2f} TFLOPS")
|
|
st = et
|
|
proc, next_proc = next_proc, None
|
|
tlog("done")
|
|
|
|
def eval_unet3d():
|
|
# UNet3D
|
|
from extra.models.unet3d import UNet3D
|
|
from extra.datasets.kits19 import iterate, sliding_window_inference, get_val_files
|
|
from examples.mlperf.metrics import dice_score
|
|
mdl = UNet3D()
|
|
mdl.load_from_pretrained()
|
|
s = 0
|
|
st = time.perf_counter()
|
|
for i, (image, label) in enumerate(iterate(get_val_files()), start=1):
|
|
mt = time.perf_counter()
|
|
pred, label = sliding_window_inference(mdl, image, label)
|
|
et = time.perf_counter()
|
|
print(f"{(mt-st)*1000:.2f} ms loading data, {(et-mt)*1000:.2f} ms to run model")
|
|
s += dice_score(Tensor(pred), Tensor(label)).mean().item()
|
|
print(f"****** {s:.2f}/{i} {s/i:.5f} Mean DICE score")
|
|
st = time.perf_counter()
|
|
|
|
def eval_retinanet():
|
|
# RetinaNet with ResNeXt50_32X4D
|
|
from examples.mlperf.dataloader import batch_load_retinanet
|
|
from extra.datasets.openimages import normalize, download_dataset, BASEDIR
|
|
from extra.models.resnet import ResNeXt50_32X4D
|
|
from extra.models.retinanet import RetinaNet
|
|
from pycocotools.coco import COCO
|
|
from pycocotools.cocoeval import COCOeval
|
|
from contextlib import redirect_stdout
|
|
tlog("imports")
|
|
|
|
mdl = RetinaNet(ResNeXt50_32X4D())
|
|
mdl.load_from_pretrained()
|
|
tlog("loaded models")
|
|
|
|
coco = COCO(download_dataset(base_dir:=getenv("BASEDIR", BASEDIR), 'validation'))
|
|
coco_eval = COCOeval(coco, iouType="bbox")
|
|
coco_evalimgs, evaluated_imgs, ncats, narea = [], [], len(coco_eval.params.catIds), len(coco_eval.params.areaRng)
|
|
tlog("loaded dataset")
|
|
|
|
iterator = batch_load_retinanet(coco, True, Path(base_dir), getenv("BS", 8), shuffle=False)
|
|
def data_get():
|
|
x, img_ids, img_sizes, cookie = next(iterator)
|
|
return x.to(Device.DEFAULT).realize(), img_ids, img_sizes, cookie
|
|
n = 0
|
|
proc = data_get()
|
|
tlog("loaded initial data")
|
|
st = time.perf_counter()
|
|
while proc is not None:
|
|
GlobalCounters.reset()
|
|
proc = (mdl(normalize(proc[0])), proc[1], proc[2], proc[3])
|
|
run = time.perf_counter()
|
|
# load the next data here
|
|
try: next_proc = data_get()
|
|
except StopIteration: next_proc = None
|
|
nd = time.perf_counter()
|
|
predictions, img_ids = mdl.postprocess_detections(proc[0].numpy(), orig_image_sizes=proc[2]), proc[1]
|
|
pd = time.perf_counter()
|
|
coco_results = [{"image_id": img_ids[i], "category_id": label, "bbox": box.tolist(), "score": score}
|
|
for i, prediction in enumerate(predictions) for box, score, label in zip(*prediction.values())]
|
|
with redirect_stdout(None):
|
|
coco_eval.cocoDt = coco.loadRes(coco_results)
|
|
coco_eval.params.imgIds = img_ids
|
|
coco_eval.evaluate()
|
|
evaluated_imgs.extend(img_ids)
|
|
coco_evalimgs.append(np.array(coco_eval.evalImgs).reshape(ncats, narea, len(img_ids)))
|
|
n += len(proc[0])
|
|
et = time.perf_counter()
|
|
tlog(f"****** {(run-st)*1000:7.2f} ms to enqueue, {(et-run)*1000:7.2f} ms to realize ({(nd-run)*1000:7.2f} ms fetching, {(pd-run)*1000:4.2f} ms postprocess_detections). {(len(proc))/(et-st):8.2f} examples/sec. {GlobalCounters.global_ops*1e-12/(et-st):5.2f} TFLOPS")
|
|
st = et
|
|
proc, next_proc = next_proc, None
|
|
|
|
coco_eval.params.imgIds = evaluated_imgs
|
|
coco_eval._paramsEval.imgIds = evaluated_imgs
|
|
coco_eval.evalImgs = list(np.concatenate(coco_evalimgs, -1).flatten())
|
|
coco_eval.accumulate()
|
|
coco_eval.summarize()
|
|
tlog("done")
|
|
|
|
def eval_rnnt():
|
|
# RNN-T
|
|
from extra.models.rnnt import RNNT
|
|
mdl = RNNT()
|
|
mdl.load_from_pretrained()
|
|
|
|
from extra.datasets.librispeech import iterate
|
|
from examples.mlperf.metrics import word_error_rate
|
|
|
|
LABELS = [" ", "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", "'"]
|
|
|
|
c = 0
|
|
scores = 0
|
|
words = 0
|
|
st = time.perf_counter()
|
|
for X, Y in iterate():
|
|
mt = time.perf_counter()
|
|
tt = mdl.decode(Tensor(X[0]), Tensor([X[1]]))
|
|
et = time.perf_counter()
|
|
print(f"{(mt-st)*1000:.2f} ms loading data, {(et-mt)*1000:.2f} ms to run model")
|
|
for n, t in enumerate(tt):
|
|
tnp = np.array(t)
|
|
_, scores_, words_ = word_error_rate(["".join([LABELS[int(tnp[i])] for i in range(tnp.shape[0])])], [Y[n]])
|
|
scores += scores_
|
|
words += words_
|
|
c += len(tt)
|
|
print(f"WER: {scores/words}, {words} words, raw scores: {scores}, c: {c}")
|
|
st = time.perf_counter()
|
|
|
|
def eval_bert():
|
|
# Bert-QA
|
|
from extra.models.bert import BertForQuestionAnswering
|
|
mdl = BertForQuestionAnswering()
|
|
mdl.load_from_pretrained()
|
|
|
|
@TinyJit
|
|
def run(input_ids, input_mask, segment_ids):
|
|
return mdl(input_ids, input_mask, segment_ids).realize()
|
|
|
|
from extra.datasets.squad import iterate
|
|
from examples.mlperf.helpers import get_bert_qa_prediction
|
|
from examples.mlperf.metrics import f1_score
|
|
from transformers import BertTokenizer
|
|
|
|
tokenizer = BertTokenizer(str(Path(__file__).parents[2] / "extra/weights/bert_vocab.txt"))
|
|
|
|
c = 0
|
|
f1 = 0.0
|
|
st = time.perf_counter()
|
|
for X, Y in iterate(tokenizer):
|
|
mt = time.perf_counter()
|
|
outs = []
|
|
for x in X:
|
|
outs.append(run(Tensor(x["input_ids"]), Tensor(x["input_mask"]), Tensor(x["segment_ids"])).numpy())
|
|
et = time.perf_counter()
|
|
print(f"{(mt-st)*1000:.2f} ms loading data, {(et-mt)*1000:.2f} ms to run model over {len(X)} features")
|
|
|
|
pred = get_bert_qa_prediction(X, Y, outs)
|
|
print(f"pred: {pred}\nans: {Y['answers']}")
|
|
f1 += max([f1_score(pred, ans) for ans in Y["answers"]])
|
|
c += 1
|
|
print(f"f1: {f1/c}, raw: {f1}, c: {c}\n")
|
|
|
|
st = time.perf_counter()
|
|
|
|
def eval_mrcnn():
|
|
from tqdm import tqdm
|
|
from extra.models.mask_rcnn import MaskRCNN
|
|
from extra.models.resnet import ResNet
|
|
from extra.datasets.coco import BASEDIR, images, convert_prediction_to_coco_bbox, convert_prediction_to_coco_mask, accumulate_predictions_for_coco, evaluate_predictions_on_coco, iterate
|
|
from examples.mask_rcnn import compute_prediction_batched, Image
|
|
mdl = MaskRCNN(ResNet(50, num_classes=None, stride_in_1x1=True))
|
|
mdl.load_from_pretrained()
|
|
|
|
bbox_output = '/tmp/results_bbox.json'
|
|
mask_output = '/tmp/results_mask.json'
|
|
|
|
accumulate_predictions_for_coco([], bbox_output, rm=True)
|
|
accumulate_predictions_for_coco([], mask_output, rm=True)
|
|
|
|
#TODO: bs > 1 not as accurate
|
|
bs = 1
|
|
|
|
for batch in tqdm(iterate(images, bs=bs), total=len(images)//bs):
|
|
batch_imgs = []
|
|
for image_row in batch:
|
|
image_name = image_row['file_name']
|
|
img = Image.open(BASEDIR/f'val2017/{image_name}').convert("RGB")
|
|
batch_imgs.append(img)
|
|
batch_result = compute_prediction_batched(batch_imgs, mdl)
|
|
for image_row, result in zip(batch, batch_result):
|
|
image_name = image_row['file_name']
|
|
box_pred = convert_prediction_to_coco_bbox(image_name, result)
|
|
mask_pred = convert_prediction_to_coco_mask(image_name, result)
|
|
accumulate_predictions_for_coco(box_pred, bbox_output)
|
|
accumulate_predictions_for_coco(mask_pred, mask_output)
|
|
del batch_imgs
|
|
del batch_result
|
|
|
|
evaluate_predictions_on_coco(bbox_output, iou_type='bbox')
|
|
evaluate_predictions_on_coco(mask_output, iou_type='segm')
|
|
|
|
def eval_llama3():
|
|
from extra.models.llama import Transformer
|
|
from examples.llama3 import MODEL_PARAMS, load, convert_from_huggingface
|
|
from tinygrad.helpers import tqdm
|
|
|
|
BASEDIR = Path(getenv("BASEDIR", "/raid/datasets/c4/"))
|
|
BS = getenv("BS", 4)
|
|
SMALL = getenv("SMALL", 0)
|
|
SEQLEN = getenv("SEQLEN", 8192)
|
|
MODEL_PATH = Path(getenv("MODEL_PATH", "/raid/weights/llama31_8b/"))
|
|
|
|
params = MODEL_PARAMS[getenv("LLAMA3_SIZE", "8B")]["args"]
|
|
params = params | {"vocab_size": 32000} if not SMALL else params
|
|
if (llama_layers:=getenv("LLAMA_LAYERS")) != 0: params['n_layers'] = llama_layers
|
|
model = Transformer(**params, max_context=SEQLEN, jit=False, disable_kv_cache=True)
|
|
|
|
# load weights
|
|
weights = load(str(MODEL_PATH / "model.safetensors.index.json"))
|
|
if "model.embed_tokens.weight" in weights:
|
|
print("converting from huggingface format")
|
|
weights = convert_from_huggingface(weights, params["n_layers"], params["n_heads"], params["n_kv_heads"])
|
|
|
|
load_state_dict(model, weights, strict=False, consume=True)
|
|
|
|
@TinyJit
|
|
def eval_step(model, tokens):
|
|
logits:Tensor = model(tokens[:, :-1], start_pos=0, temperature=math.nan)
|
|
loss = logits.sparse_categorical_crossentropy(tokens[:, 1:])
|
|
return loss.flatten().float()
|
|
|
|
if SMALL:
|
|
from examples.mlperf.dataloader import batch_load_llama3_small
|
|
iter = batch_load_llama3_small(BS, 5760, SEQLEN, BASEDIR, val=True)
|
|
else:
|
|
from examples.mlperf.dataloader import batch_load_llama3
|
|
iter = batch_load_llama3(BS, 5760, SEQLEN, BASEDIR, val=True)
|
|
|
|
losses = []
|
|
for tokens in tqdm(iter, total=5760//BS):
|
|
GlobalCounters.reset()
|
|
losses += eval_step(model, tokens).tolist()
|
|
tqdm.write(f"loss: {np.mean(losses)}")
|
|
|
|
log_perplexity = np.mean(losses)
|
|
print(f"Log Perplexity: {log_perplexity}")
|
|
|
|
# NOTE: BEAM hangs on 8xmi300x with DECODE_BS=384 in final realize below; function is declared here for external testing
|
|
@TinyJit
|
|
def vae_decode(x:Tensor, vae, disable_beam=False) -> Tensor:
|
|
from examples.stable_diffusion import AutoencoderKL
|
|
assert isinstance(vae, AutoencoderKL)
|
|
x = vae.post_quant_conv(1./0.18215 * x)
|
|
|
|
x = vae.decoder.conv_in(x)
|
|
x = vae.decoder.mid(x)
|
|
for i, l in enumerate(vae.decoder.up[::-1]):
|
|
print("decode", x.shape)
|
|
for b in l['block']: x = b(x)
|
|
if 'upsample' in l:
|
|
bs,c,py,px = x.shape
|
|
x = x.reshape(bs, c, py, 1, px, 1).expand(bs, c, py, 2, px, 2).reshape(bs, c, py*2, px*2)
|
|
x = l['upsample']['conv'](x)
|
|
if i == len(vae.decoder.up) - 1 and disable_beam:
|
|
with Context(BEAM=0): x.realize()
|
|
else: x.realize()
|
|
x = vae.decoder.conv_out(vae.decoder.norm_out(x).swish())
|
|
|
|
x = ((x + 1.0) / 2.0).clip(0.0, 1.0)
|
|
return x
|
|
|
|
def eval_stable_diffusion():
|
|
import csv, PIL, sys
|
|
from tqdm import tqdm
|
|
from examples.mlperf.initializers import init_stable_diffusion, gelu_erf
|
|
from examples.stable_diffusion import AutoencoderKL
|
|
from extra.models.unet import UNetModel
|
|
from tinygrad.nn.state import load_state_dict, torch_load
|
|
from tinygrad.helpers import BEAM
|
|
from extra.models import clip
|
|
from extra.models.clip import FrozenOpenClipEmbedder
|
|
from extra.models.clip import OpenClipEncoder
|
|
from extra.models.inception import FidInceptionV3
|
|
|
|
config = {}
|
|
GPUS = config["GPUS"] = [f"{Device.DEFAULT}:{i}" for i in range(getenv("GPUS", 1))]
|
|
for x in GPUS: Device[x]
|
|
print(f"running eval on {GPUS}")
|
|
seed = config["seed"] = getenv("SEED", 12345)
|
|
CKPTDIR = config["CKPTDIR"] = Path(getenv("CKPTDIR", "./checkpoints"))
|
|
DATADIR = config["DATADIR"] = Path(getenv("DATADIR", "./datasets"))
|
|
CONTEXT_BS = config["CONTEXT_BS"] = getenv("CONTEXT_BS", 1 * len(GPUS))
|
|
DENOISE_BS = config["DENOISE_BS"] = getenv("DENOISE_BS", 1 * len(GPUS))
|
|
DECODE_BS = config["DECODE_BS"] = getenv("DECODE_BS", 1 * len(GPUS))
|
|
INCEPTION_BS = config["INCEPTION_BS"] = getenv("INCEPTION_BS", 1 * len(GPUS))
|
|
CLIP_BS = config["CLIP_BS"] = getenv("CLIP_BS", 1 * len(GPUS))
|
|
EVAL_CKPT_DIR = config["EVAL_CKPT_DIR"] = getenv("EVAL_CKPT_DIR", "")
|
|
STOP_IF_CONVERGED = config["STOP_IF_CONVERGED"] = getenv("STOP_IF_CONVERGED", 0)
|
|
|
|
if (WANDB := getenv("WANDB", "")):
|
|
import wandb
|
|
wandb.init(config=config, project="MLPerf-Stable-Diffusion")
|
|
|
|
assert EVAL_CKPT_DIR != "", "provide a directory with checkpoints to be evaluated"
|
|
print(f"running eval on checkpoints in {EVAL_CKPT_DIR}\nSEED={seed}")
|
|
eval_queue:list[tuple[int, Path]] = []
|
|
for p in Path(EVAL_CKPT_DIR).iterdir():
|
|
if p.name.endswith(".safetensors"):
|
|
ckpt_iteration = p.name.split(".safetensors")[0]
|
|
assert ckpt_iteration.isdigit(), f"invalid checkpoint name: {p.name}, expected <digits>.safetensors"
|
|
eval_queue.append((int(ckpt_iteration), p))
|
|
assert len(eval_queue), f'no files ending with ".safetensors" were found in {EVAL_CKPT_DIR}'
|
|
print(sorted(eval_queue, reverse=True))
|
|
|
|
Tensor.manual_seed(seed) # seed for weight initialization
|
|
model, unet, sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod = init_stable_diffusion("v2-mlperf-eval", CKPTDIR / "sd" / "512-base-ema.ckpt", GPUS)
|
|
|
|
# load prompts for generating images for validation; 2 MB of data total
|
|
with open(DATADIR / "coco2014" / "val2014_30k.tsv") as f:
|
|
reader = csv.DictReader(f, delimiter="\t")
|
|
eval_inputs:list[dict] = [{"image_id": int(row["image_id"]), "id": int(row["id"]), "caption": row["caption"]} for row in reader]
|
|
assert len(eval_inputs) == 30_000
|
|
# NOTE: the clip weights are the same between model.cond_stage_model and clip_encoder
|
|
eval_timesteps = list(reversed(range(1, 1000, 20)))
|
|
|
|
original_device, Device.DEFAULT = Device.DEFAULT, "CPU"
|
|
# The choice of alphas_prev[0] = alphas_cumprod[0] seems arbitrary, but it's how the mlperf ref does it:
|
|
# alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
|
|
eval_alphas_prev = model.alphas_cumprod[0:1].cat(model.alphas_cumprod[list(range(1, 1000, 20))[:-1]]).to(GPUS).realize()
|
|
inception = FidInceptionV3().load_from_pretrained(CKPTDIR / "inception" / "pt_inception-2015-12-05-6726825d.pth")
|
|
vision_cfg = {'width': 1280, 'layers': 32, 'd_head': 80, 'image_size': 224, 'patch_size': 14}
|
|
text_cfg = {'width': 1024, 'n_heads': 16, 'layers': 24, 'vocab_size': 49408, 'ctx_length': 77}
|
|
clip.gelu = gelu_erf
|
|
clip_encoder = OpenClipEncoder(1024, text_cfg, vision_cfg)
|
|
loaded = torch_load(CKPTDIR / "clip" / "open_clip_pytorch_model.bin")
|
|
loaded.update({"attn_mask": clip_encoder.attn_mask, "mean": clip_encoder.mean, "std": clip_encoder.std})
|
|
load_state_dict(clip_encoder, loaded)
|
|
Device.DEFAULT=original_device
|
|
|
|
@TinyJit
|
|
def denoise_step(x:Tensor, x_x:Tensor, t_t:Tensor, uc_c:Tensor, sqrt_alphas_cumprod_t:Tensor, sqrt_one_minus_alphas_cumprod_t:Tensor,
|
|
alpha_prev:Tensor, unet:UNetModel, GPUS) -> Tensor:
|
|
out_uncond, out = unet(x_x, t_t, uc_c).to("CPU").reshape(-1, 2, 4, 64, 64).chunk(2, dim=1)
|
|
out_uncond = out_uncond.squeeze(1).shard(GPUS,axis=0)
|
|
out = out.squeeze(1).shard(GPUS,axis=0)
|
|
v_t = out_uncond + 8.0 * (out - out_uncond)
|
|
e_t = sqrt_alphas_cumprod_t * v_t + sqrt_one_minus_alphas_cumprod_t * x
|
|
pred_x0 = sqrt_alphas_cumprod_t * x - sqrt_one_minus_alphas_cumprod_t * v_t
|
|
dir_xt = (1. - alpha_prev).sqrt() * e_t
|
|
x_prev = alpha_prev.sqrt() * pred_x0 + dir_xt
|
|
return x_prev.realize()
|
|
|
|
def shard_tensor(t:Tensor) -> Tensor: return t.shard(GPUS, axis=0) if len(GPUS) > 1 else t.to(GPUS[0])
|
|
def get_batch(whole:Tensor, i:int, bs:int) -> tuple[Tensor, int]:
|
|
batch = whole[i: i + bs].to("CPU")
|
|
if (unpadded_bs:=batch.shape[0]) < bs:
|
|
batch = batch.cat(batch[-1:].expand(bs - unpadded_bs, *batch[-1].shape))
|
|
return batch, unpadded_bs
|
|
|
|
@Tensor.train(mode=False)
|
|
def eval_unet(eval_inputs:list[dict], unet:UNetModel, cond_stage:FrozenOpenClipEmbedder, first_stage:AutoencoderKL,
|
|
inception:FidInceptionV3, clip:OpenClipEncoder) -> tuple[float, float]:
|
|
# Eval is divided into 5 jits, one per model
|
|
# It doesn't make sense to merge these jits, e.g. unet repeats 50 times in isolation; images fork to separate inception/clip
|
|
# We're generating and scoring 30,000 images per eval, and all the data can flow through one jit at a time
|
|
# To maximize throughput for each jit, we have only one model/jit on the GPU at a time, and pool outputs from each jit off-GPU
|
|
for model in (unet, first_stage, inception, clip):
|
|
Tensor.realize(*[p.to_("CPU") for p in get_parameters(model)])
|
|
|
|
uc_written = False
|
|
models = (cond_stage, unet, first_stage, inception, clip)
|
|
jits = (jit_context:=TinyJit(cond_stage.embed_tokens), denoise_step, vae_decode, jit_inception:=TinyJit(inception),
|
|
jit_clip:=TinyJit(clip.get_clip_score))
|
|
all_bs = (CONTEXT_BS, DENOISE_BS, DECODE_BS, INCEPTION_BS, CLIP_BS)
|
|
if (EVAL_SAMPLES:=getenv("EVAL_SAMPLES", 0)) and EVAL_SAMPLES > 0:
|
|
eval_inputs = eval_inputs[0:EVAL_SAMPLES]
|
|
output_shapes = [(ns:=len(eval_inputs),77), (ns,77,1024), (ns,4,64,64), (ns,3,512,512), (ns,2048), (ns,)]
|
|
# Writing progress to disk lets us resume eval if we crash
|
|
stages = ["tokens", "embeds", "latents", "imgs", "inception", "clip"]
|
|
disk_tensor_names, disk_tensor_shapes = stages + ["end", "uc"], output_shapes + [(6,), (1,77,1024)]
|
|
if not all(os.path.exists(f"{EVAL_CKPT_DIR}/{name}.bytes") for name in disk_tensor_names):
|
|
for name, shape in zip(disk_tensor_names, disk_tensor_shapes):
|
|
file = Path(f"{EVAL_CKPT_DIR}/{name}.bytes")
|
|
file.unlink(missing_ok=True)
|
|
with file.open("wb") as f: f.truncate(prod(shape) * 4)
|
|
progress = {name: Tensor.empty(*shape, device=f"disk:{EVAL_CKPT_DIR}/{name}.bytes", dtype=dtypes.int if name in {"tokens", "end"} else dtypes.float)
|
|
for name, shape in zip(disk_tensor_names, disk_tensor_shapes)}
|
|
|
|
def embed_tokens(tokens:Tensor) -> Tensor:
|
|
nonlocal uc_written
|
|
if not uc_written:
|
|
with Context(BEAM=0): progress["uc"].assign(cond_stage.embed_tokens(cond_stage.tokenize("").to(GPUS)).to("CPU").realize()).realize()
|
|
uc_written = True
|
|
return jit_context(shard_tensor(tokens))
|
|
|
|
def generate_latents(embeds:Tensor) -> Tensor:
|
|
uc_c = Tensor.stack(progress["uc"].to("CPU").expand(bs, 77, 1024), embeds, dim=1).reshape(-1, 77, 1024)
|
|
uc_c = shard_tensor(uc_c)
|
|
x = shard_tensor(Tensor.randn(bs,4,64,64))
|
|
for step_idx, timestep in enumerate(tqdm(eval_timesteps)):
|
|
reversed_idx = Tensor([50 - step_idx - 1], device=GPUS)
|
|
alpha_prev = eval_alphas_prev[reversed_idx]
|
|
ts = Tensor.full(bs, fill_value=timestep, dtype=dtypes.int, device="CPU")
|
|
ts_ts = shard_tensor(ts.cat(ts))
|
|
ts = shard_tensor(ts)
|
|
sqrt_alphas_cumprod_t = sqrt_alphas_cumprod[ts].reshape(bs, 1, 1, 1)
|
|
sqrt_one_minus_alphas_cumprod_t = sqrt_one_minus_alphas_cumprod[ts].reshape(bs, 1, 1, 1)
|
|
x_x = shard_tensor(Tensor.stack(x.to("CPU"), x.to("CPU"), dim=1).reshape(-1, 4, 64, 64))
|
|
x.assign(denoise_step(x, x_x, ts_ts, uc_c, sqrt_alphas_cumprod_t, sqrt_one_minus_alphas_cumprod_t, alpha_prev, unet, GPUS)).realize()
|
|
return x
|
|
|
|
def decode_latents(latents:Tensor) -> Tensor: return vae_decode(shard_tensor(latents), first_stage, disable_beam=True)
|
|
def generate_inception(imgs:Tensor) -> Tensor: return jit_inception(shard_tensor(imgs))[:,:,0,0]
|
|
|
|
def calc_clip_scores(batch:Tensor, batch_tokens:Tensor) -> Tensor:
|
|
# Tensor.interpolate does not yet support bicubic, so we use PIL
|
|
batch = (batch.to(GPUS[0]).permute(0,2,3,1) * 255).clip(0, 255).cast(dtypes.uint8).numpy()
|
|
batch = [np.array(PIL.Image.fromarray(batch[i]).resize((224,224), PIL.Image.BICUBIC)) for i in range(bs)]
|
|
batch = shard_tensor(Tensor(np.stack(batch, axis=0).transpose(0,3,1,2), device="CPU").realize())
|
|
batch = batch.cast(dtypes.float) / 255
|
|
batch = (batch - model.mean) / model.std
|
|
batch = jit_clip(shard_tensor(batch_tokens), batch)
|
|
return batch
|
|
|
|
callbacks = (embed_tokens, generate_latents, decode_latents, generate_inception, calc_clip_scores)
|
|
|
|
# save every forward pass output to disk; NOTE: this needs ~100 GB disk space because 30k images are large
|
|
def stage_progress(stage_idx:int) -> int: return progress["end"].to("CPU")[stage_idx].item()
|
|
if stage_progress(0) < len(eval_inputs):
|
|
tokens = []
|
|
for i in tqdm(range(0, len(eval_inputs), CONTEXT_BS)):
|
|
subset = [cond_stage.tokenize(row["caption"], device="CPU") for row in eval_inputs[i: i+CONTEXT_BS]]
|
|
tokens.append(Tensor.cat(*subset, dim=0).realize())
|
|
progress["tokens"].assign(Tensor.cat(*tokens, dim=0).realize()).realize()
|
|
progress["end"][0:1].assign(Tensor([len(eval_inputs)], dtype=dtypes.int)).realize()
|
|
prev_stage = "tokens"
|
|
tokens = progress["tokens"]
|
|
|
|
# wrapper code for every model
|
|
for stage_idx, model, jit, bs, callback in zip(range(1,6), models, jits, all_bs, callbacks):
|
|
stage = stages[stage_idx]
|
|
if stage_progress(stage_idx) >= len(eval_inputs):
|
|
prev_stage = stage
|
|
continue # use cache
|
|
t0 = time.perf_counter()
|
|
print(f"starting eval with model: {model}")
|
|
if stage_idx == 1: inputs = tokens
|
|
elif stage_idx == 5: inputs = progress["imgs"]
|
|
else: inputs = progress[prev_stage]
|
|
|
|
Tensor.realize(*[p.to_(GPUS) for p in get_parameters(model)])
|
|
for batch_idx in tqdm(range(stage_progress(stage_idx), inputs.shape[0], bs)):
|
|
t1 = time.perf_counter()
|
|
batch, unpadded_bs = get_batch(inputs, batch_idx, bs)
|
|
if isinstance(model, OpenClipEncoder): batch = callback(batch, get_batch(tokens, batch_idx, bs)[0].realize())
|
|
else: batch = callback(batch)
|
|
# to(GPUS[0]) is necessary for this to work, without that the result is still on GPUS, probably due to a bug
|
|
batch = batch.to(GPUS[0]).to("CPU")[0:unpadded_bs].realize()
|
|
progress[stage][batch_idx: batch_idx + bs].assign(batch).realize()
|
|
# keep track of what our last output was, so we can resume from there if we crash in this loop
|
|
progress["end"][stage_idx: stage_idx + 1].assign(Tensor([batch_idx + bs], dtype=dtypes.int)).realize()
|
|
print(f"model: {model}, batch_idx: {batch_idx}, elapsed: {(time.perf_counter() - t1):.2f}")
|
|
del batch
|
|
|
|
jit.reset()
|
|
Tensor.realize(*[p.to_("CPU") for p in get_parameters(model)])
|
|
print(f"done with model: {model}, elapsed: {(time.perf_counter() - t0):.2f}")
|
|
prev_stage = stage
|
|
|
|
inception_stats_fn = str(DATADIR / "coco2014" / "val2014_30k_stats.npz")
|
|
fid_score = inception.compute_score(progress["inception"].to("CPU"), inception_stats_fn)
|
|
clip_score = progress["clip"].to(GPUS[0]).mean().item()
|
|
for name in disk_tensor_names:
|
|
Path(f"{EVAL_CKPT_DIR}/{name}.bytes").unlink(missing_ok=True)
|
|
|
|
if EVAL_SAMPLES and BEAM:
|
|
print("BEAM COMPLETE", flush=True) # allows wrapper script to detect BEAM search completion and retry if it failed
|
|
sys.exit() # Don't eval additional models; we don't care about clip/fid scores when running BEAM on eval sample subset
|
|
|
|
return clip_score, fid_score
|
|
|
|
# evaluate checkpoints in reverse chronological order
|
|
for ckpt_iteration, p in sorted(eval_queue, reverse=True):
|
|
unet_ckpt = safe_load(p)
|
|
load_state_dict(unet, unet_ckpt)
|
|
clip_score, fid_score = eval_unet(eval_inputs, unet, model.cond_stage_model, model.first_stage_model, inception, clip_encoder)
|
|
converged = True if clip_score >= 0.15 and fid_score <= 90 else False
|
|
print(f"eval results for {EVAL_CKPT_DIR}/{p.name}: clip={clip_score}, fid={fid_score}, converged={converged}")
|
|
if WANDB:
|
|
wandb.log({"eval/ckpt_iteration": ckpt_iteration, "eval/clip_score": clip_score, "eval/fid_score": fid_score})
|
|
if converged and STOP_IF_CONVERGED:
|
|
print(f"Convergence detected, exiting early before evaluating other checkpoints due to STOP_IF_CONVERGED={STOP_IF_CONVERGED}")
|
|
sys.exit()
|
|
|
|
# for testing
|
|
return clip_score, fid_score, ckpt_iteration
|
|
|
|
if __name__ == "__main__":
|
|
# inference only
|
|
Tensor.training = False
|
|
|
|
models = getenv("MODEL", "resnet,retinanet,unet3d,rnnt,bert,mrcnn").split(",")
|
|
for m in models:
|
|
nm = f"eval_{m}"
|
|
if nm in globals():
|
|
print(f"eval {m}")
|
|
globals()[nm]()
|
|
|