openpilot is an open source driver assistance system. openpilot performs the functions of Automated Lane Centering and Adaptive Cruise Control for over 200 supported car makes and models.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

239 lines
8.5 KiB

import unittest
from tinygrad import Tensor
from tinygrad.helpers import getenv, GlobalCounters
from tinygrad.engine.realize import lower_schedule_item, ProgramSpec
from tinygrad.renderer import Estimates
from tinygrad.codegen.linearize import linearize_uop
from tinygrad.ops import Ops, UOp
from tinygrad.dtype import dtypes
from tinygrad.codegen.kernel import Kernel, Opt, OptOps, KernelOptError
from tinygrad.device import Device
def flops_mem(uops, ignore_indexing=False):
est = Estimates.from_uops(uops, ignore_indexing)
return est.ops, est.lds
# **************** new FlopCounter ****************
def get_stats(x:Tensor):
si = x.schedule()[-1]
ei = lower_schedule_item(si)
return ei.prg.estimates.ops, ei.prg.estimates.mem
class TestMemoryCount(unittest.TestCase):
def test_add(self):
a = Tensor.empty(1024, 1024, dtype=dtypes.uint8)
b = Tensor.empty(1024, 1024, dtype=dtypes.uint8)
_, mem = get_stats(a+b)
self.assertEqual(mem, 1024*1024*3) # 2 reads + 1 write
def test_add_const(self):
a = Tensor.empty(1024, 1024, dtype=dtypes.uint8)
_, mem = get_stats(a+3)
self.assertEqual(mem, 1024*1024*2) # 1 read + 1 write
def test_add_slice(self):
a = Tensor.empty(1024, 1024, dtype=dtypes.uint8)[:512]
_, mem = get_stats(a+3)
self.assertEqual(mem, 512*1024*2) # 1 read + 1 write
def test_expanded(self):
a = Tensor.empty(1024, 1, dtype=dtypes.uint8).expand(1024, 1024)
b = Tensor.empty(1024, 1024, dtype=dtypes.uint8)
_, mem = get_stats(a+b)
self.assertEqual(mem, 1024*1024*2 + 1024) # 1 full read + 1 lil read + 1 write
def test_both_expanded(self):
# TODO: this probably should be a full write
a = Tensor.empty(1024, 1, dtype=dtypes.uint8).expand(1024, 1024)
b = Tensor.empty(1024, 1, dtype=dtypes.uint8).expand(1024, 1024)
_, mem = get_stats(a+b)
self.assertEqual(mem, 1024*1024 + 2*1024) # 2 lil reads + 1 write
def test_self_add(self):
a = Tensor.empty(1024, 1024, dtype=dtypes.uint8)
_, mem = get_stats(a+a)
self.assertEqual(mem, 1024*1024*2) # 1 read + 1 write
def test_self_add_transposed(self):
a = Tensor.empty(1024, 1024, dtype=dtypes.uint8)
_, mem = get_stats(a+a.T)
self.assertEqual(mem, 1024*1024*2) # 1 read + 1 write
def test_self_add_assign(self):
a = Tensor.empty(1024, 1024, dtype=dtypes.uint8).realize()
_, mem = get_stats(a.assign(a+a))
self.assertEqual(mem, 1024*1024*2) # 1 read + 1 write
@unittest.skipIf(Device.DEFAULT == "CPU", "test copy to CPU from other device")
def test_copyout(self):
a = Tensor.empty(32, dtype=dtypes.uint8).to("CPU")
_, mem = get_stats(a)
self.assertEqual(mem, 32*1)
a = Tensor.empty(32, dtype=dtypes.uint32).to("CPU")
_, mem = get_stats(a)
self.assertEqual(mem, 32*4)
# NOTE: this still isn't testing unroll using the acc
@unittest.skipUnless(getenv("PYTHON"), "only run test on emulated tensor cores")
class TestUOpsStatsMatmulHalf(unittest.TestCase):
def test_simple_matmul_half(self, N=16):
GlobalCounters.reset()
a, b = Tensor.empty(N, N, dtype=dtypes.half), Tensor.empty(N, N, dtype=dtypes.half)
c = a.matmul(b)
c.realize()
expected_ops = N ** 3 * 2
self.assertEqual(expected_ops, GlobalCounters.global_ops)
def test_bigger_matmul_half(self): self.test_simple_matmul_half(64)
def test_batched_matmul_half(self, N=16):
GlobalCounters.reset()
a, b = Tensor.empty(4, N, N, dtype=dtypes.half), Tensor.empty(1, N, N, dtype=dtypes.half)
c = a.matmul(b)
c.realize()
expected_ops = 4 * N ** 3 * 2
self.assertEqual(expected_ops, GlobalCounters.global_ops)
class TestUOpsStats(unittest.TestCase):
@unittest.skipIf(getenv("PTX"), "wrong in PTX")
def test_simple_add(self):
a = Tensor.empty(100,100)
b = Tensor.empty(100,100)
c = a+b
ops, mem = get_stats(c)
expected_ops = c.numel()
expected_mem = a.nbytes() + b.nbytes() + c.nbytes()
self.assertEqual(mem, expected_mem)
# NOTE; ops also include indexing ops
assert expected_ops <= ops and ops <= expected_ops * 2
@unittest.skipIf(getenv("PTX"), "wrong in PTX")
def test_simple_add_sq(self):
a = Tensor.empty(100,100)
b = Tensor.empty(100,100)
c = (a+b)*(a+b)
ops, mem = get_stats(c)
expected_ops = c.numel()*2
expected_mem = a.nbytes() + b.nbytes() + c.nbytes()
self.assertEqual(mem, expected_mem)
# NOTE; ops also include indexing ops
assert expected_ops <= ops and ops <= expected_ops * 2
def test_simple_matmul(self):
a = Tensor.empty(1024,1024)
b = Tensor.empty(1024,1024)
c = a@b
ops, mem = get_stats(c)
expected_ops = c.numel() * 1024 * 2
required_mem = a.nbytes() + b.nbytes() + c.nbytes()
assert expected_ops <= ops and ops <= expected_ops * 1.2
# NOTE: it's hard to assert on the memory here, all depends on caching
assert required_mem <= mem
#MULACC should have the same stats as MUL + ADD
def test_mulacc(self):
globl = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), tuple())
o1 = UOp(Ops.CONST, dtypes.int, tuple(), 1)
o2 = UOp(Ops.CONST, dtypes.int, tuple(), 2)
u1 = UOp(Ops.LOAD, dtypes.int, (globl.index(o1),))
u2 = UOp(Ops.LOAD, dtypes.int, (globl.index(o2),))
u3 = UOp(Ops.CONST, dtypes.int, tuple(), 3)
u4 = UOp(Ops.MUL, dtypes.int, (u1,u2))
u5 = UOp(Ops.ADD, dtypes.int, (u4,u3))
uops = linearize_uop(u5.sink())
globl = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), tuple())
o1 = UOp(Ops.CONST, dtypes.int, tuple(), 1)
o2 = UOp(Ops.CONST, dtypes.int, tuple(), 2)
u1 = UOp(Ops.LOAD, dtypes.int, (globl.index(o1),))
u2 = UOp(Ops.LOAD, dtypes.int, (globl.index(o2),))
u3 = UOp(Ops.CONST, dtypes.int, tuple(), 3)
u4 = UOp(Ops.MULACC, dtypes.int, (u1,u2,u3))
uops_fma = linearize_uop(u4.sink())
self.assertEqual(flops_mem(uops), flops_mem(uops_fma))
N = 100
@unittest.skipIf(getenv("PTX"), "wrong in PTX") # maybe?
class TestStatsOptimized(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.ast_gemm = (Tensor.empty(N, N) @ Tensor.empty(N, N)).schedule()[-1].ast
cls.ast_reduce = (Tensor.empty(N*N).sum()).schedule()[-1].ast
def check_gemm(self, p:ProgramSpec, extra_flops=0):
#p.uops.print()
#print(p.src)
print(p.name, p.estimates.ops, p.estimates.mem, p.estimates.lds)
self.assertEqual(p.estimates.ops, 2*N*N*N + extra_flops) # N**3 mulaccs
self.assertEqual(p.estimates.mem, 3*N*N*4) # 3 NxN mats with floats
def test_gemm(self):
p = Kernel(self.ast_gemm).to_program()
self.check_gemm(p)
self.assertEqual(p.estimates.lds, 2*N*N*N*4 + 4*N*N)
# this is a good lesson about why UPCASTing is a good idea
def test_gemm_one_upcasted(self):
k = Kernel(self.ast_gemm)
k.apply_opt(Opt(OptOps.UPCAST, 0, 4))
p = k.to_program()
self.check_gemm(p)
self.assertEqual(p.estimates.lds, N*N*N*4 + N*N*N*4//4 + 4*N*N)
def test_gemm_upcasted(self):
k = Kernel(self.ast_gemm)
k.apply_opt(Opt(OptOps.UPCAST, 0, 4))
k.apply_opt(Opt(OptOps.UPCAST, 1, 4))
k.apply_opt(Opt(OptOps.UNROLL, 0, 4))
p = k.to_program()
self.check_gemm(p)
self.assertEqual(p.estimates.lds, 2*N*N*N*4//4 + 4*N*N)
def test_gemm_upcasted_locals(self):
k = Kernel(self.ast_gemm)
k.apply_opt(Opt(OptOps.UPCAST, 0, 4))
k.apply_opt(Opt(OptOps.UPCAST, 1, 4))
try:
k.apply_opt(Opt(OptOps.LOCAL, 0, 5))
k.apply_opt(Opt(OptOps.LOCAL, 1, 5))
except KernelOptError:
raise unittest.SkipTest("no locals")
p = k.to_program()
self.check_gemm(p)
self.assertEqual(p.estimates.lds, 2*N*N*N*4//4 + 4*N*N)
def test_gemm_group(self):
k = Kernel(self.ast_gemm)
try:
k.apply_opt(Opt(OptOps.GROUP, 0, 4))
except KernelOptError:
raise unittest.SkipTest("no locals")
SZ = N*N*4
p = k.to_program()
# NOTE: these are sort of wrong. they aren't honoring the IF statement
self.check_gemm(p, extra_flops=SZ*4)
self.assertEqual(p.estimates.lds, 2*N*N*N*4 + SZ*4 + (SZ*4 + 4*N*N)*4)
def test_reduce(self):
k = Kernel(self.ast_reduce)
p = k.to_program()
print(p.name, p.estimates.ops, p.estimates.mem, p.estimates.lds)
self.assertEqual(p.estimates.ops, N*N)
self.assertEqual(p.estimates.mem, N*N*4 + 4)
def test_reduce_group(self):
k = Kernel(self.ast_reduce)
try:
k.apply_opt(Opt(OptOps.GROUP, 0, 50))
except KernelOptError:
raise unittest.SkipTest("no locals")
p = k.to_program()
# NOTE: these are wrong, they don't respect the if statement
print(p.name, p.estimates.ops, p.estimates.mem, p.estimates.lds)
if __name__ == '__main__':
unittest.main(verbosity=2)