diff --git a/selfdrive/modeld/models/driving_policy.onnx b/selfdrive/modeld/models/driving_policy.onnx index 2fb716997d..2582a9f1f5 100644 --- a/selfdrive/modeld/models/driving_policy.onnx +++ b/selfdrive/modeld/models/driving_policy.onnx @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:1741cad23f6f451782b5db6182218749ee12072e393d57eac36d8d5c55d9358a -size 15583374 +oid sha256:cddefa5dbe21c60858f0d321d12d1ed4b126d05a7563f651bdf021674be63d64 +size 15701037 diff --git a/selfdrive/modeld/models/driving_vision.onnx b/selfdrive/modeld/models/driving_vision.onnx index fe636fb682..6cdca2f93c 100644 --- a/selfdrive/modeld/models/driving_vision.onnx +++ b/selfdrive/modeld/models/driving_vision.onnx @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:3d2bd82ba42341dba1bda5426e45c4c646db604c9ac422156eaa2b9ef26194f9 -size 46265993 +oid sha256:afee02876ed51f21883e731990a252fd57afda6a78700c41fbf22d8cb288b4b5 +size 46147818 diff --git a/selfdrive/modeld/parse_model_outputs.py b/selfdrive/modeld/parse_model_outputs.py index 783572d436..d5a5dc1bb4 100644 --- a/selfdrive/modeld/parse_model_outputs.py +++ b/selfdrive/modeld/parse_model_outputs.py @@ -90,22 +90,22 @@ class Parser: self.parse_mdn('road_transform', outs, in_N=0, out_N=0, out_shape=(ModelConstants.POSE_WIDTH,)) self.parse_mdn('lane_lines', outs, in_N=0, out_N=0, out_shape=(ModelConstants.NUM_LANE_LINES,ModelConstants.IDX_N,ModelConstants.LANE_LINES_WIDTH)) self.parse_mdn('road_edges', outs, in_N=0, out_N=0, out_shape=(ModelConstants.NUM_ROAD_EDGES,ModelConstants.IDX_N,ModelConstants.LANE_LINES_WIDTH)) - self.parse_mdn('lead', outs, in_N=ModelConstants.LEAD_MHP_N, out_N=ModelConstants.LEAD_MHP_SELECTION, - out_shape=(ModelConstants.LEAD_TRAJ_LEN,ModelConstants.LEAD_WIDTH)) - for k in ['lead_prob', 'lane_lines_prob']: - self.parse_binary_crossentropy(k, outs) + self.parse_binary_crossentropy('lane_lines_prob', outs) self.parse_categorical_crossentropy('desire_pred', outs, out_shape=(ModelConstants.DESIRE_PRED_LEN,ModelConstants.DESIRE_PRED_WIDTH)) self.parse_binary_crossentropy('meta', outs) return outs def parse_policy_outputs(self, outs: dict[str, np.ndarray]) -> dict[str, np.ndarray]: self.parse_mdn('plan', outs, in_N=ModelConstants.PLAN_MHP_N, out_N=ModelConstants.PLAN_MHP_SELECTION, - out_shape=(ModelConstants.IDX_N,ModelConstants.PLAN_WIDTH)) + out_shape=(ModelConstants.IDX_N,ModelConstants.PLAN_WIDTH)) if 'lat_planner_solution' in outs: self.parse_mdn('lat_planner_solution', outs, in_N=0, out_N=0, out_shape=(ModelConstants.IDX_N,ModelConstants.LAT_PLANNER_SOLUTION_WIDTH)) if 'desired_curvature' in outs: self.parse_mdn('desired_curvature', outs, in_N=0, out_N=0, out_shape=(ModelConstants.DESIRED_CURV_WIDTH,)) self.parse_categorical_crossentropy('desire_state', outs, out_shape=(ModelConstants.DESIRE_PRED_WIDTH,)) + self.parse_binary_crossentropy('lead_prob', outs) + self.parse_mdn('lead', outs, in_N=ModelConstants.LEAD_MHP_N, out_N=ModelConstants.LEAD_MHP_SELECTION, + out_shape=(ModelConstants.LEAD_TRAJ_LEN,ModelConstants.LEAD_WIDTH)) return outs def parse_outputs(self, outs: dict[str, np.ndarray]) -> dict[str, np.ndarray]: