C++ Standard Library dense linear algebra interface

Hi Victor, Thanks for your input. I am aware of assembly code presence in BLAS implementations. My comment addressed frequent (re)discoveries by different groups and communities that they can outperform BLAS implementations, particularly for small tasks, by breaking the original BLAS interface and static method overloading. When one breaks the interface, performance comparisons with original BLAS-conforming implementations are meaningless. Some tend to make a splash about such rediscoveries.

1 Like

I like this. It is indeed fast on my machine as well. Rather unfortunately, I get nearly the same performance just rearranging my loops to the same order, removing all accumulators and compiler directives. I went looking into cache-oblivious algorithms, but only really found this: https://math.mit.edu/~stevenj/18.335/oblivious-matmul-handout.pdf which seems to always perform worse on my machine that the simple triple nested loop.

I think it’s really inarguable that less input arguments should result in less calling overhead. I did not mean to claim this as any brilliant new discovery, rather I believe many were stating it simply as a fact. The real issue is that without explicit control to dump directly into/out of vector registers, compilers seem unable to generate equivalent code from higher level languages (including Fortran).

If the wall you run into is that the language simply lacks the ability to express the programmer’s desires, then it makes a lot more sense why numerical codes will continue to be ported to C++. The language is more expressive and has better availability of interfaces to modern hardware (inline assembly, explicit SIMD, GPU offload, other accelerator cards, etc.).

At the end of the day, I just think it’s lame that the fastest code I seem able to generate in Fortran looks like this:

pure subroutine mm_pmn(m, n, p, a, b, c)
    use, intrinsic :: iso_fortran_env, only: rk => real64
    implicit none
    integer, intent(in) :: m, n, p             
    real(rk), intent(in) :: a(m,n), b(n,p)     
    real(rk), intent(inout) :: c(m,p)          
    integer :: i, j, k                         
    do i=1,p                                   
        do j=1,m                               
            do k=1,n                           
                c(j,i) = c(j,i) + a(j,k)*b(k,i)
            end do                             
        end do                                 
    end do                                     
end subroutine mm_pmn                          

And it is still substantially slower than linking to an external library (none of which are actually implemented in Fortran). Maybe it would be possible following a problem decomposition like here: https://www.cs.utexas.edu/~flame/pubs/GotoTOMS_revision.pdf Each level of the solution could be compiled individually, finding the perfect set of compiler flags that manage to get the compiler to actually produce direct code for each part, then in the end using some additional compiler directives to force inlining during LTO. I think you would need some LTO step, because otherwise each individually compiled file will not have anything in its compilation unit to inline into.

I used both ifort and gfortran, and benchmarked intrinsic matmul, calling single-threaded OpenBLAS dgemm, the above dgemm_nn, and each permutation of triple nested loop. Results are somewhat messy:

Click to see results (all, somewhat messy)

There’s a lot going on in that chart. If I filter out all the bad ones, keeping gfortran’s intrinsic matmul as the lowest acceptable (~10 GFLOPS on my machine) it looks like this

Click to see results (filtered, much cleaner)

In the filtered down view, we can see that clearly calling OpenBLAS dgemm is the best. This is no surprise, and performance between gfortran and ifort is the same. Good, this is as expected. The dark blue line below those 2 is ifort’s intrinsic matmul. Averaging around 22 GFLOPS on my machine, it was in a league of its own on my machine. None of the Fortran source versions were within striking distance, and it was significantly slower than actually calling OpenBLAS.

Interestingly, the loop ordering of actual fast triple-nested-loop versions was different between ifort and gfortran. ifort favored M-P-N and P-M-N ordering, both offering consistent 15 GFLOPS on my machine, while gfortran favored P-N-M ordering at about 11 GFLOPS. Also interesting, the dgemm_nn routine above with explicit OpenMP and Intel compiler directives was still slower on ifort than the 2 good loop orderings. With gfortran, dgemm_nn was about 10% faster than the P-N-M loops.

I would love to see any pure Fortran implementation of matrix multiplication that can get within 10% of OpenBLAS in a single threaded scenario. If that is deemed truly impossible by those more knowledgeable than myself, coming within 10% of ifort’s intrinsic matmul would be cool too. The most simple loops are already nearly there and only need a 50% speedup to achieve parity.

click for code

CODE:

