TensorFlow lưu vào / tải biểu đồ từ một tệp


98

Từ những gì tôi đã thu thập cho đến nay, có một số cách khác nhau để kết xuất đồ thị TensorFlow vào một tệp và sau đó tải nó vào một chương trình khác, nhưng tôi không thể tìm thấy các ví dụ / thông tin rõ ràng về cách chúng hoạt động. Những gì tôi đã biết là:

  1. Lưu các biến của mô hình vào tệp điểm kiểm tra (.ckpt) bằng cách sử dụng a tf.train.Saver()và khôi phục chúng sau ( nguồn )
  2. Lưu mô hình thành tệp .pb và tải lại bằng cách sử dụng tf.train.write_graph()tf.import_graph_def()( nguồn )
  3. Tải vào một mô hình từ tệp .pb, đào tạo lại và kết xuất mô hình đó vào tệp .pb mới bằng cách sử dụng Bazel ( nguồn )
  4. Cố định biểu đồ để lưu biểu đồ và trọng số cùng nhau ( nguồn )
  5. Sử dụng as_graph_def()để lưu mô hình và đối với trọng số / biến, ánh xạ chúng thành hằng số ( nguồn )

Tuy nhiên, tôi không thể giải đáp một số câu hỏi liên quan đến các phương pháp khác nhau này:

  1. Về các tệp điểm kiểm tra, chúng chỉ lưu các trọng lượng đã được đào tạo của một mô hình? Các tệp điểm kiểm tra có thể được tải vào một chương trình mới và được sử dụng để chạy mô hình hay chúng chỉ đơn giản là cách để lưu trọng số trong một mô hình tại một thời điểm / giai đoạn nhất định?
  2. Về vấn đề tf.train.write_graph(), các trọng số / biến có được lưu không?
  3. Về Bazel, nó chỉ có thể lưu vào / tải từ các tệp .pb để đào tạo lại? Có một lệnh Bazel đơn giản chỉ để kết xuất một biểu đồ thành .pb không?
  4. Về vấn đề đóng băng, có thể tải đồ thị được đóng băng khi sử dụng tf.import_graph_def()không?
  5. Bản demo Android cho TensorFlow tải trong mô hình Inception của Google từ tệp .pb. Nếu tôi muốn thay thế tệp .pb của chính mình, tôi sẽ làm như thế nào? Tôi có cần thay đổi bất kỳ mã / phương thức gốc nào không?
  6. Nói chung, sự khác biệt chính xác giữa tất cả các phương pháp này là gì? Hay rộng hơn, sự khác biệt giữa as_graph_def()/.ckpt/.pb là gì?

Tóm lại, những gì tôi đang tìm kiếm là một phương pháp để lưu cả một biểu đồ (như trong, các phép toán khác nhau, v.v.) và trọng số / biến của nó vào một tệp, sau đó có thể được sử dụng để tải biểu đồ và trọng số vào một chương trình khác , để sử dụng (không nhất thiết phải tiếp tục / đào tạo lại).

Tài liệu về chủ đề này không đơn giản lắm, vì vậy mọi câu trả lời / thông tin sẽ được đánh giá cao.


2
API mới nhất / hoàn chỉnh nhất là biểu đồ meta, sẽ cung cấp cho bạn cách lưu cả ba cùng một lúc - 1) biểu đồ 2) giá trị tham số 3) collection: tensorflow.org/versions/r0.10/how_tos/meta_graph/ index.html
Yaroslav Bulatov

Câu trả lời:


80

