diff --git a/selfdrive/modeld/modeld.py b/selfdrive/modeld/modeld.py index 1e3d1782e6..ac88dfbb5f 100755 --- a/selfdrive/modeld/modeld.py +++ b/selfdrive/modeld/modeld.py @@ -86,10 +86,20 @@ class ModelState: prev_desire: np.ndarray # for tracking the rising edge of the pulse def __init__(self, context: CLContext): - self.frames = { - 'input_imgs': DrivingModelFrame(context, ModelConstants.TEMPORAL_SKIP), - 'big_input_imgs': DrivingModelFrame(context, ModelConstants.TEMPORAL_SKIP) - } + with open(VISION_METADATA_PATH, 'rb') as f: + vision_metadata = pickle.load(f) + self.vision_input_shapes = vision_metadata['input_shapes'] + self.vision_input_names = list(self.vision_input_shapes.keys()) + self.vision_output_slices = vision_metadata['output_slices'] + vision_output_size = vision_metadata['output_shapes']['outputs'][1] + + with open(POLICY_METADATA_PATH, 'rb') as f: + policy_metadata = pickle.load(f) + self.policy_input_shapes = policy_metadata['input_shapes'] + self.policy_output_slices = policy_metadata['output_slices'] + policy_output_size = policy_metadata['output_shapes']['outputs'][1] + + self.frames = {name: DrivingModelFrame(context, ModelConstants.TEMPORAL_SKIP) for name in self.vision_input_names} self.prev_desire = np.zeros(ModelConstants.DESIRE_LEN, dtype=np.float32) self.full_features_buffer = np.zeros((1, ModelConstants.FULL_HISTORY_BUFFER_LEN, ModelConstants.FEATURE_LEN), dtype=np.float32) @@ -106,18 +116,6 @@ class ModelState: 'features_buffer': np.zeros((1, ModelConstants.INPUT_HISTORY_BUFFER_LEN, ModelConstants.FEATURE_LEN), dtype=np.float32), } - with open(VISION_METADATA_PATH, 'rb') as f: - vision_metadata = pickle.load(f) - self.vision_input_shapes = vision_metadata['input_shapes'] - self.vision_output_slices = vision_metadata['output_slices'] - vision_output_size = vision_metadata['output_shapes']['outputs'][1] - - with open(POLICY_METADATA_PATH, 'rb') as f: - policy_metadata = pickle.load(f) - self.policy_input_shapes = policy_metadata['input_shapes'] - self.policy_output_slices = policy_metadata['output_slices'] - policy_output_size = policy_metadata['output_shapes']['outputs'][1] - # img buffers are managed in openCL transform code self.vision_inputs: dict[str, Tensor] = {} self.vision_output = np.zeros(vision_output_size, dtype=np.float32) @@ -135,7 +133,7 @@ class ModelState: parsed_model_outputs = {k: model_outputs[np.newaxis, v] for k,v in output_slices.items()} return parsed_model_outputs - def run(self, buf: VisionBuf, wbuf: VisionBuf, transform: np.ndarray, transform_wide: np.ndarray, + def run(self, bufs: dict[str, VisionBuf], transforms: dict[str, np.ndarray], inputs: dict[str, np.ndarray], prepare_only: bool) -> dict[str, np.ndarray] | None: # Model decides when action is completed, so desire input is just a pulse triggered on rising edge inputs['desire'][0] = 0 @@ -148,8 +146,7 @@ class ModelState: self.numpy_inputs['traffic_convention'][:] = inputs['traffic_convention'] self.numpy_inputs['lateral_control_params'][:] = inputs['lateral_control_params'] - imgs_cl = {'input_imgs': self.frames['input_imgs'].prepare(buf, transform.flatten()), - 'big_input_imgs': self.frames['big_input_imgs'].prepare(wbuf, transform_wide.flatten())} + imgs_cl = {name: self.frames[name].prepare(bufs[name], transforms[name].flatten()) for name in self.vision_input_names} if TICI and not USBGPU: # The imgs tensors are backed by opencl memory, only need init once @@ -328,14 +325,16 @@ def main(demo=False): if prepare_only: cloudlog.error(f"skipping model eval. Dropped {vipc_dropped_frames} frames") + bufs = {name: buf_extra if 'big' in name else buf_main for name in model.vision_input_names} + transforms = {name: model_transform_extra if 'big' in name else model_transform_main for name in model.vision_input_names} inputs:dict[str, np.ndarray] = { 'desire': vec_desire, 'traffic_convention': traffic_convention, 'lateral_control_params': lateral_control_params, - } + } mt1 = time.perf_counter() - model_output = model.run(buf_main, buf_extra, model_transform_main, model_transform_extra, inputs, prepare_only) + model_output = model.run(bufs, transforms, inputs, prepare_only) mt2 = time.perf_counter() model_execution_time = mt2 - mt1