Làm thế nào để liệt kê tất cả các hoạt động được sử dụng trong Tensorflow SavingModel?


10

Nếu tôi lưu mô hình của mình bằng cách sử dụng tensorflow.saved_model.savechức năng ở định dạng SavingModel, làm thế nào tôi có thể truy xuất Ops Tensorflow nào được sử dụng trong mô hình này sau đó. Khi mô hình có thể được khôi phục, các hoạt động này được lưu trữ trong biểu đồ, tôi đoán là trong saved_model.pbtệp. Nếu tôi tải protobuf này (không phải toàn bộ mô hình) thì phần thư viện của protobuf liệt kê những thứ này, nhưng hiện tại nó không được ghi lại và được gắn thẻ như một tính năng thử nghiệm. Các mô hình được tạo trong Tensorflow 1.x sẽ không có phần này.

Vậy đâu là cách nhanh chóng và đáng tin cậy để lấy danh sách các Hoạt động đã sử dụng (Thích MatchingFileshoặc WriteFile) từ một mô hình ở định dạng SavingModel?

Ngay bây giờ tôi có thể đóng băng toàn bộ, như thế tensorflowjs-converter. Vì họ cũng kiểm tra các hoạt động được hỗ trợ. Điều này hiện không hoạt động khi một LSTM trong mô hình, xem tại đây . Có cách nào tốt hơn để làm điều này không, vì Ops chắc chắn ở trong đó?

Một mô hình ví dụ:

class FileReader(tf.Module):

@tf.function(input_signature=[tf.TensorSpec(name='filename', shape=[None], dtype=tf.string)])
def read_disk(self, file_name):
    input_scalar = tf.reshape(file_name, [])
    output = tf.io.read_file(input_scalar)
    return tf.stack([output], name='content')

file_reader = FileReader()

tf.saved_model.save(file_reader, 'file_reader')

Dự kiến ​​trong đầu ra tất cả các Ops, chứa trong trường hợp này ít nhất là:

  • ReadFilenhư được mô tả ở đây
  • ...

1
Thật khó để nói chính xác những gì bạn muốn, saved_model.pbnó là gì tf.GraphDef, hoặc là một SavedModeltin nhắn protobuf? Nếu bạn có một cuộc tf.GraphDefgọi gd, bạn có thể nhận được danh sách các op được sử dụng với sorted(set(n.op for n in gd.node)). Nếu bạn có một mô hình được tải, bạn có thể làm sorted(set(op.type for op in tf.get_default_graph().get_operations())). Nếu nó là một SavedModel, bạn có thể lấy tf.GraphDeftừ nó (ví dụ saved_model.meta_graphs[0].graph_def).
jdehesa

Tôi muốn lấy ops từ một SavingModel được lưu trữ. Vì vậy, thực sự, tùy chọn cuối cùng bạn đang mô tả. Là gì saved_modelbiến trong ví dụ cuối cùng của bạn? Kết quả tf.saved_model.load('/path/to/model')hoặc tải protobuf của tệp save_model.pb.
sampers

Câu trả lời:


1

Nếu saved_model.pblà một SavedModelthông điệp protobuf, thì bạn có được các hoạt động trực tiếp từ đó. Giả sử chúng ta tạo một mô hình như sau:

import tensorflow as tf

class FileReader(tf.Module):
    @tf.function(input_signature=[tf.TensorSpec(name='filename', shape=[None], dtype=tf.string)])
    def read_disk(self, file_name):
        input_scalar = tf.reshape(file_name, [])
        output = tf.io.read_file(input_scalar)
        return tf.stack([output], name='content')

file_reader = FileReader()
tf.saved_model.save(file_reader, 'tmp')

Bây giờ chúng ta có thể tìm thấy các hoạt động được sử dụng bởi mô hình đó như thế này:

from tensorflow.core.protobuf.saved_model_pb2 import SavedModel

saved_model = SavedModel()
with open('tmp/saved_model.pb', 'rb') as f:
    saved_model.ParseFromString(f.read())
model_op_names = set()
# Iterate over every metagraph in case there is more than one
for meta_graph in saved_model.meta_graphs:
    # Add operations in the graph definition
    model_op_names.update(node.op for node in meta_graph.graph_def.node)
    # Go through the functions in the graph definition
    for func in meta_graph.graph_def.library.function:
        # Add operations in each function
        model_op_names.update(node.op for node in func.node_def)
# Convert to list, sorted if you want
model_op_names = sorted(model_op_names)
print(*model_op_names, sep='\n')
# Const
# Identity
# MergeV2Checkpoints
# NoOp
# Pack
# PartitionedCall
# Placeholder
# ReadFile
# Reshape
# RestoreV2
# SaveV2
# ShardedFilename
# StatefulPartitionedCall
# StringJoin

Tôi đã thử một cái gì đó như thế này, nhưng thật không may, đây không phải là điều tôi mong đợi: Giả sử tôi có một mô hình thực hiện điều này: input_scalar = tf.reshape(file_name, []) output = tf.io.read_file(input_scalar) return tf.stack([output], name='content')Sau đó, ReadFile Op như được liệt kê ở đây có trong đó, nhưng không được in.
sampers

1
@sampers Tôi đã chỉnh sửa câu trả lời với một ví dụ như bạn đề xuất. Tôi nhận được các ReadFilehoạt động trong đầu ra. Có thể là, trong trường hợp thực tế của bạn, hoạt động đó không nằm giữa đầu vào và đầu ra của mô hình đã lưu? Trong trường hợp đó tôi nghĩ rằng nó có thể được cắt tỉa.
jdehesa

Thật vậy với mô hình đã cho, nó hoạt động. Thật không may cho một mô-đun được thực hiện trong tf2, nó không. Nếu tôi tạo một tf.Module với 1 hàm với chú thích file_nameđối số @tf.function, có chứa các cuộc gọi tôi đã liệt kê trong nhận xét trước đó, nó sẽ đưa ra danh sách sau:Const, NoOp, PartitionedCall, Placeholder, StatefulPartitionedCall
sampers

đã thêm một mô hình vào câu hỏi của tôi
sampers

@sampers Mình đã cập nhật câu trả lời. Tôi đã sử dụng TF 1.x trước đây, tôi không quen với các thay đổi đối với các đối tượng định nghĩa biểu đồ trong TF 2.x, tôi nghĩ rằng câu trả lời hiện bao gồm mọi thứ trong mô hình đã lưu. Tôi nghĩ rằng các hoạt động tương ứng với hàm Python mà bạn đã viết nằm trong saved_model.meta_graphs[0].graph_def.library.function[0]( node_defbộ sưu tập trong đối tượng hàm đó).
jdehesa
Khi sử dụng trang web của chúng tôi, bạn xác nhận rằng bạn đã đọc và hiểu Chính sách cookieChính sách bảo mật của chúng tôi.
Licensed under cc by-sa 3.0 with attribution required.