morphoclass.training.cli module

Utilities for the morphoclass train command.

morphoclass.training.cli.ask(msg: str)bool

Ask an interactive question in the terminal.

morphoclass.training.cli.collect_metrics(y_true: numpy.ndarray, y_pred: numpy.ndarray, target_names: Sequence[str])dict

Collect different evaluation metrics.

Parameters
  • y_true (1d array-like) – Ground truth labels

  • y_pred (1d array-like) – Predicted labels.

  • target_names (array-like of shape (n_labels,)) – Names of the classes.

Returns

metrics_dict – The dictionary with all computed metrics. The keys are:

  • classification_report: The sklearn classification report.

  • confusion_matrix: The sklearn confusion matrix.

  • accuracy: The accuracy score.

  • f1_micro: The micro-averaged F1-score.

  • f1_macro: The macro-averaged F1-score.

  • f1_weighted: The weighted average of the F1-score.

Return type

dict

morphoclass.training.cli.get_model(config, pretrained_state_dict, n_classes)

Reconstruct the model from the config file.

morphoclass.training.cli.oversample(ids, labels, random_state=None)

Oversample to balance the label count.

morphoclass.training.cli.plot_confusion_matrices(img_dir: pathlib.Path | None, training_log: TrainingLog)None

Plot confusion matrices.

morphoclass.training.cli.prune_one_member_classes(dataset)

Prune classes with only one member.

morphoclass.training.cli.run_training(dataset: morphoclass.data.morphology_dataset.MorphologyDataset, config: morphoclass.training.training_config.TrainingConfig)morphoclass.training.training_log.TrainingLog

Training and evaluation of the model.

morphoclass.training.cli.split_metrics(splits: Sequence[dict])dict[str, float]

Compute average metrics across splits.

morphoclass.training.cli.train_dm_model(model: torch.nn.Module, dataset: MorphologyDataset, train_idx: torch_geometric.data.dataset.IndexType, val_idx: torch_geometric.data.dataset.IndexType, optimizer: torch.optim.Optimizer, batch_size: int, n_epochs: int, interactive: bool = False)dict[str, Any]

Train morphoclass models.

morphoclass.training.cli.train_ml_model(model: sklearn.base.BaseEstimator, dataset: MorphologyDataset, train_idx: torch_geometric.data.dataset.IndexType, val_idx: torch_geometric.data.dataset.IndexType)dict[str, Any]

Train a sklearn-like model.

Parameters
  • model – A sklearn-like model with methods fit, predict and predict_proba.

  • dataset – A morphology dataset.

  • train_idx – The indices of the training set.

  • val_idx – The indices of the validation set.

Returns

The training history with the keys “model”, “predictions”, “ground_truths”, “probabilities”, “latent_features”.

Return type

dict

morphoclass.training.cli.train_model(config, pretrained_model, train_idx, val_idx, dataset)

Train a model.