Source code for torchrecurrent.cells.multiplicativelstm_cell

import torch
from torch import Tensor
import torch.nn as nn
from typing import Optional, Callable, Union, Tuple
from ..base import BaseDoubleRecurrentLayer, BaseDoubleRecurrentCell


[docs] class MultiplicativeLSTM(BaseDoubleRecurrentLayer): r"""Multi-layer multiplicative long short-term memory network. [`arXiv <https://arxiv.org/abs/1609.07959>`_] Each layer consists of a :class:`MultiplicativeLSTMCell`, which updates the hidden and cell states according to: .. math:: \begin{aligned} m_t &= (W_{ih}^m x_t + b_{ih}^m) \circ (W_{hh}^m h_{t-1} + b_{hh}^m), \\ \hat{h}_t &= W_{ih}^h x_t + b_{ih}^h + W_{mh}^h m_t + b_{mh}^h, \\ i_t &= \sigma(W_{ih}^i x_t + b_{ih}^i + W_{mh}^i m_t + b_{mh}^i), \\ f_t &= \sigma(W_{ih}^f x_t + b_{ih}^f + W_{mh}^f m_t + b_{mh}^f), \\ o_t &= \sigma(W_{ih}^o x_t + b_{ih}^o + W_{mh}^o m_t + b_{mh}^o), \\ c_t &= f_t \circ c_{t-1} + i_t \circ \tanh(\hat{h}_t), \\ h_t &= \tanh(c_t) \circ o_t \end{aligned} where :math:`h_t` is the hidden state, :math:`c_t` the cell state, :math:`\sigma` is the sigmoid, and :math:`\circ` the Hadamard product. In a multilayer multiplicative LSTM, the input :math:`x^{(l)}_t` of the :math:`l`-th layer (:math:`l \ge 2`) is the hidden state :math:`h^{(l-1)}_t` of the previous layer multiplied by dropout :math:`\delta^{(l-1)}_t`, where each :math:`\delta^{(l-1)}_t` is a Bernoulli random variable which is 0 with probability :attr:`dropout`. Args: input_size: The number of expected features in the input `x`. hidden_size: The number of features in the hidden and cell states `h`, `c`. num_layers: Number of recurrent layers. E.g., setting ``num_layers=2`` would mean stacking two multiplicative LSTM layers, with the second receiving the outputs of the first. Default: 1 dropout: If non-zero, introduces a `Dropout` layer on the outputs of each layer except the last layer, with dropout probability equal to :attr:`dropout`. Default: 0 batch_first: If ``True``, then the input and output tensors are provided as `(batch, seq, feature)` instead of `(seq, batch, feature)`. Default: False bias: If ``False``, then the layer does not use input-side biases. Default: True recurrent_bias: If ``False``, then the layer does not use recurrent biases. Default: True multiplicative_bias: If ``False``, then the layer does not use multiplicative biases. Default: True kernel_init: Initializer for `W_{ih}`. Default: :func:`torch.nn.init.xavier_uniform_` recurrent_kernel_init: Initializer for `W_{hh}`. Default: :func:`torch.nn.init.xavier_uniform_` multiplicative_kernel_init: Initializer for `W_{mh}`. Default: :func:`torch.nn.init.normal_` bias_init: Initializer for `b_{ih}` when ``bias=True``. Default: :func:`torch.nn.init.zeros_` recurrent_bias_init: Initializer for `b_{hh}` when ``recurrent_bias=True``. Default: :func:`torch.nn.init.zeros_` multiplicative_bias_init: Initializer for `b_{mh}` when ``multiplicative_bias=True``. Default: :func:`torch.nn.init.zeros_` device: The desired device of parameters. dtype: The desired floating point type of parameters. Inputs: input, (h_0, c_0) - **input**: tensor of shape :math:`(L, H_{in})` for unbatched input, :math:`(L, N, H_{in})` when ``batch_first=False`` or :math:`(N, L, H_{in})` when ``batch_first=True`` containing the features of the input sequence. The input can also be a packed variable length sequence. See :func:`torch.nn.utils.rnn.pack_padded_sequence` or :func:`torch.nn.utils.rnn.pack_sequence` for details. - **h_0**: tensor of shape :math:`(\text{num_layers}, H_{out})` for unbatched input or :math:`(\text{num_layers}, N, H_{out})` containing the initial hidden state for each element in the input sequence. Defaults to zeros if not provided. - **c_0**: tensor of shape :math:`(\text{num_layers}, H_{out})` for unbatched input or :math:`(\text{num_layers}, N, H_{out})` containing the initial cell state for each element in the input sequence. Defaults to zeros if not provided. where: .. math:: \begin{aligned} N ={} & \text{batch size} \\ L ={} & \text{sequence length} \\ H_{in} ={} & \text{input\_size} \\ H_{out} ={} & \text{hidden\_size} \end{aligned} Outputs: output, (h_n, c_n) - **output**: tensor of shape :math:`(L, H_{out})` for unbatched input, :math:`(L, N, H_{out})` when ``batch_first=False`` or :math:`(N, L, H_{out})` when ``batch_first=True`` containing the output features `(h_t)` from the last layer of the multiplicative LSTM, for each `t`. If a :class:`torch.nn.utils.rnn.PackedSequence` has been given as the input, the output will also be a packed sequence. - **h_n**: tensor of shape :math:`(\text{num_layers}, H_{out})` for unbatched input or :math:`(\text{num_layers}, N, H_{out})` containing the final hidden state for each element in the sequence. - **c_n**: tensor of shape :math:`(\text{num_layers}, H_{out})` for unbatched input or :math:`(\text{num_layers}, N, H_{out})` containing the final cell state for each element in the sequence. Attributes: cells.{k}.weight_ih : the learnable input-hidden weights of the :math:`k`-th layer, of shape `(5*hidden_size, input_size)` for `k = 0`. Otherwise, the shape is `(5*hidden_size, hidden_size)`. cells.{k}.weight_hh : the learnable hidden-hidden weights of the :math:`k`-th layer, of shape `(hidden_size, hidden_size)`. cells.{k}.weight_mh : the learnable multiplicative-hidden weights of the :math:`k`-th layer, of shape `(4*hidden_size, hidden_size)`. cells.{k}.bias_ih : the learnable input-hidden biases of the :math:`k`-th layer, of shape `(5*hidden_size)`. Only present when ``bias=True``. cells.{k}.bias_hh : the learnable hidden-hidden biases of the :math:`k`-th layer, of shape `(hidden_size)`. Only present when ``recurrent_bias=True``. cells.{k}.bias_mh : the learnable multiplicative biases of the :math:`k`-th layer, of shape `(4*hidden_size)`. Only present when ``multiplicative_bias=True``. .. note:: All the weights and biases are initialized according to the provided initializers (`kernel_init`, `recurrent_kernel_init`, etc.). .. note:: ``batch_first`` argument is ignored for unbatched inputs. .. seealso:: :class:`MultiplicativeLSTMCell` Examples:: >>> rnn = MultiplicativeLSTM(10, 20, num_layers=2, dropout=0.1) >>> input = torch.randn(5, 3, 10) # (seq_len, batch, input_size) >>> h0 = torch.zeros(2, 3, 20) # (num_layers, batch, hidden_size) >>> c0 = torch.zeros(2, 3, 20) # (num_layers, batch, hidden_size) >>> output, (hn, cn) = rnn(input, (h0, c0)) """
[docs] def __init__( self, input_size: int, hidden_size: int, num_layers: int = 1, dropout: float = 0.0, batch_first: bool = False, **kwargs, ): super(MultiplicativeLSTM, self).__init__( input_size, hidden_size, num_layers, dropout, batch_first ) self.initialize_cells(MultiplicativeLSTMCell, **kwargs)
[docs] class MultiplicativeLSTMCell(BaseDoubleRecurrentCell): r"""A multiplicative LSTM cell. [`arXiv <https://arxiv.org/abs/1609.07959>`_] .. math:: \begin{aligned} \mathbf{m}(t) &= \bigl(\mathbf{W}_{ih}^{m}\,\mathbf{x}(t) + \mathbf{b}_{ih}^{m}\bigr)\,\circ\,\bigl(\mathbf{W}_{hh}^{m}\,\mathbf{h}(t-1) + \mathbf{b}_{hh}^{m}\bigr), \\ \hat{\mathbf{h}}(t) &= \mathbf{W}_{ih}^{h}\,\mathbf{x}(t) + \mathbf{b}_{ih}^{h} + \mathbf{W}_{mh}^{h}\,\mathbf{m}(t) + \mathbf{b}_{mh}^{h}, \\ \mathbf{i}(t) &= \sigma\bigl(\mathbf{W}_{ih}^{i}\,\mathbf{x}(t) + \mathbf{b}_{ih}^{i} + \mathbf{W}_{mh}^{i}\,\mathbf{m}(t) + \mathbf{b}_{mh}^{i}\bigr), \\ \mathbf{f}(t) &= \sigma\bigl(\mathbf{W}_{ih}^{f}\,\mathbf{x}(t) + \mathbf{b}_{ih}^{f} + \mathbf{W}_{mh}^{f}\,\mathbf{m}(t) + \mathbf{b}_{mh}^{f}\bigr), \\ \mathbf{o}(t) &= \sigma\bigl(\mathbf{W}_{ih}^{o}\,\mathbf{x}(t) + \mathbf{b}_{ih}^{o} + \mathbf{W}_{mh}^{o}\,\mathbf{m}(t) + \mathbf{b}_{mh}^{o}\bigr), \\ \mathbf{c}(t) &= \mathbf{f}(t)\circ\mathbf{c}(t-1) + \mathbf{i}(t)\circ\tanh\bigl(\hat{\mathbf{h}}(t)\bigr), \\ \mathbf{h}(t) &= \tanh\bigl(\mathbf{c}(t)\bigr)\circ\mathbf{o}(t) \end{aligned} where :math:`\circ` is the Hadamard product and :math:`\sigma` the sigmoid. Args: input_size: The number of expected features in the input ``x`` hidden_size: The number of features in the hidden/cell states ``h`` and ``c`` bias: If ``False``, the layer does not use input-side bias ``b_{ih}``. Default: ``True`` recurrent_bias: If ``False``, the layer does not use recurrent bias ``b_{hh}``. Default: ``True`` multiplicative_bias: If ``False``, the layer does not use multiplicative bias ``b_{mh}``. Default: ``True`` kernel_init: Initializer for ``W_{ih}``. Default: :func:`torch.nn.init.xavier_uniform_` recurrent_kernel_init: Initializer for ``W_{hh}``. Default: :func:`torch.nn.init.xavier_uniform_` multiplicative_kernel_init: Initializer for ``W_{mh}``. Default: :func:`torch.nn.init.normal_` bias_init: Initializer for ``b_{ih}`` when ``bias=True``. Default: :func:`torch.nn.init.zeros_` recurrent_bias_init: Initializer for ``b_{hh}`` when ``recurrent_bias=True``. Default: :func:`torch.nn.init.zeros_` multiplicative_bias_init: Initializer for ``b_{mh}`` when ``multiplicative_bias=True``. Default: :func:`torch.nn.init.zeros_` device: The desired device of parameters. dtype: The desired floating point type of parameters. Inputs: input, (h_0, c_0) - **input** of shape ``(batch, input_size)`` or ``(input_size,)``: tensor containing input features - **h_0** of shape ``(batch, hidden_size)`` or ``(hidden_size,)``: tensor containing the initial hidden state - **c_0** of shape ``(batch, hidden_size)`` or ``(hidden_size,)``: tensor containing the initial cell state If ``(h_0, c_0)`` is not provided, both default to zeros. Outputs: (h_1, c_1) - **h_1** of shape ``(batch, hidden_size)`` or ``(hidden_size,)``: tensor containing the next hidden state - **c_1** of shape ``(batch, hidden_size)`` or ``(hidden_size,)``: tensor containing the next cell state Variables: weight_ih: the learnable input–hidden weights, of shape ``(5*hidden_size, input_size)`` weight_hh: the learnable hidden–hidden weights, of shape ``(hidden_size, hidden_size)`` weight_mh: the learnable multiplicative–hidden weights, of shape ``(4*hidden_size, hidden_size)`` bias_ih: the learnable input–hidden biases, of shape ``(5*hidden_size)`` bias_hh: the learnable hidden–hidden biases, of shape ``(hidden_size)`` bias_mh: the learnable multiplicative biases, of shape ``(4*hidden_size)`` Examples:: >>> cell = MultiplicativeLSTMCell(10, 20) >>> x = torch.randn(5, 3, 10) # (time_steps, batch, input_size) >>> h, c = torch.zeros(3, 20), torch.zeros(3, 20) >>> out_h = [] >>> for t in range(x.size(0)): ... h, c = cell(x[t], (h, c)) ... out_h.append(h) >>> out_h = torch.stack(out_h, dim=0) # (time_steps, batch, hidden_size) """ __constants__ = [ "input_size", "hidden_size", "bias", "recurrent_bias", "multiplicative_bias", "kernel_init", "recurrent_kernel_init", "multiplicative_kernel_init", "bias_init", "recurrent_bias_init", "multiplicative_bias_init", ] weight_ih: Tensor weight_hh: Tensor weight_mh: Tensor bias_ih: Tensor bias_hh: Tensor bias_mh: Tensor
[docs] def __init__( self, input_size: int, hidden_size: int, bias: bool = True, recurrent_bias: bool = True, multiplicative_bias: bool = True, kernel_init: Callable = nn.init.xavier_uniform_, recurrent_kernel_init: Callable = nn.init.xavier_uniform_, multiplicative_kernel_init: Callable = nn.init.normal_, bias_init: Callable = nn.init.zeros_, recurrent_bias_init: Callable = nn.init.zeros_, multiplicative_bias_init: Callable = nn.init.zeros_, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ): super(MultiplicativeLSTMCell, self).__init__( input_size, hidden_size, bias, recurrent_bias, device=device, dtype=dtype ) self.kernel_init = kernel_init self.recurrent_kernel_init = recurrent_kernel_init self.multiplicative_kernel_init = multiplicative_kernel_init self.bias_init = bias_init self.recurrent_bias_init = recurrent_bias_init self.multiplicative_bias_init = multiplicative_bias_init self._register_tensors( { "weight_ih": ((5 * hidden_size, input_size), True), "weight_hh": ((hidden_size, hidden_size), True), "weight_mh": ((4 * hidden_size, hidden_size), True), "bias_ih": ((5 * hidden_size,), bias), "bias_hh": ((hidden_size,), recurrent_bias), "bias_mh": ((4 * hidden_size,), multiplicative_bias), } ) self.init_weights()
def init_weights(self): for name, param in self.named_parameters(): if "weight_ih" in name: self.kernel_init(param) elif "weight_hh" in name: self.recurrent_kernel_init(param) elif "weight_mh" in name: self.multiplicative_kernel_init(param) elif "bias_ih" in name: self.bias_init(param) elif "bias_hh" in name: self.recurrent_bias_init(param) elif "bias_mh" in name: self.multiplicative_bias_init(param) def forward( self, inp: Tensor, state: Optional[Union[Tensor, Tuple[Tensor, ...]]] = None ) -> Tuple[Tensor, Tensor]: state, c_state = self._check_states(state) self._validate_input(inp) self._validate_states((state, c_state)) inp, state, c_state, is_batched = self._preprocess_states(inp, (state, c_state)) inp_expanded = inp @ self.weight_ih.t() + self.bias_ih gxs1, gxs2, gxs3, gxs4, gxs5 = inp_expanded.chunk(5, 1) multiplicative_state = gxs1 * (state @ self.weight_hh.t() + self.bias_hh) mult_expanded = multiplicative_state @ self.weight_mh.t() + self.bias_mh gms1, gms2, gms3, gms4 = mult_expanded.chunk(4, 1) input_gate = torch.sigmoid(gxs2 + gms1) forget_gate = torch.sigmoid(gxs3 + gms2) candidate_state = torch.tanh(gxs4 + gms3) output_gate = torch.sigmoid(gxs5 + gms4) new_cstate = forget_gate * c_state + input_gate * candidate_state new_state = output_gate * torch.tanh(new_cstate) if not is_batched: new_state = new_state.squeeze(0) new_cstate = new_cstate.squeeze(0) return new_state, new_cstate