from typing import List from extra.models.resnet import ResNet50 from tinygrad import Tensor, nn from tinygrad.helpers import Profiling, Timing, getenv, BEAM, NOOPT, DEBUG, Context, ansilen from tinygrad.uop.ops import Ops from tinygrad.codegen.kernel import Kernel from tinygrad.codegen.heuristic import hand_coded_optimizations from tinygrad.codegen import get_rewrites_for_renderer, apply_rewrites, rewrites_for_linearizer from tinygrad.engine.search import beam_search, bufs_from_lin if __name__ == "__main__": mdl = ResNet50() for p in nn.state.get_parameters(mdl): p.replace(Tensor.empty(p.shape)) img = Tensor.empty(64, 3, 224, 224) PROFILE = getenv("PYPROFILE", 0) FORWARD_ONLY = getenv("FORWARD_ONLY", 0) SCHEDULE_ONLY = getenv("SCHEDULE_ONLY", 0) LINEARIZE = bool(getenv("LINEARIZE", 1)) with Timing("all "): with Timing("***** model tensor in "): out = mdl(img) if not FORWARD_ONLY: with Timing("***** model schedule in "): with Profiling(PROFILE >= 3): sched = out.schedule() if not SCHEDULE_ONLY: asts = list({x.ast.key:x.ast for x in sched if x.ast.op is Ops.SINK}.values()) if (restrict_kernel := getenv("RESTRICT_KERNEL", -1)) != -1: asts = asts[restrict_kernel:restrict_kernel+1] kernels: List[Kernel] = [] with Timing(f"***** model opts({len(asts):2d}) in "): with Profiling(PROFILE >= 3): for ast in asts: k = Kernel(ast) if BEAM: with Context(DEBUG=max(2, DEBUG.value)): k = beam_search(k, bufs_from_lin(k), BEAM.value) elif NOOPT: pass else: k.apply_opts(hand_coded_optimizations(k)) kernels.append(k) with Timing("***** model prep in "): kernels = [(k, k.get_optimized_ast(), get_rewrites_for_renderer(k.opts, linearizer=False)) for k in kernels] with Profiling(PROFILE, fn="/tmp/rewrite.prof"): with Timing("***** model rewrite in "): rewritten_uops = [] for i,(k,u,rewrites) in enumerate(kernels): with Timing(f"rewrite {i:2d} {k.name}{' '*(50-ansilen(k.name))}", enabled=getenv("VERBOSE", 0)): rewritten_uops.append(apply_rewrites(u, rewrites)) if LINEARIZE: with Timing("***** model linearize in "): uops_line = [] for u in rewritten_uops: uops_line.append(apply_rewrites(u, rewrites_for_linearizer)) print(sum(len(u.arg.lst) for u in uops_line))