module my_mod
use, intrinsic :: iso_fortran_env, i64 => int64, rk => real64
implicit none
private

    public :: i64, rk, mm_matmul, dgemm_nn, dgemm, mm_mnp, mm_mpn, mm_nmp, mm_npm, mm_pmn, mm_pnm


    interface
        subroutine dgemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc)
            import rk
            implicit none
            character, intent(in) :: transa, transb
            integer, intent(in) :: m, n, k, lda, ldb, ldc
            real(rk), intent(in) :: alpha, a(lda,*), b(ldb,*), beta
            real(rk), intent(inout) :: c(ldc,*)
        end subroutine dgemm
    end interface


    contains


        pure subroutine mm_matmul(m, n, p, a, b, c)
            integer, intent(in) :: m, n, p
            real(rk), intent(in) :: a(m,n), b(n,p)
            real(rk), intent(inout) :: c(m,p)
            c = c + matmul(a, b)
        end subroutine mm_matmul


        pure subroutine dgemm_nn(m,n,k,alpha,a,lda,b,ldb,beta,c,ldc)
        implicit none
            integer, intent(in) :: m,n,k,lda,ldb,ldc
            real(8), intent(in) :: alpha,beta,a(lda,k),b(ldb,n)
            real(8), intent(inout) :: c(ldc,n)
            ! local variables
            integer i,j,l
            real(8) cc(m),t
            !DIR$ ATTRIBUTES ALIGN : 64 :: cc
            do j=1,n
              cc(1:m)=beta*c(1:m,j)
            !DIR$ UNROLL_AND_JAM(8)
              do l=1,k
                t=alpha*b(l,j)
            !$OMP SIMD SIMDLEN(8)
                do i=1,m
                  cc(i)=cc(i)+t*a(i,l)
                end do
              end do
              c(1:m,j)=cc(1:m)
            end do
        end subroutine


        pure subroutine mm_mnp(m, n, p, a, b, c)
            integer, intent(in) :: m, n, p
            real(rk), intent(in) :: a(m,n), b(n,p)
            real(rk), intent(inout) :: c(m,p)
            integer :: i, j, k
            do i=1,m
                do j=1,n
                    do k=1,p
                        c(i,k) = c(i,k) + a(i,j)*b(j,k)
                    end do
                end do
            end do
        end subroutine mm_mnp


        pure subroutine mm_mpn(m, n, p, a, b, c)
            integer, intent(in) :: m, n, p
            real(rk), intent(in) :: a(m,n), b(n,p)
            real(rk), intent(inout) :: c(m,p)
            integer :: i, j, k
            do i=1,m
                do j=1,p
                    do k=1,n
                        c(i,j) = c(i,j) + a(i,k)*b(k,j)
                    end do
                end do
            end do
        end subroutine mm_mpn


        pure subroutine mm_nmp(m, n, p, a, b, c)
            integer, intent(in) :: m, n, p
            real(rk), intent(in) :: a(m,n), b(n,p)
            real(rk), intent(inout) :: c(m,p)
            integer :: i, j, k
            do i=1,n
                do j=1,m
                    do k=1,p
                        c(j,k) = c(j,k) + a(j,i)*b(i,k)
                    end do
                end do
            end do
        end subroutine mm_nmp


        pure subroutine mm_npm(m, n, p, a, b, c)
            integer, intent(in) :: m, n, p
            real(rk), intent(in) :: a(m,n), b(n,p)
            real(rk), intent(inout) :: c(m,p)
            integer :: i, j, k
            do i=1,n
                do j=1,p
                    do k=1,m
                        c(k,j) = c(k,j) + a(k,i)*b(i,j)
                    end do
                end do
            end do
        end subroutine mm_npm


        pure subroutine mm_pmn(m, n, p, a, b, c)
            integer, intent(in) :: m, n, p
            real(rk), intent(in) :: a(m,n), b(n,p)
            real(rk), intent(inout) :: c(m,p)
            integer :: i, j, k
            do i=1,p
                do j=1,m
                    do k=1,n
                        c(j,i) = c(j,i) + a(j,k)*b(k,i)
                    end do
                end do
            end do
        end subroutine mm_pmn


        pure subroutine mm_pnm(m, n, p, a, b, c)
            integer, intent(in) :: m, n, p
            real(rk), intent(in) :: a(m,n), b(n,p)
            real(rk), intent(inout) :: c(m,p)
            integer :: i, j, k
            do i=1,p
                do j=1,n
                    do k=1,m
                        c(k,i) = c(k,i) + a(k,j)*b(j,i)
                    end do
                end do
            end do
        end subroutine mm_pnm


end module my_mod 


