mexopencv  3.4.1
MEX interface for OpenCV library
Boost_.cpp
Go to the documentation of this file.
1 
8 #include "mexopencv.hpp"
9 #include "mexopencv_ml.hpp"
10 #include "opencv2/ml.hpp"
11 using namespace std;
12 using namespace cv;
13 using namespace cv::ml;
14 
15 // Persistent objects
16 namespace {
18 int last_id = 0;
21 
24  ("Discrete", cv::ml::Boost::DISCRETE)
25  ("Real", cv::ml::Boost::REAL)
26  ("Logit", cv::ml::Boost::LOGIT)
27  ("Gentle", cv::ml::Boost::GENTLE);
28 
31  (cv::ml::Boost::DISCRETE, "Discrete")
32  (cv::ml::Boost::REAL, "Real")
33  (cv::ml::Boost::LOGIT, "Logit")
34  (cv::ml::Boost::GENTLE, "Gentle");
35 }
36 
44 void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
45 {
46  // Check the number of arguments
47  nargchk(nrhs>=2 && nlhs<=2);
48 
49  // Argument vector
50  vector<MxArray> rhs(prhs, prhs+nrhs);
51  int id = rhs[0].toInt();
52  string method(rhs[1].toString());
53 
54  // Constructor is called. Create a new object from argument
55  if (method == "new") {
56  nargchk(nrhs==2 && nlhs<=1);
57  obj_[++last_id] = Boost::create();
58  plhs[0] = MxArray(last_id);
59  mexLock();
60  return;
61  }
62 
63  // Big operation switch
64  Ptr<Boost> obj = obj_[id];
65  if (obj.empty())
66  mexErrMsgIdAndTxt("mexopencv:error", "Object not found id=%d", id);
67  if (method == "delete") {
68  nargchk(nrhs==2 && nlhs==0);
69  obj_.erase(id);
70  mexUnlock();
71  }
72  else if (method == "clear") {
73  nargchk(nrhs==2 && nlhs==0);
74  obj->clear();
75  }
76  else if (method == "load") {
77  nargchk(nrhs>=3 && (nrhs%2)==1 && nlhs==0);
78  string objname;
79  bool loadFromString = false;
80  for (int i=3; i<nrhs; i+=2) {
81  string key(rhs[i].toString());
82  if (key == "ObjName")
83  objname = rhs[i+1].toString();
84  else if (key == "FromString")
85  loadFromString = rhs[i+1].toBool();
86  else
87  mexErrMsgIdAndTxt("mexopencv:error",
88  "Unrecognized option %s", key.c_str());
89  }
90  //obj_[id] = Boost::load(rhs[2].toString());
91  obj_[id] = (loadFromString ?
92  Algorithm::loadFromString<Boost>(rhs[2].toString(), objname) :
93  Algorithm::load<Boost>(rhs[2].toString(), objname));
94  }
95  else if (method == "save") {
96  nargchk(nrhs==3 && nlhs<=1);
97  string fname(rhs[2].toString());
98  if (nlhs > 0) {
99  // write to memory, and return string
100  FileStorage fs(fname, FileStorage::WRITE + FileStorage::MEMORY);
101  if (!fs.isOpened())
102  mexErrMsgIdAndTxt("mexopencv:error", "Failed to open file");
103  fs << obj->getDefaultName() << "{";
104  obj->write(fs);
105  fs << "}";
106  plhs[0] = MxArray(fs.releaseAndGetString());
107  }
108  else
109  // write to disk
110  obj->save(fname);
111  }
112  else if (method == "empty") {
113  nargchk(nrhs==2 && nlhs<=1);
114  plhs[0] = MxArray(obj->empty());
115  }
116  else if (method == "getDefaultName") {
117  nargchk(nrhs==2 && nlhs<=1);
118  plhs[0] = MxArray(obj->getDefaultName());
119  }
120  else if (method == "getVarCount") {
121  nargchk(nrhs==2 && nlhs<=1);
122  plhs[0] = MxArray(obj->getVarCount());
123  }
124  else if (method == "isClassifier") {
125  nargchk(nrhs==2 && nlhs<=1);
126  plhs[0] = MxArray(obj->isClassifier());
127  }
128  else if (method == "isTrained") {
129  nargchk(nrhs==2 && nlhs<=1);
130  plhs[0] = MxArray(obj->isTrained());
131  }
132  else if (method == "train") {
133  nargchk(nrhs>=4 && (nrhs%2)==0 && nlhs<=1);
134  vector<MxArray> dataOptions;
135  int flags = 0;
136  for (int i=4; i<nrhs; i+=2) {
137  string key(rhs[i].toString());
138  if (key == "Data")
139  dataOptions = rhs[i+1].toVector<MxArray>();
140  else if (key == "Flags")
141  flags = rhs[i+1].toInt();
142  else if (key == "RawOutput")
143  UPDATE_FLAG(flags, rhs[i+1].toBool(), StatModel::RAW_OUTPUT);
144  else if (key == "CompressedInput")
145  UPDATE_FLAG(flags, rhs[i+1].toBool(), StatModel::COMPRESSED_INPUT);
146  else if (key == "PredictSum")
147  UPDATE_FLAG(flags, rhs[i+1].toBool(), DTrees::PREDICT_SUM);
148  else if (key == "PredictMaxVote")
149  UPDATE_FLAG(flags, rhs[i+1].toBool(), DTrees::PREDICT_MAX_VOTE);
150  else
151  mexErrMsgIdAndTxt("mexopencv:error",
152  "Unrecognized option %s", key.c_str());
153  }
154  Ptr<TrainData> data;
155  if (rhs[2].isChar())
156  data = loadTrainData(rhs[2].toString(),
157  dataOptions.begin(), dataOptions.end());
158  else
159  data = createTrainData(
160  rhs[2].toMat(CV_32F),
161  rhs[3].toMat(rhs[3].isInt32() ? CV_32S : CV_32F),
162  dataOptions.begin(), dataOptions.end());
163  bool b = obj->train(data, flags);
164  plhs[0] = MxArray(b);
165  }
166  else if (method == "calcError") {
167  nargchk(nrhs>=4 && (nrhs%2)==0 && nlhs<=2);
168  vector<MxArray> dataOptions;
169  bool test = false;
170  for (int i=4; i<nrhs; i+=2) {
171  string key(rhs[i].toString());
172  if (key == "Data")
173  dataOptions = rhs[i+1].toVector<MxArray>();
174  else if (key == "TestError")
175  test = rhs[i+1].toBool();
176  else
177  mexErrMsgIdAndTxt("mexopencv:error",
178  "Unrecognized option %s", key.c_str());
179  }
180  Ptr<TrainData> data;
181  if (rhs[2].isChar())
182  data = loadTrainData(rhs[2].toString(),
183  dataOptions.begin(), dataOptions.end());
184  else
185  data = createTrainData(
186  rhs[2].toMat(CV_32F),
187  rhs[3].toMat(rhs[3].isInt32() ? CV_32S : CV_32F),
188  dataOptions.begin(), dataOptions.end());
189  Mat resp;
190  float err = obj->calcError(data, test, (nlhs>1 ? resp : noArray()));
191  plhs[0] = MxArray(err);
192  if (nlhs>1)
193  plhs[1] = MxArray(resp);
194  }
195  else if (method == "predict") {
196  nargchk(nrhs>=3 && (nrhs%2)==1 && nlhs<=2);
197  int flags = 0;
198  for (int i=3; i<nrhs; i+=2) {
199  string key(rhs[i].toString());
200  if (key == "Flags")
201  flags = rhs[i+1].toInt();
202  else if (key == "RawOutput")
203  UPDATE_FLAG(flags, rhs[i+1].toBool(), StatModel::RAW_OUTPUT);
204  else if (key == "CompressedInput")
205  UPDATE_FLAG(flags, rhs[i+1].toBool(), StatModel::COMPRESSED_INPUT);
206  else if (key == "PreprocessedInput")
207  UPDATE_FLAG(flags, rhs[i+1].toBool(), StatModel::PREPROCESSED_INPUT);
208  else if (key == "PredictAuto") {
209  //UPDATE_FLAG(flags, rhs[i+1].toBool(), DTrees::PREDICT_AUTO);
210  UPDATE_FLAG(flags, !rhs[i+1].toBool(), DTrees::PREDICT_SUM);
211  UPDATE_FLAG(flags, !rhs[i+1].toBool(), DTrees::PREDICT_MAX_VOTE);
212  }
213  else if (key == "PredictSum")
214  UPDATE_FLAG(flags, rhs[i+1].toBool(), DTrees::PREDICT_SUM);
215  else if (key == "PredictMaxVote")
216  UPDATE_FLAG(flags, rhs[i+1].toBool(), DTrees::PREDICT_MAX_VOTE);
217  else
218  mexErrMsgIdAndTxt("mexopencv:error",
219  "Unrecognized option %s", key.c_str());
220  }
221  Mat samples(rhs[2].toMat(CV_32F)),
222  results;
223  float f = obj->predict(samples, results, flags);
224  plhs[0] = MxArray(results);
225  if (nlhs>1)
226  plhs[1] = MxArray(f);
227  }
228  else if (method == "getNodes") {
229  nargchk(nrhs==2 && nlhs<=1);
230  plhs[0] = toStruct(obj->getNodes());
231  }
232  else if (method == "getRoots") {
233  nargchk(nrhs==2 && nlhs<=1);
234  plhs[0] = MxArray(obj->getRoots());
235  }
236  else if (method == "getSplits") {
237  nargchk(nrhs==2 && nlhs<=1);
238  plhs[0] = toStruct(obj->getSplits());
239  }
240  else if (method == "getSubsets") {
241  nargchk(nrhs==2 && nlhs<=1);
242  plhs[0] = MxArray(obj->getSubsets());
243  }
244  else if (method == "get") {
245  nargchk(nrhs==3 && nlhs<=1);
246  string prop(rhs[2].toString());
247  if (prop == "CVFolds")
248  plhs[0] = MxArray(obj->getCVFolds());
249  else if (prop == "MaxCategories")
250  plhs[0] = MxArray(obj->getMaxCategories());
251  else if (prop == "MaxDepth")
252  plhs[0] = MxArray(obj->getMaxDepth());
253  else if (prop == "MinSampleCount")
254  plhs[0] = MxArray(obj->getMinSampleCount());
255  else if (prop == "Priors")
256  plhs[0] = MxArray(obj->getPriors());
257  else if (prop == "RegressionAccuracy")
258  plhs[0] = MxArray(obj->getRegressionAccuracy());
259  else if (prop == "TruncatePrunedTree")
260  plhs[0] = MxArray(obj->getTruncatePrunedTree());
261  else if (prop == "Use1SERule")
262  plhs[0] = MxArray(obj->getUse1SERule());
263  else if (prop == "UseSurrogates")
264  plhs[0] = MxArray(obj->getUseSurrogates());
265  else if (prop == "BoostType")
266  plhs[0] = MxArray(InvBoostType[obj->getBoostType()]);
267  else if (prop == "WeakCount")
268  plhs[0] = MxArray(obj->getWeakCount());
269  else if (prop == "WeightTrimRate")
270  plhs[0] = MxArray(obj->getWeightTrimRate());
271  else
272  mexErrMsgIdAndTxt("mexopencv:error",
273  "Unrecognized property %s", prop.c_str());
274  }
275  else if (method == "set") {
276  nargchk(nrhs==4 && nlhs==0);
277  string prop(rhs[2].toString());
278  if (prop == "CVFolds")
279  obj->setCVFolds(rhs[3].toInt());
280  else if (prop == "MaxCategories")
281  obj->setMaxCategories(rhs[3].toInt());
282  else if (prop == "MaxDepth")
283  obj->setMaxDepth(rhs[3].toInt());
284  else if (prop == "MinSampleCount")
285  obj->setMinSampleCount(rhs[3].toInt());
286  else if (prop == "Priors")
287  obj->setPriors(rhs[3].toMat());
288  else if (prop == "RegressionAccuracy")
289  obj->setRegressionAccuracy(rhs[3].toFloat());
290  else if (prop == "TruncatePrunedTree")
291  obj->setTruncatePrunedTree(rhs[3].toBool());
292  else if (prop == "Use1SERule")
293  obj->setUse1SERule(rhs[3].toBool());
294  else if (prop == "UseSurrogates")
295  obj->setUseSurrogates(rhs[3].toBool());
296  else if (prop == "BoostType")
297  obj->setBoostType(BoostType[rhs[3].toString()]);
298  else if (prop == "WeakCount")
299  obj->setWeakCount(rhs[3].toInt());
300  else if (prop == "WeightTrimRate")
301  obj->setWeightTrimRate(rhs[3].toDouble());
302  else
303  mexErrMsgIdAndTxt("mexopencv:error",
304  "Unrecognized property %s", prop.c_str());
305  }
306  else
307  mexErrMsgIdAndTxt("mexopencv:error",
308  "Unrecognized operation %s", method.c_str());
309 }
virtual int getMaxDepth() const=0
virtual int getBoostType() const=0
virtual bool getUse1SERule() const=0
virtual bool isTrained() const=0
virtual void setMinSampleCount(int val)=0
const ConstMap< int, string > InvBoostType
Option values for Inverse boost types.
Definition: Boost_.cpp:30
LIBMWMEX_API_EXTERN_C void mexLock(void)
Lock a MEX-function so that it cannot be cleared from memory.
virtual void setUse1SERule(bool val)=0
virtual void setRegressionAccuracy(float val)=0
virtual void setUseSurrogates(bool val)=0
STL namespace.
virtual float getRegressionAccuracy() const=0
T end(T... args)
virtual bool isOpened() const
const ConstMap< string, int > BoostType
Option values for Boost types.
Definition: Boost_.cpp:23
struct mxArray_tag mxArray
Forward declaration for mxArray.
Definition: matrix.h:259
STL class.
int last_id
Last object id to allocate.
Definition: Boost_.cpp:18
virtual double getWeightTrimRate() const=0
virtual bool getUseSurrogates() const=0
MxArray toStruct(const std::vector< cv::ml::DTrees::Node > &nodes)
Convert tree nodes to struct array.
virtual bool train(const Ptr< TrainData > &trainData, int flags=0)
cv::Ptr< cv::ml::TrainData > loadTrainData(const std::string &filename, std::vector< MxArray >::const_iterator first, std::vector< MxArray >::const_iterator last)
Read a dataset from a CSV file.
virtual const std::vector< Split > & getSplits() const=0
virtual void clear()
virtual String releaseAndGetString()
#define CV_32F
virtual void setWeightTrimRate(double val)=0
virtual void setBoostType(int val)=0
virtual void write(FileStorage &fs) const
InputOutputArray noArray()
#define UPDATE_FLAG(NUM, TF, BIT)
set or clear a bit in flag depending on bool value
Definition: mexopencv.hpp:174
LIBMWMEX_API_EXTERN_C void mexErrMsgIdAndTxt(const char *identifier, const char *err_msg,...)
Issue formatted error message with corresponding error identifier and return to MATLAB prompt...
virtual const std::vector< Node > & getNodes() const=0
LIBMWMEX_API_EXTERN_C void mexUnlock(void)
Unlock a locked MEX-function so that it can be cleared from memory.
mxArray object wrapper for data conversion and manipulation.
Definition: MxArray.hpp:123
void nargchk(bool cond)
Alias for input/output arguments number check.
Definition: mexopencv.hpp:181
virtual int getMaxCategories() const=0
virtual bool getTruncatePrunedTree() const=0
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
Main entry called from Matlab.
Definition: Boost_.cpp:44
virtual int getWeakCount() const=0
STL class.
bool empty() const
virtual cv::Mat getPriors() const=0
virtual const std::vector< int > & getSubsets() const=0
#define CV_32S
virtual int getMinSampleCount() const=0
virtual String getDefaultName() const
Global constant definitions.
T begin(T... args)
virtual float predict(InputArray samples, OutputArray results=noArray(), int flags=0) const=0
T c_str(T... args)
virtual int getCVFolds() const=0
virtual bool empty() const
map< int, Ptr< Boost > > obj_
Object container.
Definition: Boost_.cpp:20
virtual void setTruncatePrunedTree(bool val)=0
virtual bool isClassifier() const=0
virtual float calcError(const Ptr< TrainData > &data, bool test, OutputArray resp) const
virtual void setWeakCount(int val)=0
virtual void save(const String &filename) const
cv::Ptr< cv::ml::TrainData > createTrainData(const cv::Mat &samples, const cv::Mat &responses, std::vector< MxArray >::const_iterator first, std::vector< MxArray >::const_iterator last)
Create an instance of TrainData using options in arguments.
virtual void setMaxDepth(int val)=0
std::map wrapper with one-line initialization and lookup method.
Definition: MxArray.hpp:927
virtual const std::vector< int > & getRoots() const=0
virtual void setPriors(const cv::Mat &val)=0
void create(int arows, int acols, int atype, Target target=ARRAY_BUFFER, bool autoRelease=false)
Common definitions for the ml module.
virtual void setCVFolds(int val)=0
virtual int getVarCount() const=0
cv::Mat toMat() const
virtual void setMaxCategories(int val)=0