Làm thế nào để một mô hình hồi quy logistic đơn giản đạt được độ chính xác phân loại 92% trên MNIST?


68

Mặc dù tất cả các hình ảnh trong bộ dữ liệu MNIST đều được căn giữa, với tỷ lệ tương tự và không có góc quay, chúng có một biến thể chữ viết đáng kể đánh đố tôi làm thế nào một mô hình tuyến tính đạt được độ chính xác phân loại cao như vậy.

Theo như tôi có thể hình dung, với sự khác biệt đáng kể về chữ viết tay, các chữ số nên không thể tách rời tuyến tính trong không gian 784 chiều, nghĩa là, nên có một ranh giới phi tuyến tính phức tạp (mặc dù không phức tạp) để phân tách các chữ số khác nhau , tương tự như ví dụ được trích dẫn tốt trong đó các lớp dương và âm không thể được phân tách bằng bất kỳ phân loại tuyến tính nào. Nó có vẻ gây trở ngại cho tôi làm thế nào hồi quy logistic đa lớp tạo ra độ chính xác cao như vậy với các tính năng hoàn toàn tuyến tính (không có tính năng đa thức).XOR

Ví dụ, với bất kỳ pixel nào trong ảnh, các biến thể viết tay khác nhau của các chữ số và có thể làm cho pixel đó được chiếu sáng hoặc không. Do đó, với một tập hợp các trọng số đã học, mỗi pixel có thể làm cho một chữ số trông giống như cũng như . Chỉ với sự kết hợp của các giá trị pixel, có thể nói liệu một chữ số là hay . Điều này đúng với hầu hết các cặp chữ số. Vì vậy, làm thế nào là hồi quy logistic, một cách mù quáng dựa trên quyết định của nó một cách độc lập trên tất cả các giá trị pixel (mà không xem xét bất kỳ phụ thuộc giữa các pixel nào), có thể đạt được độ chính xác cao như vậy.232323

Tôi biết rằng tôi đã sai ở đâu đó hoặc chỉ ước tính quá mức sự thay đổi trong hình ảnh. Tuy nhiên, sẽ thật tuyệt nếu ai đó có thể giúp tôi có trực giác về cách các chữ số 'gần như' có thể phân tách tuyến tính.


Hãy xem sách giáo khoa Học thống kê với độ thưa thớt: Lasso và khái quát hóa 3.3.1 Ví dụ: Chữ số viết tay web.stanford.edu/~hastie/StatLearnSparsity_files/SLS.pdf
Adrian

Tôi đã tò mò: làm thế nào một cái gì đó giống như một mô hình tuyến tính bị phạt (ví dụ, glmnet) làm gì với vấn đề này? Nếu tôi nhớ lại, những gì bạn đang báo cáo là độ chính xác ngoài mẫu không được đánh giá cao.
Vách đá AB

Câu trả lời:


86

tl; dr Mặc dù đây là một bộ dữ liệu phân loại hình ảnh, nó vẫn là một nhiệm vụ rất dễ dàng , mà người ta có thể dễ dàng tìm thấy một ánh xạ trực tiếp từ đầu vào đến dự đoán.


Câu trả lời:

Đây là một câu hỏi rất thú vị và nhờ sự đơn giản của hồi quy logistic mà bạn thực sự có thể tìm ra câu trả lời.

Những gì hồi quy logistic làm là cho mỗi hình ảnh chấp nhận đầu vào và nhân chúng với trọng số để tạo dự đoán của nó. Điều thú vị là do ánh xạ trực tiếp giữa đầu vào và đầu ra (nghĩa là không có lớp ẩn), giá trị của mỗi trọng số tương ứng với mỗi một trong số đầu vào được tính đến khi tính toán xác suất của mỗi lớp. Bây giờ, bằng cách lấy các trọng số cho mỗi lớp và định hình lại chúng thành (tức là độ phân giải hình ảnh), chúng ta có thể biết pixel nào là quan trọng nhất cho tính toán của mỗi lớp .78478428×28

Lưu ý, một lần nữa, đây là những trọng lượng .

Bây giờ hãy xem hình ảnh trên và tập trung vào hai chữ số đầu tiên (tức là 0 và một). Trọng lượng màu xanh có nghĩa là cường độ của pixel này đóng góp rất nhiều cho lớp đó và các giá trị màu đỏ có nghĩa là nó đóng góp tiêu cực.

Bây giờ hãy tưởng tượng, làm thế nào để một người vẽ ? Anh ta vẽ một hình tròn trống rỗng ở giữa. Đó chính xác là những gì các trọng lượng nhặt được trên. Trong thực tế, nếu một người nào đó thu hút giữa của hình ảnh, nó đếm tiêu cực như một số không. Vì vậy, để nhận ra số không, bạn không cần một số bộ lọc tinh vi và các tính năng cấp cao. Bạn chỉ có thể nhìn vào các vị trí pixel được vẽ và đánh giá theo điều này.0

