|
|
|
@ -58,23 +58,20 @@ def parse_mdn(name, outs, in_N=0, out_N=1, out_shape=None): |
|
|
|
|
final_shape = tuple([raw.shape[0],] + list(out_shape)) |
|
|
|
|
outs[name] = pred_mu_final.reshape(final_shape) |
|
|
|
|
outs[name + '_stds'] = pred_std_final.reshape(final_shape) |
|
|
|
|
return |
|
|
|
|
|
|
|
|
|
def parse_binary_crossentropy(name, outs): |
|
|
|
|
if name not in outs: |
|
|
|
|
return |
|
|
|
|
raw = outs[name] |
|
|
|
|
outs[name] = sigmoid(raw) |
|
|
|
|
return |
|
|
|
|
|
|
|
|
|
def parse_categorical_crossentropy(name, outs, size=1): |
|
|
|
|
def parse_categorical_crossentropy(name, outs, out_shape=None): |
|
|
|
|
if name not in outs: |
|
|
|
|
return |
|
|
|
|
raw = outs[name] |
|
|
|
|
if size > 1: |
|
|
|
|
raw = raw.reshape((raw.shape[0], size, -1)) |
|
|
|
|
if out_shape is not None: |
|
|
|
|
raw = raw.reshape((raw.shape[0],) + out_shape) |
|
|
|
|
outs[name] = softmax(raw, axis=-1) |
|
|
|
|
return |
|
|
|
|
|
|
|
|
|
def parse_outputs(outs: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: |
|
|
|
|
parse_mdn('plan', outs, in_N=PLAN_MHP_N, out_N=PLAN_MHP_SELECTION, out_shape=(IDX_N,PLAN_WIDTH)) |
|
|
|
@ -87,7 +84,6 @@ def parse_outputs(outs: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: |
|
|
|
|
parse_mdn('lead', outs, in_N=LEAD_MHP_N, out_N=LEAD_MHP_SELECTION, out_shape=(LEAD_TRAJ_LEN,LEAD_WIDTH)) |
|
|
|
|
for k in ['lead_prob', 'lane_lines_prob', 'meta']: |
|
|
|
|
parse_binary_crossentropy(k, outs) |
|
|
|
|
for k in ['desire_pred', 'desire_state']: |
|
|
|
|
parse_categorical_crossentropy(k, outs, size=DESIRE_PRED_WIDTH) |
|
|
|
|
parse_categorical_crossentropy(k, outs, size=DESIRE_PRED_WIDTH) |
|
|
|
|
parse_categorical_crossentropy('desire_state', outs, out_shape=(DESIRE_PRED_WIDTH,)) |
|
|
|
|
parse_categorical_crossentropy('desire_pred', outs, out_shape=(DESIRE_PRED_LEN,DESIRE_PRED_WIDTH)) |
|
|
|
|
return outs |
|
|
|
|