Backward Mode Auto-Diff in Modern Fortran

Hello,

I don’t know if you are interested, I’m also not sure if we are doing the right thing?

@St_Maxwell wrote a Fortran code for backward automatic differentiation. I modified it slightly, and I believe it can draw inspiration from the joddlehod/DNAD(Forward Mode) code and become a complete backward differentiation code.

zoziha/Auto-Diff: Fortran backward mode automatic differentiation. (github.com)

Currently there is only one example: fpm run --example demo1

demo1.f90
!> Backward auto diff
program main
    use auto_diff, only: node_t, operator(*), operator(+), exp, dp
    implicit none
    type(node_t) :: a, b, c
    type(node_t), pointer :: y

    call a%constructor(value=2.0_dp)
    call b%constructor(value=1.0_dp)
    call c%constructor(value=0.0_dp)

    print *, "demo1: y = (a + b*b)*exp(c)"

    y => (a + b*b)*exp(c)

    print *, "y     = ", y%get_value()
    call y%backward()

    print *, "dy/da = ", a%get_grad()
    print *, "dy/db = ", b%get_grad()
    print *, "dy/dc = ", c%get_grad()

    ! - - -
    
    call a%constructor(value=2.0_dp)
    call b%constructor(value=1.0_dp)
    
    print *, ""
    print *, "demo2: y = (a + b)*(b + 1.0)"
    
    y => (a + b)*(b + 1.0_dp)

    print *, "y     = ", y%get_value()
    
    call y%backward()

    print *, "dy/da = ", a%get_grad()
    print *, "dy/db = ", b%get_grad()

end program main

!>  demo1: y = (a + b*b)*exp(c)
!>  y     =    3.0000000000000000     
!>  dy/da =    1.0000000000000000
!>  dy/db =    2.0000000000000000
!>  dy/dc =    3.0000000000000000
!> 
!>  demo2: y = (a + b)*(b + 1.0)
!>  y     =    6.0000000000000000
!>  dy/da =    2.0000000000000000
!>  dy/db =    5.0000000000000000

PS. We are not from a mathematics class. If there are errors, let us discuss.

Related discussion

6 Likes

This is definitely of interest here. Thanks for posting it. I’ll let others comment on the best approach, I am not an expert on automatic differentiation.

1 Like

Thank you for your support @certik . :heart:
I just added more functions, including the sigmoid function. Here are two examples that seem to work well:

`fpm run --example demo2`
!> Sigmoid func & gate
program main

    use auto_diff, only: sigmoid
    use auto_diff, only: node_t, dp
    use auto_diff, only: operator(*), operator(+)
    implicit none
    type(node_t) :: w0, w1, w2, x0, x1
    type(node_t), pointer :: y

    call w0%constructor(value=2.0_dp)
    call w1%constructor(value=-3.0_dp)
    call w2%constructor(value=-3.0_dp)
    call x0%constructor(value=-1.0_dp)
    call x1%constructor(value=-2.0_dp)

    print *, "sigmoid demo: y = 1/(1 + exp(-z), z = w0*x0 + w1*x1 + w2"
    y => sigmoid(w0*x0 + w1*x1 + w2)

    print *, "y      = ", y%get_value() ! should be  0.73
    call y%backward()

    print *, "dy/dw0 = ", w0%get_grad() ! should be -0.20
    print *, "dy/dw1 = ", w1%get_grad() ! should be -0.39
    print *, "dy/dw2 = ", w2%get_grad() ! should be  0.20
    print *, "dy/dx0 = ", x0%get_grad() ! should be  0.39
    print *, "dy/dx1 = ", x1%get_grad() ! should be -0.59

end program main

!> sigmoid demo: y = 1/(1 + exp(-z), z = w0*x0 + w1*x1 + w2
!> y      =   0.73105857863000490     
!> dy/dw0 =  -0.19661193324148185
!> dy/dw1 =  -0.39322386648296370
!> dy/dw2 =   0.19661193324148185
!> dy/dx0 =   0.39322386648296370
!> dy/dx1 =  -0.58983579972444555
`fpm run --example demo3`
!> Staged solution
program main

    use auto_diff, only: sigmoid
    use auto_diff, only: node_t, dp
    use auto_diff, only: operator(*), operator(+), operator(/), operator(**)
    implicit none
    type(node_t) :: x1, x2
    type(node_t), pointer :: y

    call x1%constructor(value=3.0_dp)
    call x2%constructor(value=-4.0_dp)

    print *, "staged demo: y = (x1 + sigmoid(x2))/(sigmoid(x1) + (x1 + x2)**2)"
    y => (x1 + sigmoid(x2))/(sigmoid(x1) + (x1 + x2)**2.0_dp)

    print *, "y      = ", y%get_value()
    call y%backward()

    print *, "dy/dx1 = ", x1%get_grad()
    print *, "dy/dx2 = ", x2%get_grad()

end program main

!> staged demo: y = (x1 + sigmoid(x2))/(sigmoid(x1) + (x1 + x2)**2)
!> y      =    1.5456448841066441     
!> dy/dx1 =   -1.1068039935182090
!> dy/dx2 =   -1.5741410376065648

This package uses a pointer-based solution, because the Fortran2008 standard supports (Thanks to @St_Maxwell ):

Recursive allocatable components - as an alternative to recursive pointers in derived types.

If this method is feasible, details such as encapsulation and pointer destruction can be considered in the future. (I haven’t figured out how to use it to solve the gradient of a matrix.)

1 Like

I have not studied the code in detail yet, but it looks like a nice addition to the tools we have already.

One caveat: applying this technique in a large application is definitely not trivial. The literature on the subject describes all manner of techniques to reduce memory usage and speeding up the computations. Still, it ought to be useful in not-so-large applications :slight_smile:

2 Likes

Yes, forward differentiation and reverse (Backward) differentiation have different application scenarios and characteristics:

  1. Forward differentiation: low space complexity, high time complexity;
  2. Reverse differentiation: high space complexity, low time complexity.

This is a situation where more space is used for reducing time. It seems that it is important to reduce the space complexity of reverse differentiation.

For example, simplify typical gates to reduce the number of gates:
image

Link

自动微分 - 李理的博客 (fancyerii.github.io)

2 Likes

FWIW, I did something similar in one of the examples in my book: use this technique to solve Diophantine equations :slight_smile:

1 Like

Though not related to the topic (auto-diff), I recently try an “automatic translation” plug-in (?) of Chrome to read web pages, which translates the linked page above very nicely to English or other languages. I wonder the Chinese language may be relatively straightforward to translate to English (because of the similarity of the structure etc?).
auto-diff page (linked above)

This characterization isn’t quite right. Forward mode is faster when the number of outputs is small compared to the number of inputs. Reverse mode is faster when the number of inputs is small compared to the number of outputs. For things like 2nd order derivatives, the optimal is often reverse mode over forward mode.

1 Like

I thought it was the other way around? To get all sensitivities, reverse mode is the better choice when there are fewer outputs than inputs.

2 Likes

Yeah. Oops.