I’m modeling a few concrete types that are based on the common abstract type. Each type has two methods, forward
and backward
. The methods expect arguments of different rank; for example, a concrete dense_layer
type expects rank-1 arguments to its methods, while a concrete conv2d_layer
type expects rank-3 arguments to its methods.
type, abstract :: base_layer
contains
procedure(forward_interface), deferred :: forward
procedure(backward_interface), deferred :: backward
end type base_layer
abstract interface
subroutine forward_interface(self, input)
import :: base_layer
class(base_layer), intent(in out) :: self
real, intent(in) :: input(..)
end subroutine forward_interface
subroutine backward_interface(self, input, loss)
import :: base_layer
class(base_layer), intent(in out) :: self
real, intent(in) :: input(..)
real, intent(in) :: loss(..)
end subroutine backward_interface
end interface
Initially, I thought about modeling this using the assumed-rank arrays and the select rank
construct. For example, the forward
method on the concrete dense_layer
looks like this:
subroutine forward(self, input)
class(dense_layer), intent(in out) :: self
real, intent(in) :: input(..)
select rank(input)
rank(1)
self % z = matmul(input, self % weights) + self % biases
self % output = sigmoid(self % z)
rank default
print *, 'Warning: rank ', rank(input), ' is not valid'
end select
end subroutine forward
Here, only rank-1 argument is valid for this type, but I’m invoking all this select rank
machinery only to be able to use input
. OK, maybe not too bad, but it gets worse with the backward
method which has two assumed-rank arguments:
subroutine backward(self, input, loss)
class(dense_layer), intent(in out) :: self
real, intent(in) :: input(..)
real, intent(in) :: loss(..)
select rank(loss)
rank(1)
select rank(input)
rank(1)
self % db = loss * sigmoid_prime(self % z)
self % dw = matmul( &
reshape(input, [size(input), 1]), &
reshape(self % db, [1, size(self % db)]) &
)
rank default
print *, 'Rank ', rank(input), ' is not valid for input(..)'
end select
rank default
print *, 'Rank ', rank(loss), ' is not valid for loss(..)'
end select
end subroutine backward
Notice that for each assumed-rank argument that I want to use, I need to wrap the algorithm in that many nested levels of select rank
(or do I? I couldn’t find a different way to do it). At this point, assumed-rank arrays don’t seem like a good idea anymore. It feels like I’m shoehorning a language construct into a use-case where it’s not the best choice.
What are my alternatives? I thought about whether the deferred methods can be a generic, but I don’t see an obvious way to do that considering that the deferred method must have a specific interface.
A simpler solution may be to let my deferred methods expect the highest rank needed (in my case that would be 3), and then treat 1-d input in the algorithm itself, i.e. input(:,1,1)
.
In summary, I understand my options to be:
- Assumed-rank +
select rank
, where the array shape inquiry is cleanly separated from the algorithm, but at the cost of much boilerplate, and possibly not as well supported by compilers; - Use rank-3 arguments for all concrete types, and sort out the shape on the caller-side and in the algorithm. The upside is that it needs less code and is well supported by compilers. The downside is that the array-shape logic would be mixed-in with the algorithm.
Currently option 2 seems more attractive to me.
What do you think about these options, and do you have any ideas for an alternative approach?
Thank you!