You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
45 lines
1.8 KiB
45 lines
1.8 KiB
6 days ago
|
import unittest
|
||
|
from tinygrad import Tensor, dtypes, TinyJit, UOp
|
||
|
from tinygrad.apps.llm import apply_rope
|
||
|
|
||
|
# TODO: test_scheduler, but just in uint
|
||
|
class TestAttention(unittest.TestCase):
|
||
|
def test_half_qkv_buffers(self):
|
||
|
BS, seqlen, dim = 10, 4, 100
|
||
|
q = Tensor.ones(BS, seqlen, dim, dtype=dtypes.half).contiguous().realize()
|
||
|
k = Tensor.ones(BS, seqlen, dim, dtype=dtypes.half).contiguous().realize()
|
||
|
v = Tensor.ones(BS, seqlen, dim, dtype=dtypes.half).contiguous().realize()
|
||
|
attn = q.scaled_dot_product_attention(k, v)
|
||
|
sched = attn.schedule()
|
||
|
# attention has 5 kernels now
|
||
|
self.assertEqual(len(sched), 5)
|
||
|
softmax_inputs = sched[1:4]
|
||
|
for si in softmax_inputs:
|
||
|
assert all(b.dtype == dtypes.half for b in si.bufs), f"non half {si.bufs=}"
|
||
|
|
||
|
def test_apply_rope(self):
|
||
|
x = Tensor.randn(1, 2, 4, 8, dtype=dtypes.float32)
|
||
|
result = apply_rope(x, 0)
|
||
|
self.assertEqual(result.shape, x.shape)
|
||
|
self.assertEqual(result.dtype, x.dtype)
|
||
|
self.assertGreater((result - apply_rope(x, 5)).abs().max().item(), 1e-6)
|
||
|
with self.assertRaises(AssertionError): apply_rope(Tensor.randn(1, 1, 4, 7, dtype=dtypes.float32), 0)
|
||
|
|
||
|
def test_apply_rope_jit_prune(self):
|
||
|
def rope_fn(x_in, pos): return apply_rope(x_in, pos)
|
||
|
rope_noprune = TinyJit(rope_fn)
|
||
|
rope_prune = TinyJit(rope_fn, prune=True)
|
||
|
|
||
|
v_pos = UOp.variable("start_pos", 0, 100)
|
||
|
for _ in range(3):
|
||
|
rope_noprune(Tensor.randn(1, 2, 4, 8, dtype=dtypes.float32), v_pos.bind(1))
|
||
|
rope_prune(Tensor.randn(1, 2, 4, 8, dtype=dtypes.float32), v_pos.bind(1))
|
||
|
noprune_size = len(rope_noprune.captured.jit_cache)
|
||
|
prune_size = len(rope_prune.captured.jit_cache)
|
||
|
|
||
|
self.assertGreater(noprune_size, prune_size)
|
||
|
self.assertGreaterEqual(noprune_size, 3)
|
||
|
self.assertEqual(prune_size, 1)
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
unittest.main()
|