You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
31 lines
1.1 KiB
31 lines
1.1 KiB
1 day ago
|
import pickle, sys
|
||
|
from dataclasses import replace
|
||
|
from tinygrad import Device, Context
|
||
|
from tinygrad.device import Buffer
|
||
|
from tinygrad.helpers import getenv, BEAM
|
||
|
from tinygrad.engine.jit import TinyJit
|
||
|
from tinygrad.engine.realize import CompiledRunner
|
||
|
from tinygrad.renderer import ProgramSpec
|
||
|
from tinygrad.codegen.kernel import Kernel, Opt, OptOps
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
with Context(DEBUG=0):
|
||
|
with open(sys.argv[1], "rb") as f:
|
||
|
fxn: TinyJit = pickle.load(f)
|
||
|
print(f"{f.tell()/1e6:.2f}M loaded")
|
||
|
print(type(fxn))
|
||
|
|
||
|
knum = 1
|
||
|
for ei in fxn.captured.jit_cache:
|
||
|
# skip the copy and the first kernel
|
||
|
if isinstance(ei.prg, CompiledRunner) and all(x is not None for x in ei.bufs):
|
||
|
if knum == (pknum:=getenv("KNUM", 0)) or pknum == 0:
|
||
|
p: ProgramSpec = ei.prg.p
|
||
|
k = Kernel(p.ast, Device["DSP"].renderer)
|
||
|
dsp_bufs = [Buffer("DSP", 8192+b.size, b.dtype).view(b.size, b.dtype, 4096) for b in ei.bufs]
|
||
|
k.hand_coded_optimizations()
|
||
|
p2 = k.to_program()
|
||
|
new_ei = replace(ei, prg=CompiledRunner(p2), bufs=dsp_bufs)
|
||
|
new_ei.run()
|
||
|
knum += 1
|