Ben geçiyordu Bert kağıt kullanan Gelu (Gauss hata Doğrusal Birimi) olarak denklemi devletler
bu da
Denklemi basitleştirebilir ve nasıl yaklaştığını açıklayabilir misiniz?
Ben geçiyordu Bert kağıt kullanan Gelu (Gauss hata Doğrusal Birimi) olarak denklemi devletler
bu da
Denklemi basitleştirebilir ve nasıl yaklaştığını açıklayabilir misiniz?
Yanıtlar:
N ( 0 , 1 ) , yani Φ ( x ) kümülatif dağılımını aşağıdaki gibi genişletebiliriz :
GELU ( x ) : = x P ( X ≤ x ) = x Φ ( x ) = 0,5 x ( 1 + erf ( x
Bunun bir denklem (veya bir ilişki) değil bir tanım olduğunu unutmayın . Yazarlar bu öneri için bazı gerekçeler sunmuşlardır, örneğin stokastik bir benzetme , ancak matematiksel olarak bu sadece bir tanımdır.
İşte GELU'nun konusu:
Bu tür sayısal yaklaşımlar için, anahtar fikir benzer bir işlevi (öncelikle deneyime dayanarak) bulmak, parametreleştirmek ve daha sonra orijinal işlevden bir dizi noktaya sığdırmaktır.
in tanh ( x )' a çok yakın olduğunu bilmek
ve erf ( x'in ilk türeviTan'inkinedenk gelir(√deise, ,
Bu işlevi ( bu siteyi kullanarak ) arasında 20 örneğe yerleştirdim ve işte katsayılar:
Ayarlayarak , olarak tahmin edilmiştir . Daha geniş bir aralıktan daha fazla örnekle (yalnızca bu alana 20 izin verilir), katsayısı kağıdın daha yakın olacaktır . Sonunda aldık
Veri noktaları oluşturmak, işlevlere uymak ve ortalama kare hatalarını hesaplamak için bir Python kodu:
import math
import numpy as np
import scipy.optimize as optimize
def tahn(xs, a):
return [math.tanh(math.sqrt(2 / math.pi) * (x + a * x**3)) for x in xs]
def sigmoid(xs, a):
return [2 * (1 / (1 + math.exp(-a * x)) - 0.5) for x in xs]
print_points = 0
np.random.seed(123)
# xs = [-2, -1, -.9, -.7, 0.6, -.5, -.4, -.3, -0.2, -.1, 0,
# .1, 0.2, .3, .4, .5, 0.6, .7, .9, 2]
# xs = np.concatenate((np.arange(-1, 1, 0.2), np.arange(-4, 4, 0.8)))
# xs = np.concatenate((np.arange(-2, 2, 0.5), np.arange(-8, 8, 1.6)))
xs = np.arange(-10, 10, 0.001)
erfs = np.array([math.erf(x/math.sqrt(2)) for x in xs])
ys = np.array([0.5 * x * (1 + math.erf(x/math.sqrt(2))) for x in xs])
# Fit tanh and sigmoid curves to erf points
tanh_popt, _ = optimize.curve_fit(tahn, xs, erfs)
print('Tanh fit: a=%5.5f' % tuple(tanh_popt))
sig_popt, _ = optimize.curve_fit(sigmoid, xs, erfs)
print('Sigmoid fit: a=%5.5f' % tuple(sig_popt))
# curves used in https://mycurvefit.com:
# 1. sinh(sqrt(2/3.141593)*(x+a*x^2+b*x^3+c*x^4+d*x^5))/cosh(sqrt(2/3.141593)*(x+a*x^2+b*x^3+c*x^4+d*x^5))
# 2. sinh(sqrt(2/3.141593)*(x+b*x^3))/cosh(sqrt(2/3.141593)*(x+b*x^3))
y_paper_tanh = np.array([0.5 * x * (1 + math.tanh(math.sqrt(2/math.pi)*(x + 0.044715 * x**3))) for x in xs])
tanh_error_paper = (np.square(ys - y_paper_tanh)).mean()
y_alt_tanh = np.array([0.5 * x * (1 + math.tanh(math.sqrt(2/math.pi)*(x + tanh_popt[0] * x**3))) for x in xs])
tanh_error_alt = (np.square(ys - y_alt_tanh)).mean()
# curve used in https://mycurvefit.com:
# 1. 2*(1/(1+2.718281828459^(-(a*x))) - 0.5)
y_paper_sigmoid = np.array([x * (1 / (1 + math.exp(-1.702 * x))) for x in xs])
sigmoid_error_paper = (np.square(ys - y_paper_sigmoid)).mean()
y_alt_sigmoid = np.array([x * (1 / (1 + math.exp(-sig_popt[0] * x))) for x in xs])
sigmoid_error_alt = (np.square(ys - y_alt_sigmoid)).mean()
print('Paper tanh error:', tanh_error_paper)
print('Alternative tanh error:', tanh_error_alt)
print('Paper sigmoid error:', sigmoid_error_paper)
print('Alternative sigmoid error:', sigmoid_error_alt)
if print_points == 1:
print(len(xs))
for x, erf in zip(xs, erfs):
print(x, erf)
Çıktı:
Tanh fit: a=0.04485
Sigmoid fit: a=1.70099
Paper tanh error: 2.4329173471294176e-08
Alternative tanh error: 2.698034519269613e-08
Paper sigmoid error: 5.6479106346814546e-05
Alternative sigmoid error: 5.704246564663601e-05
Büyük değerler için , her iki işlev de . Küçük içinilgili Taylor dizisi