testSVM.cpp

SVM test program - Boris Mansencal, 2015-03-05 12:25 pm

Download (2.9 kB)

 
1
//g++ -Wall -Wextra `pkg-config --cflags --libs opencv` -o testSVM testSVM.cpp
2
3
#include <iostream>
4
#include <cassert>
5
#include <fstream>
6
7
#include <opencv2/core/core.hpp>
8
#include <opencv2/highgui/highgui.hpp>
9
#include <opencv2/ml/ml.hpp>
10
11
static
12
void
13
readData(const std::string &inputDataFilename, cv::Mat &data, cv::Mat &responses)
14
{
15
  /*
16
    inputDataFilename must contain (in binary form) [x86-64]
17
    2 ints : rows & cols
18
    rows*cols floats for first class
19
    rows*cols floats for second class
20
21
   */
22
23
  std::ifstream in(inputDataFilename.c_str(), std::ios::in|std::ios::binary);
24
  if (! in) {
25
    std::cerr<<"Error: unable to open file: "<<inputDataFilename<<"\n";
26
    exit(EXIT_FAILURE);
27
  }
28
  
29
  int rows=0, cols=0;
30
  in.read((char *)&rows, sizeof(rows));
31
  in.read((char *)&cols, sizeof(cols));
32
  
33
  if (rows <= 0 || cols <= 0) {
34
    std::cerr<<"Error: invalid data from file: "<<inputDataFilename<<"\n";
35
    exit(EXIT_FAILURE);
36
  }
37
38
  data = cv::Mat(rows*2, cols, CV_32FC1);
39
  const size_t size = data.rows*data.cols*sizeof(float);
40
  in.read((char *)data.ptr<float>(0), size);
41
42
  if (! in) {
43
    std::cerr<<"Error: unable to read "<<size<<" Bytes from file: "<<inputDataFilename<<"\n";
44
    exit(EXIT_FAILURE);
45
  }
46
47
48
  responses = cv::Mat::zeros(data.rows, 1, CV_32SC1);
49
  for (int i=data.rows/2; i<data.rows; ++i) {
50
    *(responses.ptr<int>(i)) = 1;
51
  }
52
53
}
54
55
int
56
main(int argc, char *argv[])
57
{
58
  if (argc != 2) {
59
    std::cerr<<"Usage: "<<argv[0]<<" dataFilename\n";
60
    exit(EXIT_FAILURE);
61
  }
62
63
  const char *dataFilename = argv[1];
64
65
  cv::Mat data, responses;
66
  readData(dataFilename, data, responses);
67
68
  const double degree = 0.5;
69
  const double gamma = 1;
70
  const double coef0 = 1;
71
  const double Cvalue = 1;
72
  const double nu = 0.5;
73
  const double p = 0;
74
  const int term_iter = 1000;
75
  const double term_eps = 0.01;
76
77
#if CV_MAJOR_VERSION*100+CV_MINOR_VERSION*10+CV_SUBMINOR_VERSION < 300
78
  int svm_type = CvSVM::NU_SVC;
79
  int kernel_type = CvSVM::RBF;
80
81
  CvSVMParams params;
82
  params.svm_type = svm_type;
83
  params.kernel_type = kernel_type;
84
  params.degree = degree;
85
  params.gamma = gamma;
86
  params.coef0 = coef0;
87
  params.C = Cvalue;
88
  params.nu = nu;
89
  params.p = p;
90
  params.term_crit = cvTermCriteria(CV_TERMCRIT_ITER, term_iter, term_eps);
91
  
92
  CvSVM svmClassifier;
93
  svmClassifier.train(data, responses, cv::Mat(), cv::Mat(), params);
94
95
#else
96
  int svm_type = cv::ml::SVM::NU_SVC;
97
  int kernel_type = cv::ml::SVM::RBF;
98
  
99
  cv::Ptr<cv::ml::SVM> svmClassifier = cv::ml::SVM::create();
100
  svmClassifier->setType(svm_type);
101
  svmClassifier->setKernel(kernel_type);
102
  svmClassifier->setDegree(degree);
103
  svmClassifier->setGamma(gamma);
104
  svmClassifier->setCoef0(coef0);
105
  svmClassifier->setC(Cvalue);
106
  svmClassifier->setNu(nu);
107
  svmClassifier->setP(p);
108
  svmClassifier->setTermCriteria(cv::TermCriteria(cv::TermCriteria::COUNT, term_iter, term_eps));
109
110
  svmClassifier->train(data, cv::ml::ROW_SAMPLE, responses);
111
112
#endif 
113
114
115
  std::cout<<"classif done.\n";
116
117
  return 0;
118
}