mexopencv  3.4.1
MEX interface for OpenCV library
SVM_.cpp
Go to the documentation of this file.
1 
8 #include "mexopencv.hpp"
9 #include "mexopencv_ml.hpp"
10 #include "opencv2/ml.hpp"
11 using namespace std;
12 using namespace cv;
13 using namespace cv::ml;
14 
15 // Persistent objects
16 namespace {
18 int last_id = 0;
21 
24  ("C_SVC", cv::ml::SVM::C_SVC)
25  ("NU_SVC", cv::ml::SVM::NU_SVC)
26  ("ONE_CLASS", cv::ml::SVM::ONE_CLASS)
27  ("EPS_SVR", cv::ml::SVM::EPS_SVR)
28  ("NU_SVR", cv::ml::SVM::NU_SVR);
29 
32  (cv::ml::SVM::C_SVC, "C_SVC")
33  (cv::ml::SVM::NU_SVC, "NU_SVC")
34  (cv::ml::SVM::ONE_CLASS, "ONE_CLASS")
35  (cv::ml::SVM::EPS_SVR, "EPS_SVR")
36  (cv::ml::SVM::NU_SVR, "NU_SVR");
37 
40  ("Custom", cv::ml::SVM::CUSTOM)
41  ("Linear", cv::ml::SVM::LINEAR)
42  ("Poly", cv::ml::SVM::POLY)
43  ("RBF", cv::ml::SVM::RBF)
44  ("Sigmoid", cv::ml::SVM::SIGMOID)
45  ("Chi2", cv::ml::SVM::CHI2)
46  ("Intersection", cv::ml::SVM::INTER);
47 
50  (cv::ml::SVM::CUSTOM, "Custom")
51  (cv::ml::SVM::LINEAR, "Linear")
52  (cv::ml::SVM::POLY, "Poly")
53  (cv::ml::SVM::RBF, "RBF")
54  (cv::ml::SVM::SIGMOID, "Sigmoid")
55  (cv::ml::SVM::CHI2, "Chi2")
56  (cv::ml::SVM::INTER, "Intersection");
57 
60  ("C", cv::ml::SVM::C)
61  ("Gamma", cv::ml::SVM::GAMMA)
62  ("P", cv::ml::SVM::P)
63  ("Nu", cv::ml::SVM::NU)
64  ("Coef", cv::ml::SVM::COEF)
65  ("Degree", cv::ml::SVM::DEGREE);
66 
72 {
73  ParamGrid g;
74  if (m.isNumeric() && m.numel()==3) {
75  g.minVal = m.at<double>(0);
76  g.maxVal = m.at<double>(1);
77  g.logStep = m.at<double>(2);
78  }
79  else if (m.isStruct() && m.numel()==1) {
80  if (m.isField("minVal"))
81  g.minVal = m.at("minVal").toDouble();
82  if (m.isField("maxVal"))
83  g.maxVal = m.at("maxVal").toDouble();
84  if (m.isField("logStep"))
85  g.logStep = m.at("logStep").toDouble();
86  }
87  else if (m.isChar())
88  g = SVM::getDefaultGrid(SVMParamType[m.toString()]);
89  else
90  mexErrMsgIdAndTxt("mexopencv:error",
91  "Invalid argument to grid parameter");
92  // SVM::trainAuto permits setting step<=1 if we want to disable optimizing
93  // a certain parameter, in which case the value is taken from the props.
94  // Besides the check is done by function itself, so its not needed here.
95  /*
96  if (!g.check())
97  mexErrMsgIdAndTxt("mexopencv:error",
98  "Invalid argument to grid parameter");
99  */
100  return g;
101 }
102 
106 {
107 public:
111  explicit MatlabFunction(const string &func)
112  : fun_name(func)
113  {}
114 
126  void calc(int vcount, int n, const float* vecs,
127  const float* another, float* results)
128  {
129  // create input to evaluate kernel function
130  mxArray *lhs, *rhs[3];
131  rhs[0] = MxArray(fun_name);
132  rhs[1] = MxArray(Mat(vcount, n, CV_32F, const_cast<float*>(vecs)));
133  rhs[2] = MxArray(Mat(1, n, CV_32F, const_cast<float*>(another)));
134 
135  //TODO: mexCallMATLAB is not thread-safe!
136  // evaluate specified function in MATLAB as:
137  // results = feval("fun_name", vecs, another)
138  if (mexCallMATLAB(1, &lhs, 3, rhs, "feval") == 0) {
139  MxArray res(lhs);
140  CV_Assert(res.isSingle() && !res.isComplex() && res.ndims() == 2);
141  vector<float> v(res.toVector<float>());
142  CV_Assert(v.size() == vcount);
143  std::copy(v.begin(), v.end(), results);
144  }
145  else {
146  //TODO: error
147  std::fill(results, results + vcount, 0.0f);
148  }
149 
150  // cleanup
151  mxDestroyArray(lhs);
152  mxDestroyArray(rhs[0]);
153  mxDestroyArray(rhs[1]);
154  mxDestroyArray(rhs[2]);
155  }
156 
160  int getType() const
161  {
162  return cv::ml::SVM::CUSTOM;
163  }
164 
169  static Ptr<MatlabFunction> create(const string &func)
170  {
171  return makePtr<MatlabFunction>(func);
172  }
173 
174 private:
175  string fun_name;
176 };
177 }
178 
186 void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
187 {
188  // Check the number of arguments
189  nargchk(nrhs>=2 && nlhs<=3);
190 
191  // Argument vector
192  vector<MxArray> rhs(prhs, prhs+nrhs);
193  int id = rhs[0].toInt();
194  string method(rhs[1].toString());
195 
196  // Constructor is called. Create a new object from argument
197  if (method == "new") {
198  nargchk(nrhs==2 && nlhs<=1);
199  obj_[++last_id] = SVM::create();
200  plhs[0] = MxArray(last_id);
201  mexLock();
202  return;
203  }
204 
205  // Big operation switch
206  Ptr<SVM> obj = obj_[id];
207  if (obj.empty())
208  mexErrMsgIdAndTxt("mexopencv:error", "Object not found id=%d", id);
209  if (method == "delete") {
210  nargchk(nrhs==2 && nlhs==0);
211  obj_.erase(id);
212  mexUnlock();
213  }
214  else if (method == "clear") {
215  nargchk(nrhs==2 && nlhs==0);
216  obj->clear();
217  }
218  else if (method == "load") {
219  nargchk(nrhs>=3 && (nrhs%2)==1 && nlhs==0);
220  string objname;
221  bool loadFromString = false;
222  for (int i=3; i<nrhs; i+=2) {
223  string key(rhs[i].toString());
224  if (key == "ObjName")
225  objname = rhs[i+1].toString();
226  else if (key == "FromString")
227  loadFromString = rhs[i+1].toBool();
228  else
229  mexErrMsgIdAndTxt("mexopencv:error",
230  "Unrecognized option %s", key.c_str());
231  }
232  //obj_[id] = SVM::load(rhs[2].toString());
233  obj_[id] = (loadFromString ?
234  Algorithm::loadFromString<SVM>(rhs[2].toString(), objname) :
235  Algorithm::load<SVM>(rhs[2].toString(), objname));
236  }
237  else if (method == "save") {
238  nargchk(nrhs==3 && nlhs<=1);
239  string fname(rhs[2].toString());
240  if (nlhs > 0) {
241  // write to memory, and return string
242  FileStorage fs(fname, FileStorage::WRITE + FileStorage::MEMORY);
243  if (!fs.isOpened())
244  mexErrMsgIdAndTxt("mexopencv:error", "Failed to open file");
245  fs << obj->getDefaultName() << "{";
246  obj->write(fs);
247  fs << "}";
248  plhs[0] = MxArray(fs.releaseAndGetString());
249  }
250  else
251  // write to disk
252  obj->save(fname);
253  }
254  else if (method == "empty") {
255  nargchk(nrhs==2 && nlhs<=1);
256  plhs[0] = MxArray(obj->empty());
257  }
258  else if (method == "getDefaultName") {
259  nargchk(nrhs==2 && nlhs<=1);
260  plhs[0] = MxArray(obj->getDefaultName());
261  }
262  else if (method == "getVarCount") {
263  nargchk(nrhs==2 && nlhs<=1);
264  plhs[0] = MxArray(obj->getVarCount());
265  }
266  else if (method == "isClassifier") {
267  nargchk(nrhs==2 && nlhs<=1);
268  plhs[0] = MxArray(obj->isClassifier());
269  }
270  else if (method == "isTrained") {
271  nargchk(nrhs==2 && nlhs<=1);
272  plhs[0] = MxArray(obj->isTrained());
273  }
274  else if (method == "train") {
275  nargchk(nrhs>=4 && (nrhs%2)==0 && nlhs<=1);
276  vector<MxArray> dataOptions;
277  int flags = 0;
278  for (int i=4; i<nrhs; i+=2) {
279  string key(rhs[i].toString());
280  if (key == "Data")
281  dataOptions = rhs[i+1].toVector<MxArray>();
282  else if (key == "Flags")
283  flags = rhs[i+1].toInt();
284  else
285  mexErrMsgIdAndTxt("mexopencv:error",
286  "Unrecognized option %s", key.c_str());
287  }
288  Ptr<TrainData> data;
289  if (rhs[2].isChar())
290  data = loadTrainData(rhs[2].toString(),
291  dataOptions.begin(), dataOptions.end());
292  else
293  data = createTrainData(
294  rhs[2].toMat(CV_32F),
295  rhs[3].toMat(rhs[3].isInt32() ? CV_32S : CV_32F),
296  dataOptions.begin(), dataOptions.end());
297  bool b = obj->train(data, flags);
298  plhs[0] = MxArray(b);
299  }
300  else if (method == "calcError") {
301  nargchk(nrhs>=4 && (nrhs%2)==0 && nlhs<=2);
302  vector<MxArray> dataOptions;
303  bool test = false;
304  for (int i=4; i<nrhs; i+=2) {
305  string key(rhs[i].toString());
306  if (key == "Data")
307  dataOptions = rhs[i+1].toVector<MxArray>();
308  else if (key == "TestError")
309  test = rhs[i+1].toBool();
310  else
311  mexErrMsgIdAndTxt("mexopencv:error",
312  "Unrecognized option %s", key.c_str());
313  }
314  Ptr<TrainData> data;
315  if (rhs[2].isChar())
316  data = loadTrainData(rhs[2].toString(),
317  dataOptions.begin(), dataOptions.end());
318  else
319  data = createTrainData(
320  rhs[2].toMat(CV_32F),
321  rhs[3].toMat(rhs[3].isInt32() ? CV_32S : CV_32F),
322  dataOptions.begin(), dataOptions.end());
323  Mat resp;
324  float err = obj->calcError(data, test, (nlhs>1 ? resp : noArray()));
325  plhs[0] = MxArray(err);
326  if (nlhs>1)
327  plhs[1] = MxArray(resp);
328  }
329  else if (method == "predict") {
330  nargchk(nrhs>=3 && (nrhs%2)==1 && nlhs<=2);
331  int flags = 0;
332  for (int i=3; i<nrhs; i+=2) {
333  string key(rhs[i].toString());
334  if (key == "Flags")
335  flags = rhs[i+1].toInt();
336  else if (key == "RawOutput")
337  UPDATE_FLAG(flags, rhs[i+1].toBool(), StatModel::RAW_OUTPUT);
338  else
339  mexErrMsgIdAndTxt("mexopencv:error",
340  "Unrecognized option %s", key.c_str());
341  }
342  Mat samples(rhs[2].toMat(CV_32F)),
343  results;
344  float f = obj->predict(samples, results, flags);
345  plhs[0] = MxArray(results);
346  if (nlhs>1)
347  plhs[1] = MxArray(f);
348  }
349  else if (method == "trainAuto") {
350  nargchk(nrhs>=4 && (nrhs%2)==0 && nlhs<=1);
351  vector<MxArray> dataOptions;
352  int kFold = 10;
353  bool balanced = false;
354  ParamGrid CGrid = SVM::getDefaultGrid(SVM::C),
355  gammaGrid = SVM::getDefaultGrid(SVM::GAMMA),
356  pGrid = SVM::getDefaultGrid(SVM::P),
357  nuGrid = SVM::getDefaultGrid(SVM::NU),
358  coeffGrid = SVM::getDefaultGrid(SVM::COEF),
359  degreeGrid = SVM::getDefaultGrid(SVM::DEGREE);
360  for (int i=4; i<nrhs; i+=2) {
361  string key(rhs[i].toString());
362  if (key == "Data")
363  dataOptions = rhs[i+1].toVector<MxArray>();
364  else if (key == "KFold")
365  kFold = rhs[i+1].toInt();
366  else if (key == "Balanced")
367  balanced = rhs[i+1].toBool();
368  else if (key == "CGrid")
369  CGrid = toParamGrid(rhs[i+1]);
370  else if (key == "GammaGrid")
371  gammaGrid = toParamGrid(rhs[i+1]);
372  else if (key == "PGrid")
373  pGrid = toParamGrid(rhs[i+1]);
374  else if (key == "NuGrid")
375  nuGrid = toParamGrid(rhs[i+1]);
376  else if (key == "CoeffGrid")
377  coeffGrid = toParamGrid(rhs[i+1]);
378  else if (key == "DegreeGrid")
379  degreeGrid = toParamGrid(rhs[i+1]);
380  else
381  mexErrMsgIdAndTxt("mexopencv:error",
382  "Unrecognized option %s", key.c_str());
383  }
384  Ptr<TrainData> data;
385  if (rhs[2].isChar())
386  data = loadTrainData(rhs[2].toString(),
387  dataOptions.begin(), dataOptions.end());
388  else
389  data = createTrainData(
390  rhs[2].toMat(CV_32F),
391  rhs[3].toMat(rhs[3].isInt32() ? CV_32S : CV_32F),
392  dataOptions.begin(), dataOptions.end());
393  bool b = obj->trainAuto(data, kFold,
394  CGrid, gammaGrid, pGrid, nuGrid, coeffGrid, degreeGrid, balanced);
395  plhs[0] = MxArray(b);
396  }
397  else if (method == "getDecisionFunction") {
398  nargchk(nrhs==3 && nlhs<=3);
399  int index = rhs[2].toInt();
400  Mat alpha, svidx;
401  double rho = obj->getDecisionFunction(index, alpha, svidx);
402  plhs[0] = MxArray(alpha);
403  if (nlhs > 1)
404  plhs[1] = MxArray(svidx);
405  if (nlhs > 2)
406  plhs[2] = MxArray(rho);
407  }
408  else if (method == "getSupportVectors") {
409  nargchk(nrhs==2 && nlhs<=1);
410  plhs[0] = MxArray(obj->getSupportVectors());
411  }
412  else if (method == "getUncompressedSupportVectors") {
413  nargchk(nrhs==2 && nlhs<=1);
414  plhs[0] = MxArray(obj->getUncompressedSupportVectors());
415  }
416  else if (method == "setCustomKernel") {
417  nargchk(nrhs==3 && nlhs==0);
418  obj->setCustomKernel(MatlabFunction::create(rhs[2].toString()));
419  }
420  else if (method == "get") {
421  nargchk(nrhs==3 && nlhs<=1);
422  string prop(rhs[2].toString());
423  if (prop == "Type")
424  plhs[0] = MxArray(InvSVMType[obj->getType()]);
425  else if (prop == "KernelType")
426  plhs[0] = MxArray(InvSVMKernelType[obj->getKernelType()]);
427  else if (prop == "Degree")
428  plhs[0] = MxArray(obj->getDegree());
429  else if (prop == "Gamma")
430  plhs[0] = MxArray(obj->getGamma());
431  else if (prop == "Coef0")
432  plhs[0] = MxArray(obj->getCoef0());
433  else if (prop == "C")
434  plhs[0] = MxArray(obj->getC());
435  else if (prop == "Nu")
436  plhs[0] = MxArray(obj->getNu());
437  else if (prop == "P")
438  plhs[0] = MxArray(obj->getP());
439  else if (prop == "ClassWeights")
440  plhs[0] = MxArray(obj->getClassWeights());
441  else if (prop == "TermCriteria")
442  plhs[0] = MxArray(obj->getTermCriteria());
443  else
444  mexErrMsgIdAndTxt("mexopencv:error",
445  "Unrecognized property %s", prop.c_str());
446  }
447  else if (method == "set") {
448  nargchk(nrhs==4 && nlhs==0);
449  string prop(rhs[2].toString());
450  if (prop == "Type")
451  obj->setType(SVMType[rhs[3].toString()]);
452  else if (prop == "KernelType")
453  obj->setKernel(SVMKernelType[rhs[3].toString()]);
454  else if (prop == "Degree")
455  obj->setDegree(rhs[3].toDouble());
456  else if (prop == "Gamma")
457  obj->setGamma(rhs[3].toDouble());
458  else if (prop == "Coef0")
459  obj->setCoef0(rhs[3].toDouble());
460  else if (prop == "C")
461  obj->setC(rhs[3].toDouble());
462  else if (prop == "Nu")
463  obj->setNu(rhs[3].toDouble());
464  else if (prop == "P")
465  obj->setP(rhs[3].toDouble());
466  else if (prop == "ClassWeights")
467  obj->setClassWeights(rhs[3].toMat());
468  else if (prop == "TermCriteria")
469  obj->setTermCriteria(rhs[3].toTermCriteria());
470  else
471  mexErrMsgIdAndTxt("mexopencv:error",
472  "Unrecognized property %s", prop.c_str());
473  }
474  else
475  mexErrMsgIdAndTxt("mexopencv:error",
476  "Unrecognized operation %s", method.c_str());
477 }
virtual Mat getSupportVectors() const=0
virtual int getType() const=0
mwSize ndims() const
Number of dimensions.
Definition: MxArray.hpp:550
virtual bool isTrained() const=0
T copy(T... args)
virtual void setKernel(int kernelType)=0
#define CV_Assert(...)
virtual void setTermCriteria(const cv::TermCriteria &val)=0
const ConstMap< string, int > SVMKernelType
Option values for SVM Kernel types.
Definition: SVM_.cpp:39
MatlabFunction(const string &func)
Constructor.
Definition: SVM_.cpp:111
mwSize numel() const
Number of elements in an array.
Definition: MxArray.hpp:546
T at(mwIndex index) const
Template for numeric array element accessor.
Definition: MxArray.hpp:1250
void calc(int vcount, int n, const float *vecs, const float *another, float *results)
Evaluates MATLAB kernel function.
Definition: SVM_.cpp:126
bool isSingle() const
Determine whether array represents data as single-precision, floating-point numbers.
Definition: MxArray.hpp:700
LIBMWMEX_API_EXTERN_C void mexLock(void)
Lock a MEX-function so that it cannot be cleared from memory.
virtual void setDegree(double val)=0
const ConstMap< int, string > InvSVMType
Option values for inverse SVM types.
Definition: SVM_.cpp:31
std::string toString() const
Convert MxArray to std::string.
Definition: MxArray.cpp:517
virtual void setCustomKernel(const Ptr< Kernel > &_kernel)=0
virtual cv::Mat getClassWeights() const=0
virtual void setGamma(double val)=0
STL namespace.
bool isChar() const
Determine whether input is string array.
Definition: MxArray.hpp:614
LIBMMWMATRIX_PUBLISHED_API_EXTERN_C void mxDestroyArray(mxArray *pa)
mxArray destructor
T end(T... args)
Represents custom kernel implemented as a MATLAB function.
Definition: SVM_.cpp:105
virtual bool isOpened() const
virtual cv::TermCriteria getTermCriteria() const=0
virtual void setNu(double val)=0
struct mxArray_tag mxArray
Forward declaration for mxArray.
Definition: matrix.h:259
STL class.
virtual void setCoef0(double val)=0
virtual double getDegree() const=0
static Ptr< MatlabFunction > create(const string &func)
Factory function.
Definition: SVM_.cpp:169
string func
name of MATLAB function to evaluate (custom face detector)
Definition: Facemark_.cpp:21
index
virtual bool train(const Ptr< TrainData > &trainData, int flags=0)
virtual bool trainAuto(const Ptr< TrainData > &data, int kFold=10, ParamGrid Cgrid=getDefaultGrid(C), ParamGrid gammaGrid=getDefaultGrid(GAMMA), ParamGrid pGrid=getDefaultGrid(P), ParamGrid nuGrid=getDefaultGrid(NU), ParamGrid coeffGrid=getDefaultGrid(COEF), ParamGrid degreeGrid=getDefaultGrid(DEGREE), bool balanced=false)=0
cv::Ptr< cv::ml::TrainData > loadTrainData(const std::string &filename, std::vector< MxArray >::const_iterator first, std::vector< MxArray >::const_iterator last)
Read a dataset from a CSV file.
virtual double getDecisionFunction(int i, OutputArray alpha, OutputArray svidx) const=0
virtual void clear()
virtual String releaseAndGetString()
#define CV_32F
uint32_t v
bool isNumeric() const
Determine whether array is numeric.
Definition: MxArray.hpp:695
bool isComplex() const
Determine whether data is complex.
Definition: MxArray.hpp:632
virtual void write(FileStorage &fs) const
virtual void setType(int val)=0
InputOutputArray noArray()
#define UPDATE_FLAG(NUM, TF, BIT)
set or clear a bit in flag depending on bool value
Definition: mexopencv.hpp:174
LIBMWMEX_API_EXTERN_C void mexErrMsgIdAndTxt(const char *identifier, const char *err_msg,...)
Issue formatted error message with corresponding error identifier and return to MATLAB prompt...
LIBMWMEX_API_EXTERN_C void mexUnlock(void)
Unlock a locked MEX-function so that it can be cleared from memory.
virtual void setClassWeights(const cv::Mat &val)=0
bool isField(const std::string &fieldName) const
Determine whether a struct array has a specified field.
Definition: MxArray.hpp:743
mxArray object wrapper for data conversion and manipulation.
Definition: MxArray.hpp:123
void nargchk(bool cond)
Alias for input/output arguments number check.
Definition: mexopencv.hpp:181
Mat getUncompressedSupportVectors() const
bool isStruct() const
Determine whether input is structure array.
Definition: MxArray.hpp:708
virtual double getNu() const=0
virtual double getGamma() const=0
std::vector< T > toVector() const
Convert MxArray to std::vector<T> of primitive types.
Definition: MxArray.hpp:1151
bool empty() const
map< int, Ptr< SVM > > obj_
Object container.
Definition: SVM_.cpp:20
#define CV_32S
virtual String getDefaultName() const
virtual void setP(double val)=0
Global constant definitions.
T begin(T... args)
virtual double getC() const=0
int last_id
Last object id to allocate.
Definition: SVM_.cpp:18
virtual float predict(InputArray samples, OutputArray results=noArray(), int flags=0) const=0
T c_str(T... args)
virtual bool empty() const
ParamGrid toParamGrid(const MxArray &m)
Obtain ParamGrid object from MxArray.
Definition: SVM_.cpp:71
const ConstMap< string, int > SVMParamType
Option values for SVM params grid types.
Definition: SVM_.cpp:59
virtual int getKernelType() const=0
const ConstMap< string, int > SVMType
Option values for SVM types.
Definition: SVM_.cpp:23
int getType() const
Return type of SVM formulation.
Definition: SVM_.cpp:160
virtual bool isClassifier() const=0
virtual float calcError(const Ptr< TrainData > &data, bool test, OutputArray resp) const
virtual double getP() const=0
T fill(T... args)
virtual void save(const String &filename) const
cv::Ptr< cv::ml::TrainData > createTrainData(const cv::Mat &samples, const cv::Mat &responses, std::vector< MxArray >::const_iterator first, std::vector< MxArray >::const_iterator last)
Create an instance of TrainData using options in arguments.
std::map wrapper with one-line initialization and lookup method.
Definition: MxArray.hpp:927
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
Main entry called from Matlab.
Definition: SVM_.cpp:186
LIBMWMEX_API_EXTERN_C int mexCallMATLAB(int nlhs, mxArray *plhs[], int nrhs, mxArray *prhs[], const char *fcn_name)
call MATLAB function
virtual double getCoef0() const=0
void create(int arows, int acols, int atype, Target target=ARRAY_BUFFER, bool autoRelease=false)
Common definitions for the ml module.
virtual int getVarCount() const=0
cv::Mat toMat() const
virtual void setC(double val)=0
const ConstMap< int, string > InvSVMKernelType
Option values for inverse SVM Kernel types.
Definition: SVM_.cpp:49