Tôi đang cố gắng cập nhật / thay đổi các tham số của mô hình mạng nơ ron và sau đó có đường truyền chuyển tiếp của mạng nơ ron cập nhật nằm trong biểu đồ tính toán (cho dù chúng tôi có thực hiện bao nhiêu thay đổi / cập nhật).
Tôi đã thử ý tưởng này nhưng bất cứ khi nào tôi thực hiện nó, pytorch sẽ đặt các thang đo được cập nhật của tôi (bên trong mô hình) thành các lá, nó sẽ giết dòng chảy của gradient đến các mạng mà tôi muốn nhận độ dốc. Nó giết chết dòng chảy của gradient vì các nút lá không phải là một phần của biểu đồ tính toán theo cách tôi muốn (vì chúng không thực sự là lá).
Tôi đã thử nhiều thứ nhưng dường như không có gì để làm việc. Tôi đã tạo một mã giả bao gồm các phần tử in các gradient của mạng mà tôi muốn có độ dốc:
import torch
import torch.nn as nn
import copy
from collections import OrderedDict
# img = torch.randn([8,3,32,32])
# targets = torch.LongTensor([1, 2, 0, 6, 2, 9, 4, 9])
# img = torch.randn([1,3,32,32])
# targets = torch.LongTensor([1])
x = torch.randn(1)
target = 12.0*x**2
criterion = nn.CrossEntropyLoss()
#loss_net = nn.Sequential(OrderedDict([('conv0',nn.Conv2d(in_channels=3,out_channels=10,kernel_size=32))]))
loss_net = nn.Sequential(OrderedDict([('fc0', nn.Linear(in_features=1,out_features=1))]))
hidden = torch.randn(size=(1,1),requires_grad=True)
updater_net = nn.Sequential(OrderedDict([('fc0',nn.Linear(in_features=1,out_features=1))]))
print(f'updater_net.fc0.weight.is_leaf = {updater_net.fc0.weight.is_leaf}')
#
nb_updates = 2
for i in range(nb_updates):
print(f'i = {i}')
new_params = copy.deepcopy( loss_net.state_dict() )
## w^<t> := f(w^<t-1>,delta^<t-1>)
for (name, w) in loss_net.named_parameters():
print(f'name = {name}')
print(w.size())
hidden = updater_net(hidden).view(1)
print(hidden.size())
#delta = ((hidden**2)*w/2)
delta = w + hidden
wt = w + delta
print(wt.size())
new_params[name] = wt
#del loss_net.fc0.weight
#setattr(loss_net.fc0, 'weight', nn.Parameter( wt ))
#setattr(loss_net.fc0, 'weight', wt)
#loss_net.fc0.weight = wt
#loss_net.fc0.weight = nn.Parameter( wt )
##
loss_net.load_state_dict(new_params)
#
print()
print(f'updater_net.fc0.weight.is_leaf = {updater_net.fc0.weight.is_leaf}')
outputs = loss_net(x)
loss_val = 0.5*(target - outputs)**2
loss_val.backward()
print()
print(f'-- params that dont matter if they have gradients --')
print(f'loss_net.grad = {loss_net.fc0.weight.grad}')
print('-- params we want to have gradients --')
print(f'hidden.grad = {hidden.grad}')
print(f'updater_net.fc0.weight.grad = {updater_net.fc0.weight.grad}')
print(f'updater_net.fc0.bias.grad = {updater_net.fc0.bias.grad}')
nếu có ai biết cách thực hiện việc này, vui lòng cho tôi ping ... Tôi đặt số lần cập nhật là 2 vì thao tác cập nhật phải nằm trong biểu đồ tính toán một số lần tùy ý ... vì vậy nó PHẢI hoạt động cho 2.
Bài liên quan mạnh mẽ:
- SO: Làm thế nào để một tham số trong mô hình pytorch không phải là lá và nằm trong biểu đồ tính toán?
- diễn đàn pytorch: https://discuss.pytorch.org/t/how-does-one-have-the-parameter-of-a-model-not-be-leafs/70076
Đăng chéo:
backward
? Cụ thểretain_graph=True
và / hoặccreate_graph=True
?