|
|
|
@ -62,6 +62,7 @@ class ModelState: |
|
|
|
|
self.full_features_buffer = np.zeros((1, ModelConstants.FULL_HISTORY_BUFFER_LEN, ModelConstants.FEATURE_LEN), dtype=np.float32) |
|
|
|
|
self.full_desire = np.zeros((1, ModelConstants.FULL_HISTORY_BUFFER_LEN, ModelConstants.DESIRE_LEN), dtype=np.float32) |
|
|
|
|
self.full_prev_desired_curv = np.zeros((1, ModelConstants.FULL_HISTORY_BUFFER_LEN, ModelConstants.PREV_DESIRED_CURV_LEN), dtype=np.float32) |
|
|
|
|
self.temporal_idxs = slice(-1-(ModelConstants.TEMPORAL_SKIP*(ModelConstants.FULL_HISTORY_BUFFER_LEN_INPUT-1)), None, ModelConstants.TEMPORAL_SKIP) |
|
|
|
|
|
|
|
|
|
# policy inputs |
|
|
|
|
self.numpy_inputs = { |
|
|
|
@ -110,7 +111,7 @@ class ModelState: |
|
|
|
|
|
|
|
|
|
self.full_desire[0,:-1] = self.full_desire[0,1:] |
|
|
|
|
self.full_desire[0,-1] = new_desire |
|
|
|
|
self.numpy_inputs['desire'][:] = self.full_desire[0, -1-ModelConstants.FULL_HISTORY_BUFFER_LEN_INPUT*ModelConstants.TEMPORAL_SKIP::ModelConstants.TEMPORAL_SKIP] |
|
|
|
|
self.numpy_inputs['desire'][:] = self.full_desire[self.temporal_idxs] |
|
|
|
|
|
|
|
|
|
self.numpy_inputs['traffic_convention'][:] = inputs['traffic_convention'] |
|
|
|
|
self.numpy_inputs['lateral_control_params'][:] = inputs['lateral_control_params'] |
|
|
|
@ -135,7 +136,7 @@ class ModelState: |
|
|
|
|
|
|
|
|
|
self.full_features_buffer[0,:-1] = self.full_features_buffer[0,1:] |
|
|
|
|
self.full_features_buffer[0,-1] = vision_outputs_dict['hidden_state'][0, :] |
|
|
|
|
self.numpy_inputs['features_buffer'][:] = self.full_features_buffer[0, -1-ModelConstants.FULL_HISTORY_BUFFER_LEN_INPUT*ModelConstants.TEMPORAL_SKIP::ModelConstants.TEMPORAL_SKIP] |
|
|
|
|
self.numpy_inputs['features_buffer'][:] = self.full_features_buffer[self.temporal_idxs] |
|
|
|
|
|
|
|
|
|
self.policy_output = self.policy_run(**self.policy_inputs).numpy().flatten() |
|
|
|
|
policy_outputs_dict = self.parser.parse_policy_outputs(self.slice_outputs(self.policy_output, self.policy_output_slices)) |
|
|
|
@ -143,7 +144,7 @@ class ModelState: |
|
|
|
|
# TODO model only uses last value now |
|
|
|
|
self.full_prev_desired_curv[0,:-1] = self.full_prev_desired_curv[0,1:] |
|
|
|
|
self.full_prev_desired_curv[0,-1,:] = policy_outputs_dict['desired_curvature'][0, :] |
|
|
|
|
self.numpy_inputs['prev_desired_curv'][:] = self.full_prev_desired_curv[0, -1-ModelConstants.FULL_HISTORY_BUFFER_LEN_INPUT*ModelConstants.TEMPORAL_SKIP::ModelConstants.TEMPORAL_SKIP] |
|
|
|
|
self.numpy_inputs['prev_desired_curv'][:] = self.full_prev_desired_curv[self.temporal_idxs] |
|
|
|
|
|
|
|
|
|
combined_outputs_dict = {**vision_outputs_dict, **policy_outputs_dict} |
|
|
|
|
if SEND_RAW_PRED: |
|
|
|
|