Làm thế nào tôi có thể có được dự đoán cho chỉ một trường hợp trong Keras?


13

Khi tôi yêu cầu Keras áp dụng dự đoán với một mô hình được trang bị cho bộ dữ liệu mới mà không có nhãn như thế này:

model1.predict_classes(X_test)

nó hoạt động tốt. Nhưng khi tôi cố gắng đưa ra dự đoán chỉ một hàng thì thất bại:

model1.predict_classes(X_test[10])

Exception: Error when checking : expected dense_input_6 to have shape (None, 784) but got array with shape (784, 1)

Tôi tự hỏi, tại sao?

Câu trả lời:


12

Bạn có thể làm:

q = model.predict( np.array( [single_x_test,] )  )

Mà cũng trả về a numpy.ndarray. Vì vậy, để có được giá trị bạn muốn:q = model.predict(np.array([single_x_test]))[0]
Loisaida Sam Sandberg

7

predict_classesđang mong đợi một mảng 2D của hình dạng (num_instances, features), giống như X_testlà. Nhưng lập chỉ mục một thể hiện như X_test[10]trả về một mảng hình dạng 1D (features,).

Để thêm phía sau trục phụ, bạn có thể sử dụng , hoặc , hoặc không được thoát khỏi nó ở nơi đầu tiên (ví dụ, bằng cách sử dụng ).np.expand_dims(X_test[10], axis=0)X_test[10][np.newaxis,:]X_test[10:11]


Nó dường như không hoạt động: không có thông báo lỗi nhưng cũng không có đầu ra. Kỳ dị.
Hendrik

5

Hiện tại (Keras v2.0.8) phải mất thêm một chút nỗ lực để có được dự đoán trên các hàng đơn sau khi đào tạo theo đợt.

Về cơ bản, batch_size được cố định tại thời điểm đào tạo và phải giống nhau ở thời điểm dự đoán.

Cách giải quyết ngay bây giờ là lấy các trọng số từ mô hình đã được đào tạo và sử dụng các trọng số đó làm các trọng số trong một mô hình mới mà bạn vừa tạo, có một lô có kích thước 1.

Mã nhanh chóng cho điều đó là

model = create_model(batch_size=64)
mode.fit(X, y)
weights = model.get_weights()
single_item_model = create_model(batch_size=1)
single_item_model.set_weights(weights)
single_item_model.compile(compile_params)

Đây là một bài đăng blog đi sâu hơn: https://machinelearningmastery.com/use-different-batch-sizes-training-predicting-python-keras/

Trước đây, tôi đã sử dụng phương pháp này để có nhiều mô hình dự đoán - một mô hình đưa ra dự đoán về các lô lớn, một mô hình đưa ra dự đoán về các lô nhỏ và một mô hình đưa ra dự đoán về các mục đơn lẻ. Vì các dự đoán theo đợt hiệu quả hơn nhiều, điều này cho phép chúng tôi linh hoạt trong bất kỳ số lượng hàng dự đoán nào (không chỉ là một số chia đều cho batch_size), trong khi vẫn nhận được dự đoán khá nhanh.


3

Đây sẽ là cách dự đoán cho một yếu tố, lần này là số 17.

model.predict_classes(X_test[17:18])

Có gì sai với câu trả lời?
Supamee

2

Bạn nên vượt qua một danh sách chỉ với 1 ví dụ, tôi không thể kiểm tra ngay bây giờ nhưng điều này sẽ hoạt động:

model1.predict_classes([X_test[10]])

Thật không may làm việc không may.
Hendrik

1
self.result = self.model.predict(X)

Trong đó X là mảng numpy. Đó là tất cả những gì tôi đã làm và nó đã làm việc.


1

Tôi đã sửa lỗi này bằng cách sử dụng phương pháp sau:

single_test = X_test[10]
single_test = single_test.reshape(1,784)

Xin lưu ý rằng số lượng tính năng (784) trong chức năng định hình lại dựa trên ví dụ của bạn ở trên, nếu bạn có ít tính năng hơn thì bạn cần điều chỉnh nó.

Hy vọng nó sẽ làm việc cho bạn quá.


1

nếu bạn cố gắng in ra ví dụ, bạn sẽ thấy điều này:

x_test:\n
array([[0., 1., 1., ..., 0., 0., 0.],
        [0., 1., 1., ..., 0., 0., 0.],
        [0., 1., 1., ..., 0., 0., 0.],
        ...,
        [0., 1., 0., ..., 0., 0., 0.],
        [0., 1., 1., ..., 0., 0., 0.],
        [0., 1., 1., ..., 0., 0., 0.]])

x_test[0]:
array([0., 1., 1., ..., 0., 0., 0.])

vì vậy tôi nghĩ rằng chúng ta chỉ có thể thêm lại một thứ nguyên bằng cách sử dụng np.array:

mode.predict(np.array(x_test[0],ndmin=2))

0

Nó có nghĩa là dữ liệu đào tạo của bạn có hình dạng (784, 1). Bạn chỉ có thể định hình lại nó như sau. Nó làm việc cho tôi.

model1.predict_classes(X_test[10].reshape(784,1))

Bạn cũng có thể làm transpose()nếu hình dạng là (1,784),

model1.predict_classes(X_test[10].transpose())
Khi sử dụng trang web của chúng tôi, bạn xác nhận rằng bạn đã đọc và hiểu Chính sách cookieChính sách bảo mật của chúng tôi.
Licensed under cc by-sa 3.0 with attribution required.