Wrappers

Wrappers for layers

RecurrentLayers.StackedRNNType
StackedRNN(rlayer, (input_size, hidden_size), args...;
    num_layers = 1, dropout = 0.0, kwargs...)

Constructs a stack of recurrent layers given the recurrent layer type.

Arguments:

  • rlayer: Any recurrent layer such as MGU, RHN, etc... or Flux.RNN, Flux.LSTM, etc.
  • input_size: Defines the input dimension for the first layer.
  • hidden_size: defines the dimension of the hidden layer.
  • num_layers: The number of layers to stack. Default is 1.
  • dropout: Value of dropout to apply between recurrent layers. Default is 0.0.
  • args...: Additional positional arguments passed to the recurrent layer.

Keyword arguments

  • kwargs...: Additional keyword arguments passed to the recurrent layers.

Examples

julia> using RecurrentLayers

julia> stac_rnn = StackedRNN(MGU, (3=>5); num_layers = 4)
StackedRNN(
  [
    MGU(3 => 10),                       # 90 parameters
    MGU(5 => 10),                       # 110 parameters
    MGU(5 => 10),                       # 110 parameters
    MGU(5 => 10),                       # 110 parameters
  ],
)         # Total: 12 trainable arrays, 420 parameters,
          # plus 4 non-trainable, 20 parameters, summarysize 2.711 KiB.
source

Wrappers for cells

RecurrentLayers.MultiplicativeType
Multiplicative(cell, inp, state)

Multiplicative RNN. Wraps a given cell, and performs the following forward pass.

Currently this wrapper does not support the following cells:

  • RHNCell
  • RHNCellUnit
  • FSRNNCell
  • TLSTMCell

\[\begin{aligned} \mathbf{m}_t &= (\mathbf{W}_{mx} \mathbf{x}_t) \circ (\mathbf{W}_{mh} \mathbf{h}_{t-1}), \\ \mathbf{h}_{t} &= \text{cell}(\mathbf{x}_t, \mathbf{m}_t). \end{aligned}\]

Arguments

  • rcell: A recurrent cell constructor such as MGUCell, or

Flux.LSTMCell etc.

  • input_size: Defines the input dimension for the first layer.
  • hidden_size: defines the dimension of the hidden layer.
  • args...: positional arguments for the rcell.

Keyword arguments

  • init_multiplicative_kernel:Initializer for the multiplicative input kernel. Default is glorot_uniform.
  • init_multiplicativerecurrent_kernel:Initializer for the multiplicative recurrent kernel. Default is glorot_uniform.
  • kwargs...: keyword arguments for the rcell.

Forward

mrnn(inp, state)
mrnn(inp, (state, c_state))
mrnn(inp)

Arguments

  • inp: The input to the rcell. It should be a vector of size input_size

or a matrix of size input_size x batch_size.

  • state: The hidden state of the rcell, is single return. It should be a vector of size hidden_size or a matrix of size hidden_size x batch_size. If not provided, it is assumed to be a vector of zeros, initialized by Flux.initialstates.
  • (state, cstate): A tuple containing the hidden and cell states of the rcell. if double return. They should be vectors of size hidden_size or matrices of size hidden_size x batch_size. If not provided, they are assumed to be vectors of zeros, initialized by Flux.initialstates.

Returns

Either of

  • A tuple (output, state), where both elements are given by the updated state new_state, a tensor of size hidden_size or hidden_size x batch_size, if the rcell is single return (e.g. Flux.RNNCell).

  • A tuple (output, state), where output = new_state is the new hidden state and state = (new_state, new_cstate) is the new hidden and cell state. They are tensors of size hidden_size or hidden_size x batch_size. This applies if the rcell is double return (e.g. Flux.LSTMCell).

Examples

When used to wrap a cell, Multiplicative will behave as the cell wrapped, taking input data in the same format, and returning states like the rcell would.

julia> using RecurrentLayers

julia> mrnn = Multiplicative(MGUCell, 3 => 5)
Multiplicative(
  5×3 Matrix{Float32},                  # 15 parameters
  5×5 Matrix{Float32},                  # 25 parameters
  MGUCell(3 => 5),                      # 90 parameters
)                   # Total: 5 arrays, 130 parameters, 792 bytes.

In order to make Multiplicative act on a full sequence it is possible to wrap it in a Flux.Recurrence layer.

julia> using RecurrentLayers, Flux

julia> wrap = Recurrence(Multiplicative(AntisymmetricRNNCell, 2 => 4))
Recurrence(
  Multiplicative(
    4×2 Matrix{Float32},                # 8 parameters
    4×4 Matrix{Float32},                # 16 parameters
    AntisymmetricRNNCell(2 => 4, tanh),  # 28 parameters
  ),
)                   # Total: 5 arrays, 52 parameters, 488 bytes.
source