Điều gì đang xảy ra ở đây, khi tôi sử dụng mất bình phương trong cài đặt hồi quy logistic?


16

Tôi đang cố gắng sử dụng tổn thất bình phương để phân loại nhị phân trên tập dữ liệu đồ chơi.

Tôi đang sử dụng mtcarstập dữ liệu, sử dụng dặm trên mỗi gallon và trọng lượng để dự đoán loại truyền. Biểu đồ bên dưới hiển thị hai loại dữ liệu loại truyền có màu khác nhau và ranh giới quyết định được tạo bởi chức năng mất khác nhau. Sự mất mát bình phương là i(yipi)2 nơi yi là nhãn thực địa (0 hoặc 1) và pi là dự đoán khả pi=Logit1(βTxi). Nói cách khác, tôi thay thế mất logistic bằng mất bình phương trong cài đặt phân loại, các phần khác đều giống nhau.

Đối với một ví dụ đồ chơi có mtcarsdữ liệu, trong nhiều trường hợp, tôi đã có một mô hình "tương tự" với hồi quy logistic (xem hình dưới đây, với hạt giống ngẫu nhiên 0).

nhập mô tả hình ảnh ở đây

Nhưng trong một số thời điểm (nếu chúng ta làm set.seed(1)), mất bình phương dường như không hoạt động tốt. nhập mô tả hình ảnh ở đây Chuyện gì đang xảy ra ở đây? Tối ưu hóa không hội tụ? Mất logistic dễ dàng hơn để tối ưu hóa so với mất bình phương? Bất kỳ trợ giúp sẽ được đánh giá cao.


d=mtcars[,c("am","mpg","wt")]
plot(d$mpg,d$wt,col=factor(d$am))
lg_fit=glm(am~.,d, family = binomial())
abline(-lg_fit$coefficients[1]/lg_fit$coefficients[3],
       -lg_fit$coefficients[2]/lg_fit$coefficients[3])
grid()

# sq loss
lossSqOnBinary<-function(x,y,w){
  p=plogis(x %*% w)
  return(sum((y-p)^2))
}

# ----------------------------------------------------------------
# note, this random seed is important for squared loss work
# ----------------------------------------------------------------
set.seed(0)

x0=runif(3)
x=as.matrix(cbind(1,d[,2:3]))
y=d$am
opt=optim(x0, lossSqOnBinary, method="BFGS", x=x,y=y)

abline(-opt$par[1]/opt$par[3],
       -opt$par[2]/opt$par[3], lty=2)
legend(25,5,c("logisitc loss","squared loss"), lty=c(1,2))

1
Có lẽ giá trị bắt đầu ngẫu nhiên là một người nghèo. Tại sao không chọn một cái tốt hơn?
whuber

1
@whuber mất logistic là lồi, nên bắt đầu không thành vấn đề. Mất bình phương trên p và y thì sao? có phải lồi không?
Haitao Du

5
Tôi không thể tái tạo những gì bạn mô tả. optimnói với bạn rằng nó chưa kết thúc, đó là tất cả: nó đang hội tụ. Bạn có thể học được rất nhiều bằng cách chạy lại mã của bạn với đối số bổ sung control=list(maxit=10000), vẽ sơ đồ mức độ phù hợp của nó và so sánh các hệ số của nó với mã gốc.
whuber

2
@amoeba cảm ơn bạn đã góp ý, tôi đã sửa lại câu hỏi. hy vọng nó tốt hơn
Haitao Du

@amoeba Tôi sẽ sửa lại truyền thuyết, nhưng tuyên bố này sẽ không sửa (3)? "Tôi đang sử dụng tập dữ liệu mtcars, sử dụng dặm trên mỗi gallon và trọng lượng để dự đoán loại truyền. Biểu đồ bên dưới hiển thị hai loại dữ liệu loại truyền có màu khác nhau và ranh giới quyết định được tạo bởi chức năng mất khác nhau."
Haitao Du

Câu trả lời:


19

