Tôi đang đào tạo một mạng nơ ron tích chập đơn giản để hồi quy, trong đó nhiệm vụ là dự đoán vị trí (x, y) của một hộp trong một hình ảnh, ví dụ:
Đầu ra của mạng có hai nút, một cho x và một cho y. Phần còn lại của mạng là một mạng nơ ron tích chập tiêu chuẩn. Mất mát là một lỗi bình phương trung bình tiêu chuẩn giữa vị trí dự đoán của hộp và vị trí sự thật mặt đất. Tôi đang đào tạo trên 10000 hình ảnh này và xác nhận vào năm 2000.
Vấn đề tôi đang gặp phải là ngay cả sau khi được đào tạo bài bản, tổn thất vẫn không thực sự giảm. Sau khi quan sát đầu ra của mạng, tôi nhận thấy rằng mạng có xu hướng xuất các giá trị gần bằng 0, cho cả hai nút đầu ra. Do đó, dự đoán vị trí của hộp luôn là trung tâm của hình ảnh. Có một số sai lệch trong các dự đoán, nhưng luôn ở khoảng không. Dưới đây cho thấy sự mất mát:
Tôi đã chạy nó trong nhiều kỷ nguyên hơn so với hiển thị trong biểu đồ này và tổn thất vẫn không bao giờ giảm. Điều thú vị ở đây, sự mất mát thực sự tăng ở một điểm.
Vì vậy, có vẻ như mạng chỉ dự đoán mức trung bình của dữ liệu đào tạo, thay vì học một cách phù hợp. Bất kỳ ý tưởng về lý do tại sao điều này có thể được? Tôi đang sử dụng Adam làm trình tối ưu hóa, với tỷ lệ học tập ban đầu là 0,01 và được sử dụng làm kích hoạt
Nếu bạn quan tâm đến một số mã của tôi (Keras), thì đây là:
# Create the model
model = Sequential()
model.add(Convolution2D(32, 5, 5, border_mode='same', subsample=(2, 2), activation='relu', input_shape=(3, image_width, image_height)))
model.add(Convolution2D(64, 5, 5, border_mode='same', subsample=(2, 2), activation='relu'))
model.add(Convolution2D(128, 5, 5, border_mode='same', subsample=(2, 2), activation='relu'))
model.add(Flatten())
model.add(Dense(100, activation='relu'))
model.add(Dense(2, activation='linear'))
# Compile the model
adam = Adam(lr=0.01, beta_1=0.9, beta_2=0.999, epsilon=1e-08, decay=0.0)
model.compile(loss='mean_squared_error', optimizer=adam)
# Fit the model
model.fit(images, targets, batch_size=128, nb_epoch=1000, verbose=1, callbacks=[plot_callback], validation_split=0.2, shuffle=True)