You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
			
				
					49 lines
				
				1.8 KiB
			
		
		
			
		
	
	
					49 lines
				
				1.8 KiB
			| 
											7 days ago
										 | import unittest
 | ||
|  | from tinygrad import Tensor, Device, Variable
 | ||
|  | from examples.gpt2 import Transformer
 | ||
|  | from tinygrad.nn.state import get_state_dict
 | ||
|  | 
 | ||
|  | class TestMethodCache(unittest.TestCase):
 | ||
|  |   def setUp(self):
 | ||
|  |     self.backup_compiler = Device[Device.DEFAULT].compiler.compile_cached
 | ||
|  |   def tearDown(self):
 | ||
|  |     Device[Device.DEFAULT].compiler.compile_cached = self.backup_compiler
 | ||
|  | 
 | ||
|  |   def test_simple_methodcache(self):
 | ||
|  |     a = Tensor([1])
 | ||
|  |     b = Tensor([2])
 | ||
|  |     c = Tensor([3])
 | ||
|  |     d = Tensor([4])
 | ||
|  |     (a+b).realize()
 | ||
|  |     Device[Device.DEFAULT].compiler.compile_cached = None
 | ||
|  |     (c+d).realize()
 | ||
|  | 
 | ||
|  |   def test_nested_methodcache(self):
 | ||
|  |     a,b,c,d = Tensor([1]), Tensor([2]), Tensor([3]), Tensor([4])
 | ||
|  |     ((a+b)+(a+b)).realize()
 | ||
|  |     Device[Device.DEFAULT].compiler.compile_cached = None
 | ||
|  |     ((c+d)+(c+d)).realize()
 | ||
|  | 
 | ||
|  |   def test_nested_methodcache_swap(self):
 | ||
|  |     a,b,c,d = Tensor([1]), Tensor([2]), Tensor([3]), Tensor([4])
 | ||
|  |     ((a+b)+(c+d)).realize()
 | ||
|  |     Device[Device.DEFAULT].compiler.compile_cached = None
 | ||
|  |     ((c+d)+(a+b)).realize()
 | ||
|  | 
 | ||
|  |   @unittest.skip("incorrect use of transformer")
 | ||
|  |   def test_small_transformer(self):
 | ||
|  |     args_tiny = {"dim": 16, "n_heads": 8, "n_layers": 8, "norm_eps": 1e-05, "vocab_size": 10}
 | ||
|  |     model = Transformer(**args_tiny)
 | ||
|  |     for v in get_state_dict(model).values(): v.assign(Tensor.empty(*v.shape, dtype=v.dtype).realize())
 | ||
|  |     # NOTE: you have to do this twice due to the k-v cache
 | ||
|  |     for i in range(3): model(Tensor([[1,2,3,4]]), Variable("start_pos", 0, 10).bind(i)).realize()
 | ||
|  |     for i in range(3): model(Tensor([[1,2,3,4]]), Variable("start_pos", 0, 10).bind(i)).realize()
 | ||
|  |     Device[Device.DEFAULT].compiler.compile_cached = None
 | ||
|  |     for i in range(3): model(Tensor([[1,2,3,4]]), Variable("start_pos", 0, 10).bind(i)).realize()
 | ||
|  | 
 | ||
|  | if __name__ == '__main__':
 | ||
|  |   unittest.main()
 | ||
|  | 
 | ||
|  | 
 | ||
|  | 
 |