fastGPT: Faster than PyTorch in 300 lines of Fortran

This is a very good point. I initially just used declarations like real(sp), intent(in) :: x(:,:), but there are lots of arrays and it quickly became really hard to ensure I didn’t make a mistake. So I reverted to the style of real(sp), intent(in) :: x(n_embd,n_seq) where the compiler can check the compatibility of arrays (typically at runtime) and that helped a lot to catch bugs, that I multiply matrices correctly, and loop over the correct bounds, etc. In Python the indices are reverted (column vs row major), so it’s really easy to get it wrong. I also found it’s nicer to document what each index is doing directly like this, rather than having it as comments. I could infer it and still set dimensions of everything, but that becomes very hard to read if you have size(x, 1) everywhere. Much more natural to use the problem parameters like n_embd, n_seq, or n_layers.

See my proposal here for how this can be improved: How to concisely declare a function result as a multi-dimensional automatic array? - #4 by certik

1 Like

Obviously … if I had read the post more carefully. Sorry for the noise.

1 Like

I created a GSoC idea for fastGPT here:

https://github.com/fortran-lang/webpage/wiki/GSoC-2023-Project-ideas#improving-fastgpt-making-it-faster-easier-to-use-and-more-general

If you are a student and interested, please let us know!

4 Likes

I implemented the input tokenizer (encoder) in Fortran, so now fastGPT is completely standalone. This allows to create a chat interface, which I also did. Here is an example session using the largest 1558M model:

$ ./chat
Your name is fastGPT and you are an AI bot. The user will ask you questions and you answer in a nice, truthful, short way.
User: What is the capital of Czechia?
fastGPT: Prague.
User: How many legs does a dog have?
fastGPT: Four.
User: What color does the sky have?
fastGPT: Blue.
User: What can you type a document on?
fastGPT: A typewriter.
User: What can you drive in?
fastGPT: A car.
User: What can you fly in?
fastGPT: A plane.
User: What continent is Germany in?
fastGPT: Europe.
User: When did Second World War start?
fastGPT: 1939.
User: When did it end?
fastGPT: 1945.
User: When did the U.S. enter the Second World War?
fastGPT: 1941.
User: When did the First World War start?
fastGPT: 1914.
User: When did it end?
fastGPT: 1918.
User: When did the Mexican-American war start?
fastGPT: 1846.
User: When did it end?
fastGPT: 1848.
User: What color is snow?
fastGPT: White.
User: What color do plants usually have?
fastGPT: Green.
User: What is your name?
fastGPT: fastGPT.

(You must scroll to see the full output.)

I think it’s actually very impressive that GPT-2 without any fine-tuning can not only act as a chat bot, but even answer all these questions correctly! All running locally, the inference calculation in about 300 lines of Fortran.

8 Likes

Brilliant, thank you!

Readers should now know for sure this is yet barely scratching the surface of what’s possible with Fortran.

With a bit of added language support and the increasingly better ecosystem, Fortran can be among the first choice languages for any form of computing, not merely number-crunching, with easy to read and good-looking syntax as well as elegant semantics.

2 Likes

Is the next step to make use of @interkosmos 's and/or GSoC effort to use the Fortran interface to CURL or some transfer protocol library so that the model creation can also be in Fortran? From the “create_model” Python script, it appears the only gap currently might be this get request and it can be addressed with a wrapper Fortran library:

Ultimately it should be possible to author a Fortran-only multiprotocol file transfer library from the ground up, especially for easy-to-use client-side URL transfer, and perhaps make it part of stdlib too?

1 Like

The create_model.py uses TensorFlow (I think) to load its model and transform to our custom (fast) binary file model.dat. It could be done in Fortran, but I think a better approach is to upload our own model.dat file online and just download it, I even have an issue for it here: Upload model.dat online · Issue #31 · certik/fastGPT · GitHub.

However, being able to download from Fortran would be helpful.

It might be better to use an approximated erf function instead of an approximated tanh function. I did some testing, and it seems to be 10-20% faster to calculate gelu with fast_erf compared to gelu using fast_tanh. Tested on Linux, i7-10510U CPU @ 1.80GHz, GFortran with -O3 -ffast-math -march=native

