* squash

* double

* proper merge

* better organization
capnpy
Harald Schäfer 3 days ago committed by GitHub
parent 10cc87b80b
commit cd087a561e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 10
      selfdrive/modeld/parse_model_outputs.py

@ -94,15 +94,21 @@ class Parser:
self.parse_categorical_crossentropy('desire_pred', outs, out_shape=(ModelConstants.DESIRE_PRED_LEN,ModelConstants.DESIRE_PRED_WIDTH)) self.parse_categorical_crossentropy('desire_pred', outs, out_shape=(ModelConstants.DESIRE_PRED_LEN,ModelConstants.DESIRE_PRED_WIDTH))
self.parse_binary_crossentropy('meta', outs) self.parse_binary_crossentropy('meta', outs)
self.parse_binary_crossentropy('lead_prob', outs) self.parse_binary_crossentropy('lead_prob', outs)
if outs['lead'].shape[1] == 2 * ModelConstants.LEAD_MHP_SELECTION *ModelConstants.LEAD_TRAJ_LEN * ModelConstants.LEAD_WIDTH:
self.parse_mdn('lead', outs, in_N=0, out_N=0,
out_shape=(ModelConstants.LEAD_MHP_SELECTION, ModelConstants.LEAD_TRAJ_LEN,ModelConstants.LEAD_WIDTH))
else:
self.parse_mdn('lead', outs, in_N=ModelConstants.LEAD_MHP_N, out_N=ModelConstants.LEAD_MHP_SELECTION, self.parse_mdn('lead', outs, in_N=ModelConstants.LEAD_MHP_N, out_N=ModelConstants.LEAD_MHP_SELECTION,
out_shape=(ModelConstants.LEAD_TRAJ_LEN,ModelConstants.LEAD_WIDTH)) out_shape=(ModelConstants.LEAD_TRAJ_LEN,ModelConstants.LEAD_WIDTH))
return outs return outs
def parse_policy_outputs(self, outs: dict[str, np.ndarray]) -> dict[str, np.ndarray]: def parse_policy_outputs(self, outs: dict[str, np.ndarray]) -> dict[str, np.ndarray]:
if outs['plan'].shape[1] == 2 * ModelConstants.IDX_N * ModelConstants.PLAN_WIDTH:
self.parse_mdn('plan', outs, in_N=0, out_N=0,
out_shape=(ModelConstants.IDX_N,ModelConstants.PLAN_WIDTH))
else:
self.parse_mdn('plan', outs, in_N=ModelConstants.PLAN_MHP_N, out_N=ModelConstants.PLAN_MHP_SELECTION, self.parse_mdn('plan', outs, in_N=ModelConstants.PLAN_MHP_N, out_N=ModelConstants.PLAN_MHP_SELECTION,
out_shape=(ModelConstants.IDX_N,ModelConstants.PLAN_WIDTH)) out_shape=(ModelConstants.IDX_N,ModelConstants.PLAN_WIDTH))
if 'lat_planner_solution' in outs:
self.parse_mdn('lat_planner_solution', outs, in_N=0, out_N=0, out_shape=(ModelConstants.IDX_N,ModelConstants.LAT_PLANNER_SOLUTION_WIDTH))
if 'desired_curvature' in outs: if 'desired_curvature' in outs:
self.parse_mdn('desired_curvature', outs, in_N=0, out_N=0, out_shape=(ModelConstants.DESIRED_CURV_WIDTH,)) self.parse_mdn('desired_curvature', outs, in_N=0, out_N=0, out_shape=(ModelConstants.DESIRED_CURV_WIDTH,))
self.parse_categorical_crossentropy('desire_state', outs, out_shape=(ModelConstants.DESIRE_PRED_WIDTH,)) self.parse_categorical_crossentropy('desire_state', outs, out_shape=(ModelConstants.DESIRE_PRED_WIDTH,))

Loading…
Cancel
Save