Hồi quy rừng ngẫu nhiên không dự đoán cao hơn dữ liệu đào tạo


12

Tôi đã nhận thấy rằng khi xây dựng các mô hình hồi quy rừng ngẫu nhiên, ít nhất là trong R, giá trị dự đoán không bao giờ vượt quá giá trị tối đa của biến mục tiêu được thấy trong dữ liệu huấn luyện. Ví dụ, xem mã dưới đây. Tôi đang xây dựng mô hình hồi quy để dự đoán mpgdựa trên mtcarsdữ liệu. Tôi xây dựng OLS và các mô hình rừng ngẫu nhiên, và sử dụng chúng để dự đoán mpgcho một chiếc xe giả định nên có khả năng tiết kiệm nhiên liệu rất tốt. OLS dự đoán một mpgkhu rừng cao , như mong đợi, nhưng rừng ngẫu nhiên thì không. Tôi cũng nhận thấy điều này trong các mô hình phức tạp hơn. Tại sao lại thế này?

> library(datasets)
> library(randomForest)
> 
> data(mtcars)
> max(mtcars$mpg)
[1] 33.9
> 
> set.seed(2)
> fit1 <- lm(mpg~., data=mtcars) #OLS fit
> fit2 <- randomForest(mpg~., data=mtcars) #random forest fit
> 
> #Hypothetical car that should have very high mpg
> hypCar <- data.frame(cyl=4, disp=50, hp=40, drat=5.5, wt=1, qsec=24, vs=1, am=1, gear=4, carb=1)
> 
> predict(fit1, hypCar) #OLS predicts higher mpg than max(mtcars$mpg)
      1 
37.2441 
> predict(fit2, hypCar) #RF does not predict higher mpg than max(mtcars$mpg)
       1 
30.78899 

Có phải mọi người thường gọi hồi quy tuyến tính là OLS? Tôi đã luôn nghĩ về OLS như một phương pháp.
Hao Ye

1
Tôi tin rằng OLS là phương pháp hồi quy tuyến tính mặc định, ít nhất là trong R.
Gaurav Bansal

Đối với cây / rừng ngẫu nhiên, dự đoán là trung bình của dữ liệu huấn luyện trong nút tương ứng. Vì vậy, nó không thể lớn hơn các giá trị trong dữ liệu đào tạo.
Jason

1
Tôi đồng ý nhưng nó đã được trả lời bởi ít nhất ba người dùng khác.
HelloWorld

Câu trả lời:


12

Như đã được đề cập trong các câu trả lời trước, rừng ngẫu nhiên cho cây hồi quy / hồi quy không tạo ra dự đoán dự kiến ​​cho các điểm dữ liệu nằm ngoài phạm vi của phạm vi dữ liệu đào tạo vì chúng không thể ngoại suy (tốt). Cây hồi quy bao gồm một hệ thống phân cấp của các nút, trong đó mỗi nút chỉ định một thử nghiệm được thực hiện trên một giá trị thuộc tính và mỗi nút lá (terminal) chỉ định một quy tắc để tính toán đầu ra dự đoán. Trong trường hợp của bạn, quan sát thử nghiệm chảy qua các cây đến các nút lá, ví dụ: "nếu x> 335, thì y = 15", sau đó được lấy trung bình bởi rừng ngẫu nhiên.

Dưới đây là một kịch bản R trực quan hóa tình huống với cả rừng ngẫu nhiên và hồi quy tuyến tính. Trong trường hợp rừng ngẫu nhiên, các dự đoán là không đổi để kiểm tra các điểm dữ liệu thấp hơn giá trị x dữ liệu huấn luyện thấp nhất hoặc cao hơn giá trị x dữ liệu đào tạo cao nhất.

library(datasets)
library(randomForest)
library(ggplot2)
library(ggthemes)

# Import mtcars (Motor Trend Car Road Tests) dataset
data(mtcars)

# Define training data
train_data = data.frame(
    x = mtcars$hp,  # Gross horsepower
    y = mtcars$qsec)  # 1/4 mile time

# Train random forest model for regression
random_forest <- randomForest(x = matrix(train_data$x),
                              y = matrix(train_data$y), ntree = 20)
# Train linear regression model using ordinary least squares (OLS) estimator
linear_regr <- lm(y ~ x, train_data)

# Create testing data
test_data = data.frame(x = seq(0, 400))

# Predict targets for testing data points
test_data$y_predicted_rf <- predict(random_forest, matrix(test_data$x)) 
test_data$y_predicted_linreg <- predict(linear_regr, test_data)

# Visualize
ggplot2::ggplot() + 
    # Training data points
    ggplot2::geom_point(data = train_data, size = 2,
                        ggplot2::aes(x = x, y = y, color = "Training data")) +
    # Random forest predictions
    ggplot2::geom_line(data = test_data, size = 2, alpha = 0.7,
                       ggplot2::aes(x = x, y = y_predicted_rf,
                                    color = "Predicted with random forest")) +
    # Linear regression predictions
    ggplot2::geom_line(data = test_data, size = 2, alpha = 0.7,
                       ggplot2::aes(x = x, y = y_predicted_linreg,
                                    color = "Predicted with linear regression")) +
    # Hide legend title, change legend location and add axis labels
    ggplot2::theme(legend.title = element_blank(),
                   legend.position = "bottom") + labs(y = "1/4 mile time",
                                                      x = "Gross horsepower") +
    ggthemes::scale_colour_colorblind()

Ngoại suy với rừng ngẫu nhiên và hồi quy tuyến tính


16

Không có cách nào để Rừng ngẫu nhiên ngoại suy như OLS. Lý do rất đơn giản: các dự đoán từ Rừng ngẫu nhiên được thực hiện thông qua việc lấy trung bình các kết quả thu được trong một số cây. Các cây tự xuất giá trị trung bình của các mẫu trong mỗi nút đầu cuối, các lá. Kết quả không thể nằm ngoài phạm vi của dữ liệu đào tạo, vì trung bình luôn nằm trong phạm vi của các thành phần của nó.

Nói cách khác, trung bình không thể lớn hơn (hoặc thấp hơn) so với mọi mẫu và hồi quy Rừng ngẫu nhiên dựa trên tính trung bình.


11

Cây quyết định / Forrest ngẫu nhiên không thể ngoại suy bên ngoài dữ liệu đào tạo. Và mặc dù OLS có thể làm điều này, những dự đoán như vậy nên được xem xét một cách thận trọng; vì mẫu đã xác định có thể không tiếp tục nằm ngoài phạm vi quan sát được.

Khi sử dụng trang web của chúng tôi, bạn xác nhận rằng bạn đã đọc và hiểu Chính sách cookieChính sách bảo mật của chúng tôi.
Licensed under cc by-sa 3.0 with attribution required.