mexopencv  3.4.1
MEX interface for OpenCV library
LogisticRegression_.cpp
Go to the documentation of this file.
1 
8 #include "mexopencv.hpp"
9 #include "mexopencv_ml.hpp"
10 #include "opencv2/ml.hpp"
11 using namespace std;
12 using namespace cv;
13 using namespace cv::ml;
14 
15 // Persistent objects
16 namespace {
18 int last_id = 0;
21 
26 
31 
34  ("Disable", cv::ml::LogisticRegression::REG_DISABLE) // Regularization disabled
35  ("L1", cv::ml::LogisticRegression::REG_L1) // L1 norm
36  ("L2", cv::ml::LogisticRegression::REG_L2); // L2 norm
37 
43 }
44 
52 void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
53 {
54  // Check the number of arguments
55  nargchk(nrhs>=2 && nlhs<=2);
56 
57  // Argument vector
58  vector<MxArray> rhs(prhs, prhs+nrhs);
59  int id = rhs[0].toInt();
60  string method(rhs[1].toString());
61 
62  // Constructor is called. Create a new object from argument
63  if (method == "new") {
64  nargchk(nrhs==2 && nlhs<=1);
66  plhs[0] = MxArray(last_id);
67  mexLock();
68  return;
69  }
70 
71  // Big operation switch
72  Ptr<LogisticRegression> obj = obj_[id];
73  if (obj.empty())
74  mexErrMsgIdAndTxt("mexopencv:error", "Object not found id=%d", id);
75  if (method == "delete") {
76  nargchk(nrhs==2 && nlhs==0);
77  obj_.erase(id);
78  mexUnlock();
79  }
80  else if (method == "clear") {
81  nargchk(nrhs==2 && nlhs==0);
82  obj->clear();
83  }
84  else if (method == "load") {
85  nargchk(nrhs>=3 && (nrhs%2)==1 && nlhs==0);
86  string objname;
87  bool loadFromString = false;
88  for (int i=3; i<nrhs; i+=2) {
89  string key(rhs[i].toString());
90  if (key == "ObjName")
91  objname = rhs[i+1].toString();
92  else if (key == "FromString")
93  loadFromString = rhs[i+1].toBool();
94  else
95  mexErrMsgIdAndTxt("mexopencv:error",
96  "Unrecognized option %s", key.c_str());
97  }
98  //obj_[id] = LogisticRegression::load(rhs[2].toString());
99  obj_[id] = (loadFromString ?
100  Algorithm::loadFromString<LogisticRegression>(rhs[2].toString(), objname) :
101  Algorithm::load<LogisticRegression>(rhs[2].toString(), objname));
102  }
103  else if (method == "save") {
104  nargchk(nrhs==3 && nlhs<=1);
105  string fname(rhs[2].toString());
106  if (nlhs > 0) {
107  // write to memory, and return string
108  FileStorage fs(fname, FileStorage::WRITE + FileStorage::MEMORY);
109  if (!fs.isOpened())
110  mexErrMsgIdAndTxt("mexopencv:error", "Failed to open file");
111  fs << obj->getDefaultName() << "{";
112  obj->write(fs);
113  fs << "}";
114  plhs[0] = MxArray(fs.releaseAndGetString());
115  }
116  else
117  // write to disk
118  obj->save(fname);
119  }
120  else if (method == "empty") {
121  nargchk(nrhs==2 && nlhs<=1);
122  plhs[0] = MxArray(obj->empty());
123  }
124  else if (method == "getDefaultName") {
125  nargchk(nrhs==2 && nlhs<=1);
126  plhs[0] = MxArray(obj->getDefaultName());
127  }
128  else if (method == "getVarCount") {
129  nargchk(nrhs==2 && nlhs<=1);
130  plhs[0] = MxArray(obj->getVarCount());
131  }
132  else if (method == "isClassifier") {
133  nargchk(nrhs==2 && nlhs<=1);
134  plhs[0] = MxArray(obj->isClassifier());
135  }
136  else if (method == "isTrained") {
137  nargchk(nrhs==2 && nlhs<=1);
138  plhs[0] = MxArray(obj->isTrained());
139  }
140  else if (method == "train") {
141  nargchk(nrhs>=4 && (nrhs%2)==0 && nlhs<=1);
142  vector<MxArray> dataOptions;
143  int flags = 0;
144  for (int i=4; i<nrhs; i+=2) {
145  string key(rhs[i].toString());
146  if (key == "Data")
147  dataOptions = rhs[i+1].toVector<MxArray>();
148  else if (key == "Flags")
149  flags = rhs[i+1].toInt();
150  else
151  mexErrMsgIdAndTxt("mexopencv:error",
152  "Unrecognized option %s", key.c_str());
153  }
154  Ptr<TrainData> data;
155  if (rhs[2].isChar())
156  data = loadTrainData(rhs[2].toString(),
157  dataOptions.begin(), dataOptions.end());
158  else
159  data = createTrainData(
160  rhs[2].toMat(CV_32F),
161  rhs[3].toMat(CV_32F),
162  dataOptions.begin(), dataOptions.end());
163  bool b = obj->train(data, flags);
164  plhs[0] = MxArray(b);
165  }
166  else if (method == "calcError") {
167  nargchk(nrhs>=4 && (nrhs%2)==0 && nlhs<=2);
168  vector<MxArray> dataOptions;
169  bool test = false;
170  for (int i=4; i<nrhs; i+=2) {
171  string key(rhs[i].toString());
172  if (key == "Data")
173  dataOptions = rhs[i+1].toVector<MxArray>();
174  else if (key == "TestError")
175  test = rhs[i+1].toBool();
176  else
177  mexErrMsgIdAndTxt("mexopencv:error",
178  "Unrecognized option %s", key.c_str());
179  }
180  Ptr<TrainData> data;
181  if (rhs[2].isChar())
182  data = loadTrainData(rhs[2].toString(),
183  dataOptions.begin(), dataOptions.end());
184  else
185  data = createTrainData(
186  rhs[2].toMat(CV_32F),
187  rhs[3].toMat(CV_32F),
188  dataOptions.begin(), dataOptions.end());
189  Mat resp;
190  float err = obj->calcError(data, test, (nlhs>1 ? resp : noArray()));
191  plhs[0] = MxArray(err);
192  if (nlhs>1)
193  plhs[1] = MxArray(resp);
194  }
195  else if (method == "predict") {
196  nargchk(nrhs>=3 && (nrhs%2)==1 && nlhs<=2);
197  int flags = 0;
198  for (int i=3; i<nrhs; i+=2) {
199  string key(rhs[i].toString());
200  if (key == "Flags")
201  flags = rhs[i+1].toInt();
202  else if (key == "RawOutput")
203  UPDATE_FLAG(flags, rhs[i+1].toBool(), StatModel::RAW_OUTPUT);
204  else
205  mexErrMsgIdAndTxt("mexopencv:error",
206  "Unrecognized option %s", key.c_str());
207  }
208  Mat samples(rhs[2].toMat(CV_32F)),
209  results;
210  float f = obj->predict(samples, results, flags);
211  plhs[0] = MxArray(results);
212  if (nlhs>1)
213  plhs[1] = MxArray(f);
214  }
215  else if (method == "get_learnt_thetas") {
216  nargchk(nrhs==2 && nlhs<=1);
217  plhs[0] = MxArray(obj->get_learnt_thetas());
218  }
219  else if (method == "get") {
220  nargchk(nrhs==3 && nlhs<=1);
221  string prop(rhs[2].toString());
222  if (prop == "Iterations")
223  plhs[0] = MxArray(obj->getIterations());
224  else if (prop == "LearningRate")
225  plhs[0] = MxArray(obj->getLearningRate());
226  else if (prop == "MiniBatchSize")
227  plhs[0] = MxArray(obj->getMiniBatchSize());
228  else if (prop == "Regularization")
230  else if (prop == "TermCriteria")
231  plhs[0] = MxArray(obj->getTermCriteria());
232  else if (prop == "TrainMethod")
233  plhs[0] = MxArray(InvTrainingMethodType[obj->getTrainMethod()]);
234  else
235  mexErrMsgIdAndTxt("mexopencv:error",
236  "Unrecognized property %s", prop.c_str());
237  }
238  else if (method == "set") {
239  nargchk(nrhs==4 && nlhs==0);
240  string prop(rhs[2].toString());
241  if (prop == "Iterations")
242  obj->setIterations(rhs[3].toInt());
243  else if (prop == "LearningRate")
244  obj->setLearningRate(rhs[3].toDouble());
245  else if (prop == "MiniBatchSize")
246  obj->setMiniBatchSize(rhs[3].toInt());
247  else if (prop == "Regularization")
248  obj->setRegularization(RegularizationType[rhs[3].toString()]);
249  else if (prop == "TermCriteria")
250  obj->setTermCriteria(rhs[3].toTermCriteria());
251  else if (prop == "TrainMethod")
252  obj->setTrainMethod(TrainingMethodType[rhs[3].toString()]);
253  else
254  mexErrMsgIdAndTxt("mexopencv:error",
255  "Unrecognized property %s", prop.c_str());
256  }
257  else
258  mexErrMsgIdAndTxt("mexopencv:error",
259  "Unrecognized operation %s", method.c_str());
260 }
const ConstMap< int, string > InvTrainingMethodType
Option values for Inverse Training methods.
virtual bool isTrained() const=0
virtual float predict(InputArray samples, OutputArray results=noArray(), int flags=0) const=0
LIBMWMEX_API_EXTERN_C void mexLock(void)
Lock a MEX-function so that it cannot be cleared from memory.
virtual int getIterations() const=0
virtual void setMiniBatchSize(int val)=0
virtual void setTrainMethod(int val)=0
STL namespace.
T end(T... args)
virtual void setLearningRate(double val)=0
virtual bool isOpened() const
virtual double getLearningRate() const=0
struct mxArray_tag mxArray
Forward declaration for mxArray.
Definition: matrix.h:259
STL class.
virtual int getRegularization() const=0
const ConstMap< string, int > RegularizationType
Option values for Regularization kinds.
virtual void setTermCriteria(TermCriteria val)=0
virtual bool train(const Ptr< TrainData > &trainData, int flags=0)
cv::Ptr< cv::ml::TrainData > loadTrainData(const std::string &filename, std::vector< MxArray >::const_iterator first, std::vector< MxArray >::const_iterator last)
Read a dataset from a CSV file.
virtual void clear()
virtual String releaseAndGetString()
virtual int getTrainMethod() const=0
#define CV_32F
const ConstMap< string, int > TrainingMethodType
Option values for Training methods.
virtual void setRegularization(int val)=0
map< int, Ptr< LogisticRegression > > obj_
Object container.
virtual void write(FileStorage &fs) const
InputOutputArray noArray()
#define UPDATE_FLAG(NUM, TF, BIT)
set or clear a bit in flag depending on bool value
Definition: mexopencv.hpp:174
LIBMWMEX_API_EXTERN_C void mexErrMsgIdAndTxt(const char *identifier, const char *err_msg,...)
Issue formatted error message with corresponding error identifier and return to MATLAB prompt...
LIBMWMEX_API_EXTERN_C void mexUnlock(void)
Unlock a locked MEX-function so that it can be cleared from memory.
virtual void setIterations(int val)=0
mxArray object wrapper for data conversion and manipulation.
Definition: MxArray.hpp:123
void nargchk(bool cond)
Alias for input/output arguments number check.
Definition: mexopencv.hpp:181
STL class.
bool empty() const
virtual String getDefaultName() const
virtual Mat get_learnt_thetas() const=0
Global constant definitions.
T begin(T... args)
T c_str(T... args)
virtual bool empty() const
virtual bool isClassifier() const=0
virtual float calcError(const Ptr< TrainData > &data, bool test, OutputArray resp) const
virtual void save(const String &filename) const
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
Main entry called from Matlab.
cv::Ptr< cv::ml::TrainData > createTrainData(const cv::Mat &samples, const cv::Mat &responses, std::vector< MxArray >::const_iterator first, std::vector< MxArray >::const_iterator last)
Create an instance of TrainData using options in arguments.
virtual TermCriteria getTermCriteria() const=0
std::map wrapper with one-line initialization and lookup method.
Definition: MxArray.hpp:927
void create(int arows, int acols, int atype, Target target=ARRAY_BUFFER, bool autoRelease=false)
Common definitions for the ml module.
virtual int getVarCount() const=0
cv::Mat toMat() const
virtual int getMiniBatchSize() const=0
const ConstMap< int, string > InvRegularizationType
Option values for Inverse Regularization kinds.