You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							145 lines
						
					
					
						
							4.7 KiB
						
					
					
				
			
		
		
	
	
							145 lines
						
					
					
						
							4.7 KiB
						
					
					
				| #!/usr/bin/env python3
 | |
| import os
 | |
| import struct
 | |
| import zipfile
 | |
| import numpy as np
 | |
| from tqdm import tqdm
 | |
| 
 | |
| from common.basedir import BASEDIR
 | |
| from selfdrive.modeld.thneed.lib import load_thneed, save_thneed
 | |
| 
 | |
| # this is junk code, but it doesn't have deps
 | |
| def load_dlc_weights(fn):
 | |
|   archive = zipfile.ZipFile(fn, 'r')
 | |
|   dlc_params = archive.read("model.params")
 | |
| 
 | |
|   def extract(rdat):
 | |
|     idx = rdat.find(b"\x00\x00\x00\x09\x04\x00\x00\x00")
 | |
|     rdat = rdat[idx+8:]
 | |
|     ll = struct.unpack("I", rdat[0:4])[0]
 | |
|     buf = np.frombuffer(rdat[4:4+ll*4], dtype=np.float32)
 | |
|     rdat = rdat[4+ll*4:]
 | |
|     dims = struct.unpack("I", rdat[0:4])[0]
 | |
|     buf = buf.reshape(struct.unpack("I"*dims, rdat[4:4+dims*4]))
 | |
|     if len(buf.shape) == 4:
 | |
|       buf = np.transpose(buf, (3,2,0,1))
 | |
|     return buf
 | |
| 
 | |
|   def parse(tdat):
 | |
|     ll = struct.unpack("I", tdat[0:4])[0] + 4
 | |
|     return (None, [extract(tdat[0:]), extract(tdat[ll:])])
 | |
| 
 | |
|   ptr = 0x20
 | |
|   def r4():
 | |
|     nonlocal ptr
 | |
|     ret = struct.unpack("I", dlc_params[ptr:ptr+4])[0]
 | |
|     ptr += 4
 | |
|     return ret
 | |
|   ranges = []
 | |
|   cnt = r4()
 | |
|   for _ in range(cnt):
 | |
|     o = r4() + ptr
 | |
|     # the header is 0xC
 | |
|     plen, is_4, is_2 = struct.unpack("III", dlc_params[o:o+0xC])
 | |
|     assert is_4 == 4 and is_2 == 2
 | |
|     ranges.append((o+0xC, o+plen+0xC))
 | |
|   ranges = sorted(ranges, reverse=True)
 | |
| 
 | |
|   return [parse(dlc_params[s:e]) for s,e in ranges]
 | |
| 
 | |
| # this won't run on device without onnx
 | |
| def load_onnx_weights(fn):
 | |
|   import onnx
 | |
|   from onnx import numpy_helper
 | |
| 
 | |
|   model = onnx.load(fn)
 | |
|   graph = model.graph  # pylint: disable=maybe-no-member
 | |
|   init = {x.name:x for x in graph.initializer}
 | |
| 
 | |
|   onnx_layers = []
 | |
|   for node in graph.node:
 | |
|     #print(node.name, node.op_type, node.input, node.output)
 | |
|     vals = []
 | |
|     for inp in node.input:
 | |
|       if inp in init:
 | |
|         vals.append(numpy_helper.to_array(init[inp]))
 | |
|     if len(vals) > 0:
 | |
|       onnx_layers.append((node.name, vals))
 | |
|   return onnx_layers
 | |
| 
 | |
| def weights_fixup(target, source_thneed, dlc):
 | |
|   #onnx_layers = load_onnx_weights(os.path.join(BASEDIR, "models/supercombo.onnx"))
 | |
|   onnx_layers = load_dlc_weights(dlc)
 | |
|   jdat = load_thneed(source_thneed)
 | |
| 
 | |
|   bufs = {}
 | |
|   for o in jdat['objects']:
 | |
|     bufs[o['id']] = o
 | |
