CNN¶
The overall idea of training a convolutional neural network (CNN) is quite similar to the one for the GNN described in section GNN.
At the moment there are some differences in the interface that arose for historical reasons and might be adjusted in the future if necessary
Loading Data¶
Unlike GNNs, CNNs operate on images rather than graphs. Therefore the the morphologies need to be transformed into some kind of 2D representation. Luckily Lida Kanari has developed a topological framework for doing exactly that - see the TMD package and the corresponding publications for details.
Here we will be using this TMD approach to obtain the so-called persistence diagrams of the neurites, and then apply Gaussian kernel-density estimation to generate images out of the diagrams. Here is a sample code for loading morphology data together with their TMD persistence diagrams and images:
import numpy as np
import torch
from tmd.Topology.analysis import get_persistence_image_data
from morphoclass.data import MorphologyDataset
from morphoclass.features.non_graph import get_tmd_diagrams
from morphoclass.transforms import BranchingOnlyNeurites
from morphoclass.transforms import Compose
from morphoclass.transforms import ExtractEdgeIndex
from morphoclass.transforms import ExtractTMDNeurites
def load_persistence_dataset(input_csv):
# Pre-processing transformations
pre_transform = [
ExtractTMDNeurites(neurite_type="apical"),
BranchingOnlyNeurites(),
ExtractEdgeIndex(),
]
pre_transform = Compose(pre_transform)
# Load neurites
dataset = MorphologyDataset.from_csv(
csv_file=input_csv,
pre_transform=pre_transform,
)
# Attach TMD diagram to each sample
for data in dataset:
data.num_nodes = sum(len(tree.p) for tree in data.tmd_neurites)
neurite_collection = [data.tmd_neurites for data in dataset]
tmd_diagrams = get_tmd_diagrams(
neurite_collection, feature="projection"
) # or feature="radial_distance"
# Normalize TMD diagrams
xmin, ymin = np.stack([d.min(axis=0) for d in tmd_diagrams]).min(axis=0)
xmax, ymax = np.stack([d.max(axis=0) for d in tmd_diagrams]).max(axis=0)
xscale = max(abs(xmax), abs(xmin))
yscale = max(abs(ymax), abs(ymin))
scale = np.array([[xscale, yscale]])
for sample, diagram in zip(dataset, tmd_diagrams):
sample.diagram = torch.tensor(diagram / scale).float()
# Attach TMD images
xmin_norm = min(xmin, 0)
ymin_norm = min(ymin, 0)
for sample, diagram in zip(dataset, tmd_diagrams):
image = get_persistence_image_data(
diagram,
xlims=(xmin_norm, xmax),
ylims=(ymin_norm, ymax),
)
image = np.rot90(image)[np.newaxis, np.newaxis] # shape = (batch, c, w, h)
sample.image = torch.tensor(image.copy()).float()
return dataset
As described in section Data, we first create a MorphologyDataset class, and then we attach to its sample in this dataset the corresponding persistence diagram and image.
Training¶
The convolutional model we propose in this package is defined in mc.models.CNNet. Here’s an example of how it can be trained on persistence images:
import numpy as np
import torch
from tqdm import tqdm
from morphoclass.data.morphology_data_loader import MorphologyDataLoader
from morphoclass.models import CNNet
from morphoclass.training import reset_seeds
from morphoclass.training.trainers import Trainer
dataset = load_persistence_dataset(input_csv_train)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
labels = torch.tensor([s.y for s in dataset]).to(device)
label_to_y = dataset.label_to_y
labels_unique_str = sorted(label_to_y, key=lambda label: label_to_y[label])
n_classes = len(labels_unique_str)
reset_seeds(numpy_seed=0, torch_seed=0)
model = CNNet(n_classes=n_classes, image_size=100)
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=5e-3, weight_decay=5e-4)
trainer = Trainer(model, dataset, optimizer, MorphologyDataLoader)
train_idx = torch.arange(len(dataset))
val_idx = torch.arange(0)
history = trainer.train(
n_epochs=100,
batch_size=2,
train_idx=train_idx,
val_idx=None,
progress_bar=tqdm,
)
The main difference is that the trainer class accepts a set of train and validation indices.
The logic here is that one can load a set or morphologies that contains both the train and
validation sets and then specify which of the morphologies should be used in training and which
in validation by providing train_idx and val_idx, which are sequences of indices.
Here we just want to train on the whole set, so we set val_idx to an empty sequence, and
train_idx to all indices.
Otherwise the code should be straight-forward and self-explanatory. After running it the model instance is trained and can be used for prediction.
Evaluating¶
Unlike for GNNs, the evaluation of the CNN has to be done in a more manual way. This may change in the future. Let’s first look at the code and then make some comments after:
from morphoclass.data.morphology_data_loader import MorphologyDataLoader
val_idx = torch.arange(len(dataset))
data_loader = MorphologyDataLoader(dataset.index_select(val_idx))
model.eval()
logits = []
with torch.no_grad():
for batch in data_loader:
batch = batch.to(device)
out = model(batch)
logits.append(out)
logits = torch.cat(logits)
# Compute predictions and accuracy
predictions = logits.argmax(axis=1)
acc_train = (predictions == labels).float().mean()
print(f"Accuracy: {acc_train * 100:.2f}%")
As you can see, one needs to manually loop through the data by creating a data loader. As for the GNN, the output of the model are logits, i.e. logarithms of the probabilities over the classes. These can be transformed to actual predictions by taking the arg-max, just as before.