Tại sao LSTM hoạt động kém hơn trong việc chốt thông tin so với mạng nơ ron tái phát vanilla


8

Tôi muốn hiểu rõ hơn tại sao LSTM có thể nhớ thông tin trong một khoảng thời gian dài hơn vanilla / mạng thần kinh tái phát đơn giản (SRNN) bằng cách làm lại một thử nghiệm từ bài báo Học phụ thuộc lâu dài với Gradient Descent là khó khăn bởi Bengio et al. 1994 .

Xem hình 1. và 2 trên tờ giấy đó. Tác vụ rất đơn giản, được đưa ra một chuỗi, nếu nó bắt đầu với giá trị cao (ví dụ 1), thì nhãn đầu ra là 1; nếu nó bắt đầu với giá trị thấp (ví dụ -1), thì nhãn đầu ra là 0. Giữa là tiếng ồn. Nhiệm vụ này được gọi là chốt thông tin vì mô hình cần nhớ giá trị bắt đầu trong khi đi qua nhiễu trung bình để tạo ra nhãn chính xác. Nó đã sử dụng một nơron thần kinh duy nhất để xây dựng một mô hình thể hiện hành vi đó. Hình 2 (b) cho thấy kết quả và tần suất thành công của việc đào tạo một mô hình như vậy giảm đáng kể khi độ dài chuỗi tăng. Không có kết quả cho LSTM vì nó chưa được phát minh vào năm 1994.

Vì vậy, tôi trở nên tò mò và muốn thấy rằng nếu LSTM thực sự sẽ hoạt động tốt hơn cho một nhiệm vụ như vậy. Tương tự, tôi đã tạo ra một RNN nơ-ron duy nhất cho cả tế bào vanilla và LSTM để mô hình hóa thông tin chốt. Đáng ngạc nhiên, tôi thấy LSTM hoạt động kém hơn và tôi không biết tại sao. Ai đó có thể giúp tôi giải thích hoặc nếu có bất cứ điều gì sai với mã của tôi, xin vui lòng?

Đây là kết quả của tôi:

nhập mô tả hình ảnh ở đây

Đây là mã của tôi:

import matplotlib.pyplot as plt
import numpy as np    
from keras.models import Model
from keras.layers import Input, LSTM, Dense, SimpleRNN


N = 10000
num_repeats = 30
num_epochs = 5
# sequence length options
lens = [2, 5, 8, 10, 15, 20, 25, 30] + np.arange(30, 210, 10).tolist()

