[coarray] How to use a custom type with co_reduce operation?

So, I was trying to mimic a MPI code as shown in this paper after a recommendation by @ivanpribec where they show how to use MPI_OP_create to create an operator that is invoked during MPIAllreduce. I wanted to do it with coarray just to learn.

The basic idea of what they do is to implement a kahan sum operator that will work during the reduction of the sum result from every process. Since the operator can only take the same types as input and output as it should work as a element-wise reduction they used a struct to encapsulate the sum and the correction terms for the kahan sum.

When reading about co_reduce I saw that one can invoke a pure procedure, and the documentation states the following:

OPERATION
pure function with two scalar nonallocatable arguments, which shall be nonpolymorphic and have the same type and type parameters as A. The function shall return a nonallocatable scalar of the same type and type parameters as A. The function shall be the same on all images and with regards to the arguments mathematically commutative and associative. Note that OPERATION may not be an elemental function, unless it is an intrisic function.

In principle I see no limitation to use a simple derived type to define the input/output arguments for the operation. So I thought about the following (that compiles but doesn’t work)

module custom_sum
    use iso_fortran_env, only: sp=>real32
    implicit none
    integer, parameter :: chunk32 = 64
    
    type :: esum
     real(sp) :: s !> sum
     real(sp) :: c !> corrector
    end type
    
    contains
    
    function psum(a) result(sout)
      integer, parameter :: wp = sp
      integer, parameter :: chunk = chunk32
      real(wp), intent(in) :: a(:)
      real(wp) :: sout, c
      ! --
      real(wp) :: sbatch(chunk)
      real(wp) :: cbatch(chunk)
      integer :: i, dr, rr
      ! -----------------------------
      dr = size(a)/chunk
      rr = size(a) - dr*chunk
      !------ serial vectorizable sum
      sbatch(:) = 0.0_wp
      cbatch(:) = 0.0_wp
      do concurrent( i = 1:dr )
          call vkahans( a(chunk*i-chunk+1:chunk*i) , sbatch(1:chunk) , cbatch(1:chunk) )
      end do
      call vkahans( a(size(a)-rr+1:size(a)) , sbatch(1:rr) , cbatch(1:rr) )
      !------ secuential reduction
      sout = 0.0_wp
      do i = 1, chunk
          call vkahans( sbatch(i) , sout , cbatch(i) )
      end do
      !> up to here everything works like a charm
      ! If instead of the next block I just use 
      ! call co_sum( sout ) ! everything fine, but no correction for the parallel reduction then.
      !------ parallel reduction with correction
      block
          type(esum) :: A
          A%s= sout
          A%c= cbatch(chunk)
          call CO_REDUCE(A, operation = mysum )
          sout = A%s
      end block
    contains
        pure function mysum(a,b) result(r)
            type(esum) , value :: a 
            type(esum) , value :: b 
            type(esum)  :: r 
            !-----------------------
            real(wp) :: s, t, y, c
            s = a%s
            c = a%c
            
            y = b%s - c
            t = s + y
            r%c = (t - s) - y
            r%s = t
        end function
    end function
    
    elemental subroutine vkahans(a,s,c)
        real(sp), intent(in) :: a
        real(sp), intent(inout) :: s
        real(sp), intent(inout) :: c
        real(sp) :: t, y
        y = a - c
        t = s + y
        c = (t - s) - y
        s = t
    end subroutine

end module
    
program main
    use custom_sum
    use iso_fortran_env, only: sp=>real32, dp=>real64
    implicit none

    ! Variables
    real(sp), allocatable :: x4(:)
    real(dp), allocatable :: x8(:)
    real(sp) :: s4, sk
    real(dp) :: s8
    integer :: n = 1000000
    
    allocate( x4(n) , x8(n) )
    CALL RANDOM_INIT(.true., .true.)
    call RANDOM_NUMBER( x4 )
    x4 = x4 / (n/10)
    x8 = dble( x4 )
    ! -- Sum and co_sum with original data type
    s4 = sum( x4 )
    call CO_SUM( s4 )
    ! -- Sum and co_sum with kahan sum operator
    sk = psum( x4 )
    ! -- sum and co_sum on double precision for comparison
    s8 = sum( x8 )
    call CO_SUM( s8 )
    
    if(this_image()==1) then
        print *, 'Parallel sum4: ',s4
        print *, 'Parallel sumk: ',sk
        print *, 'Parallel sum8: ',s8
        pause
    end if
end program main

I tested it with ifx 2023 with -qcoarray=distributed and 8 images.
removing the not-working co_reduce I get results such as:

 Parallel sum4:    40.02064
 Parallel sumk:    40.02067
 Parallel sum8:    40.0206743357571

EDIT
Putting it back gives non-sense, in the sense that, if I define the array as follows:

allocate( x4(1) , x8(1) )
x4 = real( this_image() )
x8 = dble( x4 )

I get as result of psum(x4) a value of 8 instead of 36. This leads me to the hypothesis that the result of the operation in mysum is not being assigned at each call, and only the last value is retained…

before going for the class definition, I tried to pass the corrector term with import or a common but as the function must be pure, this is not possible … Any ideas?