Source code for torchrecurrent.cells.nas_cell

# This file is a reimplementation in PyTorch of the NASCell as described in:
# "Neural Architecture Search with Reinforcement Learning".
# The original implementation in TensorFlow can be found here:
# https://www.tensorflow.org/addons/api_docs/python/tfa/rnn/NASCell
# No changes were made that alter the behavior of the cell compared to the original
# implementation; differences may be due to library-specific syntax.
#
# The original implementation is licensed under the Apache License, Version 2.0.
# This reimplementation is also licensed under the Apache License, Version 2.0.

#
# Copyright 2024 Francesco Martinuzzi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import torch.nn as nn
from typing import Optional, Tuple, Union, Callable
from torch import Tensor
from ..base import BaseDoubleRecurrentLayer, BaseDoubleRecurrentCell


[docs] class NAS(BaseDoubleRecurrentLayer): r"""Multi-layer neural architecture search recurrent network. [`arXiv <https://arxiv.org/abs/1611.01578>`_] Each layer consists of a :class:`NASCell`, which updates the hidden and cell states according to: .. math:: \begin{aligned} g(t) &= W_{ih} x(t) + b_{ih} + W_{hh} h(t-1) + b_{hh}, \\ [g_0,\dots,g_7] &= \mathrm{chunk}_8(g(t)), \\ o_k(t) &= \begin{cases} \sigma(g_k), & k\in\{0,2,5,7\},\\ \mathrm{ReLU}(g_k), & k\in\{1,3\},\\ \tanh(g_k), & k\in\{4,6\}, \end{cases} \\ \ell_1(t) &= \tanh(o_0 \circ o_1),\quad \ell_2(t) = \tanh(o_2 + o_3), \\ \ell_3(t) &= \tanh(o_4 \circ o_5),\quad \ell_4(t) = \sigma(o_6 + o_7), \\ \tilde{c}(t) &= \tanh(\ell_1 + c(t-1)),\quad c(t) = \ell_1 \circ \ell_2, \\ \ell_5(t) &= \tanh(\ell_3 + \ell_4),\quad h(t) = \tanh(c(t)\circ\ell_5(t)) \end{aligned} where :math:`h(t)` and :math:`c(t)` are the hidden and cell states at time :math:`t`, :math:`\sigma` is the sigmoid function, and :math:`\circ` denotes elementwise product. In a multilayer NAS, the input :math:`x^{(l)}_t` of the :math:`l`-th layer (:math:`l \ge 2`) is the hidden state :math:`h^{(l-1)}_t` of the previous layer multiplied by dropout :math:`\delta^{(l-1)}_t`, where each :math:`\delta^{(l-1)}_t` is a Bernoulli random variable which is 0 with probability :attr:`dropout`. Args: input_size: The number of expected features in the input `x`. hidden_size: The number of features in the hidden and cell states. num_layers: Number of recurrent layers. E.g., setting ``num_layers=2`` would mean stacking two NAS layers, with the second receiving the outputs of the first. Default: 1 dropout: If non-zero, introduces a `Dropout` layer on the outputs of each layer except the last layer, with dropout probability equal to :attr:`dropout`. Default: 0 batch_first: If ``True``, then the input and output tensors are provided as `(batch, seq, feature)` instead of `(seq, batch, feature)`. Default: False bias: If ``False``, then the layer does not use input-side bias `b_{ih}`. Default: True recurrent_bias: If ``False``, then the layer does not use recurrent bias `b_{hh}`. Default: True kernel_init: Initializer for `W_{ih}`. Default: :func:`torch.nn.init.xavier_uniform_` recurrent_kernel_init: Initializer for `W_{hh}`. Default: :func:`torch.nn.init.xavier_uniform_` bias_init: Initializer for `b_{ih}`. Default: :func:`torch.nn.init.zeros_` recurrent_bias_init: Initializer for `b_{hh}`. Default: :func:`torch.nn.init.zeros_` device: The desired device of parameters. dtype: The desired floating point type of parameters. Inputs: input, (h_0, c_0) - **input**: tensor of shape :math:`(L, H_{in})` for unbatched input, :math:`(L, N, H_{in})` when ``batch_first=False`` or :math:`(N, L, H_{in})` when ``batch_first=True`` containing the features of the input sequence. The input can also be a packed variable length sequence. See :func:`torch.nn.utils.rnn.pack_padded_sequence` or :func:`torch.nn.utils.rnn.pack_sequence` for details. - **h_0**: tensor of shape :math:`(\text{num_layers}, H_{out})` for unbatched input or :math:`(\text{num_layers}, N, H_{out})` containing the initial hidden state. Defaults to zeros if not provided. - **c_0**: tensor of shape :math:`(\text{num_layers}, H_{out})` for unbatched input or :math:`(\text{num_layers}, N, H_{out})` containing the initial cell state. Defaults to zeros if not provided. where: .. math:: \begin{aligned} N ={} & \text{batch size} \\ L ={} & \text{sequence length} \\ H_{in} ={} & \text{input\_size} \\ H_{out} ={} & \text{hidden\_size} \end{aligned} Outputs: output, (h_n, c_n) - **output**: tensor of shape :math:`(L, H_{out})` for unbatched input, :math:`(L, N, H_{out})` when ``batch_first=False`` or :math:`(N, L, H_{out})` when ``batch_first=True`` containing the output features `(h_t)` from the last layer of the NAS, for each `t`. If a :class:`torch.nn.utils.rnn.PackedSequence` has been given as the input, the output will also be a packed sequence. - **h_n**: tensor of shape :math:`(\text{num_layers}, H_{out})` for unbatched input or :math:`(\text{num_layers}, N, H_{out})` containing the final hidden state for each element in the sequence. - **c_n**: tensor of shape :math:`(\text{num_layers}, H_{out})` for unbatched input or :math:`(\text{num_layers}, N, H_{out})` containing the final cell state for each element in the sequence. Attributes: cells.{k}.weight_ih : the learnable input-hidden weights of the :math:`k`-th layer, of shape `(8*hidden_size, input_size)` for `k = 0`. Otherwise, the shape is `(8*hidden_size, hidden_size)`. cells.{k}.weight_hh : the learnable hidden-hidden weights of the :math:`k`-th layer, of shape `(8*hidden_size, hidden_size)`. cells.{k}.bias_ih : the learnable input-hidden biases of the :math:`k`-th layer, of shape `(8*hidden_size)`. Only present when ``bias=True``. cells.{k}.bias_hh : the learnable hidden-hidden biases of the :math:`k`-th layer, of shape `(8*hidden_size)`. Only present when ``recurrent_bias=True``. .. note:: All the weights and biases are initialized according to the provided initializers (`kernel_init`, `recurrent_kernel_init`, etc.). .. note:: ``batch_first`` argument is ignored for unbatched inputs. .. seealso:: :class:`NASCell` Examples:: >>> rnn = NAS(10, 20, num_layers=2, dropout=0.1) >>> input = torch.randn(5, 3, 10) # (seq_len, batch, input_size) >>> h0 = torch.zeros(2, 3, 20) >>> c0 = torch.zeros(2, 3, 20) >>> output, (hn, cn) = rnn(input, (h0, c0)) """
[docs] def __init__( self, input_size: int, hidden_size: int, num_layers: int = 1, dropout: float = 0.0, batch_first: bool = False, **kwargs, ): super(NAS, self).__init__(input_size, hidden_size, num_layers, dropout, batch_first) self.initialize_cells(NASCell, **kwargs)
[docs] class NASCell(BaseDoubleRecurrentCell): r"""A Neural Architecture Search (NAS) cell. [`arXiv <https://arxiv.org/abs/1611.01578>`_] .. math:: \mathbf{g}(t) &= \mathbf{W}_{ih}\,\mathbf{x}(t) + \mathbf{b}_{ih} + \mathbf{W}_{hh}\,\mathbf{h}(t-1) + \mathbf{b}_{hh}, \\[6pt] [g_0,\dots,g_7] &= \mathrm{chunk}_8\bigl(\mathbf{g}(t)\bigr), \\[6pt] o_k(t) &= \begin{cases} \sigma(g_k), & k\in\{0,2,5,7\},\\ \mathrm{ReLU}(g_k), & k\in\{1,3\},\\ \tanh(g_k), & k\in\{4,6\}, \end{cases} \\[6pt] \ell_1(t) &= \tanh\bigl(o_0\,\circ\,o_1\bigr),\quad \ell_2(t) = \tanh\bigl(o_2 + o_3\bigr), \\[3pt] \ell_3(t) &= \tanh\bigl(o_4\,\circ\,o_5\bigr),\quad \ell_4(t) = \sigma\bigl(o_6 + o_7\bigr), \\[6pt] \tilde{c}(t) &= \tanh\bigl(\ell_1 + c(t-1)\bigr),\quad c(t) = \ell_1\,\circ\,\ell_2, \\[3pt] \ell_5(t) &= \tanh\bigl(\ell_3 + \ell_4\bigr),\quad h(t) = \tanh\bigl(c(t)\,\circ\,\ell_5(t)\bigr). Args: input_size: The number of expected features in the input ``x`` hidden_size: The number of features in the hidden and cell states bias: If ``False``, the layer does not use input-side bias ``b_{ih}``. Default: ``True`` recurrent_bias: If ``False``, the layer does not use recurrent bias ``b_{hh}``. Default: ``True`` kernel_init: Initializer for ``W_{ih}``. Default: :func:`torch.nn.init.xavier_uniform_` recurrent_kernel_init: Initializer for ``W_{hh}``. Default: :func:`torch.nn.init.xavier_uniform_` bias_init: Initializer for ``b_{ih}`` when ``bias=True``. Default: :func:`torch.nn.init.zeros_` recurrent_bias_init: Initializer for ``b_{hh}`` when ``recurrent_bias=True``. Default: :func:`torch.nn.init.zeros_` device: The desired device of parameters. dtype: The desired floating point type of parameters. Inputs: input, (h_0, c_0) - **input** of shape ``(batch, input_size)`` or ``(input_size,)``: tensor containing input features - **h_0** of shape ``(batch, hidden_size)`` or ``(hidden_size,)``: initial hidden state - **c_0** of shape ``(batch, hidden_size)`` or ``(hidden_size,)``: initial cell state If **(h_0, c_0)** is not provided, both default to zero. Outputs: (h_1, c_1) - **h_1** of shape ``(batch, hidden_size)`` or ``(hidden_size,)``: next hidden state - **c_1** of shape ``(batch, hidden_size)`` or ``(hidden_size,)``: next cell state Variables: weight_ih: the learnable input–hidden weights, of shape ``(8*hidden_size, input_size)`` weight_hh: the learnable hidden–hidden weights, of shape ``(8*hidden_size, hidden_size)`` bias_ih: the learnable input–hidden biases, of shape ``(8*hidden_size)`` bias_hh: the learnable hidden–hidden biases, of shape ``(8*hidden_size)`` Examples:: >>> cell = NASCell(10, 20) >>> x = torch.randn(5, 3, 10) # (time_steps, batch, input_size) >>> h = torch.zeros(3, 20) >>> c = torch.zeros(3, 20) >>> out_h = [] >>> for t in range(x.size(0)): ... h, c = cell(x[t], (h, c)) ... out_h.append(h) >>> out_h = torch.stack(out_h, dim=0) # (time_steps, batch, hidden_size) """ __constants__ = [ "input_size", "hidden_size", "bias", "recurrent_bias", "kernel_init", "recurrent_kernel_init", "bias_init", "recurrent_bias_init", ] weight_ih: Tensor weight_hh: Tensor bias_ih: Tensor bias_hh: Tensor
[docs] def __init__( self, input_size: int, hidden_size: int, bias: bool = True, recurrent_bias: bool = True, kernel_init: Callable = nn.init.xavier_uniform_, recurrent_kernel_init: Callable = nn.init.xavier_uniform_, bias_init: Callable = nn.init.zeros_, recurrent_bias_init: Callable = nn.init.zeros_, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ): super(NASCell, self).__init__( input_size, hidden_size, bias, recurrent_bias, device=device, dtype=dtype ) self.kernel_init = kernel_init self.recurrent_kernel_init = recurrent_kernel_init self.bias_init = bias_init self.recurrent_bias_init = recurrent_bias_init self._default_register_tensors( input_size, hidden_size, ih_mult=8, hh_mult=8, bias=bias, recurrent_bias=recurrent_bias, ) self.init_weights()
def forward( self, inp: Tensor, state: Optional[Union[Tensor, Tuple[Tensor, ...]]] = None ) -> Tuple[Tensor, Tensor]: state, c_state = self._check_states(state) self._validate_input(inp) self._validate_states((state, c_state)) inp, state, c_state, is_batched = self._preprocess_states(inp, (state, c_state)) gates = ( inp @ self.weight_ih.t() + self.bias_ih + state @ self.weight_hh.t() + self.bias_hh ) g0, g1, g2, g3, g4, g5, g6, g7 = gates.chunk(8, 1) layer1_0 = torch.sigmoid(g0) layer1_1 = torch.relu(g1) layer1_2 = torch.sigmoid(g2) layer1_3 = torch.relu(g3) layer1_4 = torch.tanh(g4) layer1_5 = torch.sigmoid(g5) layer1_6 = torch.tanh(g6) layer1_7 = torch.sigmoid(g7) l2_0 = torch.tanh(layer1_0 * layer1_1) l2_1 = torch.tanh(layer1_2 + layer1_3) l2_2 = torch.tanh(layer1_4 * layer1_5) l2_3 = torch.sigmoid(layer1_6 + layer1_7) l2_0 = torch.tanh(l2_0 + c_state) new_cstate = l2_0 * l2_1 l3_1 = torch.tanh(l2_2 + l2_3) new_state = torch.tanh(new_cstate * l3_1) if not is_batched: new_state = new_state.squeeze(0) new_cstate = new_cstate.squeeze(0) return new_state, new_cstate