Result
 gelu_erf
 loop   1.11502607E-09
 array   1.10498388E-09

 gelu_fast_erf
 loop   6.49572729E-10
 array   6.29547081E-10

 gelu_fast_erf2
 loop   5.84478632E-10
 array   5.82690174E-10

 gelu_tanh
 loop   1.60507918E-09
 array   1.57718894E-09

 gelu_fast_tanh
 loop   6.57730481E-10
 array   6.95166535E-10

 gelu_fast_tanh2
 loop   7.59801999E-10
 array   7.49453777E-10
Code
program bench
    use iso_fortran_env, only: int64
    integer, parameter :: n = 1e9
    real,    parameter :: x(*) = [(i, i=-5,5)]

    integer :: i
    real    :: dat(n)
    integer(int64)  :: clock_start, clock_end, clock_rate

    call system_clock(count_rate=clock_rate)


    print *, "comparison,      x = ", x
    print *, "gelu_erf,        y = ", gelu_erf(x)
    print *, "gelu_fast_erf,   y = ", gelu_fast_erf(x)
    print *, "gelu_fast_erf2,  y = ", gelu_fast_erf2(x)
    print *, "gelu_tanh,       y = ", gelu_tanh(x)
    print *, "gelu_fast_tanh,  y = ", gelu_fast_tanh(x)
    print *, "gelu_fast_tanh2, y = ", gelu_fast_tanh2(x)

    ! warmup
    print *, "Warmup"
    dat(:) = 1.
    dat(:) = erf(dat(:))
    dat(:) = erf(dat(:))
    dat(:) = erf(dat(:))
    if(minval(dat) /= 1.) print *


    print *, "gelu_erf"

    dat(:) = 1.
    call system_clock(clock_start)
    do i=1,n
        dat(i) = gelu_erf(dat(i))
    end do
    call system_clock(clock_end)
    print*, "loop", real(clock_end-clock_start) / clock_rate / n
    if(minval(dat) == 1.) print *

    dat(:) = 1.
    call system_clock(clock_start)
    dat(:) = gelu_erf(dat(:))
    call system_clock(clock_end)
    print*, "array", real(clock_end-clock_start) / clock_rate / n
    if(minval(dat) /= 1.) print *


    print *, "gelu_fast_erf"

    dat(:) = 1.
    call system_clock(clock_start)
    do i=1,n
        dat(i) = gelu_fast_erf(dat(i))
    end do
    call system_clock(clock_end)
    print*, "loop", real(clock_end-clock_start) / clock_rate / n
    if(minval(dat) == 1.) print *

    dat(:) = 1.
    call system_clock(clock_start)
    dat(:) = gelu_fast_erf(dat(:))
    call system_clock(clock_end)
    print*, "array", real(clock_end-clock_start) / clock_rate / n
    if(minval(dat) /= 1.) print *


    print *, "gelu_fast_erf2"

    dat(:) = 1.
    call system_clock(clock_start)
    do i=1,n
        dat(i) = gelu_fast_erf2(dat(i))
    end do
    call system_clock(clock_end)
    print*, "loop", real(clock_end-clock_start) / clock_rate / n
    if(minval(dat) == 1.) print *

    dat(:) = 1.
    call system_clock(clock_start)
    dat(:) = gelu_fast_erf2(dat(:))
    call system_clock(clock_end)
    print*, "array", real(clock_end-clock_start) / clock_rate / n
    if(minval(dat) /= 1.) print *


    print *, "gelu_tanh"

    dat(:) = 1.
    call system_clock(clock_start)
    do i=1,n
        dat(i) = gelu_tanh(dat(i))
    end do
    call system_clock(clock_end)
    print*, "loop", real(clock_end-clock_start) / clock_rate / n
    if(minval(dat) == 1.) print *

    dat(:) = 1.
    call system_clock(clock_start)
    dat(:) = gelu_tanh(dat(:))
    call system_clock(clock_end)
    print*, "array", real(clock_end-clock_start) / clock_rate / n
    if(minval(dat) /= 1.) print *


    print *, "gelu_fast_tanh"

    dat(:) = 1.
    call system_clock(clock_start)
    do i=1,n
        dat(i) = gelu_fast_tanh(dat(i))
    end do
    call system_clock(clock_end)
    print*, "loop", real(clock_end-clock_start) / clock_rate / n
    if(minval(dat) == 1.) print *

    dat(:) = 1.
    call system_clock(clock_start)
    dat(:) = gelu_fast_tanh(dat(:))
    call system_clock(clock_end)
    print*, "array", real(clock_end-clock_start) / clock_rate / n
    if(minval(dat) /= 1.) print *


    print *, "gelu_fast_tanh2"

    dat(:) = 1.
    call system_clock(clock_start)
    do i=1,n
        dat(i) = gelu_fast_tanh2(dat(i))
    end do
    call system_clock(clock_end)
    print*, "loop", real(clock_end-clock_start) / clock_rate / n
    if(minval(dat) == 1.) print *

    dat(:) = 1.
    call system_clock(clock_start)
    dat(:) = gelu_fast_tanh2(dat(:))
    call system_clock(clock_end)
    print*, "array", real(clock_end-clock_start) / clock_rate / n
    if(minval(dat) /= 1.) print *

