import torch
from tensordict import TensorDict
from typing import Tuple, Optional
import warnings
from maenvs4vrp.core.env_generator_builder import InstanceBuilder
from maenvs4vrp.core.env_observation_builder import ObservationBuilder
from maenvs4vrp.core.env_agent_selector import BaseSelector
from maenvs4vrp.core.env_agent_reward import RewardFn
from maenvs4vrp.core.env import AECEnv
from maenvs4vrp.utils.utils import gather_by_index
[docs]
class Environment(AECEnv):
"""
TOPTW environment generator class.
"""
[docs]
def __init__(self,
instance_generator_object: InstanceBuilder,
obs_builder_object: ObservationBuilder,
agent_selector_object: BaseSelector,
reward_evaluator: RewardFn,
seed=None,
device: Optional[str] = None,
batch_size: torch.Size = None):
"""
Constructor.
Args:
instance_generator_object(InstanceBuilder): Generator instance.
obs_builder_object(ObservationBuilder): Observations instance.
agent_selector_object(BaseSelector): Agent selector instance
reward_evaluator(RewardFn): Reward evaluator instance.
seed(int): Random number generator seed. Defaults to None.
device(str, optional): Type of processing. It can be "cpu" or "gpu". Defaults to None.
batch_size(torch.Size): Batch size. Defaults to None.
"""
self.version = 'v0'
self.env_name = 'toptw'
# seed the environment
if seed is None:
self._set_seed(self.DEFAULT_SEED)
else:
self._set_seed(seed)
self.agent_selector = agent_selector_object
self.inst_generator = instance_generator_object
self.inst_generator._set_seed(self.seed)
self.obs_builder = obs_builder_object
self.obs_builder.set_env(self)
self.reward_evaluator = reward_evaluator
self.reward_evaluator.set_env(self)
self.env_nsteps = 0
if device is None:
self.device = self.inst_generator.device
else:
self.device = device
self.inst_generator.device = device
if batch_size is None:
self.batch_size = self.inst_generator.batch_size
else:
batch_size = [batch_size] if isinstance(batch_size, int) else batch_size
self.batch_size = torch.Size(batch_size)
self.inst_generator.batch_size = torch.Size(batch_size)
self.td_state = TensorDict({}, batch_size=self.batch_size, device=self.device)
[docs]
def observe(self, is_reset=False)-> TensorDict:
"""
Retrieve agent environment observations.
Args:
is_reset(bool): If the environment is on reset. Defauts to False.
Returns
td_observations(TensorDict): Current agent observaions and masks dictionary.
"""
self._update_feasibility()
td_observations = self.obs_builder.get_observations(is_reset=is_reset)
td_observations['action_mask'] = self.td_state['cur_agent']['action_mask'].clone()
td_observations['agents_mask'] = self.td_state['agents']['active_agents_mask'].clone()
return td_observations
[docs]
def sample_action(self, td: TensorDict)-> TensorDict:
"""
Compute a random action from avaliable actions to current agent.
Args:
td(TensorDict): Environment instance tensor.
Returns:
td(TensorDict): Environment instance tensor with updated action.
"""
action = torch.multinomial(self.td_state['cur_agent']["action_mask"].float(), 1).to(self.device)
td['action'] = action
return td
[docs]
def reset(self,
num_agents:int|None=None,
num_nodes:int|None=None,
service_times:float|None=None,
profits:str='constant',
instance_name:str|None=None,
sample_type:str='random',
batch_size: Optional[torch.Size] = None,
n_augment: Optional[int] = None,
seed:int|None=None)-> TensorDict:
"""
Reset the environment.
Args:
num_agents(int, optional): Total number of agents. Defaults to None.
num_nodes(int, optional): Total number of nodes. Defaults to None.
capacity(int, optional): Total capacity for each agent. Defaults to None.
service_times(float, optional): Service time in the nodes. Defaults to None.
instance_name(str, optional): Instance name. Defaults to None.
sample_type(str): Sample type. It can be "random", "augment" or "saved". Defaults to "random".
batch_size(torch.Size, optional): Batch size. Defaults to None.
n_augment(int, optional): Data augmentation. Defaults to None.
seed(int, optional): Random number generator seed. Defaults to None.
Returns:
TensorDict: Environment information dictionary.
"""
if seed is not None:
self._set_seed(seed)
if batch_size is None:
batch_size = self.batch_size
else:
batch_size = [batch_size] if isinstance(batch_size, int) else batch_size
self.batch_size = torch.Size(batch_size)
self.inst_generator.batch_size = torch.Size(batch_size)
instance_info = self.inst_generator.sample_instance(num_agents=num_agents,
num_nodes=num_nodes,
service_times=service_times,
profits=profits,
instance_name=instance_name,
sample_type=sample_type,
batch_size=batch_size,
n_augment = n_augment,
seed=seed)
self.num_nodes = instance_info['num_nodes']
self.num_agents = instance_info['num_agents']
if 'n_digits' in instance_info:
self.n_digits = instance_info['n_digits']
else:
self.n_digits = None
self.td_state = instance_info['data']
self.td_state['done'] = torch.zeros(*batch_size, dtype=torch.bool)
self.td_state['is_last_step'] = torch.zeros(*batch_size, dtype=torch.bool)
self.td_state['depot_loc'] = self.td_state['coords'].gather(1, self.td_state['depot_idx'][:,:,None].expand(-1, -1, 2))
self.td_state['max_tour_duration'] = self.td_state['end_time'] - self.td_state['start_time']
time2depot = torch.pairwise_distance(self.td_state['depot_loc'],
self.td_state['coords'], eps=0, keepdim = False)
if self.n_digits is not None:
time2depot = torch.floor(self.n_digits * time2depot) / self.n_digits
self.td_state['time2depot'] = time2depot
self.td_state['nodes'] = TensorDict(
source={'cur_profits': self.td_state['profits'].clone(),
'active_nodes_mask': torch.ones((*batch_size, self.num_nodes),dtype=torch.bool, device=self.device)},
batch_size=batch_size, device=self.device)
self.td_state['agents'] = TensorDict(
source={'cum_profit': torch.zeros((*batch_size, self.num_agents), dtype = torch.float, device=self.device),
'step_profit': torch.zeros((*batch_size, self.num_agents), dtype = torch.float, device=self.device),
'cur_time': self.td_state['start_time'].unsqueeze(1).clone() * torch.ones((*batch_size, self.num_agents), dtype = torch.float, device=self.device),
'cur_node': self.td_state['depot_idx'] * torch.ones((*batch_size, self.num_agents), dtype = torch.int64, device=self.device),
'visited_nodes': torch.zeros((*batch_size, self.num_agents, self.num_nodes), dtype=torch.bool, device=self.device),
'feasible_nodes': torch.ones((*batch_size, self.num_agents, self.num_nodes), dtype=torch.bool, device=self.device),
'active_agents_mask': torch.ones((*batch_size, self.num_agents), dtype=torch.bool, device=self.device),
'cur_step': torch.zeros((*batch_size, self.num_agents), dtype=torch.int32, device=self.device)},
batch_size=batch_size, device=self.device)
self.td_state['cur_agent_idx'] = torch.zeros((*batch_size, 1), dtype = torch.int64, device=self.device)
self.td_state['cur_node_idx'] = self.td_state['depot_idx'].clone()
self.td_state['cur_agent'] = TensorDict({
'action_mask': self.td_state['agents']['feasible_nodes'].gather(1, self.td_state['cur_agent_idx'][:,:,None].expand(-1, -1, self.num_nodes)).squeeze(1),
'cum_profit': self.td_state['agents']['cum_profit'].gather(1, self.td_state['cur_agent_idx']).clone(),
'step_profit': self.td_state['agents']['step_profit'].gather(1, self.td_state['cur_agent_idx']).clone(),
'cur_time': self.td_state['agents']['cur_time'].gather(1, self.td_state['cur_agent_idx']).clone(),
'cur_node': self.td_state['agents']['cur_node'].gather(1, self.td_state['cur_agent_idx']).clone(),
'cur_step': self.td_state['agents']['cur_step'].gather(1, self.td_state['cur_agent_idx']).clone(),
}, batch_size=batch_size)
self.td_state['solution'] = TensorDict({}, batch_size=batch_size)
self.agent_selector.set_env(self)
self.obs_builder.set_env(self)
self.reward_evaluator.set_env(self)
agent_step = self.td_state['cur_agent']['cur_step']
done = self.td_state['done'].clone()
reward = torch.zeros_like(done, dtype = torch.float, device=self.device)
penalty = torch.zeros_like(done, dtype = torch.float, device=self.device)
td_observations = self.observe(is_reset=True)
self.env_nsteps = 0
return TensorDict(
{
"agent_step": agent_step,
"observations": td_observations,
"cur_agent_idx":self.td_state['cur_agent_idx'].clone(),
"cur_node_idx": self.td_state['cur_node_idx'].clone(),
"reward": reward,
"penalty":penalty,
"done": done,
},
batch_size=batch_size, device=self.device)
def _update_feasibility(self):
"""
Update actions feasibility.
Args:
n/a.
Returns:
None.
"""
_mask = self.td_state['nodes']['active_nodes_mask'].clone()
# time windows constraints
loc = self.td_state['coords'].gather(1, self.td_state['cur_agent']['cur_node'][:,:,None].expand(-1, -1, 2))
ptime = self.td_state['cur_agent']['cur_time'].clone()
time2j = torch.pairwise_distance(loc, self.td_state["coords"], eps=0, keepdim = False)
if self.n_digits is not None:
time2j = torch.floor(self.n_digits * time2j) / self.n_digits
arrivej = ptime + time2j
waitj = torch.clip(self.td_state['tw_low']-arrivej, min=0)
service_startj = arrivej + waitj
c1 = service_startj <= self.td_state['tw_high']
c2 = service_startj + self.td_state['service_time'] + self.td_state['time2depot'] <= self.td_state['end_time'].unsqueeze(-1)
_mask = _mask * c1 * c2
# update state
self.td_state['cur_agent'].update({'action_mask': _mask})
self.td_state['agents']['feasible_nodes'].scatter_(1,
self.td_state['cur_agent_idx'][:,:,None].expand(-1,-1,self.num_nodes), _mask.unsqueeze(1))
def _update_done(self, action):
"""
Update done state.
Args:
action(torch.Tensor): Tensor with agent moves.
Returns:
None.
"""
former_done = self.td_state['done'].clone()
# update done agents
self.td_state['agents']['active_agents_mask'].scatter_(1, self.td_state['cur_agent_idx'],
~action.eq(self.td_state['depot_idx']))
self.td_state['done'] = (~self.td_state['agents']['active_agents_mask']).all(dim=-1)
# update served nodes
self.td_state['nodes']['active_nodes_mask'].scatter_(1, action, action.eq(self.td_state['depot_idx']))
self.td_state['is_last_step'] = self.td_state['done'].eq(~former_done)
def _update_state(self, action):
"""
Update environment state.
Args:
action(torch.Tensor): Tensor with agent moves.
Returns:
None.
"""
loc = self.td_state['coords'].gather(1, self.td_state['cur_agent']['cur_node'][:,:,None].expand(-1, -1, 2))
next_loc = self.td_state['coords'].gather(1, action[:,:,None].expand(-1, -1, 2))
ptime = self.td_state['cur_agent']['cur_time'].clone()
time2j = torch.pairwise_distance(loc, next_loc, eps=0, keepdim = False)
if self.n_digits is not None:
time2j = torch.floor(self.n_digits * time2j) / self.n_digits
tw = self.td_state['tw_low'].gather(1, action)
service_time = self.td_state['service_time'].gather(1, action)
arrivej = ptime + time2j
waitj = torch.clip(tw-arrivej, min=0)
time_update = arrivej + waitj + service_time
# update agent cur node
self.td_state['cur_agent']['cur_node'] = action
self.td_state['agents']['cur_node'].scatter_(1, self.td_state['cur_agent_idx'], self.td_state['cur_agent']['cur_node'])
# update agent cur time
self.td_state['cur_agent']['cur_time'] = time_update
# is agent is done set agent time to end_time
agents_done = ~self.td_state['agents']['active_agents_mask'].gather(1, self.td_state['cur_agent_idx']).clone()
self.td_state['cur_agent']['cur_time'] = torch.where(agents_done, self.td_state['end_time'].unsqueeze(-1),
self.td_state['cur_agent']['cur_time'])
self.td_state['agents']['cur_time'].scatter_(1, self.td_state['cur_agent_idx'], self.td_state['cur_agent']['cur_time'])
# update agent profit
self.td_state['cur_agent']['cum_profit'] += self.td_state['profits'].gather(1, action)
self.td_state['agents']['cum_profit'].scatter_(1, self.td_state['cur_agent_idx'], self.td_state['cur_agent']['cum_profit'])
self.td_state['cur_agent']['step_profit'] = self.td_state['profits'].gather(1, action)
self.td_state['agents']['step_profit'].scatter_(1, self.td_state['cur_agent_idx'], self.td_state['cur_agent']['step_profit'])
self.td_state['nodes']['cur_profits'].scatter_(1, action, torch.zeros_like(action, dtype = torch.float))
# update visited nodes
r = torch.arange(*self.td_state.batch_size, device=self.device)
self.td_state['agents']['visited_nodes'][r, self.td_state['cur_agent_idx'], action.squeeze(-1)] = True
# update agent step
self.td_state['cur_agent']['cur_step'] = torch.where(~agents_done, self.td_state['cur_agent']['cur_step']+1,
self.td_state['cur_agent']['cur_step'])
self.td_state['agents']['cur_step'].scatter_(1, self.td_state['cur_agent_idx'], self.td_state['cur_agent']['cur_step'])
self.td_state['cur_node_idx'] = action.clone()
# if all done activate first agent to guarantee batch consistency during agent sampling
self.td_state['agents']['active_agents_mask'][self.td_state['agents']['active_agents_mask'].sum(1).eq(0), 0] = True
def _update_cur_agent(self, cur_agent_idx):
"""
Update current agent.
Args:
cur_agent_idx(torch.Tensor): Current agent id.
Returns:
None.
"""
self.td_state['cur_agent_idx'] = cur_agent_idx
self.td_state['cur_agent'] = TensorDict({
'action_mask': self.td_state['agents']['feasible_nodes'].gather(1, self.td_state['cur_agent_idx'][:,:,None].expand(-1, -1, self.num_nodes)).squeeze(1).clone(),
'cum_profit': self.td_state['agents']['cum_profit'].gather(1, self.td_state['cur_agent_idx']).clone(),
'step_profit': self.td_state['agents']['step_profit'].gather(1, self.td_state['cur_agent_idx']).clone(),
'cur_time': self.td_state['agents']['cur_time'].gather(1, self.td_state['cur_agent_idx']).clone(),
'cur_node': self.td_state['agents']['cur_node'].gather(1, self.td_state['cur_agent_idx']).clone(),
'cur_step': self.td_state['agents']['cur_step'].gather(1, self.td_state['cur_agent_idx']).clone(),
}, batch_size=self.td_state.batch_size, device=self.device)
def _update_solution(self, action):
"""
Update agents and actions in solution.
Args:
action(torch.Tensor): Tensor with agent moves.
Returns:
None.
"""
# update solution dic
if 'actions' in self.td_state['solution'].keys():
self.td_state['solution','actions'] = torch.concat( [self.td_state['solution','actions'], action], dim=-1)
else:
self.td_state['solution','actions'] = action
if 'agents' in self.td_state['solution'].keys():
self.td_state['solution','agents'] = torch.concat( [self.td_state['solution','agents'], self.td_state['cur_agent_idx']], dim=-1)
else:
self.td_state['solution','agents'] = self.td_state['cur_agent_idx']
[docs]
def step(self, td: TensorDict) -> TensorDict:
"""
Perform an environment step for active agent.
Args:
td(TensorDict): Environment tensor instance.
Returns:
td(TensorDict): Updated environment tensor instance.
"""
action = td["action"]
assert self.td_state['cur_agent']['action_mask'].gather(1, action).all(), f"not feasible action"
self._update_done(action)
done = self.td_state['done'].clone()
is_last_step = self.td_state['is_last_step'].clone()
# get reward and penalty
reward, penalty = self.reward_evaluator.get_reward(action)
# update env state
self._update_state(action)
# update solution dic
self. _update_solution(action)
# select and update cur agent
cur_agent_idx = self.agent_selector._next_agent()
self._update_cur_agent(cur_agent_idx)
agent_step = self.td_state['cur_agent']['cur_step']
# new observations
td_observations = self.observe()
self.env_nsteps += 1
td.update(
{
"agent_step": agent_step,
"observations": td_observations,
"reward": reward,
"penalty":penalty,
"cur_agent_idx":cur_agent_idx,
"cur_node_idx": self.td_state['cur_node_idx'].clone(),
"done": done,
"is_last_step": is_last_step
},
)
return td
[docs]
def check_solution_validity(self):
"""
Check if solution is valid according to constraints.
Args:
N/a.
Returns:
None.
"""
distance2depot = torch.pairwise_distance(self.td_state['coords'], self.td_state['coords'][..., 0:1, :], eps=0, keepdim=False)
a = self.td_state['tw_low'] + distance2depot + self.td_state['service_time'] # Time to serve node and get back to depot
b = self.td_state['tw_high'][:, 0, None] # Depot late tw
# Can agent serve node and get back to depot?
assert torch.all(a <= b), "Agent cannot serve node and get back to depot."
curr_node = torch.zeros(*self.batch_size, dtype=torch.int64, device=self.device)
curr_time = torch.zeros(*self.batch_size, dtype=torch.float32, device=self.device)
visited_nodes = torch.zeros(*self.batch_size, self.num_nodes, dtype=torch.int64, device=self.device)
# Sort indices along each row
sorted_indices = torch.argsort(self.td_state['solution']['agents'], dim=-1, stable=True)
# Use gather to reorder data per row
sorted_data = torch.gather(self.td_state['solution']['actions'], dim=-1, index=sorted_indices)
for ii in range(sorted_data.size(1)):
next_node = sorted_data[:, ii]
curr_loc = gather_by_index(self.td_state['coords'], curr_node)
next_loc = gather_by_index(self.td_state['coords'], next_node)
dist = torch.pairwise_distance(curr_loc, next_loc, eps=0, keepdim=False)
fill = visited_nodes.gather(1, next_node.unsqueeze(-1))
visited_nodes.scatter_(1, next_node.unsqueeze(-1), fill + 1)
curr_time = torch.max(curr_time + dist, gather_by_index(self.td_state['tw_low'], next_node))
assert torch.all(curr_time <= gather_by_index(self.td_state['tw_high'], next_node)), "Agent must perform service before node's time window closes."
curr_time = curr_time + gather_by_index(self.td_state['service_time'], next_node)
curr_node = next_node
curr_time[next_node == 0] = 0.0
visited_nodes_exc_depot = visited_nodes[:, 1:]
assert torch.all((visited_nodes_exc_depot == 0) | (visited_nodes_exc_depot == 1)), "Nodes were visited more than once!"