Source code for marl.model.nn.mlpnet

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

[docs]def hidden_init(layer): fan_in = layer.weight.data.size()[0] lim = 1. / np.sqrt(fan_in) return (-lim, lim)
[docs]class MlpNet(nn.Module): def __init__(self, obs_sp, act_sp, hidden_size=[64,64], hidden_activ=nn.ReLU, last_activ=None, lay_norm=False): super(MlpNet, self).__init__() self.input_size = obs_sp self.output_size = act_sp self.h_activ = hidden_activ self.last_activ = last_activ self.lay_norm = lay_norm in_size = hidden_size[-1] if len(hidden_size) > 0 else self.input_size self.feature_extractor = self._build_module(hidden_size) self.output_layer = nn.Linear(in_size, self.output_size) self.reset_parameters() def _build_module(self, h_size): in_size = self.input_size modules = [] for n_units in h_size: modules.append(nn.Linear(in_size, n_units)) modules.append(self.h_activ()) if self.lay_norm: modules.append(nn.LayerNorm(n_units)) in_size = n_units return nn.Sequential(*modules)
[docs] def reset_parameters(self): for lay in self.feature_extractor: if isinstance(lay, nn.Linear): lay.weight.data.uniform_(*hidden_init(lay)) self.output_layer.weight.data.uniform_(-3e-3, 3e-3)
[docs] def forward(self, x): x = self.feature_extractor(x) x = self.output_layer(x) if self.last_activ is not None: x = self.last_activ(x) return x
[docs]class GumbelMlpNet(MlpNet): def __init__(self, obs_sp, act_sp, hidden_size=[64,64], hidden_activ=nn.ReLU, tau=1., lay_norm=False): super(GumbelMlpNet, self).__init__(obs_sp=obs_sp, act_sp=act_sp, hidden_size=[64,64], hidden_activ=nn.ReLU, lay_norm=lay_norm) self.tau = tau
[docs] def forward(self, x): x = super().forward(x) x = F.gumbel_softmax(x, tau=self.tau, hard=False) return x
[docs]class ContinuousCritic(nn.Module): def __init__(self, obs_sp, act_sp, hidden_size=[64,64]): super(ContinuousCritic, self).__init__() if len(hidden_size)!=2: raise ValueError("'hidden_size' must be of length 2") self.action_space = act_sp self.obs_space = obs_sp self.linear1 = nn.Linear(obs_sp, hidden_size[0]) self.ln1 = nn.LayerNorm(hidden_size[0]) self.linear2 = nn.Linear(hidden_size[0] + self.action_space, hidden_size[1]) self.ln2 = nn.LayerNorm(hidden_size[1]) self.output_layer = nn.Linear(hidden_size[1], 1) self.reset_parameters()
[docs] def reset_parameters(self): for lay in [self.linear1, self.linear2]: if isinstance(lay, nn.Linear): lay.weight.data.uniform_(*hidden_init(lay)) self.output_layer.weight.data.uniform_(-3e-3, 3e-3)
[docs] def forward(self, obs, act): x = obs # Extractor features obs x = self.linear1(x) x = self.ln1(x) x = F.relu(x) # Concat features obs and actions x = torch.cat((x, act), 1) # Insert the actions x = self.linear2(x) x = self.ln2(x) x = F.relu(x) outputs = self.output_layer(x) return outputs