contains

    elemental real function gelu_erf(x) result(y)
        real, intent(in) :: x
        real, parameter :: inverse_root_2 = 1 / sqrt(2.)

        y = 0.5 * x * (1 + erf(x * inverse_root_2))
    end function

    elemental real function gelu_fast_erf(x) result(y)
        real, intent(in) :: x
        real, parameter :: inverse_root_2 = 1 / sqrt(2.)

        y = 0.5 * x * (1 + fast_erf(x * inverse_root_2))
    end function

    elemental real function gelu_fast_erf2(x) result(y)
        real, intent(in) :: x
        real, parameter :: inverse_root_2 = 1 / sqrt(2.)

        y = 0.5 * x * (1 + fast_erf2(x * inverse_root_2))
    end function

    elemental real function gelu_tanh(x) result(y)
        real, intent(in) :: x
        real, parameter :: inverse_pi = 1 / (4*atan(1.))

        y = 0.5 * x * (1 + tanh(sqrt(2 * inverse_pi) * (x + 0.044715 * x**3)))
    end function

    elemental real function gelu_fast_tanh(x) result(y)
        real, intent(in) :: x
        real, parameter :: inverse_pi = 1 / (4*atan(1.))

        y = 0.5 * x * (1 + fast_tanh(sqrt(2 * inverse_pi) * (x + 0.044715 * x**3)))
    end function

    elemental real function gelu_fast_tanh2(x) result(y)
        real, intent(in) :: x
        real, parameter :: inverse_pi = 1 / (4*atan(1.))

        y = 0.5 * x * (1 + fast_tanh2(sqrt(2 * inverse_pi) * (x + 0.044715 * x**3)))
    end function

    elemental real function fast_tanh(x) result(y)
        real, intent(in) :: x
        real :: x2

        if (x > 5) then
            y = 1
        elseif (x < -5) then
            y = -1
        else
            x2 = x*x
            y = x * (0.98569772605911309407 + x2 *(-0.2794500993392901382 &
                + x2 * (6.8280504526399188164e-2 + x2 * (-1.0972014877337651823e-2 &
                + x2 * (1.1132367134444316902e-3 + x2 * (-7.018851897305717565e-5 &
                + x2 * (2.656616768082727089e-6 + x2 * (-5.5138381821615909058e-8 &
                + x2 * 4.8162484477588665996e-10))))))))
        end if
    end function

    elemental real function fast_tanh2(x) result(y)
        real, intent(in) :: x
        real :: x2, a, b

        if (x > 5) then
            y = 1
        elseif (x < -5) then
            y = -1
        else
            x2 = x*x
            a = x * (135135.0 + x2 * (17325.0 + x2 * (378.0 + x2)))
            b = 135135.0 + x2 * (62370.0 + x2 * (3150.0 + x2 * 28.0))
            y = a / b
        end if
    end function

    elemental real function fast_erf(x) result(y)
        real, intent(in) :: x
        real :: abs_x

        abs_x = abs(x)
        y = 1 - 1 / (1+ 0.278393*abs_x + 0.230389*abs_x**2 + 0.000972*abs_x**3 + 0.078108*abs_x**4)**4
        y = merge(y, -y, x >= 0)
    end function

    elemental real function fast_erf2(x) result(y)
        real, intent(in) :: x

        abs_x = abs(x)
        y = 1 - 1 / (1+ abs_x * (0.278393 + abs_x * (0.230389 + abs_x * (0.000972 + 0.078108*abs_x))))**4
        y = merge(y, -y, x >= 0)
    end function