Điều tương tự cho . Nó luôn có một đường thẳng đứng ở giữa hình ảnh. Tất cả những thứ khác được tính tiêu cực.1

Các chữ số còn lại phức tạp hơn một chút, nhưng với trí tưởng tượng nhỏ, bạn có thể thấy , , và . Các con số còn lại khó khăn hơn một chút, đó là điều thực sự hạn chế hồi quy logistic khi đạt đến những năm 90 cao.2378

Thông qua điều này, bạn có thể thấy rằng hồi quy logistic có cơ hội rất tốt để có được nhiều hình ảnh đúng và đó là lý do tại sao nó đạt điểm rất cao.


Mã để tái tạo hình trên có một chút ngày, nhưng ở đây bạn đi:

import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data

# Load MNIST:
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

# Create model
x = tf.placeholder(tf.float32, shape=(None, 784))
y = tf.placeholder(tf.float32, shape=(None, 10))

W = tf.Variable(tf.zeros((784,10)))
b = tf.Variable(tf.zeros((10)))
z = tf.matmul(x, W) + b

y_hat = tf.nn.softmax(z)
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y * tf.log(y_hat), reduction_indices=[1]))
optimizer = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) # 

correct_pred = tf.equal(tf.argmax(y_hat, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

# Train model
batch_size = 64
with tf.Session() as sess:

    loss_tr, acc_tr, loss_ts, acc_ts = [], [], [], []

    sess.run(tf.global_variables_initializer()) 

    for step in range(1, 1001):

        x_batch, y_batch = mnist.train.next_batch(batch_size) 
        sess.run(optimizer, feed_dict={x: x_batch, y: y_batch})

        l_tr, a_tr = sess.run([cross_entropy, accuracy], feed_dict={x: x_batch, y: y_batch})
        l_ts, a_ts = sess.run([cross_entropy, accuracy], feed_dict={x: mnist.test.images, y: mnist.test.labels})
        loss_tr.append(l_tr)
        acc_tr.append(a_tr)
        loss_ts.append(l_ts)
        acc_ts.append(a_ts)

    weights = sess.run(W)      
    print('Test Accuracy =', sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels})) 

# Plotting:
for i in range(10):
    plt.subplot(2, 5, i+1)
    weight = weights[:,i].reshape([28,28])
    plt.title(i)
    plt.imshow(weight, cmap='RdBu')  # as noted by @Eric Duminil, cmap='gray' makes the numbers stand out more
    frame1 = plt.gca()
    frame1.axes.get_xaxis().set_visible(False)
    frame1.axes.get_yaxis().set_visible(False)

12
Cảm ơn đã minh họa. Những hình ảnh trọng lượng này làm cho nó rõ ràng hơn như thế nào là độ chính xác rất cao. Phép nhân số chấm của hình ảnh chữ số viết tay với hình ảnh trọng lượng tương ứng với nhãn thật của hình ảnh 'dường như' là cao nhất so với sản phẩm chấm với hầu hết các nhãn trọng lượng khác (vẫn giống tôi đến 92%) của các hình ảnh trong MNIST. Tuy nhiên, có một chút ngạc nhiên khi và hoặc và hiếm khi bị phân loại nhầm lẫn nhau khi kiểm tra ma trận nhầm lẫn. Dù sao, đây là những gì nó được. Các dữ liệu không bao giờ nói dối. :)2378
Nitish Agarwal

13
Tất nhiên, điều đó giúp các mẫu của MNIST được căn giữa, chia tỷ lệ và được chuẩn hóa tương phản trước khi bộ phân loại nhìn thấy chúng. Bạn không phải giải quyết các câu hỏi như "nếu cạnh số 0 thực sự đi qua giữa hộp thì sao?" bởi vì bộ xử lý trước đã đi một chặng đường dài hướng tới việc làm cho tất cả các số 0 trông giống nhau.
hobbs

1
@EricDuminil Tôi đã thêm một lời khen về kịch bản với đề xuất của bạn. Cảm ơn rất nhiều cho đầu vào! : D
Djib2011

1
@NitishAgarwal, Nếu bạn nghĩ rằng câu trả lời này là Câu trả lời cho Câu hỏi của bạn, hãy xem xét việc đánh dấu nó như vậy.
sintax

11
Đối với ai đó quan tâm nhưng không đặc biệt quen thuộc với cách xử lý này, câu trả lời này cung cấp một ví dụ trực quan tuyệt vời về cơ học.
chrylis -on đình công-
Khi sử dụng trang web của chúng tôi, bạn xác nhận rằng bạn đã đọc và hiểu Chính sách cookieChính sách bảo mật của chúng tôi.
Licensed under cc by-sa 3.0 with attribution required.