Vai trò của "Flatten" trong Keras là gì?


108

Tôi đang cố gắng hiểu vai trò của Flatten hàm trong Keras. Dưới đây là mã của tôi, đó là một mạng hai lớp đơn giản. Nó lấy dữ liệu 2 chiều về hình dạng (3, 2) và xuất ra dữ liệu 1 chiều về hình dạng (1, 4):

model = Sequential()
model.add(Dense(16, input_shape=(3, 2)))
model.add(Activation('relu'))
model.add(Flatten())
model.add(Dense(4))
model.compile(loss='mean_squared_error', optimizer='SGD')

x = np.array([[[1, 2], [3, 4], [5, 6]]])

y = model.predict(x)

print y.shape

Điều này in ra ycó hình dạng (1, 4). Tuy nhiên, nếu tôi xóaFlatten dòng, thì nó sẽ in ra ycó hình dạng (1, 3, 4).

Tôi không hiểu điều này. Từ hiểu biết của tôi về mạng nơ-ron,model.add(Dense(16, input_shape=(3, 2))) chức năng đang tạo một lớp được kết nối đầy đủ ẩn, với 16 nút. Mỗi nút này được kết nối với mỗi phần tử đầu vào 3x2. Do đó, 16 nút ở đầu ra của lớp đầu tiên này đã "phẳng". Vì vậy, hình dạng đầu ra của lớp đầu tiên phải là (1, 16). Sau đó, lớp thứ hai lấy đây làm đầu vào và xuất dữ liệu về hình dạng (1, 4).

Vì vậy, nếu đầu ra của lớp đầu tiên đã "phẳng" và có hình dạng (1, 16), tại sao tôi cần phải làm phẳng nó hơn nữa?

Câu trả lời:


123

Nếu bạn đọc mục nhập tài liệu Keras cho Dense, bạn sẽ thấy rằng lệnh gọi này:

Dense(16, input_shape=(5,3))

sẽ dẫn đến một Densemạng có 3 đầu vào và 16 đầu ra sẽ được áp dụng độc lập cho mỗi bước trong số 5 bước. Vì vậy, nếu D(x)chuyển đổi vectơ 3 chiều thành vectơ 16-d, những gì bạn sẽ nhận được dưới dạng đầu ra từ lớp của bạn sẽ là một chuỗi các vectơ: [D(x[0,:]), D(x[1,:]),..., D(x[4,:])]với hình dạng (5, 16). Để có hành vi mà bạn chỉ định, trước tiên bạn có thể Flattennhập vào vectơ 15-d và sau đó áp dụng Dense:

model = Sequential()
model.add(Flatten(input_shape=(3, 2)))
model.add(Dense(16))
model.add(Activation('relu'))
model.add(Dense(4))
model.compile(loss='mean_squared_error', optimizer='SGD')

CHỈNH SỬA: Khi một số người cố gắng hiểu - ở đây bạn có một hình ảnh giải thích:

nhập mô tả hình ảnh ở đây


