Cách tốt nhất để cứu một mô hình được đào tạo trong PyTorch?


191

Tôi đang tìm cách thay thế để cứu một mô hình được đào tạo ở PyTorch. Cho đến nay, tôi đã tìm thấy hai lựa chọn thay thế.

  1. Torch.save () để lưu mô hình và Torch.load () để tải mô hình.
  2. model.state_dict () để lưu mô hình đã được đào tạo và model.load_state_dict () để tải mô hình đã lưu.

Tôi đã đi qua cuộc thảo luận này trong đó cách tiếp cận 2 được khuyến nghị so với cách tiếp cận 1.

Câu hỏi của tôi là, tại sao cách tiếp cận thứ hai được ưa thích? Có phải chỉ vì các mô-đun Torch.nn có hai chức năng đó và chúng tôi được khuyến khích sử dụng chúng?


2
Tôi nghĩ đó là bởi vì Torch.save () cũng lưu tất cả các biến trung gian, giống như các đầu ra trung gian để sử dụng lan truyền ngược. Nhưng bạn chỉ cần lưu các tham số mô hình, như trọng lượng / độ lệch, v.v. Đôi khi cái trước có thể lớn hơn cái sau.
Dawei Yang

2
Tôi đã thử nghiệm torch.save(model, f)torch.save(model.state_dict(), f). Các tập tin đã lưu có cùng kích thước. Tôi đang cảm thấy bối rối. Ngoài ra, tôi thấy việc sử dụng dưa chua để lưu model.state_dict () cực kỳ chậm. Tôi nghĩ cách tốt nhất là sử dụng torch.save(model.state_dict(), f)vì bạn xử lý việc tạo mô hình và đèn pin xử lý việc tải trọng lượng mô hình, do đó loại bỏ các vấn đề có thể xảy ra. Tham khảo: discuss.pytorch.org/t/saving-torch-models/838/4
Dawei Yang

Có vẻ như PyTorch đã giải quyết vấn đề này rõ ràng hơn một chút trong phần hướng dẫn của họ rất nhiều thông tin tốt không được liệt kê trong các câu trả lời ở đây, bao gồm lưu nhiều hơn một mô hình tại một thời điểm và các mô hình khởi động ấm áp.
whlteXbread 24/03/19

Có gì sai khi sử dụng pickle?
Charlie Parker

1
@CharlieParker Torch.save dựa trên dưa chua. Dưới đây là từ hướng dẫn được liên kết ở trên: "[Torch.save] sẽ lưu toàn bộ mô-đun bằng mô-đun dưa của Python. Nhược điểm của phương pháp này là dữ liệu được tuần tự hóa được liên kết với các lớp cụ thể và cấu trúc thư mục chính xác được sử dụng khi mô hình được lưu. Lý do cho điều này là do Pickle không tự lưu lớp mô hình. Thay vào đó, nó lưu một đường dẫn đến tệp chứa lớp, được sử dụng trong thời gian tải. Vì điều này, mã của bạn có thể bị phá vỡ theo nhiều cách khác nhau khi được sử dụng trong các dự án khác hoặc sau khi tái cấu trúc. "
David Miller

Câu trả lời:


213

Tôi đã tìm thấy trang này trên repo github của họ, tôi sẽ chỉ dán nội dung ở đây.


Phương pháp đề xuất để lưu mô hình

Có hai cách tiếp cận chính để tuần tự hóa và khôi phục mô hình.

Đầu tiên (được khuyến nghị) chỉ lưu và tải các tham số mô hình:

torch.save(the_model.state_dict(), PATH)

Sau đó:

the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))

Thứ hai lưu và tải toàn bộ mô hình:

torch.save(the_model, PATH)

Sau đó:

the_model = torch.load(PATH)

Tuy nhiên, trong trường hợp này, dữ liệu tuần tự được liên kết với các lớp cụ thể và cấu trúc thư mục chính xác được sử dụng, do đó nó có thể bị phá vỡ theo nhiều cách khác nhau khi được sử dụng trong các dự án khác hoặc sau một số bộ tái cấu trúc nghiêm trọng.


7
Theo @smth thảo luận về công cụ tìm kiếm theo mô hình đào tạo theo mặc định. vì vậy cần phải gọi thủ công the_model.eval () sau khi tải, nếu bạn đang tải nó để suy luận, không tiếp tục đào tạo.
WillZ

phương pháp thứ hai mang lại cho stackoverflow.com/questions/53798009/ Lỗi trên windows 10. không thể giải quyết nó
Gulzar

Có tùy chọn nào để lưu mà không cần truy cập cho lớp mô hình không?
Michael D

Với cách tiếp cận đó, làm thế nào để bạn theo dõi các * args và ** kwargs bạn cần truyền vào cho trường hợp tải?
Mariano Kamp

