I’m modeling a few concrete types that are based on the common abstract type. Each type has two methods,
backward. The methods expect arguments of different rank; for example, a concrete
dense_layer type expects rank-1 arguments to its methods, while a concrete
conv2d_layer type expects rank-3 arguments to its methods.
type, abstract :: base_layer contains procedure(forward_interface), deferred :: forward procedure(backward_interface), deferred :: backward end type base_layer abstract interface subroutine forward_interface(self, input) import :: base_layer class(base_layer), intent(in out) :: self real, intent(in) :: input(..) end subroutine forward_interface subroutine backward_interface(self, input, loss) import :: base_layer class(base_layer), intent(in out) :: self real, intent(in) :: input(..) real, intent(in) :: loss(..) end subroutine backward_interface end interface
Initially, I thought about modeling this using the assumed-rank arrays and the
select rank construct. For example, the
forward method on the concrete
dense_layer looks like this:
subroutine forward(self, input) class(dense_layer), intent(in out) :: self real, intent(in) :: input(..) select rank(input) rank(1) self % z = matmul(input, self % weights) + self % biases self % output = sigmoid(self % z) rank default print *, 'Warning: rank ', rank(input), ' is not valid' end select end subroutine forward
Here, only rank-1 argument is valid for this type, but I’m invoking all this
select rank machinery only to be able to use
input. OK, maybe not too bad, but it gets worse with the
backward method which has two assumed-rank arguments:
subroutine backward(self, input, loss) class(dense_layer), intent(in out) :: self real, intent(in) :: input(..) real, intent(in) :: loss(..) select rank(loss) rank(1) select rank(input) rank(1) self % db = loss * sigmoid_prime(self % z) self % dw = matmul( & reshape(input, [size(input), 1]), & reshape(self % db, [1, size(self % db)]) & ) rank default print *, 'Rank ', rank(input), ' is not valid for input(..)' end select rank default print *, 'Rank ', rank(loss), ' is not valid for loss(..)' end select end subroutine backward
Notice that for each assumed-rank argument that I want to use, I need to wrap the algorithm in that many nested levels of
select rank (or do I? I couldn’t find a different way to do it). At this point, assumed-rank arrays don’t seem like a good idea anymore. It feels like I’m shoehorning a language construct into a use-case where it’s not the best choice.
What are my alternatives? I thought about whether the deferred methods can be a generic, but I don’t see an obvious way to do that considering that the deferred method must have a specific interface.
A simpler solution may be to let my deferred methods expect the highest rank needed (in my case that would be 3), and then treat 1-d input in the algorithm itself, i.e.
In summary, I understand my options to be:
- Assumed-rank +
select rank, where the array shape inquiry is cleanly separated from the algorithm, but at the cost of much boilerplate, and possibly not as well supported by compilers;
- Use rank-3 arguments for all concrete types, and sort out the shape on the caller-side and in the algorithm. The upside is that it needs less code and is well supported by compilers. The downside is that the array-shape logic would be mixed-in with the algorithm.
Currently option 2 seems more attractive to me.
What do you think about these options, and do you have any ideas for an alternative approach?