[docs]defset_env(self,env:AECEnv):""" Set environment. Args: env(AECEnv): Environment. Returns: None. """super().set_env(env)
def_next_agent(self):""" Return the next agent. Args: n/a. Returns: selected_agent(torch.Tensor): Next agent. """avail=torch.arange(self.env.num_agents,dtype=torch.float).unsqueeze(0).repeat(*self.env.batch_size,1).to(self.env.device)avail[~self.env.td_state['agents']['active_agents_mask']]=float('inf')selected_agent=avail.argmin(1,keepdim=True)returnselected_agent
[docs]classRandomSelector(BaseSelector):""" MDVRPTW random agent selector class. """
[docs]defset_env(self,env:AECEnv):""" Set environment. Args: env(AECEnv): Environment. Returns: None. """super().set_env(env)
def_next_agent(self):""" Return the next agent. Args: n/a. Returns: selected_agent(torch.Tensor): Next agent. """selected_agent=torch.multinomial(self.env.td_state['agents']['active_agents_mask'].float(),1).to(self.env.device)returnselected_agent
[docs]classSmallestTimeAgentSelector(BaseSelector):""" MDVRPTW smallest time agent selector class. """
[docs]defset_env(self,env:AECEnv):""" Set environment. Args: env(AECEnv): Environment. Returns: None. """super().set_env(env)
def_next_agent(self):""" Return the next agent. Args: n/a. Returns: selected_agent(torch.Tensor): Next agent. """avail=self.env.td_state['agents']['cur_time'].clone()avail[~self.env.td_state['agents']['active_agents_mask']]=float('inf')selected_agent=avail.argmin(1,keepdim=True)returnselected_agent