57 int id = rhs[0].toInt();
58 string method(rhs[1].toString());
61 if (method ==
"new") {
73 if (method ==
"delete") {
78 else if (method ==
"clear") {
82 else if (method ==
"load") {
83 nargchk(nrhs>=3 && (nrhs%2)==1 && nlhs==0);
85 bool loadFromString =
false;
86 for (
int i=3; i<nrhs; i+=2) {
87 string key(rhs[i].toString());
89 objname = rhs[i+1].toString();
90 else if (key ==
"FromString")
91 loadFromString = rhs[i+1].toBool();
94 "Unrecognized option %s", key.
c_str());
97 obj_[id] = (loadFromString ?
98 Algorithm::loadFromString<SVMSGD>(rhs[2].toString(), objname) :
99 Algorithm::load<SVMSGD>(rhs[2].toString(), objname));
101 else if (method ==
"save") {
103 string fname(rhs[2].toString());
106 FileStorage fs(fname, FileStorage::WRITE + FileStorage::MEMORY);
118 else if (method ==
"empty") {
122 else if (method ==
"getDefaultName") {
126 else if (method ==
"getVarCount") {
130 else if (method ==
"isClassifier") {
134 else if (method ==
"isTrained") {
138 else if (method ==
"train") {
139 nargchk(nrhs>=4 && (nrhs%2)==0 && nlhs<=1);
142 for (
int i=4; i<nrhs; i+=2) {
143 string key(rhs[i].toString());
145 dataOptions = rhs[i+1].toVector<
MxArray>();
146 else if (key ==
"Flags")
147 flags = rhs[i+1].toInt();
150 "Unrecognized option %s", key.
c_str());
155 dataOptions.
begin(), dataOptions.
end());
160 dataOptions.
begin(), dataOptions.
end());
161 bool b = obj->
train(data, flags);
164 else if (method ==
"calcError") {
165 nargchk(nrhs>=4 && (nrhs%2)==0 && nlhs<=2);
168 for (
int i=4; i<nrhs; i+=2) {
169 string key(rhs[i].toString());
171 dataOptions = rhs[i+1].toVector<
MxArray>();
172 else if (key ==
"TestError")
173 test = rhs[i+1].toBool();
176 "Unrecognized option %s", key.
c_str());
181 dataOptions.
begin(), dataOptions.
end());
186 dataOptions.
begin(), dataOptions.
end());
193 else if (method ==
"predict") {
194 nargchk(nrhs>=3 && (nrhs%2)==1 && nlhs<=2);
196 for (
int i=3; i<nrhs; i+=2) {
197 string key(rhs[i].toString());
199 flags = rhs[i+1].toInt();
200 else if (key ==
"RawOutput")
201 UPDATE_FLAG(flags, rhs[i+1].toBool(), StatModel::RAW_OUTPUT);
204 "Unrecognized option %s", key.
c_str());
208 float f = obj->
predict(samples, results, flags);
213 else if (method ==
"getShift") {
217 else if (method ==
"getWeights") {
221 else if (method ==
"setOptimalParameters") {
222 nargchk(nrhs>=2 && (nrhs%2)==0 && nlhs==0);
225 for (
int i=2; i<nrhs; i+=2) {
226 string key(rhs[i].toString());
227 if (key ==
"SvmsgdType")
229 else if (key ==
"MarginType")
233 "Unrecognized option %s", key.
c_str());
237 else if (method ==
"get") {
239 string prop(rhs[2].toString());
240 if (prop ==
"InitialStepSize")
242 else if (prop ==
"MarginRegularization")
244 else if (prop ==
"MarginType")
246 else if (prop ==
"StepDecreasingPower")
248 else if (prop ==
"SvmsgdType")
250 else if (prop ==
"TermCriteria")
254 "Unrecognized property %s", prop.
c_str());
256 else if (method ==
"set") {
258 string prop(rhs[2].toString());
259 if (prop ==
"InitialStepSize")
261 else if (prop ==
"MarginRegularization")
263 else if (prop ==
"MarginType")
265 else if (prop ==
"StepDecreasingPower")
267 else if (prop ==
"SvmsgdType")
269 else if (prop ==
"TermCriteria")
273 "Unrecognized property %s", prop.
c_str());
277 "Unrecognized operation %s", method.
c_str());
virtual void setMarginType(int marginType)=0
virtual bool isTrained() const=0
int last_id
Last object id to allocate.
virtual void setSvmsgdType(int svmsgdType)=0
LIBMWMEX_API_EXTERN_C void mexLock(void)
Lock a MEX-function so that it cannot be cleared from memory.
virtual void setInitialStepSize(float InitialStepSize)=0
virtual int getMarginType() const=0
const ConstMap< string, int > SvmsgdTypeMap
Option values for SVMSGD types.
virtual void setStepDecreasingPower(float stepDecreasingPower)=0
virtual void setOptimalParameters(int svmsgdType=SVMSGD::ASGD, int marginType=SVMSGD::SOFT_MARGIN)=0
virtual bool isOpened() const
virtual void setMarginRegularization(float marginRegularization)=0
struct mxArray_tag mxArray
Forward declaration for mxArray.
virtual float getMarginRegularization() const=0
virtual bool train(const Ptr< TrainData > &trainData, int flags=0)
cv::Ptr< cv::ml::TrainData > loadTrainData(const std::string &filename, std::vector< MxArray >::const_iterator first, std::vector< MxArray >::const_iterator last)
Read a dataset from a CSV file.
virtual String releaseAndGetString()
virtual void write(FileStorage &fs) const
map< int, Ptr< SVMSGD > > obj_
Object container.
InputOutputArray noArray()
#define UPDATE_FLAG(NUM, TF, BIT)
set or clear a bit in flag depending on bool value
LIBMWMEX_API_EXTERN_C void mexErrMsgIdAndTxt(const char *identifier, const char *err_msg,...)
Issue formatted error message with corresponding error identifier and return to MATLAB prompt...
virtual float getShift()=0
LIBMWMEX_API_EXTERN_C void mexUnlock(void)
Unlock a locked MEX-function so that it can be cleared from memory.
virtual float getStepDecreasingPower() const=0
const ConstMap< string, int > MarginTypeMap
Option values for margin types.
mxArray object wrapper for data conversion and manipulation.
void nargchk(bool cond)
Alias for input/output arguments number check.
virtual Mat getWeights()=0
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
Main entry called from Matlab.
virtual String getDefaultName() const
Global constant definitions.
virtual int getSvmsgdType() const=0
virtual float predict(InputArray samples, OutputArray results=noArray(), int flags=0) const=0
const ConstMap< int, string > InvMarginTypeMap
Option values for inverse margin types.
virtual bool empty() const
virtual float getInitialStepSize() const=0
const ConstMap< int, string > InvSvmsgdTypeMap
Option values for inverse SVMSGD types.
virtual bool isClassifier() const=0
virtual float calcError(const Ptr< TrainData > &data, bool test, OutputArray resp) const
virtual void save(const String &filename) const
virtual TermCriteria getTermCriteria() const=0
cv::Ptr< cv::ml::TrainData > createTrainData(const cv::Mat &samples, const cv::Mat &responses, std::vector< MxArray >::const_iterator first, std::vector< MxArray >::const_iterator last)
Create an instance of TrainData using options in arguments.
std::map wrapper with one-line initialization and lookup method.
virtual void setTermCriteria(const cv::TermCriteria &val)=0
void create(int arows, int acols, int atype, Target target=ARRAY_BUFFER, bool autoRelease=false)
Common definitions for the ml module.
virtual int getVarCount() const=0