You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
50 lines
1.5 KiB
50 lines
1.5 KiB
1 month ago
|
import os
|
||
|
#os.environ["METAL"] = "1"
|
||
|
import numpy as np
|
||
|
|
||
|
BS = 64
|
||
|
CIN = 256
|
||
|
COUT = 256
|
||
|
HW = 32
|
||
|
K = 3
|
||
|
PADDING = 0
|
||
|
# TODO: this is doing some trick, since with CIN=256 COUT=256 it's over 10.4 TFLOPS.
|
||
|
# are winograd convs less flops? it appears so if they are batched
|
||
|
# https://www.cse.ust.hk/~weiwa/papers/yan-ppopp20.pdf
|
||
|
FLOPS = BS*K*K*CIN*HW*HW*COUT*2
|
||
|
|
||
|
nb = np.random.default_rng().standard_normal(size=(BS,CIN,HW,HW), dtype=np.float32)
|
||
|
nc = np.random.default_rng().standard_normal(size=(COUT,CIN,K,K), dtype=np.float32)
|
||
|
|
||
|
try:
|
||
|
import time, torch, torch.mps
|
||
|
b = torch.from_numpy(nb).to('mps')
|
||
|
c = torch.from_numpy(nc).to('mps')
|
||
|
|
||
|
def torch_prog(b, c):
|
||
|
st = time.perf_counter()
|
||
|
a = torch.nn.functional.conv2d(b, c, padding=PADDING)
|
||
|
torch.mps.synchronize()
|
||
|
return time.perf_counter() - st
|
||
|
tm = min([torch_prog(b, c) for _ in range(20)])
|
||
|
print(f"{tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS conv in torch")
|
||
|
except RuntimeError:
|
||
|
print("no torch metal conv")
|
||
|
|
||
|
from tinygrad.tensor import Tensor
|
||
|
from tinygrad.engine.jit import TinyJit
|
||
|
from tinygrad import Device
|
||
|
b = Tensor(nb)
|
||
|
c = Tensor(nc)
|
||
|
# TODO: slowness without the JIT I suspect comes from a lack of a caching allocator
|
||
|
@TinyJit
|
||
|
def tiny_jit(b, c):
|
||
|
return b.conv2d(c, padding=PADDING).realize()
|
||
|
def tiny_prog(b, c):
|
||
|
st = time.perf_counter()
|
||
|
a = tiny_jit(b, c)
|
||
|
Device[a.device].synchronize()
|
||
|
return time.perf_counter() - st
|
||
|
tm = min([tiny_prog(b, c) for _ in range(5)])
|
||
|
print(f"{tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS conv in tinygrad")
|