class AE(nn.Module):
def __init__(self):
super(AE, self).__init__()
self.encoder = nn.Sequential(
nn.Linear(28 * 28, 128),
nn.ReLU(),
nn.Linear(128, 64),
nn.ReLU(),
nn.Linear(64, 36),
nn.ReLU(),
nn.Linear(36, 18),
nn.ReLU(),
nn.Linear(18, 9)
)
self.decoder = nn.Sequential(
nn.Linear(9, 18),
nn.ReLU(),
nn.Linear(18, 36),
nn.ReLU(),
nn.Linear(36, 64),
nn.ReLU(),
nn.Linear(64, 128),
nn.ReLU(),
nn.Linear(128, 28 * 28),
nn.Sigmoid()
)
def forward(self, x):
encoded = self.encoder(x)
decoded = self.decoder(encoded)
return decoded