|
|
|
@ -56,24 +56,16 @@ class ModelState: |
|
|
|
|
prev_desire: np.ndarray # for tracking the rising edge of the pulse |
|
|
|
|
|
|
|
|
|
def __init__(self, context: CLContext): |
|
|
|
|
self.frames = { |
|
|
|
|
'input_imgs': DrivingModelFrame(context, ModelConstants.TEMPORAL_SKIP), |
|
|
|
|
'big_input_imgs': DrivingModelFrame(context, ModelConstants.TEMPORAL_SKIP) |
|
|
|
|
} |
|
|
|
|
self.frames = {'input_imgs': DrivingModelFrame(context), 'big_input_imgs': DrivingModelFrame(context)} |
|
|
|
|
self.prev_desire = np.zeros(ModelConstants.DESIRE_LEN, dtype=np.float32) |
|
|
|
|
|
|
|
|
|
self.full_features_buffer = np.zeros((1, ModelConstants.FULL_HISTORY_BUFFER_LEN, ModelConstants.FEATURE_LEN), dtype=np.float32) |
|
|
|
|
self.full_desire = np.zeros((1, ModelConstants.FULL_HISTORY_BUFFER_LEN, ModelConstants.DESIRE_LEN), dtype=np.float32) |
|
|
|
|
self.full_prev_desired_curv = np.zeros((1, ModelConstants.FULL_HISTORY_BUFFER_LEN, ModelConstants.PREV_DESIRED_CURV_LEN), dtype=np.float32) |
|
|
|
|
self.temporal_idxs = slice(-1-(ModelConstants.TEMPORAL_SKIP*(ModelConstants.INPUT_HISTORY_BUFFER_LEN-1)), None, ModelConstants.TEMPORAL_SKIP) |
|
|
|
|
|
|
|
|
|
# policy inputs |
|
|
|
|
self.numpy_inputs = { |
|
|
|
|
'desire': np.zeros((1, ModelConstants.INPUT_HISTORY_BUFFER_LEN, ModelConstants.DESIRE_LEN), dtype=np.float32), |
|
|
|
|
'desire': np.zeros((1, ModelConstants.FULL_HISTORY_BUFFER_LEN, ModelConstants.DESIRE_LEN), dtype=np.float32), |
|
|
|
|
'traffic_convention': np.zeros((1, ModelConstants.TRAFFIC_CONVENTION_LEN), dtype=np.float32), |
|
|
|
|
'lateral_control_params': np.zeros((1, ModelConstants.LATERAL_CONTROL_PARAMS_LEN), dtype=np.float32), |
|
|
|
|
'prev_desired_curv': np.zeros((1, ModelConstants.INPUT_HISTORY_BUFFER_LEN, ModelConstants.PREV_DESIRED_CURV_LEN), dtype=np.float32), |
|
|
|
|
'features_buffer': np.zeros((1, ModelConstants.INPUT_HISTORY_BUFFER_LEN, ModelConstants.FEATURE_LEN), dtype=np.float32), |
|
|
|
|
'prev_desired_curv': np.zeros((1, ModelConstants.FULL_HISTORY_BUFFER_LEN, ModelConstants.PREV_DESIRED_CURV_LEN), dtype=np.float32), |
|
|
|
|
'features_buffer': np.zeros((1, ModelConstants.FULL_HISTORY_BUFFER_LEN, ModelConstants.FEATURE_LEN), dtype=np.float32), |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
with open(VISION_METADATA_PATH, 'rb') as f: |
|
|
|
@ -112,9 +104,8 @@ class ModelState: |
|
|
|
|
new_desire = np.where(inputs['desire'] - self.prev_desire > .99, inputs['desire'], 0) |
|
|
|
|
self.prev_desire[:] = inputs['desire'] |
|
|
|
|
|
|
|
|
|
self.full_desire[0,:-1] = self.full_desire[0,1:] |
|
|
|
|
self.full_desire[0,-1] = new_desire |
|
|
|
|
self.numpy_inputs['desire'][:] = self.full_desire.reshape((1,ModelConstants.INPUT_HISTORY_BUFFER_LEN,ModelConstants.TEMPORAL_SKIP,-1)).max(axis=2) |
|
|
|
|
self.numpy_inputs['desire'][0,:-1] = self.numpy_inputs['desire'][0,1:] |
|
|
|
|
self.numpy_inputs['desire'][0,-1] = new_desire |
|
|
|
|
|
|
|
|
|
self.numpy_inputs['traffic_convention'][:] = inputs['traffic_convention'] |
|
|
|
|
self.numpy_inputs['lateral_control_params'][:] = inputs['lateral_control_params'] |
|
|
|
@ -137,17 +128,15 @@ class ModelState: |
|
|
|
|
self.vision_output = self.vision_run(**self.vision_inputs).numpy().flatten() |
|
|
|
|
vision_outputs_dict = self.parser.parse_vision_outputs(self.slice_outputs(self.vision_output, self.vision_output_slices)) |
|
|
|
|
|
|
|
|
|
self.full_features_buffer[0,:-1] = self.full_features_buffer[0,1:] |
|
|
|
|
self.full_features_buffer[0,-1] = vision_outputs_dict['hidden_state'][0, :] |
|
|
|
|
self.numpy_inputs['features_buffer'][:] = self.full_features_buffer[0, self.temporal_idxs] |
|
|
|
|
self.numpy_inputs['features_buffer'][0,:-1] = self.numpy_inputs['features_buffer'][0,1:] |
|
|
|
|
self.numpy_inputs['features_buffer'][0,-1] = vision_outputs_dict['hidden_state'][0, :] |
|
|
|
|
|
|
|
|
|
self.policy_output = self.policy_run(**self.policy_inputs).numpy().flatten() |
|
|
|
|
policy_outputs_dict = self.parser.parse_policy_outputs(self.slice_outputs(self.policy_output, self.policy_output_slices)) |
|
|
|
|
|
|
|
|
|
# TODO model only uses last value now |
|
|
|
|
self.full_prev_desired_curv[0,:-1] = self.full_prev_desired_curv[0,1:] |
|
|
|
|
self.full_prev_desired_curv[0,-1,:] = policy_outputs_dict['desired_curvature'][0, :] |
|
|
|
|
self.numpy_inputs['prev_desired_curv'][:] = self.full_prev_desired_curv[0, self.temporal_idxs] |
|
|
|
|
self.numpy_inputs['prev_desired_curv'][0,:-1] = self.numpy_inputs['prev_desired_curv'][0,1:] |
|
|
|
|
self.numpy_inputs['prev_desired_curv'][0,-1,:] = policy_outputs_dict['desired_curvature'][0, :] |
|
|
|
|
|
|
|
|
|
combined_outputs_dict = {**vision_outputs_dict, **policy_outputs_dict} |
|
|
|
|
if SEND_RAW_PRED: |
|
|
|
|