37 int id = rhs[0].toInt();
38 string method(rhs[1].toString());
41 if (method ==
"new") {
53 if (method ==
"delete") {
58 else if (method ==
"clear") {
62 else if (method ==
"load") {
63 nargchk(nrhs>=3 && (nrhs%2)==1 && nlhs==0);
65 bool loadFromString =
false;
66 for (
int i=3; i<nrhs; i+=2) {
67 string key(rhs[i].toString());
69 objname = rhs[i+1].toString();
70 else if (key ==
"FromString")
71 loadFromString = rhs[i+1].toBool();
74 "Unrecognized option %s", key.
c_str());
77 obj_[id] = (loadFromString ?
78 Algorithm::loadFromString<RTrees>(rhs[2].toString(), objname) :
79 Algorithm::load<RTrees>(rhs[2].toString(), objname));
81 else if (method ==
"save") {
83 string fname(rhs[2].toString());
86 FileStorage fs(fname, FileStorage::WRITE + FileStorage::MEMORY);
98 else if (method ==
"empty") {
102 else if (method ==
"getDefaultName") {
106 else if (method ==
"getVarCount") {
110 else if (method ==
"isClassifier") {
114 else if (method ==
"isTrained") {
118 else if (method ==
"train") {
119 nargchk(nrhs>=4 && (nrhs%2)==0 && nlhs<=1);
122 for (
int i=4; i<nrhs; i+=2) {
123 string key(rhs[i].toString());
125 dataOptions = rhs[i+1].toVector<
MxArray>();
126 else if (key ==
"Flags")
127 flags = rhs[i+1].toInt();
128 else if (key ==
"RawOutput")
129 UPDATE_FLAG(flags, rhs[i+1].toBool(), StatModel::RAW_OUTPUT);
130 else if (key ==
"PredictSum")
131 UPDATE_FLAG(flags, rhs[i+1].toBool(), DTrees::PREDICT_SUM);
132 else if (key ==
"PredictMaxVote")
133 UPDATE_FLAG(flags, rhs[i+1].toBool(), DTrees::PREDICT_MAX_VOTE);
136 "Unrecognized option %s", key.
c_str());
141 dataOptions.
begin(), dataOptions.
end());
146 dataOptions.
begin(), dataOptions.
end());
147 bool b = obj->
train(data, flags);
150 else if (method ==
"calcError") {
151 nargchk(nrhs>=4 && (nrhs%2)==0 && nlhs<=2);
154 for (
int i=4; i<nrhs; i+=2) {
155 string key(rhs[i].toString());
157 dataOptions = rhs[i+1].toVector<
MxArray>();
158 else if (key ==
"TestError")
159 test = rhs[i+1].toBool();
162 "Unrecognized option %s", key.
c_str());
167 dataOptions.
begin(), dataOptions.
end());
172 dataOptions.
begin(), dataOptions.
end());
179 else if (method ==
"predict") {
180 nargchk(nrhs>=3 && (nrhs%2)==1 && nlhs<=2);
182 for (
int i=3; i<nrhs; i+=2) {
183 string key(rhs[i].toString());
185 flags = rhs[i+1].toInt();
186 else if (key ==
"RawOutput")
187 UPDATE_FLAG(flags, rhs[i+1].toBool(), StatModel::RAW_OUTPUT);
188 else if (key ==
"CompressedInput")
189 UPDATE_FLAG(flags, rhs[i+1].toBool(), StatModel::COMPRESSED_INPUT);
190 else if (key ==
"PreprocessedInput")
191 UPDATE_FLAG(flags, rhs[i+1].toBool(), StatModel::PREPROCESSED_INPUT);
192 else if (key ==
"PredictAuto") {
194 UPDATE_FLAG(flags, !rhs[i+1].toBool(), DTrees::PREDICT_SUM);
195 UPDATE_FLAG(flags, !rhs[i+1].toBool(), DTrees::PREDICT_MAX_VOTE);
197 else if (key ==
"PredictSum")
198 UPDATE_FLAG(flags, rhs[i+1].toBool(), DTrees::PREDICT_SUM);
199 else if (key ==
"PredictMaxVote")
200 UPDATE_FLAG(flags, rhs[i+1].toBool(), DTrees::PREDICT_MAX_VOTE);
203 "Unrecognized option %s", key.
c_str());
207 float f = obj->
predict(samples, results, flags);
212 else if (method ==
"getNodes") {
216 else if (method ==
"getRoots") {
220 else if (method ==
"getSplits") {
224 else if (method ==
"getSubsets") {
228 else if (method ==
"getVarImportance") {
232 else if (method ==
"getVotes") {
233 nargchk(nrhs>=3 && (nrhs%2)==1 && nlhs<=1);
235 for (
int i=3; i<nrhs; i+=2) {
236 string key(rhs[i].toString());
238 flags = rhs[i+1].toInt();
239 else if (key ==
"RawOutput")
240 UPDATE_FLAG(flags, rhs[i+1].toBool(), StatModel::RAW_OUTPUT);
241 else if (key ==
"CompressedInput")
242 UPDATE_FLAG(flags, rhs[i+1].toBool(), StatModel::COMPRESSED_INPUT);
243 else if (key ==
"PreprocessedInput")
244 UPDATE_FLAG(flags, rhs[i+1].toBool(), StatModel::PREPROCESSED_INPUT);
245 else if (key ==
"PredictAuto") {
247 UPDATE_FLAG(flags, !rhs[i+1].toBool(), DTrees::PREDICT_SUM);
248 UPDATE_FLAG(flags, !rhs[i+1].toBool(), DTrees::PREDICT_MAX_VOTE);
250 else if (key ==
"PredictSum")
251 UPDATE_FLAG(flags, rhs[i+1].toBool(), DTrees::PREDICT_SUM);
252 else if (key ==
"PredictMaxVote")
253 UPDATE_FLAG(flags, rhs[i+1].toBool(), DTrees::PREDICT_MAX_VOTE);
256 "Unrecognized option %s", key.
c_str());
260 obj->
getVotes(samples, results, flags);
263 else if (method ==
"get") {
265 string prop(rhs[2].toString());
266 if (prop ==
"CVFolds")
268 else if (prop ==
"MaxCategories")
270 else if (prop ==
"MaxDepth")
272 else if (prop ==
"MinSampleCount")
274 else if (prop ==
"Priors")
276 else if (prop ==
"RegressionAccuracy")
278 else if (prop ==
"TruncatePrunedTree")
280 else if (prop ==
"Use1SERule")
282 else if (prop ==
"UseSurrogates")
284 else if (prop ==
"ActiveVarCount")
286 else if (prop ==
"CalculateVarImportance")
288 else if (prop ==
"TermCriteria")
292 "Unrecognized property %s", prop.
c_str());
294 else if (method ==
"set") {
296 string prop(rhs[2].toString());
297 if (prop ==
"CVFolds")
299 else if (prop ==
"MaxCategories")
301 else if (prop ==
"MaxDepth")
303 else if (prop ==
"MinSampleCount")
305 else if (prop ==
"Priors")
307 else if (prop ==
"RegressionAccuracy")
309 else if (prop ==
"TruncatePrunedTree")
311 else if (prop ==
"Use1SERule")
313 else if (prop ==
"UseSurrogates")
315 else if (prop ==
"ActiveVarCount")
317 else if (prop ==
"CalculateVarImportance")
319 else if (prop ==
"TermCriteria")
323 "Unrecognized property %s", prop.
c_str());
327 "Unrecognized operation %s", method.
c_str());
virtual int getMaxDepth() const=0
virtual bool getUse1SERule() const=0
virtual bool isTrained() const=0
virtual void setActiveVarCount(int val)=0
virtual void setMinSampleCount(int val)=0
map< int, Ptr< RTrees > > obj_
Object container.
LIBMWMEX_API_EXTERN_C void mexLock(void)
Lock a MEX-function so that it cannot be cleared from memory.
virtual void setUse1SERule(bool val)=0
virtual void setRegressionAccuracy(float val)=0
virtual void setUseSurrogates(bool val)=0
virtual bool getCalculateVarImportance() const=0
virtual float getRegressionAccuracy() const=0
virtual bool isOpened() const
struct mxArray_tag mxArray
Forward declaration for mxArray.
virtual bool getUseSurrogates() const=0
MxArray toStruct(const std::vector< cv::ml::DTrees::Node > &nodes)
Convert tree nodes to struct array.
virtual bool train(const Ptr< TrainData > &trainData, int flags=0)
cv::Ptr< cv::ml::TrainData > loadTrainData(const std::string &filename, std::vector< MxArray >::const_iterator first, std::vector< MxArray >::const_iterator last)
Read a dataset from a CSV file.
virtual const std::vector< Split > & getSplits() const=0
virtual String releaseAndGetString()
virtual void setCalculateVarImportance(bool val)=0
virtual TermCriteria getTermCriteria() const=0
virtual void setTermCriteria(const TermCriteria &val)=0
virtual void write(FileStorage &fs) const
virtual Mat getVarImportance() const=0
InputOutputArray noArray()
#define UPDATE_FLAG(NUM, TF, BIT)
set or clear a bit in flag depending on bool value
LIBMWMEX_API_EXTERN_C void mexErrMsgIdAndTxt(const char *identifier, const char *err_msg,...)
Issue formatted error message with corresponding error identifier and return to MATLAB prompt...
virtual const std::vector< Node > & getNodes() const=0
LIBMWMEX_API_EXTERN_C void mexUnlock(void)
Unlock a locked MEX-function so that it can be cleared from memory.
mxArray object wrapper for data conversion and manipulation.
void nargchk(bool cond)
Alias for input/output arguments number check.
virtual int getMaxCategories() const=0
virtual bool getTruncatePrunedTree() const=0
virtual cv::Mat getPriors() const=0
void getVotes(InputArray samples, OutputArray results, int flags) const
virtual const std::vector< int > & getSubsets() const=0
virtual int getMinSampleCount() const=0
virtual String getDefaultName() const
Global constant definitions.
virtual float predict(InputArray samples, OutputArray results=noArray(), int flags=0) const=0
virtual int getCVFolds() const=0
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
Main entry called from Matlab.
virtual bool empty() const
virtual void setTruncatePrunedTree(bool val)=0
virtual int getActiveVarCount() const=0
virtual bool isClassifier() const=0
virtual float calcError(const Ptr< TrainData > &data, bool test, OutputArray resp) const
virtual void save(const String &filename) const
cv::Ptr< cv::ml::TrainData > createTrainData(const cv::Mat &samples, const cv::Mat &responses, std::vector< MxArray >::const_iterator first, std::vector< MxArray >::const_iterator last)
Create an instance of TrainData using options in arguments.
virtual void setMaxDepth(int val)=0
virtual const std::vector< int > & getRoots() const=0
virtual void setPriors(const cv::Mat &val)=0
void create(int arows, int acols, int atype, Target target=ARRAY_BUFFER, bool autoRelease=false)
Common definitions for the ml module.
virtual void setCVFolds(int val)=0
virtual int getVarCount() const=0
virtual void setMaxCategories(int val)=0
int last_id
Last object id to allocate.