In [None]:
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
import json
import math
import random
import numpy as np
import scipy as sp
import scipy.stats as st
import scipy.integrate as integrate
from scipy.stats import multivariate_normal
from sklearn import linear_model
from sklearn.exceptions import ConvergenceWarning
import statsmodels.api as sm
from matplotlib.colors import LogNorm
import pickle

from joblib import Parallel, delayed
import multiprocessing
from collections import namedtuple
from itertools import count

import cProfile
from datetime import datetime

sns.set_style("whitegrid")
sns.set_palette("colorblind")
palette = sns.color_palette()
figsize = (15,8)
legend_fontsize = 16

from matplotlib import rc
rc('font',**{'family':'sans-serif'})
rc('text', usetex=True)
rc('text.latex',preamble=r'\usepackage[utf8]{inputenc}')
rc('text.latex',preamble=r'\usepackage[russian]{babel}')
rc('figure', **{'dpi': 300})

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.autograd import Variable
from torchvision.utils import save_image

import torchvision.datasets as datasets
from torchvision.utils import make_grid

## VAE

In [None]:
## загружаем MNIST
image_size = 28
image_shape = (1, image_size, image_size)

cuda = True if torch.cuda.is_available() else False
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

batch_size = 64
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "data/mnist",
        train=True,
        download=True,
        transform=transforms.ToTensor(),
    ),
    batch_size=batch_size,
    shuffle=True,
)
testloader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "data/mnist",
        train=False,
        download=True,
        transform=transforms.ToTensor(),
    ),
    batch_size=batch_size,
    shuffle=False,
)

In [None]:
class VAE(nn.Module):
    def __init__(self, x_dim, h_dim1, h_dim2, z_dim):
        super(VAE, self).__init__()
        
        # encoder
        self.fc1 = nn.Linear(x_dim, h_dim1)
        self.fc2 = nn.Linear(h_dim1, h_dim2)
        self.fc31 = nn.Linear(h_dim2, z_dim)
        self.fc32 = nn.Linear(h_dim2, z_dim)
        # decoder
        self.fc4 = nn.Linear(z_dim, h_dim2)
        self.fc5 = nn.Linear(h_dim2, h_dim1)
        self.fc6 = nn.Linear(h_dim1, x_dim)
        
    def encoder(self, x):
        h = F.relu(self.fc1(x))
        h = F.relu(self.fc2(h))
        return self.fc31(h), self.fc32(h)
    
    def sampling(self, mu, log_var):
        std = torch.exp(0.5*log_var)
        eps = torch.randn_like(std)
        return eps.mul(std).add_(mu) 
        
    def decoder(self, z):
        h = F.relu(self.fc4(z))
        h = F.relu(self.fc5(h))
        return torch.sigmoid(self.fc6(h)) 
    
    def forward(self, x):
        mu, log_var = self.encoder(x.view(-1, 784))
        z = self.sampling(mu, log_var)
        return self.decoder(z), mu, log_var

# build model
vae = VAE(x_dim=784, h_dim1= 512, h_dim2=256, z_dim=2)
if torch.cuda.is_available():
    vae.cuda()

In [None]:
vae

In [None]:
dataloader.dataset

In [None]:
optimizer = optim.Adam(vae.parameters())
# return reconstruction error + KL divergence losses
def loss_function(recon_x, x, mu, log_var):
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
    KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    return BCE + KLD

In [None]:
def train(epoch, dataloader=dataloader, testloader=testloader, sample_dir=None, show_test_loss=False):
    vae.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(dataloader):
        data = data.cuda()
        optimizer.zero_grad()
        
        recon_batch, mu, log_var = vae(data)
        loss = loss_function(recon_batch, data, mu, log_var)
        
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        
    if sample_dir is not None:
        with torch.no_grad():
            z = torch.randn(25, 2).cuda()
            sample = vae.decoder(z).cuda()
            save_image(sample.view(25, 1, 28, 28), "%s/%02d.png" % (sample_dir, epoch), nrow=5, normalize=True)
    if show_test_loss:
        vae.eval()
        test_loss=0
        with torch.no_grad():
            for data, _ in testloader:
                data = data.cuda()
                recon, mu, log_var = vae(data)
                test_loss += loss_function(recon, data, mu, log_var).item()
        print('\tepoch %2d\ttrain set loss %.4f\ttest set loss %.4f' % (epoch, train_loss / len(dataloader.dataset), test_loss / len(testloader.dataset)))
    else:
        print('\tepoch %d\taverage loss %.4f' % (epoch, train_loss / len(dataloader.dataset)))

In [None]:
for epoch in range(1, 51):
    train(epoch, sample_dir='data/images/vae', show_test_loss=True)

![Url](results_vae.gif "VAE")

In [None]:
fashionloader = torch.utils.data.DataLoader(
    datasets.FashionMNIST(
        "data/mnist",
        train=True,
        download=True,
        transform=transforms.ToTensor(),
    ),
    batch_size=batch_size,
    shuffle=True,
)
fashiontestloader = torch.utils.data.DataLoader(
    datasets.FashionMNIST(
        "data/mnist",
        train=False,
        download=True,
        transform=transforms.ToTensor(),
    ),
    batch_size=batch_size,
    shuffle=False,
)

In [None]:
for epoch in range(1, 51):
    train(epoch, dataloader=fashionloader, testloader=fashiontestloader, sample_dir='data/images/vae_fashion', show_test_loss=True)

![Url](results_vae_fashion.gif "VAE Fashion")

In [None]:
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import glob

In [None]:
for fname in sorted(glob.glob('data/images/vae/*.png')):
    print(fname)
    img = mpimg.imread(fname)
    imgplot = plt.imshow(img)
    plt.show()

In [None]:
for fname in sorted(glob.glob('data/images/vae_fashion/*.png')):
    print(fname)
    img = mpimg.imread(fname)
    imgplot = plt.imshow(img)
    plt.show()