Có nhiều cách để tiếp cận vấn đề lưu một mô hình trong TensorFlow, điều này có thể khiến bạn hơi khó hiểu. Lần lượt trả lời từng câu hỏi phụ của bạn:

  1. Các tệp điểm kiểm tra (ví dụ: được tạo ra bằng cách gọi saver.save()một tf.train.Saverđối tượng) chỉ chứa các trọng số và bất kỳ biến nào khác được xác định trong cùng một chương trình. Để sử dụng chúng trong một chương trình khác, bạn phải tạo lại cấu trúc đồ thị được liên kết (ví dụ: bằng cách chạy mã để xây dựng lại hoặc gọi tf.import_graph_def()), điều này cho TensorFlow biết phải làm gì với các trọng số đó. Lưu ý rằng việc gọi saver.save()cũng tạo ra một tệp chứa a MetaGraphDef, tệp này chứa một biểu đồ và chi tiết về cách liên kết các trọng số từ một điểm kiểm tra với biểu đồ đó. Xem hướng dẫn để biết thêm chi tiết.

  2. tf.train.write_graph()chỉ viết cấu trúc đồ thị; không phải là trọng lượng.

  3. Bazel không liên quan đến việc đọc hoặc viết đồ thị TensorFlow. (Có lẽ tôi hiểu sai câu hỏi của bạn: vui lòng làm rõ nó trong một bình luận.)

  4. Một biểu đồ cố định có thể được tải bằng cách sử dụng tf.import_graph_def(). Trong trường hợp này, các trọng số (thường) được nhúng vào biểu đồ, vì vậy bạn không cần tải một điểm kiểm tra riêng biệt.

  5. Thay đổi chính sẽ là cập nhật tên của (các) tensor được đưa vào mô hình và tên của (các) tensor được lấy từ mô hình. Trong bản demo TensorFlow Android, điều này sẽ tương ứng với chuỗi inputNameoutputNamechuỗi được chuyển đến TensorFlowClassifier.initializeTensorFlow().

  6. Đây GraphDeflà cấu trúc chương trình, thường không thay đổi trong quá trình đào tạo. Điểm kiểm tra là một ảnh chụp nhanh trạng thái của quá trình đào tạo, trạng thái này thường thay đổi ở mọi bước của quá trình đào tạo. Do đó, TensorFlow sử dụng các định dạng lưu trữ khác nhau cho các loại dữ liệu này và API cấp thấp cung cấp các cách khác nhau để lưu và tải chúng. Các thư viện cấp cao hơn, chẳng hạn như MetaGraphDefthư viện, Kerasskflow xây dựng dựa trên các cơ chế này để cung cấp các cách thuận tiện hơn để lưu và khôi phục toàn bộ mô hình.


Điều này có nghĩa là tài liệu C ++ API nói dối, khi nó nói rằng bạn có thể tải biểu đồ đã lưu tf.train.write_graph()và sau đó thực thi nó?
mnicky

2
Tài liệu API C ++ không nói dối, nhưng nó thiếu một vài chi tiết. Chi tiết quan trọng nhất là, ngoài GraphDeflưu bởi tf.train.write_graph(), bạn cũng cần nhớ tên của tensor mà bạn muốn cấp và tìm nạp khi thực hiện biểu đồ (mục 5 ở trên).
mrry

@mrry: Tôi đã cố gắng sử dụng ví dụ về tensorflows DeepDream. nhưng có vẻ như nó cần các mô hình được đào tạo trước ở định dạng pb! Tôi đã chạy ví dụ Cifar10, nhưng nó chỉ tạo các điểm kiểm tra! Tôi không thể tìm thấy bất kỳ tệp pb nào hoặc bất kỳ thứ gì! làm cách nào để chuyển đổi các điểm kiểm tra của tôi sang định dạng pb mà ví dụ deepdream sử dụng?
Rika

2
@ Coderx7 Tôi thực sự nghĩ rằng bạn không thể chuyển đổi một .ckpt đến một .pb kể từ khi trạm kiểm soát chỉ chứa trọng lượng và các biến và không biết gì về cấu trúc của đồ thị
davidivad

1
có một mã đơn giản để tải một tệp .pb và sau đó chạy nó không?
Kong

1

Bạn có thể thử mã sau:

with tf.gfile.FastGFile('model/frozen_inference_graph.pb', "rb") as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    g_in = tf.import_graph_def(graph_def, name="")
sess = tf.Session(graph=g_in)
Khi sử dụng trang web của chúng tôi, bạn xác nhận rằng bạn đã đọc và hiểu Chính sách cookieChính sách bảo mật của chúng tôi.
Licensed under cc by-sa 3.0 with attribution required.