morphoclass.layers.tree_lstm_pool module¶
Implementation of the tree-LSTM pooling layer.
-
class
morphoclass.layers.tree_lstm_pool.TreeLSTMCell(x_size, h_size)¶ Bases:
torch.nn.modules.module.ModuleChild-Sum Tree LSTM Cell.
This class implements a single child-sum LSTM cell, and can be applied to a single node in a tree. (see the implementation of the forward function).
WARNING: this implementation is not at all optimized, and ideally one should be able to apply LSTM cells to multiple nodes in parallel. The current implementation only supports a sequential processing.
- Parameters
x_size (int) – Dimension of the node embedding vector
h_size (int) – Dimension of the hidden state and the memory cell vectors
See also
https//arxiv.org/abs/1503.00075
-
forward(x, hs, cs)¶ Single forward pass of the LSTM cell on a node.
- Parameters
x (torch.tensor) – The node embedding vector of shape (x_size, )
hs (torch.tensor) – Hidden states of the child nodes. The shape should be (n_children, h_size)
cs (torch.tensor) – Memory states of the child nodes. The shape should be (n_children, h_size)
- Returns
h (torch.tensor) – The updated hidden state of shape (h_size, )
c (torch.tensor) – The updated memory cell of shape (h_size, )
-
reset_parameters()¶ Randomly initialize all weights and set biases to zero.
-
training: bool¶
-
class
morphoclass.layers.tree_lstm_pool.TreeLSTMPool(x_size, h_size)¶ Bases:
torch.nn.modules.module.ModuleChild-Sum Tree LSTM Pooling Layer.
This class implements the Tree-LSTM as described in the reference. After traversing the whole tree the hidden state of the root node is returned as the result of the pooling.
Batched graphs can be processed as well. In that case the hidden states of all root nodes of all graphs are returned.
WARNING: this implementation is not at all optimized, and ideally one should be able to apply LSTM cells to multiple nodes in parallel. In principle all cells yielded by topologically_sorted can be processed in parallel, but the current implementation sequentially loops through them.
- Parameters
x_size (int) – Dimension of the node embedding vector
h_size (int) – Dimension of the hidden state and the memory cell vectors
See also
https//arxiv.org/abs/1503.00075
-
forward(x, edge_index)¶ Compute the forward pass.
- Parameters
x – The batched node feature maps for pooling.
edge_index – The batched adjacency matrices.
- Returns
h – The pooled features
- Return type
torch.Tensor
-
static
topologically_sorted(adj)¶ Topologically sort nodes in a given tree.
Topological sort mean that we start with all leaf nodes and yield them. Next all nodes for which all children have already been processed are yielded. This is repeated until all nodes have been seen.
At each iteration step a list of nodes is returned (an equivalent node mask to be precise) which are next in the topological order.
- Parameters
adj (matrix_like) – The adjacency matrix describing the tree to be sorted
- Yields
active_nodes (array_like of type bool) – Node mask for the current set of nodes in a sorted tree.
-
training: bool¶