diff --git a/selfdrive/modeld/constants.py b/selfdrive/modeld/constants.py index e513922c72..e1741cbdcf 100644 --- a/selfdrive/modeld/constants.py +++ b/selfdrive/modeld/constants.py @@ -21,6 +21,7 @@ class ModelConstants: NAV_FEATURE_LEN = 256 NAV_INSTRUCTION_LEN = 150 DRIVING_STYLE_LEN = 12 + LAT_PLANNER_STATE_LEN = 4 LATERAL_CONTROL_PARAMS_LEN = 2 PREV_DESIRED_CURVS_LEN = 20 @@ -39,6 +40,7 @@ class ModelConstants: ROAD_EDGES_WIDTH = 2 PLAN_WIDTH = 15 DESIRE_PRED_WIDTH = 8 + LAT_PLANNER_SOLUTION_WIDTH = 4 DESIRED_CURV_WIDTH = 1 NUM_LANE_LINES = 4 diff --git a/selfdrive/modeld/parse_model_outputs.py b/selfdrive/modeld/parse_model_outputs.py index 07d5e0a921..01cba29d1c 100644 --- a/selfdrive/modeld/parse_model_outputs.py +++ b/selfdrive/modeld/parse_model_outputs.py @@ -93,7 +93,10 @@ class Parser: self.parse_mdn('wide_from_device_euler', outs, in_N=0, out_N=0, out_shape=(ModelConstants.WIDE_FROM_DEVICE_WIDTH,)) self.parse_mdn('lead', outs, in_N=ModelConstants.LEAD_MHP_N, out_N=ModelConstants.LEAD_MHP_SELECTION, out_shape=(ModelConstants.LEAD_TRAJ_LEN,ModelConstants.LEAD_WIDTH)) - self.parse_mdn('desired_curvature', outs, in_N=0, out_N=0, out_shape=(ModelConstants.DESIRED_CURV_WIDTH,)) + if 'lat_planner_solution' in outs: + self.parse_mdn('lat_planner_solution', outs, in_N=0, out_N=0, out_shape=(ModelConstants.IDX_N,ModelConstants.LAT_PLANNER_SOLUTION_WIDTH)) + if 'desired_curvature' in outs: + self.parse_mdn('desired_curvature', outs, in_N=0, out_N=0, out_shape=(ModelConstants.DESIRED_CURV_WIDTH,)) for k in ['lead_prob', 'lane_lines_prob', 'meta']: self.parse_binary_crossentropy(k, outs) self.parse_categorical_crossentropy('desire_state', outs, out_shape=(ModelConstants.DESIRE_PRED_WIDTH,))