I wrote a simple function to compute the log sum of exponentials in Fortran. My first version uses array syntax. As a curiosity, I implemented the same function using loops. The loop-based version turned out to be faster.
To provide some context, I am using this function in a larger project where it is called many times within several nested loops, so performance is important.
I like more the version with array syntax (fun_logsum) because it is closer to a similar function I have in Matlab.
module mymodule
implicit none
private
public :: fun_logsum, fun_logsum_loops
contains
! fun_logsum, fun_logsum_loops
! DESCRIPTION
! Calculates the log-sum. The computation avoids
! overflows.
! https://gregorygundersen.com/blog/2020/02/09/log-sum-exp/
! https://nhigham.com/2021/01/05/what-is-the-log-sum-exp-function/
! INPUTS
! V: Vector with values of the different choices
! sigma: Standard deviation of the taste shock
! n: Size of the vector V
! OUTPUTS
! LogSum: Log sum of exponentials
! AUTHOR
! Alessandro Di Nola, March 2022.
function fun_logsum(V,sigma,n) result(LogSum)
! Uses array syntax
implicit none
! Declare inputs and function result:
integer, intent(in) :: n ! Size of the vector V
real(8), intent(in) :: V(n) ! Input vector
real(8), intent(in) :: sigma ! St. dev. of the preference shock
real(8) :: LogSum ! Result is log(sum(exp(v)))
! Declare locals:
real(8) :: max_val, sum_exp
! Step 1: Find the maximum value to ensure numerical stability
max_val = maxval(V)
! Step 2: Compute sum of exp( (v_i - maxval)/sigma )
sum_exp = sum( exp((V-max_val)/sigma) )
! Step 3: Return max_val + sigma*log(sum_exp)
LogSum = max_val + sigma*log(sum_exp)
end function fun_logsum
function fun_logsum_loops(V, sigma, n) result(LogSum)
! Uses loops
implicit none
! Declare inputs and function result:
integer, intent(in) :: n ! Size of the vector V
real(8), intent(in) :: V(n) ! Input vector
real(8), intent(in) :: sigma ! St. dev. of the preference shock
real(8) :: LogSum ! Result is log(sum(exp(v)))
! Declare locals:
real(8) :: max_val, sum_exp
integer :: i
! Step 1: Find the maximum value to ensure numerical stability
max_val = V(1)
do i = 2, n
if (V(i) > max_val) then
max_val = V(i)
endif
enddo
! Step 2: Compute sum of exp( (v_i - maxval)/sigma )
sum_exp = 0.0d0
do i = 1, n
sum_exp = sum_exp + exp((V(i) - max_val)/sigma)
enddo
! Step 3: Return max_val + sigma*log(sum_exp)
LogSum = max_val + sigma*log(sum_exp)
end function fun_logsum_loops
end module mymodule
!===============================================================================!
program main
use mymodule, only: fun_logsum, fun_logsum_loops
implicit none
! Declare variables:
integer, parameter :: n = 5 ! Number of alternatives
integer, parameter :: Nsim = 100000000 ! Number of iterations
real(8), parameter :: sigma = 0.02d0 ! Standard deviation of the preference shock
real(8) :: v_vec5(n), tic, toc, tic2, toc2
real(8) :: Vbar, Vbar2, Prob(n)
integer :: ii
! Execution starts here:
! CASE 1: simple
v_vec5 = [0.1d0, 0.2d0, 0.3d0, 0.4d0, 0.5d0] ! Example vector of utilities
Vbar = fun_logsum(v_vec5, sigma, n)
Vbar2 = fun_logsum_loops(v_vec5, sigma, n)
write(*,'(A,5E15.6)') "v_vec5 is: ", v_vec5
write(*,*) "fun_logsum: ", Vbar
write(*,*) "fun_logsum_loops: ", Vbar2
! CASE 2: big numbers
v_vec5 = [1.0d10, 2.0d10, 3.0d10, 4.0d10, 5.0d10] ! Example vector of utilities
Vbar = fun_logsum(v_vec5, sigma, n)
Vbar2 = fun_logsum_loops(v_vec5, sigma, n)
write(*,'(A,5E15.6)') "v_vec5 is: ", v_vec5
write(*,*) "fun_logsum: ", Vbar
write(*,*) "fun_logsum_loops: ", Vbar2
! CASE 3: all zeros
v_vec5 = [0.0d0, 0.0d0, 0.0d0, 0.0d0, 0.0d0] ! Example vector of utilities
Vbar = fun_logsum(v_vec5, sigma, n)
Vbar2 = fun_logsum_loops(v_vec5, sigma, n)
write(*,'(A,5E15.6)') "v_vec5 is: ", v_vec5
write(*,*) "fun_logsum: ", Vbar
write(*,*) "fun_logsum_loops: ", Vbar2
pause "Press Enter to continue..."
write(*,*) "Now testing speed..."
call cpu_time(tic) ! Start timer
do ii = 1, Nsim
Vbar = fun_logsum(v_vec5, sigma, n)
enddo
call cpu_time(toc) ! Stop timer
call cpu_time(tic2) ! Start timer
do ii = 1, Nsim
Vbar = fun_logsum_loops(v_vec5, sigma, n)
enddo
call cpu_time(toc2) ! Stop timer
write(*,*) "Time taken for fun_logsum: ", toc-tic, " seconds"
write(*,*) "Time taken for fun_logsum_loops: ", toc2-tic2, " seconds"
end program main