Multiplicative
RecurrentLayers.Multiplicative — TypeMultiplicative(cell, inp, state)Multiplicative RNN (Sutskever et al., 2011). Wraps a given cell, and performs the following forward pass.
Currently this wrapper does not support the following cells:
RHNCellRHNCellUnitFSRNNCellTLSTMCell
\[\begin{aligned} \mathbf{m}_t &= (\mathbf{W}_{mx} \mathbf{x}_t) \circ (\mathbf{W}_{mh} \mathbf{h}_{t-1}), \\ \mathbf{h}_{t} &= \text{cell}(\mathbf{x}_t, \mathbf{m}_t). \end{aligned}\]
Arguments
rcell: A recurrent cell constructor such as MGUCell, or
Flux.LSTMCell etc.
input_size: Defines the input dimension for the first layer.hidden_size: defines the dimension of the hidden layer.args...: positional arguments for thercell.
Keyword arguments
init_multiplicative_kernel:Initializer for the multiplicative input kernel. Default is glorot_uniform.init_multiplicativerecurrent_kernel:Initializer for the multiplicative recurrent kernel. Default is glorot_uniform.kwargs...: keyword arguments for thercell.
Forward
mrnn(inp, state)
mrnn(inp, (state, c_state))
mrnn(inp)Arguments
inp: The input to thercell. It should be a vector of sizeinput_size
or a matrix of size input_size x batch_size.
state: The hidden state of thercell, is single return. It should be a vector of sizehidden_sizeor a matrix of sizehidden_size x batch_size. If not provided, it is assumed to be a vector of zeros, initialized byFlux.initialstates.(state, cstate): A tuple containing the hidden and cell states of thercell. if double return. They should be vectors of sizehidden_sizeor matrices of sizehidden_size x batch_size. If not provided, they are assumed to be vectors of zeros, initialized byFlux.initialstates.
Returns
Either of
A tuple
(output, state), where both elements are given by the updated statenew_state, a tensor of sizehidden_sizeorhidden_size x batch_size, if thercellis single return (e.g.Flux.RNNCell).A tuple
(output, state), whereoutput = new_stateis the new hidden state andstate = (new_state, new_cstate)is the new hidden and cell state. They are tensors of sizehidden_sizeorhidden_size x batch_size. This applies if thercellis double return (e.g.Flux.LSTMCell).
Examples
When used to wrap a cell, Multiplicative will behave as the cell wrapped, taking input data in the same format, and returning states like the rcell would.
julia> using RecurrentLayers
julia> mrnn = Multiplicative(MGUCell, 3 => 5)
Multiplicative(
5×3 Matrix{Float32}, # 15 parameters
5×5 Matrix{Float32}, # 25 parameters
MGUCell(3 => 5), # 100 parameters
) # Total: 6 arrays, 140 parameters, 880 bytes.In order to make Multiplicative act on a full sequence it is possible to wrap it in a Flux.Recurrence layer.
julia> using RecurrentLayers, Flux
julia> wrap = Recurrence(Multiplicative(AntisymmetricRNNCell, 2 => 4))
Recurrence(
Multiplicative(
4×2 Matrix{Float32}, # 8 parameters
4×4 Matrix{Float32}, # 16 parameters
AntisymmetricRNNCell(2 => 4, tanh), # 32 parameters
),
) # Total: 6 arrays, 56 parameters, 552 bytes.