diff --git a/selfdrive/modeld/constants.py b/selfdrive/modeld/constants.py index d9851b3435..05b319c278 100644 --- a/selfdrive/modeld/constants.py +++ b/selfdrive/modeld/constants.py @@ -29,6 +29,7 @@ NUM_LANE_LINES = 4 NUM_ROAD_EDGES = 2 LEAD_TRAJ_LEN = 6 +DESIRE_PRED_LEN = 4 PLAN_MHP_N = 5 LEAD_MHP_N = 2 diff --git a/selfdrive/modeld/parse_model_outputs.py b/selfdrive/modeld/parse_model_outputs.py index 43ac7f92cb..faeb12883c 100755 --- a/selfdrive/modeld/parse_model_outputs.py +++ b/selfdrive/modeld/parse_model_outputs.py @@ -58,23 +58,20 @@ def parse_mdn(name, outs, in_N=0, out_N=1, out_shape=None): final_shape = tuple([raw.shape[0],] + list(out_shape)) outs[name] = pred_mu_final.reshape(final_shape) outs[name + '_stds'] = pred_std_final.reshape(final_shape) - return def parse_binary_crossentropy(name, outs): if name not in outs: return raw = outs[name] outs[name] = sigmoid(raw) - return -def parse_categorical_crossentropy(name, outs, size=1): +def parse_categorical_crossentropy(name, outs, out_shape=None): if name not in outs: return raw = outs[name] - if size > 1: - raw = raw.reshape((raw.shape[0], size, -1)) + if out_shape is not None: + raw = raw.reshape((raw.shape[0],) + out_shape) outs[name] = softmax(raw, axis=-1) - return def parse_outputs(outs: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: parse_mdn('plan', outs, in_N=PLAN_MHP_N, out_N=PLAN_MHP_SELECTION, out_shape=(IDX_N,PLAN_WIDTH)) @@ -87,7 +84,6 @@ def parse_outputs(outs: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: parse_mdn('lead', outs, in_N=LEAD_MHP_N, out_N=LEAD_MHP_SELECTION, out_shape=(LEAD_TRAJ_LEN,LEAD_WIDTH)) for k in ['lead_prob', 'lane_lines_prob', 'meta']: parse_binary_crossentropy(k, outs) - for k in ['desire_pred', 'desire_state']: - parse_categorical_crossentropy(k, outs, size=DESIRE_PRED_WIDTH) - parse_categorical_crossentropy(k, outs, size=DESIRE_PRED_WIDTH) + parse_categorical_crossentropy('desire_state', outs, out_shape=(DESIRE_PRED_WIDTH,)) + parse_categorical_crossentropy('desire_pred', outs, out_shape=(DESIRE_PRED_LEN,DESIRE_PRED_WIDTH)) return outs