model parser: fix lead mhp out shape (#36024)

* model parser: fix lead mhp out shape

* fix for real
pull/36027/head
YassineYousfi 3 days ago committed by GitHub
parent 560c503871
commit d097a0c201
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 9
      selfdrive/modeld/parse_model_outputs.py

@ -104,16 +104,15 @@ class Parser:
self.parse_binary_crossentropy('lead_prob', outs) self.parse_binary_crossentropy('lead_prob', outs)
lead_mhp = self.is_mhp(outs, 'lead', ModelConstants.LEAD_MHP_SELECTION * ModelConstants.LEAD_TRAJ_LEN * ModelConstants.LEAD_WIDTH) lead_mhp = self.is_mhp(outs, 'lead', ModelConstants.LEAD_MHP_SELECTION * ModelConstants.LEAD_TRAJ_LEN * ModelConstants.LEAD_WIDTH)
lead_in_N, lead_out_N = (ModelConstants.LEAD_MHP_N, ModelConstants.LEAD_MHP_SELECTION) if lead_mhp else (0, 0) lead_in_N, lead_out_N = (ModelConstants.LEAD_MHP_N, ModelConstants.LEAD_MHP_SELECTION) if lead_mhp else (0, 0)
self.parse_mdn( lead_out_shape = (ModelConstants.LEAD_TRAJ_LEN, ModelConstants.LEAD_WIDTH) if lead_mhp else \
'lead', outs, in_N=lead_in_N, out_N=lead_out_N, (ModelConstants.LEAD_MHP_SELECTION, ModelConstants.LEAD_TRAJ_LEN, ModelConstants.LEAD_WIDTH)
out_shape=(ModelConstants.LEAD_MHP_SELECTION, ModelConstants.LEAD_TRAJ_LEN, ModelConstants.LEAD_WIDTH) self.parse_mdn('lead', outs, in_N=lead_in_N, out_N=lead_out_N, out_shape=lead_out_shape)
)
return outs return outs
def parse_policy_outputs(self, outs: dict[str, np.ndarray]) -> dict[str, np.ndarray]: def parse_policy_outputs(self, outs: dict[str, np.ndarray]) -> dict[str, np.ndarray]:
plan_mhp = self.is_mhp(outs, 'plan', ModelConstants.IDX_N * ModelConstants.PLAN_WIDTH) plan_mhp = self.is_mhp(outs, 'plan', ModelConstants.IDX_N * ModelConstants.PLAN_WIDTH)
plan_in_N, plan_out_N = (ModelConstants.PLAN_MHP_N, ModelConstants.PLAN_MHP_SELECTION) if plan_mhp else (0, 0) plan_in_N, plan_out_N = (ModelConstants.PLAN_MHP_N, ModelConstants.PLAN_MHP_SELECTION) if plan_mhp else (0, 0)
self.parse_mdn('plan', outs, in_N=plan_in_N, out_N=plan_out_N, out_shape=(ModelConstants.IDX_N,ModelConstants.PLAN_WIDTH)) self.parse_mdn('plan', outs, in_N=plan_in_N, out_N=plan_out_N, out_shape=(ModelConstants.IDX_N, ModelConstants.PLAN_WIDTH))
self.parse_categorical_crossentropy('desire_state', outs, out_shape=(ModelConstants.DESIRE_PRED_WIDTH,)) self.parse_categorical_crossentropy('desire_state', outs, out_shape=(ModelConstants.DESIRE_PRED_WIDTH,))
return outs return outs

Loading…
Cancel
Save