mexopencv  3.4.1
MEX interface for OpenCV library
mexopencv_ml.cpp
Go to the documentation of this file.
1 
8 #include "mexopencv_ml.hpp"
9 using std::vector;
10 using std::string;
11 using namespace cv;
12 using namespace cv::ml;
13 
14 
15 // ==================== XXX ====================
16 
19  ("Row", cv::ml::ROW_SAMPLE) // each training sample is a row of samples
20  ("Col", cv::ml::COL_SAMPLE); // each training sample occupies a column of samples
21 
24  ("Numerical", cv::ml::VAR_NUMERICAL) // same as VAR_ORDERED
25  ("Ordered", cv::ml::VAR_ORDERED) // ordered variables
26  ("Categorical", cv::ml::VAR_CATEGORICAL) // categorical variables
27  ("N", cv::ml::VAR_NUMERICAL) // shorthand for (N)umerical
28  ("O", cv::ml::VAR_ORDERED) // shorthand for (O)rdered
29  ("C", cv::ml::VAR_CATEGORICAL); // shorthand for (C)ategorical
30 
31 
32 // ==================== XXX ====================
33 
35 {
36  const char* fields[] = {"value", "classIdx", "parent", "left", "right",
37  "defaultDir", "split"};
38  MxArray s = MxArray::Struct(fields, 7, 1, nodes.size());
39  for (size_t i = 0; i < nodes.size(); ++i) {
40  s.set("value", nodes[i].value, i);
41  s.set("classIdx", nodes[i].classIdx, i);
42  s.set("parent", nodes[i].parent, i);
43  s.set("left", nodes[i].left, i);
44  s.set("right", nodes[i].right, i);
45  s.set("defaultDir", nodes[i].defaultDir, i);
46  s.set("split", nodes[i].split, i);
47  }
48  return s;
49 }
50 
52 {
53  const char* fields[] = {"varIdx", "inversed", "quality", "next", "c",
54  "subsetOfs"};
55  MxArray s = MxArray::Struct(fields, 6, 1, splits.size());
56  for (size_t i = 0; i < splits.size(); ++i) {
57  s.set("varIdx", splits[i].varIdx, i);
58  s.set("inversed", splits[i].inversed, i);
59  s.set("quality", splits[i].quality, i);
60  s.set("next", splits[i].next, i);
61  s.set("c", splits[i].c, i);
62  s.set("subsetOfs", splits[i].subsetOfs, i);
63  }
64  return s;
65 }
66 
67 
68 // ==================== XXX ====================
69 
71  const Mat& samples, const Mat& responses,
74 {
75  nargchk((std::distance(first, last) % 2) == 0);
76  int layout = cv::ml::ROW_SAMPLE;
77  Mat varIdx, sampleIdx, sampleWeights, varType;
78  Mat missing; //TODO: currently not possible through TrainData interface
79  int splitCount = -1; // [0, nsamples)
80  double splitRatio = -1.0; // [0.0, 1.0)
81  bool splitShuffle = true;
82  for (; first != last; first += 2) {
83  string key(first->toString());
84  const MxArray& val = *(first + 1);
85  if (key == "Layout")
86  layout = SampleTypesMap[val.toString()];
87  else if (key == "VarIdx")
88  varIdx = val.toMat(
89  (val.isUint8() || val.isLogical()) ? CV_8U : CV_32S);
90  else if (key == "SampleIdx")
91  sampleIdx = val.toMat(
92  (val.isUint8() || val.isLogical()) ? CV_8U : CV_32S);
93  else if (key == "SampleWeights")
94  sampleWeights = val.toMat(CV_32F);
95  else if (key == "VarType") {
96  if (val.isCell()) {
97  vector<string> vtypes(val.toVector<string>());
98  varType.create(1, vtypes.size(), CV_8U);
99  for (size_t idx = 0; idx < vtypes.size(); idx++)
100  varType.at<uchar>(idx) = VariableTypeMap[vtypes[idx]];
101  }
102  else if (val.isChar()) {
103  string str(val.toString());
104  varType.create(1, str.size(), CV_8U);
105  for (size_t idx = 0; idx < str.size(); idx++)
106  varType.at<uchar>(idx) = VariableTypeMap[string(1,str[idx])];
107  }
108  else if (val.isNumeric())
109  varType = val.toMat(CV_8U);
110  else
111  mexErrMsgIdAndTxt("mexopencv:error", "Invalid VarType value");
112  }
113  else if (key == "MissingMask")
114  missing = val.toMat(CV_8U); //TODO: unused, see TrainData::setData
115  else if (key == "TrainTestSplitCount")
116  splitCount = val.toInt();
117  else if (key == "TrainTestSplitRatio")
118  splitRatio = val.toDouble();
119  else if (key == "TrainTestSplitShuffle")
120  splitShuffle = val.toBool();
121  else
122  mexErrMsgIdAndTxt("mexopencv:error",
123  "Unrecognized option %s", key.c_str());
124  }
125  Ptr<TrainData> p = TrainData::create(samples, layout, responses,
126  varIdx, sampleIdx, sampleWeights, varType);
127  if (splitCount >= 0)
128  p->setTrainTestSplit(splitCount, splitShuffle);
129  else if (splitRatio >= 0)
130  p->setTrainTestSplitRatio(splitRatio, splitShuffle);
131  return p;
132 }
133 
134 Ptr<TrainData> loadTrainData(const string& filename,
137 {
138  nargchk((std::distance(first, last) % 2) == 0);
139  int headerLineCount = 1;
140  int responseStartIdx = -1;
141  int responseEndIdx = -1;
142  string varTypeSpec;
143  char delimiter = ',';
144  char missch = '?';
145  int splitCount = -1; // [0, nsamples)
146  double splitRatio = -1.0; // [0.0, 1.0)
147  bool splitShuffle = true;
148  for (; first != last; first += 2) {
149  string key(first->toString());
150  const MxArray& val = *(first + 1);
151  if (key == "HeaderLineCount")
152  headerLineCount = val.toInt();
153  else if (key == "ResponseStartIdx")
154  responseStartIdx = val.toInt();
155  else if (key == "ResponseEndIdx")
156  responseEndIdx = val.toInt();
157  else if (key == "VarTypeSpec")
158  varTypeSpec = val.toString();
159  else if (key == "Delimiter")
160  delimiter = (!val.isEmpty()) ? val.toString()[0] : ' ';
161  else if (key == "Missing")
162  missch = (!val.isEmpty()) ? val.toString()[0] : '?';
163  else if (key == "TrainTestSplitCount")
164  splitCount = val.toInt();
165  else if (key == "TrainTestSplitRatio")
166  splitRatio = val.toDouble();
167  else if (key == "TrainTestSplitShuffle")
168  splitShuffle = val.toBool();
169  else
170  mexErrMsgIdAndTxt("mexopencv:error",
171  "Unrecognized option %s", key.c_str());
172  }
173  Ptr<TrainData> p = TrainData::loadFromCSV(filename, headerLineCount,
174  responseStartIdx, responseEndIdx, varTypeSpec, delimiter, missch);
175  if (p.empty())
176  mexErrMsgIdAndTxt("mexopencv:error",
177  "Failed to load dataset '%s'", filename.c_str());
178  if (splitCount >= 0)
179  p->setTrainTestSplit(splitCount, splitShuffle);
180  else if (splitRatio >= 0)
181  p->setTrainTestSplitRatio(splitRatio, splitShuffle);
182  return p;
183 
184 }
int toInt() const
Convert MxArray to int.
Definition: MxArray.cpp:489
T distance(T... args)
void split(const Mat &src, Mat *mvbegin)
static Ptr< TrainData > create(InputArray samples, int layout, InputArray responses, InputArray varIdx=noArray(), InputArray sampleIdx=noArray(), InputArray sampleWeights=noArray(), InputArray varType=noArray())
const ConstMap< string, int > VariableTypeMap
Option values for variable types.
#define CV_8U
MxArray toStruct(const vector< DTrees::Node > &nodes)
Convert tree nodes to struct array.
void set(mwIndex index, const T &value)
Template for numeric array element write accessor.
Definition: MxArray.hpp:1310
virtual void setTrainTestSplit(int count, bool shuffle=true)=0
STL class.
#define CV_32F
virtual void setTrainTestSplitRatio(double ratio, bool shuffle=true)=0
LIBMWMEX_API_EXTERN_C void mexErrMsgIdAndTxt(const char *identifier, const char *err_msg,...)
Issue formatted error message with corresponding error identifier and return to MATLAB prompt...
unsigned char uchar
mxArray object wrapper for data conversion and manipulation.
Definition: MxArray.hpp:123
void nargchk(bool cond)
Alias for input/output arguments number check.
Definition: mexopencv.hpp:181
static MxArray Struct(const char **fields=NULL, int nfields=0, mwSize m=1, mwSize n=1)
Create a new struct array.
Definition: MxArray.hpp:312
static Ptr< TrainData > loadFromCSV(const String &filename, int headerLineCount, int responseStartIdx=-1, int responseEndIdx=-1, const String &varTypeSpec=String(), char delimiter=',', char missch='?')
T size(T... args)
STL class.
bool empty() const
#define CV_32S
T c_str(T... args)
Ptr< TrainData > createTrainData(const Mat &samples, const Mat &responses, vector< MxArray >::const_iterator first, vector< MxArray >::const_iterator last)
Create an instance of TrainData using options in arguments.
void create(int rows, int cols, int type)
const ConstMap< string, int > SampleTypesMap
Option values for sample layouts.
Ptr< TrainData > loadTrainData(const string &filename, vector< MxArray >::const_iterator first, vector< MxArray >::const_iterator last)
Read a dataset from a CSV file.
std::map wrapper with one-line initialization and lookup method.
Definition: MxArray.hpp:927
_Tp & at(int i0=0)
Common definitions for the ml module.