Tại sao chúng ta cần gọi zero_grad () trong PyTorch?


Câu trả lời:


144

Trong đó PyTorch, chúng ta cần đặt các gradient thành 0 trước khi bắt đầu thực hiện backpropragation vì PyTorch tích lũy các gradient trong các lần đi lùi tiếp theo. Điều này thuận tiện trong khi huấn luyện RNN. Vì vậy, hành động mặc định là tích lũy (tức là tổng) các gradient trên mỗi loss.backward()cuộc gọi.

Vì lý do này, khi bạn bắt đầu vòng lặp đào tạo của mình, lý tưởng nhất là bạn nên zero out the gradientscập nhật tham số một cách chính xác. Ngược lại, gradient sẽ hướng theo một số hướng khác với hướng dự định về phía cực tiểu (hoặc cực đại , trong trường hợp mục tiêu tối đa hóa).

Đây là một ví dụ đơn giản:

import torch
from torch.autograd import Variable
import torch.optim as optim

def linear_model(x, W, b):
    return torch.matmul(x, W) + b

data, targets = ...

W = Variable(torch.randn(4, 3), requires_grad=True)
b = Variable(torch.randn(3), requires_grad=True)

optimizer = optim.Adam([W, b])

for sample, target in zip(data, targets):
    # clear out the gradients of all Variables 
    # in this optimizer (i.e. W, b)
    optimizer.zero_grad()
    output = linear_model(sample, W, b)
    loss = (output - target) ** 2
    loss.backward()
    optimizer.step()

Ngoài ra, nếu bạn đang thực hiện giảm độ dốc màu vani , thì:

W = Variable(torch.randn(4, 3), requires_grad=True)
b = Variable(torch.randn(3), requires_grad=True)

for sample, target in zip(data, targets):
    # clear out the gradients of Variables 
    # (i.e. W, b)
    W.grad.data.zero_()
    b.grad.data.zero_()

    output = linear_model(sample, W, b)
    loss = (output - target) ** 2
    loss.backward()

    W -= learning_rate * W.grad.data
    b -= learning_rate * b.grad.data

Lưu ý : Sự tích lũy (tức là tổng ) của các gradient xảy ra khi .backward()được gọi trên losstensor .


3
cảm ơn bạn rất nhiều, điều này thực sự hữu ích! Bạn có tình cờ biết liệu tensorflow có hành vi không?
layser

Chỉ để chắc chắn rằng .. nếu bạn không làm điều này, thì bạn sẽ gặp phải vấn đề về gradient bùng nổ, phải không?
zwep

2
@zwep Nếu chúng ta tích lũy các gradient, điều đó không có nghĩa là độ lớn của chúng tăng lên: một ví dụ sẽ là nếu dấu hiệu của gradient tiếp tục lật. Vì vậy, nó sẽ không đảm bảo bạn sẽ gặp phải vấn đề gradient bùng nổ. Bên cạnh đó, các gradient bùng nổ vẫn tồn tại ngay cả khi bạn 0 chính xác.
Tom Roth

Khi bạn chạy vani gradient descent bạn không nhận được lỗi "Biến lá yêu cầu grad đã được sử dụng trong hoạt động tại chỗ" khi bạn cố gắng cập nhật các trọng số?
MUAS

1

zero_grad () đang khởi động lại vòng lặp mà không có tổn thất nào từ bước trước nếu bạn sử dụng phương pháp gradient để giảm lỗi (hoặc tổn thất)

nếu bạn không sử dụng zero_grad (), tổn thất sẽ giảm xuống không tăng theo yêu cầu

Ví dụ: nếu bạn sử dụng zero_grad (), bạn sẽ tìm thấy kết quả sau:

model training loss is 1.5
model training loss is 1.4
model training loss is 1.3
model training loss is 1.2

nếu bạn không sử dụng zero_grad (), bạn sẽ tìm thấy kết quả sau:

model training loss is 1.4
model training loss is 1.9
model training loss is 2
model training loss is 2.8
model training loss is 3.5
Khi sử dụng trang web của chúng tôi, bạn xác nhận rằng bạn đã đọc và hiểu Chính sách cookieChính sách bảo mật của chúng tôi.
Licensed under cc by-sa 3.0 with attribution required.