mexopencv  3.4.1
MEX interface for OpenCV library
DTrees_.cpp
Go to the documentation of this file.
1 
8 #include "mexopencv.hpp"
9 #include "mexopencv_ml.hpp"
10 #include "opencv2/ml.hpp"
11 using namespace std;
12 using namespace cv;
13 using namespace cv::ml;
14 
15 // Persistent objects
16 namespace {
18 int last_id = 0;
21 }
22 
30 void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
31 {
32  // Check the number of arguments
33  nargchk(nrhs>=2 && nlhs<=2);
34 
35  // Argument vector
36  vector<MxArray> rhs(prhs, prhs+nrhs);
37  int id = rhs[0].toInt();
38  string method(rhs[1].toString());
39 
40  // Constructor is called. Create a new object from argument
41  if (method == "new") {
42  nargchk(nrhs==2 && nlhs<=1);
44  plhs[0] = MxArray(last_id);
45  mexLock();
46  return;
47  }
48 
49  // Big operation switch
50  Ptr<DTrees> obj = obj_[id];
51  if (obj.empty())
52  mexErrMsgIdAndTxt("mexopencv:error", "Object not found id=%d", id);
53  if (method == "delete") {
54  nargchk(nrhs==2 && nlhs==0);
55  obj_.erase(id);
56  mexUnlock();
57  }
58  else if (method == "clear") {
59  nargchk(nrhs==2 && nlhs==0);
60  obj->clear();
61  }
62  else if (method == "load") {
63  nargchk(nrhs>=3 && (nrhs%2)==1 && nlhs==0);
64  string objname;
65  bool loadFromString = false;
66  for (int i=3; i<nrhs; i+=2) {
67  string key(rhs[i].toString());
68  if (key == "ObjName")
69  objname = rhs[i+1].toString();
70  else if (key == "FromString")
71  loadFromString = rhs[i+1].toBool();
72  else
73  mexErrMsgIdAndTxt("mexopencv:error",
74  "Unrecognized option %s", key.c_str());
75  }
76  //obj_[id] = DTrees::load(rhs[2].toString());
77  obj_[id] = (loadFromString ?
78  Algorithm::loadFromString<DTrees>(rhs[2].toString(), objname) :
79  Algorithm::load<DTrees>(rhs[2].toString(), objname));
80  }
81  else if (method == "save") {
82  nargchk(nrhs==3 && nlhs<=1);
83  string fname(rhs[2].toString());
84  if (nlhs > 0) {
85  // write to memory, and return string
86  FileStorage fs(fname, FileStorage::WRITE + FileStorage::MEMORY);
87  if (!fs.isOpened())
88  mexErrMsgIdAndTxt("mexopencv:error", "Failed to open file");
89  fs << obj->getDefaultName() << "{";
90  obj->write(fs);
91  fs << "}";
92  plhs[0] = MxArray(fs.releaseAndGetString());
93  }
94  else
95  // write to disk
96  obj->save(fname);
97  }
98  else if (method == "empty") {
99  nargchk(nrhs==2 && nlhs<=1);
100  plhs[0] = MxArray(obj->empty());
101  }
102  else if (method == "getDefaultName") {
103  nargchk(nrhs==2 && nlhs<=1);
104  plhs[0] = MxArray(obj->getDefaultName());
105  }
106  else if (method == "getVarCount") {
107  nargchk(nrhs==2 && nlhs<=1);
108  plhs[0] = MxArray(obj->getVarCount());
109  }
110  else if (method == "isClassifier") {
111  nargchk(nrhs==2 && nlhs<=1);
112  plhs[0] = MxArray(obj->isClassifier());
113  }
114  else if (method == "isTrained") {
115  nargchk(nrhs==2 && nlhs<=1);
116  plhs[0] = MxArray(obj->isTrained());
117  }
118  else if (method == "train") {
119  nargchk(nrhs>=4 && (nrhs%2)==0 && nlhs<=1);
120  vector<MxArray> dataOptions;
121  int flags = 0;
122  for (int i=4; i<nrhs; i+=2) {
123  string key(rhs[i].toString());
124  if (key == "Data")
125  dataOptions = rhs[i+1].toVector<MxArray>();
126  else if (key == "Flags")
127  flags = rhs[i+1].toInt();
128  else
129  mexErrMsgIdAndTxt("mexopencv:error",
130  "Unrecognized option %s", key.c_str());
131  }
132  Ptr<TrainData> data;
133  if (rhs[2].isChar())
134  data = loadTrainData(rhs[2].toString(),
135  dataOptions.begin(), dataOptions.end());
136  else
137  data = createTrainData(
138  rhs[2].toMat(CV_32F),
139  rhs[3].toMat(rhs[3].isInt32() ? CV_32S : CV_32F),
140  dataOptions.begin(), dataOptions.end());
141  bool b = obj->train(data, flags);
142  plhs[0] = MxArray(b);
143  }
144  else if (method == "calcError") {
145  nargchk(nrhs>=4 && (nrhs%2)==0 && nlhs<=2);
146  vector<MxArray> dataOptions;
147  bool test = false;
148  for (int i=4; i<nrhs; i+=2) {
149  string key(rhs[i].toString());
150  if (key == "Data")
151  dataOptions = rhs[i+1].toVector<MxArray>();
152  else if (key == "TestError")
153  test = rhs[i+1].toBool();
154  else
155  mexErrMsgIdAndTxt("mexopencv:error",
156  "Unrecognized option %s", key.c_str());
157  }
158  Ptr<TrainData> data;
159  if (rhs[2].isChar())
160  data = loadTrainData(rhs[2].toString(),
161  dataOptions.begin(), dataOptions.end());
162  else
163  data = createTrainData(
164  rhs[2].toMat(CV_32F),
165  rhs[3].toMat(rhs[3].isInt32() ? CV_32S : CV_32F),
166  dataOptions.begin(), dataOptions.end());
167  Mat resp;
168  float err = obj->calcError(data, test, (nlhs>1 ? resp : noArray()));
169  plhs[0] = MxArray(err);
170  if (nlhs>1)
171  plhs[1] = MxArray(resp);
172  }
173  else if (method == "predict") {
174  nargchk(nrhs>=3 && (nrhs%2)==1 && nlhs<=2);
175  int flags = 0;
176  for (int i=3; i<nrhs; i+=2) {
177  string key(rhs[i].toString());
178  if (key == "Flags")
179  flags = rhs[i+1].toInt();
180  else if (key == "RawOutput")
181  UPDATE_FLAG(flags, rhs[i+1].toBool(), StatModel::RAW_OUTPUT);
182  else if (key == "CompressedInput")
183  UPDATE_FLAG(flags, rhs[i+1].toBool(), StatModel::COMPRESSED_INPUT);
184  else if (key == "PreprocessedInput")
185  UPDATE_FLAG(flags, rhs[i+1].toBool(), StatModel::PREPROCESSED_INPUT);
186  else if (key == "PredictAuto") {
187  //UPDATE_FLAG(flags, rhs[i+1].toBool(), DTrees::PREDICT_AUTO);
188  UPDATE_FLAG(flags, !rhs[i+1].toBool(), DTrees::PREDICT_SUM);
189  UPDATE_FLAG(flags, !rhs[i+1].toBool(), DTrees::PREDICT_MAX_VOTE);
190  }
191  else if (key == "PredictSum")
192  UPDATE_FLAG(flags, rhs[i+1].toBool(), DTrees::PREDICT_SUM);
193  else if (key == "PredictMaxVote")
194  UPDATE_FLAG(flags, rhs[i+1].toBool(), DTrees::PREDICT_MAX_VOTE);
195  else
196  mexErrMsgIdAndTxt("mexopencv:error",
197  "Unrecognized option %s", key.c_str());
198  }
199  Mat samples(rhs[2].toMat(CV_32F)),
200  results;
201  float f = obj->predict(samples, results, flags);
202  plhs[0] = MxArray(results);
203  if (nlhs>1)
204  plhs[1] = MxArray(f);
205  }
206  else if (method == "getNodes") {
207  nargchk(nrhs==2 && nlhs<=1);
208  plhs[0] = toStruct(obj->getNodes());
209  }
210  else if (method == "getRoots") {
211  nargchk(nrhs==2 && nlhs<=1);
212  plhs[0] = MxArray(obj->getRoots());
213  }
214  else if (method == "getSplits") {
215  nargchk(nrhs==2 && nlhs<=1);
216  plhs[0] = toStruct(obj->getSplits());
217  }
218  else if (method == "getSubsets") {
219  nargchk(nrhs==2 && nlhs<=1);
220  plhs[0] = MxArray(obj->getSubsets());
221  }
222  else if (method == "get") {
223  nargchk(nrhs==3 && nlhs<=1);
224  string prop(rhs[2].toString());
225  if (prop == "CVFolds")
226  plhs[0] = MxArray(obj->getCVFolds());
227  else if (prop == "MaxCategories")
228  plhs[0] = MxArray(obj->getMaxCategories());
229  else if (prop == "MaxDepth")
230  plhs[0] = MxArray(obj->getMaxDepth());
231  else if (prop == "MinSampleCount")
232  plhs[0] = MxArray(obj->getMinSampleCount());
233  else if (prop == "Priors")
234  plhs[0] = MxArray(obj->getPriors());
235  else if (prop == "RegressionAccuracy")
236  plhs[0] = MxArray(obj->getRegressionAccuracy());
237  else if (prop == "TruncatePrunedTree")
238  plhs[0] = MxArray(obj->getTruncatePrunedTree());
239  else if (prop == "Use1SERule")
240  plhs[0] = MxArray(obj->getUse1SERule());
241  else if (prop == "UseSurrogates")
242  plhs[0] = MxArray(obj->getUseSurrogates());
243  else
244  mexErrMsgIdAndTxt("mexopencv:error",
245  "Unrecognized property %s", prop.c_str());
246  }
247  else if (method == "set") {
248  nargchk(nrhs==4 && nlhs==0);
249  string prop(rhs[2].toString());
250  if (prop == "CVFolds")
251  obj->setCVFolds(rhs[3].toInt());
252  else if (prop == "MaxCategories")
253  obj->setMaxCategories(rhs[3].toInt());
254  else if (prop == "MaxDepth")
255  obj->setMaxDepth(rhs[3].toInt());
256  else if (prop == "MinSampleCount")
257  obj->setMinSampleCount(rhs[3].toInt());
258  else if (prop == "Priors")
259  obj->setPriors(rhs[3].toMat());
260  else if (prop == "RegressionAccuracy")
261  obj->setRegressionAccuracy(rhs[3].toFloat());
262  else if (prop == "TruncatePrunedTree")
263  obj->setTruncatePrunedTree(rhs[3].toBool());
264  else if (prop == "Use1SERule")
265  obj->setUse1SERule(rhs[3].toBool());
266  else if (prop == "UseSurrogates")
267  obj->setUseSurrogates(rhs[3].toBool());
268  else
269  mexErrMsgIdAndTxt("mexopencv:error",
270  "Unrecognized property %s", prop.c_str());
271  }
272  else
273  mexErrMsgIdAndTxt("mexopencv:error",
274  "Unrecognized operation %s", method.c_str());
275 }
virtual int getMaxDepth() const=0
virtual bool getUse1SERule() const=0
virtual bool isTrained() const=0
virtual void setMinSampleCount(int val)=0
LIBMWMEX_API_EXTERN_C void mexLock(void)
Lock a MEX-function so that it cannot be cleared from memory.
virtual void setUse1SERule(bool val)=0
virtual void setRegressionAccuracy(float val)=0
virtual void setUseSurrogates(bool val)=0
STL namespace.
virtual float getRegressionAccuracy() const=0
T end(T... args)
virtual bool isOpened() const
struct mxArray_tag mxArray
Forward declaration for mxArray.
Definition: matrix.h:259
STL class.
int last_id
Last object id to allocate.
Definition: DTrees_.cpp:18
virtual bool getUseSurrogates() const=0
MxArray toStruct(const std::vector< cv::ml::DTrees::Node > &nodes)
Convert tree nodes to struct array.
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
Main entry called from Matlab.
Definition: DTrees_.cpp:30
virtual bool train(const Ptr< TrainData > &trainData, int flags=0)
cv::Ptr< cv::ml::TrainData > loadTrainData(const std::string &filename, std::vector< MxArray >::const_iterator first, std::vector< MxArray >::const_iterator last)
Read a dataset from a CSV file.
virtual const std::vector< Split > & getSplits() const=0
virtual void clear()
virtual String releaseAndGetString()
#define CV_32F
virtual void write(FileStorage &fs) const
InputOutputArray noArray()
#define UPDATE_FLAG(NUM, TF, BIT)
set or clear a bit in flag depending on bool value
Definition: mexopencv.hpp:174
LIBMWMEX_API_EXTERN_C void mexErrMsgIdAndTxt(const char *identifier, const char *err_msg,...)
Issue formatted error message with corresponding error identifier and return to MATLAB prompt...
virtual const std::vector< Node > & getNodes() const=0
LIBMWMEX_API_EXTERN_C void mexUnlock(void)
Unlock a locked MEX-function so that it can be cleared from memory.
mxArray object wrapper for data conversion and manipulation.
Definition: MxArray.hpp:123
void nargchk(bool cond)
Alias for input/output arguments number check.
Definition: mexopencv.hpp:181
virtual int getMaxCategories() const=0
virtual bool getTruncatePrunedTree() const=0
STL class.
bool empty() const
virtual cv::Mat getPriors() const=0
virtual const std::vector< int > & getSubsets() const=0
#define CV_32S
virtual int getMinSampleCount() const=0
virtual String getDefaultName() const
Global constant definitions.
T begin(T... args)
virtual float predict(InputArray samples, OutputArray results=noArray(), int flags=0) const=0
T c_str(T... args)
virtual int getCVFolds() const=0
virtual bool empty() const
virtual void setTruncatePrunedTree(bool val)=0
virtual bool isClassifier() const=0
virtual float calcError(const Ptr< TrainData > &data, bool test, OutputArray resp) const
virtual void save(const String &filename) const
cv::Ptr< cv::ml::TrainData > createTrainData(const cv::Mat &samples, const cv::Mat &responses, std::vector< MxArray >::const_iterator first, std::vector< MxArray >::const_iterator last)
Create an instance of TrainData using options in arguments.
virtual void setMaxDepth(int val)=0
map< int, Ptr< DTrees > > obj_
Object container.
Definition: DTrees_.cpp:20
virtual const std::vector< int > & getRoots() const=0
virtual void setPriors(const cv::Mat &val)=0
void create(int arows, int acols, int atype, Target target=ARRAY_BUFFER, bool autoRelease=false)
Common definitions for the ml module.
virtual void setCVFolds(int val)=0
virtual int getVarCount() const=0
cv::Mat toMat() const
virtual void setMaxCategories(int val)=0