Source code for maenvs4vrp.environments.gmtdvrp.env

import torch
from tensordict import TensorDict

from typing import Optional, Dict

from maenvs4vrp.core.env_generator_builder import InstanceBuilder
from maenvs4vrp.core.env_observation_builder import ObservationBuilder
from maenvs4vrp.core.env_agent_selector import BaseSelector
from maenvs4vrp.core.env_agent_reward import RewardFn
from maenvs4vrp.core.env import AECEnv

from maenvs4vrp.utils.utils import gather_by_index

MAX_TIME = 1_000_000

[docs] class Environment(AECEnv): """ GMTDVRP environment generator class. """
[docs] def __init__( self, instance_generator_object: InstanceBuilder, obs_builder_object: ObservationBuilder, agent_selector_object: BaseSelector, reward_evaluator: RewardFn, seed=None, device: Optional[str] = None, batch_size: torch.Size = None ): """ Constructor. Args: instance_generator_object(InstanceBuilder): Generator instance. obs_builder_object(ObservationBuilder): Observations instance. agent_selector_object(BaseSelector): Agent selector instance reward_evaluator(RewardFn): Reward evaluator instance. seed(int): Random number generator seed. Defaults to None. device(str, optional): Type of processing. It can be "cpu" or "gpu". Defaults to None. batch_size(torch.Size): Batch size. Defaults to None. """ self.version = 'v0' self.env_name = 'gmtdvrp' # seed the environment if seed is None: self._set_seed(self.DEFAULT_SEED) else: self._set_seed(seed) self.agent_selector = agent_selector_object self.inst_generator = instance_generator_object self.inst_generator._set_seed(self.seed) self.obs_builder = obs_builder_object self.obs_builder.set_env(self) self.reward_evaluator = reward_evaluator self.reward_evaluator.set_env(self) self.env_nsteps = 0 if device is None: self.device = self.inst_generator.device else: self.device = device self.inst_generator.device = device if batch_size is None: self.batch_size = self.inst_generator.batch_size else: batch_size = [batch_size] if isinstance(batch_size, int) else batch_size self.batch_size = torch.Size(batch_size) self.inst_generator.batch_size = torch.Size(batch_size) self.td_state = TensorDict({}, batch_size=self.batch_size, device=self.device) #Environment TensorDict
[docs] def observe(self, is_reset=False)-> TensorDict: """ Retrieve agent environment observations. Args: is_reset(bool): If the environment is on reset. Defauts to False. Returns td_observations(TensorDict): Current agent observaions and masks dictionary. """ self._update_feasibility() td_observations = self.obs_builder.get_observations(is_reset=is_reset) td_observations['action_mask'] = self.td_state['cur_agent']['action_mask'].clone() td_observations['agents_mask'] = self.td_state['agents']['active_agents_mask'].clone() return td_observations
[docs] def sample_initial_load(self, td:TensorDict): """ Sample random initial loads for agents. Args: td(TensorDict): Environment instance tensor. Returns: td(TensorDict): Environment instance tensor with updated initial load. """ assert self.env_nsteps == 0, f"Initial load must be done at step = 0" initial_load = torch.rand(*self.batch_size, self.num_agents) * self.td_state['capacity'] #Random number between 0 and capacity td['initial_load'] = initial_load return td
[docs] def set_initial_load(self, td:TensorDict): """ Set initial loads for agents. Initial loads are filled with td['initial_load']. Args: td(TensorDict): Environment instance tensor. Returns: td(TensorDict): Environment instance tensor with updated initial load. """ assert self.env_nsteps == 0, f"Initial load must be done at step = 0" self.td_state['agents']['cur_linehaul_load'] = td['initial_load'] self.td_state['cur_agent']['cur_linehaul_load'] = self.td_state['agents']['cur_linehaul_load'].gather(1, self.td_state['cur_agent_idx']) return td
[docs] def sample_action( self, td: TensorDict ) -> TensorDict: """ Compute a random action from avaliable actions to current agent. Args: td(TensorDict): Environment instance tensor. Returns: td(TensorDict): Environment instance tensor with updated action. """ action = torch.multinomial(self.td_state['cur_agent']["action_mask"].float(), 1).to(self.device) td['action'] = action return td
[docs] def reset( self, num_depots: int = None, num_agents: int = None, num_nodes: int = None, min_coords: float = None, max_coords: float = None, capacity: int = None, service_time: float = None, instance_name:str|None=None, min_demands: int = None, max_demands: int = None, min_backhaul: int = None, max_backhaul: int = None, max_time: float = None, backhaul_ratio: float = None, backhaul_class: int = None, sample_backhaul_class: bool = None, max_distance_limit: float = None, speed: float = None, initial_load: float = None, subsample: bool = True, variant_preset: str = None, use_combinations: bool = False, instance_dict:Dict=None, force_visit: bool = False, batch_size: Optional[torch.Size] = None, n_augment: Optional[int] = 2, sample_type: str = 'random', seed: int = None, device: Optional[str] = "cpu" ): """ Reset the environment. Args: num_depots(int): Total number of depots. Defaults to None. num_agents(int): Total number of agents. Defaults to None. num_nodes(int): Total number of nodes. Defaults to None. min_coords(float): Minimum number of coords. Defaults to None. max_coords(float): Maximum number of coords. Defaults to None. capacity(int): Vehicles' capacity. Defaults to None. service_time(float): Service time. Defaults to None. min_demands(int): Minimum number of demands. Defaults to None. max_demands(int): Maximum number of demands. Defaults to None. min_backhaul(int): Minimum number of backhauls. Defaults to None. max_backhaul(int): Maximum number of backhauls. Defaults to None. max_time(float): Maximum route time. Defaults to None. backhaul_ratio(float): Ratio of backhaul demands. Defaults to None. backhaul_class(int): Class of backhaul problem. If 1, it's unmixed, if 2, it's mixed. Defaults to None. sample_backhaul_class(bool): If backhaul class is sampled across batches. Defaults to False. max_distance_limit(float): Route distance limits. Defaults to None. speed(float): Vehicles' speed. Defaults to None. initial_load(float): Vehicles' initial load. Defaults to None. subsample(bool): If problem variants are to be sampled. Defaults to True. variant_preset(str): Variant preset to be sampled. Defaults to None. use_combinations(bool): It considers combinations for which sampling mask the instance is defined. Defaults to False. force_visit(bool): It forces the agent to visit all feasible nodes before going back to depot. Defaults to True. batch_size(torch.Size, optional): Batch size. Defaults to None. n_augment(int, optional): Number of augmentations. Defaults to None. sample_type(str): Type of instance to sample. It can be "random", "augment" or "saved". Defaults to "random". seed(int): Random number generator seed. Defaults to None. device(str, optional): Type of processing. It can be "cpu" or "gpu". Defaults to "cpu". Returns: TensorDict: Environment information dictionary. """ if seed is not None: self._set_seed(seed) if batch_size is None: batch_size = self.batch_size else: batch_size = [batch_size] if isinstance(batch_size, int) else batch_size self.batch_size = torch.Size(batch_size) self.inst_generator.batch_size = torch.Size(batch_size) if force_visit is not None: self.force_visit = force_visit if instance_dict: instance_info = instance_dict else: instance_info = self.inst_generator.sample_instance( num_depots = num_depots, num_agents = num_agents, num_nodes = num_nodes, min_coords = min_coords, max_coords = max_coords, capacity = capacity, service_time = service_time, instance_name=instance_name, min_demands = min_demands, max_demands = max_demands, min_backhaul = min_backhaul, max_backhaul = max_backhaul, max_time = max_time, backhaul_ratio = backhaul_ratio, backhaul_class = backhaul_class, sample_backhaul_class = sample_backhaul_class, max_distance_limit = max_distance_limit, speed = speed, initial_load = initial_load, subsample = subsample, variant_preset = variant_preset, use_combinations = use_combinations, sample_type=sample_type, n_augment=n_augment, batch_size = batch_size, seed = seed, device = device ) self.num_nodes = instance_info['num_nodes'] self.num_depots = instance_info['num_depots'] self.num_agents_depot = instance_info['num_agents'] self.num_agents = instance_info['num_agents'] * self.num_depots if 'n_digits' in instance_info: self.n_digits = instance_info['n_digits'] else: self.n_digits = None self.td_state = instance_info['data'].to(self.device) #Data from instance goes into env td_state self.td_state['done'] = torch.zeros(*batch_size, dtype=torch.bool) self.td_state['is_last_step'] = torch.zeros(*batch_size, dtype=torch.bool) self.td_state['depot_loc'] = self.td_state['coords'].gather(1, self.td_state['depot_idx'][:,:,None].expand(-1, -1, 2)) self.td_state['start_time'] = self.td_state['tw_low'].gather(1, torch.zeros((*self.batch_size, 1), dtype=torch.int64, device=self.device)).squeeze(-1) self.td_state['end_time'] = self.td_state['tw_high'].gather(1, torch.zeros((*self.batch_size, 1), dtype=torch.int64, device=self.device)).squeeze(-1) self.td_state['max_tour_duration'] = self.td_state['end_time'] - self.td_state['start_time'] if self.n_digits is not None: distance2depot = torch.floor(self.n_digits * time2depot) / self.n_digits time2depot = torch.floor(self.n_digits * time2depot) / self.n_digits self.td_state['initial_load'] = instance_info['data']['initial_load'].clone() cur_agent_idx = torch.zeros((*batch_size, 1), dtype = torch.int64, device=self.device) self.td_state['cur_agent_idx'] = cur_agent_idx self.td_state['agents'] = TensorDict( source={'capacity': self.td_state['capacity'], 'depot_idx': self.td_state['depot_idx'].repeat((1, self.num_agents_depot)), 'cur_time': self.td_state['start_time'].unsqueeze(1).clone() * torch.ones((*batch_size, self.num_agents), dtype = torch.float, device=self.device), 'cur_node': self.td_state['depot_idx'].repeat((1, self.num_agents_depot)), 'cur_ttime': torch.zeros((*batch_size, self.num_agents), dtype = torch.float, device=self.device), 'cum_ttime': torch.zeros((*batch_size, self.num_agents), dtype = torch.float, device=self.device), 'visited_nodes': torch.zeros((*batch_size, self.num_agents, self.num_nodes), dtype=torch.bool, device=self.device), 'feasible_nodes': torch.ones((*batch_size, self.num_agents, self.num_nodes), dtype=torch.bool, device=self.device), 'active_agents_mask': torch.ones((*batch_size, self.num_agents), dtype=torch.bool, device=self.device), 'cur_step': torch.zeros((*batch_size, self.num_agents), dtype=torch.int32, device=self.device), 'route_length': torch.zeros((*batch_size, self.num_agents), dtype=torch.float, device=self.device), 'cur_linehaul_load': torch.ones((*self.batch_size, self.num_agents), dtype=torch.float32, device=self.device) * self.td_state['initial_load'], 'cur_backhaul_load': torch.zeros((*self.batch_size, self.num_agents), dtype=torch.float32, device=self.device)}, batch_size=batch_size, device=self.device) self.td_state['cur_agent'] = TensorDict({ 'action_mask': self.td_state['agents']['feasible_nodes'].gather(1, cur_agent_idx[:,:,None].expand(-1, -1, self.num_nodes)).squeeze(1), 'depot_idx': self.td_state['agents']['depot_idx'].gather(1, self.td_state['cur_agent_idx']).clone(), 'cur_time': self.td_state['agents']['cur_time'].gather(1, cur_agent_idx).clone(), 'cur_node': self.td_state['agents']['cur_node'].gather(1, cur_agent_idx).clone(), 'cur_ttime': self.td_state['agents']['cur_ttime'].gather(1, cur_agent_idx).clone(), 'cum_ttime': self.td_state['agents']['cum_ttime'].gather(1, cur_agent_idx).clone(), 'cur_route_length': self.td_state['agents']['route_length'].gather(1, cur_agent_idx).clone(), 'cur_step': self.td_state['agents']['cur_step'].gather(1, cur_agent_idx).clone(), 'cur_linehaul_load': self.td_state['agents']['cur_linehaul_load'].gather(1, cur_agent_idx).clone(), 'cur_backhaul_load': self.td_state['agents']['cur_backhaul_load'].gather(1, cur_agent_idx).clone() }, batch_size=batch_size) self.td_state['cur_node_idx'] = self.td_state['cur_agent']['depot_idx'].clone() distance2depot = torch.cdist(self.td_state['depot_loc'], self.td_state['coords'], p=2, compute_mode="donot_use_mm_for_euclid_dist") self.td_state['speed'] = instance_info['data']['speed'].clone() time2depot = distance2depot / self.td_state['speed'].unsqueeze(-1) self.td_state['nodes'] = TensorDict( source={'linehaul_demands': self.td_state['linehaul_demands'].clone(), 'backhaul_demands': self.td_state['backhaul_demands'].clone(), 'distance2depot': distance2depot, 'time2depot': time2depot, 'active_nodes_mask': torch.ones((*batch_size, self.num_nodes),dtype=torch.bool, device=self.device)}, batch_size=batch_size, device=self.device) self.td_state['backhaul_class'] = instance_info['data']['backhaul_class'].clone() self.td_state['solution'] = TensorDict({}, batch_size=batch_size) self.agent_selector.set_env(self) self.obs_builder.set_env(self) self.reward_evaluator.set_env(self) #Set environment do agent selector, reward e observations agent_step = self.td_state['cur_agent']['cur_step'] done = self.td_state['done'].clone() reward = torch.zeros_like(done, dtype = torch.float, device=self.device) penalty = torch.zeros_like(done, dtype = torch.float, device=self.device) #td observations td_observations = self.observe(is_reset=True) self.env_nsteps = 0 return TensorDict( { "agent_step": agent_step, "observations": td_observations, "cur_agent_idx":self.td_state['cur_agent_idx'].clone(), "cur_node_idx": self.td_state['cur_node_idx'].clone(), "initial_load": self.td_state['initial_load'].clone(), "reward": reward, "penalty":penalty, "done": done, }, batch_size=batch_size, device=self.device)
def _update_feasibility(self): """ Update actions feasibility. Args: n/a. Returns: None. """ active_nodes = self.td_state['nodes']['active_nodes_mask'].clone() #Active nodes. Agent can only visit node if it's active loc = self.td_state['coords'].gather(1, self.td_state['cur_agent']['cur_node'][:,:,None].expand(-1, -1, 2)) #Current agent location ptime = self.td_state['cur_agent']['cur_time'].clone() #Agent current time distance2j = torch.pairwise_distance(loc, self.td_state["coords"], eps=0, keepdim = False) #Distance between current agent and nodes if self.n_digits is not None: distance2j = torch.floor(self.n_digits * distance2j) / self.n_digits time2depot = self.td_state['nodes']['time2depot'].gather(1, self.td_state['cur_agent']['depot_idx'].unsqueeze(-1).expand(-1, -1, self.num_nodes)).squeeze(1) distance2depot = self.td_state['nodes']['distance2depot'].gather(1, self.td_state['cur_agent']['depot_idx'].unsqueeze(-1).expand(-1, -1, self.num_nodes)).squeeze(1) time2arrive = distance2j / self.td_state['speed'] arrival_time = ptime + time2arrive #Arrival time. Current time + time 2 arrive (distance / speed) #Constraint 1. Can arrive to node in time. c1 = arrival_time <= self.td_state['tw_high'] #Constraint 2. If problem is closed, if agent can arrive to depot in time. c2 = (torch.max(arrival_time, self.td_state['tw_low']) + self.td_state['service_time'] + time2depot) * ~self.td_state['open_routes'] <= self.td_state['end_time'].unsqueeze(-1) #Constraint 3. Does agent exceed distance limit. c3 = self.td_state['cur_agent']['cur_route_length'] + distance2j + (distance2depot * ~self.td_state['open_routes']) <= self.td_state['distance_limits'] #Demands constraints total_load = self.td_state['cur_agent']['cur_linehaul_load'] + self.td_state['cur_agent']['cur_backhaul_load'] can_go_to_linehaul = self.td_state['cur_agent']['cur_linehaul_load'] - self.td_state['linehaul_demands'] >= 0 can_go_to_backhaul = total_load + self.td_state['backhaul_demands'] <= self.td_state['capacity'] ''' Backhaul class 1. Node either linehaul or backhaul. Linehauls before backhauls. ''' linehaul_missing = ((self.td_state['linehaul_demands'] * active_nodes).sum(-1) > 0).unsqueeze(-1) is_carrying_backhaul = gather_by_index(src=self.td_state['backhaul_demands'], idx=self.td_state['cur_agent']['cur_node'], dim=1, squeeze=False) > 0 meets_demand_constraint_backhaul_1 = (linehaul_missing & can_go_to_linehaul & ~is_carrying_backhaul & (self.td_state['linehaul_demands'] > 0)) | (can_go_to_backhaul & (self.td_state['backhaul_demands'] > 0)) ''' Backhaul class 2. Mixed linehauls and backhauls ''' cannot_serve_linehaul = self.td_state['linehaul_demands'] > self.td_state['capacity'] - self.td_state['cur_agent']['cur_backhaul_load'] meets_demand_constraint_backhaul_2 = can_go_to_linehaul & can_go_to_backhaul & ~cannot_serve_linehaul #Demand constraints according to backhaul class meet_demand_constraints = ((self.td_state['backhaul_class'] == 1) & meets_demand_constraint_backhaul_1) | ((self.td_state['backhaul_class'] == 2) & meets_demand_constraint_backhaul_2) _mask = active_nodes & c1 & c2 & c3 & meet_demand_constraints # after done close all services and open depot _mask = _mask * ~self.td_state['done'].unsqueeze(-1) _mask.scatter_(1, self.td_state['depot_idx'], 0) #Close all depots but the one of the agent. Regardless of the above condition. _mask.scatter_(1, self.td_state['cur_agent']['depot_idx'], True) if self.force_visit: can_visit = ~((self.td_state['cur_agent']['cur_node'] == self.td_state['cur_agent']['depot_idx']).squeeze(-1) & (_mask[:, self.num_depots:].sum(-1) > 0)) _mask.scatter_(1, self.td_state['cur_agent']['depot_idx'], can_visit.unsqueeze(-1)) self.td_state['cur_agent'].update({'action_mask': _mask}) self.td_state['agents']['feasible_nodes'].scatter_(1, self.td_state['cur_agent_idx'][:,:,None].expand(-1,-1,self.num_nodes), _mask.unsqueeze(1)) def _update_done( self, action ): """ Update done state. Args: action(torch.Tensor): Tensor with agent moves. Returns: None. """ former_done = self.td_state['done'].clone() # update done agents self.td_state['agents']['active_agents_mask'].scatter_(1, self.td_state['cur_agent_idx'], ~action.eq(self.td_state['cur_agent']['depot_idx'])) self.td_state['done'] = (~self.td_state['agents']['active_agents_mask']).all(dim=-1) self.td_state['done'][former_done] = True # update served nodes self.td_state['nodes']['active_nodes_mask'].scatter_(1, action, action.eq(self.td_state['cur_agent']['depot_idx'])) self.td_state['is_last_step'] = self.td_state['done'].eq(~former_done) def _update_state(self, action): """ Update environment state. Args: action(torch.Tensor): Tensor with agent moves. Returns: None. """ loc = self.td_state['coords'].gather(1, self.td_state['cur_agent']['cur_node'][:,:,None].expand(-1, -1, 2)) next_loc = self.td_state['coords'].gather(1, action[:,:,None].expand(-1, -1, 2)) ptime = self.td_state['cur_agent']['cur_time'].clone() distance2j = torch.pairwise_distance(loc, next_loc, eps=0, keepdim = False) time2j = distance2j / self.td_state['speed'] if self.n_digits is not None: time2j = torch.floor(self.n_digits * time2j) / self.n_digits tw = self.td_state['tw_low'].gather(1, action) service_time = self.td_state['service_time'].gather(1, action) arrivej = ptime + time2j waitj = torch.clip(tw-arrivej, min=0) time_update = arrivej + waitj + service_time is_open_and_getting_to_depot = (self.td_state['open_routes']) & (action.eq(self.td_state['cur_agent']['depot_idx'])) #Update distances and time if problem is open and agent going back to depot distance2j[is_open_and_getting_to_depot] = 0. time2j[is_open_and_getting_to_depot] = 0. # update agent cur node self.td_state['cur_agent']['cur_node'] = action self.td_state['agents']['cur_node'].scatter_(1, self.td_state['cur_agent_idx'], self.td_state['cur_agent']['cur_node']) # update agent cur time self.td_state['cur_agent']['cur_time'] = time_update self.td_state['agents']['cur_time'].scatter_(1, self.td_state['cur_agent_idx'], self.td_state['cur_agent']['cur_time']) #Current route length self.td_state['cur_agent']['cur_route_length'] += distance2j self.td_state['agents']['route_length'].scatter_(1, self.td_state['cur_agent_idx'], self.td_state['cur_agent']['cur_route_length']) self.td_state['agents']['cur_time'].scatter_(1, self.td_state['cur_agent_idx'], self.td_state['cur_agent']['cur_time']) # update agent cum traveled time self.td_state['cur_agent']['cur_ttime'] = time2j self.td_state['cur_agent']['cum_ttime'] += time2j self.td_state['agents']['cur_ttime'].scatter_(1, self.td_state['cur_agent_idx'], self.td_state['cur_agent']['cur_ttime']) self.td_state['agents']['cum_ttime'].scatter_(1, self.td_state['cur_agent_idx'], self.td_state['cur_agent']['cum_ttime']) self.td_state['nodes']['linehaul_demands'].scatter_(1, action, torch.zeros_like(action, dtype = torch.float)) self.td_state['nodes']['backhaul_demands'].scatter_(1, action, torch.zeros_like(action, dtype = torch.float)) # update visited nodes r = torch.arange(*self.td_state.batch_size, device=self.device) self.td_state['agents']['visited_nodes'][r, self.td_state['cur_agent_idx'].squeeze(-1), action.squeeze(-1)] = True # update agent step agents_done = ~self.td_state['agents']['active_agents_mask'].gather(1, self.td_state['cur_agent_idx']).clone() self.td_state['cur_agent']['cur_step'] = torch.where(~agents_done, self.td_state['cur_agent']['cur_step']+1, self.td_state['cur_agent']['cur_step']) self.td_state['agents']['cur_step'].scatter_(1, self.td_state['cur_agent_idx'], self.td_state['cur_agent']['cur_step']) # update used capacities selected_demand_linehaul = gather_by_index(src=self.td_state['linehaul_demands'], idx=self.td_state['cur_agent']['cur_node'], dim=1, squeeze=False) selected_demand_backhaul = gather_by_index(src=self.td_state['backhaul_demands'], idx=self.td_state['cur_agent']['cur_node'], dim=1, squeeze=False) cur_node = self.td_state['agents']['cur_node'].gather(1, self.td_state['cur_agent_idx']).clone() cur_linehaul_load = (self.td_state['cur_agent']['cur_linehaul_load'] - selected_demand_linehaul) cur_backhaul_load = (self.td_state['cur_agent']['cur_backhaul_load'] + selected_demand_backhaul) self.td_state['cur_agent']['cur_linehaul_load'] = cur_linehaul_load self.td_state['cur_agent']['cur_backhaul_load'] = cur_backhaul_load self.td_state['agents']['cur_linehaul_load'].scatter_(1, self.td_state['cur_agent_idx'], self.td_state['cur_agent']['cur_linehaul_load']) self.td_state['agents']['cur_backhaul_load'].scatter_(1, self.td_state['cur_agent_idx'], self.td_state['cur_agent']['cur_backhaul_load']) # if all done activate first agent to guarantee batch consistency during agent sampling self.td_state['agents']['active_agents_mask'][self.td_state['agents']['active_agents_mask'].sum(1).eq(0), 0] = True self.td_state['cur_node_idx'] = action.clone() self.td_state['agents']['active_agents_mask'] #self._update_feasibility() def _update_cur_agent(self, cur_agent_idx): """ Update current agent. Args: cur_agent_idx(torch.Tensor): Current agent id. Returns: None. """ self.td_state['cur_agent_idx'] = cur_agent_idx self.td_state['cur_agent'] = TensorDict({ 'action_mask': self.td_state['agents']['feasible_nodes'].gather(1, self.td_state['cur_agent_idx'][:,:,None].expand(-1, -1, self.num_nodes)).squeeze(1).clone(), 'depot_idx': self.td_state['agents']['depot_idx'].gather(1, self.td_state['cur_agent_idx']).clone(), 'cur_agent_idx': cur_agent_idx, 'cur_route_length': self.td_state['agents']['route_length'].gather(1, self.td_state['cur_agent_idx']).clone(), 'cur_time': self.td_state['agents']['cur_time'].gather(1, self.td_state['cur_agent_idx']).clone(), 'cur_node': self.td_state['agents']['cur_node'].gather(1, self.td_state['cur_agent_idx']).clone(), 'cur_ttime': self.td_state['agents']['cur_ttime'].gather(1, self.td_state['cur_agent_idx']).clone(), 'cum_ttime': self.td_state['agents']['cum_ttime'].gather(1, self.td_state['cur_agent_idx']).clone(), 'cur_step': self.td_state['agents']['cur_step'].gather(1, self.td_state['cur_agent_idx']).clone(), 'cur_linehaul_load': self.td_state['agents']['cur_linehaul_load'].gather(1, self.td_state['cur_agent_idx']).clone(), 'cur_backhaul_load': self.td_state['agents']['cur_backhaul_load'].gather(1, self.td_state['cur_agent_idx']).clone() }, batch_size=self.td_state.batch_size, device=self.device) def _update_solution(self, action): """ Update agents and actions in solution. Args: action(torch.Tensor): Tensor with agent moves. Returns: None. """ # update solution dic if 'actions' in self.td_state['solution'].keys(): self.td_state['solution','actions'] = torch.concat( [self.td_state['solution','actions'], action], dim=-1) else: self.td_state['solution','actions'] = action if 'agents' in self.td_state['solution'].keys(): self.td_state['solution','agents'] = torch.concat( [self.td_state['solution','agents'], self.td_state['cur_agent_idx']], dim=-1) else: self.td_state['solution','agents'] = self.td_state['cur_agent_idx']
[docs] def step( self, td: TensorDict ) -> TensorDict: """ Perform an environment step for active agent. Args: td(TensorDict): Environment tensor instance. Returns: td(TensorDict): Updated environment tensor instance. """ action = td['action'] assert self.td_state['cur_agent']['action_mask'].gather(1, action).all(), f"not feasible action" self._update_done(action) done = self.td_state['done'].clone() is_last_step = self.td_state['is_last_step'].clone() # update env state self._update_state(action) # update solution dic self. _update_solution(action) # get reward and penalty reward, penalty = self.reward_evaluator.get_reward(action) # select and update cur agent cur_agent_idx = self.agent_selector._next_agent() self._update_cur_agent(cur_agent_idx) agent_step = self.td_state['cur_agent']['cur_step'] # new observations td_observations = self.observe() self.env_nsteps += 1 td.update( { "agent_step": agent_step, "observations": td_observations, "reward": reward, "penalty":penalty, "cur_agent_idx":cur_agent_idx, "cur_node_idx": self.td_state['cur_node_idx'].clone(), "done": done, "is_last_step": is_last_step }, ) return td
[docs] def check_solution_validity(self): """ Check if solution is valid according to constraints. Args: N/a. Returns: None. """ for i in range(self.td_state['num_depots'][0].item()): distance2depot = torch.pairwise_distance(self.td_state['coords'], self.td_state['coords'][..., i:i+1, :], eps=0, keepdim = False) a = self.td_state['tw_low'] + distance2depot + self.td_state['service_time'] #Time 2 serve node and get back to depot b = self.td_state['time_windows'][..., 0, 1, None] #Depot late tw #Can agent serve node and get back to depot? assert torch.all(a <= b), "Agent cannot serve node and get back to depot." #Actions cycle assert. Curr_node starts at 0 (depot) and iteratively keeps going onto the next. curr_node = torch.zeros(*self.batch_size, dtype=torch.int64, device=self.device) curr_time = torch.zeros(*self.batch_size, dtype=torch.float32, device=self.device) curr_length = torch.zeros(*self.batch_size, dtype=torch.float32, device=self.device) visited_nodes = torch.zeros(*self.batch_size, self.num_nodes, dtype=torch.int64, device=self.device) # Sort indices along each row sorted_indices = torch.argsort(self.td_state['solution']['agents'], dim=-1, stable=True) # Use gather to reorder data per row sorted_data = torch.gather(self.td_state['solution']['actions'], dim=-1, index=sorted_indices) for ii in range(sorted_data.size(1)): next_node = sorted_data[:, ii] curr_loc = gather_by_index(self.td_state['coords'], curr_node) next_loc = gather_by_index(self.td_state['coords'], next_node) dist = torch.pairwise_distance(curr_loc, next_loc, eps=0, keepdim = False) fill = visited_nodes.gather(1, next_node.unsqueeze(-1)) visited_nodes.scatter_(1, next_node.unsqueeze(-1), fill + 1) curr_length = curr_length + dist * ~(self.td_state['open_routes'].squeeze(-1) & (torch.isin(next_node, self.td_state['depot_idx']))) #Update curr_length assert torch.all(curr_length <= self.td_state['distance_limits'].squeeze(-1)), "Route length exceeds distance limit." is_next_node_depot = torch.isin(next_node, self.td_state['depot_idx']) curr_length[is_next_node_depot] = 0.0 #Reset length for depot curr_time = torch.max(curr_time + dist, gather_by_index(self.td_state['time_windows'], next_node)[..., 0]) #Curr time either time to get to node or early tw assert torch.all(curr_time <= gather_by_index(self.td_state['time_windows'], next_node)[..., 1]), "Agent must perform service before node's time window closes." curr_time = curr_time + gather_by_index(self.td_state['service_time'], next_node) curr_node = next_node curr_node[is_next_node_depot] = (curr_node[is_next_node_depot] + 1 ) % self.num_depots curr_time[is_next_node_depot] = 0.0 visited_nodes_exc_depot = visited_nodes[:, self.num_depots:] assert(torch.all((visited_nodes_exc_depot == 0) | (visited_nodes_exc_depot == 1))), "Nodes were visited more than once!" demand_l = self.td_state['linehaul_demands'].gather(1, sorted_data) demand_b = self.td_state['backhaul_demands'].gather(1, sorted_data) used_cap_l = torch.zeros_like(self.td_state['linehaul_demands'][:, 0]) #Starts at 0 used_cap_b = torch.zeros_like(self.td_state['backhaul_demands'][:, 0]) #Starts at 0 for ii in range(sorted_data.size(1)): #reset at depot used_cap_l = used_cap_l * (~torch.isin(sorted_data[:, ii], self.td_state['depot_idx'])) used_cap_b = used_cap_b * (~torch.isin(sorted_data[:, ii], self.td_state['depot_idx'])) used_cap_l += demand_l[:, ii] used_cap_b += demand_b[:, ii] #Backhaul class 1 (unmixed), agents cannot supply linehaul if carrying backhaul assert( (self.td_state['backhaul_class'] == 2) | (used_cap_b == 0) | ((self.td_state['backhaul_class'] == 1) & ~(demand_l[:, ii] > 0)) ).all(), "Cannot pickup linehaul while carrying backhaul in unmixed problems." #Backhaul class 2 (mixed), agents cannot supply linehaul, if backhaul load + linehaul demand in node exceeds agent's capacity assert( (self.td_state['backhaul_class'] == 1) | (used_cap_b == 0) | ((self.td_state['backhaul_class'] == 2) & (used_cap_b + demand_l[:, ii] <= self.td_state['capacity'])) ).all(), "Cannot supply linehaul, not enough load." #Loads must not exceed capacity assert( used_cap_l <= self.td_state['capacity'] ).all(), "Used more linehaul than capacity: {}/{}".format(used_cap_l, self.td_state['capacity']) assert( used_cap_b <= self.td_state['capacity'] ).all(), "Used more backhaul than capacity: {}/{}".format(used_cap_b, self.td_state['capacity'])