end program bench
2 Likes

I’ve given my version of a fasterf (that folds in the division by sqrt(2)) above, which I’m pretty sure will be faster since it uses only 1 more coeficient and no division)

1 Like

How did you get those coefficients?
The result is impressive!

  • 83% faster than what I found
  • 2 x as fast as gelu with fast_tanh
  • 3.6 x faster than gelu with erf!
  • 4.5 x faster than the original gelu with tanh
Result
 gelu_erf
 loop   1.12205600E-09
 array   1.11854548E-09

 gelu_fast_erf
 loop   6.15510864E-10
 array   6.07725537E-10

 gelu_fast_erf2
 loop   5.67470404E-10
 array   5.66036051E-10

 gelu_fast_erf_fold
 loop   3.10133447E-10
 array   3.04507197E-10

 gelu_tanh
 loop   1.38453005E-09
 array   1.37178480E-09

 gelu_fast_tanh
 loop   6.36799724E-10
 array   6.35009989E-10

 gelu_fast_tanh2
 loop   7.15637327E-10
 array   7.13035464E-10
Code
program bench
    use iso_fortran_env, only: int64
    integer, parameter :: n = 1e9
    real,    parameter :: x(*) = [(i, i=-5,5)]

    integer :: i
    real    :: dat(n)
    integer(int64)  :: clock_start, clock_end, clock_rate

    call system_clock(count_rate=clock_rate)


    print *, "comparison,         x = ", x
    print *, "gelu_erf,           y = ", gelu_erf(x)
    print *, "gelu_fast_erf,      y = ", gelu_fast_erf(x)
    print *, "gelu_fast_erf2,     y = ", gelu_fast_erf2(x)
    print *, "gelu_fast_erf_fold, y = ", gelu_fast_erf2(x)
    print *, "gelu_tanh,          y = ", gelu_tanh(x)
    print *, "gelu_fast_tanh,     y = ", gelu_fast_tanh(x)
    print *, "gelu_fast_tanh2,    y = ", gelu_fast_tanh2(x)

    ! warmup
    print *, "Warmup"
    dat(:) = 1.
    dat(:) = erf(dat(:))
    dat(:) = erf(dat(:))
    dat(:) = erf(dat(:))
    if(minval(dat) /= 1.) print *


    print *, "gelu_erf"

    dat(:) = 1.
    call system_clock(clock_start)
    do i=1,n
        dat(i) = gelu_erf(dat(i))
    end do
    call system_clock(clock_end)
    print*, "loop", real(clock_end-clock_start) / clock_rate / n
    if(minval(dat) == 1.) print *

    dat(:) = 1.
    call system_clock(clock_start)
    dat(:) = gelu_erf(dat(:))
    call system_clock(clock_end)
    print*, "array", real(clock_end-clock_start) / clock_rate / n
    if(minval(dat) /= 1.) print *


    print *, "gelu_fast_erf"

    dat(:) = 1.
    call system_clock(clock_start)
    do i=1,n
        dat(i) = gelu_fast_erf(dat(i))
    end do
    call system_clock(clock_end)
    print*, "loop", real(clock_end-clock_start) / clock_rate / n
    if(minval(dat) == 1.) print *

    dat(:) = 1.
    call system_clock(clock_start)
    dat(:) = gelu_fast_erf(dat(:))
    call system_clock(clock_end)
    print*, "array", real(clock_end-clock_start) / clock_rate / n
    if(minval(dat) /= 1.) print *


    print *, "gelu_fast_erf2"

    dat(:) = 1.
    call system_clock(clock_start)
    do i=1,n
        dat(i) = gelu_fast_erf2(dat(i))
    end do
    call system_clock(clock_end)
    print*, "loop", real(clock_end-clock_start) / clock_rate / n
    if(minval(dat) == 1.) print *

    dat(:) = 1.
    call system_clock(clock_start)
    dat(:) = gelu_fast_erf2(dat(:))
    call system_clock(clock_end)
    print*, "array", real(clock_end-clock_start) / clock_rate / n
    if(minval(dat) /= 1.) print *


    print *, "gelu_fast_erf_fold"

    dat(:) = 1.
    call system_clock(clock_start)
    do i=1,n
        dat(i) = gelu_fast_erf_fold(dat(i))
    end do
    call system_clock(clock_end)
    print*, "loop", real(clock_end-clock_start) / clock_rate / n
    if(minval(dat) == 1.) print *

    dat(:) = 1.
    call system_clock(clock_start)
    dat(:) = gelu_fast_erf_fold(dat(:))
    call system_clock(clock_end)
    print*, "array", real(clock_end-clock_start) / clock_rate / n
    if(minval(dat) /= 1.) print *


    print *, "gelu_tanh"

    dat(:) = 1.
    call system_clock(clock_start)
    do i=1,n
        dat(i) = gelu_tanh(dat(i))
    end do
    call system_clock(clock_end)
    print*, "loop", real(clock_end-clock_start) / clock_rate / n
    if(minval(dat) == 1.) print *

    dat(:) = 1.
    call system_clock(clock_start)
    dat(:) = gelu_tanh(dat(:))
    call system_clock(clock_end)
    print*, "array", real(clock_end-clock_start) / clock_rate / n
    if(minval(dat) /= 1.) print *


    print *, "gelu_fast_tanh"

    dat(:) = 1.
    call system_clock(clock_start)
    do i=1,n
        dat(i) = gelu_fast_tanh(dat(i))
    end do
    call system_clock(clock_end)
    print*, "loop", real(clock_end-clock_start) / clock_rate / n
    if(minval(dat) == 1.) print *

    dat(:) = 1.
    call system_clock(clock_start)
    dat(:) = gelu_fast_tanh(dat(:))
    call system_clock(clock_end)
    print*, "array", real(clock_end-clock_start) / clock_rate / n
    if(minval(dat) /= 1.) print *


    print *, "gelu_fast_tanh2"

    dat(:) = 1.
    call system_clock(clock_start)
    do i=1,n
        dat(i) = gelu_fast_tanh2(dat(i))
    end do
    call system_clock(clock_end)
    print*, "loop", real(clock_end-clock_start) / clock_rate / n
    if(minval(dat) == 1.) print *

    dat(:) = 1.
    call system_clock(clock_start)
    dat(:) = gelu_fast_tanh2(dat(:))
    call system_clock(clock_end)
    print*, "array", real(clock_end-clock_start) / clock_rate / n
    if(minval(dat) /= 1.) print *

