Đầu ra của một tf.nn.dynamic_rnn () là gì?


8

Tôi không chắc chắn về những gì tôi hiểu từ tài liệu chính thức, trong đó nói:

Trả về: Một cặp (đầu ra, trạng thái) trong đó:

outputs: Tenor đầu ra RNN.

Nếu time_major == False(mặc định), đây sẽ là một hình dạng kéo căng : [batch_size, max_time, cell.output_size].

Nếu time_major == True, đây sẽ là một hình dạng Tenor : [max_time, batch_size, cell.output_size].

Lưu ý, nếu cell.output_sizelà một tuple (có thể lồng nhau) các số nguyên hoặc các đối tượng TensorShape, thì các đầu ra sẽ là một tuple có cấu trúc giống như cell.output_size, chứa các Tenor có hình dạng tương ứng với dữ liệu hình dạng cell.output_size.

state: Trạng thái cuối cùng. Nếu cell.state_size là một int, nó sẽ được định hình [batch_size, cell.state_size]. Nếu nó là một TensorShape, nó sẽ được định hình [batch_size] + cell.state_size. Nếu nó là một tuple (có thể lồng nhau) của ints hoặc TensorShape, đây sẽ là một tuple có các hình dạng tương ứng. Nếu các ô là trạng thái LSTMCells sẽ là một tuple chứa LSTMStateTuple cho mỗi ô.

Có phải output[-1] luôn luôn (trong cả ba loại ô, ví dụ RNN, GRU, LSTM) bằng trạng thái (phần tử thứ hai của bộ trả về)? Tôi đoán rằng văn học ở khắp mọi nơi là quá tự do trong việc sử dụng thuật ngữ ẩn trạng thái. Là trạng thái ẩn trong cả ba ô, điểm số xuất hiện (tại sao nó được gọi là ẩn nằm ngoài tôi, nó sẽ xuất hiện trạng thái ô trong LSTM nên được gọi là trạng thái ẩn vì nó không bị lộ)?

Câu trả lời:


10

Có, đầu ra tế bào bằng với trạng thái ẩn. Trong trường hợp của LSTM, đó là phần ngắn hạn của bộ dữ liệu (yếu tố thứ hai LSTMStateTuple), như có thể thấy trong hình này:

LSTM

Nhưng đối với tf.nn.dynamic_rnn, trạng thái trả về có thể khác khi chuỗi ngắn hơn ( sequence_lengthđối số). Hãy xem ví dụ này:

n_steps = 2
n_inputs = 3
n_neurons = 5

X = tf.placeholder(dtype=tf.float32, shape=[None, n_steps, n_inputs])
seq_length = tf.placeholder(tf.int32, [None])

basic_cell = tf.nn.rnn_cell.BasicRNNCell(num_units=n_neurons)
outputs, states = tf.nn.dynamic_rnn(basic_cell, X, sequence_length=seq_length, dtype=tf.float32)

X_batch = np.array([
  # t = 0      t = 1
  [[0, 1, 2], [9, 8, 7]], # instance 0
  [[3, 4, 5], [0, 0, 0]], # instance 1
  [[6, 7, 8], [6, 5, 4]], # instance 2
  [[9, 0, 1], [3, 2, 1]], # instance 3
])
seq_length_batch = np.array([2, 1, 2, 2])

with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  outputs_val, states_val = sess.run([outputs, states], 
                                     feed_dict={X: X_batch, seq_length: seq_length_batch})

  print(outputs_val)
  print()
  print(states_val)

Ở đây, lô đầu vào chứa 4 chuỗi và một trong số chúng ngắn và được đệm bằng số không. Khi chạy bạn nên một cái gì đó như thế này:

[[[ 0.2315362  -0.37939444 -0.625332   -0.80235624  0.2288385 ]
  [ 0.9999524   0.99987394  0.33580178 -0.9981791   0.99975705]]

 [[ 0.97374666  0.8373545  -0.7455188  -0.98751736  0.9658986 ]
  [ 0.          0.          0.          0.          0.        ]]

 [[ 0.9994331   0.9929737  -0.8311569  -0.99928087  0.9990415 ]
  [ 0.9984355   0.9936006   0.3662448  -0.87244385  0.993848  ]]

 [[ 0.9962312   0.99659646  0.98880637  0.99548346  0.9997809 ]
  [ 0.9915743   0.9936939   0.4348318   0.8798458   0.95265496]]]

[[ 0.9999524   0.99987394  0.33580178 -0.9981791   0.99975705]
 [ 0.97374666  0.8373545  -0.7455188  -0.98751736  0.9658986 ]
 [ 0.9984355   0.9936006   0.3662448  -0.87244385  0.993848  ]
 [ 0.9915743   0.9936939   0.4348318   0.8798458   0.95265496]]

... Điều đó thực sự cho thấy rằng state == output[1]đối với các chuỗi đầy đủ và state == output[0]ngắn. Cũng output[1]là một vector không cho chuỗi này. Điều tương tự giữ cho các tế bào LSTM và GRU.

Vì vậy, statemột tenx thuận tiện giữ trạng thái RNN thực tế cuối cùng , bỏ qua các số không. Các outputtensor giữ kết quả của tất cả các tế bào, vì vậy nó không bỏ qua những số không. Đó là lý do để trả lại cả hai.


2

Bản sao có thể của /programming/36817596/get-last-output-of-dynamic-rnn-in-tensorflow/49705930#49705930

Dù sao đi nữa, hãy tiếp tục với câu trả lời.

Đoạn mã này có thể giúp hiểu những gì thực sự được trả về bởi dynamic_rnnlớp

=> Tuple của (đầu ra, FinalDefput_state) .

Vì vậy, đối với một đầu vào có độ dài chuỗi tối đa của các bước thời gian T, các đầu ra có hình dạng [Batch_size, T, num_inputs](đã cho time_major= Sai; giá trị mặc định) và nó chứa trạng thái đầu ra ở mỗi dấu thời gian h1, h2.....hT.

FinalDefput_state có hình dạng [Batch_size,num_inputs]và có trạng thái ô cuối cùng cTvà trạng thái đầu ra hTcủa từng chuỗi lô.

Nhưng vì dynamic_rnnnó đang được sử dụng nên tôi đoán là độ dài chuỗi của bạn thay đổi theo từng đợt.

    import tensorflow as tf
    import numpy as np
    from tensorflow.contrib import rnn
    tf.reset_default_graph()

    # Create input data
    X = np.random.randn(2, 10, 8)

    # The second example is of length 6 
    X[1,6:] = 0
    X_lengths = [10, 6]

    cell = tf.nn.rnn_cell.LSTMCell(num_units=64, state_is_tuple=True)

    outputs, states  = tf.nn.dynamic_rnn(cell=cell,
                                         dtype=tf.float64,
                                         sequence_length=X_lengths,
                                         inputs=X)

    result = tf.contrib.learn.run_n({"outputs": outputs, "states":states},
                                    n=1,
                                    feed_dict=None)
    assert result[0]["outputs"].shape == (2, 10, 64)
    print result[0]["outputs"].shape
    print result[0]["states"].h.shape
    # the final outputs state and states returned must be equal for each      
    # sequence
    assert(result[0]["outputs"][0][-1]==result[0]["states"].h[0]).all()
    assert(result[0]["outputs"][-1][5]==result[0]["states"].h[-1]).all()
    assert(result[0]["outputs"][-1][-1]==result[0]["states"].h[-1]).all()

Khẳng định cuối cùng sẽ thất bại vì trạng thái cuối cùng cho chuỗi thứ 2 là ở bước thứ 6 tức là. chỉ số 5 và phần còn lại của đầu ra từ [6: 9] đều là 0 trong dấu thời gian thứ 2

Khi sử dụng trang web của chúng tôi, bạn xác nhận rằng bạn đã đọc và hiểu Chính sách cookieChính sách bảo mật của chúng tôi.
Licensed under cc by-sa 3.0 with attribution required.