Cảm ơn vì lời giải thích của bạn. Tuy nhiên, chỉ cần làm rõ: với Dense(16, input_shape=(5,3), liệu mỗi nơ-ron đầu ra từ tập 16 (và, đối với tất cả 5 bộ nơ-ron này), có được kết nối với tất cả (3 x 5 = 15) nơ-ron đầu vào không? Hay mỗi nơ-ron trong bộ 16 đầu tiên chỉ được kết nối với 3 nơ-ron trong bộ 5 nơ-ron đầu vào, và sau đó mỗi nơ-ron trong bộ 16 đầu vào chỉ được kết nối với 3 nơ-ron trong bộ 5 đầu vào thứ hai. tế bào thần kinh, vv .... Tôi bối rối không biết nó là gì!
Karnivaurus

1
Bạn có một lớp dày đặc có 3 nơ-ron và đầu ra 16 được áp dụng cho mỗi 5 bộ 3 nơ-ron.
Marcin Możejko

1
À được rồi. Những gì tôi đang cố gắng làm là lấy danh sách 5 pixel màu làm đầu vào và tôi muốn chúng chuyển qua một lớp được kết nối đầy đủ. Vậy input_shape=(5,3)có nghĩa là có 5 pixel và mỗi pixel có ba kênh (R, G, B). Nhưng theo những gì bạn đang nói, mỗi kênh sẽ được xử lý riêng lẻ, trong khi tôi muốn tất cả ba kênh được xử lý bởi tất cả các tế bào thần kinh trong lớp đầu tiên. Vì vậy, việc áp dụng Flattenlớp ngay lập tức khi bắt đầu có mang lại cho tôi những gì tôi muốn không?
Karnivaurus

8
Một chút bản vẽ có và không Flattencó có thể giúp hiểu.
Xvolks

2
Ok, các bạn - Tôi đã cung cấp cho bạn một hình ảnh. Bây giờ bạn có thể xóa phiếu phản đối của mình.
Marcin Możejko

52

nhập mô tả hình ảnh ở đây Đây là cách Flatten hoạt động khi chuyển đổi Ma trận thành mảng đơn.


4
Anh chàng này cần phải tạo ra nhiều hình ảnh hơn. Tôi thích điều này. Nó có lý.
alofgran

10
Có, nhưng tại sao nó lại cần, đây là câu hỏi thực tế mà tôi nghĩ.
Helen - xuống với PCorrectness

35

đọc ngắn:

Làm phẳng một tensor có nghĩa là loại bỏ tất cả các kích thước ngoại trừ một kích thước. Đây chính xác là những gì lớp Flatten làm.

đọc lâu:

Nếu chúng tôi xem xét mô hình ban đầu (với lớp Flatten) được tạo ra, chúng tôi có thể nhận được tóm tắt mô hình sau:

Layer (type)                 Output Shape              Param #   
=================================================================
D16 (Dense)                  (None, 3, 16)             48        
_________________________________________________________________
A (Activation)               (None, 3, 16)             0         
_________________________________________________________________
F (Flatten)                  (None, 48)                0         
_________________________________________________________________
D4 (Dense)                   (None, 4)                 196       
=================================================================
Total params: 244
Trainable params: 244
Non-trainable params: 0

Đối với bản tóm tắt này, hình ảnh tiếp theo hy vọng sẽ cung cấp thêm chút ý nghĩa về kích thước đầu vào và đầu ra cho mỗi lớp.

Hình dạng đầu ra cho lớp Flatten như bạn có thể đọc là (None, 48). Đây là mẹo. Bạn nên đọc nó (1, 48)hoặc (2, 48)hoặc ... hoặc (16, 48)... hoặc(32, 48) , ...

Trong thực tế, None trên vị trí đó có nghĩa là bất kỳ kích thước lô nào. Đối với các đầu vào để gọi lại, thứ nguyên đầu tiên có nghĩa là kích thước lô và thứ hai có nghĩa là số lượng tính năng đầu vào.

Vai trò của lớp Flatten trong Keras rất đơn giản:

Thao tác làm phẳng trên tensor định hình lại tensor để có hình dạng bằng với số phần tử có trong tensor không bao gồm kích thước lô .

nhập mô tả hình ảnh ở đây


Lưu ý: Tôi đã sử dụng model.summary()phương pháp này để cung cấp hình dạng đầu ra và chi tiết tham số.


1
Sơ đồ rất sâu sắc.
Shrey Joshi, 09/07/19

1
Cảm ơn vì sơ đồ. Nó cho tôi một bức tranh rõ ràng.
Sultan Ahmed Sagor

0

Làm phẳng làm cho rõ ràng cách bạn tuần tự hóa một tensor đa chiều (chủ yếu là đầu vào). Điều này cho phép ánh xạ giữa tensor đầu vào (phẳng) và lớp ẩn đầu tiên. Nếu lớp ẩn đầu tiên "dày đặc" thì mỗi phần tử của tensor đầu vào (được nối tiếp hóa) sẽ được kết nối với mỗi phần tử của mảng ẩn. Nếu bạn không sử dụng Flatten, cách thức ánh xạ đầu vào tensor lên lớp ẩn đầu tiên sẽ không rõ ràng.


0

Tôi đã xem qua điều này gần đây, nó chắc chắn đã giúp tôi hiểu: https://www.cs.ryerson.ca/~aharley/vis/conv/

Vì vậy, có một đầu vào, một Conv2D, MaxPooling2D, v.v., các lớp Flatten ở cuối và hiển thị chính xác cách chúng được hình thành và cách chúng tiếp tục để xác định các phân loại cuối cùng (0-9).

Khi sử dụng trang web của chúng tôi, bạn xác nhận rằng bạn đã đọc và hiểu Chính sách cookieChính sách bảo mật của chúng tôi.
Licensed under cc by-sa 3.0 with attribution required.