|
|
|
@ -22,9 +22,10 @@ class Parser: |
|
|
|
|
self.ignore_missing = ignore_missing |
|
|
|
|
|
|
|
|
|
def check_missing(self, outs, name): |
|
|
|
|
if name not in outs and not self.ignore_missing: |
|
|
|
|
missing = name not in outs |
|
|
|
|
if missing and not self.ignore_missing: |
|
|
|
|
raise ValueError(f"Missing output {name}") |
|
|
|
|
return name not in outs |
|
|
|
|
return missing |
|
|
|
|
|
|
|
|
|
def parse_categorical_crossentropy(self, name, outs, out_shape=None): |
|
|
|
|
if self.check_missing(outs, name): |
|
|
|
@ -84,6 +85,13 @@ class Parser: |
|
|
|
|
outs[name] = pred_mu_final.reshape(final_shape) |
|
|
|
|
outs[name + '_stds'] = pred_std_final.reshape(final_shape) |
|
|
|
|
|
|
|
|
|
def is_mhp(self, outs, name, shape): |
|
|
|
|
if self.check_missing(outs, name): |
|
|
|
|
return False |
|
|
|
|
if outs[name].shape[1] == 2 * shape: |
|
|
|
|
return False |
|
|
|
|
return True |
|
|
|
|
|
|
|
|
|
def parse_vision_outputs(self, outs: dict[str, np.ndarray]) -> dict[str, np.ndarray]: |
|
|
|
|
self.parse_mdn('pose', outs, in_N=0, out_N=0, out_shape=(ModelConstants.POSE_WIDTH,)) |
|
|
|
|
self.parse_mdn('wide_from_device_euler', outs, in_N=0, out_N=0, out_shape=(ModelConstants.WIDE_FROM_DEVICE_WIDTH,)) |
|
|
|
@ -94,23 +102,18 @@ class Parser: |
|
|
|
|
self.parse_categorical_crossentropy('desire_pred', outs, out_shape=(ModelConstants.DESIRE_PRED_LEN,ModelConstants.DESIRE_PRED_WIDTH)) |
|
|
|
|
self.parse_binary_crossentropy('meta', outs) |
|
|
|
|
self.parse_binary_crossentropy('lead_prob', outs) |
|
|
|
|
if outs['lead'].shape[1] == 2 * ModelConstants.LEAD_MHP_SELECTION *ModelConstants.LEAD_TRAJ_LEN * ModelConstants.LEAD_WIDTH: |
|
|
|
|
self.parse_mdn('lead', outs, in_N=0, out_N=0, |
|
|
|
|
out_shape=(ModelConstants.LEAD_MHP_SELECTION, ModelConstants.LEAD_TRAJ_LEN,ModelConstants.LEAD_WIDTH)) |
|
|
|
|
else: |
|
|
|
|
self.parse_mdn('lead', outs, in_N=ModelConstants.LEAD_MHP_N, out_N=ModelConstants.LEAD_MHP_SELECTION, |
|
|
|
|
out_shape=(ModelConstants.LEAD_TRAJ_LEN,ModelConstants.LEAD_WIDTH)) |
|
|
|
|
lead_mhp = self.is_mhp(outs, 'lead', ModelConstants.LEAD_MHP_SELECTION * ModelConstants.LEAD_TRAJ_LEN * ModelConstants.LEAD_WIDTH) |
|
|
|
|
lead_in_N, lead_out_N = (ModelConstants.LEAD_MHP_N, ModelConstants.LEAD_MHP_SELECTION) if lead_mhp else (0, 0) |
|
|
|
|
self.parse_mdn( |
|
|
|
|
'lead', outs, in_N=lead_in_N, out_N=lead_out_N, |
|
|
|
|
out_shape=(ModelConstants.LEAD_MHP_SELECTION, ModelConstants.LEAD_TRAJ_LEN, ModelConstants.LEAD_WIDTH) |
|
|
|
|
) |
|
|
|
|
return outs |
|
|
|
|
|
|
|
|
|
def parse_policy_outputs(self, outs: dict[str, np.ndarray]) -> dict[str, np.ndarray]: |
|
|
|
|
if outs['plan'].shape[1] == 2 * ModelConstants.IDX_N * ModelConstants.PLAN_WIDTH: |
|
|
|
|
self.parse_mdn('plan', outs, in_N=0, out_N=0, |
|
|
|
|
out_shape=(ModelConstants.IDX_N,ModelConstants.PLAN_WIDTH)) |
|
|
|
|
else: |
|
|
|
|
self.parse_mdn('plan', outs, in_N=ModelConstants.PLAN_MHP_N, out_N=ModelConstants.PLAN_MHP_SELECTION, |
|
|
|
|
out_shape=(ModelConstants.IDX_N,ModelConstants.PLAN_WIDTH)) |
|
|
|
|
if 'desired_curvature' in outs: |
|
|
|
|
self.parse_mdn('desired_curvature', outs, in_N=0, out_N=0, out_shape=(ModelConstants.DESIRED_CURV_WIDTH,)) |
|
|
|
|
plan_mhp = self.is_mhp(outs, 'plan', ModelConstants.IDX_N * ModelConstants.PLAN_WIDTH) |
|
|
|
|
plan_in_N, plan_out_N = (ModelConstants.PLAN_MHP_N, ModelConstants.PLAN_MHP_SELECTION) if plan_mhp else (0, 0) |
|
|
|
|
self.parse_mdn('plan', outs, in_N=plan_in_N, out_N=plan_out_N, out_shape=(ModelConstants.IDX_N,ModelConstants.PLAN_WIDTH)) |
|
|
|
|
self.parse_categorical_crossentropy('desire_state', outs, out_shape=(ModelConstants.DESIRE_PRED_WIDTH,)) |
|
|
|
|
return outs |
|
|
|
|
|
|
|
|
|