MPM#
- class deeplay.models.gnn.mpm.MPM(*args, **kwargs)#
Bases:
DeeplayModuleMessage Passing Neural Network (MPN) model.
Parameters#
- hidden_features: list[int]
Number of hidden units in each Message Passing Layer.
- out_features: int
Number of output features.
- pool: template-like
Specification for the pooling of the model. Default: nn.Identity.
- out_activation: template-like
Specification for the output activation of the model. Default: nn.Identity.
Configurables#
hidden_features (list[int]): Number of hidden units in each Message Passing Layer.
out_features (int): Number of output features.
pool (template-like): Specification for the pooling of the model. Default: nn.Identity.
out_activation (template-like): Specification for the output activation of the model. Default: nn.Identity.
encoder (template-like): Specification for the encoder of the model. Default: dl.Parallel consisting of two MLPs to process node and edge features.
backbone (template-like): Specification for the backbone of the model. Default: dl.MessagePassingNeuralNetwork.
selector (template-like): Specification for the selector of the model. Default: dl.FromDict(“x”) selecting the node features.
head (template-like): Specification for the head of the model. Default: dl.MultiLayerPerceptron.
Constraints#
- input: Dict[str, Any] or torch-geometric Data object containing the following attributes:
x: torch.Tensor of shape (num_nodes, node_in_features).
edge_index: torch.Tensor of shape (2, num_edges).
edge_attr: torch.Tensor of shape (num_edges, edge_in_features).
NOTE: node_in_features and edge_in_features are inferred from the input data.
output: torch.Tensor of shape (num_nodes, out_features)
Examples#
>>> # MPN with 2 hidden layers of 64 units each and 1 output feature >>> model = MPN([64, 64], 1).create() >>> # Define input as a dictionary with node features, edge index and edge features >>> inp = {} >>> inp["x"] = torch.randn(10, 16) >>> inp["edge_index"] = torch.randint(0, 10, (2, 20)) >>> inp["edge_attr"] = torch.randn(20, 8) >>> out = model(inp) >>> print(out.shape) torch.Size([10, 1])
Methods Summary
forward(x)Define the computation performed at every call.
Methods Documentation
- forward(x)#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.