|
|
|
@ -93,7 +93,10 @@ class Parser: |
|
|
|
|
self.parse_mdn('wide_from_device_euler', outs, in_N=0, out_N=0, out_shape=(ModelConstants.WIDE_FROM_DEVICE_WIDTH,)) |
|
|
|
|
self.parse_mdn('lead', outs, in_N=ModelConstants.LEAD_MHP_N, out_N=ModelConstants.LEAD_MHP_SELECTION, |
|
|
|
|
out_shape=(ModelConstants.LEAD_TRAJ_LEN,ModelConstants.LEAD_WIDTH)) |
|
|
|
|
self.parse_mdn('desired_curvature', outs, in_N=0, out_N=0, out_shape=(ModelConstants.DESIRED_CURV_WIDTH,)) |
|
|
|
|
if 'lat_planner_solution' in outs: |
|
|
|
|
self.parse_mdn('lat_planner_solution', outs, in_N=0, out_N=0, out_shape=(ModelConstants.IDX_N,ModelConstants.LAT_PLANNER_SOLUTION_WIDTH)) |
|
|
|
|
if 'desired_curvature' in outs: |
|
|
|
|
self.parse_mdn('desired_curvature', outs, in_N=0, out_N=0, out_shape=(ModelConstants.DESIRED_CURV_WIDTH,)) |
|
|
|
|
for k in ['lead_prob', 'lane_lines_prob', 'meta']: |
|
|
|
|
self.parse_binary_crossentropy(k, outs) |
|
|
|
|
self.parse_categorical_crossentropy('desire_state', outs, out_shape=(ModelConstants.DESIRE_PRED_WIDTH,)) |
|
|
|
|