Có gì sai khi sử dụng pickle?
Charlie Parker

143

Nó phụ thuộc vào những gì bạn muốn làm.

Trường hợp # 1: Lưu mô hình để tự sử dụng mô hình để suy luận : Bạn lưu mô hình, bạn khôi phục mô hình và sau đó bạn thay đổi mô hình sang chế độ đánh giá. Điều này được thực hiện bởi vì bạn thường có BatchNormDropoutcác lớp theo mặc định đang ở chế độ đào tạo khi xây dựng:

torch.save(model.state_dict(), filepath)

#Later to restore:
model.load_state_dict(torch.load(filepath))
model.eval()

Trường hợp 2: Lưu mô hình để tiếp tục đào tạo sau : Nếu bạn cần tiếp tục đào tạo mô hình mà bạn sắp lưu, bạn cần lưu nhiều hơn chỉ là mô hình. Bạn cũng cần lưu trạng thái của trình tối ưu hóa, kỷ nguyên, điểm số, v.v. Bạn sẽ làm như thế này:

state = {
    'epoch': epoch,
    'state_dict': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    ...
}
torch.save(state, filepath)

Để tiếp tục đào tạo, bạn sẽ làm những việc như: state = torch.load(filepath)và sau đó, để khôi phục trạng thái của từng đối tượng riêng lẻ, đại loại như thế này:

model.load_state_dict(state['state_dict'])
optimizer.load_state_dict(state['optimizer'])

Vì bạn đang tiếp tục đào tạo, KHÔNG nên gọi model.eval()một khi bạn khôi phục trạng thái khi tải.

Trường hợp # 3: Mô hình được sử dụng bởi người khác không có quyền truy cập vào mã của bạn : Trong Tensorflow bạn có thể tạo một .pbtệp xác định cả kiến ​​trúc và trọng số của mô hình. Điều này rất tiện dụng, đặc biệt khi sử dụng Tensorflow serve. Cách tương đương để làm điều này trong Pytorch sẽ là:

torch.save(model, filepath)

# Then later:
model = torch.load(filepath)

Cách này vẫn không phải là bằng chứng đạn và vì pytorch vẫn đang trải qua rất nhiều thay đổi, tôi không khuyến nghị điều đó.


1
Có một tập tin đề nghị kết thúc cho 3 trường hợp? Hoặc là luôn luôn .pth?
Verena Haunschmid

1
Trong trường hợp # 3 torch.loadchỉ trả về một OrderedDict. Làm thế nào để bạn có được mô hình để đưa ra dự đoán?
Alber8295

Xin chào, tôi có thể biết cách thực hiện "Trường hợp số 2: Lưu mô hình để tiếp tục đào tạo sau" không? Tôi đã quản lý để tải điểm kiểm tra vào mô hình, sau đó tôi không thể chạy hoặc tiếp tục đào tạo mô hình như "model.to (device) model = train_model_epoch (model, criterion, Optimizer, calendar, epochs)"
dnez

1
Xin chào, đối với trường hợp là suy luận, trong tài liệu pytorch chính thức nói rằng phải lưu trạng thái tối ưu hóa cho trạng thái suy luận hoặc hoàn thành đào tạo. "Khi lưu một điểm kiểm tra chung, được sử dụng cho mục đích suy luận hoặc tiếp tục đào tạo, bạn phải lưu nhiều hơn chỉ trạng thái của mô hình. Điều quan trọng là cũng phải lưu trạng thái của trình tối ưu hóa, vì điều này chứa bộ đệm và tham số được cập nhật khi tàu mô hình . "
Mohammed Awney

1
Trong trường hợp # 3, lớp mô hình nên được xác định ở đâu đó.
Michael D

11

Các dưa cụ thư viện Python giao thức nhị phân cho serializing và de-serializing một đối tượng Python.

Khi bạn import torch(hoặc khi bạn sử dụng PyTorch), nó sẽ import picklecho bạn và bạn không cần phải gọi pickle.dump()pickle.load()trực tiếp, đó là các phương thức để lưu và tải đối tượng.

Trong thực tế, torch.save()torch.load()sẽ bọc pickle.dump()pickle.load()cho bạn.

Một state_dictcâu trả lời khác được đề cập xứng đáng chỉ là một vài ghi chú.

Có gì state_dictchúng ta có bên trong PyTorch? Thực tế có hai state_dicts.

Mô hình PyTorch torch.nn.Modulemodel.parameters()lệnh gọi để có được các tham số có thể học được (w và b). Các tham số có thể học được này, một khi được đặt ngẫu nhiên, sẽ cập nhật theo thời gian khi chúng ta tìm hiểu. Thông số có thể học được là đầu tiên state_dict.