contains

    elemental real function gelu_erf(x) result(y)
        real, intent(in) :: x
        real, parameter :: inverse_root_2 = 1 / sqrt(2.)

        y = 0.5 * x * (1 + erf(x * inverse_root_2))
    end function

    elemental real function gelu_fast_erf(x) result(y)
        real, intent(in) :: x
        real, parameter :: inverse_root_2 = 1 / sqrt(2.)

        y = 0.5 * x * (1 + fast_erf(x * inverse_root_2))
    end function

    elemental real function gelu_fast_erf2(x) result(y)
        real, intent(in) :: x
        real, parameter :: inverse_root_2 = 1 / sqrt(2.)

        y = 0.5 * x * (1 + fast_erf2(x * inverse_root_2))
    end function

    elemental real function gelu_fast_erf_fold(x) result(y)
        real, intent(in) :: x

        y = 0.5 * x * (1 + fast_erf_fold(x))
    end function

    elemental real function gelu_tanh(x) result(y)
        real, intent(in) :: x
        real, parameter :: inverse_pi = 1 / (4*atan(1.))

        y = 0.5 * x * (1 + tanh(sqrt(2 * inverse_pi) * (x + 0.044715 * x**3)))
    end function

    elemental real function gelu_fast_tanh(x) result(y)
        real, intent(in) :: x
        real, parameter :: inverse_pi = 1 / (4*atan(1.))

        y = 0.5 * x * (1 + fast_tanh(sqrt(2 * inverse_pi) * (x + 0.044715 * x**3)))
    end function

    elemental real function gelu_fast_tanh2(x) result(y)
        real, intent(in) :: x
        real, parameter :: inverse_pi = 1 / (4*atan(1.))

        y = 0.5 * x * (1 + fast_tanh2(sqrt(2 * inverse_pi) * (x + 0.044715 * x**3)))
    end function

    elemental real function fast_tanh(x) result(y)
        real, intent(in) :: x
        real :: x2

        if (x > 5) then
            y = 1
        elseif (x < -5) then
            y = -1
        else
            x2 = x*x
            y = x * (0.98569772605911309407 + x2 *(-0.2794500993392901382 &
                + x2 * (6.8280504526399188164e-2 + x2 * (-1.0972014877337651823e-2 &
                + x2 * (1.1132367134444316902e-3 + x2 * (-7.018851897305717565e-5 &
                + x2 * (2.656616768082727089e-6 + x2 * (-5.5138381821615909058e-8 &
                + x2 * 4.8162484477588665996e-10))))))))
        end if
    end function

    elemental real function fast_tanh2(x) result(y)
        real, intent(in) :: x
        real :: x2, a, b

        if (x > 5) then
            y = 1
        elseif (x < -5) then
            y = -1
        else
            x2 = x*x
            a = x * (135135.0 + x2 * (17325.0 + x2 * (378.0 + x2)))
            b = 135135.0 + x2 * (62370.0 + x2 * (3150.0 + x2 * 28.0))
            y = a / b
        end if
    end function

    elemental real function fast_erf(x) result(y)
        real, intent(in) :: x
        real :: abs_x

        abs_x = abs(x)
        y = 1 - 1 / (1+ 0.278393*abs_x + 0.230389*abs_x**2 + 0.000972*abs_x**3 + 0.078108*abs_x**4)**4
        y = merge(y, -y, x >= 0)
    end function

    elemental real function fast_erf2(x) result(y)
        real, intent(in) :: x

        abs_x = abs(x)
        y = 1 - 1 / (1+ abs_x * (0.278393 + abs_x * (0.230389 + abs_x * (0.000972 + 0.078108*abs_x))))**4
        y = merge(y, -y, x >= 0)
    end function

    elemental real function fast_erf_fold(x) result(y)
        real, intent(in) :: x

        real :: x2, res
        real, parameter :: c1 = 0.7975839, c2 = -0.13200624, c3 = 0.019021248, &
            c4 = -0.0019748025, c5 = 0.00013678304, c6 = -5.5545797e-6, c7 = 9.853275e-8

        x2 = x * x
        res = x * (c1 + x2 * (c2 + x2 * (c3 + x2 * (c4 + x2 * (c5 + x2 * (c6 + x2 * c7))))))

        if (x2 < 12.75) then
            fasterf = res
        else
            fasterf = sign(1.0, x)
        end if
    end function

