In [None]:
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
import json
import numpy as np
import scipy as sp
import scipy.stats as st
import scipy.integrate as integrate
import sklearn
from sklearn import linear_model

sns.set_style("whitegrid")
sns.set_palette("colorblind")
palette = sns.color_palette()
figsize = (10,10)
legend_fontsize = 16

# Классификация

## Сгенерируем данные

In [None]:

x, y = np.mgrid[-2:2:.01, -2:2:.01]
pos = np.empty(x.shape + (2,))
pos[:, :, 0] = x
pos[:, :, 1] = y

centers_1 = np.random.multivariate_normal([-1,-1], 1.5*np.identity(2), size=10)
centers_2 = np.random.multivariate_normal([1,1], 1.5*np.identity(2), size=10)

def mn(x, c, mat):
	return sp.stats.multivariate_normal.pdf(x, mean=c, cov=mat)

def sample_point(centers, cov=.1 * np.identity(2)):
	i = np.random.randint(len(centers))
	return np.random.multivariate_normal(centers[i], cov)

def my_density(pos, centers):
	return np.sum([ mn(pos, c, .1 * np.identity(2)) for c in centers ], axis=0) / float(len(centers))

def my_density_ratio(pos, centers_1, centers_2):
	return my_density(pos, centers_1) / (my_density(pos, centers_1) + my_density(pos, centers_2))

z_1 = my_density(pos, centers_1)
z_2 = my_density(pos, centers_2)
z = my_density_ratio(pos, centers_1, centers_2)

num_levels = 50
plt.contourf(x, y, z, num_levels, cmap='Reds', ls=None, transparent=True)
plt.contour(x, y, z, levels=[0.5], color='black')
plt.show()

In [None]:
points_1 = np.array([ sample_point(centers_1, cov=.2 * np.identity(2)) for _ in range(100) ])
points_2 = np.array([ sample_point(centers_2, cov=.2 * np.identity(2)) for _ in range(100) ])

def plot_twopoints(ax, d1, d2, sizes=[15,25], markers=[ 'o', '*' ], colors=['0.3', '0.3']):
	ax.scatter(d1[:,0], d1[:,1], s=sizes[0], marker=markers[0], color=colors[0])
	ax.scatter(d2[:,0], d2[:,1], s=sizes[1], marker=markers[1], color=colors[1])

fig, ax = plt.subplots()
plot_twopoints(ax, centers_1, centers_2, sizes=[200,200], markers=['*', '*'], colors=['darkblue', 'r'])
plot_twopoints(ax, points_1, points_2, sizes=[25,25], markers=['o', 'o'], colors=['darkblue', 'r'])
plt.show()

In [None]:
## Рисуем результаты модели
def plot_model_results(m, c1=centers_1, c2=centers_2, p1=points_1, p2=points_2):
    xdata = np.array([ x for x in points_1 + points_2 ])
    x_min, x_max = min(xdata[:,0]), max(xdata[:,0])
    y_min, y_max = min(xdata[:,1]), max(xdata[:,1])
    x, y = np.mgrid[x_min:x_max:.01, y_min:y_max:.01]
    pos = np.empty(x.shape + (2,))
    pos[:, :, 0] = x
    pos[:, :, 1] = y
    pos_lin = pos.reshape(pos.shape[0] * pos.shape[1], 2)
    pred_lin = m.decision_function(pos_lin)
    pred = pred_lin.reshape(pos.shape[0], pos.shape[1])

    fig, ax = plt.subplots()
    ax.set_xlim((x_min, x_max))
    ax.set_ylim((y_min, y_max))
    ax.contourf(x, y, pred, num_levels, cmap='Reds', alpha=0.5, ls=None)
    ax.contour(x, y, pred, levels=[0.5])
#     ax.plot(np.arange(x_min, x_max, .01), (-m.intercept_ - m.coef_[0][0]*np.arange(x_min, x_max, .01)) / m.coef_[0][1], color='black')
    plot_twopoints(ax, centers_1, centers_2, sizes=[400,400], markers=['*', '*'], colors=['darkblue', 'r'])
    plot_twopoints(ax, points_1, points_2, sizes=[50,50], markers=['o', 'o'], colors=['darkblue', 'r'])
    plt.show()
    

## LDA и QDA

In [None]:
x_and_y = [ (p, 0) for p in points_1 ] + [ (p,1) for p in points_2 ]
np.random.shuffle(x_and_y)

data_X = np.array([x[0] for x in x_and_y])
data_y = np.array([x[1] for x in x_and_y])
x_min, x_max = min(data_X[:,0]), max(data_X[:,0])
y_min, y_max = min(data_X[:,1]), max(data_X[:,1])

from sklearn import linear_model
import sklearn

In [None]:
from sklearn import discriminant_analysis
m = sklearn.discriminant_analysis.LinearDiscriminantAnalysis()
m.fit(data_X, data_y)
plot_model_results(m)

In [None]:
m = sklearn.discriminant_analysis.QuadraticDiscriminantAnalysis()
m.fit(data_X, data_y)
plot_model_results(m)

## Логистическая регрессия

In [None]:
m = sklearn.linear_model.LogisticRegression()
m.fit(data_X, data_y)
plot_model_results(m)
print(m.coef_, m.intercept_)