TensorFlow, tại sao có 3 tệp sau khi lưu mô hình?


113

Sau khi đọc tài liệu , tôi đã lưu một mô hình vào TensorFlow, đây là mã demo của tôi:

# Create some variables.
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")
...
# Add an op to initialize the variables.
init_op = tf.global_variables_initializer()

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, initialize the variables, do some work, save the
# variables to disk.
with tf.Session() as sess:
  sess.run(init_op)
  # Do some work with the model.
  ..
  # Save the variables to disk.
  save_path = saver.save(sess, "/tmp/model.ckpt")
  print("Model saved in file: %s" % save_path)

nhưng sau đó, tôi thấy có 3 tệp

model.ckpt.data-00000-of-00001
model.ckpt.index
model.ckpt.meta

Và tôi không thể khôi phục mô hình bằng cách khôi phục model.ckpttệp, vì không có tệp nào như vậy. Đây là mã của tôi

with tf.Session() as sess:
  # Restore variables from disk.
  saver.restore(sess, "/tmp/model.ckpt")

Vì vậy, tại sao có 3 tệp?


2
Bạn đã tìm ra cách giải quyết vấn đề này chưa? Làm cách nào để tải lại mô hình (sử dụng Keras)?
rajkiran 11/1017

Câu trả lời:


116

Thử cái này:

with tf.Session() as sess:
    saver = tf.train.import_meta_graph('/tmp/model.ckpt.meta')
    saver.restore(sess, "/tmp/model.ckpt")

Phương thức lưu TensorFlow lưu ba loại tệp vì nó lưu trữ cấu trúc đồ thị riêng biệt với các giá trị biến . Các .metatập tin mô tả cấu trúc đồ thị lưu, vì vậy bạn cần phải nhập nó trước khi khôi phục các trạm kiểm soát (nếu không nó không biết những gì biến các giá trị trạm kiểm soát lưu tương ứng với).

Ngoài ra, bạn có thể làm điều này:

# Recreate the EXACT SAME variables
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")

...

# Now load the checkpoint variable values
with tf.Session() as sess:
    saver = tf.train.Saver()
    saver.restore(sess, "/tmp/model.ckpt")

Mặc dù không có tệp nào được đặt tên model.ckpt, bạn vẫn tham chiếu đến trạm kiểm soát đã lưu theo tên đó khi khôi phục nó. Từ saver.pymã nguồn :

Người dùng chỉ cần tương tác với tiền tố do người dùng chỉ định ... thay vì bất kỳ tên đường dẫn vật lý nào.


1
vậy .index và .data không được sử dụng? Khi nào thì 2 tệp đó được sử dụng?
ajfbiw.s

26
@ ajfbiw.s .meta lưu trữ cấu trúc đồ thị, .data lưu trữ các giá trị của từng biến trong biểu đồ, .index xác định điểm kiểm tra. Vì vậy, trong ví dụ trên: import_meta_graph sử dụng .meta và saver.restore sử dụng .data và .index
TK Bartel

Ồ, tôi hiểu rồi. Cảm ơn.
ajfbiw.s

1
Có cơ hội nào bạn lưu mô hình bằng phiên bản TensorFlow khác với phiên bản bạn đang sử dụng để tải nó không? ( github.com/tensorflow/tensorflow/issues/5639 )
TK Bartel

5
Có ai biết điều đó 00000và những 00001con số có nghĩa là gì? trong variables.data-?????-of-?????hồ sơ
Ivan Talalaev

55
  • tệp meta : mô tả cấu trúc đồ thị đã lưu, bao gồm GraphDef, SaverDef, v.v. sau đó áp dụng tf.train.import_meta_graph('/tmp/model.ckpt.meta'), sẽ khôi phục SaverGraph.

  • tệp chỉ mục : nó là một bảng bất biến chuỗi-chuỗi (tensorflow :: table :: Table). Mỗi khóa là tên của một tensor và giá trị của nó là một BundleEntryProto được tuần tự hóa. Mỗi BundleEntryProto mô tả siêu dữ liệu của tensor: tệp "dữ liệu" nào chứa nội dung của tensor, phần bù vào tệp đó, tổng kiểm tra, một số dữ liệu phụ trợ, v.v.

  • tệp dữ liệu : nó là tập hợp TensorBundle, lưu giá trị của tất cả các biến.


Tôi đã có tệp pb mà tôi có để phân loại hình ảnh. Tôi có thể sử dụng nó để phân loại video theo thời gian thực không?

Bạn có thể vui lòng cho tôi biết, Sử dụng Keras 2, làm cách nào để tải mô hình nếu nó được lưu dưới dạng 3 tệp?
rajkiran

5

Tôi đang khôi phục embeddings từ đào tạo từ Word2Vec tensorflow hướng dẫn.

Trong trường hợp bạn đã tạo nhiều điểm kiểm tra:

ví dụ: các tệp được tạo trông như thế này

model.ckpt-55695.data-00000-of-00001

model.ckpt-55695.index

model.ckpt-55695.meta

thử cái này

def restore_session(self, session):
   saver = tf.train.import_meta_graph('./tmp/model.ckpt-55695.meta')
   saver.restore(session, './tmp/model.ckpt-55695')

khi gọi restore_session ():

def test_word2vec():
   opts = Options()    
   with tf.Graph().as_default(), tf.Session() as session:
       with tf.device("/cpu:0"):            
           model = Word2Vec(opts, session)
           model.restore_session(session)
           model.get_embedding("assistance")

"00000-of-00001" trong "model.ckpt-55695.data-00000-of-00001" có nghĩa là gì?
hafiz031

0

Ví dụ: nếu bạn đã đào tạo một CNN bị bỏ học, bạn có thể làm điều này:

def predict(image, model_name):
    """
    image -> single image, (width, height, channels)
    model_name -> model file that was saved without any extensions
    """
    with tf.Session() as sess:
        saver = tf.train.import_meta_graph('./' + model_name + '.meta')
        saver.restore(sess, './' + model_name)
        # Substitute 'logits' with your model
        prediction = tf.argmax(logits, 1)
        # 'x' is what you defined it to be. In my case it is a batch of RGB images, that's why I add the extra dimension
        return prediction.eval(feed_dict={x: image[np.newaxis,:,:,:], keep_prob_dnn: 1.0})
Khi sử dụng trang web của chúng tôi, bạn xác nhận rằng bạn đã đọc và hiểu Chính sách cookieChính sách bảo mật của chúng tôi.
Licensed under cc by-sa 3.0 with attribution required.