end program bench

PS: ChatGPT was kind enough to translate the Julia version to Fortran. :relieved:

1 Like

TLDR: Remez.jl (and the power of Julia to make exploratory analysis easy)
Step 1 was to find an accuracy target. I went with around 2^-11 since that was about the error between the tanh and erf version.
Since erf is an odd function, the obvious approximation is an odd minimax polynomial.
Since 1-erf(x/sqrt(2)) drops below the tolerance at sqrt(12.75), that’s the domain over which we need to find a minimax polynomial.
Play with the error incurred by different degree polynomials until you find one that works.

using Remez, SpecialFuntions
f(x) = erf(x/sqrt(big(2))
UB = sqrt(12.75)
# see https://github.com/simonbyrne/Remez.jl/issues/7 for how to generate odd minimax functions
Float32.(ratfn_minimax(x->f(sqrt(x))/sqrt(x), (1e-20, UB^2), 6, 0, (x,y)->sqrt(x))[1])
3 Likes

Thanks @oscardssmith and thanks @Carltoffel for testing it! Yes, that is likely the fastest way to do it, the only other way I can think of is to reduce accuracy, say to 1e-6, which would allow even less terms but as discussed above, we would have to investigate the effects of it.

I implemented this fast version here: Implement fast_erf() and fast_gelu() by certik · Pull Request #45 · certik/fastGPT · GitHub, on my test with the Accelerate framework, it is about as fast as the fast_tanh, however I think that’s because most of the time is spend in matrix-matrix multiplication and the arrays that GELU is applied on are small, thanks to caching (fast_tanh and fast_gelu run at 288ms, original at 300ms). Without caching, both fast_tanh and fast_gelu (via fast_erf) run at about the same speed (about 520ms), compared to the original (720ms). Don’t know why, I would expect fast_gelu to be faster than fast_tanh.

5 posts were merged into an existing topic: Fast.ai - Mojo may be the biggest programming language advance in decades