mexopencv  3.4.1
MEX interface for OpenCV library
RTrees_.cpp
Go to the documentation of this file.
1 
8 #include "mexopencv.hpp"
9 #include "mexopencv_ml.hpp"
10 #include "opencv2/ml.hpp"
11 using namespace std;
12 using namespace cv;
13 using namespace cv::ml;
14 
15 // Persistent objects
16 namespace {
18 int last_id = 0;
21 }
22 
30 void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
31 {
32  // Check the number of arguments
33  nargchk(nrhs>=2 && nlhs<=2);
34 
35  // Argument vector
36  vector<MxArray> rhs(prhs, prhs+nrhs);
37  int id = rhs[0].toInt();
38  string method(rhs[1].toString());
39 
40  // Constructor is called. Create a new object from argument
41  if (method == "new") {
42  nargchk(nrhs==2 && nlhs<=1);
44  plhs[0] = MxArray(last_id);
45  mexLock();
46  return;
47  }
48 
49  // Big operation switch
50  Ptr<RTrees> obj = obj_[id];
51  if (obj.empty())
52  mexErrMsgIdAndTxt("mexopencv:error", "Object not found id=%d", id);
53  if (method == "delete") {
54  nargchk(nrhs==2 && nlhs==0);
55  obj_.erase(id);
56  mexUnlock();
57  }
58  else if (method == "clear") {
59  nargchk(nrhs==2 && nlhs==0);
60  obj->clear();
61  }
62  else if (method == "load") {
63  nargchk(nrhs>=3 && (nrhs%2)==1 && nlhs==0);
64  string objname;
65  bool loadFromString = false;
66  for (int i=3; i<nrhs; i+=2) {
67  string key(rhs[i].toString());
68  if (key == "ObjName")
69  objname = rhs[i+1].toString();
70  else if (key == "FromString")
71  loadFromString = rhs[i+1].toBool();
72  else
73  mexErrMsgIdAndTxt("mexopencv:error",
74  "Unrecognized option %s", key.c_str());
75  }
76  //obj_[id] = RTrees::load(rhs[2].toString());
77  obj_[id] = (loadFromString ?
78  Algorithm::loadFromString<RTrees>(rhs[2].toString(), objname) :
79  Algorithm::load<RTrees>(rhs[2].toString(), objname));
80  }
81  else if (method == "save") {
82  nargchk(nrhs==3 && nlhs<=1);
83  string fname(rhs[2].toString());
84  if (nlhs > 0) {
85  // write to memory, and return string
86  FileStorage fs(fname, FileStorage::WRITE + FileStorage::MEMORY);
87  if (!fs.isOpened())
88  mexErrMsgIdAndTxt("mexopencv:error", "Failed to open file");
89  fs << obj->getDefaultName() << "{";
90  obj->write(fs);
91  fs << "}";
92  plhs[0] = MxArray(fs.releaseAndGetString());
93  }
94  else
95  // write to disk
96  obj->save(fname);
97  }
98  else if (method == "empty") {
99  nargchk(nrhs==2 && nlhs<=1);
100  plhs[0] = MxArray(obj->empty());
101  }
102  else if (method == "getDefaultName") {
103  nargchk(nrhs==2 && nlhs<=1);
104  plhs[0] = MxArray(obj->getDefaultName());
105  }
106  else if (method == "getVarCount") {
107  nargchk(nrhs==2 && nlhs<=1);
108  plhs[0] = MxArray(obj->getVarCount());
109  }
110  else if (method == "isClassifier") {
111  nargchk(nrhs==2 && nlhs<=1);
112  plhs[0] = MxArray(obj->isClassifier());
113  }
114  else if (method == "isTrained") {
115  nargchk(nrhs==2 && nlhs<=1);
116  plhs[0] = MxArray(obj->isTrained());
117  }
118  else if (method == "train") {
119  nargchk(nrhs>=4 && (nrhs%2)==0 && nlhs<=1);
120  vector<MxArray> dataOptions;
121  int flags = 0;
122  for (int i=4; i<nrhs; i+=2) {
123  string key(rhs[i].toString());
124  if (key == "Data")
125  dataOptions = rhs[i+1].toVector<MxArray>();
126  else if (key == "Flags")
127  flags = rhs[i+1].toInt();
128  else if (key == "RawOutput")
129  UPDATE_FLAG(flags, rhs[i+1].toBool(), StatModel::RAW_OUTPUT);
130  else if (key == "PredictSum")
131  UPDATE_FLAG(flags, rhs[i+1].toBool(), DTrees::PREDICT_SUM);
132  else if (key == "PredictMaxVote")
133  UPDATE_FLAG(flags, rhs[i+1].toBool(), DTrees::PREDICT_MAX_VOTE);
134  else
135  mexErrMsgIdAndTxt("mexopencv:error",
136  "Unrecognized option %s", key.c_str());
137  }
138  Ptr<TrainData> data;
139  if (rhs[2].isChar())
140  data = loadTrainData(rhs[2].toString(),
141  dataOptions.begin(), dataOptions.end());
142  else
143  data = createTrainData(
144  rhs[2].toMat(CV_32F),
145  rhs[3].toMat(rhs[3].isInt32() ? CV_32S : CV_32F),
146  dataOptions.begin(), dataOptions.end());
147  bool b = obj->train(data, flags);
148  plhs[0] = MxArray(b);
149  }
150  else if (method == "calcError") {
151  nargchk(nrhs>=4 && (nrhs%2)==0 && nlhs<=2);
152  vector<MxArray> dataOptions;
153  bool test = false;
154  for (int i=4; i<nrhs; i+=2) {
155  string key(rhs[i].toString());
156  if (key == "Data")
157  dataOptions = rhs[i+1].toVector<MxArray>();
158  else if (key == "TestError")
159  test = rhs[i+1].toBool();
160  else
161  mexErrMsgIdAndTxt("mexopencv:error",
162  "Unrecognized option %s", key.c_str());
163  }
164  Ptr<TrainData> data;
165  if (rhs[2].isChar())
166  data = loadTrainData(rhs[2].toString(),
167  dataOptions.begin(), dataOptions.end());
168  else
169  data = createTrainData(
170  rhs[2].toMat(CV_32F),
171  rhs[3].toMat(rhs[3].isInt32() ? CV_32S : CV_32F),
172  dataOptions.begin(), dataOptions.end());
173  Mat resp;
174  float err = obj->calcError(data, test, (nlhs>1 ? resp : noArray()));
175  plhs[0] = MxArray(err);
176  if (nlhs>1)
177  plhs[1] = MxArray(resp);
178  }
179  else if (method == "predict") {
180  nargchk(nrhs>=3 && (nrhs%2)==1 && nlhs<=2);
181  int flags = 0;
182  for (int i=3; i<nrhs; i+=2) {
183  string key(rhs[i].toString());
184  if (key == "Flags")
185  flags = rhs[i+1].toInt();
186  else if (key == "RawOutput")
187  UPDATE_FLAG(flags, rhs[i+1].toBool(), StatModel::RAW_OUTPUT);
188  else if (key == "CompressedInput")
189  UPDATE_FLAG(flags, rhs[i+1].toBool(), StatModel::COMPRESSED_INPUT);
190  else if (key == "PreprocessedInput")
191  UPDATE_FLAG(flags, rhs[i+1].toBool(), StatModel::PREPROCESSED_INPUT);
192  else if (key == "PredictAuto") {
193  //UPDATE_FLAG(flags, rhs[i+1].toBool(), DTrees::PREDICT_AUTO);
194  UPDATE_FLAG(flags, !rhs[i+1].toBool(), DTrees::PREDICT_SUM);
195  UPDATE_FLAG(flags, !rhs[i+1].toBool(), DTrees::PREDICT_MAX_VOTE);
196  }
197  else if (key == "PredictSum")
198  UPDATE_FLAG(flags, rhs[i+1].toBool(), DTrees::PREDICT_SUM);
199  else if (key == "PredictMaxVote")
200  UPDATE_FLAG(flags, rhs[i+1].toBool(), DTrees::PREDICT_MAX_VOTE);
201  else
202  mexErrMsgIdAndTxt("mexopencv:error",
203  "Unrecognized option %s", key.c_str());
204  }
205  Mat samples(rhs[2].toMat(CV_32F)),
206  results;
207  float f = obj->predict(samples, results, flags);
208  plhs[0] = MxArray(results);
209  if (nlhs>1)
210  plhs[1] = MxArray(f);
211  }
212  else if (method == "getNodes") {
213  nargchk(nrhs==2 && nlhs<=1);
214  plhs[0] = toStruct(obj->getNodes());
215  }
216  else if (method == "getRoots") {
217  nargchk(nrhs==2 && nlhs<=1);
218  plhs[0] = MxArray(obj->getRoots());
219  }
220  else if (method == "getSplits") {
221  nargchk(nrhs==2 && nlhs<=1);
222  plhs[0] = toStruct(obj->getSplits());
223  }
224  else if (method == "getSubsets") {
225  nargchk(nrhs==2 && nlhs<=1);
226  plhs[0] = MxArray(obj->getSubsets());
227  }
228  else if (method == "getVarImportance") {
229  nargchk(nrhs==2 && nlhs<=1);
230  plhs[0] = MxArray(obj->getVarImportance());
231  }
232  else if (method == "getVotes") {
233  nargchk(nrhs>=3 && (nrhs%2)==1 && nlhs<=1);
234  int flags = 0;
235  for (int i=3; i<nrhs; i+=2) {
236  string key(rhs[i].toString());
237  if (key == "Flags")
238  flags = rhs[i+1].toInt();
239  else if (key == "RawOutput")
240  UPDATE_FLAG(flags, rhs[i+1].toBool(), StatModel::RAW_OUTPUT);
241  else if (key == "CompressedInput")
242  UPDATE_FLAG(flags, rhs[i+1].toBool(), StatModel::COMPRESSED_INPUT);
243  else if (key == "PreprocessedInput")
244  UPDATE_FLAG(flags, rhs[i+1].toBool(), StatModel::PREPROCESSED_INPUT);
245  else if (key == "PredictAuto") {
246  //UPDATE_FLAG(flags, rhs[i+1].toBool(), DTrees::PREDICT_AUTO);
247  UPDATE_FLAG(flags, !rhs[i+1].toBool(), DTrees::PREDICT_SUM);
248  UPDATE_FLAG(flags, !rhs[i+1].toBool(), DTrees::PREDICT_MAX_VOTE);
249  }
250  else if (key == "PredictSum")
251  UPDATE_FLAG(flags, rhs[i+1].toBool(), DTrees::PREDICT_SUM);
252  else if (key == "PredictMaxVote")
253  UPDATE_FLAG(flags, rhs[i+1].toBool(), DTrees::PREDICT_MAX_VOTE);
254  else
255  mexErrMsgIdAndTxt("mexopencv:error",
256  "Unrecognized option %s", key.c_str());
257  }
258  Mat samples(rhs[2].toMat(CV_32F)),
259  results;
260  obj->getVotes(samples, results, flags);
261  plhs[0] = MxArray(results);
262  }
263  else if (method == "get") {
264  nargchk(nrhs==3 && nlhs<=1);
265  string prop(rhs[2].toString());
266  if (prop == "CVFolds")
267  plhs[0] = MxArray(obj->getCVFolds());
268  else if (prop == "MaxCategories")
269  plhs[0] = MxArray(obj->getMaxCategories());
270  else if (prop == "MaxDepth")
271  plhs[0] = MxArray(obj->getMaxDepth());
272  else if (prop == "MinSampleCount")
273  plhs[0] = MxArray(obj->getMinSampleCount());
274  else if (prop == "Priors")
275  plhs[0] = MxArray(obj->getPriors());
276  else if (prop == "RegressionAccuracy")
277  plhs[0] = MxArray(obj->getRegressionAccuracy());
278  else if (prop == "TruncatePrunedTree")
279  plhs[0] = MxArray(obj->getTruncatePrunedTree());
280  else if (prop == "Use1SERule")
281  plhs[0] = MxArray(obj->getUse1SERule());
282  else if (prop == "UseSurrogates")
283  plhs[0] = MxArray(obj->getUseSurrogates());
284  else if (prop == "ActiveVarCount")
285  plhs[0] = MxArray(obj->getActiveVarCount());
286  else if (prop == "CalculateVarImportance")
287  plhs[0] = MxArray(obj->getCalculateVarImportance());
288  else if (prop == "TermCriteria")
289  plhs[0] = MxArray(obj->getTermCriteria());
290  else
291  mexErrMsgIdAndTxt("mexopencv:error",
292  "Unrecognized property %s", prop.c_str());
293  }
294  else if (method == "set") {
295  nargchk(nrhs==4 && nlhs==0);
296  string prop(rhs[2].toString());
297  if (prop == "CVFolds")
298  obj->setCVFolds(rhs[3].toInt());
299  else if (prop == "MaxCategories")
300  obj->setMaxCategories(rhs[3].toInt());
301  else if (prop == "MaxDepth")
302  obj->setMaxDepth(rhs[3].toInt());
303  else if (prop == "MinSampleCount")
304  obj->setMinSampleCount(rhs[3].toInt());
305  else if (prop == "Priors")
306  obj->setPriors(rhs[3].toMat());
307  else if (prop == "RegressionAccuracy")
308  obj->setRegressionAccuracy(rhs[3].toFloat());
309  else if (prop == "TruncatePrunedTree")
310  obj->setTruncatePrunedTree(rhs[3].toBool());
311  else if (prop == "Use1SERule")
312  obj->setUse1SERule(rhs[3].toBool());
313  else if (prop == "UseSurrogates")
314  obj->setUseSurrogates(rhs[3].toBool());
315  else if (prop == "ActiveVarCount")
316  obj->setActiveVarCount(rhs[3].toInt());
317  else if (prop == "CalculateVarImportance")
318  obj->setCalculateVarImportance(rhs[3].toBool());
319  else if (prop == "TermCriteria")
320  obj->setTermCriteria(rhs[3].toTermCriteria());
321  else
322  mexErrMsgIdAndTxt("mexopencv:error",
323  "Unrecognized property %s", prop.c_str());
324  }
325  else
326  mexErrMsgIdAndTxt("mexopencv:error",
327  "Unrecognized operation %s", method.c_str());
328 }
virtual int getMaxDepth() const=0
virtual bool getUse1SERule() const=0
virtual bool isTrained() const=0
virtual void setActiveVarCount(int val)=0
virtual void setMinSampleCount(int val)=0
map< int, Ptr< RTrees > > obj_
Object container.
Definition: RTrees_.cpp:20
LIBMWMEX_API_EXTERN_C void mexLock(void)
Lock a MEX-function so that it cannot be cleared from memory.
virtual void setUse1SERule(bool val)=0
virtual void setRegressionAccuracy(float val)=0
virtual void setUseSurrogates(bool val)=0
STL namespace.
virtual bool getCalculateVarImportance() const=0
virtual float getRegressionAccuracy() const=0
T end(T... args)
virtual bool isOpened() const
struct mxArray_tag mxArray
Forward declaration for mxArray.
Definition: matrix.h:259
STL class.
virtual bool getUseSurrogates() const=0
MxArray toStruct(const std::vector< cv::ml::DTrees::Node > &nodes)
Convert tree nodes to struct array.
virtual bool train(const Ptr< TrainData > &trainData, int flags=0)
cv::Ptr< cv::ml::TrainData > loadTrainData(const std::string &filename, std::vector< MxArray >::const_iterator first, std::vector< MxArray >::const_iterator last)
Read a dataset from a CSV file.
virtual const std::vector< Split > & getSplits() const=0
virtual void clear()
virtual String releaseAndGetString()
#define CV_32F
virtual void setCalculateVarImportance(bool val)=0
virtual TermCriteria getTermCriteria() const=0
virtual void setTermCriteria(const TermCriteria &val)=0
virtual void write(FileStorage &fs) const
virtual Mat getVarImportance() const=0
InputOutputArray noArray()
#define UPDATE_FLAG(NUM, TF, BIT)
set or clear a bit in flag depending on bool value
Definition: mexopencv.hpp:174
LIBMWMEX_API_EXTERN_C void mexErrMsgIdAndTxt(const char *identifier, const char *err_msg,...)
Issue formatted error message with corresponding error identifier and return to MATLAB prompt...
virtual const std::vector< Node > & getNodes() const=0
LIBMWMEX_API_EXTERN_C void mexUnlock(void)
Unlock a locked MEX-function so that it can be cleared from memory.
mxArray object wrapper for data conversion and manipulation.
Definition: MxArray.hpp:123
void nargchk(bool cond)
Alias for input/output arguments number check.
Definition: mexopencv.hpp:181
virtual int getMaxCategories() const=0
virtual bool getTruncatePrunedTree() const=0
STL class.
bool empty() const
virtual cv::Mat getPriors() const=0
void getVotes(InputArray samples, OutputArray results, int flags) const
virtual const std::vector< int > & getSubsets() const=0
#define CV_32S
virtual int getMinSampleCount() const=0
virtual String getDefaultName() const
Global constant definitions.
T begin(T... args)
virtual float predict(InputArray samples, OutputArray results=noArray(), int flags=0) const=0
T c_str(T... args)
virtual int getCVFolds() const=0
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
Main entry called from Matlab.
Definition: RTrees_.cpp:30
virtual bool empty() const
virtual void setTruncatePrunedTree(bool val)=0
virtual int getActiveVarCount() const=0
virtual bool isClassifier() const=0
virtual float calcError(const Ptr< TrainData > &data, bool test, OutputArray resp) const
virtual void save(const String &filename) const
cv::Ptr< cv::ml::TrainData > createTrainData(const cv::Mat &samples, const cv::Mat &responses, std::vector< MxArray >::const_iterator first, std::vector< MxArray >::const_iterator last)
Create an instance of TrainData using options in arguments.
virtual void setMaxDepth(int val)=0
virtual const std::vector< int > & getRoots() const=0
virtual void setPriors(const cv::Mat &val)=0
void create(int arows, int acols, int atype, Target target=ARRAY_BUFFER, bool autoRelease=false)
Common definitions for the ml module.
virtual void setCVFolds(int val)=0
virtual int getVarCount() const=0
cv::Mat toMat() const
virtual void setMaxCategories(int val)=0
int last_id
Last object id to allocate.
Definition: RTrees_.cpp:18