Có vẻ như bạn đã khắc phục vấn đề trong ví dụ cụ thể của mình nhưng tôi nghĩ rằng nó vẫn đáng để nghiên cứu cẩn thận hơn về sự khác biệt giữa bình phương nhỏ nhất và hồi quy logistic khả năng tối đa.

LS(yi,y^i)=12(yiy^i)2LL(yi,y^i)=yilogy^i+(1yi)log(1y^i)

β^L:=argminbRpi=1nyilogg1(xiTb)+(1yi)log(1g1(xiTb))
g là chức năng liên kết của chúng tôi.

β^S:=argminbRp12i=1n(yig1(xiTb))2
β^SLSLL

fSfLLSLLβ^Sβ^Lh=g1y^i=h(xiTb)

h(z)=11+ezh(z)=h(z)(1h(z)).


fLbj=i=1nh(xiTb)xij(yih(xiTb)1yi1h(xiTb)).
h=h(1h)
fLbj=i=1nxij(yi(1y^i)(1yi)y^i)=i=1nxij(yiy^i)
fL(b)=XT(YY^).

Tiếp theo hãy làm dẫn xuất thứ hai. Người Hessian

HL:=2fLbjbk=i=1nxijxiky^i(1y^i).
HL=XTAXA=diag(Y^(1Y^))HLY^YHLb


Chúng ta hãy so sánh điều này với hình vuông nhỏ nhất.

fSbj=i=1n(yiy^i)h(xiTb)xij.

This means we have

fS(b)=XTA(YY^).
This is a vital point: the gradient is almost the same except for all i y^i(1y^i)(0,1) so basically we're flattening the gradient relative to fL. This'll make convergence slower.

For the Hessian we can first write

fSbj=i=1nxij(yiy^i)y^i(1y^i)=i=1nxij(yiy^i(1+yi)y^i2+y^i3).

This leads us to

HS:=2fSbjbk=i=1nxijxikh(xiTb)(yi2(1+yi)y^i+3y^i2).

Let B=diag(yi2(1+yi)y^i+3y^i2). We now have

HS=XTABX.

Unfortunately for us, the weights in B are not guaranteed to be non-negative: if yi=0 then yi2(1+yi)y^i+3y^i2=y^i(3y^i2) which is positive iff y^i>23. Similarly, if yi=1 then yi2(1+yi)y^i+3y^i2=14y^i+3y^i2 which is positive when y^i<13 (it's also positive for y^i>1 but that's not possible). This means that HS is not necessarily PSD, so not only are we squashing our gradients which will make learning harder, but we've also messed up the convexity of our problem.


All in all, it's no surprise that least squares logistic regression struggles sometimes, and in your example you've got enough fitted values close to 0 or 1 so that y^i(1y^i) can be pretty small and thus the gradient is quite flattened.

Connecting this to neural networks, even though this is but a humble logistic regression I think with squared loss you're experiencing something like what Goodfellow, Bengio, and Courville are referring to in their Deep Learning book when they write the following:

One recurring theme throughout neural network design is that the gradient of the cost function must be large and predictable enough to serve as a good guide for the learning algorithm. Functions that saturate (become very flat) undermine this objective because they make the gradient become very small. In many cases this happens because the activation functions used to produce the output of the hidden units or the output units saturate. The negative log-likelihood helps to avoid this problem for many models. Many output units involve an exp function that can saturate when its argument is very negative. The log function in the negative log-likelihood cost function undoes the exp of some output units. We will discuss the interaction between the cost function and the choice of output unit in Sec. 6.2.2.

and, in 6.2.2,

Unfortunately, mean squared error and mean absolute error often lead to poor results when used with gradient-based optimization. Some output units that saturate produce very small gradients when combined with these cost functions. This is one reason that the cross-entropy cost function is more popular than mean squared error or mean absolute error, even when it is not necessary to estimate an entire distribution p(y|x).

(both excerpts are from chapter 6).


1
I really like you helped me to derive the derivative and hessian. I will check it more careful tomorrow.
Haitao Du

1
@hxd1011 you're very welcome, and thanks for the link to that older question of yours! I've really been meaning to go through this more carefully so this was a great excuse :)
jld

