mexopencv  3.4.1
MEX interface for OpenCV library
ANN_MLP_.cpp
Go to the documentation of this file.
1 
8 #include "mexopencv.hpp"
9 #include "mexopencv_ml.hpp"
10 #include "opencv2/ml.hpp"
11 using namespace std;
12 using namespace cv;
13 using namespace cv::ml;
14 
15 // Persistent objects
16 namespace {
18 int last_id = 0;
21 
24  ("Backprop", cv::ml::ANN_MLP::BACKPROP)
25  ("RProp", cv::ml::ANN_MLP::RPROP)
26  ("Anneal", cv::ml::ANN_MLP::ANNEAL);
27 
30  (cv::ml::ANN_MLP::BACKPROP, "Backprop")
31  (cv::ml::ANN_MLP::RPROP, "RProp")
32  (cv::ml::ANN_MLP::ANNEAL, "Anneal");
33 
36  ("Identity", cv::ml::ANN_MLP::IDENTITY)
37  ("Sigmoid", cv::ml::ANN_MLP::SIGMOID_SYM)
38  ("Gaussian", cv::ml::ANN_MLP::GAUSSIAN)
39  ("ReLU", cv::ml::ANN_MLP::RELU)
40  ("LeakyReLU", cv::ml::ANN_MLP::LEAKYRELU);
41 
44  (cv::ml::ANN_MLP::IDENTITY, "Identity")
45  (cv::ml::ANN_MLP::SIGMOID_SYM, "Sigmoid")
46  (cv::ml::ANN_MLP::GAUSSIAN, "Gaussian")
47  (cv::ml::ANN_MLP::RELU, "ReLU")
48  (cv::ml::ANN_MLP::LEAKYRELU, "LeakyReLU");
49 }
50 
58 void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
59 {
60  // Check the number of arguments
61  nargchk(nrhs>=2 && nlhs<=2);
62 
63  // Argument vector
64  vector<MxArray> rhs(prhs, prhs+nrhs);
65  int id = rhs[0].toInt();
66  string method(rhs[1].toString());
67 
68  // Constructor is called. Create a new object from argument
69  if (method == "new") {
70  nargchk(nrhs==2 && nlhs<=1);
72  plhs[0] = MxArray(last_id);
73  mexLock();
74  return;
75  }
76 
77  // Big operation switch
78  Ptr<ANN_MLP> obj = obj_[id];
79  if (obj.empty())
80  mexErrMsgIdAndTxt("mexopencv:error", "Object not found id=%d", id);
81  if (method == "delete") {
82  nargchk(nrhs==2 && nlhs==0);
83  obj_.erase(id);
84  mexUnlock();
85  }
86  else if (method == "clear") {
87  nargchk(nrhs==2 && nlhs==0);
88  obj->clear();
89  }
90  else if (method == "load") {
91  nargchk(nrhs>=3 && (nrhs%2)==1 && nlhs==0);
92  string objname;
93  bool loadFromString = false;
94  for (int i=3; i<nrhs; i+=2) {
95  string key(rhs[i].toString());
96  if (key == "ObjName")
97  objname = rhs[i+1].toString();
98  else if (key == "FromString")
99  loadFromString = rhs[i+1].toBool();
100  else
101  mexErrMsgIdAndTxt("mexopencv:error",
102  "Unrecognized option %s", key.c_str());
103  }
104  //obj_[id] = ANN_MLP::load(rhs[2].toString());
105  obj_[id] = (loadFromString ?
106  Algorithm::loadFromString<ANN_MLP>(rhs[2].toString(), objname) :
107  Algorithm::load<ANN_MLP>(rhs[2].toString(), objname));
108  }
109  else if (method == "save") {
110  nargchk(nrhs==3 && nlhs<=1);
111  string fname(rhs[2].toString());
112  if (nlhs > 0) {
113  // write to memory, and return string
114  FileStorage fs(fname, FileStorage::WRITE + FileStorage::MEMORY);
115  if (!fs.isOpened())
116  mexErrMsgIdAndTxt("mexopencv:error", "Failed to open file");
117  fs << obj->getDefaultName() << "{";
118  obj->write(fs);
119  fs << "}";
120  plhs[0] = MxArray(fs.releaseAndGetString());
121  }
122  else
123  // write to disk
124  obj->save(fname);
125  }
126  else if (method == "empty") {
127  nargchk(nrhs==2 && nlhs<=1);
128  plhs[0] = MxArray(obj->empty());
129  }
130  else if (method == "getDefaultName") {
131  nargchk(nrhs==2 && nlhs<=1);
132  plhs[0] = MxArray(obj->getDefaultName());
133  }
134  else if (method == "getVarCount") {
135  nargchk(nrhs==2 && nlhs<=1);
136  plhs[0] = MxArray(obj->getVarCount());
137  }
138  else if (method == "isClassifier") {
139  nargchk(nrhs==2 && nlhs<=1);
140  plhs[0] = MxArray(obj->isClassifier());
141  }
142  else if (method == "isTrained") {
143  nargchk(nrhs==2 && nlhs<=1);
144  plhs[0] = MxArray(obj->isTrained());
145  }
146  else if (method == "train") {
147  nargchk(nrhs>=4 && (nrhs%2)==0 && nlhs<=1);
148  vector<MxArray> dataOptions;
149  int flags = 0;
150  for (int i=4; i<nrhs; i+=2) {
151  string key(rhs[i].toString());
152  if (key == "Data")
153  dataOptions = rhs[i+1].toVector<MxArray>();
154  else if (key == "Flags")
155  flags = rhs[i+1].toInt();
156  else if (key == "UpdateWeights")
157  UPDATE_FLAG(flags, rhs[i+1].toBool(), ANN_MLP::UPDATE_WEIGHTS);
158  else if (key == "NoInputScale")
159  UPDATE_FLAG(flags, rhs[i+1].toBool(), ANN_MLP::NO_INPUT_SCALE);
160  else if (key == "NoOutputScale")
161  UPDATE_FLAG(flags, rhs[i+1].toBool(), ANN_MLP::NO_OUTPUT_SCALE);
162  else
163  mexErrMsgIdAndTxt("mexopencv:error",
164  "Unrecognized option %s", key.c_str());
165  }
166  Ptr<TrainData> data;
167  if (rhs[2].isChar())
168  data = loadTrainData(rhs[2].toString(),
169  dataOptions.begin(), dataOptions.end());
170  else
171  data = createTrainData(
172  rhs[2].toMat(CV_32F),
173  rhs[3].toMat(CV_32F),
174  dataOptions.begin(), dataOptions.end());
175  bool b = obj->train(data, flags);
176  plhs[0] = MxArray(b);
177  }
178  else if (method == "calcError") {
179  nargchk(nrhs>=4 && (nrhs%2)==0 && nlhs<=2);
180  vector<MxArray> dataOptions;
181  bool test = false;
182  for (int i=4; i<nrhs; i+=2) {
183  string key(rhs[i].toString());
184  if (key == "Data")
185  dataOptions = rhs[i+1].toVector<MxArray>();
186  else if (key == "TestError")
187  test = rhs[i+1].toBool();
188  else
189  mexErrMsgIdAndTxt("mexopencv:error",
190  "Unrecognized option %s", key.c_str());
191  }
192  Ptr<TrainData> data;
193  if (rhs[2].isChar())
194  data = loadTrainData(rhs[2].toString(),
195  dataOptions.begin(), dataOptions.end());
196  else
197  data = createTrainData(
198  rhs[2].toMat(CV_32F),
199  rhs[3].toMat(CV_32F),
200  dataOptions.begin(), dataOptions.end());
201  Mat resp;
202  float err = obj->calcError(data, test, (nlhs>1 ? resp : noArray()));
203  plhs[0] = MxArray(err);
204  if (nlhs>1)
205  plhs[1] = MxArray(resp);
206  }
207  else if (method == "predict") {
208  nargchk(nrhs>=3 && (nrhs%2)==1 && nlhs<=2);
209  int flags = 0;
210  for (int i=3; i<nrhs; i+=2) {
211  string key(rhs[i].toString());
212  if (key == "Flags")
213  flags = rhs[i+1].toInt();
214  else
215  mexErrMsgIdAndTxt("mexopencv:error",
216  "Unrecognized option %s", key.c_str());
217  }
218  Mat samples(rhs[2].toMat(CV_32F)),
219  results;
220  float f = obj->predict(samples, results, flags);
221  plhs[0] = MxArray(results);
222  if (nlhs>1)
223  plhs[1] = MxArray(f);
224  }
225  else if (method == "getWeights") {
226  nargchk(nrhs==3 && nlhs<=1);
227  int layerIdx = rhs[2].toInt();
228  plhs[0] = MxArray(obj->getWeights(layerIdx));
229  }
230  else if (method == "setActivationFunction" || method == "setTrainMethod") {
231  nargchk(nrhs>=3 && (nrhs%2)==1 && nlhs==0);
232  double param1 = 0,
233  param2 = 0;
234  for (int i=3; i<nrhs; i+=2) {
235  string key(rhs[i].toString());
236  if (key == "Param1")
237  param1 = rhs[i+1].toDouble();
238  else if (key == "Param2")
239  param2 = rhs[i+1].toDouble();
240  else
241  mexErrMsgIdAndTxt("mexopencv:error",
242  "Unrecognized option %s", key.c_str());
243  }
244  if (method == "setActivationFunction") {
245  int type = ActivateFunc[rhs[2].toString()];
246  obj->setActivationFunction(type, param1, param2);
247  }
248  else {
249  int tmethod = ANN_MLPTrain[rhs[2].toString()];
250  obj->setTrainMethod(tmethod, param1, param2);
251  }
252  }
253  else if (method == "get") {
254  nargchk(nrhs==3 && nlhs<=1);
255  string prop(rhs[2].toString());
256  if (prop == "TrainMethod")
257  plhs[0] = MxArray(InvANN_MLPTrain[obj->getTrainMethod()]);
258  else if (prop == "LayerSizes")
259  plhs[0] = MxArray(obj->getLayerSizes());
260  else if (prop == "TermCriteria")
261  plhs[0] = MxArray(obj->getTermCriteria());
262  else if (prop == "BackpropWeightScale")
263  plhs[0] = MxArray(obj->getBackpropWeightScale());
264  else if (prop == "BackpropMomentumScale")
265  plhs[0] = MxArray(obj->getBackpropMomentumScale());
266  else if (prop == "RpropDW0")
267  plhs[0] = MxArray(obj->getRpropDW0());
268  else if (prop == "RpropDWPlus")
269  plhs[0] = MxArray(obj->getRpropDWPlus());
270  else if (prop == "RpropDWMinus")
271  plhs[0] = MxArray(obj->getRpropDWMinus());
272  else if (prop == "RpropDWMin")
273  plhs[0] = MxArray(obj->getRpropDWMin());
274  else if (prop == "RpropDWMax")
275  plhs[0] = MxArray(obj->getRpropDWMax());
276  else if (prop == "AnnealInitialT")
277  plhs[0] = MxArray(obj->getAnnealInitialT());
278  else if (prop == "AnnealFinalT")
279  plhs[0] = MxArray(obj->getAnnealFinalT());
280  else if (prop == "AnnealCoolingRatio")
281  plhs[0] = MxArray(obj->getAnnealCoolingRatio());
282  else if (prop == "AnnealItePerStep")
283  plhs[0] = MxArray(obj->getAnnealItePerStep());
284  else
285  mexErrMsgIdAndTxt("mexopencv:error",
286  "Unrecognized property %s", prop.c_str());
287  }
288  else if (method == "set") {
289  nargchk(nrhs==4 && nlhs==0);
290  string prop(rhs[2].toString());
291  if (prop == "TrainMethod")
292  obj->setTrainMethod(ANN_MLPTrain[rhs[3].toString()]);
293  else if (prop == "ActivationFunction")
294  obj->setActivationFunction(ActivateFunc[rhs[3].toString()]);
295  else if (prop == "LayerSizes")
296  obj->setLayerSizes(rhs[3].toMat());
297  else if (prop == "TermCriteria")
298  obj->setTermCriteria(rhs[3].toTermCriteria());
299  else if (prop == "BackpropWeightScale")
300  obj->setBackpropWeightScale(rhs[3].toDouble());
301  else if (prop == "BackpropMomentumScale")
302  obj->setBackpropMomentumScale(rhs[3].toDouble());
303  else if (prop == "RpropDW0")
304  obj->setRpropDW0(rhs[3].toDouble());
305  else if (prop == "RpropDWPlus")
306  obj->setRpropDWPlus(rhs[3].toDouble());
307  else if (prop == "RpropDWMinus")
308  obj->setRpropDWMinus(rhs[3].toDouble());
309  else if (prop == "RpropDWMin")
310  obj->setRpropDWMin(rhs[3].toDouble());
311  else if (prop == "RpropDWMax")
312  obj->setRpropDWMax(rhs[3].toDouble());
313  else if (prop == "AnnealInitialT")
314  obj->setAnnealInitialT(rhs[3].toDouble());
315  else if (prop == "AnnealFinalT")
316  obj->setAnnealFinalT(rhs[3].toDouble());
317  else if (prop == "AnnealCoolingRatio")
318  obj->setAnnealCoolingRatio(rhs[3].toDouble());
319  else if (prop == "AnnealItePerStep")
320  obj->setAnnealItePerStep(rhs[3].toInt());
321  else
322  mexErrMsgIdAndTxt("mexopencv:error",
323  "Unrecognized property %s", prop.c_str());
324  }
325  else
326  mexErrMsgIdAndTxt("mexopencv:error",
327  "Unrecognized operation %s", method.c_str());
328 }
virtual double getRpropDWMax() const=0
virtual bool isTrained() const=0
virtual double getRpropDWMinus() const=0
virtual double getRpropDWMin() const=0
virtual void setRpropDWMinus(double val)=0
LIBMWMEX_API_EXTERN_C void mexLock(void)
Lock a MEX-function so that it cannot be cleared from memory.
void setAnnealItePerStep(int val)
const ConstMap< string, int > ANN_MLPTrain
Option values for ANN_MLP train types.
Definition: ANN_MLP_.cpp:23
virtual void setRpropDWPlus(double val)=0
virtual void setTermCriteria(TermCriteria val)=0
double getAnnealInitialT() const
STL namespace.
void setAnnealFinalT(double val)
T end(T... args)
virtual bool isOpened() const
struct mxArray_tag mxArray
Forward declaration for mxArray.
Definition: matrix.h:259
virtual void setRpropDWMax(double val)=0
virtual double getRpropDWPlus() const=0
virtual cv::Mat getLayerSizes() const=0
STL class.
const ConstMap< int, string > InvANN_MLPTrain
Inverse option values for ANN_MLP train types.
Definition: ANN_MLP_.cpp:29
const ConstMap< int, string > InvActivateFunc
Inverse option values for ANN_MLP activation function.
Definition: ANN_MLP_.cpp:43
virtual void setBackpropWeightScale(double val)=0
map< int, Ptr< ANN_MLP > > obj_
Object container.
Definition: ANN_MLP_.cpp:20
virtual double getBackpropMomentumScale() const=0
virtual bool train(const Ptr< TrainData > &trainData, int flags=0)
cv::Ptr< cv::ml::TrainData > loadTrainData(const std::string &filename, std::vector< MxArray >::const_iterator first, std::vector< MxArray >::const_iterator last)
Read a dataset from a CSV file.
virtual TermCriteria getTermCriteria() const=0
virtual void clear()
virtual String releaseAndGetString()
#define CV_32F
virtual double getBackpropWeightScale() const=0
virtual void write(FileStorage &fs) const
int getAnnealItePerStep() const
InputOutputArray noArray()
#define UPDATE_FLAG(NUM, TF, BIT)
set or clear a bit in flag depending on bool value
Definition: mexopencv.hpp:174
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
Main entry called from Matlab.
Definition: ANN_MLP_.cpp:58
LIBMWMEX_API_EXTERN_C void mexErrMsgIdAndTxt(const char *identifier, const char *err_msg,...)
Issue formatted error message with corresponding error identifier and return to MATLAB prompt...
virtual void setRpropDWMin(double val)=0
LIBMWMEX_API_EXTERN_C void mexUnlock(void)
Unlock a locked MEX-function so that it can be cleared from memory.
mxArray object wrapper for data conversion and manipulation.
Definition: MxArray.hpp:123
void nargchk(bool cond)
Alias for input/output arguments number check.
Definition: mexopencv.hpp:181
double getAnnealFinalT() const
void setAnnealCoolingRatio(double val)
virtual void setBackpropMomentumScale(double val)=0
virtual int getTrainMethod() const=0
virtual void setActivationFunction(int type, double param1=0, double param2=0)=0
STL class.
bool empty() const
virtual void setLayerSizes(InputArray _layer_sizes)=0
const ConstMap< string, int > ActivateFunc
Option values for ANN_MLP activation function.
Definition: ANN_MLP_.cpp:35
virtual String getDefaultName() const
Global constant definitions.
T begin(T... args)
virtual float predict(InputArray samples, OutputArray results=noArray(), int flags=0) const=0
T c_str(T... args)
virtual bool empty() const
int last_id
Last object id to allocate.
Definition: ANN_MLP_.cpp:18
double getAnnealCoolingRatio() const
virtual void setTrainMethod(int method, double param1=0, double param2=0)=0
virtual void setRpropDW0(double val)=0
virtual bool isClassifier() const=0
void setAnnealInitialT(double val)
virtual float calcError(const Ptr< TrainData > &data, bool test, OutputArray resp) const
virtual double getRpropDW0() const=0
virtual void save(const String &filename) const
cv::Ptr< cv::ml::TrainData > createTrainData(const cv::Mat &samples, const cv::Mat &responses, std::vector< MxArray >::const_iterator first, std::vector< MxArray >::const_iterator last)
Create an instance of TrainData using options in arguments.
int type() const
std::map wrapper with one-line initialization and lookup method.
Definition: MxArray.hpp:927
void create(int arows, int acols, int atype, Target target=ARRAY_BUFFER, bool autoRelease=false)
Common definitions for the ml module.
virtual int getVarCount() const=0
virtual Mat getWeights(int layerIdx) const=0
cv::Mat toMat() const