Làm thế nào để một tham số trong mô hình pytorch không phải là lá và nằm trong biểu đồ tính toán?


10

Tôi đang cố gắng cập nhật / thay đổi các tham số của mô hình mạng nơ ron và sau đó có đường truyền chuyển tiếp của mạng nơ ron cập nhật nằm trong biểu đồ tính toán (cho dù chúng tôi có thực hiện bao nhiêu thay đổi / cập nhật).

Tôi đã thử ý tưởng này nhưng bất cứ khi nào tôi thực hiện nó, pytorch sẽ đặt các thang đo được cập nhật của tôi (bên trong mô hình) thành các lá, nó sẽ giết dòng chảy của gradient đến các mạng mà tôi muốn nhận độ dốc. Nó giết chết dòng chảy của gradient vì các nút lá không phải là một phần của biểu đồ tính toán theo cách tôi muốn (vì chúng không thực sự là lá).

Tôi đã thử nhiều thứ nhưng dường như không có gì để làm việc. Tôi đã tạo một mã giả bao gồm các phần tử in các gradient của mạng mà tôi muốn có độ dốc:

import torch
import torch.nn as nn

import copy

from collections import OrderedDict

# img = torch.randn([8,3,32,32])
# targets = torch.LongTensor([1, 2, 0, 6, 2, 9, 4, 9])
# img = torch.randn([1,3,32,32])
# targets = torch.LongTensor([1])
x = torch.randn(1)
target = 12.0*x**2

criterion = nn.CrossEntropyLoss()

#loss_net = nn.Sequential(OrderedDict([('conv0',nn.Conv2d(in_channels=3,out_channels=10,kernel_size=32))]))
loss_net = nn.Sequential(OrderedDict([('fc0', nn.Linear(in_features=1,out_features=1))]))

hidden = torch.randn(size=(1,1),requires_grad=True)
updater_net = nn.Sequential(OrderedDict([('fc0',nn.Linear(in_features=1,out_features=1))]))
print(f'updater_net.fc0.weight.is_leaf = {updater_net.fc0.weight.is_leaf}')
#
nb_updates = 2
for i in range(nb_updates):
    print(f'i = {i}')
    new_params = copy.deepcopy( loss_net.state_dict() )
    ## w^<t> := f(w^<t-1>,delta^<t-1>)
    for (name, w) in loss_net.named_parameters():
        print(f'name = {name}')
        print(w.size())
        hidden = updater_net(hidden).view(1)
        print(hidden.size())
        #delta = ((hidden**2)*w/2)
        delta = w + hidden
        wt = w + delta
        print(wt.size())
        new_params[name] = wt
        #del loss_net.fc0.weight
        #setattr(loss_net.fc0, 'weight', nn.Parameter( wt ))
        #setattr(loss_net.fc0, 'weight', wt)
        #loss_net.fc0.weight = wt
        #loss_net.fc0.weight = nn.Parameter( wt )
    ##
    loss_net.load_state_dict(new_params)
#
print()
print(f'updater_net.fc0.weight.is_leaf = {updater_net.fc0.weight.is_leaf}')
outputs = loss_net(x)
loss_val = 0.5*(target - outputs)**2
loss_val.backward()
print()
print(f'-- params that dont matter if they have gradients --')
print(f'loss_net.grad = {loss_net.fc0.weight.grad}')
print('-- params we want to have gradients --')
print(f'hidden.grad = {hidden.grad}')
print(f'updater_net.fc0.weight.grad = {updater_net.fc0.weight.grad}')
print(f'updater_net.fc0.bias.grad = {updater_net.fc0.bias.grad}')

nếu có ai biết cách thực hiện việc này, vui lòng cho tôi ping ... Tôi đặt số lần cập nhật là 2 vì thao tác cập nhật phải nằm trong biểu đồ tính toán một số lần tùy ý ... vì vậy nó PHẢI hoạt động cho 2.


Bài liên quan mạnh mẽ:

Đăng chéo:


Bạn đã thử tranh luận cho backward? Cụ thể retain_graph=Truevà / hoặc create_graph=True?
Szymon Maszke

