#!/usr/bin/python3 import os import matplotlib.pyplot as plt import numpy as np import functools as ft def create_dataset(point_count = 10, mean = [0, 0], variance = 1, seed = 0): '''Create a 2D data cluster with specified mean and variance @param {int} point_ount number of points @param {list} mean of the dataset @param {float} diagonal variance value @param {int} seed random seed number for reproducible results @return {list} List of (x, y) tuples ''' RSTATE = np.random.RandomState(seed) cov = np.identity(2) * variance dset = [] x, y = RSTATE.multivariate_normal(mean, cov, point_count).T for j, _ in enumerate(x): dset.append(np.array([x[j], y[j]], np.float64)) return dset def plot_data_points(ax, data_sets): '''Plot scatter @param {matplotlib.Axis} axis object to plot on @param {list} list of data sets @return {None} ''' # Default colors used for plotting data points COLORS = ['#1f77b4', '#ff7f03', '#2ca02c', '#d62728', '#9467bd', '#8c564b'] for class_id, points in enumerate(data_sets): x, y = zip(*points) ax.plot(x, y, marker='o', linestyle='', markersize=4, label='Set {}'.format(class_id+1), color=COLORS[class_id]) def data_mean_cov_set(dataset): '''Return thea mean and covariance @param {list} dataset list of (x,y) tuples @return {np.ndarray} mean of the dataset {np.ndarray} covariance of the dataset ''' N = len(dataset) mean = ft.reduce(lambda a, b: np.add(a, b), dataset) mean = 1.0/N * mean cov = np.zeros((2,2)) for p in dataset: cov += np.outer(p-mean, p-mean) return mean, cov def sw_set(data_set, mean): ''' Return the Sw matrix for a single set @param {list} data_set list of (x, y) tuple @param {mean} mean of the dataset @return {np.ndarray} Sw matrix ''' N = len(data_set) Sw = np.zeros((2,2)) for p in data_set: Sw += np.outer(p - mean, p - mean) return Sw def sb_set(means): ''' Return the Sb matrix for two sets @param {list} means list array of means of the dataset @return {np.ndarray} Sb matrix ''' return np.outer(means[1] - means[0], means[1] - means[0]) def Sb_Sw(datasets, means): '''Return the Sb and Sw matrix for two sets @param {list} datasets list of data sets @param {list} means means of the datasets @return {np.ndarray} Sb matrix {np.ndarray} Sw matrix ''' set1, set2 = datasets[0], datasets[1] Sw = sw_set(set1, means[0]) + sw_set(set2, means[1]) Sb = sb_set(means) return Sb, Sw def zero_line(x, w, w0): '''Return the y values corresponding to a linear discriminant The discriminant is defined by y = w^{T} * x + w0 = w_0 * x_0 + w_1 * x_1 + w0 = 0 x_1 is computed as x_1 = -1/w_1 * ( w_0 * x_0 + w0 ) @param {np.ndarray} x values for which to compute the corresponding y-value @param {list} w array of linear discrimnant weights @return {list} y-values of the linear discriminant ''' return -1/w[1] * (w[0] * x + w0) # return 1/w1*(w[0]*mean[0] + w[1] * mean[1] - w[0]*x) def set_extent(dataset): x_min, x_max = dataset[0][0], dataset[0][0] y_min, y_max = dataset[0][1], dataset[0][1] for p in dataset: x_min = np.minimum(x_min, p[0]) y_min = np.minimum(y_min, p[1]) x_max = np.maximum(x_max, p[0]) y_max = np.maximum(y_max, p[1]) return [x_min, x_max], [y_min, y_max] def sets_extent(datasets): x_range, y_range = set_extent(datasets[0]) for dset in datasets: [x0, x1], [y0, y1] = set_extent(dset) x_range[0] = np.minimum(x_range[0], x0) x_range[1] = np.maximum(x_range[1], x1) y_range[0] = np.minimum(y_range[0], y0) y_range[1] = np.maximum(y_range[1], y1) return x_range, y_range def create_plot(seed1, seed2): fig, ax = plt.subplots() set1 = create_dataset(30, mean=[+1, +1], variance=1, seed=seed1) set2 = create_dataset(30, mean=[-1, -1], variance=1, seed=seed2) data = [set1, set2] mean1, cov1 = data_mean_cov_set(set1) mean2, cov2 = data_mean_cov_set(set2) means = [mean1, mean2] x_range, y_range = sets_extent(data) x = np.linspace(x_range[0], x_range[1], 5) sb, sw = Sb_Sw(data, means) N = len(data[0]) + len(data[1]) N1, N2 = len(data[0]), len(data[1]) mean_total = 1.0/N*(N1 * means[0] + N2 * means[1]) w = np.linalg.inv(sw).dot(means[1] - means[0]) w0 = - mean_total.dot(w) lp1, = ax.plot(x, zero_line(x, w, w0), color='red') plot_data_points(ax, data) ax.legend(loc='best', fancybox=True, framealpha=0.5, fontsize='medium') ax.grid(True) return fig, ax def main(): fig, ax = create_plot(4, 8) __dirname = os.path.dirname(os.path.realpath(__file__)) fig.patch.set_alpha(0.0) fn = '../img/classification-least-squares.svg' fn = os.path.join(__dirname, fn) plt.savefig(fn, facecolor=fig.get_facecolor(), edgecolor='none', bbox_inches=0) if __name__ == '__main__': main()