Code
program bench
use iso_fortran_env, only: int64
integer, parameter :: n = 1e9
real, parameter :: x(*) = [(i, i=-5,5)]
integer :: i
real :: dat(n)
integer(int64) :: clock_start, clock_end, clock_rate
call system_clock(count_rate=clock_rate)
print *, "comparison, x = ", x
print *, "gelu_erf, y = ", gelu_erf(x)
print *, "gelu_fast_erf, y = ", gelu_fast_erf(x)
print *, "gelu_fast_erf2, y = ", gelu_fast_erf2(x)
print *, "gelu_fast_erf_fold, y = ", gelu_fast_erf2(x)
print *, "gelu_tanh, y = ", gelu_tanh(x)
print *, "gelu_fast_tanh, y = ", gelu_fast_tanh(x)
print *, "gelu_fast_tanh2, y = ", gelu_fast_tanh2(x)
! warmup
print *, "Warmup"
dat(:) = 1.
dat(:) = erf(dat(:))
dat(:) = erf(dat(:))
dat(:) = erf(dat(:))
if(minval(dat) /= 1.) print *
print *, "gelu_erf"
dat(:) = 1.
call system_clock(clock_start)
do i=1,n
dat(i) = gelu_erf(dat(i))
end do
call system_clock(clock_end)
print*, "loop", real(clock_end-clock_start) / clock_rate / n
if(minval(dat) == 1.) print *
dat(:) = 1.
call system_clock(clock_start)
dat(:) = gelu_erf(dat(:))
call system_clock(clock_end)
print*, "array", real(clock_end-clock_start) / clock_rate / n
if(minval(dat) /= 1.) print *
print *, "gelu_fast_erf"
dat(:) = 1.
call system_clock(clock_start)
do i=1,n
dat(i) = gelu_fast_erf(dat(i))
end do
call system_clock(clock_end)
print*, "loop", real(clock_end-clock_start) / clock_rate / n
if(minval(dat) == 1.) print *
dat(:) = 1.
call system_clock(clock_start)
dat(:) = gelu_fast_erf(dat(:))
call system_clock(clock_end)
print*, "array", real(clock_end-clock_start) / clock_rate / n
if(minval(dat) /= 1.) print *
print *, "gelu_fast_erf2"
dat(:) = 1.
call system_clock(clock_start)
do i=1,n
dat(i) = gelu_fast_erf2(dat(i))
end do
call system_clock(clock_end)
print*, "loop", real(clock_end-clock_start) / clock_rate / n
if(minval(dat) == 1.) print *
dat(:) = 1.
call system_clock(clock_start)
dat(:) = gelu_fast_erf2(dat(:))
call system_clock(clock_end)
print*, "array", real(clock_end-clock_start) / clock_rate / n
if(minval(dat) /= 1.) print *
print *, "gelu_fast_erf_fold"
dat(:) = 1.
call system_clock(clock_start)
do i=1,n
dat(i) = gelu_fast_erf_fold(dat(i))
end do
call system_clock(clock_end)
print*, "loop", real(clock_end-clock_start) / clock_rate / n
if(minval(dat) == 1.) print *
dat(:) = 1.
call system_clock(clock_start)
dat(:) = gelu_fast_erf_fold(dat(:))
call system_clock(clock_end)
print*, "array", real(clock_end-clock_start) / clock_rate / n
if(minval(dat) /= 1.) print *
print *, "gelu_tanh"
dat(:) = 1.
call system_clock(clock_start)
do i=1,n
dat(i) = gelu_tanh(dat(i))
end do
call system_clock(clock_end)
print*, "loop", real(clock_end-clock_start) / clock_rate / n
if(minval(dat) == 1.) print *
dat(:) = 1.
call system_clock(clock_start)
dat(:) = gelu_tanh(dat(:))
call system_clock(clock_end)
print*, "array", real(clock_end-clock_start) / clock_rate / n
if(minval(dat) /= 1.) print *
print *, "gelu_fast_tanh"
dat(:) = 1.
call system_clock(clock_start)
do i=1,n
dat(i) = gelu_fast_tanh(dat(i))
end do
call system_clock(clock_end)
print*, "loop", real(clock_end-clock_start) / clock_rate / n
if(minval(dat) == 1.) print *
dat(:) = 1.
call system_clock(clock_start)
dat(:) = gelu_fast_tanh(dat(:))
call system_clock(clock_end)
print*, "array", real(clock_end-clock_start) / clock_rate / n
if(minval(dat) /= 1.) print *
print *, "gelu_fast_tanh2"
dat(:) = 1.
call system_clock(clock_start)
do i=1,n
dat(i) = gelu_fast_tanh2(dat(i))
end do
call system_clock(clock_end)
print*, "loop", real(clock_end-clock_start) / clock_rate / n
if(minval(dat) == 1.) print *
dat(:) = 1.
call system_clock(clock_start)
dat(:) = gelu_fast_tanh2(dat(:))
call system_clock(clock_end)
print*, "array", real(clock_end-clock_start) / clock_rate / n
if(minval(dat) /= 1.) print *
contains
elemental real function gelu_erf(x) result(y)
real, intent(in) :: x
real, parameter :: inverse_root_2 = 1 / sqrt(2.)
y = 0.5 * x * (1 + erf(x * inverse_root_2))
end function
elemental real function gelu_fast_erf(x) result(y)
real, intent(in) :: x
real, parameter :: inverse_root_2 = 1 / sqrt(2.)
y = 0.5 * x * (1 + fast_erf(x * inverse_root_2))
end function
elemental real function gelu_fast_erf2(x) result(y)
real, intent(in) :: x
real, parameter :: inverse_root_2 = 1 / sqrt(2.)
y = 0.5 * x * (1 + fast_erf2(x * inverse_root_2))
end function
elemental real function gelu_fast_erf_fold(x) result(y)
real, intent(in) :: x
y = 0.5 * x * (1 + fast_erf_fold(x))
end function
elemental real function gelu_tanh(x) result(y)
real, intent(in) :: x
real, parameter :: inverse_pi = 1 / (4*atan(1.))
y = 0.5 * x * (1 + tanh(sqrt(2 * inverse_pi) * (x + 0.044715 * x**3)))
end function
elemental real function gelu_fast_tanh(x) result(y)
real, intent(in) :: x
real, parameter :: inverse_pi = 1 / (4*atan(1.))
y = 0.5 * x * (1 + fast_tanh(sqrt(2 * inverse_pi) * (x + 0.044715 * x**3)))
end function
elemental real function gelu_fast_tanh2(x) result(y)
real, intent(in) :: x
real, parameter :: inverse_pi = 1 / (4*atan(1.))
y = 0.5 * x * (1 + fast_tanh2(sqrt(2 * inverse_pi) * (x + 0.044715 * x**3)))
end function
elemental real function fast_tanh(x) result(y)
real, intent(in) :: x
real :: x2
if (x > 5) then
y = 1
elseif (x < -5) then
y = -1
else
x2 = x*x
y = x * (0.98569772605911309407 + x2 *(-0.2794500993392901382 &
+ x2 * (6.8280504526399188164e-2 + x2 * (-1.0972014877337651823e-2 &
+ x2 * (1.1132367134444316902e-3 + x2 * (-7.018851897305717565e-5 &
+ x2 * (2.656616768082727089e-6 + x2 * (-5.5138381821615909058e-8 &
+ x2 * 4.8162484477588665996e-10))))))))
end if
end function
elemental real function fast_tanh2(x) result(y)
real, intent(in) :: x
real :: x2, a, b
if (x > 5) then
y = 1
elseif (x < -5) then
y = -1
else
x2 = x*x
a = x * (135135.0 + x2 * (17325.0 + x2 * (378.0 + x2)))
b = 135135.0 + x2 * (62370.0 + x2 * (3150.0 + x2 * 28.0))
y = a / b
end if
end function
elemental real function fast_erf(x) result(y)
real, intent(in) :: x
real :: abs_x
abs_x = abs(x)
y = 1 - 1 / (1+ 0.278393*abs_x + 0.230389*abs_x**2 + 0.000972*abs_x**3 + 0.078108*abs_x**4)**4
y = merge(y, -y, x >= 0)
end function
elemental real function fast_erf2(x) result(y)
real, intent(in) :: x
abs_x = abs(x)
y = 1 - 1 / (1+ abs_x * (0.278393 + abs_x * (0.230389 + abs_x * (0.000972 + 0.078108*abs_x))))**4
y = merge(y, -y, x >= 0)
end function
elemental real function fast_erf_fold(x) result(y)
real, intent(in) :: x
real :: x2, res
real, parameter :: c1 = 0.7975839, c2 = -0.13200624, c3 = 0.019021248, &
c4 = -0.0019748025, c5 = 0.00013678304, c6 = -5.5545797e-6, c7 = 9.853275e-8
x2 = x * x
res = x * (c1 + x2 * (c2 + x2 * (c3 + x2 * (c4 + x2 * (c5 + x2 * (c6 + x2 * c7))))))
if (x2 < 12.75) then
fasterf = res
else
fasterf = sign(1.0, x)
end if
end function
end program bench
PS: ChatGPT was kind enough to translate the Julia version to Fortran.