Tôi đã xem qua giấy BERT sử dụng GELU (Đơn vị tuyến tính lỗi Gaussian) trong đó nêu phương trình là
mà tương ứng với 0,5x (1 + tanh [\ sqrt {2 / π} (x + 0,044715x ^ 3)])
Bạn có thể đơn giản hóa phương trình và giải thích cách nó đã được phê duyệt.
Tôi đã xem qua giấy BERT sử dụng GELU (Đơn vị tuyến tính lỗi Gaussian) trong đó nêu phương trình là
mà tương ứng với 0,5x (1 + tanh [\ sqrt {2 / π} (x + 0,044715x ^ 3)])
Bạn có thể đơn giản hóa phương trình và giải thích cách nó đã được phê duyệt.
Câu trả lời:
Chúng tôi có thể mở rộng phân phối tích lũy của , tức là , như sau:
Lưu ý rằng đây là một định nghĩa , không phải là một phương trình (hoặc một mối quan hệ). Các tác giả đã cung cấp một số biện minh cho đề xuất này, ví dụ như một sự tương tự ngẫu nhiên , tuy nhiên về mặt toán học, đây chỉ là một định nghĩa.
Đây là cốt truyện của GELU:
Đối với các loại xấp xỉ bằng số này, ý tưởng chính là tìm một hàm tương tự (chủ yếu dựa trên kinh nghiệm), tham số hóa nó, sau đó khớp nó với một tập hợp các điểm từ hàm ban đầu.
Biết rằng rất gần với
và đạo hàm đầu tiên của trùng với tại , đó là , chúng tôi tiến hành điều chỉnh
(hoặc có nhiều điều khoản hơn) cho một tập hợp các điểm .
Tôi đã trang bị chức năng này cho 20 mẫu trong khoảng ( sử dụng trang này ) và đây là các hệ số:
Bằng cách đặt , được ước tính là . Với nhiều mẫu hơn từ phạm vi rộng hơn (trang web đó chỉ cho phép 20), hệ số sẽ gần hơn với của giấy . Cuối cùng chúng tôi cũng nhận được
với lỗi bình phương trung bình cho .
Lưu ý rằng nếu chúng tôi không sử dụng mối quan hệ giữa các dẫn xuất đầu tiên, thuật ngữ sẽ được bao gồm trong các tham số như sau
ít đẹp hơn (ít phân tích hơn, nhiều số hơn)!
Theo đề xuất của @BookYourLuck , chúng tôi có thể sử dụng tính chẵn lẻ của các hàm để hạn chế không gian của đa thức mà chúng tôi tìm kiếm. Nghĩa là, vì là một hàm lẻ, tức là và cũng là một hàm lẻ, hàm đa thức bên trong cũng phải là số lẻ (chỉ nên có quyền hạn lẻ của ) để có
Trước đây, chúng tôi đã may mắn khi kết thúc với (hầu như) không hệ số cho sức mạnh thậm chí và , tuy nhiên nói chung, điều này có thể dẫn đến xấp xỉ chất lượng thấp, ví dụ, có thời hạn như mà đang bị hủy bỏ bởi các điều khoản bổ sung (chẵn hoặc lẻ) thay vì chỉ chọn .
Một mối quan hệ tương tự giữ giữa và (sigmoid), được đề xuất trong bài báo dưới dạng một xấp xỉ khác, với lỗi bình phương trung bình cho .
Dưới đây là mã Python để tạo các điểm dữ liệu, khớp các hàm và tính toán các lỗi bình phương trung bình:
import math
import numpy as np
import scipy.optimize as optimize
def tahn(xs, a):
return [math.tanh(math.sqrt(2 / math.pi) * (x + a * x**3)) for x in xs]
def sigmoid(xs, a):
return [2 * (1 / (1 + math.exp(-a * x)) - 0.5) for x in xs]
print_points = 0
np.random.seed(123)
# xs = [-2, -1, -.9, -.7, 0.6, -.5, -.4, -.3, -0.2, -.1, 0,
# .1, 0.2, .3, .4, .5, 0.6, .7, .9, 2]
# xs = np.concatenate((np.arange(-1, 1, 0.2), np.arange(-4, 4, 0.8)))
# xs = np.concatenate((np.arange(-2, 2, 0.5), np.arange(-8, 8, 1.6)))
xs = np.arange(-10, 10, 0.001)
erfs = np.array([math.erf(x/math.sqrt(2)) for x in xs])
ys = np.array([0.5 * x * (1 + math.erf(x/math.sqrt(2))) for x in xs])
# Fit tanh and sigmoid curves to erf points
tanh_popt, _ = optimize.curve_fit(tahn, xs, erfs)
print('Tanh fit: a=%5.5f' % tuple(tanh_popt))
sig_popt, _ = optimize.curve_fit(sigmoid, xs, erfs)
print('Sigmoid fit: a=%5.5f' % tuple(sig_popt))
# curves used in https://mycurvefit.com:
# 1. sinh(sqrt(2/3.141593)*(x+a*x^2+b*x^3+c*x^4+d*x^5))/cosh(sqrt(2/3.141593)*(x+a*x^2+b*x^3+c*x^4+d*x^5))
# 2. sinh(sqrt(2/3.141593)*(x+b*x^3))/cosh(sqrt(2/3.141593)*(x+b*x^3))
y_paper_tanh = np.array([0.5 * x * (1 + math.tanh(math.sqrt(2/math.pi)*(x + 0.044715 * x**3))) for x in xs])
tanh_error_paper = (np.square(ys - y_paper_tanh)).mean()
y_alt_tanh = np.array([0.5 * x * (1 + math.tanh(math.sqrt(2/math.pi)*(x + tanh_popt[0] * x**3))) for x in xs])
tanh_error_alt = (np.square(ys - y_alt_tanh)).mean()
# curve used in https://mycurvefit.com:
# 1. 2*(1/(1+2.718281828459^(-(a*x))) - 0.5)
y_paper_sigmoid = np.array([x * (1 / (1 + math.exp(-1.702 * x))) for x in xs])
sigmoid_error_paper = (np.square(ys - y_paper_sigmoid)).mean()
y_alt_sigmoid = np.array([x * (1 / (1 + math.exp(-sig_popt[0] * x))) for x in xs])
sigmoid_error_alt = (np.square(ys - y_alt_sigmoid)).mean()
print('Paper tanh error:', tanh_error_paper)
print('Alternative tanh error:', tanh_error_alt)
print('Paper sigmoid error:', sigmoid_error_paper)
print('Alternative sigmoid error:', sigmoid_error_alt)
if print_points == 1:
print(len(xs))
for x, erf in zip(xs, erfs):
print(x, erf)
Đầu ra:
Tanh fit: a=0.04485
Sigmoid fit: a=1.70099
Paper tanh error: 2.4329173471294176e-08
Alternative tanh error: 2.698034519269613e-08
Paper sigmoid error: 5.6479106346814546e-05
Alternative sigmoid error: 5.704246564663601e-05