The MNIST dataset of handwritten digits
Demonstrates loading the MNIST dataset.
Sources:
Contents
MNIST dataset
download/extract files if needed
dirMNIST = fullfile(mexopencv.root(), 'test', 'mnist'); if ~isdir(dirMNIST) % download baseURL = 'http://yann.lecun.com/exdb/mnist/'; files = { 'train-images-idx3-ubyte.gz' 'train-labels-idx1-ubyte.gz' 't10k-images-idx3-ubyte.gz' 't10k-labels-idx1-ubyte.gz' }; disp('Downloading MNIST database, and extracting files...'); mkdir(dirMNIST); for i=1:numel(files) gzFile = fullfile(dirMNIST, files{i}); if exist(gzFile, 'file') ~= 2 url = [baseURL, files{i}]; urlwrite(url, gzFile); end %HACK: unfortunately MATLAB's gunzip (with Java's GZIPInputStream) % ignore the stored filename when extracting a GZ file, instead they % simply remove the .gz extension! So we use Apache Commons to get the % original filename (a Java library that ships with MATLAB). % Alternatively we could use system tools like GNU gunzip or 7zip. if false gunzip(gzFile, dirMNIST); elseif false system(['gunzip --keep --name ' gzFile]); else % adapted from gunzip.m in = org.apache.commons.compress.compressors.gzip.GzipCompressorInputStream(... java.io.FileInputStream(java.io.File(gzFile))); fname = char(in.getMetaData().getFilename()); if isempty(fname) [~,fname,~] = fileparts(gzFile); end out = java.io.FileOutputStream(java.io.File(fullfile(dirMNIST, fname))); %org.apache.commons.io.IOUtils.copy(in, out); streamCopier = com.mathworks.mlwidgets.io.InterruptibleStreamCopier.getInterruptibleStreamCopier(); streamCopier.copyStream(in, out); in.close(); out.close(); clear in out streamCopier fname end end end
Load dataset from disk
we give the path to the extracted files to the load method
tic ds = cv.Dataset('OR_mnist'); ds.load(fullfile(dirMNIST,'/')); %HACK: path must end with a slash! toc
Elapsed time is 0.431709 seconds.
Data
dataset contains for each object its image and label as a structure array of images (8-bit 28x28 grayscale) and corresponding labels (0..9)
fprintf('NumSplits = %d\n', ds.getNumSplits()); dtrain = ds.getTrain() dtest = ds.getTest() whos dtrain dtest
NumSplits = 1 dtrain = 1×60000 struct array with fields: label image dtest = 1×10000 struct array with fields: label image Name Size Bytes Class Attributes dtest 1x10000 10160128 struct dtrain 1x60000 60960128 struct
Display
show one instance from each train/test sets
subplot(121), imshow(dtrain(1).image) title(sprintf('label = %d',dtrain(1).label)) subplot(122), imshow(dtest(1).image) title(sprintf('label = %d',dtest(1).label))
Display
show a sample of the first 100 train images corresponding to digit 8
if mexopencv.require('images') idx = find([dtrain.label] == 8); idx(101:end) = []; figure, montage(cat(4, dtrain(idx).image)) end