Làm thế nào để lấy kích thước tensor Tensorflow (hình dạng) dưới dạng giá trị int?


88

Giả sử tôi có một tensor Tensorflow. Làm cách nào để lấy kích thước (hình dạng) của tenxơ dưới dạng giá trị nguyên? Tôi biết có hai phương pháp tensor.get_shape()tf.shape(tensor), nhưng tôi không thể lấy giá trị hình dạng dưới dạng giá trị số nguyên int32.

Ví dụ, bên dưới tôi đã tạo tensor 2-D và tôi cần lấy số hàng và cột int32để tôi có thể gọi reshape()để tạo ra tensor hình dạng (num_rows * num_cols, 1). Tuy nhiên, phương thức tensor.get_shape()trả về giá trị là Dimensionkiểu, không phải int32.

import tensorflow as tf
import numpy as np

sess = tf.Session()    
tensor = tf.convert_to_tensor(np.array([[1001,1002,1003],[3,4,5]]), dtype=tf.float32)

sess.run(tensor)    
# array([[ 1001.,  1002.,  1003.],
#        [    3.,     4.,     5.]], dtype=float32)

tensor_shape = tensor.get_shape()    
tensor_shape
# TensorShape([Dimension(2), Dimension(3)])    
print tensor_shape    
# (2, 3)

num_rows = tensor_shape[0] # ???
num_cols = tensor_shape[1] # ???

tensor2 = tf.reshape(tensor, (num_rows*num_cols, 1))    
# Traceback (most recent call last):
#   File "<stdin>", line 1, in <module>
#   File "/usr/local/lib/python2.7/site-packages/tensorflow/python/ops/gen_array_ops.py", line 1750, in reshape
#     name=name)
#   File "/usr/local/lib/python2.7/site-packages/tensorflow/python/framework/op_def_library.py", line 454, in apply_op
#     as_ref=input_arg.is_ref)
#   File "/usr/local/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 621, in convert_to_tensor
#     ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)
#   File "/usr/local/lib/python2.7/site-packages/tensorflow/python/framework/constant_op.py", line 180, in _constant_tensor_conversion_function
#     return constant(v, dtype=dtype, name=name)
#   File "/usr/local/lib/python2.7/site-packages/tensorflow/python/framework/constant_op.py", line 163, in constant
#     tensor_util.make_tensor_proto(value, dtype=dtype, shape=shape))
#   File "/usr/local/lib/python2.7/site-packages/tensorflow/python/framework/tensor_util.py", line 353, in make_tensor_proto
#     _AssertCompatible(values, dtype)
#   File "/usr/local/lib/python2.7/site-packages/tensorflow/python/framework/tensor_util.py", line 290, in _AssertCompatible
#     (dtype.name, repr(mismatch), type(mismatch).__name__))
# TypeError: Expected int32, got Dimension(6) of type 'Dimension' instead.

Câu trả lời:


126

Để có được hình dạng như một danh sách các int, hãy làm tensor.get_shape().as_list().

Để hoàn thành tf.shape()cuộc gọi của bạn , hãy thử tensor2 = tf.reshape(tensor, tf.TensorShape([num_rows*num_cols, 1])). Hoặc bạn có thể trực tiếp làm tensor2 = tf.reshape(tensor, tf.TensorShape([-1, 1]))ở nơi có thể suy ra thứ nguyên đầu tiên của nó.


Cảm ơn, điều đó cho phép tôi gọi và hoàn thành tf.reshape(), nhưng tôi thực sự muốn lấy num_rowsnum_colsdưới dạng số nguyên cho các hoạt động khác.
stackoverflowuser2010

6
Hãy thửtensor.get_shape().as_list()
yuefengz

1
Đúng vậy as_list(). Vui lòng thêm nó vào câu trả lời của bạn, và tôi sẽ chấp nhận.
stackoverflowuser2010

2
Để hoàn thiện, mã này hoạt động:num_rows, num_cols = x.get_shape().as_list()
stackoverflowuser2010

1
Đẹp! Tôi đang sử dụng python int () để truyền kết quả của x.get_shape (). tức là num_rows = int (x.get_shape () [1]), num_cols = int (x.get_shape () [2]), v.v. Vâng, thật khó để khắc phục lỗi khó chịu đó, nhưng nó đã hoạt động. Cảm ơn vì đã khai sáng cho tôi một cách tốt hơn :-)
SherylHohman

31

Một cách khác để giải quyết vấn đề này là như sau:

tensor_shape[0].value

Điều này sẽ trả về giá trị int của đối tượng Thứ nguyên.


6

đối với tensor 2-D, bạn có thể lấy số hàng và cột dưới dạng int32 bằng cách sử dụng mã sau:

rows, columns = map(lambda i: i.value, tensor.get_shape())

2
Rất không nhã nhặn. Làm thế nào để điều này thêm vào các câu trả lời đã được cung cấp?
rayryeng

4

2.0 Câu trả lời tương thích : Trong Tensorflow 2.x (2.1), bạn có thể lấy kích thước (hình dạng) của tensor dưới dạng giá trị số nguyên, như được hiển thị trong Mã bên dưới:

Phương pháp 1 (sử dụng tf.shape) :

import tensorflow as tf
c = tf.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
Shape = c.shape.as_list()
print(Shape)   # [2,3]

Phương pháp 2 (sử dụng tf.get_shape()) :

import tensorflow as tf
c = tf.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
Shape = c.get_shape().as_list()
print(Shape)   # [2,3]

1

Một giải pháp đơn giản khác là sử dụng map()như sau:

tensor_shape = map(int, my_tensor.shape)

Điều này chuyển đổi tất cả các Dimensionđối tượng thànhint


0

Trong các phiên bản sau (được thử nghiệm với TensorFlow 1.14), có một cách giống như numpy hơn để có được hình dạng của tensor. Bạn có thể sử dụng tensor.shapeđể có được hình dạng của tensor.

tensor_shape = tensor.shape
print(tensor_shape)
Khi sử dụng trang web của chúng tôi, bạn xác nhận rằng bạn đã đọc và hiểu Chính sách cookieChính sách bảo mật của chúng tôi.
Licensed under cc by-sa 3.0 with attribution required.