mexopencv  3.4.1
MEX interface for OpenCV library
EM_.cpp
Go to the documentation of this file.
1 
8 #include "mexopencv.hpp"
9 #include "mexopencv_ml.hpp"
10 #include "opencv2/ml.hpp"
11 using namespace std;
12 using namespace cv;
13 using namespace cv::ml;
14 
15 // Persistent objects
16 namespace {
18 int last_id = 0;
21 
24  ("Spherical", cv::ml::EM::COV_MAT_SPHERICAL)
25  ("Diagonal", cv::ml::EM::COV_MAT_DIAGONAL)
26  ("Generic", cv::ml::EM::COV_MAT_GENERIC)
27  ("Default", cv::ml::EM::COV_MAT_DEFAULT);
28 
31  (cv::ml::EM::COV_MAT_SPHERICAL, "Spherical")
32  (cv::ml::EM::COV_MAT_DIAGONAL, "Diagonal")
33  (cv::ml::EM::COV_MAT_GENERIC, "Generic")
34  (cv::ml::EM::COV_MAT_DEFAULT, "Default");
35 }
36 
44 void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
45 {
46  // Check the number of arguments
47  nargchk(nrhs>=2 && nlhs<=4);
48 
49  // Argument vector
50  vector<MxArray> rhs(prhs, prhs+nrhs);
51  int id = rhs[0].toInt();
52  string method(rhs[1].toString());
53 
54  // Constructor is called. Create a new object from argument
55  if (method == "new") {
56  nargchk(nrhs==2 && nlhs<=1);
57  obj_[++last_id] = EM::create();
58  plhs[0] = MxArray(last_id);
59  mexLock();
60  return;
61  }
62 
63  // Big operation switch
64  Ptr<EM> obj = obj_[id];
65  if (obj.empty())
66  mexErrMsgIdAndTxt("mexopencv:error", "Object not found id=%d", id);
67  if (method == "delete") {
68  nargchk(nrhs==2 && nlhs==0);
69  obj_.erase(id);
70  mexUnlock();
71  }
72  else if (method == "clear") {
73  nargchk(nrhs==2 && nlhs==0);
74  obj->clear();
75  }
76  else if (method == "load") {
77  nargchk(nrhs>=3 && (nrhs%2)==1 && nlhs==0);
78  string objname;
79  bool loadFromString = false;
80  for (int i=3; i<nrhs; i+=2) {
81  string key(rhs[i].toString());
82  if (key == "ObjName")
83  objname = rhs[i+1].toString();
84  else if (key == "FromString")
85  loadFromString = rhs[i+1].toBool();
86  else
87  mexErrMsgIdAndTxt("mexopencv:error",
88  "Unrecognized option %s", key.c_str());
89  }
90  //obj_[id] = EM::load(rhs[2].toString());
91  obj_[id] = (loadFromString ?
92  Algorithm::loadFromString<EM>(rhs[2].toString(), objname) :
93  Algorithm::load<EM>(rhs[2].toString(), objname));
94  }
95  else if (method == "save") {
96  nargchk(nrhs==3 && nlhs<=1);
97  string fname(rhs[2].toString());
98  if (nlhs > 0) {
99  // write to memory, and return string
100  FileStorage fs(fname, FileStorage::WRITE + FileStorage::MEMORY);
101  if (!fs.isOpened())
102  mexErrMsgIdAndTxt("mexopencv:error", "Failed to open file");
103  fs << obj->getDefaultName() << "{";
104  obj->write(fs);
105  fs << "}";
106  plhs[0] = MxArray(fs.releaseAndGetString());
107  }
108  else
109  // write to disk
110  obj->save(fname);
111  }
112  else if (method == "empty") {
113  nargchk(nrhs==2 && nlhs<=1);
114  plhs[0] = MxArray(obj->empty());
115  }
116  else if (method == "getDefaultName") {
117  nargchk(nrhs==2 && nlhs<=1);
118  plhs[0] = MxArray(obj->getDefaultName());
119  }
120  else if (method == "getVarCount") {
121  nargchk(nrhs==2 && nlhs<=1);
122  plhs[0] = MxArray(obj->getVarCount());
123  }
124  else if (method == "isClassifier") {
125  nargchk(nrhs==2 && nlhs<=1);
126  plhs[0] = MxArray(obj->isClassifier());
127  }
128  else if (method == "isTrained") {
129  nargchk(nrhs==2 && nlhs<=1);
130  plhs[0] = MxArray(obj->isTrained());
131  }
132  else if (method == "train") {
133  nargchk(nrhs>=3 && (nrhs%2)==1 && nlhs<=1);
134  vector<MxArray> dataOptions;
135  int flags = 0;
136  for (int i=3; i<nrhs; i+=2) {
137  string key(rhs[i].toString());
138  if (key == "Data")
139  dataOptions = rhs[i+1].toVector<MxArray>();
140  else if (key == "Flags")
141  flags = rhs[i+1].toInt();
142  else
143  mexErrMsgIdAndTxt("mexopencv:error",
144  "Unrecognized option %s", key.c_str());
145  }
146  Ptr<TrainData> data;
147  if (rhs[2].isChar())
148  data = loadTrainData(rhs[2].toString(),
149  dataOptions.begin(), dataOptions.end());
150  else
151  data = createTrainData(
152  rhs[2].toMat(CV_32F), Mat(),
153  dataOptions.begin(), dataOptions.end());
154  bool b = obj->train(data, flags);
155  plhs[0] = MxArray(b);
156  }
157  else if (method == "calcError") {
158  nargchk(nrhs>=4 && (nrhs%2)==0 && nlhs<=2);
159  vector<MxArray> dataOptions;
160  bool test = false;
161  for (int i=4; i<nrhs; i+=2) {
162  string key(rhs[i].toString());
163  if (key == "Data")
164  dataOptions = rhs[i+1].toVector<MxArray>();
165  else if (key == "TestError")
166  test = rhs[i+1].toBool();
167  else
168  mexErrMsgIdAndTxt("mexopencv:error",
169  "Unrecognized option %s", key.c_str());
170  }
171  Ptr<TrainData> data;
172  if (rhs[2].isChar())
173  data = loadTrainData(rhs[2].toString(),
174  dataOptions.begin(), dataOptions.end());
175  else
176  data = createTrainData(
177  rhs[2].toMat(CV_32F),
178  rhs[3].toMat(rhs[3].isInt32() ? CV_32S : CV_32F),
179  dataOptions.begin(), dataOptions.end());
180  Mat resp;
181  float err = obj->calcError(data, test, (nlhs>1 ? resp : noArray()));
182  plhs[0] = MxArray(err);
183  if (nlhs>1)
184  plhs[1] = MxArray(resp);
185  }
186  else if (method == "predict") {
187  nargchk(nrhs>=3 && (nrhs%2)==1 && nlhs<=2);
188  int flags = 0;
189  for (int i=3; i<nrhs; i+=2) {
190  string key(rhs[i].toString());
191  if (key == "Flags")
192  flags = rhs[i+1].toInt();
193  else
194  mexErrMsgIdAndTxt("mexopencv:error",
195  "Unrecognized option %s", key.c_str());
196  }
197  Mat samples(rhs[2].toMat(rhs[2].isSingle() ? CV_32F : CV_64F)),
198  results;
199  float f = obj->predict(samples, results, flags);
200  plhs[0] = MxArray(results);
201  if (nlhs>1)
202  plhs[1] = MxArray(f);
203  }
204  else if (method == "trainEM") {
205  nargchk(nrhs==3 && nlhs<=4);
206  Mat samples(rhs[2].toMat(rhs[2].isSingle() ? CV_32F : CV_64F)),
207  logLikelihoods, labels, probs;
208  bool b = obj->trainEM(samples,
209  (nlhs>0 ? logLikelihoods : noArray()),
210  (nlhs>1 ? labels : noArray()),
211  (nlhs>2 ? probs : noArray()));
212  plhs[0] = MxArray(logLikelihoods);
213  if (nlhs > 1)
214  plhs[1] = MxArray(labels);
215  if (nlhs > 2)
216  plhs[2] = MxArray(probs);
217  if (nlhs > 3)
218  plhs[3] = MxArray(b);
219  }
220  else if (method == "trainE") {
221  nargchk(nrhs>=4 && (nrhs%2)==0 && nlhs<=4);
222  vector<Mat> covs0;
223  Mat weights0;
224  for(int i = 4; i < nrhs; i += 2) {
225  string key(rhs[i].toString());
226  if (key == "Covs0") {
227  //covs0 = rhs[i+1].toVector<Mat>();
228  covs0.clear();
229  vector<MxArray> arr(rhs[i+1].toVector<MxArray>());
230  covs0.reserve(arr.size());
231  for (vector<MxArray>::const_iterator it = arr.begin(); it != arr.end(); ++it)
232  covs0.push_back(it->toMat(
233  it->isSingle() ? CV_32F : CV_64F));
234  }
235  else if (key == "Weights0")
236  weights0 = rhs[i+1].toMat(
237  rhs[i+1].isSingle() ? CV_32F : CV_64F);
238  else
239  mexErrMsgIdAndTxt("mexopencv:error",
240  "Unrecognized option %s", key.c_str());
241  }
242  Mat samples(rhs[2].toMat(rhs[2].isSingle() ? CV_32F : CV_64F)),
243  means0(rhs[3].toMat(rhs[3].isSingle() ? CV_32F : CV_64F)),
244  logLikelihoods, labels, probs;
245  bool b = obj->trainE(samples, means0, covs0, weights0,
246  (nlhs>0 ? logLikelihoods : noArray()),
247  (nlhs>1 ? labels : noArray()),
248  (nlhs>2 ? probs : noArray()));
249  plhs[0] = MxArray(logLikelihoods);
250  if (nlhs > 1)
251  plhs[1] = MxArray(labels);
252  if (nlhs > 2)
253  plhs[2] = MxArray(probs);
254  if (nlhs > 3)
255  plhs[3] = MxArray(b);
256  }
257  else if (method == "trainM") {
258  nargchk(nrhs==4 && nlhs<=4);
259  Mat samples(rhs[2].toMat(rhs[2].isSingle() ? CV_32F : CV_64F)),
260  probs0(rhs[3].toMat(rhs[3].isSingle() ? CV_32F : CV_64F)),
261  logLikelihoods, labels, probs;
262  bool b = obj->trainM(samples, probs0,
263  (nlhs>0 ? logLikelihoods : noArray()),
264  (nlhs>1 ? labels : noArray()),
265  (nlhs>2 ? probs : noArray()));
266  plhs[0] = MxArray(logLikelihoods);
267  if (nlhs > 1)
268  plhs[1] = MxArray(labels);
269  if (nlhs > 2)
270  plhs[2] = MxArray(probs);
271  if (nlhs > 3)
272  plhs[3] = MxArray(b);
273  }
274  else if (method == "predict2") {
275  nargchk(nrhs==3 && nlhs<=3);
276  Mat samples(rhs[2].toMat(rhs[2].isSingle() ? CV_32F : CV_64F)),
277  probs;
278  if (nlhs > 1)
279  probs.create(samples.rows, obj->getClustersNumber(), CV_64F);
280  vector<Vec2d> results;
281  results.reserve(samples.rows);
282  for (size_t i = 0; i < samples.rows; ++i) {
283  Vec2d res = obj->predict2(samples.row(i),
284  (nlhs>1 ? probs.row(i) : noArray()));
285  results.push_back(res);
286  }
287  plhs[0] = MxArray(Mat(results, false).reshape(1,0)); // Nx2
288  if (nlhs > 1)
289  plhs[1] = MxArray(probs); // NxK
290  }
291  else if (method == "getCovs") {
292  nargchk(nrhs==2 && nlhs<=1);
293  vector<Mat> covs;
294  obj->getCovs(covs);
295  plhs[0] = MxArray(covs);
296  }
297  else if (method == "getMeans") {
298  nargchk(nrhs==2 && nlhs<=1);
299  plhs[0] = MxArray(obj->getMeans());
300  }
301  else if (method == "getWeights") {
302  nargchk(nrhs==2 && nlhs<=1);
303  plhs[0] = MxArray(obj->getWeights());
304  }
305  else if (method == "get") {
306  nargchk(nrhs==3 && nlhs<=1);
307  string prop(rhs[2].toString());
308  if (prop == "ClustersNumber")
309  plhs[0] = MxArray(obj->getClustersNumber());
310  else if (prop == "CovarianceMatrixType")
311  plhs[0] = MxArray(CovMatTypeInv[obj->getCovarianceMatrixType()]);
312  else if (prop == "TermCriteria")
313  plhs[0] = MxArray(obj->getTermCriteria());
314  else
315  mexErrMsgIdAndTxt("mexopencv:error",
316  "Unrecognized property %s", prop.c_str());
317  }
318  else if (method == "set") {
319  nargchk(nrhs==4 && nlhs==0);
320  string prop(rhs[2].toString());
321  if (prop == "ClustersNumber")
322  obj->setClustersNumber(rhs[3].toInt());
323  else if (prop == "CovarianceMatrixType")
324  obj->setCovarianceMatrixType(CovMatType[rhs[3].toString()]);
325  else if (prop == "TermCriteria")
326  obj->setTermCriteria(rhs[3].toTermCriteria());
327  else
328  mexErrMsgIdAndTxt("mexopencv:error",
329  "Unrecognized property %s", prop.c_str());
330  }
331  else
332  mexErrMsgIdAndTxt("mexopencv:error",
333  "Unrecognized operation %s", method.c_str());
334 }
virtual Mat getMeans() const=0
virtual bool isTrained() const=0
virtual void getCovs(std::vector< Mat > &covs) const=0
virtual bool trainE(InputArray samples, InputArray means0, InputArray covs0=noArray(), InputArray weights0=noArray(), OutputArray logLikelihoods=noArray(), OutputArray labels=noArray(), OutputArray probs=noArray())=0
LIBMWMEX_API_EXTERN_C void mexLock(void)
Lock a MEX-function so that it cannot be cleared from memory.
STL namespace.
T end(T... args)
virtual bool isOpened() const
virtual Vec2d predict2(InputArray sample, OutputArray probs) const=0
struct mxArray_tag mxArray
Forward declaration for mxArray.
Definition: matrix.h:259
STL class.
virtual void setClustersNumber(int val)=0
virtual int getCovarianceMatrixType() const=0
virtual bool train(const Ptr< TrainData > &trainData, int flags=0)
T push_back(T... args)
cv::Ptr< cv::ml::TrainData > loadTrainData(const std::string &filename, std::vector< MxArray >::const_iterator first, std::vector< MxArray >::const_iterator last)
Read a dataset from a CSV file.
virtual void clear()
virtual String releaseAndGetString()
const ConstMap< int, string > CovMatTypeInv
CovMatTypeInv map for option processing.
Definition: EM_.cpp:30
#define CV_32F
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
Main entry called from Matlab.
Definition: EM_.cpp:44
virtual void write(FileStorage &fs) const
InputOutputArray noArray()
#define CV_64F
LIBMWMEX_API_EXTERN_C void mexErrMsgIdAndTxt(const char *identifier, const char *err_msg,...)
Issue formatted error message with corresponding error identifier and return to MATLAB prompt...
LIBMWMEX_API_EXTERN_C void mexUnlock(void)
Unlock a locked MEX-function so that it can be cleared from memory.
map< int, Ptr< EM > > obj_
Object container.
Definition: EM_.cpp:20
virtual void setCovarianceMatrixType(int val)=0
T clear(T... args)
virtual bool trainM(InputArray samples, InputArray probs0, OutputArray logLikelihoods=noArray(), OutputArray labels=noArray(), OutputArray probs=noArray())=0
mxArray object wrapper for data conversion and manipulation.
Definition: MxArray.hpp:123
void nargchk(bool cond)
Alias for input/output arguments number check.
Definition: mexopencv.hpp:181
virtual void setTermCriteria(const TermCriteria &val)=0
virtual float predict(InputArray samples, OutputArray results=noArray(), int flags=0) const=0
virtual int getClustersNumber() const=0
T size(T... args)
STL class.
bool empty() const
#define CV_32S
int last_id
Last object id to allocate.
Definition: EM_.cpp:18
virtual String getDefaultName() const
const ConstMap< string, int > CovMatType
CovMatType map for option processing.
Definition: EM_.cpp:23
Global constant definitions.
T begin(T... args)
T c_str(T... args)
Mat row(int y) const
virtual bool empty() const
void create(int rows, int cols, int type)
virtual bool isClassifier() const=0
virtual bool trainEM(InputArray samples, OutputArray logLikelihoods=noArray(), OutputArray labels=noArray(), OutputArray probs=noArray())=0
virtual float calcError(const Ptr< TrainData > &data, bool test, OutputArray resp) const
virtual Mat getWeights() const=0
virtual void save(const String &filename) const
virtual TermCriteria getTermCriteria() const=0
cv::Ptr< cv::ml::TrainData > createTrainData(const cv::Mat &samples, const cv::Mat &responses, std::vector< MxArray >::const_iterator first, std::vector< MxArray >::const_iterator last)
Create an instance of TrainData using options in arguments.
std::map wrapper with one-line initialization and lookup method.
Definition: MxArray.hpp:927
void create(int arows, int acols, int atype, Target target=ARRAY_BUFFER, bool autoRelease=false)
Common definitions for the ml module.
virtual int getVarCount() const=0
cv::Mat toMat() const
T reserve(T... args)