Thứ hai state_dictlà dict nhà nước tối ưu hóa. Bạn nhớ lại rằng trình tối ưu hóa được sử dụng để cải thiện các tham số có thể học được của chúng tôi. Nhưng trình tối ưu hóa đã state_dictđược sửa. Không có gì để học trong đó.

Bởi vì state_dictcác đối tượng là từ điển Python, chúng có thể dễ dàng lưu, cập nhật, thay đổi và khôi phục, thêm rất nhiều mô-đun vào các mô hình và tối ưu hóa PyTorch.

Hãy tạo một mô hình siêu đơn giản để giải thích điều này:

import torch
import torch.optim as optim

model = torch.nn.Linear(5, 2)

# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

print("Model weight:")    
print(model.weight)

print("Model bias:")    
print(model.bias)

print("---")
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
    print(var_name, "\t", optimizer.state_dict()[var_name])

Mã này sẽ xuất ra như sau:

Model's state_dict:
weight   torch.Size([2, 5])
bias     torch.Size([2])
Model weight:
Parameter containing:
tensor([[ 0.1328,  0.1360,  0.1553, -0.1838, -0.0316],
        [ 0.0479,  0.1760,  0.1712,  0.2244,  0.1408]], requires_grad=True)
Model bias:
Parameter containing:
tensor([ 0.4112, -0.0733], requires_grad=True)
---
Optimizer's state_dict:
state    {}
param_groups     [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [140695321443856, 140695321443928]}]

Lưu ý đây là một mô hình tối thiểu. Bạn có thể thử thêm ngăn xếp tuần tự

model = torch.nn.Sequential(
          torch.nn.Linear(D_in, H),
          torch.nn.Conv2d(A, B, C)
          torch.nn.Linear(H, D_out),
        )

Lưu ý rằng chỉ các lớp có tham số có thể học được (lớp chập, lớp tuyến tính, v.v.) và bộ đệm đã đăng ký (lớp batchnorm) có các mục trong mô hình state_dict.

Những thứ không thể học được, thuộc về đối tượng tối ưu hóa state_dict, chứa thông tin về trạng thái của trình tối ưu hóa, cũng như các siêu đường kính được sử dụng.

Phần còn lại của câu chuyện là như nhau; trong giai đoạn suy luận (đây là giai đoạn khi chúng ta sử dụng mô hình sau khi đào tạo) để dự đoán; chúng tôi dự đoán dựa trên các thông số chúng tôi đã học. Vì vậy, để suy luận, chúng ta chỉ cần lưu các tham số model.state_dict().

torch.save(model.state_dict(), filepath)

Và để sử dụng mô hình sau.load_state_dict (Torch.load (filepath)) model.eval ()

Lưu ý: Đừng quên dòng cuối cùng model.eval()này là rất quan trọng sau khi tải mô hình.

Cũng đừng cố gắng tiết kiệm torch.save(model.parameters(), filepath). Đây model.parameters()chỉ là đối tượng máy phát điện.

Mặt khác, torch.save(model, filepath)lưu chính đối tượng mô hình, nhưng hãy nhớ rằng mô hình không có trình tối ưu hóa state_dict. Kiểm tra câu trả lời xuất sắc khác của @Jadiel de Armas để lưu chính tả trạng thái của trình tối ưu hóa.


Mặc dù nó không phải là một giải pháp đơn giản, nhưng bản chất của vấn đề được phân tích sâu sắc! Upvote.
Jason Young

7

Một quy ước PyTorch phổ biến là lưu các mô hình bằng cách sử dụng phần mở rộng tệp .pt hoặc .pth.

Lưu / Tải toàn bộ mô hình Lưu:

path = "username/directory/lstmmodelgpu.pth"
torch.save(trainer, path)

Tải:

Lớp mô hình phải được định nghĩa ở đâu đó

model = torch.load(PATH)
model.eval()

3

Nếu bạn muốn lưu mô hình và muốn tiếp tục đào tạo sau:

GPU đơn: Lưu:

state = {
        'epoch': epoch,
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict(),
}
savepath='checkpoint.t7'
torch.save(state,savepath)

Tải:

checkpoint = torch.load('checkpoint.t7')
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
epoch = checkpoint['epoch']

Nhiều GPU: Lưu

state = {
        'epoch': epoch,
        'state_dict': model.module.state_dict(),
        'optimizer': optimizer.state_dict(),
}
savepath='checkpoint.t7'
torch.save(state,savepath)

Tải:

checkpoint = torch.load('checkpoint.t7')
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
epoch = checkpoint['epoch']

#Don't call DataParallel before loading the model otherwise you will get an error

model = nn.DataParallel(model) #ignore the line if you want to load on Single GPU
Khi sử dụng trang web của chúng tôi, bạn xác nhận rằng bạn đã đọc và hiểu Chính sách cookieChính sách bảo mật của chúng tôi.
Licensed under cc by-sa 3.0 with attribution required.