Kích hoạt GELU là gì?


17

Tôi đã xem qua giấy BERT sử dụng GELU (Đơn vị tuyến tính lỗi Gaussian) trong đó nêu phương trình là mà tương ứng với 0,5x (1 + tanh [\ sqrt {2 / π} (x + 0,044715x ^ 3)]) Bạn có thể đơn giản hóa phương trình và giải thích cách nó đã được phê duyệt.

GELU(x)=xP(Xx)=xΦ(x).
0.5x(1+tanh[2/π(x+0.044715x3)])

Câu trả lời:


16

Hàm GELU

Chúng tôi có thể mở rộng phân phối tích lũy của N(0,1) , tức là Φ(x) , như sau:

GELU(x):=xP(Xx)=xΦ(x)=0.5x(1+erf(x2))

Lưu ý rằng đây là một định nghĩa , không phải là một phương trình (hoặc một mối quan hệ). Các tác giả đã cung cấp một số biện minh cho đề xuất này, ví dụ như một sự tương tự ngẫu nhiên , tuy nhiên về mặt toán học, đây chỉ là một định nghĩa.

Đây là cốt truyện của GELU:

Tanh xấp xỉ

Đối với các loại xấp xỉ bằng số này, ý tưởng chính là tìm một hàm tương tự (chủ yếu dựa trên kinh nghiệm), tham số hóa nó, sau đó khớp nó với một tập hợp các điểm từ hàm ban đầu.

Biết rằng rất gần vớierf(x)tanh(x)

và đạo hàm đầu tiên của trùng với tại , đó là , chúng tôi tiến hành điều chỉnh (hoặc có nhiều điều khoản hơn) cho một tập hợp các điểm .erf(x2)tanh(2πx)x=02π

tanh(2π(x+ax2+bx3+cx4+dx5))
(xi,erf(xi2))

Tôi đã trang bị chức năng này cho 20 mẫu trong khoảng ( sử dụng trang này ) và đây là các hệ số:(1.5,1.5)

Bằng cách đặt , được ước tính là . Với nhiều mẫu hơn từ phạm vi rộng hơn (trang web đó chỉ cho phép 20), hệ số sẽ gần hơn với của giấy . Cuối cùng chúng tôi cũng nhận đượca=c=d=0b0.04495641b0.044715

GELU(x)=xΦ(x)=0.5x(1+erf(x2))0.5x(1+tanh(2π(x+0.044715x3)))

với lỗi bình phương trung bình cho .108x[10,10]

Lưu ý rằng nếu chúng tôi không sử dụng mối quan hệ giữa các dẫn xuất đầu tiên, thuật ngữ sẽ được bao gồm trong các tham số như sau ít đẹp hơn (ít phân tích hơn, nhiều số hơn)!2π

0.5x(1+tanh(0.797885x+0.035677x3))

Sử dụng chẵn lẻ

Theo đề xuất của @BookYourLuck , chúng tôi có thể sử dụng tính chẵn lẻ của các hàm để hạn chế không gian của đa thức mà chúng tôi tìm kiếm. Nghĩa là, vì là một hàm lẻ, tức là và cũng là một hàm lẻ, hàm đa thức bên trong cũng phải là số lẻ (chỉ nên có quyền hạn lẻ của ) để có erff(x)=f(x)tanhpol(x)tanhx

erf(x)tanh(pol(x))=tanh(pol(x))=tanh(pol(x))erf(x)

Trước đây, chúng tôi đã may mắn khi kết thúc với (hầu như) không hệ số cho sức mạnh thậm chí và , tuy nhiên nói chung, điều này có thể dẫn đến xấp xỉ chất lượng thấp, ví dụ, có thời hạn như mà đang bị hủy bỏ bởi các điều khoản bổ sung (chẵn hoặc lẻ) thay vì chỉ chọn .x2x40.23x20x2

Xấp xỉ sigma

Một mối quan hệ tương tự giữ giữa và (sigmoid), được đề xuất trong bài báo dưới dạng một xấp xỉ khác, với lỗi bình phương trung bình cho .erf(x)2(σ(x)12)104x[10,10]

Dưới đây là mã Python để tạo các điểm dữ liệu, khớp các hàm và tính toán các lỗi bình phương trung bình:

import math
import numpy as np
import scipy.optimize as optimize


