morphoclass.layers.tree_lstm_pool module

Implementation of the tree-LSTM pooling layer.

class morphoclass.layers.tree_lstm_pool.TreeLSTMCell(x_size, h_size)

Bases: torch.nn.modules.module.Module

Child-Sum Tree LSTM Cell.

This class implements a single child-sum LSTM cell, and can be applied to a single node in a tree. (see the implementation of the forward function).

WARNING: this implementation is not at all optimized, and ideally one should be able to apply LSTM cells to multiple nodes in parallel. The current implementation only supports a sequential processing.

Parameters
  • x_size (int) – Dimension of the node embedding vector

  • h_size (int) – Dimension of the hidden state and the memory cell vectors

See also

https

//arxiv.org/abs/1503.00075

forward(x, hs, cs)

Single forward pass of the LSTM cell on a node.

Parameters
  • x (torch.tensor) – The node embedding vector of shape (x_size, )

  • hs (torch.tensor) – Hidden states of the child nodes. The shape should be (n_children, h_size)

  • cs (torch.tensor) – Memory states of the child nodes. The shape should be (n_children, h_size)

Returns

  • h (torch.tensor) – The updated hidden state of shape (h_size, )

  • c (torch.tensor) – The updated memory cell of shape (h_size, )

reset_parameters()

Randomly initialize all weights and set biases to zero.

training: bool
class morphoclass.layers.tree_lstm_pool.TreeLSTMPool(x_size, h_size)

Bases: torch.nn.modules.module.Module

Child-Sum Tree LSTM Pooling Layer.

This class implements the Tree-LSTM as described in the reference. After traversing the whole tree the hidden state of the root node is returned as the result of the pooling.

Batched graphs can be processed as well. In that case the hidden states of all root nodes of all graphs are returned.

WARNING: this implementation is not at all optimized, and ideally one should be able to apply LSTM cells to multiple nodes in parallel. In principle all cells yielded by topologically_sorted can be processed in parallel, but the current implementation sequentially loops through them.

Parameters
  • x_size (int) – Dimension of the node embedding vector

  • h_size (int) – Dimension of the hidden state and the memory cell vectors

See also

https

//arxiv.org/abs/1503.00075

forward(x, edge_index)

Compute the forward pass.

Parameters
  • x – The batched node feature maps for pooling.

  • edge_index – The batched adjacency matrices.

Returns

h – The pooled features

Return type

torch.Tensor

static topologically_sorted(adj)

Topologically sort nodes in a given tree.

Topological sort mean that we start with all leaf nodes and yield them. Next all nodes for which all children have already been processed are yielded. This is repeated until all nodes have been seen.

At each iteration step a list of nodes is returned (an equivalent node mask to be precise) which are next in the topological order.

Parameters

adj (matrix_like) – The adjacency matrix describing the tree to be sorted

Yields

active_nodes (array_like of type bool) – Node mask for the current set of nodes in a sorted tree.

training: bool