SVMSGD Interactive Classification
Train with SVMSGD algorithm. The classifier can handle linearly separable 2-class dataset.
Sources:
function varargout = train_svmsgd_demo_gui() % data sz = [594 841]; % [height width] samples = []; % Set of train samples. Contains points on image responses = []; % Set of responses for train samples % create the UI h = buildGUI(); onHelp([],[]); % display instructions if nargout > 0, varargout{1} = h; end function h = buildGUI() %BUILDGUI Creates the UI h = struct(); h.fig = figure('Name','Train SVMSGD', 'Menubar','none', ... 'Position',[200 200 sz(2) sz(1)]); if ~mexopencv.isOctave() %HACK: not implemented in Octave movegui(h.fig, 'center'); end h.ax = axes('Parent',h.fig, 'Units','normalized', 'Position',[0 0 1 1]); img = zeros([sz 3], 'uint8'); if ~mexopencv.isOctave() h.img = imshow(img, 'Parent',h.ax); else %HACK: https://savannah.gnu.org/bugs/index.php?45473 axes(h.ax); h.img = imshow(img); end % register mouse button handlers and change cursor set(h.fig, 'Pointer','cross', 'WindowKeyPressFcn',@onType, ... 'WindowButtonDownFcn',@onMouseDown); end function onHelp(~,~) %ONHELP Display usage help dialog hd = helpdlg({ 'Left-click the mouse to add a positive sample.' 'Right-click the mouse to add a negative sample.' 'Hot keys:' ' h: usage dialog' ' r: reset' ' s: save current output image as PNG image' ' e: export current data as MAT-file' ' q: quit the program' }, 'Interactive SVMSGD demo'); set(hd, 'WindowStyle','modal'); end function onReset(~,~) %ONRESET Restart from scratch % reset data samples = []; responses = []; % update plot img = zeros([sz 3], 'uint8'); set(h.img, 'CData',img); drawnow; end function onType(~,e) %ONTYPE Event handler for key press on figure % handle keys switch e.Key case 'r' onReset([],[]); case 'h' onHelp([],[]); case 's' fname = [tempname() '.png']; imwrite(img, fname); fprintf('Output saved as "%s"\n', fname); case 'e' uisave({'samples', 'responses'}, 'data.mat'); case {'q', 'escape'} close(h.fig); end end function onMouseDown(~,~) %ONMOUSEDOWN Event handler for mouse down on figure % get current location of mouse pointer p = get(h.ax, 'CurrentPoint'); p = round(p(1,1:2)); % add point to train set with corresponding positive/negative class samples(end+1,:) = p; if strcmp(get(h.fig,'SelectionType'), 'normal') responses(end+1) = +1; else responses(end+1) = -1; end % process (train model and draw results on image) [weights, shift] = doTrain(samples, responses(:)); pts = doFindPointsForLine(sz, weights, shift); img = doRedraw(sz, samples, responses, pts); % update plot set(h.img, 'CData',img); drawnow; end end function [weights, shift] = doTrain(samples, responses) %DOTRAIN Train with SVMSGD algorithm % % [weights, shift] = doTrain(samples, responses) % % ## Input % * __samples__, __responses__ train set % % ## Output % * __weights__, __shift__ vector of decision function of SVMSGD algorithm % weights = []; shift = []; if numel(unique(responses)) < 2 % ensure we have at least one point from each class return; end model = cv.SVMSGD(); model.train(samples, responses); if model.isTrained() weights = model.getWeights(); shift = model.getShift(); display(model) fprintf('%f*x + %f*y + %f = 0\n', weights(1), weights(2), shift); end end function pts = doFindPointsForLine(sz, weights, shift) %DOFINDPOINTSFORLINE Find two points for drawing decision function line (w*x = 0) pts = []; if isempty(weights) return; end % axis-aligned border segments segments = { [sz(2) 0; sz(2) sz(1)]; % right [0 sz(1); sz(2) sz(1)]; % top [0 0; sz(2) 0]; % bottom [0 0; 0 sz(1)] % left }; % test intersection against each segment until we collect two points for i=1:numel(segments) pt = doFindCrossPointWithBorders(weights, shift, segments{i}); if ~isempty(pt) pts(end+1,:) = pt; if size(pts,1) >= 2 return; end end end end function pt = doFindCrossPointWithBorders(weights, shift, seg) %DOFINDCROSSPOINTWITHBORDERS Find intersection of line (w*x = 0) and segment % % (y = HEIGHT, 0 <= x <= WIDTH) or (x = WIDTH, 0 <= y <= HEIGHT) % % decision function line equation: w(1)*x + w(2)*y + s = 0 % border equations either x=c or y=c % xmn = min(seg(:,1)); xmx = max(seg(:,1)); ymn = min(seg(:,2)); ymx = max(seg(:,2)); pt = []; if xmn == xmx && weights(2) ~= 0 % intersect with vertical border x = xmn; y = floor(-(weights(1) * x + shift) / weights(2)); if ymn <= y && y <= ymx pt = [x y]; end elseif ymn == ymx && weights(1) ~= 0 % intersect with horizontal border y = ymn; x = floor(-(weights(2) * y + shift) / weights(1)); if xmn <= x && x <= xmx pt = [x y]; end end end function img = doRedraw(sz, samples, responses, pts) %DOREDRAW Redraw point set and decision function line (w*x = 0) img = zeros([sz 3], 'uint8'); if ~isempty(samples) img = cv.circle(img, samples(responses==+1,:), 6, ... 'Color',[255 0 0], 'Thickness','Filled'); img = cv.circle(img, samples(responses==-1,:), 6, ... 'Color',[0 0 255], 'Thickness','Filled'); end if ~isempty(pts) img = cv.line(img, pts(1,:), pts(2,:), 'Color',[0 255 0]); end end