mexopencv  3.4.1
MEX interface for OpenCV library
NormalBayesClassifier_.cpp
Go to the documentation of this file.
1 
8 #include "mexopencv.hpp"
9 #include "mexopencv_ml.hpp"
10 #include "opencv2/ml.hpp"
11 using namespace std;
12 using namespace cv;
13 using namespace cv::ml;
14 
15 // Persistent objects
16 namespace {
18 int last_id = 0;
21 }
22 
30 void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
31 {
32  // Check the number of arguments
33  nargchk(nrhs>=2 && nlhs<=3);
34 
35  // Argument vector
36  vector<MxArray> rhs(prhs, prhs+nrhs);
37  int id = rhs[0].toInt();
38  string method(rhs[1].toString());
39 
40  // Constructor is called. Create a new object from argument
41  if (method == "new") {
42  nargchk(nrhs==2 && nlhs<=1);
44  plhs[0] = MxArray(last_id);
45  mexLock();
46  return;
47  }
48 
49  // Big operation switch
51  if (obj.empty())
52  mexErrMsgIdAndTxt("mexopencv:error", "Object not found id=%d", id);
53  if (method == "delete") {
54  nargchk(nrhs==2 && nlhs==0);
55  obj_.erase(id);
56  mexUnlock();
57  }
58  else if (method == "clear") {
59  nargchk(nrhs==2 && nlhs==0);
60  obj->clear();
61  }
62  else if (method == "load") {
63  nargchk(nrhs>=3 && (nrhs%2)==1 && nlhs==0);
64  string objname;
65  bool loadFromString = false;
66  for (int i=3; i<nrhs; i+=2) {
67  string key(rhs[i].toString());
68  if (key == "ObjName")
69  objname = rhs[i+1].toString();
70  else if (key == "FromString")
71  loadFromString = rhs[i+1].toBool();
72  else
73  mexErrMsgIdAndTxt("mexopencv:error",
74  "Unrecognized option %s", key.c_str());
75  }
76  //obj_[id] = NormalBayesClassifier::load(rhs[2].toString());
77  obj_[id] = (loadFromString ?
78  Algorithm::loadFromString<NormalBayesClassifier>(rhs[2].toString(), objname) :
79  Algorithm::load<NormalBayesClassifier>(rhs[2].toString(), objname));
80  }
81  else if (method == "save") {
82  nargchk(nrhs==3 && nlhs<=1);
83  string fname(rhs[2].toString());
84  if (nlhs > 0) {
85  // write to memory, and return string
86  FileStorage fs(fname, FileStorage::WRITE + FileStorage::MEMORY);
87  if (!fs.isOpened())
88  mexErrMsgIdAndTxt("mexopencv:error", "Failed to open file");
89  fs << obj->getDefaultName() << "{";
90  obj->write(fs);
91  fs << "}";
92  plhs[0] = MxArray(fs.releaseAndGetString());
93  }
94  else
95  // write to disk
96  obj->save(fname);
97  }
98  else if (method == "empty") {
99  nargchk(nrhs==2 && nlhs<=1);
100  plhs[0] = MxArray(obj->empty());
101  }
102  else if (method == "getDefaultName") {
103  nargchk(nrhs==2 && nlhs<=1);
104  plhs[0] = MxArray(obj->getDefaultName());
105  }
106  else if (method == "getVarCount") {
107  nargchk(nrhs==2 && nlhs<=1);
108  plhs[0] = MxArray(obj->getVarCount());
109  }
110  else if (method == "isClassifier") {
111  nargchk(nrhs==2 && nlhs<=1);
112  plhs[0] = MxArray(obj->isClassifier());
113  }
114  else if (method == "isTrained") {
115  nargchk(nrhs==2 && nlhs<=1);
116  plhs[0] = MxArray(obj->isTrained());
117  }
118  else if (method == "train") {
119  nargchk(nrhs>=4 && (nrhs%2)==0 && nlhs<=1);
120  vector<MxArray> dataOptions;
121  int flags = 0;
122  for (int i=4; i<nrhs; i+=2) {
123  string key(rhs[i].toString());
124  if (key == "Data")
125  dataOptions = rhs[i+1].toVector<MxArray>();
126  else if (key == "Flags")
127  flags = rhs[i+1].toInt();
128  else if (key == "UpdateModel")
129  UPDATE_FLAG(flags, rhs[i+1].toBool(), StatModel::UPDATE_MODEL);
130  else
131  mexErrMsgIdAndTxt("mexopencv:error",
132  "Unrecognized option %s", key.c_str());
133  }
134  Ptr<TrainData> data;
135  if (rhs[2].isChar())
136  data = loadTrainData(rhs[2].toString(),
137  dataOptions.begin(), dataOptions.end());
138  else
139  data = createTrainData(
140  rhs[2].toMat(CV_32F),
141  rhs[3].toMat(CV_32S),
142  dataOptions.begin(), dataOptions.end());
143  bool b = obj->train(data, flags);
144  plhs[0] = MxArray(b);
145  }
146  else if (method == "calcError") {
147  nargchk(nrhs>=4 && (nrhs%2)==0 && nlhs<=2);
148  vector<MxArray> dataOptions;
149  bool test = false;
150  for (int i=4; i<nrhs; i+=2) {
151  string key(rhs[i].toString());
152  if (key == "Data")
153  dataOptions = rhs[i+1].toVector<MxArray>();
154  else if (key == "TestError")
155  test = rhs[i+1].toBool();
156  else
157  mexErrMsgIdAndTxt("mexopencv:error",
158  "Unrecognized option %s", key.c_str());
159  }
160  Ptr<TrainData> data;
161  if (rhs[2].isChar())
162  data = loadTrainData(rhs[2].toString(),
163  dataOptions.begin(), dataOptions.end());
164  else
165  data = createTrainData(
166  rhs[2].toMat(CV_32F),
167  rhs[3].toMat(CV_32S),
168  dataOptions.begin(), dataOptions.end());
169  Mat resp;
170  float err = obj->calcError(data, test, (nlhs>1 ? resp : noArray()));
171  plhs[0] = MxArray(err);
172  if (nlhs>1)
173  plhs[1] = MxArray(resp);
174  }
175  else if (method == "predict") {
176  nargchk(nrhs>=3 && (nrhs%2)==1 && nlhs<=2);
177  int flags = 0;
178  for (int i=3; i<nrhs; i+=2) {
179  string key(rhs[i].toString());
180  if (key == "Flags")
181  flags = rhs[i+1].toInt();
182  else if (key == "RawOutput")
183  UPDATE_FLAG(flags, rhs[i+1].toBool(), StatModel::RAW_OUTPUT);
184  else
185  mexErrMsgIdAndTxt("mexopencv:error",
186  "Unrecognized option %s", key.c_str());
187  }
188  Mat samples(rhs[2].toMat(CV_32F)),
189  results;
190  float f = obj->predict(samples, results, flags);
191  plhs[0] = MxArray(results);
192  if (nlhs>1)
193  plhs[1] = MxArray(f);
194  }
195  else if (method == "predictProb") {
196  nargchk(nrhs>=3 && (nrhs%2)==1 && nlhs<=3);
197  int flags = 0;
198  for (int i=3; i<nrhs; i+=2) {
199  string key(rhs[i].toString());
200  if (key == "Flags")
201  flags = rhs[i+1].toInt();
202  else if (key == "RawOutput")
203  UPDATE_FLAG(flags, rhs[i+1].toBool(), StatModel::RAW_OUTPUT);
204  else
205  mexErrMsgIdAndTxt("mexopencv:error",
206  "Unrecognized option %s", key.c_str());
207  }
208  Mat inputs(rhs[2].toMat(CV_32F)),
209  outputs,
210  outputProbs;
211  float f = obj->predictProb(inputs, outputs,
212  (nlhs>1 ? outputProbs : noArray()), flags);
213  plhs[0] = MxArray(outputs);
214  if (nlhs>1)
215  plhs[1] = MxArray(outputProbs);
216  if (nlhs>2)
217  plhs[2] = MxArray(f);
218  }
219  else
220  mexErrMsgIdAndTxt("mexopencv:error",
221  "Unrecognized operation %s", method.c_str());
222 }
map< int, Ptr< NormalBayesClassifier > > obj_
Object container.
virtual bool isTrained() const=0
LIBMWMEX_API_EXTERN_C void mexLock(void)
Lock a MEX-function so that it cannot be cleared from memory.
STL namespace.
T end(T... args)
virtual bool isOpened() const
struct mxArray_tag mxArray
Forward declaration for mxArray.
Definition: matrix.h:259
STL class.
virtual bool train(const Ptr< TrainData > &trainData, int flags=0)
cv::Ptr< cv::ml::TrainData > loadTrainData(const std::string &filename, std::vector< MxArray >::const_iterator first, std::vector< MxArray >::const_iterator last)
Read a dataset from a CSV file.
virtual void clear()
virtual String releaseAndGetString()
#define CV_32F
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
Main entry called from Matlab.
virtual void write(FileStorage &fs) const
InputOutputArray noArray()
#define UPDATE_FLAG(NUM, TF, BIT)
set or clear a bit in flag depending on bool value
Definition: mexopencv.hpp:174
LIBMWMEX_API_EXTERN_C void mexErrMsgIdAndTxt(const char *identifier, const char *err_msg,...)
Issue formatted error message with corresponding error identifier and return to MATLAB prompt...
virtual float predictProb(InputArray inputs, OutputArray outputs, OutputArray outputProbs, int flags=0) const=0
LIBMWMEX_API_EXTERN_C void mexUnlock(void)
Unlock a locked MEX-function so that it can be cleared from memory.
mxArray object wrapper for data conversion and manipulation.
Definition: MxArray.hpp:123
void nargchk(bool cond)
Alias for input/output arguments number check.
Definition: mexopencv.hpp:181
STL class.
bool empty() const
#define CV_32S
virtual String getDefaultName() const
Global constant definitions.
T begin(T... args)
virtual float predict(InputArray samples, OutputArray results=noArray(), int flags=0) const=0
T c_str(T... args)
virtual bool empty() const
virtual bool isClassifier() const=0
virtual float calcError(const Ptr< TrainData > &data, bool test, OutputArray resp) const
virtual void save(const String &filename) const
cv::Ptr< cv::ml::TrainData > createTrainData(const cv::Mat &samples, const cv::Mat &responses, std::vector< MxArray >::const_iterator first, std::vector< MxArray >::const_iterator last)
Create an instance of TrainData using options in arguments.
void create(int arows, int acols, int atype, Target target=ARRAY_BUFFER, bool autoRelease=false)
Common definitions for the ml module.
virtual int getVarCount() const=0
cv::Mat toMat() const