mexopencv  3.4.1
MEX interface for OpenCV library
KNearest_.cpp
Go to the documentation of this file.
1 
8 #include "mexopencv.hpp"
9 #include "mexopencv_ml.hpp"
10 #include "opencv2/ml.hpp"
11 using namespace std;
12 using namespace cv;
13 using namespace cv::ml;
14 
15 // Persistent objects
16 namespace {
18 int last_id = 0;
21 
24  ("BruteForce", KNearest::BRUTE_FORCE)
25  ("KDTree", KNearest::KDTREE);
26 
29  (KNearest::BRUTE_FORCE, "BruteForce")
30  (KNearest::KDTREE, "KDTree");
31 }
32 
40 void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
41 {
42  // Check the number of arguments
43  nargchk(nrhs>=2 && nlhs<=4);
44 
45  // Argument vector
46  vector<MxArray> rhs(prhs, prhs+nrhs);
47  int id = rhs[0].toInt();
48  string method(rhs[1].toString());
49 
50  // Constructor is called. Create a new object from argument
51  if (method == "new") {
52  nargchk(nrhs==2 && nlhs<=1);
54  plhs[0] = MxArray(last_id);
55  mexLock();
56  return;
57  }
58 
59  // Big operation switch
60  Ptr<KNearest> obj = obj_[id];
61  if (obj.empty())
62  mexErrMsgIdAndTxt("mexopencv:error", "Object not found id=%d", id);
63  if (method == "delete") {
64  nargchk(nrhs==2 && nlhs==0);
65  obj_.erase(id);
66  mexUnlock();
67  }
68  else if (method == "clear") {
69  nargchk(nrhs==2 && nlhs==0);
70  obj->clear();
71  }
72  else if (method == "load") {
73  nargchk(nrhs>=3 && (nrhs%2)==1 && nlhs==0);
74  string objname;
75  bool loadFromString = false;
76  for (int i=3; i<nrhs; i+=2) {
77  string key(rhs[i].toString());
78  if (key == "ObjName")
79  objname = rhs[i+1].toString();
80  else if (key == "FromString")
81  loadFromString = rhs[i+1].toBool();
82  else
83  mexErrMsgIdAndTxt("mexopencv:error",
84  "Unrecognized option %s", key.c_str());
85  }
86  obj_[id] = (loadFromString ?
87  Algorithm::loadFromString<KNearest>(rhs[2].toString(), objname) :
88  Algorithm::load<KNearest>(rhs[2].toString(), objname));
89  }
90  else if (method == "save") {
91  nargchk(nrhs==3 && nlhs<=1);
92  string fname(rhs[2].toString());
93  if (nlhs > 0) {
94  // write to memory, and return string
95  FileStorage fs(fname, FileStorage::WRITE + FileStorage::MEMORY);
96  if (!fs.isOpened())
97  mexErrMsgIdAndTxt("mexopencv:error", "Failed to open file");
98  fs << obj->getDefaultName() << "{";
99  obj->write(fs);
100  fs << "}";
101  plhs[0] = MxArray(fs.releaseAndGetString());
102  }
103  else
104  // write to disk
105  obj->save(fname);
106  }
107  else if (method == "empty") {
108  nargchk(nrhs==2 && nlhs<=1);
109  plhs[0] = MxArray(obj->empty());
110  }
111  else if (method == "getDefaultName") {
112  nargchk(nrhs==2 && nlhs<=1);
113  plhs[0] = MxArray(obj->getDefaultName());
114  }
115  else if (method == "getVarCount") {
116  nargchk(nrhs==2 && nlhs<=1);
117  plhs[0] = MxArray(obj->getVarCount());
118  }
119  else if (method == "isClassifier") {
120  nargchk(nrhs==2 && nlhs<=1);
121  plhs[0] = MxArray(obj->isClassifier());
122  }
123  else if (method == "isTrained") {
124  nargchk(nrhs==2 && nlhs<=1);
125  plhs[0] = MxArray(obj->isTrained());
126  }
127  else if (method == "train") {
128  nargchk(nrhs>=4 && (nrhs%2)==0 && nlhs<=1);
129  vector<MxArray> dataOptions;
130  int flags = 0;
131  for (int i=4; i<nrhs; i+=2) {
132  string key(rhs[i].toString());
133  if (key == "Data")
134  dataOptions = rhs[i+1].toVector<MxArray>();
135  else if (key == "Flags")
136  flags = rhs[i+1].toInt();
137  else if (key == "UpdateModel")
138  UPDATE_FLAG(flags, rhs[i+1].toBool(), StatModel::UPDATE_MODEL);
139  else
140  mexErrMsgIdAndTxt("mexopencv:error",
141  "Unrecognized option %s", key.c_str());
142  }
143  Ptr<TrainData> data;
144  if (rhs[2].isChar())
145  data = loadTrainData(rhs[2].toString(),
146  dataOptions.begin(), dataOptions.end());
147  else
148  data = createTrainData(
149  rhs[2].toMat(CV_32F),
150  rhs[3].toMat(rhs[3].isInt32() ? CV_32S : CV_32F),
151  dataOptions.begin(), dataOptions.end());
152  bool b = obj->train(data, flags);
153  plhs[0] = MxArray(b);
154  }
155  else if (method == "calcError") {
156  nargchk(nrhs>=4 && (nrhs%2)==0 && nlhs<=2);
157  vector<MxArray> dataOptions;
158  bool test = false;
159  for (int i=4; i<nrhs; i+=2) {
160  string key(rhs[i].toString());
161  if (key == "Data")
162  dataOptions = rhs[i+1].toVector<MxArray>();
163  else if (key == "TestError")
164  test = rhs[i+1].toBool();
165  else
166  mexErrMsgIdAndTxt("mexopencv:error",
167  "Unrecognized option %s", key.c_str());
168  }
169  Ptr<TrainData> data;
170  if (rhs[2].isChar())
171  data = loadTrainData(rhs[2].toString(),
172  dataOptions.begin(), dataOptions.end());
173  else
174  data = createTrainData(
175  rhs[2].toMat(CV_32F),
176  rhs[3].toMat(rhs[3].isInt32() ? CV_32S : CV_32F),
177  dataOptions.begin(), dataOptions.end());
178  Mat resp;
179  float err = obj->calcError(data, test, (nlhs>1 ? resp : noArray()));
180  plhs[0] = MxArray(err);
181  if (nlhs>1)
182  plhs[1] = MxArray(resp);
183  }
184  else if (method == "predict") {
185  nargchk(nrhs>=3 && (nrhs%2)==1 && nlhs<=2);
186  int flags = 0;
187  for (int i=3; i<nrhs; i+=2) {
188  string key(rhs[i].toString());
189  if (key == "Flags")
190  flags = rhs[i+1].toInt();
191  else
192  mexErrMsgIdAndTxt("mexopencv:error",
193  "Unrecognized option %s", key.c_str());
194  }
195  Mat samples(rhs[2].toMat(CV_32F)),
196  results;
197  float f = obj->predict(samples, results, flags);
198  plhs[0] = MxArray(results);
199  if (nlhs>1)
200  plhs[1] = MxArray(f);
201  }
202  else if (method == "findNearest") {
203  nargchk(nrhs==4 && nlhs<=4);
204  Mat samples(rhs[2].toMat(CV_32F));
205  int k = rhs[3].toInt();
206  Mat results, neighborResponses, dist;
207  float f = obj->findNearest(samples, k, results, neighborResponses, dist);
208  plhs[0] = MxArray(results);
209  if (nlhs>1)
210  plhs[1] = MxArray(neighborResponses);
211  if (nlhs>2)
212  plhs[2] = MxArray(dist);
213  if (nlhs>3)
214  plhs[3] = MxArray(f);
215  }
216  else if (method == "get") {
217  nargchk(nrhs==3 && nlhs<=1);
218  string prop(rhs[2].toString());
219  if (prop == "AlgorithmType")
220  plhs[0] = MxArray(InvKNNAlgType[obj->getAlgorithmType()]);
221  else if (prop == "DefaultK")
222  plhs[0] = MxArray(obj->getDefaultK());
223  else if (prop == "Emax")
224  plhs[0] = MxArray(obj->getEmax());
225  else if (prop == "IsClassifier")
226  plhs[0] = MxArray(obj->getIsClassifier());
227  else
228  mexErrMsgIdAndTxt("mexopencv:error",
229  "Unrecognized property %s", prop.c_str());
230  }
231  else if (method == "set") {
232  nargchk(nrhs==4 && nlhs==0);
233  string prop(rhs[2].toString());
234  if (prop == "AlgorithmType")
235  obj->setAlgorithmType(KNNAlgType[rhs[3].toString()]);
236  else if (prop == "DefaultK")
237  obj->setDefaultK(rhs[3].toInt());
238  else if (prop == "Emax")
239  obj->setEmax(rhs[3].toInt());
240  else if (prop == "IsClassifier")
241  obj->setIsClassifier(rhs[3].toBool());
242  else
243  mexErrMsgIdAndTxt("mexopencv:error",
244  "Unrecognized property %s", prop.c_str());
245  }
246  else
247  mexErrMsgIdAndTxt("mexopencv:error",
248  "Unrecognized operation %s", method.c_str());
249 }
virtual bool isTrained() const=0
map< int, Ptr< KNearest > > obj_
Object container.
Definition: KNearest_.cpp:20
LIBMWMEX_API_EXTERN_C void mexLock(void)
Lock a MEX-function so that it cannot be cleared from memory.
const ConstMap< std::string, int > KNNAlgType
Option values for KNearest algorithm type.
Definition: KNearest_.cpp:23
virtual int getEmax() const=0
virtual void setAlgorithmType(int val)=0
STL namespace.
T end(T... args)
virtual bool isOpened() const
struct mxArray_tag mxArray
Forward declaration for mxArray.
Definition: matrix.h:259
STL class.
int last_id
Last object id to allocate.
Definition: KNearest_.cpp:18
virtual void setDefaultK(int val)=0
virtual bool train(const Ptr< TrainData > &trainData, int flags=0)
cv::Ptr< cv::ml::TrainData > loadTrainData(const std::string &filename, std::vector< MxArray >::const_iterator first, std::vector< MxArray >::const_iterator last)
Read a dataset from a CSV file.
virtual void clear()
virtual String releaseAndGetString()
#define CV_32F
virtual int getDefaultK() const=0
virtual void write(FileStorage &fs) const
InputOutputArray noArray()
#define UPDATE_FLAG(NUM, TF, BIT)
set or clear a bit in flag depending on bool value
Definition: mexopencv.hpp:174
LIBMWMEX_API_EXTERN_C void mexErrMsgIdAndTxt(const char *identifier, const char *err_msg,...)
Issue formatted error message with corresponding error identifier and return to MATLAB prompt...
const ConstMap< int, std::string > InvKNNAlgType
Option values for inverse KNearest algorithm type.
Definition: KNearest_.cpp:28
LIBMWMEX_API_EXTERN_C void mexUnlock(void)
Unlock a locked MEX-function so that it can be cleared from memory.
mxArray object wrapper for data conversion and manipulation.
Definition: MxArray.hpp:123
void nargchk(bool cond)
Alias for input/output arguments number check.
Definition: mexopencv.hpp:181
virtual void setIsClassifier(bool val)=0
STL class.
bool empty() const
#define CV_32S
virtual String getDefaultName() const
Global constant definitions.
T begin(T... args)
virtual int getAlgorithmType() const=0
virtual void setEmax(int val)=0
virtual float predict(InputArray samples, OutputArray results=noArray(), int flags=0) const=0
T c_str(T... args)
virtual bool empty() const
virtual float findNearest(InputArray samples, int k, OutputArray results, OutputArray neighborResponses=noArray(), OutputArray dist=noArray()) const=0
virtual bool isClassifier() const=0
virtual float calcError(const Ptr< TrainData > &data, bool test, OutputArray resp) const
virtual void save(const String &filename) const
cv::Ptr< cv::ml::TrainData > createTrainData(const cv::Mat &samples, const cv::Mat &responses, std::vector< MxArray >::const_iterator first, std::vector< MxArray >::const_iterator last)
Create an instance of TrainData using options in arguments.
std::map wrapper with one-line initialization and lookup method.
Definition: MxArray.hpp:927
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
Main entry called from Matlab.
Definition: KNearest_.cpp:40
void create(int arows, int acols, int atype, Target target=ARRAY_BUFFER, bool autoRelease=false)
Common definitions for the ml module.
virtual int getVarCount() const=0
cv::Mat toMat() const
virtual bool getIsClassifier() const=0