Phương thức zero_grad()
cần được gọi trong quá trình đào tạo. Nhưng tài liệu không hữu ích lắm
| zero_grad(self)
| Sets gradients of all model parameters to zero.
Tại sao chúng ta cần gọi phương thức này?
Phương thức zero_grad()
cần được gọi trong quá trình đào tạo. Nhưng tài liệu không hữu ích lắm
| zero_grad(self)
| Sets gradients of all model parameters to zero.
Tại sao chúng ta cần gọi phương thức này?
Câu trả lời:
Trong đó PyTorch
, chúng ta cần đặt các gradient thành 0 trước khi bắt đầu thực hiện backpropragation vì PyTorch tích lũy các gradient trong các lần đi lùi tiếp theo. Điều này thuận tiện trong khi huấn luyện RNN. Vì vậy, hành động mặc định là tích lũy (tức là tổng) các gradient trên mỗi loss.backward()
cuộc gọi.
Vì lý do này, khi bạn bắt đầu vòng lặp đào tạo của mình, lý tưởng nhất là bạn nên zero out the gradients
cập nhật tham số một cách chính xác. Ngược lại, gradient sẽ hướng theo một số hướng khác với hướng dự định về phía cực tiểu (hoặc cực đại , trong trường hợp mục tiêu tối đa hóa).
Đây là một ví dụ đơn giản:
import torch
from torch.autograd import Variable
import torch.optim as optim
def linear_model(x, W, b):
return torch.matmul(x, W) + b
data, targets = ...
W = Variable(torch.randn(4, 3), requires_grad=True)
b = Variable(torch.randn(3), requires_grad=True)
optimizer = optim.Adam([W, b])
for sample, target in zip(data, targets):
# clear out the gradients of all Variables
# in this optimizer (i.e. W, b)
optimizer.zero_grad()
output = linear_model(sample, W, b)
loss = (output - target) ** 2
loss.backward()
optimizer.step()
Ngoài ra, nếu bạn đang thực hiện giảm độ dốc màu vani , thì:
W = Variable(torch.randn(4, 3), requires_grad=True)
b = Variable(torch.randn(3), requires_grad=True)
for sample, target in zip(data, targets):
# clear out the gradients of Variables
# (i.e. W, b)
W.grad.data.zero_()
b.grad.data.zero_()
output = linear_model(sample, W, b)
loss = (output - target) ** 2
loss.backward()
W -= learning_rate * W.grad.data
b -= learning_rate * b.grad.data
Lưu ý : Sự tích lũy (tức là tổng ) của các gradient xảy ra khi .backward()
được gọi trên loss
tensor .
zero_grad () đang khởi động lại vòng lặp mà không có tổn thất nào từ bước trước nếu bạn sử dụng phương pháp gradient để giảm lỗi (hoặc tổn thất)
nếu bạn không sử dụng zero_grad (), tổn thất sẽ giảm xuống không tăng theo yêu cầu
Ví dụ: nếu bạn sử dụng zero_grad (), bạn sẽ tìm thấy kết quả sau:
model training loss is 1.5
model training loss is 1.4
model training loss is 1.3
model training loss is 1.2
nếu bạn không sử dụng zero_grad (), bạn sẽ tìm thấy kết quả sau:
model training loss is 1.4
model training loss is 1.9
model training loss is 2
model training loss is 2.8
model training loss is 3.5