def tahn(xs, a):
    return [math.tanh(math.sqrt(2 / math.pi) * (x + a * x**3)) for x in xs]


def sigmoid(xs, a):
    return [2 * (1 / (1 + math.exp(-a * x)) - 0.5) for x in xs]


print_points = 0
np.random.seed(123)
# xs = [-2, -1, -.9, -.7, 0.6, -.5, -.4, -.3, -0.2, -.1, 0,
#       .1, 0.2, .3, .4, .5, 0.6, .7, .9, 2]
# xs = np.concatenate((np.arange(-1, 1, 0.2), np.arange(-4, 4, 0.8)))
# xs = np.concatenate((np.arange(-2, 2, 0.5), np.arange(-8, 8, 1.6)))
xs = np.arange(-10, 10, 0.001)
erfs = np.array([math.erf(x/math.sqrt(2)) for x in xs])
ys = np.array([0.5 * x * (1 + math.erf(x/math.sqrt(2))) for x in xs])

# Fit tanh and sigmoid curves to erf points
tanh_popt, _ = optimize.curve_fit(tahn, xs, erfs)
print('Tanh fit: a=%5.5f' % tuple(tanh_popt))

sig_popt, _ = optimize.curve_fit(sigmoid, xs, erfs)
print('Sigmoid fit: a=%5.5f' % tuple(sig_popt))

# curves used in https://mycurvefit.com:
# 1. sinh(sqrt(2/3.141593)*(x+a*x^2+b*x^3+c*x^4+d*x^5))/cosh(sqrt(2/3.141593)*(x+a*x^2+b*x^3+c*x^4+d*x^5))
# 2. sinh(sqrt(2/3.141593)*(x+b*x^3))/cosh(sqrt(2/3.141593)*(x+b*x^3))
y_paper_tanh = np.array([0.5 * x * (1 + math.tanh(math.sqrt(2/math.pi)*(x + 0.044715 * x**3))) for x in xs])
tanh_error_paper = (np.square(ys - y_paper_tanh)).mean()
y_alt_tanh = np.array([0.5 * x * (1 + math.tanh(math.sqrt(2/math.pi)*(x + tanh_popt[0] * x**3))) for x in xs])
tanh_error_alt = (np.square(ys - y_alt_tanh)).mean()

# curve used in https://mycurvefit.com:
# 1. 2*(1/(1+2.718281828459^(-(a*x))) - 0.5)
y_paper_sigmoid = np.array([x * (1 / (1 + math.exp(-1.702 * x))) for x in xs])
sigmoid_error_paper = (np.square(ys - y_paper_sigmoid)).mean()
y_alt_sigmoid = np.array([x * (1 / (1 + math.exp(-sig_popt[0] * x))) for x in xs])
sigmoid_error_alt = (np.square(ys - y_alt_sigmoid)).mean()

print('Paper tanh error:', tanh_error_paper)
print('Alternative tanh error:', tanh_error_alt)
print('Paper sigmoid error:', sigmoid_error_paper)
print('Alternative sigmoid error:', sigmoid_error_alt)

if print_points == 1:
    print(len(xs))
    for x, erf in zip(xs, erfs):
        print(x, erf)

Đầu ra:

Tanh fit: a=0.04485
Sigmoid fit: a=1.70099
Paper tanh error: 2.4329173471294176e-08
Alternative tanh error: 2.698034519269613e-08
Paper sigmoid error: 5.6479106346814546e-05
Alternative sigmoid error: 5.704246564663601e-05

2
Tại sao gần đúng là cần thiết? Họ không thể sử dụng chức năng erf sao?
SebiSebi

8

Φ(x)=12erfc(x2)=12(1+erf(x2))
erf
erf(x2)tanh(2π(x+ax3))
a0.044715

x[1,1]x

tanh(x)=xx33+o(x3)
erf(x)=2π(xx33)+o(x3).
tanh(2π(x+ax3))=2π(x+(a23π)x3)+o(x3)
erf(x2)=2π(xx36)+o(x3).
x3
a0.04553992412
0.044715

Khi sử dụng trang web của chúng tôi, bạn xác nhận rằng bạn đã đọc và hiểu Chính sách cookieChính sách bảo mật của chúng tôi.
Licensed under cc by-sa 3.0 with attribution required.