Pytorch: Cách chính xác để sử dụng bản đồ trọng lượng tùy chỉnh trong kiến ​​trúc unet


8

Có một mẹo nổi tiếng trong kiến ​​trúc u-net là sử dụng bản đồ trọng lượng tùy chỉnh để tăng độ chính xác. Dưới đây là các chi tiết của nó-

nhập mô tả hình ảnh ở đây

Bây giờ, bằng cách hỏi ở đây và tại nhiều nơi khác, tôi đã biết về 2 cách tiếp cận. Tôi muốn biết cách nào là đúng hoặc có cách tiếp cận nào khác đúng hơn không?

1) Đầu tiên là sử dụng torch.nn.Functionalphương pháp trong vòng huấn luyện-

loss = torch.nn.functional.cross_entropy(output, target, w) Trong đó w sẽ là trọng lượng tùy chỉnh tính toán.

2) Thứ hai là sử dụng reduction='none'chức năng gọi mất chức năng ngoài vòng huấn luyện criterion = torch.nn.CrossEntropy(reduction='none')

và sau đó trong vòng đào tạo nhân với trọng số tùy chỉnh-

gt # Ground truth, format torch.long
pd # Network output
W # per-element weighting based on the distance map from UNet
loss = criterion(pd, gt)
loss = W*loss # Ensure that weights are scaled appropriately
loss = torch.sum(loss.flatten(start_dim=1), axis=0) # Sums the loss per image
loss = torch.mean(loss) # Average across a batch

Bây giờ, tôi hơi bối rối không biết cái nào đúng hay còn cách nào khác, hay cả hai đều đúng?

Câu trả lời:


3

Phần trọng số trông giống như entropy chéo có trọng số đơn giản được thực hiện như thế này cho số lượng các lớp (2 trong ví dụ dưới đây).

weights = torch.FloatTensor([.3, .7])
loss_func = nn.CrossEntropyLoss(weight=weights)

BIÊN TẬP:

Bạn đã thấy triển khai này từ Patrick Black?

# Set properties
batch_size = 10
out_channels = 2
W = 10
H = 10

# Initialize logits etc. with random
logits = torch.FloatTensor(batch_size, out_channels, H, W).normal_()
target = torch.LongTensor(batch_size, H, W).random_(0, out_channels)
weights = torch.FloatTensor(batch_size, 1, H, W).random_(1, 3)

# Calculate log probabilities
logp = F.log_softmax(logits)

# Gather log probabilities with respect to target
logp = logp.gather(1, target.view(batch_size, 1, H, W))

# Multiply with weights
weighted_logp = (logp * weights).view(batch_size, -1)

# Rescale so that loss is in approx. same interval
weighted_loss = weighted_logp.sum(1) / weights.view(batch_size, -1).sum(1)

# Average over mini-batch
weighted_loss = -1. * weighted_loss.mean()

Vấn đề là trọng lượng được tính theo một chức năng nhất định ở đây và không kín đáo. Để biết thêm thông tin, đây là một bài báo - arxiv.org/abs/1505.04597
Đánh dấu

1
@Mark oh tôi thấy bây giờ. Vì vậy, nó là một đầu ra mất pixelwise. Và các đường viền được tính toán trước bằng cách sử dụng một số thư viện như opencvhoặc một cái gì đó, và sau đó các vị trí pixel đó được lưu cho mỗi hình ảnh và sau đó nhân với các thang đo mất sau này trong quá trình đào tạo để thuật toán tập trung vào việc giảm tổn thất ở các khu vực đó.
jchaykow

Thanks.this hợp pháp này trông giống như một câu trả lời, tôi sẽ thử xác minh và thực hiện nó nhiều hơn và sẽ chấp nhận câu trả lời của bạn sau nó.
Đánh dấu

Bạn có thể giải thích trực giác đằng sau dòng nàylogp = logp.gather(1, target.view(batch_size, 1, H, W))
Đánh dấu

0

Lưu ý rằng Torch.nn.CrossEntropyLoss () là một lớp gọi Torch.nn.feftal. Xem https://pytorch.org/docs/urdy/_modules/torch/nn/modules/loss.html#CrossEntropyLoss

Bạn có thể sử dụng các trọng số khi bạn xác định các tiêu chí. So sánh chúng theo chức năng, cả hai phương pháp đều giống nhau.

Bây giờ, tôi không hiểu ý tưởng của bạn về mất điện toán trong vòng đào tạo trong phương pháp 1 và bên ngoài vòng đào tạo trong phương pháp 2. nếu bạn tính toán mất bên ngoài vòng lặp thì bạn sẽ sao lưu như thế nào?


Tôi đã không nhầm lẫn giữa việc sử dụng torch.nn.CrossEntropyLoss() torch.nn.functional.cross_entropy(output, target, w), tôi đã bối rối làm thế nào để sử dụng bản đồ trọng lượng tùy chỉnh trong sự mất mát. Vui lòng xem bài viết này - arxiv.org/abs/1505.04597 và cho tôi biết, nếu bạn vẫn không thể tìm ra tôi là gì hỏi
Đánh dấu

1
Nếu tôi hiểu đúng, tôi nghĩ phương pháp 2 là đúng. Các trọng số (w) bên trong đèn pin mất.nn.feftal.cross_entropy (đầu ra, đích, w) là các trọng số cho các lớp không w (x) trong công thức. Chúng ta có thể dễ dàng kiểm tra nó với một kịch bản nhỏ.
Devansh Bisla

Đúng, ngay cả tôi cũng đi đến kết luận tương tự. Tôi sẽ quay lại với bạn nếu mạng của tôi chạy như mong đợi và sẽ đánh dấu câu trả lời là được chấp nhận.
Đánh dấu

được rồi, nó không hoạt động. Tôi nhận được grad can be implicitly created only for scalar outputskhi tôi chạy loss = loss * w method
Mark

Bạn có chắc chắn rằng bạn đang tóm tắt chúng hoặc lấy ý nghĩa?
Devansh Bisla
Khi sử dụng trang web của chúng tôi, bạn xác nhận rằng bạn đã đọc và hiểu Chính sách cookieChính sách bảo mật của chúng tôi.
Licensed under cc by-sa 3.0 with attribution required.