Data science/Computer Vision
[5분 컷 리뷰] Matthew's correlation coefficient loss
연금(Pension)술사
2024. 7. 25. 19:39
요약
DICE loss에서 반영하지 못하는 background의 오분류를 패널티로 반영하기위한 손실함수. Confusion matrix의 요소들을 직접 사용하여 손실함수를 만듬.
DICE vs MCC(Mattew's correlation coefficient, MCC)
DICE loss을 사용하든, IoU(Jaccard loss)을 사용하든 둘 다 TN(True negative)에 대한 정보가 없습니다.
이 True Negative을 반영하기위해서, confusion matrix에서 각 요소를 아래와 같이 만듭니다. 즉, TP와 TN은 많이 맞추고, FP,FN은 적은 confusion matrix을 만들기를 바라며, 이를 정규화하기위한 분모로 만듭니다. 아래를 보면 각 요소들이 분모에 2번씩 사용됩니다.
MCC은 미분가능한가?
다행히 MCC은 segmentation에서 pixel wise로 연산하기때문에 미분가능합니다. 각 요소를 아래와 같이 표현 할 수 있기 때문입니다.
아래와 같이 1-MCC로 손실함수로 사용하고, 각 항목을 y예측값에 대해서 미분하면됩니다.
이 식을 pytorch로 구현하면 아래와 같습니다.
class MCCLosswithLogits(torch.nn.Module):
"""
Calculates the proposed Matthews Correlation Coefficient-based loss.
Args:
inputs (torch.Tensor): 1-hot encoded predictions
targets (torch.Tensor): 1-hot encoded ground truth
Reference:
https://github.com/kakumarabhishek/MCC-Loss/blob/main/loss.py
"""
def __init__(self):
super(MCCLosswithLogits, self).__init__()
def forward(self, logits, targets):
"""
Note:
위의 모든 코드가 logits값을 입력값으로 받고 있어서, logtis->confidence [0,1]으로 변경
MCC = (TP.TN - FP.FN) / sqrt((TP+FP) . (TP+FN) . (TN+FP) . (TN+FN))
where TP, TN, FP, and FN are elements in the confusion matrix.
"""
pred = torch.sigmoid(logits)
tp = torch.sum(torch.mul(pred, targets))
tn = torch.sum(torch.mul((1 - pred), (1 - targets)))
fp = torch.sum(torch.mul(pred, (1 - targets)))
fn = torch.sum(torch.mul((1 - pred), targets))
numerator = torch.mul(tp, tn) - torch.mul(fp, fn)
denominator = torch.sqrt(
torch.add(tp, 1, fp)
* torch.add(tp, 1, fn)
* torch.add(tn, 1, fp)
* torch.add(tn, 1, fn)
)
# Adding 1 to the denominator to avoid divide-by-zero errors.
mcc = torch.div(numerator.sum(), denominator.sum() + 1.0)
return 1 - mcc
반응형