You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
			
				
					62 lines
				
				1.8 KiB
			
		
		
			
		
	
	
					62 lines
				
				1.8 KiB
			| 
											1 week ago
										 | import unittest
 | ||
|  | from tinygrad.nn.state import get_parameters
 | ||
|  | from tinygrad.tensor import Tensor
 | ||
|  | from tinygrad.nn import Conv2d, BatchNorm2d, optim
 | ||
|  | 
 | ||
|  | def model_step(lm):
 | ||
|  |   with Tensor.train():
 | ||
|  |     x = Tensor.ones(8,12,128,256, requires_grad=False)
 | ||
|  |     optimizer = optim.SGD(get_parameters(lm), lr=0.001)
 | ||
|  |     loss = lm.forward(x).sum()
 | ||
|  |     optimizer.zero_grad()
 | ||
|  |     loss.backward()
 | ||
|  |     del x,loss
 | ||
|  |     optimizer.step()
 | ||
|  | 
 | ||
|  | class TestBatchnorm(unittest.TestCase):
 | ||
|  |   def test_conv(self):
 | ||
|  |     class LilModel:
 | ||
|  |       def __init__(self):
 | ||
|  |         self.c = Conv2d(12, 32, 3, padding=1, bias=False)
 | ||
|  |       def forward(self, x):
 | ||
|  |         return self.c(x).relu()
 | ||
|  |     lm = LilModel()
 | ||
|  |     model_step(lm)
 | ||
|  | 
 | ||
|  |   def test_two_conv(self):
 | ||
|  |     class LilModel:
 | ||
|  |       def __init__(self):
 | ||
|  |         self.c = Conv2d(12, 32, 3, padding=1, bias=False)
 | ||
|  |         self.c2 = Conv2d(32, 32, 3, padding=1, bias=False)
 | ||
|  |       def forward(self, x):
 | ||
|  |         return self.c2(self.c(x)).relu()
 | ||
|  |     lm = LilModel()
 | ||
|  |     model_step(lm)
 | ||
|  | 
 | ||
|  |   def test_two_conv_bn(self):
 | ||
|  |     class LilModel:
 | ||
|  |       def __init__(self):
 | ||
|  |         self.c = Conv2d(12, 24, 3, padding=1, bias=False)
 | ||
|  |         self.bn = BatchNorm2d(24, track_running_stats=False)
 | ||
|  |         self.c2 = Conv2d(24, 32, 3, padding=1, bias=False)
 | ||
|  |         self.bn2 = BatchNorm2d(32, track_running_stats=False)
 | ||
|  |       def forward(self, x):
 | ||
|  |         x = self.bn(self.c(x)).relu()
 | ||
|  |         return self.bn2(self.c2(x)).relu()
 | ||
|  |     lm = LilModel()
 | ||
|  |     model_step(lm)
 | ||
|  | 
 | ||
|  |   def test_conv_bn(self):
 | ||
|  |     class LilModel:
 | ||
|  |       def __init__(self):
 | ||
|  |         self.c = Conv2d(12, 32, 3, padding=1, bias=False)
 | ||
|  |         self.bn = BatchNorm2d(32, track_running_stats=False)
 | ||
|  |       def forward(self, x):
 | ||
|  |         return self.bn(self.c(x)).relu()
 | ||
|  |     lm = LilModel()
 | ||
|  |     model_step(lm)
 | ||
|  | 
 | ||
|  | 
 | ||
|  | if __name__ == '__main__':
 | ||
|  |   unittest.main()
 |