|
|
@ -104,10 +104,9 @@ class Parser: |
|
|
|
self.parse_binary_crossentropy('lead_prob', outs) |
|
|
|
self.parse_binary_crossentropy('lead_prob', outs) |
|
|
|
lead_mhp = self.is_mhp(outs, 'lead', ModelConstants.LEAD_MHP_SELECTION * ModelConstants.LEAD_TRAJ_LEN * ModelConstants.LEAD_WIDTH) |
|
|
|
lead_mhp = self.is_mhp(outs, 'lead', ModelConstants.LEAD_MHP_SELECTION * ModelConstants.LEAD_TRAJ_LEN * ModelConstants.LEAD_WIDTH) |
|
|
|
lead_in_N, lead_out_N = (ModelConstants.LEAD_MHP_N, ModelConstants.LEAD_MHP_SELECTION) if lead_mhp else (0, 0) |
|
|
|
lead_in_N, lead_out_N = (ModelConstants.LEAD_MHP_N, ModelConstants.LEAD_MHP_SELECTION) if lead_mhp else (0, 0) |
|
|
|
self.parse_mdn( |
|
|
|
lead_out_shape = (ModelConstants.LEAD_TRAJ_LEN, ModelConstants.LEAD_WIDTH) if lead_mhp else \ |
|
|
|
'lead', outs, in_N=lead_in_N, out_N=lead_out_N, |
|
|
|
(ModelConstants.LEAD_MHP_SELECTION, ModelConstants.LEAD_TRAJ_LEN, ModelConstants.LEAD_WIDTH) |
|
|
|
out_shape=(ModelConstants.LEAD_MHP_SELECTION, ModelConstants.LEAD_TRAJ_LEN, ModelConstants.LEAD_WIDTH) |
|
|
|
self.parse_mdn('lead', outs, in_N=lead_in_N, out_N=lead_out_N, out_shape=lead_out_shape) |
|
|
|
) |
|
|
|
|
|
|
|
return outs |
|
|
|
return outs |
|
|
|
|
|
|
|
|
|
|
|
def parse_policy_outputs(self, outs: dict[str, np.ndarray]) -> dict[str, np.ndarray]: |
|
|
|
def parse_policy_outputs(self, outs: dict[str, np.ndarray]) -> dict[str, np.ndarray]: |
|
|
|