본문 바로가기
Data science/Computer Vision

[5분 컷 리뷰] Matthew's correlation coefficient loss

by 연금(Pension)술사 2024. 7. 25.

 

요약


DICE loss에서 반영하지 못하는 background의 오분류를 패널티로 반영하기위한 손실함수. Confusion matrix의 요소들을 직접 사용하여 손실함수를 만듬.

 

DICE vs MCC(Mattew's correlation coefficient, MCC)

DICE loss을 사용하든, IoU(Jaccard loss)을 사용하든 둘 다 TN(True negative)에 대한 정보가 없습니다. 

 

이 True Negative을 반영하기위해서, confusion matrix에서 각 요소를 아래와 같이 만듭니다. 즉, TP와 TN은 많이 맞추고, FP,FN은 적은 confusion matrix을 만들기를 바라며, 이를 정규화하기위한 분모로 만듭니다. 아래를 보면 각 요소들이 분모에 2번씩 사용됩니다.

 

MCC은 미분가능한가?

다행히 MCC은 segmentation에서 pixel wise로 연산하기때문에 미분가능합니다. 각 요소를 아래와 같이 표현 할 수 있기 때문입니다.

 

아래와 같이 1-MCC로 손실함수로 사용하고, 각 항목을 y예측값에 대해서 미분하면됩니다.

 

이 식을 pytorch로 구현하면 아래와 같습니다.

class MCCLosswithLogits(torch.nn.Module):
    """
    Calculates the proposed Matthews Correlation Coefficient-based loss.

    Args:
        inputs (torch.Tensor): 1-hot encoded predictions
        targets (torch.Tensor): 1-hot encoded ground truth
    
    Reference:
        https://github.com/kakumarabhishek/MCC-Loss/blob/main/loss.py
    """

    def __init__(self):
        super(MCCLosswithLogits, self).__init__()

    def forward(self, logits, targets):
        """
        
        Note:
            위의 모든 코드가 logits값을 입력값으로 받고 있어서, logtis->confidence [0,1]으로 변경
            MCC = (TP.TN - FP.FN) / sqrt((TP+FP) . (TP+FN) . (TN+FP) . (TN+FN))
            where TP, TN, FP, and FN are elements in the confusion matrix.
        
        
        """
        pred = torch.sigmoid(logits)
        tp = torch.sum(torch.mul(pred, targets))
        tn = torch.sum(torch.mul((1 - pred), (1 - targets)))
        fp = torch.sum(torch.mul(pred, (1 - targets)))
        fn = torch.sum(torch.mul((1 - pred), targets))

        numerator = torch.mul(tp, tn) - torch.mul(fp, fn)
        denominator = torch.sqrt(
            torch.add(tp, 1, fp)
            * torch.add(tp, 1, fn)
            * torch.add(tn, 1, fp)
            * torch.add(tn, 1, fn)
        )

        # Adding 1 to the denominator to avoid divide-by-zero errors.
        mcc = torch.div(numerator.sum(), denominator.sum() + 1.0)
        return 1 - mcc
반응형