|
|
@ -163,14 +163,14 @@ class ModelState: |
|
|
|
if prepare_only: |
|
|
|
if prepare_only: |
|
|
|
return None |
|
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
self.vision_output = self.vision_run(**self.vision_inputs).numpy().flatten() |
|
|
|
self.vision_output = self.vision_run(**self.vision_inputs).contiguous().realize().uop.base.buffer.numpy() |
|
|
|
vision_outputs_dict = self.parser.parse_vision_outputs(self.slice_outputs(self.vision_output, self.vision_output_slices)) |
|
|
|
vision_outputs_dict = self.parser.parse_vision_outputs(self.slice_outputs(self.vision_output, self.vision_output_slices)) |
|
|
|
|
|
|
|
|
|
|
|
self.full_features_buffer[0,:-1] = self.full_features_buffer[0,1:] |
|
|
|
self.full_features_buffer[0,:-1] = self.full_features_buffer[0,1:] |
|
|
|
self.full_features_buffer[0,-1] = vision_outputs_dict['hidden_state'][0, :] |
|
|
|
self.full_features_buffer[0,-1] = vision_outputs_dict['hidden_state'][0, :] |
|
|
|
self.numpy_inputs['features_buffer'][:] = self.full_features_buffer[0, self.temporal_idxs] |
|
|
|
self.numpy_inputs['features_buffer'][:] = self.full_features_buffer[0, self.temporal_idxs] |
|
|
|
|
|
|
|
|
|
|
|
self.policy_output = self.policy_run(**self.policy_inputs).numpy().flatten() |
|
|
|
self.policy_output = self.policy_run(**self.policy_inputs).contiguous().realize().uop.base.buffer.numpy() |
|
|
|
policy_outputs_dict = self.parser.parse_policy_outputs(self.slice_outputs(self.policy_output, self.policy_output_slices)) |
|
|
|
policy_outputs_dict = self.parser.parse_policy_outputs(self.slice_outputs(self.policy_output, self.policy_output_slices)) |
|
|
|
|
|
|
|
|
|
|
|
# TODO model only uses last value now |
|
|
|
# TODO model only uses last value now |
|
|
|