res = {}
for (RNN_CELL, key) in zip([SimpleRNN, LSTM], ['srnn', 'lstm']):
    res[key] = {}
    print(key, end=': ')
    for seq_len in lens:
        print(seq_len, end=',')
        xs = np.zeros((N, seq_len))
        ys = np.zeros(N)

        # construct input data
        positive_indexes = np.arange(N // 2)
        negative_indexes = np.arange(N // 2, N)

        xs[positive_indexes, 0] = 1
        ys[positive_indexes] = 1

        xs[negative_indexes, 0] = -1
        ys[negative_indexes] = 0

        noise = np.random.normal(loc=0, scale=0.1, size=(N, seq_len))

        train_xs = (xs + noise).reshape(N, seq_len, 1)
        train_ys = ys

        # repeat each experiments multiple times
        hists = []
        for i in range(num_repeats):
            inputs = Input(shape=(None, 1), name='input')

            rnn = RNN_CELL(1, input_shape=(None, 1), name='rnn')(inputs)
            out = Dense(2, activation='softmax', name='output')(rnn)
            model = Model(inputs, out)
            model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
            hist = model.fit(train_xs, train_ys, epochs=num_epochs, shuffle=True, validation_split=0.2, batch_size=16, verbose=0)
            hists.append(hist.history['val_acc'][-1])
        res[key][seq_len] = hists
    print()


fig = plt.figure()
ax = fig.add_subplot(111)
ax.plot(pd.DataFrame.from_dict(res['lstm']).mean(), label='lstm')
ax.plot(pd.DataFrame.from_dict(res['srnn']).mean(), label='srnn')
ax.legend()

Tôi cũng có kết quả hiển thị trong sổ ghi chép , sẽ thuận tiện nếu bạn muốn sao chép kết quả. Phải mất một ngày để chạy thử nghiệm trên máy của tôi chỉ sử dụng CPU. Nó có thể nhanh hơn trên máy hỗ trợ GPU.

Cập nhật 2018-04-18 :

Tôi đã cố gắng tái tạo một hình ảnh về cảnh quan của RNN lấy cảm hứng từ Hình 6 trong Về khó khăn trong việc đào tạo Mạng thần kinh tái phát . Tôi thấy thú vị khi thấy sự hình thành của vách đá trong bối cảnh mất mát khi số lần tái phát / bước thời gian / độ dài chuỗi tăng lên, có thể liên quan đến việc giải thích sự khó khăn của việc huấn luyện các chuỗi dài được quan sát ở đây. Thêm chi tiết có sẵn ở đây .

nhập mô tả hình ảnh ở đây

Cập nhật 2018-04-19

Mở rộng thử nghiệm của @ shimao. Có vẻ như LSTM và GRU không giỏi trong việc nắm bắt thông tin. Nhưng chuyển sang một nhiệm vụ khác, mà tôi gọi là chuyển tiếp bit, (nhiệm vụ 2 của @ shimao), GRU hoạt động tốt hơn trong khi SRNN và LSTM đều tệ như nhau.

Bây giờ, tôi có xu hướng nghĩ rằng hiệu suất của một loại tế bào có thể là nhiệm vụ cụ thể.

Nhiệm vụ 1: chốt thông tin (1 đơn vị; 10 lần lặp lại; 10 kỷ nguyên)

nhập mô tả hình ảnh ở đây

Nhiệm vụ 2: rơle bit (8 đơn vị; 10 lần lặp lại; 10 epoch)

nhập mô tả hình ảnh ở đây

Thanh lỗi là độ lệch chuẩn.

Sau đó, một câu hỏi hấp dẫn là tại sao LSTM không hoạt động trong việc chốt thông tin. Với sự đơn giản của nhiệm vụ, nó sẽ có thể hoạt động, phải không? Có thể liên quan đến cảnh quan (ví dụ Vách đá) đối với độ dốc của nó.


Có vẻ như bạn đang sử dụng RNN một đơn vị trong cả hai trường hợp. Bạn đã thử làm cho nó lớn hơn, nói 10 hoặc 100?
Alex R.

1
Không, ý tưởng từ bài báo gốc là thiết kế hệ thống đơn giản nhất để chúng ta có thể suy luận một cách hiệu quả. Động lực của việc sử dụng nhiều tế bào là gì? Nhiệm vụ rất đơn giản và bạn có thể thấy rằng khi chuỗi ngắn, cả LSTM và SRNN đều hoạt động tốt.
zyxue

Tôi chưa đọc kỹ bài viết này (hoặc gần đây) nhưng bạn có chắc chắn rằng bạn đang tái tạo chính xác các điều kiện thí nghiệm? Có vẻ hợp lý rằng có bất kỳ số lượng khác biệt nhỏ nhưng quan trọng giữa thiết lập của bạn và giấy có thể gây ra lỗi. Hoặc, đã nêu một cách khác, về cơ bản đây là một vấn đề gỡ lỗi trong đó chỉ có một bài kiểm tra đơn vị, tương đối không chính xác.
Sycorax nói Phục hồi lại

@Sycorax, bài báo chỉ nói về vanilla RNN vì LSTM chưa được phát minh vào năm 1994. Kết quả của tôi về vanilla RNN phù hợp với họ, đó là lý do tại sao LSTM hoạt động kém hơn khiến tôi bối rối.
zyxue

Giấy tờ này.nips.cc/apers/ trộm có thể được quan tâm - nó giải quyết các vấn đề phụ thuộc lâu dài được đưa ra bởi Bengio trong bài báo mà bạn trích dẫn, do đó, sao chép thử nghiệm được sử dụng bởi Schmidhuber có thể là một cách hiệu quả hơn
Sycorax nói Phục hồi lại

Câu trả lời:


3

Có một lỗi trong mã của bạn, vì nửa đầu của các ví dụ được xây dựng của bạn là dương và phần còn lại là âm, nhưng các máy ảnh không xáo trộn trước khi chia dữ liệu thành train và val, có nghĩa là tất cả các bộ val đều âm và tập huấn luyện thiên về hướng tích cực, đó là lý do tại sao bạn nhận được kết quả lạ như độ chính xác 0 (tệ hơn cả cơ hội).

Ngoài ra, tôi đã điều chỉnh một số tham số (như tốc độ học tập, số lượng kỷ nguyên và kích thước lô) để đảm bảo đào tạo luôn hội tụ.

Cuối cùng, tôi chỉ chạy trong 5 và 100 bước thời gian để tiết kiệm tính toán.

nhập mô tả hình ảnh ở đây

Thật kỳ lạ, LSTM không đào tạo đúng cách, mặc dù GRU gần như hoạt động tốt như RNN.

Tôi đã thử với một nhiệm vụ khó khăn hơn một chút: trong các chuỗi tích cực, dấu hiệu của yếu tố đầu tiên và một yếu tố nửa chừng trong chuỗi là như nhau (cả +1 hoặc cả -1), trong các chuỗi âm, các dấu hiệu khác nhau. Tôi đã hy vọng rằng các ô nhớ bổ sung trong LSTM sẽ có lợi ở đây

nhập mô tả hình ảnh ở đây

Nó đã kết thúc hoạt động tốt hơn RNN, nhưng chỉ một chút và GRU thắng vì một số lý do.

Tôi không có câu trả lời đầy đủ về lý do tại sao RNN làm tốt hơn LSTM trong nhiệm vụ đơn giản. Tôi nghĩ rằng chúng ta phải tìm ra các siêu đường kính phù hợp để huấn luyện LSTM đúng cách, ngoài ra thực tế là vấn đề dễ dàng đối với một RNN đơn giản. Có thể, một mô hình với rất ít tham số cũng dễ bị kẹt ở mức tối thiểu cục bộ.

sửa đổi


Tôi mở rộng thử nghiệm của bạn và thu được kết quả phù hợp. Tôi nghĩ rằng hiệu suất có thể là nhiệm vụ cụ thể.
zyxue

0

Tôi không biết liệu nó có khác biệt không, nhưng tôi sẽ sử dụng:

out = Dense (1, activation='sigmoid', ...

model.compile(loss='binary_crossentropy', ...

cho một vấn đề phân loại nhị phân.


1
Tôi tin rằng về mặt toán học hai là như nhau. Đặc biệt, sigmoid tương đương với softmax khi chỉ có hai lớp. binary_crossentropy là sau tất cả crossentropy cho hai lớp. Vấn đề này thực sự làm tôi bối rối, và tôi sẵn sàng tăng tiền thưởng nếu có ai có thể đưa ra một lời giải thích hay.
zyxue

Bạn đang sử dụng các tham số tương tự cho mỗi phiên bản? LSTM phức tạp hơn và có thể cần nhiều dữ liệu hơn, nhiều kỷ nguyên hơn, tỷ lệ học tập khác nhau, v.v ... @ zyxue
Wayne
Khi sử dụng trang web của chúng tôi, bạn xác nhận rằng bạn đã đọc và hiểu Chính sách cookieChính sách bảo mật của chúng tôi.
Licensed under cc by-sa 3.0 with attribution required.