[docs]defset_env(self,env):""" Set environment. Args: env(Environment): Environment. Returns: None. """self.env=env
[docs]defget_reward(self,action):""" Get reward and penalty. Args: action(torch.Tensor): Tensor with agent moves. Returns: reward(torch.Tensor): Reward. penalty(torch.Tensor): Penalty. """reward=torch.zeros_like(action,dtype=torch.float,device=self.env.device)penalty=torch.zeros_like(action,dtype=torch.float,device=self.env.device)# compute penalty if env has unvisited nodes is_last_step=self.env.td_state['is_last_step']final_reward=self.env.td_state['agents']['cum_profit'].sum(-1,keepdim=True)reward[is_last_step]=final_reward[is_last_step]returnreward,penalty