EM Clustering
Sources:
Contents
Options
K = 5; % number of clusters N = 25; % number of samples per cluster W = 400; H = 400; % width/height of output image containing 2D points % some colors for drawing clrFG = lines(K); clrBG = brighten(clrFG, 0.75);
Data
generate the training samples
samples = cell(K,1); for i=1:K K1 = fix(sqrt(K)); mu = ([mod(i-1,K1), fix((i-1)/K1)] + 1) * H / (K1+1); sig = [30 30]; samples{i} = bsxfun(@plus, bsxfun(@times, randn([N 2]), sig), mu); end samples = single(cat(1, samples{:})); % 2D points (nsamples-by-2)
true labels (nsamples-by-1)
if mexopencv.isOctave() %HACK: http://savannah.gnu.org/bugs/?45497 groups = int32(repelems(1:K, [1:K; repmat(N,1,K)]))'; else groups = int32(repelem((1:K)', N)); end
EM Cluster
cluster the data
em = cv.EM(); em.ClustersNumber = K; em.CovarianceMatrixType = 'Spherical'; em.TermCriteria.maxCount = 300; em.TermCriteria.epsilon = 0.1; [~,labels] = em.trainEM(samples); labels = int32(labels) + 1; % convert 0-based to 1-based indices
Plot
canvas to draw 2D points
img = zeros([H W 3], 'uint8');
classify every image pixel
[X,Y] = meshgrid(1:W, 1:H);
XY = single([X(:) Y(:)]);
[~,response] = em.predict2(XY);
response = int32(response) + 1; % convert 0-based to 1-based indices
draw the clustered pixels
for i=1:K idx = (response == i); if ~any(idx), continue; end img = cv.circle(img, XY(idx,:), 1, ... 'Color',uint8(clrBG(i,:)*255), 'Thickness','Filled'); end
draw the clustered samples
for i=1:size(samples,1) img = cv.circle(img, round(samples(i,:)), 3, ... 'Color',uint8(clrFG(labels(i),:)*255), 'Thickness','Filled'); end
display the result
imshow(img)
title('EM-clustering result')
display the original sample labels vs. clusters
figure image([1 W], [1 H], reshape(double(response), [H W])) colormap(clrBG) hold on if mexopencv.isOctave() %HACK: GSCATTER not implemented in Octave scatter(samples(:,1), samples(:,2), [], double(groups)) else gscatter(samples(:,1), samples(:,2), groups, clrFG) end hold off axis equal ij, axis([1 W 1 H]), grid on title('Original Samples'), xlabel('X'), ylabel('Y')