Như đã được đề cập trong các câu trả lời trước, rừng ngẫu nhiên cho cây hồi quy / hồi quy không tạo ra dự đoán dự kiến cho các điểm dữ liệu nằm ngoài phạm vi của phạm vi dữ liệu đào tạo vì chúng không thể ngoại suy (tốt). Cây hồi quy bao gồm một hệ thống phân cấp của các nút, trong đó mỗi nút chỉ định một thử nghiệm được thực hiện trên một giá trị thuộc tính và mỗi nút lá (terminal) chỉ định một quy tắc để tính toán đầu ra dự đoán. Trong trường hợp của bạn, quan sát thử nghiệm chảy qua các cây đến các nút lá, ví dụ: "nếu x> 335, thì y = 15", sau đó được lấy trung bình bởi rừng ngẫu nhiên.
Dưới đây là một kịch bản R trực quan hóa tình huống với cả rừng ngẫu nhiên và hồi quy tuyến tính. Trong trường hợp rừng ngẫu nhiên, các dự đoán là không đổi để kiểm tra các điểm dữ liệu thấp hơn giá trị x dữ liệu huấn luyện thấp nhất hoặc cao hơn giá trị x dữ liệu đào tạo cao nhất.
library(datasets)
library(randomForest)
library(ggplot2)
library(ggthemes)
# Import mtcars (Motor Trend Car Road Tests) dataset
data(mtcars)
# Define training data
train_data = data.frame(
x = mtcars$hp, # Gross horsepower
y = mtcars$qsec) # 1/4 mile time
# Train random forest model for regression
random_forest <- randomForest(x = matrix(train_data$x),
y = matrix(train_data$y), ntree = 20)
# Train linear regression model using ordinary least squares (OLS) estimator
linear_regr <- lm(y ~ x, train_data)
# Create testing data
test_data = data.frame(x = seq(0, 400))
# Predict targets for testing data points
test_data$y_predicted_rf <- predict(random_forest, matrix(test_data$x))
test_data$y_predicted_linreg <- predict(linear_regr, test_data)
# Visualize
ggplot2::ggplot() +
# Training data points
ggplot2::geom_point(data = train_data, size = 2,
ggplot2::aes(x = x, y = y, color = "Training data")) +
# Random forest predictions
ggplot2::geom_line(data = test_data, size = 2, alpha = 0.7,
ggplot2::aes(x = x, y = y_predicted_rf,
color = "Predicted with random forest")) +
# Linear regression predictions
ggplot2::geom_line(data = test_data, size = 2, alpha = 0.7,
ggplot2::aes(x = x, y = y_predicted_linreg,
color = "Predicted with linear regression")) +
# Hide legend title, change legend location and add axis labels
ggplot2::theme(legend.title = element_blank(),
legend.position = "bottom") + labs(y = "1/4 mile time",
x = "Gross horsepower") +
ggthemes::scale_colour_colorblind()