Việc sử dụng Torch.no_grad trong pytorch là gì?


18

Tôi mới sử dụng pytorch và bắt đầu với mã github này . Tôi không hiểu nhận xét trong dòng 60-61 trong mã "because weights have requires_grad=True, but we don't need to track this in autograd". Tôi hiểu rằng chúng tôi đề cập requires_grad=Trueđến các biến mà chúng tôi cần tính toán độ dốc cho việc sử dụng autograd nhưng nó có nghĩa là "tracked by autograd"gì?

Câu trả lời:


21

Trình bao bọc "với Torch.no_grad ()" tạm thời đặt tất cả cờ Yêu cầu thành sai. Một ví dụ từ hướng dẫn chính thức của PyTorch ( https://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html#gradrons ):

x = torch.randn(3, requires_grad=True)
print(x.requires_grad)
print((x ** 2).requires_grad)

with torch.no_grad():
    print((x ** 2).requires_grad)

Ngoài:

True
True
False

Tôi khuyên bạn nên đọc tất cả các hướng dẫn từ trang web ở trên.

Trong ví dụ của bạn: Tôi đoán tác giả không muốn PyTorch tính toán độ dốc của các biến được xác định mới w1 và w2 vì anh ta chỉ muốn cập nhật giá trị của chúng.


5
with torch.no_grad()

sẽ làm cho tất cả các hoạt động trong khối không có độ dốc.

Trong pytorch, bạn không thể thực hiện thay đổi vị trí của w1 và w2, đó là hai biến có require_grad = True. Tôi nghĩ rằng việc tránh thay đổi vị trí của w1 và w2 là vì nó sẽ gây ra lỗi trong tính toán lan truyền ngược. Vì thay đổi vị trí sẽ thay đổi hoàn toàn w1 và w2.

Tuy nhiên, nếu bạn sử dụng điều này no_grad(), bạn có thể điều khiển w1 mới và w2 mới không có độ dốc do chúng được tạo bởi các thao tác, có nghĩa là bạn chỉ thay đổi giá trị của w1 và w2, chứ không phải phần gradient, chúng vẫn có thông tin độ dốc biến được xác định trước đó và tuyên truyền trở lại có thể tiếp tục.

Khi sử dụng trang web của chúng tôi, bạn xác nhận rằng bạn đã đọc và hiểu Chính sách cookieChính sách bảo mật của chúng tôi.
Licensed under cc by-sa 3.0 with attribution required.