#!/usr/bin/python3 import os import numpy as np import matplotlib.pyplot as plt import matplotlib.animation as animation def create_dataset(point_count, mean = [0, 0], variance = 1, seed = 0): '''Create a 2D data cluster with specified mean and variance @param {int} point_ount number of points @param {list} mean of the dataset @param {float} diagonal variance value @param {int} seed random seed number for reproducible results @return {list} List of (x, y) tuples ''' RSTATE = np.random.RandomState(seed) cov = np.identity(2) * variance dset = [] x, y = RSTATE.multivariate_normal(mean, cov, point_count).T return x, y def create_clusters(counts, means, seeds): '''Create random and reproducible set of clusters with `counts` points each @param {int} number of points per dataset @param {list} means list of generated class means @param {list} list of seeds used to generate random number @return [list{[x,y]}] List of (x,y) coordinate vectors ''' return [create_dataset(counts, mean, 1, seed) for mean, seed in zip(means, seeds)] def gaussian(x, mu, sigma2): '''Gaussian function @param {float|np.ndarray} x gaussian evaluation location @param {float|np.ndarray} mu gaussian mean @param {float|np.ndarray} sigma2 gaussian variance @return {function} gaussian value at x ''' if (isinstance(sigma2, float) or isinstance(sigma2, int)): coeff = 1.0/np.sqrt(2*np.pi*sigma2) return coeff*np.exp(-(x-mu)**2/(2*sigma2)) D = sigma2.size detSigma = np.linalg.det(sigma2) coeff = 1.0/(2*np.pi)**(D/2.0) * detSigma**0.5 sigmaInverse = np.linalg.inv(sigma2) return coeff*np.exp(-0.5* np.inner( (x-mu), np.dot(sigmaInverse, x-mu), )) def flatten_clusters(clusters): '''Return an iterator of each point in the dataset @param {list} clusters list of [x,y] coordinates for each cluster. Treated as incompleted dataset @return {iterator} x,y coordinates of points in dataset ''' for class_id, dset in enumerate(clusters): for x in zip(*dset): yield x def cluster_point_count(clusters): '''Return the total number of points in dataset @param {list} clusters list of [x,y] coordinates for each cluster. Treated as incompleted dataset @return {int} total number of points in cluster ''' N = 0 for c in clusters: N += len(c[0]) return N def e_step(clusters, pis, mus, sigmas): '''(E)xpectation step Return the expected value of the responsibilities @param {list} clusters list of [x,y] coordinates for each cluster. Treated as incompleted dataset @param {list} pis mixture coefficients @param {list} mus gaussian means @param {list} sigmas gaussian variances @return {np.ndarray} NxK indicator ''' # Note: z_{nk} == indicator for nth point of class k # z_k \in {0, 1} N = cluster_point_count(clusters) K = len(pis) ret = np.zeros((N, K)) n = 0 for n, x in enumerate(flatten_clusters(clusters)): den = 0 for j in range(K): den += pis[j] * gaussian(np.array(x), mus[j], sigmas[j]) for k in range(K): num = pis[k] * gaussian(np.array(x), mus[k], sigmas[k]) ret[n][k] = num/den return ret def m_step(clusters, pis, mus, sigmas, gammas): '''(M)aximization step Return mixing coefficients, means, variances @param {list} clusters list of [x,y] coordinates for each cluster. Treated as incompleted dataset @param {list} pis mixture coefficients @param {list} mus gaussian means @param {list} sigmas gaussian variances @param {np.ndarray} gammas NxK indicator @return {list} updated mixing coefficiens {list} updated means {list} updated variances ''' K = len(pis) N = cluster_point_count(clusters) pis_new, mus_new, sigmas_new = [], [], [] for k in range(K): Nk = np.sum(gammas[:, k]) # mixing coefficients pis_new.append(Nk/float(N)) # means muk = np.zeros(2) for n, x in enumerate(flatten_clusters(clusters)): muk += 1/Nk * gammas[n][k] * np.array(x) mus_new.append(muk) sigmak = np.zeros((2,2)) for n, x in enumerate(flatten_clusters(clusters)): sigmak += 1.0/Nk * gammas[n][k]*np.outer((np.array(x)-muk), (np.array(x)-muk)) sigmas_new.append(sigmak) return pis_new, mus_new, sigmas_new def gauss_mixture_grid(xp, yp, pis, mus, sigmas): '''Evaluate the Gaussian Mixture model over a grid @param {np.ndarray} xp x values on a grid @param {np.ndarray} yp y values on a grid @param {list} pis mixture coefficients @param {list} mus gaussian means @param {list} sigmas gaussian variances @return {np.ndarray} 2D grid with gaussian mixture model values at [y, x] ''' Z = np.zeros((len(yp), len(xp))) K = len(pis) for i, y in enumerate(yp): for j, x in enumerate(xp): val = 0 for k in range(K): val += pis[k] * gaussian(np.array([x, y]), mus[k], sigmas[k]) Z[i, j] = val return Z def update_plot(frame, history, x_plot, y_plot, fig_data): '''FuncAnimation plot update function Add new random data to the plot and updates plots @param {int} frame frame number being rendered @param {list} history.mus means of the clusters {list} history.sigmas variances of the clusters {list} history.pis mixture coefficients @param {np.1darray} x_plot x-range of values to plot mixture model @param {np.1darray} y_plot y-range of values to plot mixture model @param {matplotlib.fig} fig_data.ax plot axis object {[matplotlib.Line2D]} fig_data.means mean scatter plot objs ''' print ('frame {} '.format(frame), end='\r') # fetch new pseudo random data row = history[frame] pis = np.array(row['pis']) mus = np.array(row['mus']) sigmas = np.array(row['sigmas']) ax = fig_data['ax'] ax.set_title(r'Iteration {}'.format(frame)) mean_plot = fig_data['means'] for d, p in zip(mus, mean_plot): p.set_data(d) # Remove existing contour plot to speed up replotting if 'cplot' in fig_data: cplot = fig_data['cplot'] for col in cplot.collections: col.remove() X, Y = np.meshgrid(x_plot, y_plot) Z = gauss_mixture_grid(x_plot, y_plot, pis, mus, sigmas) fig_data['cplot'] = ax.contourf(X, Y, Z)#, alpha=0.8) cmin_max = [np.min(Z), np.max(Z)] fig_data['cbar_map'].set_clim(cmin_max[0], cmin_max[1]) return [] def plot_em_iterations(data_sets, history): '''Plot Expectation Maximization iterations @param {list} clusters list of [x,y] coordinates for each cluster. Treated as incomplete dataset @param {list} history.mus means of the clusters {list} history.sigmas variances of the clusters {list} history.pis mixture coefficients @return {animation.FuncAnimation} animation object {animation.writers} writer object ''' frame_count = len(history) Writer = animation.writers['ffmpeg'] writer = Writer(fps=15, metadata=dict(artist='Me'), bitrate=1800) fig, ax = plt.subplots() ax.grid(True) ax.set_xlabel('$x_1$') ax.set_ylabel('$x_2$') # plot clusters as incomplete dataset (uniform color) for cid, dset in enumerate(data_sets): x, y = dset ax.plot(x, y, linestyle='', marker='o', markersize=4, color='c') # means row = history[0] means = row['mus'] means_plot = [] for m in means: p, = ax.plot(m, marker='x', markersize=8, markeredgewidth=4) means_plot.append(p) xr = ax.get_xlim() yr = ax.get_ylim() xp = np.linspace(xr[0], xr[1], 50) yp = np.linspace(yr[0], yr[1], 50) fig_data = { 'ax': ax, 'fig': fig, 'means': means_plot } X, Y = np.meshgrid(xp, yp) Z = gauss_mixture_grid(xp, yp, row['pis'], means, row['sigmas']) fig_data['cplot'] = ax.contourf(X, Y, Z, alpha=0.8) cbar_map = plt.cm.ScalarMappable() cbar_map.set_clim(np.min(Z), np.max(Z)) cbar = fig_data['fig'].colorbar(cbar_map, ax=ax, format='%.0e') fig_data['cbar_map'] = cbar_map c_ani = animation.FuncAnimation(fig, update_plot, frame_count, fargs=(history, xp, yp, fig_data), interval=20, blit=True) return c_ani, writer def main(): means = [[-3, +3], [+3, +3], [+3, -3], [-3, -3], [+0, -0]] seeds = [1, 4, 8, 9, 10] clusters = create_clusters(20, means, seeds) K = len(clusters) pis = [1.0/K] * K mus = [] for i in range(K): RSTATE = np.random.RandomState(i) mus.append(RSTATE.uniform(-5,5, 2)) sigmas = [np.eye(2)] * K history = [] history.append({ 'mus': mus, 'pis': pis, 'sigmas': sigmas }) for i in range(100): gammas = e_step(clusters, pis, mus, sigmas) pis, mus, sigmas = m_step(clusters, pis, mus, sigmas, gammas) history.append({ 'mus': mus, 'pis': pis, 'sigmas': sigmas }) c_ani, writer = plot_em_iterations(clusters, history) __dirname = os.path.dirname(os.path.realpath(__file__)) fn = os.path.join(__dirname, '../video/cluster-gaussian-mixture.mp4') c_ani.save(fn, writer=writer) if __name__ == '__main__': main()