Tôi đang làm việc trên Trình phân loại rừng ngẫu nhiên và trình phân loại này có thuộc tính xác suất trong dự đoán tức là nếu bạn có bản tóm tắt predictions = model.transform(testData)
như print(predictions)
trong PySpark, bạn sẽ nhận được xác suất của mỗi nhãn, Bạn có thể kiểm tra mã bên dưới và đầu ra của mã:
from pyspark.sql import DataFrame
from pyspark import SparkContext, SQLContext
from pyspark.ml import Pipeline
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.feature import StringIndexer, VectorIndexer
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
# Split the data into training and test sets (30% held out for testing)
(trainingData, testData) = data.randomSplit([0.7, 0.3])
# Train a Random Forest model.
rf = RandomForestClassifier(labelCol="label", featuresCol="features", numTrees=12, maxDepth=10)
# Chain RF in a Pipeline
pipeline = Pipeline(stages=[rf])
# Train model.
model = pipeline.fit(trainingData)
# Make predictions.
predictions = model.transform(testData)
Bây giờ công việc của bạn bắt đầu từ đây. thử in dự đoán và giá trị của dự đoán
print(predictions)
Đầu ra:
DataFrame[label: double, features: vector, indexed: double, rawPrediction: vector, probability: vector, prediction: double]
Vì vậy, trong DataFrame bạn có xác suất là xác suất của mỗi nhãn được lập chỉ mục, tôi đã kiểm tra thêm dưới dạng:
print predictions.show(3)
Đầu ra:
+-----+--------------------+-------+--------------------+--------------------+----------+
|label| features|indexed| rawPrediction| probability|prediction|
+-----+--------------------+-------+--------------------+--------------------+----------+
| 5.0|(2000,[141,260,50...| 0.0|[34.8672584923246...|[0.69734516984649...| 0.0|
| 5.0|(2000,[109,126,18...| 0.0|[34.6231572522266...|[0.69246314504453...| 0.0|
| 5.0|(2000,[185,306,34...| 0.0|[34.5016453103805...|[0.69003290620761...| 0.0|
+-----+--------------------+-------+--------------------+--------------------+----------+
only showing top 3 rows
Chỉ cho cột xác suất:
print predictions.select('probability').take(2)
Đầu ra:
[Row(probability=DenseVector([0.6973, 0.1889, 0.0532, 0.0448, 0.0157])), Row(probability=DenseVector([0.6925, 0.1825, 0.0579, 0.0497, 0.0174]))]
Trong trường hợp của tôi, tôi có 5 chỉ sốLabels và vì vậy độ dài vectơ xác suất là 5 , Hy vọng điều này sẽ giúp bạn có được xác suất của mỗi nhãn trong vấn đề của bạn.
PS: Bạn có thể sẽ nhận được xác suất trong Cây quyết định , hồi quy logistic . Chỉ cần cố gắng để có được bản tóm tắt model.transform(testData)
.