In [1]:
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
import json
import math
import random
import numpy as np
import scipy as sp
import scipy.stats as st
import scipy.integrate as integrate
from scipy.stats import multivariate_normal
from sklearn import linear_model
from sklearn.utils.testing import ignore_warnings
from sklearn.exceptions import ConvergenceWarning
import statsmodels.api as sm
from matplotlib.colors import LogNorm
import pickle

from joblib import Parallel, delayed
import multiprocessing
from collections import namedtuple
from itertools import count

import cProfile
from datetime import datetime

sns.set_style("whitegrid")
sns.set_palette("colorblind")
palette = sns.color_palette()
figsize = (15,8)
legend_fontsize = 16

from matplotlib import rc
rc('font',**{'family':'sans-serif'})
rc('text', usetex=True)
rc('text.latex',preamble=r'\usepackage[utf8]{inputenc}')
rc('text.latex',preamble=r'\usepackage[russian]{babel}')
rc('figure', **{'dpi': 300})



In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.autograd import Variable
from torchvision.utils import save_image

import torchvision.datasets as datasets
from torchvision.utils import make_grid

## VAE

In [3]:
## загружаем MNIST
image_size = 28
image_shape = (1, image_size, image_size)

cuda = True if torch.cuda.is_available() else False
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

batch_size = 64
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "data/mnist",
        train=True,
        download=True,
        transform=transforms.ToTensor(),
    ),
    batch_size=batch_size,
    shuffle=True,
)
testloader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "data/mnist",
        train=False,
        download=True,
        transform=transforms.ToTensor(),
    ),
    batch_size=batch_size,
    shuffle=False,
)

In [4]:
class VAE(nn.Module):
    def __init__(self, x_dim, h_dim1, h_dim2, z_dim):
        super(VAE, self).__init__()
        
        # encoder
        self.fc1 = nn.Linear(x_dim, h_dim1)
        self.fc2 = nn.Linear(h_dim1, h_dim2)
        self.fc31 = nn.Linear(h_dim2, z_dim)
        self.fc32 = nn.Linear(h_dim2, z_dim)
        # decoder
        self.fc4 = nn.Linear(z_dim, h_dim2)
        self.fc5 = nn.Linear(h_dim2, h_dim1)
        self.fc6 = nn.Linear(h_dim1, x_dim)
        
    def encoder(self, x):
        h = F.relu(self.fc1(x))
        h = F.relu(self.fc2(h))
        return self.fc31(h), self.fc32(h)
    
    def sampling(self, mu, log_var):
        std = torch.exp(0.5*log_var)
        eps = torch.randn_like(std)
        return eps.mul(std).add_(mu) 
        
    def decoder(self, z):
        h = F.relu(self.fc4(z))
        h = F.relu(self.fc5(h))
        return torch.sigmoid(self.fc6(h)) 
    
    def forward(self, x):
        mu, log_var = self.encoder(x.view(-1, 784))
        z = self.sampling(mu, log_var)
        return self.decoder(z), mu, log_var

# build model
vae = VAE(x_dim=784, h_dim1= 512, h_dim2=256, z_dim=2)
if torch.cuda.is_available():
    vae.cuda()

In [5]:
vae

VAE(
  (fc1): Linear(in_features=784, out_features=512, bias=True)
  (fc2): Linear(in_features=512, out_features=256, bias=True)
  (fc31): Linear(in_features=256, out_features=2, bias=True)
  (fc32): Linear(in_features=256, out_features=2, bias=True)
  (fc4): Linear(in_features=2, out_features=256, bias=True)
  (fc5): Linear(in_features=256, out_features=512, bias=True)
  (fc6): Linear(in_features=512, out_features=784, bias=True)
)

In [6]:
dataloader.dataset

Dataset MNIST
    Number of datapoints: 60000
    Root location: data/mnist
    Split: Train
    StandardTransform
Transform: ToTensor()

In [7]:
optimizer = optim.Adam(vae.parameters())
# return reconstruction error + KL divergence losses
def loss_function(recon_x, x, mu, log_var):
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
    KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    return BCE + KLD