program main
use, intrinsic :: iso_fortran_env, only: compiler_version, compiler_options
use, non_intrinsic :: my_mod
implicit none

    integer :: n, n_min, n_max, n_step, r, r_max, argc
    real(rk), allocatable :: a(:,:), b(:,:), c_0(:,:), c_1(:,:)
    integer(i64) :: c1, cr, c2
    character(len=8) :: arg_buff
    real(rk) :: elapsed, gflops

    argc = command_argument_count()
    select case (argc)
        case (3)
            call get_command_argument(1, arg_buff)
            read(arg_buff,*) n_min
            call get_command_argument(2, arg_buff)
            read(arg_buff,*) n_max
            call get_command_argument(3, arg_buff)
            read(arg_buff,*) n_step
        case (2)
            call get_command_argument(1, arg_buff)
            read(arg_buff,*) n_min
            call get_command_argument(2, arg_buff)
            read(arg_buff,*) n_max
            n_step = 1
        case (1)
            call get_command_argument(1, arg_buff)
            read(arg_buff,*) n_max
            n_min = 1
            n_step = 1
        case (0)
            n_min = 10
            n_max = 100
            n_step = 10
        case default
            error stop 'invalid argument count'
    end select
    write(*,'(a)') compiler_version()//' -- '//compiler_options()
    write(*,'(3(a,i0))') 'n_min: ',n_min,', n_max: ',n_max,', n_step: ',n_step

    write(831,'(a)') 'n'// &
                     ',r_max'// &
                     ',matmul'// &
                     ',external_dgemm'// &
                     ',dgemm_nn'// &
                     ',mm_mnp'// &
                     ',mm_mpn'// &
                     ',mm_nmp'// &
                     ',mm_npm'// &
                     ',mm_pmn'// &
                     ',mm_pnm'// &
                     ''
    do n=n_min,n_max,n_step
        r_max = 10
        write(*,'(a,i0,a,i0)') 'n: ',n,', r_max: ',r_max
        write(831,'(i0,",",i0)',advance='no') n,r_max

        if (allocated(a)) deallocate(a)
        if (allocated(b)) deallocate(b)
        if (allocated(c_0)) deallocate(c_0)
        if (allocated(c_1)) deallocate(c_1)
        allocate(a(n,n), b(n,n), c_0(n,n), c_1(n,n))
        call random_number(a)
        call random_number(b)
        c_0 = 0.0_rk
        c_1 = 0.0_rk

        call system_clock(count=c1, count_rate=cr)
        do r=1,r_max
            call mm_matmul(n, n, n, a, b, c_0)
        end do
        call system_clock(count=c2)
        elapsed = real(max(c2 - c1, 1_i64), rk)/real(cr, rk)
        gflops = real(r_max, rk)*n**3.0_rk*1.0D-9/elapsed
        write(*,'(a20,f7.2,a,3e13.6)') 'MM_MATMUL: ',gflops,' GFLOPS -- min/max/n,n: ',minval(c_0),maxval(c_0),c_0(n,n)
        write(831,'(",",f0.2)',advance='no') gflops

        call system_clock(count=c1, count_rate=cr)
        do r=1,r_max
            call dgemm('N', 'N', n, n, n, 1.0_rk, a, n, b, n, 1.0_rk, c_1, n)
        end do
        call system_clock(count=c2)
        elapsed = real(max(c2 - c1, 1_i64), rk)/real(cr, rk)
        gflops = real(r_max, rk)*n**3.0_rk*1.0D-9/elapsed
        write(*,'(a20,f7.2,a,e13.6)') 'EXTERNAL_DGEMM: ',gflops,' GFLOPS -- maxval(abs(c_1-c_0)): ',maxval(abs(c_1-c_0))
        c_1 = 0.0_rk
        write(831,'(",",f0.2)',advance='no') gflops

        call system_clock(count=c1, count_rate=cr)
        do r=1,r_max
            call dgemm_nn(n, n, n, 1.0_rk, a, n, b, n, 1.0_rk, c_1, n)
        end do
        call system_clock(count=c2)
        elapsed = real(max(c2 - c1, 1_i64), rk)/real(cr, rk)
        gflops = real(r_max, rk)*n**3.0_rk*1.0D-9/elapsed
        write(*,'(a20,f7.2,a,e13.6)') 'DGEMM_NN: ',gflops,' GFLOPS -- maxval(abs(c_1-c_0)): ',maxval(abs(c_1-c_0))
        c_1 = 0.0_rk
        write(831,'(",",f0.2)',advance='no') gflops

        call system_clock(count=c1, count_rate=cr)
        do r=1,r_max
            call mm_mnp(n, n, n, a, b, c_1)
        end do
        call system_clock(count=c2)
        elapsed = real(max(c2 - c1, 1_i64), rk)/real(cr, rk)
        gflops = real(r_max, rk)*n**3.0_rk*1.0D-9/elapsed
        write(*,'(a20,f7.2,a,e13.6)') 'MM_MNP: ',gflops,' GFLOPS -- maxval(abs(c_1-c_0)): ',maxval(abs(c_1-c_0))
        c_1 = 0.0_rk
        write(831,'(",",f0.2)',advance='no') gflops

        call system_clock(count=c1, count_rate=cr)
        do r=1,r_max
            call mm_mpn(n, n, n, a, b, c_1)
        end do
        call system_clock(count=c2)
        elapsed = real(max(c2 - c1, 1_i64), rk)/real(cr, rk)
        gflops = real(r_max, rk)*n**3.0_rk*1.0D-9/elapsed
        write(*,'(a20,f7.2,a,e13.6)') 'MM_MPN: ',gflops,' GFLOPS -- maxval(abs(c_1-c_0)): ',maxval(abs(c_1-c_0))
        c_1 = 0.0_rk
        write(831,'(",",f0.2)',advance='no') gflops

        call system_clock(count=c1, count_rate=cr)
        do r=1,r_max
            call mm_nmp(n, n, n, a, b, c_1)
        end do
        call system_clock(count=c2)
        elapsed = real(max(c2 - c1, 1_i64), rk)/real(cr, rk)
        gflops = real(r_max, rk)*n**3.0_rk*1.0D-9/elapsed
        write(*,'(a20,f7.2,a,e13.6)') 'MM_NMP: ',gflops,' GFLOPS -- maxval(abs(c_1-c_0)): ',maxval(abs(c_1-c_0))
        c_1 = 0.0_rk
        write(831,'(",",f0.2)',advance='no') gflops

        call system_clock(count=c1, count_rate=cr)
        do r=1,r_max
            call mm_npm(n, n, n, a, b, c_1)
        end do
        call system_clock(count=c2)
        elapsed = real(max(c2 - c1, 1_i64), rk)/real(cr, rk)
        gflops = real(r_max, rk)*n**3.0_rk*1.0D-9/elapsed
        write(*,'(a20,f7.2,a,e13.6)') 'MM_NPM: ',gflops,' GFLOPS -- maxval(abs(c_1-c_0)): ',maxval(abs(c_1-c_0))
        c_1 = 0.0_rk
        write(831,'(",",f0.2)',advance='no') gflops

        call system_clock(count=c1, count_rate=cr)
        do r=1,r_max
            call mm_pmn(n, n, n, a, b, c_1)
        end do
        call system_clock(count=c2)
        elapsed = real(max(c2 - c1, 1_i64), rk)/real(cr, rk)
        gflops = real(r_max, rk)*n**3.0_rk*1.0D-9/elapsed
        write(*,'(a20,f7.2,a,e13.6)') 'MM_PMN: ',gflops,' GFLOPS -- maxval(abs(c_1-c_0)): ',maxval(abs(c_1-c_0))
        c_1 = 0.0_rk
        write(831,'(",",f0.2)',advance='no') gflops

        call system_clock(count=c1, count_rate=cr)
        do r=1,r_max
            call mm_pnm(n, n, n, a, b, c_1)
        end do
        call system_clock(count=c2)
        elapsed = real(max(c2 - c1, 1_i64), rk)/real(cr, rk)
        gflops = real(r_max, rk)*n**3.0_rk*1.0D-9/elapsed
        write(*,'(a20,f7.2,a,e13.6)') 'MM_PNM: ',gflops,' GFLOPS -- maxval(abs(c_1-c_0)): ',maxval(abs(c_1-c_0))
        c_1 = 0.0_rk
        write(831,'(",",f0.2)',advance='no') gflops

        write(*,'(a)') ''
        write(831,'(a)') ''
        flush(831)
    end do

end program main

MAKEFILE:

#!/bin/bash

clear
export OMP_NUM_THREADS=1

ifort -Ofast -mavx2 -fp-model=fast=2 -fp-speculation=fast -ipo -qopenmp -lopenblas -o ifort-ofast-mavx2-lto-openmp main.f90
./ifort-ofast-mavx2-lto-openmp 16 1024 16
mv fort.831 ifort_mm_gflops.csv

gfortran -Ofast -march=native -flto -fwhole-program -fopenmp -lopenblas -o gfortran-ofast-marchnative-lto-openmp main.f90
./gfortran-ofast-marchnative-lto-openmp 16 1024 16
mv fort.831 gfortran_mm_gflops.csv