1
I carefully read the math and verified with code. I found Hessian for squared loss does not match the numerical approximation. Could you check it? I am more than happy to show you the code if you want.
Haitao Du

@hxd1011 I just went through the derivation again and I think there's a sign error: for HS I think everywhere that I have yi2(1yi)y^i+3y^i2 it should be yi2(1+yi)y^i+3y^i2. Could you recheck and tell me if that fixes it? Thanks a lot for the correction.
jld

@hxd1011 glad that fixed it! thanks again for finding that
jld

5

I would thank to thank @whuber and @Chaconne for help. Especially @Chaconne, this derivation is what I wished to have for years.

The problem IS in the optimization part. If we set the random seed to 1, the default BFGS will not work. But if we change the algorithm and change the max iteration number it will work again.

As @Chaconne mentioned, the problem is squared loss for classification is non-convex and harder to optimize. To add on @Chaconne's math, I would like to present some visualizations on to logistic loss and squared loss.

We will change the demo data from mtcars, since the original toy example has 3 coefficients including the intercept. We will use another toy data set generated from mlbench, in this data set, we set 2 parameters, which is better for visualization.

Here is the demo

  • The data is shown in the left figure: we have two classes in two colors. x,y are two features for the data. In addition, we use red line to represent the linear classifier from logistic loss, and the blue line represent the linear classifier from squared loss.

  • The middle figure and right figure shows the contour for logistic loss (red) and squared loss (blue). x, y are two parameters we are fitting. The dot is the optimal point found by BFGS.

enter image description here

From the contour we can easily see how why optimizing squared loss is harder: as Chaconne mentioned, it is non-convex.

Here is one more view from persp3d.

enter image description here


Code

set.seed(0)
d=mlbench::mlbench.2dnormals(50,2,r=1)
x=d$x
y=ifelse(d$classes==1,1,0)

lg_loss <- function(w){
  p=plogis(x %*% w)
  L=-y*log(p)-(1-y)*log(1-p)
  return(sum(L))
}
sq_loss <- function(w){
  p=plogis(x %*% w)
  L=sum((y-p)^2)
  return(L)
}

w_grid_v=seq(-15,15,0.1)
w_grid=expand.grid(w_grid_v,w_grid_v)

opt1=optimx::optimx(c(1,1),fn=lg_loss ,method="BFGS")
z1=matrix(apply(w_grid,1,lg_loss),ncol=length(w_grid_v))

opt2=optimx::optimx(c(1,1),fn=sq_loss ,method="BFGS")
z2=matrix(apply(w_grid,1,sq_loss),ncol=length(w_grid_v))

par(mfrow=c(1,3))
plot(d,xlim=c(-3,3),ylim=c(-3,3))
abline(0,-opt1$p2/opt1$p1,col='darkred',lwd=2)
abline(0,-opt2$p2/opt2$p1,col='blue',lwd=2)
grid()
contour(w_grid_v,w_grid_v,z1,col='darkred',lwd=2, nlevels = 8)
points(opt1$p1,opt1$p2,col='darkred',pch=19)
grid()
contour(w_grid_v,w_grid_v,z2,col='blue',lwd=2, nlevels = 8)
points(opt2$p1,opt2$p2,col='blue',pch=19)
grid()


# library(rgl)
# persp3d(w_grid_v,w_grid_v,z1,col='darkred')

2
I don't see any non-convexity on the third subplot of your first figure...
amoeba says Reinstate Monica

@amoeba I thought convex contour is more like ellipse, two U shaped curve back to back is non-convex, is that right?
Haitao Du

2
No, why? Maybe it's a part of a larger ellipse-like contour? I mean, it might very well be non-convex, I am just saying that I do not see it on this particular figure.
amoeba says Reinstate Monica
Khi sử dụng trang web của chúng tôi, bạn xác nhận rằng bạn đã đọc và hiểu Chính sách cookieChính sách bảo mật của chúng tôi.
Licensed under cc by-sa 3.0 with attribution required.