TorchRecurrent#
Welcome to the TorchRecurrent documentation!
TorchRecurrent provides a large collection of recurrent neural network (RNN) cells and layers, all implemented in PyTorch with a consistent API and extended customization options.
Highlights#
30+ recurrent cells (like
torch.nn.LSTMCell
) implemented from classic and modern papers30+ high-level recurrent layers (like
torch.nn.LSTM
) wrapping the cellsPyTorch-like interface so you can drop in models with minimal changes
Extra options such as custom initializers, bias settings, and advanced variants
Installation#
If you already have PyTorch installed:
pip install torchrecurrent
Quick Example#
Using a recurrent cell directly:
import torch
from torchrecurrent import LEMCell
cell = LEMCell(input_size=10, hidden_size=20)
x = torch.randn(5, 10) # (time, input_size)
h = torch.zeros(20)
outputs = []
for t in range(x.size(0)):
h = cell(x[t], h)
outputs.append(h)
outputs = torch.stack(outputs, dim=0) # (time, hidden_size)
Or using the layer abstraction:
from torchrecurrent import LEM
rnn = LEM(input_size=10, hidden_size=20, num_layers=2)
x = torch.randn(3, 5, 10) # (batch, time, input_size)
output, hn = rnn(x)