Source code for mighty.models.autoencoder
from collections import namedtuple
import torch.nn as nn
AutoencoderOutput = namedtuple("AutoencoderOutput",
("latent", "reconstructed"))
[docs]
class AutoencoderLinear(nn.Module):
"""
The simplest linear AutoEncoder.
Parameters
----------
fc_sizes: int
The sizes of fully connected layers of a resulting AutoEncoder.
Starts with the input dimension, ends with the embedding dimension.
Examples
--------
>>> AutoencoderLinear(784, 128)
AutoencoderLinear(
(encoder): Sequential(
(0): Linear(in_features=784, out_features=128, bias=True)
(1): ReLU(inplace=True)
)
(decoder): Linear(in_features=128, out_features=784, bias=True)
)
>>> AutoencoderLinear(784, 256, 128)
AutoencoderLinear(
(encoder): Sequential(
(0): Linear(in_features=784, out_features=256, bias=True)
(1): ReLU(inplace=True)
(2): Linear(in_features=256, out_features=128, bias=True)
(3): ReLU(inplace=True)
)
(decoder): Linear(in_features=128, out_features=784, bias=True)
)
"""
def __init__(self, *fc_sizes, p_drop=0.5, p_drop_input=0.25):
super().__init__()
encoder = []
for in_features, out_features in zip(fc_sizes[:-1], fc_sizes[1:]):
encoder.append(nn.Dropout(p=p_drop_input if in_features == fc_sizes[0] else p_drop))
encoder.append(nn.Linear(in_features, out_features))
encoder.append(nn.ReLU(inplace=True))
self.encoding_dim = fc_sizes[-1]
decoder = []
fc_sizes = fc_sizes[::-1]
for in_features, out_features in zip(fc_sizes[:-1], fc_sizes[1:]):
decoder.append(nn.Dropout(p=p_drop))
decoder.append(nn.Linear(in_features, out_features))
decoder.append(nn.ReLU(inplace=True))
decoder.pop() # remove the last ReLU layer
self.encoder = nn.Sequential(*encoder)
self.decoder = nn.Sequential(*decoder)
[docs]
def forward(self, x):
"""
AutoEncoder forward pass.
Parameters
----------
x : (B, C, H, W) torch.Tensor
Input images.
Returns
-------
AutoencoderOutput
A namedtuple with two keys:
`.encoded` - (B, V) latent representation of the input images.
`.decoded` - reconstructed input of the same shape as `x`.
"""
input_shape = x.shape
x = x.flatten(start_dim=1)
encoded = self.encoder(x)
decoded = self.decoder(encoded)
decoded = decoded.view(*input_shape)
return AutoencoderOutput(encoded, decoded)