from typing import Dict import torch from torchmetrics import functional as FM def classification_metrics( preds: torch.Tensor, target: torch.Tensor, num_classes: int, average: str = 'macro', task: str = 'multiclass') -> Dict[str, torch.Tensor]: """ get_classification_metrics Return some metrics evaluation the classification task Parameters ---------- preds : torch.Tensor logits, probs target : torch.Tensor targets label Returns ------- Dict[str, torch.Tensor] _description_ """ f1 = FM.f1_score(preds=preds, target=target, num_classes=num_classes, task=task, average=average) recall = FM.recall(preds=preds, target=target, num_classes=num_classes, task=task, average=average) precision = FM.precision(preds=preds, target=target, num_classes=num_classes, task=task, average=average) return dict(f1=f1, precision=precision, recall=recall)