My list of Fortran codes on GitHub has a Neural Networks and Machine Learning section.
Where’s the fpm.toml
file?
This is very impressive! I hope this will bring a lot of attention to Fortran.
What’s the benefit of using n_embd
and n_seq
in the mha
function as extra arguments, instead of using ubound
or extend
of x? Is it just for readability, or am I missing something?
real(sp), intent(in) :: x(n_embd,n_seq)
Oops, forgot to mention that this is erf(x/sqrt(2))
since that is what gelu
needs.
@oscardssmith here I found a case where my fast_tanh()
function produces different output than tanh()
: An example where the current fast_tanh() gives different results · Issue #25 · certik/fastGPT · GitHub. Both outputs look ok. How do you judge which one is better?
I assume what is happening is that it gives probabilities of all the tokens, and if I print them, I assume I would find similar probabilities for both cases, but slightly numerically different (due to the tanh
numerical differences), and the “greedy” mode selects a different token, but from the probability perspective the results might still be “equivalent”. Is there a way to determine at which point the results stop being “equivalent”? What accuracy in the final token probabilities is needed?
I wonder if one can think of the reduced precision tanh
as reducing the precision of the whole model, and there are other ways to do it as well, such are reducing the default 32bit float weights to 16bit, 8bit or even 4bit. It must affect the final probabilities, but I wonder what are some ways to judge the quality of the result. One way would be to compute the error function for some texts and see how much it changes based on the various reduced precision changes. Is that the way to approach it? And if it gets worse just by a few percent, it’s not a problem, but if it changes a lot, it might be a problem?
It’s a little bit hard to tell. On GPUs these models are likely running with bfloat16 or a mixed precision scheme so I would think that as long as you are within 2^-10 or so you should get reasonable results, but it’s hard to say.
This is a very good point. I initially just used declarations like real(sp), intent(in) :: x(:,:)
, but there are lots of arrays and it quickly became really hard to ensure I didn’t make a mistake. So I reverted to the style of real(sp), intent(in) :: x(n_embd,n_seq)
where the compiler can check the compatibility of arrays (typically at runtime) and that helped a lot to catch bugs, that I multiply matrices correctly, and loop over the correct bounds, etc. In Python the indices are reverted (column vs row major), so it’s really easy to get it wrong. I also found it’s nicer to document what each index is doing directly like this, rather than having it as comments. I could infer it and still set dimensions of everything, but that becomes very hard to read if you have size(x, 1)
everywhere. Much more natural to use the problem parameters like n_embd
, n_seq
, or n_layers
.
See my proposal here for how this can be improved: How to concisely declare a function result as a multi-dimensional automatic array? - #4 by certik
Obviously … if I had read the post more carefully. Sorry for the noise.
I created a GSoC idea for fastGPT here:
If you are a student and interested, please let us know!
I implemented the input tokenizer (encoder) in Fortran, so now fastGPT
is completely standalone. This allows to create a chat interface, which I also did. Here is an example session using the largest 1558M model:
$ ./chat
Your name is fastGPT and you are an AI bot. The user will ask you questions and you answer in a nice, truthful, short way.
User: What is the capital of Czechia?
fastGPT: Prague.
User: How many legs does a dog have?
fastGPT: Four.
User: What color does the sky have?
fastGPT: Blue.
User: What can you type a document on?
fastGPT: A typewriter.
User: What can you drive in?
fastGPT: A car.
User: What can you fly in?
fastGPT: A plane.
User: What continent is Germany in?
fastGPT: Europe.
User: When did Second World War start?
fastGPT: 1939.
User: When did it end?
fastGPT: 1945.
User: When did the U.S. enter the Second World War?
fastGPT: 1941.
User: When did the First World War start?
fastGPT: 1914.
User: When did it end?
fastGPT: 1918.
User: When did the Mexican-American war start?
fastGPT: 1846.
User: When did it end?
fastGPT: 1848.
User: What color is snow?
fastGPT: White.
User: What color do plants usually have?
fastGPT: Green.
User: What is your name?
fastGPT: fastGPT.
(You must scroll to see the full output.)
I think it’s actually very impressive that GPT-2 without any fine-tuning can not only act as a chat bot, but even answer all these questions correctly! All running locally, the inference calculation in about 300 lines of Fortran.
Brilliant, thank you!
Readers should now know for sure this is yet barely scratching the surface of what’s possible with Fortran.
With a bit of added language support and the increasingly better ecosystem, Fortran can be among the first choice languages for any form of computing, not merely number-crunching, with easy to read and good-looking syntax as well as elegant semantics.
Is the next step to make use of @interkosmos 's and/or GSoC effort to use the Fortran interface to CURL
or some transfer protocol library so that the model creation can also be in Fortran? From the “create_model” Python script, it appears the only gap currently might be this get
request and it can be addressed with a wrapper Fortran library:
Ultimately it should be possible to author a Fortran-only multiprotocol file transfer library from the ground up, especially for easy-to-use client-side URL transfer, and perhaps make it part of stdlib
too?
The create_model.py
uses TensorFlow (I think) to load its model and transform to our custom (fast) binary file model.dat
. It could be done in Fortran, but I think a better approach is to upload our own model.dat
file online and just download it, I even have an issue for it here: Upload model.dat online · Issue #31 · certik/fastGPT · GitHub.
However, being able to download from Fortran would be helpful.
It might be better to use an approximated erf
function instead of an approximated tanh
function. I did some testing, and it seems to be 10-20% faster to calculate gelu
with fast_erf
compared to gelu
using fast_tanh
. Tested on Linux, i7-10510U CPU @ 1.80GHz, GFortran with -O3 -ffast-math -march=native
Result
gelu_erf
loop 1.11502607E-09
array 1.10498388E-09
gelu_fast_erf
loop 6.49572729E-10
array 6.29547081E-10
gelu_fast_erf2
loop 5.84478632E-10
array 5.82690174E-10
gelu_tanh
loop 1.60507918E-09
array 1.57718894E-09
gelu_fast_tanh
loop 6.57730481E-10
array 6.95166535E-10
gelu_fast_tanh2
loop 7.59801999E-10
array 7.49453777E-10
Code
program bench
use iso_fortran_env, only: int64
integer, parameter :: n = 1e9
real, parameter :: x(*) = [(i, i=-5,5)]
integer :: i
real :: dat(n)
integer(int64) :: clock_start, clock_end, clock_rate
call system_clock(count_rate=clock_rate)
print *, "comparison, x = ", x
print *, "gelu_erf, y = ", gelu_erf(x)
print *, "gelu_fast_erf, y = ", gelu_fast_erf(x)
print *, "gelu_fast_erf2, y = ", gelu_fast_erf2(x)
print *, "gelu_tanh, y = ", gelu_tanh(x)
print *, "gelu_fast_tanh, y = ", gelu_fast_tanh(x)
print *, "gelu_fast_tanh2, y = ", gelu_fast_tanh2(x)
! warmup
print *, "Warmup"
dat(:) = 1.
dat(:) = erf(dat(:))
dat(:) = erf(dat(:))
dat(:) = erf(dat(:))
if(minval(dat) /= 1.) print *
print *, "gelu_erf"
dat(:) = 1.
call system_clock(clock_start)
do i=1,n
dat(i) = gelu_erf(dat(i))
end do
call system_clock(clock_end)
print*, "loop", real(clock_end-clock_start) / clock_rate / n
if(minval(dat) == 1.) print *
dat(:) = 1.
call system_clock(clock_start)
dat(:) = gelu_erf(dat(:))
call system_clock(clock_end)
print*, "array", real(clock_end-clock_start) / clock_rate / n
if(minval(dat) /= 1.) print *
print *, "gelu_fast_erf"
dat(:) = 1.
call system_clock(clock_start)
do i=1,n
dat(i) = gelu_fast_erf(dat(i))
end do
call system_clock(clock_end)
print*, "loop", real(clock_end-clock_start) / clock_rate / n
if(minval(dat) == 1.) print *
dat(:) = 1.
call system_clock(clock_start)
dat(:) = gelu_fast_erf(dat(:))
call system_clock(clock_end)
print*, "array", real(clock_end-clock_start) / clock_rate / n
if(minval(dat) /= 1.) print *
print *, "gelu_fast_erf2"
dat(:) = 1.
call system_clock(clock_start)
do i=1,n
dat(i) = gelu_fast_erf2(dat(i))
end do
call system_clock(clock_end)
print*, "loop", real(clock_end-clock_start) / clock_rate / n
if(minval(dat) == 1.) print *
dat(:) = 1.
call system_clock(clock_start)
dat(:) = gelu_fast_erf2(dat(:))
call system_clock(clock_end)
print*, "array", real(clock_end-clock_start) / clock_rate / n
if(minval(dat) /= 1.) print *
print *, "gelu_tanh"
dat(:) = 1.
call system_clock(clock_start)
do i=1,n
dat(i) = gelu_tanh(dat(i))
end do
call system_clock(clock_end)
print*, "loop", real(clock_end-clock_start) / clock_rate / n
if(minval(dat) == 1.) print *
dat(:) = 1.
call system_clock(clock_start)
dat(:) = gelu_tanh(dat(:))
call system_clock(clock_end)
print*, "array", real(clock_end-clock_start) / clock_rate / n
if(minval(dat) /= 1.) print *
print *, "gelu_fast_tanh"
dat(:) = 1.
call system_clock(clock_start)
do i=1,n
dat(i) = gelu_fast_tanh(dat(i))
end do
call system_clock(clock_end)
print*, "loop", real(clock_end-clock_start) / clock_rate / n
if(minval(dat) == 1.) print *
dat(:) = 1.
call system_clock(clock_start)
dat(:) = gelu_fast_tanh(dat(:))
call system_clock(clock_end)
print*, "array", real(clock_end-clock_start) / clock_rate / n
if(minval(dat) /= 1.) print *
print *, "gelu_fast_tanh2"
dat(:) = 1.
call system_clock(clock_start)
do i=1,n
dat(i) = gelu_fast_tanh2(dat(i))
end do
call system_clock(clock_end)
print*, "loop", real(clock_end-clock_start) / clock_rate / n
if(minval(dat) == 1.) print *
dat(:) = 1.
call system_clock(clock_start)
dat(:) = gelu_fast_tanh2(dat(:))
call system_clock(clock_end)
print*, "array", real(clock_end-clock_start) / clock_rate / n
if(minval(dat) /= 1.) print *
contains
elemental real function gelu_erf(x) result(y)
real, intent(in) :: x
real, parameter :: inverse_root_2 = 1 / sqrt(2.)
y = 0.5 * x * (1 + erf(x * inverse_root_2))
end function
elemental real function gelu_fast_erf(x) result(y)
real, intent(in) :: x
real, parameter :: inverse_root_2 = 1 / sqrt(2.)
y = 0.5 * x * (1 + fast_erf(x * inverse_root_2))
end function
elemental real function gelu_fast_erf2(x) result(y)
real, intent(in) :: x
real, parameter :: inverse_root_2 = 1 / sqrt(2.)
y = 0.5 * x * (1 + fast_erf2(x * inverse_root_2))
end function
elemental real function gelu_tanh(x) result(y)
real, intent(in) :: x
real, parameter :: inverse_pi = 1 / (4*atan(1.))
y = 0.5 * x * (1 + tanh(sqrt(2 * inverse_pi) * (x + 0.044715 * x**3)))
end function
elemental real function gelu_fast_tanh(x) result(y)
real, intent(in) :: x
real, parameter :: inverse_pi = 1 / (4*atan(1.))
y = 0.5 * x * (1 + fast_tanh(sqrt(2 * inverse_pi) * (x + 0.044715 * x**3)))
end function
elemental real function gelu_fast_tanh2(x) result(y)
real, intent(in) :: x
real, parameter :: inverse_pi = 1 / (4*atan(1.))
y = 0.5 * x * (1 + fast_tanh2(sqrt(2 * inverse_pi) * (x + 0.044715 * x**3)))
end function
elemental real function fast_tanh(x) result(y)
real, intent(in) :: x
real :: x2
if (x > 5) then
y = 1
elseif (x < -5) then
y = -1
else
x2 = x*x
y = x * (0.98569772605911309407 + x2 *(-0.2794500993392901382 &
+ x2 * (6.8280504526399188164e-2 + x2 * (-1.0972014877337651823e-2 &
+ x2 * (1.1132367134444316902e-3 + x2 * (-7.018851897305717565e-5 &
+ x2 * (2.656616768082727089e-6 + x2 * (-5.5138381821615909058e-8 &
+ x2 * 4.8162484477588665996e-10))))))))
end if
end function
elemental real function fast_tanh2(x) result(y)
real, intent(in) :: x
real :: x2, a, b
if (x > 5) then
y = 1
elseif (x < -5) then
y = -1
else
x2 = x*x
a = x * (135135.0 + x2 * (17325.0 + x2 * (378.0 + x2)))
b = 135135.0 + x2 * (62370.0 + x2 * (3150.0 + x2 * 28.0))
y = a / b
end if
end function
elemental real function fast_erf(x) result(y)
real, intent(in) :: x
real :: abs_x
abs_x = abs(x)
y = 1 - 1 / (1+ 0.278393*abs_x + 0.230389*abs_x**2 + 0.000972*abs_x**3 + 0.078108*abs_x**4)**4
y = merge(y, -y, x >= 0)
end function
elemental real function fast_erf2(x) result(y)
real, intent(in) :: x
abs_x = abs(x)
y = 1 - 1 / (1+ abs_x * (0.278393 + abs_x * (0.230389 + abs_x * (0.000972 + 0.078108*abs_x))))**4
y = merge(y, -y, x >= 0)
end function
end program bench
I’ve given my version of a fasterf (that folds in the division by sqrt(2)) above, which I’m pretty sure will be faster since it uses only 1 more coeficient and no division)
How did you get those coefficients?
The result is impressive!
- 83% faster than what I found
- 2 x as fast as
gelu
withfast_tanh
- 3.6 x faster than
gelu
witherf
! - 4.5 x faster than the original
gelu
withtanh
Result
gelu_erf
loop 1.12205600E-09
array 1.11854548E-09
gelu_fast_erf
loop 6.15510864E-10
array 6.07725537E-10
gelu_fast_erf2
loop 5.67470404E-10
array 5.66036051E-10
gelu_fast_erf_fold
loop 3.10133447E-10
array 3.04507197E-10
gelu_tanh
loop 1.38453005E-09
array 1.37178480E-09
gelu_fast_tanh
loop 6.36799724E-10
array 6.35009989E-10
gelu_fast_tanh2
loop 7.15637327E-10
array 7.13035464E-10
Code
program bench
use iso_fortran_env, only: int64
integer, parameter :: n = 1e9
real, parameter :: x(*) = [(i, i=-5,5)]
integer :: i
real :: dat(n)
integer(int64) :: clock_start, clock_end, clock_rate
call system_clock(count_rate=clock_rate)
print *, "comparison, x = ", x
print *, "gelu_erf, y = ", gelu_erf(x)
print *, "gelu_fast_erf, y = ", gelu_fast_erf(x)
print *, "gelu_fast_erf2, y = ", gelu_fast_erf2(x)
print *, "gelu_fast_erf_fold, y = ", gelu_fast_erf2(x)
print *, "gelu_tanh, y = ", gelu_tanh(x)
print *, "gelu_fast_tanh, y = ", gelu_fast_tanh(x)
print *, "gelu_fast_tanh2, y = ", gelu_fast_tanh2(x)
! warmup
print *, "Warmup"
dat(:) = 1.
dat(:) = erf(dat(:))
dat(:) = erf(dat(:))
dat(:) = erf(dat(:))
if(minval(dat) /= 1.) print *
print *, "gelu_erf"
dat(:) = 1.
call system_clock(clock_start)
do i=1,n
dat(i) = gelu_erf(dat(i))
end do
call system_clock(clock_end)
print*, "loop", real(clock_end-clock_start) / clock_rate / n
if(minval(dat) == 1.) print *
dat(:) = 1.
call system_clock(clock_start)
dat(:) = gelu_erf(dat(:))
call system_clock(clock_end)
print*, "array", real(clock_end-clock_start) / clock_rate / n
if(minval(dat) /= 1.) print *
print *, "gelu_fast_erf"
dat(:) = 1.
call system_clock(clock_start)
do i=1,n
dat(i) = gelu_fast_erf(dat(i))
end do
call system_clock(clock_end)
print*, "loop", real(clock_end-clock_start) / clock_rate / n
if(minval(dat) == 1.) print *
dat(:) = 1.
call system_clock(clock_start)
dat(:) = gelu_fast_erf(dat(:))
call system_clock(clock_end)
print*, "array", real(clock_end-clock_start) / clock_rate / n
if(minval(dat) /= 1.) print *
print *, "gelu_fast_erf2"
dat(:) = 1.
call system_clock(clock_start)
do i=1,n
dat(i) = gelu_fast_erf2(dat(i))
end do
call system_clock(clock_end)
print*, "loop", real(clock_end-clock_start) / clock_rate / n
if(minval(dat) == 1.) print *
dat(:) = 1.
call system_clock(clock_start)
dat(:) = gelu_fast_erf2(dat(:))
call system_clock(clock_end)
print*, "array", real(clock_end-clock_start) / clock_rate / n
if(minval(dat) /= 1.) print *
print *, "gelu_fast_erf_fold"
dat(:) = 1.
call system_clock(clock_start)
do i=1,n
dat(i) = gelu_fast_erf_fold(dat(i))
end do
call system_clock(clock_end)
print*, "loop", real(clock_end-clock_start) / clock_rate / n
if(minval(dat) == 1.) print *
dat(:) = 1.
call system_clock(clock_start)
dat(:) = gelu_fast_erf_fold(dat(:))
call system_clock(clock_end)
print*, "array", real(clock_end-clock_start) / clock_rate / n
if(minval(dat) /= 1.) print *
print *, "gelu_tanh"
dat(:) = 1.
call system_clock(clock_start)
do i=1,n
dat(i) = gelu_tanh(dat(i))
end do
call system_clock(clock_end)
print*, "loop", real(clock_end-clock_start) / clock_rate / n
if(minval(dat) == 1.) print *
dat(:) = 1.
call system_clock(clock_start)
dat(:) = gelu_tanh(dat(:))
call system_clock(clock_end)
print*, "array", real(clock_end-clock_start) / clock_rate / n
if(minval(dat) /= 1.) print *
print *, "gelu_fast_tanh"
dat(:) = 1.
call system_clock(clock_start)
do i=1,n
dat(i) = gelu_fast_tanh(dat(i))
end do
call system_clock(clock_end)
print*, "loop", real(clock_end-clock_start) / clock_rate / n
if(minval(dat) == 1.) print *
dat(:) = 1.
call system_clock(clock_start)
dat(:) = gelu_fast_tanh(dat(:))
call system_clock(clock_end)
print*, "array", real(clock_end-clock_start) / clock_rate / n
if(minval(dat) /= 1.) print *
print *, "gelu_fast_tanh2"
dat(:) = 1.
call system_clock(clock_start)
do i=1,n
dat(i) = gelu_fast_tanh2(dat(i))
end do
call system_clock(clock_end)
print*, "loop", real(clock_end-clock_start) / clock_rate / n
if(minval(dat) == 1.) print *
dat(:) = 1.
call system_clock(clock_start)
dat(:) = gelu_fast_tanh2(dat(:))
call system_clock(clock_end)
print*, "array", real(clock_end-clock_start) / clock_rate / n
if(minval(dat) /= 1.) print *
contains
elemental real function gelu_erf(x) result(y)
real, intent(in) :: x
real, parameter :: inverse_root_2 = 1 / sqrt(2.)
y = 0.5 * x * (1 + erf(x * inverse_root_2))
end function
elemental real function gelu_fast_erf(x) result(y)
real, intent(in) :: x
real, parameter :: inverse_root_2 = 1 / sqrt(2.)
y = 0.5 * x * (1 + fast_erf(x * inverse_root_2))
end function
elemental real function gelu_fast_erf2(x) result(y)
real, intent(in) :: x
real, parameter :: inverse_root_2 = 1 / sqrt(2.)
y = 0.5 * x * (1 + fast_erf2(x * inverse_root_2))
end function
elemental real function gelu_fast_erf_fold(x) result(y)
real, intent(in) :: x
y = 0.5 * x * (1 + fast_erf_fold(x))
end function
elemental real function gelu_tanh(x) result(y)
real, intent(in) :: x
real, parameter :: inverse_pi = 1 / (4*atan(1.))
y = 0.5 * x * (1 + tanh(sqrt(2 * inverse_pi) * (x + 0.044715 * x**3)))
end function
elemental real function gelu_fast_tanh(x) result(y)
real, intent(in) :: x
real, parameter :: inverse_pi = 1 / (4*atan(1.))
y = 0.5 * x * (1 + fast_tanh(sqrt(2 * inverse_pi) * (x + 0.044715 * x**3)))
end function
elemental real function gelu_fast_tanh2(x) result(y)
real, intent(in) :: x
real, parameter :: inverse_pi = 1 / (4*atan(1.))
y = 0.5 * x * (1 + fast_tanh2(sqrt(2 * inverse_pi) * (x + 0.044715 * x**3)))
end function
elemental real function fast_tanh(x) result(y)
real, intent(in) :: x
real :: x2
if (x > 5) then
y = 1
elseif (x < -5) then
y = -1
else
x2 = x*x
y = x * (0.98569772605911309407 + x2 *(-0.2794500993392901382 &
+ x2 * (6.8280504526399188164e-2 + x2 * (-1.0972014877337651823e-2 &
+ x2 * (1.1132367134444316902e-3 + x2 * (-7.018851897305717565e-5 &
+ x2 * (2.656616768082727089e-6 + x2 * (-5.5138381821615909058e-8 &
+ x2 * 4.8162484477588665996e-10))))))))
end if
end function
elemental real function fast_tanh2(x) result(y)
real, intent(in) :: x
real :: x2, a, b
if (x > 5) then
y = 1
elseif (x < -5) then
y = -1
else
x2 = x*x
a = x * (135135.0 + x2 * (17325.0 + x2 * (378.0 + x2)))
b = 135135.0 + x2 * (62370.0 + x2 * (3150.0 + x2 * 28.0))
y = a / b
end if
end function
elemental real function fast_erf(x) result(y)
real, intent(in) :: x
real :: abs_x
abs_x = abs(x)
y = 1 - 1 / (1+ 0.278393*abs_x + 0.230389*abs_x**2 + 0.000972*abs_x**3 + 0.078108*abs_x**4)**4
y = merge(y, -y, x >= 0)
end function
elemental real function fast_erf2(x) result(y)
real, intent(in) :: x
abs_x = abs(x)
y = 1 - 1 / (1+ abs_x * (0.278393 + abs_x * (0.230389 + abs_x * (0.000972 + 0.078108*abs_x))))**4
y = merge(y, -y, x >= 0)
end function
elemental real function fast_erf_fold(x) result(y)
real, intent(in) :: x
real :: x2, res
real, parameter :: c1 = 0.7975839, c2 = -0.13200624, c3 = 0.019021248, &
c4 = -0.0019748025, c5 = 0.00013678304, c6 = -5.5545797e-6, c7 = 9.853275e-8
x2 = x * x
res = x * (c1 + x2 * (c2 + x2 * (c3 + x2 * (c4 + x2 * (c5 + x2 * (c6 + x2 * c7))))))
if (x2 < 12.75) then
fasterf = res
else
fasterf = sign(1.0, x)
end if
end function
end program bench
PS: ChatGPT was kind enough to translate the Julia version to Fortran.
TLDR: Remez.jl
(and the power of Julia to make exploratory analysis easy)
Step 1 was to find an accuracy target. I went with around 2^-11
since that was about the error between the tanh
and erf
version.
Since erf
is an odd function, the obvious approximation is an odd minimax polynomial.
Since 1-erf(x/sqrt(2))
drops below the tolerance at sqrt(12.75)
, that’s the domain over which we need to find a minimax polynomial.
Play with the error incurred by different degree polynomials until you find one that works.
using Remez, SpecialFuntions
f(x) = erf(x/sqrt(big(2))
UB = sqrt(12.75)
# see https://github.com/simonbyrne/Remez.jl/issues/7 for how to generate odd minimax functions
Float32.(ratfn_minimax(x->f(sqrt(x))/sqrt(x), (1e-20, UB^2), 6, 0, (x,y)->sqrt(x))[1])
Thanks @oscardssmith and thanks @Carltoffel for testing it! Yes, that is likely the fastest way to do it, the only other way I can think of is to reduce accuracy, say to 1e-6, which would allow even less terms but as discussed above, we would have to investigate the effects of it.
I implemented this fast version here: Implement fast_erf() and fast_gelu() by certik · Pull Request #45 · certik/fastGPT · GitHub, on my test with the Accelerate framework, it is about as fast as the fast_tanh
, however I think that’s because most of the time is spend in matrix-matrix multiplication and the arrays that GELU is applied on are small, thanks to caching (fast_tanh
and fast_gelu
run at 288ms, original at 300ms). Without caching, both fast_tanh
and fast_gelu
(via fast_erf
) run at about the same speed (about 520ms), compared to the original (720ms). Don’t know why, I would expect fast_gelu
to be faster than fast_tanh
.
5 posts were merged into an existing topic: Fast.ai - Mojo may be the biggest programming language advance in decades