Câu trả lời:


3

DOESNT LÀM VIỆC SỞ HỮU các mô-đun tham số có tên bị xóa.


Có vẻ như công việc này:

import torch
import torch.nn as nn

from torchviz import make_dot

import copy

from collections import OrderedDict

# img = torch.randn([8,3,32,32])
# targets = torch.LongTensor([1, 2, 0, 6, 2, 9, 4, 9])
# img = torch.randn([1,3,32,32])
# targets = torch.LongTensor([1])
x = torch.randn(1)
target = 12.0*x**2

criterion = nn.CrossEntropyLoss()

#loss_net = nn.Sequential(OrderedDict([('conv0',nn.Conv2d(in_channels=3,out_channels=10,kernel_size=32))]))
loss_net = nn.Sequential(OrderedDict([('fc0', nn.Linear(in_features=1,out_features=1))]))

hidden = torch.randn(size=(1,1),requires_grad=True)
updater_net = nn.Sequential(OrderedDict([('fc0',nn.Linear(in_features=1,out_features=1))]))
print(f'updater_net.fc0.weight.is_leaf = {updater_net.fc0.weight.is_leaf}')
#
def del_attr(obj, names):
    if len(names) == 1:
        delattr(obj, names[0])
    else:
        del_attr(getattr(obj, names[0]), names[1:])
def set_attr(obj, names, val):
    if len(names) == 1:
        setattr(obj, names[0], val)
    else:
        set_attr(getattr(obj, names[0]), names[1:], val)

nb_updates = 2
for i in range(nb_updates):
    print(f'i = {i}')
    new_params = copy.deepcopy( loss_net.state_dict() )
    ## w^<t> := f(w^<t-1>,delta^<t-1>)
    for (name, w) in list(loss_net.named_parameters()):
        hidden = updater_net(hidden).view(1)
        #delta = ((hidden**2)*w/2)
        delta = w + hidden
        wt = w + delta
        del_attr(loss_net, name.split("."))
        set_attr(loss_net, name.split("."), wt)
    ##
#
print()
print(f'updater_net.fc0.weight.is_leaf = {updater_net.fc0.weight.is_leaf}')
print(f'loss_net.fc0.weight.is_leaf = {loss_net.fc0.weight.is_leaf}')
outputs = loss_net(x)
loss_val = 0.5*(target - outputs)**2
loss_val.backward()
print()
print(f'-- params that dont matter if they have gradients --')
print(f'loss_net.grad = {loss_net.fc0.weight.grad}')
print('-- params we want to have gradients --')
print(f'hidden.grad = {hidden.grad}') # None because this is not a leaf, it is overriden in the for loop above.
print(f'updater_net.fc0.weight.grad = {updater_net.fc0.weight.grad}')
print(f'updater_net.fc0.bias.grad = {updater_net.fc0.bias.grad}')
make_dot(loss_val)

đầu ra:

updater_net.fc0.weight.is_leaf = True
i = 0
i = 1

updater_net.fc0.weight.is_leaf = True
loss_net.fc0.weight.is_leaf = False

-- params that dont matter if they have gradients --
loss_net.grad = None
-- params we want to have gradients --
hidden.grad = None
updater_net.fc0.weight.grad = tensor([[0.7152]])
updater_net.fc0.bias.grad = tensor([-7.4249])

Lời cảm ơn: albanD hùng mạnh từ nhóm pytorch: https://discuss.pytorch.org/t/how-does-one-have-the-parameter-of-a-model-not-be-leafs/70076/9?u= pinocchio


Các bạn, điều này là sai, đừng sử dụng mã này, nó không cho phép truyền bá độ dốc trong hơn 1 bước. Sử dụng cái này thay thế: github.com/facebookresearch/higher
Pinocchio

Điều này không làm việc ppl!
Pinocchio

thư viện cao hơn chưa làm việc cho tôi.
Pinocchio

Khi sử dụng trang web của chúng tôi, bạn xác nhận rằng bạn đã đọc và hiểu Chính sách cookieChính sách bảo mật của chúng tôi.
Licensed under cc by-sa 3.0 with attribution required.