I’m modeling a few concrete types that are based on the common abstract type. Each type has two methods, forward and backward. The methods expect arguments of different rank; for example, a concrete dense_layer type expects rank-1 arguments to its methods, while a concrete conv2d_layer type expects rank-3 arguments to its methods.
type, abstract :: base_layer
contains
procedure(forward_interface), deferred :: forward
procedure(backward_interface), deferred :: backward
end type base_layer
abstract interface
subroutine forward_interface(self, input)
import :: base_layer
class(base_layer), intent(in out) :: self
real, intent(in) :: input(..)
end subroutine forward_interface
subroutine backward_interface(self, input, loss)
import :: base_layer
class(base_layer), intent(in out) :: self
real, intent(in) :: input(..)
real, intent(in) :: loss(..)
end subroutine backward_interface
end interface
Initially, I thought about modeling this using the assumed-rank arrays and the select rank construct. For example, the forward method on the concrete dense_layer looks like this:
subroutine forward(self, input)
class(dense_layer), intent(in out) :: self
real, intent(in) :: input(..)
select rank(input)
rank(1)
self % z = matmul(input, self % weights) + self % biases
self % output = sigmoid(self % z)
rank default
print *, 'Warning: rank ', rank(input), ' is not valid'
end select
end subroutine forward
Here, only rank-1 argument is valid for this type, but I’m invoking all this select rank machinery only to be able to use input. OK, maybe not too bad, but it gets worse with the backward method which has two assumed-rank arguments:
subroutine backward(self, input, loss)
class(dense_layer), intent(in out) :: self
real, intent(in) :: input(..)
real, intent(in) :: loss(..)
select rank(loss)
rank(1)
select rank(input)
rank(1)
self % db = loss * sigmoid_prime(self % z)
self % dw = matmul( &
reshape(input, [size(input), 1]), &
reshape(self % db, [1, size(self % db)]) &
)
rank default
print *, 'Rank ', rank(input), ' is not valid for input(..)'
end select
rank default
print *, 'Rank ', rank(loss), ' is not valid for loss(..)'
end select
end subroutine backward
Notice that for each assumed-rank argument that I want to use, I need to wrap the algorithm in that many nested levels of select rank (or do I? I couldn’t find a different way to do it). At this point, assumed-rank arrays don’t seem like a good idea anymore. It feels like I’m shoehorning a language construct into a use-case where it’s not the best choice.
What are my alternatives? I thought about whether the deferred methods can be a generic, but I don’t see an obvious way to do that considering that the deferred method must have a specific interface.
A simpler solution may be to let my deferred methods expect the highest rank needed (in my case that would be 3), and then treat 1-d input in the algorithm itself, i.e. input(:,1,1).
In summary, I understand my options to be:
- Assumed-rank +
select rank, where the array shape inquiry is cleanly separated from the algorithm, but at the cost of much boilerplate, and possibly not as well supported by compilers; - Use rank-3 arguments for all concrete types, and sort out the shape on the caller-side and in the algorithm. The upside is that it needs less code and is well supported by compilers. The downside is that the array-shape logic would be mixed-in with the algorithm.
Currently option 2 seems more attractive to me.
What do you think about these options, and do you have any ideas for an alternative approach?
Thank you!