SPARK Mllib: Hồi quy logistic nhiều lớp, làm thế nào để có được xác suất của tất cả các lớp thay vì lớp trên cùng?


7

Tôi đang sử dụng LogisticRegressionWithLBFGSđể đào tạo một phân loại nhiều lớp.

Có cách nào để có được xác suất của tất cả các lớp (không chỉ lớp ứng viên hàng đầu) khi tôi kiểm tra mô hình trên các mẫu mới chưa thấy?

PS Tôi không nhất thiết phải sử dụng trình phân loại LBFGS nhưng tôi muốn sử dụng hồi quy logistic trong vấn đề của mình. Vì vậy, nếu có một giải pháp bằng cách sử dụng một loại phân loại LR khác, tôi sẽ tìm nó.

Câu trả lời:


4

Tôi đang làm việc trên Trình phân loại rừng ngẫu nhiên và trình phân loại này có thuộc tính xác suất trong dự đoán tức là nếu bạn có bản tóm tắt predictions = model.transform(testData)như print(predictions)trong PySpark, bạn sẽ nhận được xác suất của mỗi nhãn, Bạn có thể kiểm tra mã bên dưới và đầu ra của mã:

from pyspark.sql import DataFrame
from pyspark import SparkContext, SQLContext
from pyspark.ml import Pipeline
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.feature import StringIndexer, VectorIndexer
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

# Split the data into training and test sets (30% held out for testing)
(trainingData, testData) = data.randomSplit([0.7, 0.3])

# Train a Random Forest model.
rf = RandomForestClassifier(labelCol="label", featuresCol="features", numTrees=12,  maxDepth=10)

# Chain RF in a Pipeline
pipeline = Pipeline(stages=[rf])

# Train model.
model = pipeline.fit(trainingData)

# Make predictions.
predictions = model.transform(testData)

Bây giờ công việc của bạn bắt đầu từ đây. thử in dự đoán và giá trị của dự đoán

print(predictions)

Đầu ra:

DataFrame[label: double, features: vector, indexed: double, rawPrediction: vector, probability: vector, prediction: double]

Vì vậy, trong DataFrame bạn có xác suất là xác suất của mỗi nhãn được lập chỉ mục, tôi đã kiểm tra thêm dưới dạng:

print predictions.show(3)

Đầu ra:

+-----+--------------------+-------+--------------------+--------------------+----------+
|label|            features|indexed|       rawPrediction|         probability|prediction|
+-----+--------------------+-------+--------------------+--------------------+----------+
|  5.0|(2000,[141,260,50...|    0.0|[34.8672584923246...|[0.69734516984649...|       0.0|
|  5.0|(2000,[109,126,18...|    0.0|[34.6231572522266...|[0.69246314504453...|       0.0|
|  5.0|(2000,[185,306,34...|    0.0|[34.5016453103805...|[0.69003290620761...|       0.0|
+-----+--------------------+-------+--------------------+--------------------+----------+
only showing top 3 rows

Chỉ cho cột xác suất:

print predictions.select('probability').take(2)

Đầu ra:

[Row(probability=DenseVector([0.6973, 0.1889, 0.0532, 0.0448, 0.0157])), Row(probability=DenseVector([0.6925, 0.1825, 0.0579, 0.0497, 0.0174]))]

Trong trường hợp của tôi, tôi có 5 chỉ sốLabels và vì vậy độ dài vectơ xác suất5 , Hy vọng điều này sẽ giúp bạn có được xác suất của mỗi nhãn trong vấn đề của bạn.

PS: Bạn có thể sẽ nhận được xác suất trong Cây quyết định , hồi quy logistic . Chỉ cần cố gắng để có được bản tóm tắt model.transform(testData).


Để tham khảo, bạn có thể kiểm tra tham chiếu cây quyết định tại đây
krishna Prasad

1

Để có được tất cả các xác suất thay vì tất cả các lớp thay vì chỉ có lớp được gắn nhãn, cho đến nay không có phương thức rõ ràng nào (Spark 2.0) trong Spark MLlib hoặc ML. Nhưng bạn có thể mở rộng lớp Hồi quy Logistic từ mã nguồn MLlib để có được các xác suất đó.

Một đoạn mã mẫu có thể được tìm thấy trong câu trả lời này .

Khi sử dụng trang web của chúng tôi, bạn xác nhận rằng bạn đã đọc và hiểu Chính sách cookieChính sách bảo mật của chúng tôi.
Licensed under cc by-sa 3.0 with attribution required.