mexopencv  3.4.1
MEX interface for OpenCV library
SVMSGD_.cpp
Go to the documentation of this file.
1 
8 #include "mexopencv.hpp"
9 #include "mexopencv_ml.hpp"
10 #include "opencv2/ml.hpp"
11 using namespace std;
12 using namespace cv;
13 using namespace cv::ml;
14 
15 // Persistent objects
16 namespace {
18 int last_id = 0;
21 
24  ("SoftMargin", cv::ml::SVMSGD::SOFT_MARGIN)
25  ("HardMargin", cv::ml::SVMSGD::HARD_MARGIN);
26 
29  (cv::ml::SVMSGD::SOFT_MARGIN, "SoftMargin")
30  (cv::ml::SVMSGD::HARD_MARGIN, "HardMargin");
31 
34  ("SGD", cv::ml::SVMSGD::SGD)
35  ("ASGD", cv::ml::SVMSGD::ASGD);
36 
39  (cv::ml::SVMSGD::SGD, "SGD")
40  (cv::ml::SVMSGD::ASGD, "ASGD");
41 }
42 
50 void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
51 {
52  // Check the number of arguments
53  nargchk(nrhs>=2 && nlhs<=2);
54 
55  // Argument vector
56  vector<MxArray> rhs(prhs, prhs+nrhs);
57  int id = rhs[0].toInt();
58  string method(rhs[1].toString());
59 
60  // Constructor is called. Create a new object from argument
61  if (method == "new") {
62  nargchk(nrhs==2 && nlhs<=1);
64  plhs[0] = MxArray(last_id);
65  mexLock();
66  return;
67  }
68 
69  // Big operation switch
70  Ptr<SVMSGD> obj = obj_[id];
71  if (obj.empty())
72  mexErrMsgIdAndTxt("mexopencv:error", "Object not found id=%d", id);
73  if (method == "delete") {
74  nargchk(nrhs==2 && nlhs==0);
75  obj_.erase(id);
76  mexUnlock();
77  }
78  else if (method == "clear") {
79  nargchk(nrhs==2 && nlhs==0);
80  obj->clear();
81  }
82  else if (method == "load") {
83  nargchk(nrhs>=3 && (nrhs%2)==1 && nlhs==0);
84  string objname;
85  bool loadFromString = false;
86  for (int i=3; i<nrhs; i+=2) {
87  string key(rhs[i].toString());
88  if (key == "ObjName")
89  objname = rhs[i+1].toString();
90  else if (key == "FromString")
91  loadFromString = rhs[i+1].toBool();
92  else
93  mexErrMsgIdAndTxt("mexopencv:error",
94  "Unrecognized option %s", key.c_str());
95  }
96  //obj_[id] = SVMSGD::load(rhs[2].toString());
97  obj_[id] = (loadFromString ?
98  Algorithm::loadFromString<SVMSGD>(rhs[2].toString(), objname) :
99  Algorithm::load<SVMSGD>(rhs[2].toString(), objname));
100  }
101  else if (method == "save") {
102  nargchk(nrhs==3 && nlhs<=1);
103  string fname(rhs[2].toString());
104  if (nlhs > 0) {
105  // write to memory, and return string
106  FileStorage fs(fname, FileStorage::WRITE + FileStorage::MEMORY);
107  if (!fs.isOpened())
108  mexErrMsgIdAndTxt("mexopencv:error", "Failed to open file");
109  fs << obj->getDefaultName() << "{";
110  obj->write(fs);
111  fs << "}";
112  plhs[0] = MxArray(fs.releaseAndGetString());
113  }
114  else
115  // write to disk
116  obj->save(fname);
117  }
118  else if (method == "empty") {
119  nargchk(nrhs==2 && nlhs<=1);
120  plhs[0] = MxArray(obj->empty());
121  }
122  else if (method == "getDefaultName") {
123  nargchk(nrhs==2 && nlhs<=1);
124  plhs[0] = MxArray(obj->getDefaultName());
125  }
126  else if (method == "getVarCount") {
127  nargchk(nrhs==2 && nlhs<=1);
128  plhs[0] = MxArray(obj->getVarCount());
129  }
130  else if (method == "isClassifier") {
131  nargchk(nrhs==2 && nlhs<=1);
132  plhs[0] = MxArray(obj->isClassifier());
133  }
134  else if (method == "isTrained") {
135  nargchk(nrhs==2 && nlhs<=1);
136  plhs[0] = MxArray(obj->isTrained());
137  }
138  else if (method == "train") {
139  nargchk(nrhs>=4 && (nrhs%2)==0 && nlhs<=1);
140  vector<MxArray> dataOptions;
141  int flags = 0;
142  for (int i=4; i<nrhs; i+=2) {
143  string key(rhs[i].toString());
144  if (key == "Data")
145  dataOptions = rhs[i+1].toVector<MxArray>();
146  else if (key == "Flags")
147  flags = rhs[i+1].toInt();
148  else
149  mexErrMsgIdAndTxt("mexopencv:error",
150  "Unrecognized option %s", key.c_str());
151  }
152  Ptr<TrainData> data;
153  if (rhs[2].isChar())
154  data = loadTrainData(rhs[2].toString(),
155  dataOptions.begin(), dataOptions.end());
156  else
157  data = createTrainData(
158  rhs[2].toMat(CV_32F),
159  rhs[3].toMat(CV_32F),
160  dataOptions.begin(), dataOptions.end());
161  bool b = obj->train(data, flags);
162  plhs[0] = MxArray(b);
163  }
164  else if (method == "calcError") {
165  nargchk(nrhs>=4 && (nrhs%2)==0 && nlhs<=2);
166  vector<MxArray> dataOptions;
167  bool test = false;
168  for (int i=4; i<nrhs; i+=2) {
169  string key(rhs[i].toString());
170  if (key == "Data")
171  dataOptions = rhs[i+1].toVector<MxArray>();
172  else if (key == "TestError")
173  test = rhs[i+1].toBool();
174  else
175  mexErrMsgIdAndTxt("mexopencv:error",
176  "Unrecognized option %s", key.c_str());
177  }
178  Ptr<TrainData> data;
179  if (rhs[2].isChar())
180  data = loadTrainData(rhs[2].toString(),
181  dataOptions.begin(), dataOptions.end());
182  else
183  data = createTrainData(
184  rhs[2].toMat(CV_32F),
185  rhs[3].toMat(CV_32F),
186  dataOptions.begin(), dataOptions.end());
187  Mat resp;
188  float err = obj->calcError(data, test, (nlhs>1 ? resp : noArray()));
189  plhs[0] = MxArray(err);
190  if (nlhs>1)
191  plhs[1] = MxArray(resp);
192  }
193  else if (method == "predict") {
194  nargchk(nrhs>=3 && (nrhs%2)==1 && nlhs<=2);
195  int flags = 0;
196  for (int i=3; i<nrhs; i+=2) {
197  string key(rhs[i].toString());
198  if (key == "Flags")
199  flags = rhs[i+1].toInt();
200  else if (key == "RawOutput")
201  UPDATE_FLAG(flags, rhs[i+1].toBool(), StatModel::RAW_OUTPUT);
202  else
203  mexErrMsgIdAndTxt("mexopencv:error",
204  "Unrecognized option %s", key.c_str());
205  }
206  Mat samples(rhs[2].toMat(CV_32F)),
207  results;
208  float f = obj->predict(samples, results, flags);
209  plhs[0] = MxArray(results);
210  if (nlhs>1)
211  plhs[1] = MxArray(f);
212  }
213  else if (method == "getShift") {
214  nargchk(nrhs==2 && nlhs<=1);
215  plhs[0] = MxArray(obj->getShift());
216  }
217  else if (method == "getWeights") {
218  nargchk(nrhs==2 && nlhs<=1);
219  plhs[0] = MxArray(obj->getWeights());
220  }
221  else if (method == "setOptimalParameters") {
222  nargchk(nrhs>=2 && (nrhs%2)==0 && nlhs==0);
223  int svmsgdType = cv::ml::SVMSGD::ASGD;
224  int marginType = cv::ml::SVMSGD::SOFT_MARGIN;
225  for (int i=2; i<nrhs; i+=2) {
226  string key(rhs[i].toString());
227  if (key == "SvmsgdType")
228  svmsgdType = SvmsgdTypeMap[rhs[i+1].toString()];
229  else if (key == "MarginType")
230  marginType = MarginTypeMap[rhs[i+1].toString()];
231  else
232  mexErrMsgIdAndTxt("mexopencv:error",
233  "Unrecognized option %s", key.c_str());
234  }
235  obj->setOptimalParameters(svmsgdType, marginType);
236  }
237  else if (method == "get") {
238  nargchk(nrhs==3 && nlhs<=1);
239  string prop(rhs[2].toString());
240  if (prop == "InitialStepSize")
241  plhs[0] = MxArray(obj->getInitialStepSize());
242  else if (prop == "MarginRegularization")
243  plhs[0] = MxArray(obj->getMarginRegularization());
244  else if (prop == "MarginType")
245  plhs[0] = MxArray(InvMarginTypeMap[obj->getMarginType()]);
246  else if (prop == "StepDecreasingPower")
247  plhs[0] = MxArray(obj->getStepDecreasingPower());
248  else if (prop == "SvmsgdType")
249  plhs[0] = MxArray(InvSvmsgdTypeMap[obj->getSvmsgdType()]);
250  else if (prop == "TermCriteria")
251  plhs[0] = MxArray(obj->getTermCriteria());
252  else
253  mexErrMsgIdAndTxt("mexopencv:error",
254  "Unrecognized property %s", prop.c_str());
255  }
256  else if (method == "set") {
257  nargchk(nrhs==4 && nlhs==0);
258  string prop(rhs[2].toString());
259  if (prop == "InitialStepSize")
260  obj->setInitialStepSize(rhs[3].toFloat());
261  else if (prop == "MarginRegularization")
262  obj->setMarginRegularization(rhs[3].toFloat());
263  else if (prop == "MarginType")
264  obj->setMarginType(MarginTypeMap[rhs[3].toString()]);
265  else if (prop == "StepDecreasingPower")
266  obj->setStepDecreasingPower(rhs[3].toFloat());
267  else if (prop == "SvmsgdType")
268  obj->setSvmsgdType(SvmsgdTypeMap[rhs[3].toString()]);
269  else if (prop == "TermCriteria")
270  obj->setTermCriteria(rhs[3].toTermCriteria());
271  else
272  mexErrMsgIdAndTxt("mexopencv:error",
273  "Unrecognized property %s", prop.c_str());
274  }
275  else
276  mexErrMsgIdAndTxt("mexopencv:error",
277  "Unrecognized operation %s", method.c_str());
278 }
virtual void setMarginType(int marginType)=0
virtual bool isTrained() const=0
int last_id
Last object id to allocate.
Definition: SVMSGD_.cpp:18
virtual void setSvmsgdType(int svmsgdType)=0
LIBMWMEX_API_EXTERN_C void mexLock(void)
Lock a MEX-function so that it cannot be cleared from memory.
virtual void setInitialStepSize(float InitialStepSize)=0
virtual int getMarginType() const=0
const ConstMap< string, int > SvmsgdTypeMap
Option values for SVMSGD types.
Definition: SVMSGD_.cpp:33
virtual void setStepDecreasingPower(float stepDecreasingPower)=0
STL namespace.
virtual void setOptimalParameters(int svmsgdType=SVMSGD::ASGD, int marginType=SVMSGD::SOFT_MARGIN)=0
T end(T... args)
virtual bool isOpened() const
virtual void setMarginRegularization(float marginRegularization)=0
struct mxArray_tag mxArray
Forward declaration for mxArray.
Definition: matrix.h:259
STL class.
virtual float getMarginRegularization() const=0
virtual bool train(const Ptr< TrainData > &trainData, int flags=0)
cv::Ptr< cv::ml::TrainData > loadTrainData(const std::string &filename, std::vector< MxArray >::const_iterator first, std::vector< MxArray >::const_iterator last)
Read a dataset from a CSV file.
virtual void clear()
virtual String releaseAndGetString()
#define CV_32F
virtual void write(FileStorage &fs) const
map< int, Ptr< SVMSGD > > obj_
Object container.
Definition: SVMSGD_.cpp:20
InputOutputArray noArray()
#define UPDATE_FLAG(NUM, TF, BIT)
set or clear a bit in flag depending on bool value
Definition: mexopencv.hpp:174
LIBMWMEX_API_EXTERN_C void mexErrMsgIdAndTxt(const char *identifier, const char *err_msg,...)
Issue formatted error message with corresponding error identifier and return to MATLAB prompt...
virtual float getShift()=0
LIBMWMEX_API_EXTERN_C void mexUnlock(void)
Unlock a locked MEX-function so that it can be cleared from memory.
virtual float getStepDecreasingPower() const=0
const ConstMap< string, int > MarginTypeMap
Option values for margin types.
Definition: SVMSGD_.cpp:23
mxArray object wrapper for data conversion and manipulation.
Definition: MxArray.hpp:123
void nargchk(bool cond)
Alias for input/output arguments number check.
Definition: mexopencv.hpp:181
virtual Mat getWeights()=0
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
Main entry called from Matlab.
Definition: SVMSGD_.cpp:50
STL class.
bool empty() const
virtual String getDefaultName() const
Global constant definitions.
T begin(T... args)
virtual int getSvmsgdType() const=0
virtual float predict(InputArray samples, OutputArray results=noArray(), int flags=0) const=0
const ConstMap< int, string > InvMarginTypeMap
Option values for inverse margin types.
Definition: SVMSGD_.cpp:28
T c_str(T... args)
virtual bool empty() const
virtual float getInitialStepSize() const=0
const ConstMap< int, string > InvSvmsgdTypeMap
Option values for inverse SVMSGD types.
Definition: SVMSGD_.cpp:38
virtual bool isClassifier() const=0
virtual float calcError(const Ptr< TrainData > &data, bool test, OutputArray resp) const
virtual void save(const String &filename) const
virtual TermCriteria getTermCriteria() const=0
cv::Ptr< cv::ml::TrainData > createTrainData(const cv::Mat &samples, const cv::Mat &responses, std::vector< MxArray >::const_iterator first, std::vector< MxArray >::const_iterator last)
Create an instance of TrainData using options in arguments.
std::map wrapper with one-line initialization and lookup method.
Definition: MxArray.hpp:927
virtual void setTermCriteria(const cv::TermCriteria &val)=0
void create(int arows, int acols, int atype, Target target=ARRAY_BUFFER, bool autoRelease=false)
Common definitions for the ml module.
virtual int getVarCount() const=0
cv::Mat toMat() const