| 
 | |
|   thneed_layers = []
 | |
|   for k in jdat['kernels']:
 | |
|     #print(k['name'])
 | |
|     vals = []
 | |
|     for a in k['args']:
 | |
|       if a in bufs:
 | |
|         o = bufs[a]
 | |
|         if o['needs_load'] or ('buffer_id' in o and bufs[o['buffer_id']]['needs_load']):
 | |
|           #print("  ", o['arg_type'])
 | |
|           vals.append(o)
 | |
|     if len(vals) > 0:
 | |
|       thneed_layers.append((k['name'], vals))
 | |
| 
 | |
|   assert len(thneed_layers) == len(onnx_layers)
 | |
| 
 | |
|   # fix up weights
 | |
|   for tl, ol in tqdm(zip(thneed_layers, onnx_layers), total=len(thneed_layers)):
 | |
|     #print(tl[0], ol[0])
 | |
|     assert len(tl[1]) == len(ol[1])
 | |
|     for o, onnx_weight in zip(tl[1], ol[1]):
 | |
|       if o['arg_type'] == "image2d_t":
 | |
|         obuf = bufs[o['buffer_id']]
 | |
|         saved_weights = np.frombuffer(obuf['data'], dtype=np.float16).reshape(o['height'], o['row_pitch']//2)
 | |
| 
 | |
|         if len(onnx_weight.shape) == 4:
 | |
|           # convolution
 | |
|           oc,ic,ch,cw = onnx_weight.shape
 | |
| 
 | |
|           if 'depthwise' in tl[0]:
 | |
|             assert ic == 1
 | |
|             weights = np.transpose(onnx_weight.reshape(oc//4,4,ch,cw), (0,2,3,1)).reshape(o['height'], o['width']*4)
 | |
|           else:
 | |
|             weights = np.transpose(onnx_weight.reshape(oc//4,4,ic//4,4,ch,cw), (0,4,2,5,1,3)).reshape(o['height'], o['width']*4)
 | |
|         else:
 | |
|           # fc_Wtx
 | |
|           weights = onnx_weight
 | |
| 
 | |
|         new_weights = np.zeros((o['height'], o['row_pitch']//2), dtype=np.float32)
 | |
|         new_weights[:, :weights.shape[1]] = weights
 | |
| 
 | |
|         # weights shouldn't be too far off
 | |
|         err = np.mean((saved_weights.astype(np.float32) - new_weights)**2)
 | |
|         assert err < 1e-3
 | |
|         rerr = np.mean(np.abs((saved_weights.astype(np.float32) - new_weights)/(new_weights+1e-12)))
 | |
|         assert rerr < 0.5
 | |
| 
 | |
|         # fix should improve things
 | |
|         fixed_err = np.mean((new_weights.astype(np.float16).astype(np.float32) - new_weights)**2)
 | |
|         assert (err/fixed_err) >= 1
 | |
| 
 | |
|         #print("   ", o['size'], onnx_weight.shape, o['row_pitch'], o['width'], o['height'], "err %.2fx better" % (err/fixed_err))
 | |
| 
 | |
|         obuf['data'] = new_weights.astype(np.float16).tobytes()
 | |
| 
 | |
|       elif o['arg_type'] == "float*":
 | |
|         # unconverted floats are correct
 | |
|         new_weights = np.zeros(o['size']//4, dtype=np.float32)
 | |
|         new_weights[:onnx_weight.shape[0]] = onnx_weight
 | |
|         assert new_weights.tobytes() == o['data']
 | |
|         #print("   ", o['size'], onnx_weight.shape)
 | |
| 
 | |
|   save_thneed(jdat, target)
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|   weights_fixup(os.path.join(BASEDIR, "models/supercombo_fixed.thneed"),
 | |
|                 os.path.join(BASEDIR, "models/supercombo.thneed"),
 | |
|                 os.path.join(BASEDIR, "models/supercombo.dlc"))
 | |
| 
 |