In [8]:
def train(epoch, dataloader=dataloader, testloader=testloader, sample_dir=None, show_test_loss=False):
    vae.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(dataloader):
        data = data.cuda()
        optimizer.zero_grad()
        
        recon_batch, mu, log_var = vae(data)
        loss = loss_function(recon_batch, data, mu, log_var)
        
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        
    if sample_dir is not None:
        with torch.no_grad():
            z = torch.randn(25, 2).cuda()
            sample = vae.decoder(z).cuda()
            save_image(sample.view(25, 1, 28, 28), "%s/%02d.png" % (sample_dir, epoch), nrow=5, normalize=True)
    if show_test_loss:
        vae.eval()
        test_loss=0
        with torch.no_grad():
            for data, _ in testloader:
                data = data.cuda()
                recon, mu, log_var = vae(data)
                test_loss += loss_function(recon, data, mu, log_var).item()
        print('\tepoch %2d\ttrain set loss %.4f\ttest set loss %.4f' % (epoch, train_loss / len(dataloader.dataset), test_loss / len(testloader.dataset)))
    else:
        print('\tepoch %d\taverage loss %.4f' % (epoch, train_loss / len(dataloader.dataset)))

In [9]:
for epoch in range(1, 51):
    train(epoch, sample_dir='data/images/vae', show_test_loss=True)

	epoch  1	train set loss 174.8316	test set loss 159.1139
	epoch  2	train set loss 155.3211	test set loss 153.4206
	epoch  3	train set loss 150.5982	test set loss 149.1550
	epoch  4	train set loss 147.9355	test set loss 147.7394
	epoch  5	train set loss 145.8305	test set loss 146.3652
	epoch  6	train set loss 144.8009	test set loss 145.1928
	epoch  7	train set loss 143.6917	test set loss 144.0797
	epoch  8	train set loss 142.8939	test set loss 143.3491
	epoch  9	train set loss 142.2042	test set loss 143.0468
	epoch 10	train set loss 141.6058	test set loss 142.4549
	epoch 11	train set loss 141.0041	test set loss 141.6941
	epoch 12	train set loss 140.5090	test set loss 141.3108
	epoch 13	train set loss 140.1729	test set loss 141.6779
	epoch 14	train set loss 139.7718	test set loss 141.1683
	epoch 15	train set loss 139.3873	test set loss 140.2652
	epoch 16	train set loss 139.0587	test set loss 140.6277
	epoch 17	train set loss 138.8009	test set loss 140.1999
	epoch 18	train set loss 138.54

![Url](results_vae.gif "VAE")

In [10]:
fashionloader = torch.utils.data.DataLoader(
    datasets.FashionMNIST(
        "data/mnist",
        train=True,
        download=True,
        transform=transforms.ToTensor(),
    ),
    batch_size=batch_size,
    shuffle=True,
)
fashiontestloader = torch.utils.data.DataLoader(
    datasets.FashionMNIST(
        "data/mnist",
        train=False,
        download=True,
        transform=transforms.ToTensor(),
    ),
    batch_size=batch_size,
    shuffle=False,
)

In [11]:
for epoch in range(1, 51):
    train(epoch, dataloader=fashionloader, testloader=fashiontestloader, sample_dir='data/images/vae_fashion', show_test_loss=True)

	epoch  1	train set loss 298.5994	test set loss 281.1881
	epoch  2	train set loss 276.7529	test set loss 276.0567
	epoch  3	train set loss 273.2074	test set loss 272.7264
	epoch  4	train set loss 270.4568	test set loss 271.7459
	epoch  5	train set loss 268.7618	test set loss 269.4839
	epoch  6	train set loss 267.5545	test set loss 268.0630
	epoch  7	train set loss 266.1332	test set loss 267.1256
	epoch  8	train set loss 264.9262	test set loss 266.5327
	epoch  9	train set loss 264.0914	test set loss 266.2383
	epoch 10	train set loss 263.4620	test set loss 264.1531
	epoch 11	train set loss 262.8270	test set loss 263.8486
	epoch 12	train set loss 262.4700	test set loss 263.9861
	epoch 13	train set loss 262.3701	test set loss 263.9473
	epoch 14	train set loss 261.9632	test set loss 262.9336
	epoch 15	train set loss 261.5779	test set loss 263.0345
	epoch 16	train set loss 261.4053	test set loss 263.5711
	epoch 17	train set loss 261.1775	test set loss 262.6611
	epoch 18	train set loss 260.57

![Url](results_vae_fashion.gif "VAE Fashion")