Wrappers
Wrappers for layers
RecurrentLayers.StackedRNN
— TypeStackedRNN(rlayer, (input_size, hidden_size), args...;
num_layers = 1, dropout = 0.0, kwargs...)
Constructs a stack of recurrent layers given the recurrent layer type.
Arguments:
rlayer
: Any recurrent layer such as MGU, RHN, etc... orFlux.RNN
,Flux.LSTM
, etc.input_size
: Defines the input dimension for the first layer.hidden_size
: defines the dimension of the hidden layer.num_layers
: The number of layers to stack. Default is 1.dropout
: Value of dropout to apply between recurrent layers. Default is 0.0.args...
: Additional positional arguments passed to the recurrent layer.
Keyword arguments
kwargs...
: Additional keyword arguments passed to the recurrent layers.
Examples
julia> using RecurrentLayers
julia> stac_rnn = StackedRNN(MGU, (3=>5); num_layers = 4)
StackedRNN(
[
MGU(3 => 10), # 90 parameters
MGU(5 => 10), # 110 parameters
MGU(5 => 10), # 110 parameters
MGU(5 => 10), # 110 parameters
],
) # Total: 12 trainable arrays, 420 parameters,
# plus 4 non-trainable, 20 parameters, summarysize 2.711 KiB.
Wrappers for cells
RecurrentLayers.Multiplicative
— TypeMultiplicative(cell, inp, state)
Multiplicative RNN. Wraps a given cell
, and performs the following forward pass.
Currently this wrapper does not support the following cells:
RHNCell
RHNCellUnit
FSRNNCell
TLSTMCell
\[\begin{aligned} \mathbf{m}_t &= (\mathbf{W}_{mx} \mathbf{x}_t) \circ (\mathbf{W}_{mh} \mathbf{h}_{t-1}), \\ \mathbf{h}_{t} &= \text{cell}(\mathbf{x}_t, \mathbf{m}_t). \end{aligned}\]
Arguments
rcell
: A recurrent cell constructor such as MGUCell, or
Flux.LSTMCell
etc.
input_size
: Defines the input dimension for the first layer.hidden_size
: defines the dimension of the hidden layer.args...
: positional arguments for thercell
.
Keyword arguments
init_multiplicative_kernel
:Initializer for the multiplicative input kernel. Default is glorot_uniform.init_multiplicativerecurrent_kernel
:Initializer for the multiplicative recurrent kernel. Default is glorot_uniform.kwargs...
: keyword arguments for thercell
.
Forward
mrnn(inp, state)
mrnn(inp, (state, c_state))
mrnn(inp)
Arguments
inp
: The input to thercell
. It should be a vector of sizeinput_size
or a matrix of size input_size x batch_size
.
state
: The hidden state of thercell
, is single return. It should be a vector of sizehidden_size
or a matrix of sizehidden_size x batch_size
. If not provided, it is assumed to be a vector of zeros, initialized byFlux.initialstates
.(state, cstate)
: A tuple containing the hidden and cell states of thercell
. if double return. They should be vectors of sizehidden_size
or matrices of sizehidden_size x batch_size
. If not provided, they are assumed to be vectors of zeros, initialized byFlux.initialstates
.
Returns
Either of
A tuple
(output, state)
, where both elements are given by the updated statenew_state
, a tensor of sizehidden_size
orhidden_size x batch_size
, if thercell
is single return (e.g.Flux.RNNCell
).A tuple
(output, state)
, whereoutput = new_state
is the new hidden state andstate = (new_state, new_cstate)
is the new hidden and cell state. They are tensors of sizehidden_size
orhidden_size x batch_size
. This applies if thercell
is double return (e.g.Flux.LSTMCell
).
Examples
When used to wrap a cell, Multiplicative
will behave as the cell wrapped, taking input data in the same format, and returning states like the rcell
would.
julia> using RecurrentLayers
julia> mrnn = Multiplicative(MGUCell, 3 => 5)
Multiplicative(
5×3 Matrix{Float32}, # 15 parameters
5×5 Matrix{Float32}, # 25 parameters
MGUCell(3 => 5), # 90 parameters
) # Total: 5 arrays, 130 parameters, 792 bytes.
In order to make Multiplicative
act on a full sequence it is possible to wrap it in a Flux.Recurrence
layer.
julia> using RecurrentLayers, Flux
julia> wrap = Recurrence(Multiplicative(AntisymmetricRNNCell, 2 => 4))
Recurrence(
Multiplicative(
4×2 Matrix{Float32}, # 8 parameters
4×4 Matrix{Float32}, # 16 parameters
AntisymmetricRNNCell(2 => 4, tanh), # 28 parameters
),
) # Total: 5 arrays, 52 parameters, 488 bytes.