You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
37 lines
1.4 KiB
37 lines
1.4 KiB
3 days ago
|
# https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html
|
||
|
import torch
|
||
|
import torch._dynamo
|
||
|
from extra.torch_backend.backend import unwrap, wrap
|
||
|
|
||
|
from torch._dynamo.backends.registry import register_backend
|
||
|
from torch._functorch.aot_autograd import aot_module_simplified
|
||
|
|
||
|
from tinygrad import Tensor, TinyJit
|
||
|
|
||
|
@register_backend
|
||
|
def tiny(gm:torch.fx.GraphModule, sample_inputs):
|
||
|
def my_compiler(gm:torch.fx.GraphModule, sample_inputs):
|
||
|
# TODO: the jit should capture the graph directly, not need three runs. this is a planned tinygrad refactor after becomes_map
|
||
|
@TinyJit
|
||
|
def tiny_function(*args:Tensor):
|
||
|
outs = gm(*[wrap(x) for x in args])
|
||
|
for x in outs: unwrap(x).realize()
|
||
|
return outs
|
||
|
# TODO: this should be able to pass in .tiny() Tensors, not need to convert them. it tries to access Storage if you pass in.
|
||
|
def torch_function(*args:torch.Tensor): return tiny_function(*[unwrap(x.tiny()) for x in args])
|
||
|
return torch_function
|
||
|
return aot_module_simplified(gm, sample_inputs, decompositions={}, fw_compiler=my_compiler)
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
def foo(x, y):
|
||
|
a = torch.sin(x)
|
||
|
b = torch.cos(y)
|
||
|
return a + b
|
||
|
|
||
|
print("calling compile")
|
||
|
opt_foo1 = torch.compile(foo, backend="tiny")
|
||
|
print("compiled")
|
||
|
for i in range(5):
|
||
|
out = opt_foo1(torch.randn(10, 10), torch.randn(10, 10))
|
||
|
print(out.device)
|