GrabCut segmentation demo

Interactive foreground extraction using the GrabCut algorithm.

This program demonstrates GrabCut segmentation: select an object in a region and then grabcut will attempt to segment it out.

Sources:

Contents

Theory

GrabCut algorithm was designed by Carsten Rother, Vladimir Kolmogorov and Andrew Blake from Microsoft Research Cambridge, UK. in their paper:

An algorithm was needed for foreground extraction with minimal user interaction, and the result was GrabCut.

How it works from user point of view? Initially user draws a rectangle around the foreground region (foreground region should be completely inside the rectangle). Then algorithm segments it iteratively to get the best result. Done. But in some cases, the segmentation won't be fine, like, it may have marked some foreground region as background and vice versa. In that case, user need to do fine touch-ups. Just give some strokes on the images where some faulty results are there. Strokes basically says: "Hey, this region should be foreground, you marked it background, correct it in next iteration", or its opposite for background. Then in the next iteration, you get better results.

See the image below. First player and football is enclosed in a blue rectangle. Then some final touchups with white strokes (denoting foreground) and black strokes (denoting background) is made. And we get a nice result.

So what happens in background ?

It is illustrated in below image (Image Courtesy: http://www.cs.ru.ac.za/research/g02m1682/)

Code

This is an interactive tool using grabcut. You can also watch this youtube video on how to use it.

function varargout = grabcut_demo_gui(im)
    % load an image
    if nargin < 1
        src = imread(fullfile(mexopencv.root(),'test','fruits.jpg'));
    elseif isempty(im)
        fmts = imformats();
        filtspec = strjoin(strcat('*.', [fmts.ext]), ';');
        [fn,fp] = uigetfile(filtspec, 'Select an image');
        if fp==0, error('No file selected'); end
        src = imread(fullfile(fp,fn));
    elseif ischar(im)
        src = imread(im);
    else
        src = im;
    end

    % we expect an 8-bit RGB image
    validateattributes(src, {'uint8'}, ...
        {'ndims',3, 'size',[nan nan 3], 'nonempty'});

    % initialize app state, and create the UI
    app = initApp(src);
    h = buildGUI(src, app);

    % hook event handlers
    opts = {'Interruptible','off', 'BusyAction','cancel'};
    set(h.pop, 'Callback',@onChange);
    set(h.btn(1), 'Callback',@onHelp);
    set(h.btn(2), 'Callback',@onReset);
    set(h.btn(3), 'Callback',@onNext, opts{:});
    set(h.fig, 'WindowKeyPressFcn',@onType, ...
        'WindowButtonDownFcn',@onMouseDown, opts{:});

    % return graphics handles
    if nargout > 0, varargout{1} = h; end

    % ========== Event Handlers ==========

    function onHelp(~,~)
        %ONHELP  Display usage help dialog

        helpdlg({
            'This program demonstrates GrabCut segmentation:'
            'select an object in a region and then grabcut'
            'will attempt to segment it out.'
            ''
            'Select a rectangular area around the object you'
            'want to segment.'
            ''
            'Then press "next" to segment the object (once or a few times).'
            ''
            'For finer touch-ups, set the mode and draw lines on the areas you'
            'want, to mark them as foreground/background (sure or probable),'
            'then press "next" again.'
            ''
            'Hot keys:'
            'ESC - quit the program'
            'r - restore the original image'
            'n - next iteration'
            'left mouse button - First set rectangle, then set pixels'
            '    as either BGD/FGD/PR_BGD/PR_FGD depending on selected'
            '    mode in dropdown menu.'
        });
    end

    function onReset(~,~)
        %ONRESET  Event handler for reset button

        app.mask(:) = 0;
        app.bgdModel(:) = 0;
        app.fgdModel(:) = 0;
        app.rect = zeros(0,4);
        app.rectxy = zeros(0,2);
        app.pts = repmat({zeros(0,2)}, [1 4]);
        app.iterCount = 0;
        app.isInitialized = false;

        set(h.txt, 'String','Iter =  0');
        set(h.img, 'CData',app.img0);
        set(h.rect, 'XData',NaN, 'YData',NaN);
        set(h.line(:), 'XData',NaN, 'YData',NaN);
        drawnow;
    end

    function onNext(~,~)
        %ONNEXT  Event handler for next button

        if app.isInitialized
            % set pixels in GC mask using drawing points
            if any(~cellfun(@isempty, app.pts))
                setLblsInMask();
            end
            % continue using current mask
            tic
            [app.mask, app.bgdModel, app.fgdModel] = cv.grabCut(...
                app.img0, app.mask, 'Mode','Eval', 'IterCount',1, ...
                'BgdModel',app.bgdModel, 'FgdModel',app.fgdModel);
            toc
        elseif any(~cellfun(@isempty, app.pts))
            % set foreground pixels in GC mask using rectangle
            setRectInMask();
            % set pixels in GC mask using drawing points
            setLblsInMask();
            % init using mask
            tic
            [app.mask, app.bgdModel, app.fgdModel] = cv.grabCut(...
                app.img0, app.mask, 'Mode','InitWithMask', 'IterCount',1);
            toc
        elseif ~isempty(app.rect)
            % init using rectangle
            rect = app.rect - [1 1 0 0];
            tic
            [app.mask, app.bgdModel, app.fgdModel] = cv.grabCut(...
                app.img0, rect, 'Mode','InitWithRect', 'IterCount',1);
            toc
        else
            disp('First select object to segment by drawing a rectangle');
            return;
        end

        % mark mask as initialized, and increment counter
        app.isInitialized = true;
        app.iterCount = app.iterCount + 1;

        % show result
        showImage();
    end

    function onType(~,e)
        %ONTYPE  Event handler for key press on figure

        % handle keys
        switch e.Key
            case {'q', 'escape'}
                close(h.fig);

            case 'h'
                onHelp([],[]);

            case 'r'
                onReset([],[]);

            case 'n'
                onNext([],[]);

            case {'add', 'subtract'}
                % adjust brush thickness
                if strcmp(e.Character, '+')
                    app.thick = min(app.thick + 2, 40);
                elseif strcmp(e.Character, '-')
                    app.thick = max(app.thick - 2, 1);
                end
                set(h.line(:), 'MarkerSize',app.thick*5.4);

            case {'1', '2', '3', '4'}
                % set brush value
                app.currIdx = str2double(e.Key);
                set(h.pop, 'Value',app.currIdx);
        end
    end

    function onChange(~,~)
        %ONCHANGE  Event handler for UI controls

        % change current GC mask drawing value: BGD/FGD/PR_BGD/PR_FGD
        app.currIdx = get(h.pop, 'Value');
    end

    function onMouseDown(~,~)
        %ONMOUSEDOWN  Event handler for mouse down on figure

        % ignore anything but left mouse clicks
        if ~strcmp(get(h.fig,'SelectionType'), 'normal')
            return;
        end

        % one of two phases: drawing rectangle, or free-drawing of points
        if isempty(app.rect)
            % select and draw rectangle
            select_rectangle();
            if isempty(app.rect), return; end
            set(h.rect, 'XData',app.rectxy(:,1), 'YData',app.rectxy(:,2));
        else
            % attach event handlers, and change mouse pointer
            set(h.fig, 'Pointer','circle', ...
                'WindowButtonMotionFcn',@onMouseMove, ...
                'WindowButtonUpFcn',@onMouseUp);
        end
    end

    function onMouseMove(~,~)
        %ONMOUSEMOVE  Event handler for mouse move on figure

        % get current point and append it
        app.pts{app.currIdx}(end+1,:) = getCurrentPoint();

        % update corresponding graphic line
        set(h.line(app.currIdx), ...
            'XData',app.pts{app.currIdx}(:,1), ...
            'YData',app.pts{app.currIdx}(:,2));
    end

    function onMouseUp(~,~)
        %ONMOUSEUP  Event handler for mouse up on figure

        % detach event handlers, and restore mouse pointer
        set(h.fig, 'Pointer','arrow', ...
            'WindowButtonMotionFcn','', ...
            'WindowButtonUpFcn','');
    end

    % ========== Helper Functions ==========

    function showImage()
        out = app.img0;
        if app.isInitialized
            % zero-out background pixels
            if true
                binMask = repmat(app.mask == 0 | app.mask == 2, [1 1 3]);
                out(binMask) = 0;
            else
                binMask = (app.mask == 1 | app.mask == 3);
                out = cv.bitwise_and(out, out, 'Mask',binMask);
            end
        end
        set(h.img, 'CData',out);
        set(h.txt, 'String',sprintf('Iter = %2d',app.iterCount));
        drawnow;
    end

    function setRectInMask()
        % convert rectangle to binary mask
        rect_mask = poly2mask(app.rectxy(:,1), app.rectxy(:,2), ...
            app.sz(1), app.sz(2));

        % set foreground pixels in GC mask using rectangle
        app.mask(:) = 0;          % BGD
        app.mask(rect_mask) = 3;  % PR_FGD
    end

    function setLblsInMask()
        % set pixels in GC mask from drawing points: BGD, FGD, PR_BGD, PR_FGD
        for i=1:4
            app.mask = cv.circle(app.mask, app.pts{i}-1, app.thick, ...
                'Color',uint8(i-1), 'Thickness','Filled');
        end

        % clear drawing points after being processed
        app.pts = repmat({zeros(0,2)}, [1 4]);
        set(h.line(:), 'XData',NaN, 'YData',NaN);
    end

    function select_rectangle()
        %TODO: consider IMRECT from image_toolbox
        % create rubberband box to prompt user for a rectangle
        p1 = getCurrentPoint(); % retrieve mouse location before dragging
        rbbox;                  % ignore its output (figure coordinates)
        pause(0.005);           % CP might not get updated if selection was too fast
        p2 = getCurrentPoint(); % retrieve mouse location after dragging

        % form rectangle from two points: [x y w h]
        tl = min([p1;p2]);   % top-left corner
        br = max([p1;p2]);   % bottom-right corner
        if all((br-tl) > 1)  % ignore small rectangles
            app.rect = [tl br-tl];
            app.rectxy = [tl; tl+[app.rect(3) 0]; br; tl+[0 app.rect(4)]; tl];
        end
    end

    function p = getCurrentPoint()
        % retrieve current mouse location
        p = get(h.ax, 'CurrentPoint');
        p = p(1,1:2);

        % clamp to within image coordinates
        p = max(p, [1 1]);
        p = min(p, [app.sz(2) app.sz(1)]);
    end
end

% ========== Initializer functions ==========

function app = initApp(img)
    %INITAPP  Initialize app state

    app = struct();
    app.img0 = img;              % original image
    app.sz = size(img);          % image size
    app.mask = zeros(size(img,1), size(img,2), 'uint8'); % GC mask
    app.bgdModel = zeros(1,64);  % GC background model
    app.fgdModel = zeros(1,64);  % GC foreground model
    app.currIdx = 1;             % drawing value (BGD/FGD/PR_BGD/PR_FGD)
    app.thick = 5;               % drawing thickness
    app.pts = repmat({zeros(0,2)}, [1 4]); % drawing points of each brush
    app.rect = zeros(0, 4);      % rectangle [x,y,w,h]
    app.rectxy = zeros(0,2);     % rectangle points [TL;TR;BR;BL;TL]
    app.iterCount = 0;           % iterations counter
    app.isInitialized = false;   % whethet GC mask is initialized
end

function h = buildGUI(img, app)
    %BUILDGUI  Creates the UI

    % parameters
    sz = size(img);
    sz(2) = max(sz(2), 350);  % minimum figure width

    % build the user interface (no resizing to keep it simple)
    h = struct();
    h.fig = figure('Name','GrabCut Demo', ...
        'NumberTitle','off', 'Menubar','none', 'Resize','off', ...
        'Position',[200 200 sz(2) sz(1)+29]);
    if ~mexopencv.isOctave()
        %HACK: not implemented in Octave
        movegui(h.fig, 'center');
    end
    h.ax = axes('Parent',h.fig, ...
        'Units','pixels', 'Position',[1 30 sz(2) sz(1)]);
    if ~mexopencv.isOctave()
        h.img = imshow(img, 'Parent',h.ax);
    else
        %HACK: https://savannah.gnu.org/bugs/index.php?45473
        axes(h.ax);
        h.img = imshow(img);
    end
    %axis(h.ax, 'on');
    h.btn(1) = uicontrol('Parent',h.fig, 'Style','pushbutton', ...
        'Position',[5 5 60 20], 'String','Help');
    h.btn(2) = uicontrol('Parent',h.fig, 'Style','pushbutton', ...
        'Position',[70 5 60 20], 'String','Reset');
    h.btn(3) = uicontrol('Parent',h.fig, 'Style','pushbutton', ...
        'Position',[135 5 60 20], 'String','Next');
    h.txt = uicontrol('Parent',h.fig, 'Style','text', 'FontSize',11, ...
        'Position',[200 5 60 20], 'String','Iter =  0');
    h.pop = uicontrol('Parent',h.fig, 'Style','popupmenu', ...
        'Position',[260 5 80 20], 'String',{'BGD','FGD','PR_BGD','PR_FGD'});

    % initialize lines (drawing and rectangle selection)
    clr = 'kwgrb';  % 'rbcmg'
    for i=1:4
        h.line(i) = line(NaN, NaN, 'Color',clr(i), 'Parent',h.ax, ...
            'LineStyle','none', 'Marker','.', 'MarkerSize',app.thick*5.4);
    end
    h.rect = line(NaN, NaN, 'Color',clr(5), 'Parent',h.ax, 'LineWidth',2);
end