morphoclass.training.cli module¶
Utilities for the morphoclass train command.
-
morphoclass.training.cli.ask(msg: str) → bool¶ Ask an interactive question in the terminal.
-
morphoclass.training.cli.collect_metrics(y_true: numpy.ndarray, y_pred: numpy.ndarray, target_names: Sequence[str]) → dict¶ Collect different evaluation metrics.
- Parameters
y_true (1d array-like) – Ground truth labels
y_pred (1d array-like) – Predicted labels.
target_names (array-like of shape (n_labels,)) – Names of the classes.
- Returns
metrics_dict – The dictionary with all computed metrics. The keys are:
classification_report: The sklearn classification report.
confusion_matrix: The sklearn confusion matrix.
accuracy: The accuracy score.
f1_micro: The micro-averaged F1-score.
f1_macro: The macro-averaged F1-score.
f1_weighted: The weighted average of the F1-score.
- Return type
dict
-
morphoclass.training.cli.get_model(config, pretrained_state_dict, n_classes)¶ Reconstruct the model from the config file.
-
morphoclass.training.cli.oversample(ids, labels, random_state=None)¶ Oversample to balance the label count.
-
morphoclass.training.cli.plot_confusion_matrices(img_dir: pathlib.Path | None, training_log: TrainingLog) → None¶ Plot confusion matrices.
-
morphoclass.training.cli.prune_one_member_classes(dataset)¶ Prune classes with only one member.
-
morphoclass.training.cli.run_training(dataset: morphoclass.data.morphology_dataset.MorphologyDataset, config: morphoclass.training.training_config.TrainingConfig) → morphoclass.training.training_log.TrainingLog¶ Training and evaluation of the model.
-
morphoclass.training.cli.split_metrics(splits: Sequence[dict]) → dict[str, float]¶ Compute average metrics across splits.
-
morphoclass.training.cli.train_dm_model(model: torch.nn.Module, dataset: MorphologyDataset, train_idx: torch_geometric.data.dataset.IndexType, val_idx: torch_geometric.data.dataset.IndexType, optimizer: torch.optim.Optimizer, batch_size: int, n_epochs: int, interactive: bool = False) → dict[str, Any]¶ Train morphoclass models.
-
morphoclass.training.cli.train_ml_model(model: sklearn.base.BaseEstimator, dataset: MorphologyDataset, train_idx: torch_geometric.data.dataset.IndexType, val_idx: torch_geometric.data.dataset.IndexType) → dict[str, Any]¶ Train a sklearn-like model.
- Parameters
model – A sklearn-like model with methods
fit,predictandpredict_proba.dataset – A morphology dataset.
train_idx – The indices of the training set.
val_idx – The indices of the validation set.
- Returns
The training history with the keys “model”, “predictions”, “ground_truths”, “probabilities”, “latent_features”.
- Return type
dict
-
morphoclass.training.cli.train_model(config, pretrained_model, train_idx, val_idx, dataset)¶ Train a model.