1 | |
2 | //
|
3 | // IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
|
4 | //
|
5 | // By downloading, copying, installing or using the software you agree to this license.
|
6 | // If you do not agree to this license, do not download, install,
|
7 | // copy or use the software.
|
8 | //
|
9 | //
|
10 | // Intel License Agreement
|
11 | //
|
12 | // Copyright (C) 2000, Intel Corporation, all rights reserved.
|
13 | // Third party copyrights are property of their respective owners.
|
14 | //
|
15 | // Redistribution and use in source and binary forms, with or without modification,
|
16 | // are permitted provided that the following conditions are met:
|
17 | //
|
18 | // * Redistribution's of source code must retain the above copyright notice,
|
19 | // this list of conditions and the following disclaimer.
|
20 | //
|
21 | // * Redistribution's in binary form must reproduce the above copyright notice,
|
22 | // this list of conditions and the following disclaimer in the documentation
|
23 | // and/or other materials provided with the distribution.
|
24 | //
|
25 | // * The name of Intel Corporation may not be used to endorse or promote products
|
26 | // derived from this software without specific prior written permission.
|
27 | //
|
28 | // This software is provided by the copyright holders and contributors "as is" and
|
29 | // any express or implied warranties, including, but not limited to, the implied
|
30 | // warranties of merchantability and fitness for a particular purpose are disclaimed.
|
31 | // In no event shall the Intel Corporation or contributors be liable for any direct,
|
32 | // indirect, incidental, special, exemplary, or consequential damages
|
33 | // (including, but not limited to, procurement of substitute goods or services;
|
34 | // loss of use, data, or profits; or business interruption) however caused
|
35 | // and on any theory of liability, whether in contract, strict liability,
|
36 | // or tort (including negligence or otherwise) arising in any way out of
|
37 | // the use of this software, even if advised of the possibility of such damage.
|
38 | //
|
39 | //M*/
|
40 |
|
41 |
|
42 |
|
43 | #include "precomp.hpp"
|
44 |
|
45 | |
46 | COPYRIGHT NOTICE
|
47 | ----------------
|
48 |
|
49 | The code has been derived from libsvm library (version 2.6)
|
50 | (http://www.csie.ntu.edu.tw/~cjlin/libsvm).
|
51 |
|
52 | Here is the orignal copyright:
|
53 | ------------------------------------------------------------------------------------------
|
54 | Copyright (c) 2000-2003 Chih-Chung Chang and Chih-Jen Lin
|
55 | All rights reserved.
|
56 |
|
57 | Redistribution and use in source and binary forms, with or without
|
58 | modification, are permitted provided that the following conditions
|
59 | are met:
|
60 |
|
61 | 1. Redistributions of source code must retain the above copyright
|
62 | notice, this list of conditions and the following disclaimer.
|
63 |
|
64 | 2. Redistributions in binary form must reproduce the above copyright
|
65 | notice, this list of conditions and the following disclaimer in the
|
66 | documentation and/or other materials provided with the distribution.
|
67 |
|
68 | 3. Neither name of copyright holders nor the names of its contributors
|
69 | may be used to endorse or promote products derived from this software
|
70 | without specific prior written permission.
|
71 |
|
72 |
|
73 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
74 | ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
75 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
76 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR
|
77 | CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
78 | EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
79 | PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
80 | PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
81 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
82 | NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
83 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
84 | \****************************************************************************************/
|
85 |
|
86 | using namespace cv;
|
87 |
|
88 | #define CV_SVM_MIN_CACHE_SIZE (40 << 20)
|
89 |
|
90 | #include <stdarg.h>
|
91 | #include <ctype.h>
|
92 |
|
93 | #if 1
|
94 | typedef float Qfloat;
|
95 | #define QFLOAT_TYPE CV_32F
|
96 | #else
|
97 | typedef double Qfloat;
|
98 | #define QFLOAT_TYPE CV_64F
|
99 | #endif
|
100 |
|
101 |
|
102 | bool CvParamGrid::check() const
|
103 | {
|
104 | bool ok = false;
|
105 |
|
106 | CV_FUNCNAME( "CvParamGrid::check" );
|
107 | __BEGIN__;
|
108 |
|
109 | if( min_val > max_val )
|
110 | CV_ERROR( CV_StsBadArg, "Lower bound of the grid must be less then the upper one" );
|
111 | if( min_val < DBL_EPSILON )
|
112 | CV_ERROR( CV_StsBadArg, "Lower bound of the grid must be positive" );
|
113 | if( step < 1. + FLT_EPSILON )
|
114 | CV_ERROR( CV_StsBadArg, "Grid step must greater then 1" );
|
115 |
|
116 | ok = true;
|
117 |
|
118 | __END__;
|
119 |
|
120 | return ok;
|
121 | }
|
122 |
|
123 | CvParamGrid CvSVM::get_default_grid( int param_id )
|
124 | {
|
125 | CvParamGrid grid;
|
126 | if( param_id == CvSVM::C )
|
127 | {
|
128 | grid.min_val = 0.1;
|
129 | grid.max_val = 500;
|
130 | grid.step = 5;
|
131 | }
|
132 | else if( param_id == CvSVM::GAMMA )
|
133 | {
|
134 | grid.min_val = 1e-5;
|
135 | grid.max_val = 0.6;
|
136 | grid.step = 15;
|
137 | }
|
138 | else if( param_id == CvSVM::P )
|
139 | {
|
140 | grid.min_val = 0.01;
|
141 | grid.max_val = 100;
|
142 | grid.step = 7;
|
143 | }
|
144 | else if( param_id == CvSVM::NU )
|
145 | {
|
146 | grid.min_val = 0.01;
|
147 | grid.max_val = 0.2;
|
148 | grid.step = 3;
|
149 | }
|
150 | else if( param_id == CvSVM::COEF )
|
151 | {
|
152 | grid.min_val = 0.1;
|
153 | grid.max_val = 300;
|
154 | grid.step = 14;
|
155 | }
|
156 | else if( param_id == CvSVM::DEGREE )
|
157 | {
|
158 | grid.min_val = 0.01;
|
159 | grid.max_val = 4;
|
160 | grid.step = 7;
|
161 | }
|
162 | else
|
163 | cvError( CV_StsBadArg, "CvSVM::get_default_grid", "Invalid type of parameter "
|
164 | "(use one of CvSVM::C, CvSVM::GAMMA et al.)", __FILE__, __LINE__ );
|
165 | return grid;
|
166 | }
|
167 |
|
168 |
|
169 | CvSVMParams::CvSVMParams() :
|
170 | svm_type(CvSVM::C_SVC), kernel_type(CvSVM::RBF), degree(0),
|
171 | gamma(1), coef0(0), C(1), nu(0), p(0), class_weights(0)
|
172 | {
|
173 | term_crit = cvTermCriteria( CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, 1000, FLT_EPSILON );
|
174 | }
|
175 |
|
176 |
|
177 | CvSVMParams::CvSVMParams( int _svm_type, int _kernel_type,
|
178 | double _degree, double _gamma, double _coef0,
|
179 | double _Con, double _nu, double _p,
|
180 | CvMat* _class_weights, CvTermCriteria _term_crit ) :
|
181 | svm_type(_svm_type), kernel_type(_kernel_type),
|
182 | degree(_degree), gamma(_gamma), coef0(_coef0),
|
183 | C(_Con), nu(_nu), p(_p), class_weights(_class_weights), term_crit(_term_crit)
|
184 | {
|
185 | }
|
186 |
|
187 |
|
188 |
|
189 |
|
190 | CvSVMKernel::CvSVMKernel()
|
191 | {
|
192 | clear();
|
193 | }
|
194 |
|
195 |
|
196 | void CvSVMKernel::clear()
|
197 | {
|
198 | params = 0;
|
199 | calc_func = 0;
|
200 | }
|
201 |
|
202 |
|
203 | CvSVMKernel::~CvSVMKernel()
|
204 | {
|
205 | }
|
206 |
|
207 |
|
208 | CvSVMKernel::CvSVMKernel( const CvSVMParams* _params, Calc _calc_func )
|
209 | {
|
210 | clear();
|
211 | create( _params, _calc_func );
|
212 | }
|
213 |
|
214 |
|
215 | bool CvSVMKernel::create( const CvSVMParams* _params, Calc _calc_func )
|
216 | {
|
217 | clear();
|
218 | params = _params;
|
219 | calc_func = _calc_func;
|
220 |
|
221 | if( !calc_func )
|
222 | calc_func = params->kernel_type == CvSVM::RBF ? &CvSVMKernel::calc_rbf :
|
223 | params->kernel_type == CvSVM::POLY ? &CvSVMKernel::calc_poly :
|
224 | params->kernel_type == CvSVM::SIGMOID ? &CvSVMKernel::calc_sigmoid :
|
225 | &CvSVMKernel::calc_linear;
|
226 |
|
227 | return true;
|
228 | }
|
229 |
|
230 |
|
231 | void CvSVMKernel::calc_non_rbf_base( int vcount, int var_count, const float** vecs,
|
232 | const float* another, Qfloat* results,
|
233 | double alpha, double beta )
|
234 | {
|
235 | int j, k;
|
236 | for( j = 0; j < vcount; j++ )
|
237 | {
|
238 | const float* sample = vecs[j];
|
239 | double s = 0;
|
240 | for( k = 0; k <= var_count - 4; k += 4 )
|
241 | s += sample[k]*another[k] + sample[k+1]*another[k+1] +
|
242 | sample[k+2]*another[k+2] + sample[k+3]*another[k+3];
|
243 | for( ; k < var_count; k++ )
|
244 | s += sample[k]*another[k];
|
245 | results[j] = (Qfloat)(s*alpha + beta);
|
246 | }
|
247 | }
|
248 |
|
249 |
|
250 | void CvSVMKernel::calc_linear( int vcount, int var_count, const float** vecs,
|
251 | const float* another, Qfloat* results )
|
252 | {
|
253 | calc_non_rbf_base( vcount, var_count, vecs, another, results, 1, 0 );
|
254 | }
|
255 |
|
256 |
|
257 | void CvSVMKernel::calc_poly( int vcount, int var_count, const float** vecs,
|
258 | const float* another, Qfloat* results )
|
259 | {
|
260 | CvMat R = cvMat( 1, vcount, QFLOAT_TYPE, results );
|
261 | calc_non_rbf_base( vcount, var_count, vecs, another, results, params->gamma, params->coef0 );
|
262 | if( vcount > 0 )
|
263 | cvPow( &R, &R, params->degree );
|
264 | }
|
265 |
|
266 |
|
267 | void CvSVMKernel::calc_sigmoid( int vcount, int var_count, const float** vecs,
|
268 | const float* another, Qfloat* results )
|
269 | {
|
270 | int j;
|
271 | calc_non_rbf_base( vcount, var_count, vecs, another, results,
|
272 | -2*params->gamma, -2*params->coef0 );
|
273 |
|
274 | for( j = 0; j < vcount; j++ )
|
275 | {
|
276 | Qfloat t = results[j];
|
277 | double e = exp(-fabs(t));
|
278 | if( t > 0 )
|
279 | results[j] = (Qfloat)((1. - e)/(1. + e));
|
280 | else
|
281 | results[j] = (Qfloat)((e - 1.)/(e + 1.));
|
282 | }
|
283 | }
|
284 |
|
285 |
|
286 | void CvSVMKernel::calc_rbf( int vcount, int var_count, const float** vecs,
|
287 | const float* another, Qfloat* results )
|
288 | {
|
289 | CvMat R = cvMat( 1, vcount, QFLOAT_TYPE, results );
|
290 | double gamma = -params->gamma;
|
291 | int j, k;
|
292 |
|
293 | for( j = 0; j < vcount; j++ )
|
294 | {
|
295 | const float* sample = vecs[j];
|
296 | double s = 0;
|
297 |
|
298 | for( k = 0; k <= var_count - 4; k += 4 )
|
299 | {
|
300 | double t0 = sample[k] - another[k];
|
301 | double t1 = sample[k+1] - another[k+1];
|
302 |
|
303 | s += t0*t0 + t1*t1;
|
304 |
|
305 | t0 = sample[k+2] - another[k+2];
|
306 | t1 = sample[k+3] - another[k+3];
|
307 |
|
308 | s += t0*t0 + t1*t1;
|
309 | }
|
310 |
|
311 | for( ; k < var_count; k++ )
|
312 | {
|
313 | double t0 = sample[k] - another[k];
|
314 | s += t0*t0;
|
315 | }
|
316 | results[j] = (Qfloat)(s*gamma);
|
317 | }
|
318 |
|
319 | if( vcount > 0 )
|
320 | cvExp( &R, &R );
|
321 | }
|
322 |
|
323 |
|
324 | void CvSVMKernel::calc( int vcount, int var_count, const float** vecs,
|
325 | const float* another, Qfloat* results )
|
326 | {
|
327 | const Qfloat max_val = (Qfloat)(FLT_MAX*1e-3);
|
328 | int j;
|
329 | (this->*calc_func)( vcount, var_count, vecs, another, results );
|
330 | for( j = 0; j < vcount; j++ )
|
331 | {
|
332 | if( results[j] > max_val )
|
333 | results[j] = max_val;
|
334 | }
|
335 | }
|
336 |
|
337 |
|
338 |
|
339 |
|
340 |
|
341 |
|
342 |
|
343 |
|
344 |
|
345 |
|
346 |
|
347 |
|
348 |
|
349 |
|
350 |
|
351 |
|
352 |
|
353 |
|
354 |
|
355 |
|
356 |
|
357 | void CvSVMSolver::clear()
|
358 | {
|
359 | G = 0;
|
360 | alpha = 0;
|
361 | y = 0;
|
362 | b = 0;
|
363 | buf[0] = buf[1] = 0;
|
364 | cvReleaseMemStorage( &storage );
|
365 | kernel = 0;
|
366 | select_working_set_func = 0;
|
367 | calc_rho_func = 0;
|
368 |
|
369 | rows = 0;
|
370 | samples = 0;
|
371 | get_row_func = 0;
|
372 | }
|
373 |
|
374 |
|
375 | CvSVMSolver::CvSVMSolver()
|
376 | {
|
377 | storage = 0;
|
378 | clear();
|
379 | }
|
380 |
|
381 |
|
382 | CvSVMSolver::~CvSVMSolver()
|
383 | {
|
384 | clear();
|
385 | }
|
386 |
|
387 |
|
388 | CvSVMSolver::CvSVMSolver( int _sample_count, int _var_count, const float** _samples, schar* _y,
|
389 | int _alpha_count, double* _alpha, double _Cp, double _Cn,
|
390 | CvMemStorage* _storage, CvSVMKernel* _kernel, GetRow _get_row,
|
391 | SelectWorkingSet _select_working_set, CalcRho _calc_rho )
|
392 | {
|
393 | storage = 0;
|
394 | create( _sample_count, _var_count, _samples, _y, _alpha_count, _alpha, _Cp, _Cn,
|
395 | _storage, _kernel, _get_row, _select_working_set, _calc_rho );
|
396 | }
|
397 |
|
398 |
|
399 | bool CvSVMSolver::create( int _sample_count, int _var_count, const float** _samples, schar* _y,
|
400 | int _alpha_count, double* _alpha, double _Cp, double _Cn,
|
401 | CvMemStorage* _storage, CvSVMKernel* _kernel, GetRow _get_row,
|
402 | SelectWorkingSet _select_working_set, CalcRho _calc_rho )
|
403 | {
|
404 | bool ok = false;
|
405 | int i, svm_type;
|
406 |
|
407 | CV_FUNCNAME( "CvSVMSolver::create" );
|
408 |
|
409 | __BEGIN__;
|
410 |
|
411 | int rows_hdr_size;
|
412 |
|
413 | clear();
|
414 |
|
415 | sample_count = _sample_count;
|
416 | var_count = _var_count;
|
417 | samples = _samples;
|
418 | y = _y;
|
419 | alpha_count = _alpha_count;
|
420 | alpha = _alpha;
|
421 | kernel = _kernel;
|
422 |
|
423 | C[0] = _Cn;
|
424 | C[1] = _Cp;
|
425 | eps = kernel->params->term_crit.epsilon;
|
426 | max_iter = kernel->params->term_crit.max_iter;
|
427 | storage = cvCreateChildMemStorage( _storage );
|
428 |
|
429 | b = (double*)cvMemStorageAlloc( storage, alpha_count*sizeof(b[0]));
|
430 | alpha_status = (schar*)cvMemStorageAlloc( storage, alpha_count*sizeof(alpha_status[0]));
|
431 | G = (double*)cvMemStorageAlloc( storage, alpha_count*sizeof(G[0]));
|
432 | for( i = 0; i < 2; i++ )
|
433 | buf[i] = (Qfloat*)cvMemStorageAlloc( storage, sample_count*2*sizeof(buf[i][0]) );
|
434 | svm_type = kernel->params->svm_type;
|
435 |
|
436 | select_working_set_func = _select_working_set;
|
437 | if( !select_working_set_func )
|
438 | select_working_set_func = svm_type == CvSVM::NU_SVC || svm_type == CvSVM::NU_SVR ?
|
439 | &CvSVMSolver::select_working_set_nu_svm : &CvSVMSolver::select_working_set;
|
440 |
|
441 | calc_rho_func = _calc_rho;
|
442 | if( !calc_rho_func )
|
443 | calc_rho_func = svm_type == CvSVM::NU_SVC || svm_type == CvSVM::NU_SVR ?
|
444 | &CvSVMSolver::calc_rho_nu_svm : &CvSVMSolver::calc_rho;
|
445 |
|
446 | get_row_func = _get_row;
|
447 | if( !get_row_func )
|
448 | get_row_func = params->svm_type == CvSVM::EPS_SVR ||
|
449 | params->svm_type == CvSVM::NU_SVR ? &CvSVMSolver::get_row_svr :
|
450 | params->svm_type == CvSVM::C_SVC ||
|
451 | params->svm_type == CvSVM::NU_SVC ? &CvSVMSolver::get_row_svc :
|
452 | &CvSVMSolver::get_row_one_class;
|
453 |
|
454 | cache_line_size = sample_count*sizeof(Qfloat);
|
455 |
|
456 |
|
457 | cache_size = MAX( cache_line_size*sample_count/4, CV_SVM_MIN_CACHE_SIZE );
|
458 |
|
459 |
|
460 | rows_hdr_size = sample_count*sizeof(rows[0]);
|
461 | if( rows_hdr_size > storage->block_size )
|
462 | CV_ERROR( CV_StsOutOfRange, "Too small storage block size" );
|
463 |
|
464 | lru_list.prev = lru_list.next = &lru_list;
|
465 | rows = (CvSVMKernelRow*)cvMemStorageAlloc( storage, rows_hdr_size );
|
466 | memset( rows, 0, rows_hdr_size );
|
467 |
|
468 | ok = true;
|
469 |
|
470 | __END__;
|
471 |
|
472 | return ok;
|
473 | }
|
474 |
|
475 |
|
476 | float* CvSVMSolver::get_row_base( int i, bool* _existed )
|
477 | {
|
478 | int i1 = i < sample_count ? i : i - sample_count;
|
479 | CvSVMKernelRow* row = rows + i1;
|
480 | bool existed = row->data != 0;
|
481 | Qfloat* data;
|
482 |
|
483 | if( existed || cache_size <= 0 )
|
484 | {
|
485 | CvSVMKernelRow* del_row = existed ? row : lru_list.prev;
|
486 | data = del_row->data;
|
487 | assert( data != 0 );
|
488 |
|
489 |
|
490 | del_row->data = 0;
|
491 | del_row->prev->next = del_row->next;
|
492 | del_row->next->prev = del_row->prev;
|
493 | }
|
494 | else
|
495 | {
|
496 | data = (Qfloat*)cvMemStorageAlloc( storage, cache_line_size );
|
497 | cache_size -= cache_line_size;
|
498 | }
|
499 |
|
500 |
|
501 | row->data = data;
|
502 | row->prev = &lru_list;
|
503 | row->next = lru_list.next;
|
504 | row->prev->next = row->next->prev = row;
|
505 |
|
506 | if( !existed )
|
507 | {
|
508 | kernel->calc( sample_count, var_count, samples, samples[i1], row->data );
|
509 | }
|
510 |
|
511 | if( _existed )
|
512 | *_existed = existed;
|
513 |
|
514 | return row->data;
|
515 | }
|
516 |
|
517 |
|
518 | float* CvSVMSolver::get_row_svc( int i, float* row, float*, bool existed )
|
519 | {
|
520 | if( !existed )
|
521 | {
|
522 | const schar* _y = y;
|
523 | int j, len = sample_count;
|
524 | assert( _y && i < sample_count );
|
525 |
|
526 | if( _y[i] > 0 )
|
527 | {
|
528 | for( j = 0; j < len; j++ )
|
529 | row[j] = _y[j]*row[j];
|
530 | }
|
531 | else
|
532 | {
|
533 | for( j = 0; j < len; j++ )
|
534 | row[j] = -_y[j]*row[j];
|
535 | }
|
536 | }
|
537 | return row;
|
538 | }
|
539 |
|
540 |
|
541 | float* CvSVMSolver::get_row_one_class( int, float* row, float*, bool )
|
542 | {
|
543 | return row;
|
544 | }
|
545 |
|
546 |
|
547 | float* CvSVMSolver::get_row_svr( int i, float* row, float* dst, bool )
|
548 | {
|
549 | int j, len = sample_count;
|
550 | Qfloat* dst_pos = dst;
|
551 | Qfloat* dst_neg = dst + len;
|
552 | if( i >= len )
|
553 | {
|
554 | Qfloat* temp;
|
555 | CV_SWAP( dst_pos, dst_neg, temp );
|
556 | }
|
557 |
|
558 | for( j = 0; j < len; j++ )
|
559 | {
|
560 | Qfloat t = row[j];
|
561 | dst_pos[j] = t;
|
562 | dst_neg[j] = -t;
|
563 | }
|
564 | return dst;
|
565 | }
|
566 |
|
567 |
|
568 |
|
569 | float* CvSVMSolver::get_row( int i, float* dst )
|
570 | {
|
571 | bool existed = false;
|
572 | float* row = get_row_base( i, &existed );
|
573 | return (this->*get_row_func)( i, row, dst, existed );
|
574 | }
|
575 |
|
576 |
|
577 | #undef is_upper_bound
|
578 | #define is_upper_bound(i) (alpha_status[i] > 0)
|
579 |
|
580 | #undef is_lower_bound
|
581 | #define is_lower_bound(i) (alpha_status[i] < 0)
|
582 |
|
583 | #undef is_free
|
584 | #define is_free(i) (alpha_status[i] == 0)
|
585 |
|
586 | #undef get_C
|
587 | #define get_C(i) (C[y[i]>0])
|
588 |
|
589 | #undef update_alpha_status
|
590 | #define update_alpha_status(i) \
|
591 | alpha_status[i] = (schar)(alpha[i] >= get_C(i) ? 1 : alpha[i] <= 0 ? -1 : 0)
|
592 |
|
593 | #undef reconstruct_gradient
|
594 | #define reconstruct_gradient()
|
595 |
|
596 |
|
597 | bool CvSVMSolver::solve_generic( CvSVMSolutionInfo& si )
|
598 | {
|
599 | int iter = 0;
|
600 | int i, j, k;
|
601 |
|
602 |
|
603 | for( i = 0; i < alpha_count; i++ )
|
604 | {
|
605 | update_alpha_status(i);
|
606 | G[i] = b[i];
|
607 | if( fabs(G[i]) > 1e200 )
|
608 | return false;
|
609 | }
|
610 |
|
611 | for( i = 0; i < alpha_count; i++ )
|
612 | {
|
613 | if( !is_lower_bound(i) )
|
614 | {
|
615 | const Qfloat *Q_i = get_row( i, buf[0] );
|
616 | double alpha_i = alpha[i];
|
617 |
|
618 | for( j = 0; j < alpha_count; j++ )
|
619 | G[j] += alpha_i*Q_i[j];
|
620 | }
|
621 | }
|
622 |
|
623 |
|
624 | for(;;)
|
625 | {
|
626 | const Qfloat *Q_i, *Q_j;
|
627 | double C_i, C_j;
|
628 | double old_alpha_i, old_alpha_j, alpha_i, alpha_j;
|
629 | double delta_alpha_i, delta_alpha_j;
|
630 |
|
631 | #ifdef _DEBUG
|
632 | for( i = 0; i < alpha_count; i++ )
|
633 | {
|
634 | if( fabs(G[i]) > 1e+300 )
|
635 | return false;
|
636 |
|
637 | if( fabs(alpha[i]) > 1e16 )
|
638 | return false;
|
639 | }
|
640 | #endif
|
641 |
|
642 | if( (this->*select_working_set_func)( i, j ) != 0 || iter++ >= max_iter )
|
643 | break;
|
644 |
|
645 | Q_i = get_row( i, buf[0] );
|
646 | Q_j = get_row( j, buf[1] );
|
647 |
|
648 | C_i = get_C(i);
|
649 | C_j = get_C(j);
|
650 |
|
651 | alpha_i = old_alpha_i = alpha[i];
|
652 | alpha_j = old_alpha_j = alpha[j];
|
653 |
|
654 | if( y[i] != y[j] )
|
655 | {
|
656 | double denom = Q_i[i]+Q_j[j]+2*Q_i[j];
|
657 | double delta = (-G[i]-G[j])/MAX(fabs(denom),FLT_EPSILON);
|
658 | double diff = alpha_i - alpha_j;
|
659 | alpha_i += delta;
|
660 | alpha_j += delta;
|
661 |
|
662 | if( diff > 0 && alpha_j < 0 )
|
663 | {
|
664 | alpha_j = 0;
|
665 | alpha_i = diff;
|
666 | }
|
667 | else if( diff <= 0 && alpha_i < 0 )
|
668 | {
|
669 | alpha_i = 0;
|
670 | alpha_j = -diff;
|
671 | }
|
672 |
|
673 | if( diff > C_i - C_j && alpha_i > C_i )
|
674 | {
|
675 | alpha_i = C_i;
|
676 | alpha_j = C_i - diff;
|
677 | }
|
678 | else if( diff <= C_i - C_j && alpha_j > C_j )
|
679 | {
|
680 | alpha_j = C_j;
|
681 | alpha_i = C_j + diff;
|
682 | }
|
683 | }
|
684 | else
|
685 | {
|
686 | double denom = Q_i[i]+Q_j[j]-2*Q_i[j];
|
687 | double delta = (G[i]-G[j])/MAX(fabs(denom),FLT_EPSILON);
|
688 | double sum = alpha_i + alpha_j;
|
689 | alpha_i -= delta;
|
690 | alpha_j += delta;
|
691 |
|
692 | if( sum > C_i && alpha_i > C_i )
|
693 | {
|
694 | alpha_i = C_i;
|
695 | alpha_j = sum - C_i;
|
696 | }
|
697 | else if( sum <= C_i && alpha_j < 0)
|
698 | {
|
699 | alpha_j = 0;
|
700 | alpha_i = sum;
|
701 | }
|
702 |
|
703 | if( sum > C_j && alpha_j > C_j )
|
704 | {
|
705 | alpha_j = C_j;
|
706 | alpha_i = sum - C_j;
|
707 | }
|
708 | else if( sum <= C_j && alpha_i < 0 )
|
709 | {
|
710 | alpha_i = 0;
|
711 | alpha_j = sum;
|
712 | }
|
713 | }
|
714 |
|
715 |
|
716 | alpha[i] = alpha_i;
|
717 | alpha[j] = alpha_j;
|
718 | update_alpha_status(i);
|
719 | update_alpha_status(j);
|
720 |
|
721 |
|
722 | delta_alpha_i = alpha_i - old_alpha_i;
|
723 | delta_alpha_j = alpha_j - old_alpha_j;
|
724 |
|
725 | for( k = 0; k < alpha_count; k++ )
|
726 | G[k] += Q_i[k]*delta_alpha_i + Q_j[k]*delta_alpha_j;
|
727 | }
|
728 |
|
729 |
|
730 | (this->*calc_rho_func)( si.rho, si.r );
|
731 |
|
732 |
|
733 | for( i = 0, si.obj = 0; i < alpha_count; i++ )
|
734 | si.obj += alpha[i] * (G[i] + b[i]);
|
735 |
|
736 | si.obj *= 0.5;
|
737 |
|
738 | si.upper_bound_p = C[1];
|
739 | si.upper_bound_n = C[0];
|
740 |
|
741 | return true;
|
742 | }
|
743 |
|
744 |
|
745 |
|
746 | bool
|
747 | CvSVMSolver::select_working_set( int& out_i, int& out_j )
|
748 | {
|
749 |
|
750 |
|
751 |
|
752 | double Gmax1 = -DBL_MAX;
|
753 | int Gmax1_idx = -1;
|
754 |
|
755 | double Gmax2 = -DBL_MAX;
|
756 | int Gmax2_idx = -1;
|
757 |
|
758 | int i;
|
759 |
|
760 | for( i = 0; i < alpha_count; i++ )
|
761 | {
|
762 | double t;
|
763 |
|
764 | if( y[i] > 0 )
|
765 | {
|
766 | if( !is_upper_bound(i) && (t = -G[i]) > Gmax1 )
|
767 | {
|
768 | Gmax1 = t;
|
769 | Gmax1_idx = i;
|
770 | }
|
771 | if( !is_lower_bound(i) && (t = G[i]) > Gmax2 )
|
772 | {
|
773 | Gmax2 = t;
|
774 | Gmax2_idx = i;
|
775 | }
|
776 | }
|
777 | else
|
778 | {
|
779 | if( !is_upper_bound(i) && (t = -G[i]) > Gmax2 )
|
780 | {
|
781 | Gmax2 = t;
|
782 | Gmax2_idx = i;
|
783 | }
|
784 | if( !is_lower_bound(i) && (t = G[i]) > Gmax1 )
|
785 | {
|
786 | Gmax1 = t;
|
787 | Gmax1_idx = i;
|
788 | }
|
789 | }
|
790 | }
|
791 |
|
792 | out_i = Gmax1_idx;
|
793 | out_j = Gmax2_idx;
|
794 |
|
795 | return Gmax1 + Gmax2 < eps;
|
796 | }
|
797 |
|
798 |
|
799 | void
|
800 | CvSVMSolver::calc_rho( double& rho, double& r )
|
801 | {
|
802 | int i, nr_free = 0;
|
803 | double ub = DBL_MAX, lb = -DBL_MAX, sum_free = 0;
|
804 |
|
805 | for( i = 0; i < alpha_count; i++ )
|
806 | {
|
807 | double yG = y[i]*G[i];
|
808 |
|
809 | if( is_lower_bound(i) )
|
810 | {
|
811 | if( y[i] > 0 )
|
812 | ub = MIN(ub,yG);
|
813 | else
|
814 | lb = MAX(lb,yG);
|
815 | }
|
816 | else if( is_upper_bound(i) )
|
817 | {
|
818 | if( y[i] < 0)
|
819 | ub = MIN(ub,yG);
|
820 | else
|
821 | lb = MAX(lb,yG);
|
822 | }
|
823 | else
|
824 | {
|
825 | ++nr_free;
|
826 | sum_free += yG;
|
827 | }
|
828 | }
|
829 |
|
830 | rho = nr_free > 0 ? sum_free/nr_free : (ub + lb)*0.5;
|
831 | r = 0;
|
832 | }
|
833 |
|
834 |
|
835 | bool
|
836 | CvSVMSolver::select_working_set_nu_svm( int& out_i, int& out_j )
|
837 | {
|
838 |
|
839 |
|
840 |
|
841 | double Gmax1 = -DBL_MAX;
|
842 | int Gmax1_idx = -1;
|
843 |
|
844 | double Gmax2 = -DBL_MAX;
|
845 | int Gmax2_idx = -1;
|
846 |
|
847 | double Gmax3 = -DBL_MAX;
|
848 | int Gmax3_idx = -1;
|
849 |
|
850 | double Gmax4 = -DBL_MAX;
|
851 | int Gmax4_idx = -1;
|
852 |
|
853 | int i;
|
854 |
|
855 | for( i = 0; i < alpha_count; i++ )
|
856 | {
|
857 | double t;
|
858 |
|
859 | if( y[i] > 0 )
|
860 | {
|
861 | if( !is_upper_bound(i) && (t = -G[i]) > Gmax1 )
|
862 | {
|
863 | Gmax1 = t;
|
864 | Gmax1_idx = i;
|
865 | }
|
866 | if( !is_lower_bound(i) && (t = G[i]) > Gmax2 )
|
867 | {
|
868 | Gmax2 = t;
|
869 | Gmax2_idx = i;
|
870 | }
|
871 | }
|
872 | else
|
873 | {
|
874 | if( !is_upper_bound(i) && (t = -G[i]) > Gmax3 )
|
875 | {
|
876 | Gmax3 = t;
|
877 | Gmax3_idx = i;
|
878 | }
|
879 | if( !is_lower_bound(i) && (t = G[i]) > Gmax4 )
|
880 | {
|
881 | Gmax4 = t;
|
882 | Gmax4_idx = i;
|
883 | }
|
884 | }
|
885 | }
|
886 |
|
887 | if( MAX(Gmax1 + Gmax2, Gmax3 + Gmax4) < eps )
|
888 | return 1;
|
889 |
|
890 | if( Gmax1 + Gmax2 > Gmax3 + Gmax4 )
|
891 | {
|
892 | out_i = Gmax1_idx;
|
893 | out_j = Gmax2_idx;
|
894 | }
|
895 | else
|
896 | {
|
897 | out_i = Gmax3_idx;
|
898 | out_j = Gmax4_idx;
|
899 | }
|
900 | return 0;
|
901 | }
|
902 |
|
903 |
|
904 | void
|
905 | CvSVMSolver::calc_rho_nu_svm( double& rho, double& r )
|
906 | {
|
907 | int nr_free1 = 0, nr_free2 = 0;
|
908 | double ub1 = DBL_MAX, ub2 = DBL_MAX;
|
909 | double lb1 = -DBL_MAX, lb2 = -DBL_MAX;
|
910 | double sum_free1 = 0, sum_free2 = 0;
|
911 | double r1, r2;
|
912 |
|
913 | int i;
|
914 |
|
915 | for( i = 0; i < alpha_count; i++ )
|
916 | {
|
917 | double G_i = G[i];
|
918 | if( y[i] > 0 )
|
919 | {
|
920 | if( is_lower_bound(i) )
|
921 | ub1 = MIN( ub1, G_i );
|
922 | else if( is_upper_bound(i) )
|
923 | lb1 = MAX( lb1, G_i );
|
924 | else
|
925 | {
|
926 | ++nr_free1;
|
927 | sum_free1 += G_i;
|
928 | }
|
929 | }
|
930 | else
|
931 | {
|
932 | if( is_lower_bound(i) )
|
933 | ub2 = MIN( ub2, G_i );
|
934 | else if( is_upper_bound(i) )
|
935 | lb2 = MAX( lb2, G_i );
|
936 | else
|
937 | {
|
938 | ++nr_free2;
|
939 | sum_free2 += G_i;
|
940 | }
|
941 | }
|
942 | }
|
943 |
|
944 | r1 = nr_free1 > 0 ? sum_free1/nr_free1 : (ub1 + lb1)*0.5;
|
945 | r2 = nr_free2 > 0 ? sum_free2/nr_free2 : (ub2 + lb2)*0.5;
|
946 |
|
947 | rho = (r1 - r2)*0.5;
|
948 | r = (r1 + r2)*0.5;
|
949 | }
|
950 |
|
951 |
|
952 | |
953 | ///////////////////////// construct and solve various formulations ///////////////////////
|
954 | */
|
955 |
|
956 | bool CvSVMSolver::solve_c_svc( int _sample_count, int _var_count, const float** _samples, schar* _y,
|
957 | double _Cp, double _Cn, CvMemStorage* _storage,
|
958 | CvSVMKernel* _kernel, double* _alpha, CvSVMSolutionInfo& _si )
|
959 | {
|
960 | int i;
|
961 |
|
962 | if( !create( _sample_count, _var_count, _samples, _y, _sample_count,
|
963 | _alpha, _Cp, _Cn, _storage, _kernel, &CvSVMSolver::get_row_svc,
|
964 | &CvSVMSolver::select_working_set, &CvSVMSolver::calc_rho ))
|
965 | return false;
|
966 |
|
967 | for( i = 0; i < sample_count; i++ )
|
968 | {
|
969 | alpha[i] = 0;
|
970 | b[i] = -1;
|
971 | }
|
972 |
|
973 | if( !solve_generic( _si ))
|
974 | return false;
|
975 |
|
976 | for( i = 0; i < sample_count; i++ )
|
977 | alpha[i] *= y[i];
|
978 |
|
979 | return true;
|
980 | }
|
981 |
|
982 |
|
983 | bool CvSVMSolver::solve_nu_svc( int _sample_count, int _var_count, const float** _samples, schar* _y,
|
984 | CvMemStorage* _storage, CvSVMKernel* _kernel,
|
985 | double* _alpha, CvSVMSolutionInfo& _si )
|
986 | {
|
987 | int i;
|
988 | double sum_pos, sum_neg, inv_r;
|
989 |
|
990 | if( !create( _sample_count, _var_count, _samples, _y, _sample_count,
|
991 | _alpha, 1., 1., _storage, _kernel, &CvSVMSolver::get_row_svc,
|
992 | &CvSVMSolver::select_working_set_nu_svm, &CvSVMSolver::calc_rho_nu_svm ))
|
993 | return false;
|
994 |
|
995 | sum_pos = kernel->params->nu * sample_count * 0.5;
|
996 | sum_neg = kernel->params->nu * sample_count * 0.5;
|
997 |
|
998 | for( i = 0; i < sample_count; i++ )
|
999 | {
|
1000 | if( y[i] > 0 )
|
1001 | {
|
1002 | alpha[i] = MIN(1.0, sum_pos);
|
1003 | sum_pos -= alpha[i];
|
1004 | }
|
1005 | else
|
1006 | {
|
1007 | alpha[i] = MIN(1.0, sum_neg);
|
1008 | sum_neg -= alpha[i];
|
1009 | }
|
1010 | b[i] = 0;
|
1011 | }
|
1012 |
|
1013 | if( !solve_generic( _si ))
|
1014 | return false;
|
1015 |
|
1016 | inv_r = 1./_si.r;
|
1017 |
|
1018 | for( i = 0; i < sample_count; i++ )
|
1019 | alpha[i] *= y[i]*inv_r;
|
1020 |
|
1021 | _si.rho *= inv_r;
|
1022 | _si.obj *= (inv_r*inv_r);
|
1023 | _si.upper_bound_p = inv_r;
|
1024 | _si.upper_bound_n = inv_r;
|
1025 |
|
1026 | return true;
|
1027 | }
|
1028 |
|
1029 |
|
1030 | bool CvSVMSolver::solve_one_class( int _sample_count, int _var_count, const float** _samples,
|
1031 | CvMemStorage* _storage, CvSVMKernel* _kernel,
|
1032 | double* _alpha, CvSVMSolutionInfo& _si )
|
1033 | {
|
1034 | int i, n;
|
1035 | double nu = _kernel->params->nu;
|
1036 |
|
1037 | if( !create( _sample_count, _var_count, _samples, 0, _sample_count,
|
1038 | _alpha, 1., 1., _storage, _kernel, &CvSVMSolver::get_row_one_class,
|
1039 | &CvSVMSolver::select_working_set, &CvSVMSolver::calc_rho ))
|
1040 | return false;
|
1041 |
|
1042 | y = (schar*)cvMemStorageAlloc( storage, sample_count*sizeof(y[0]) );
|
1043 | n = cvRound( nu*sample_count );
|
1044 |
|
1045 | for( i = 0; i < sample_count; i++ )
|
1046 | {
|
1047 | y[i] = 1;
|
1048 | b[i] = 0;
|
1049 | alpha[i] = i < n ? 1 : 0;
|
1050 | }
|
1051 |
|
1052 | if( n < sample_count )
|
1053 | alpha[n] = nu * sample_count - n;
|
1054 | else
|
1055 | alpha[n-1] = nu * sample_count - (n-1);
|
1056 |
|
1057 | return solve_generic(_si);
|
1058 | }
|
1059 |
|
1060 |
|
1061 | bool CvSVMSolver::solve_eps_svr( int _sample_count, int _var_count, const float** _samples,
|
1062 | const float* _y, CvMemStorage* _storage,
|
1063 | CvSVMKernel* _kernel, double* _alpha, CvSVMSolutionInfo& _si )
|
1064 | {
|
1065 | int i;
|
1066 | double p = _kernel->params->p, kernel_param_c = _kernel->params->C;
|
1067 |
|
1068 | if( !create( _sample_count, _var_count, _samples, 0,
|
1069 | _sample_count*2, 0, kernel_param_c, kernel_param_c, _storage, _kernel, &CvSVMSolver::get_row_svr,
|
1070 | &CvSVMSolver::select_working_set, &CvSVMSolver::calc_rho ))
|
1071 | return false;
|
1072 |
|
1073 | y = (schar*)cvMemStorageAlloc( storage, sample_count*2*sizeof(y[0]) );
|
1074 | alpha = (double*)cvMemStorageAlloc( storage, alpha_count*sizeof(alpha[0]) );
|
1075 |
|
1076 | for( i = 0; i < sample_count; i++ )
|
1077 | {
|
1078 | alpha[i] = 0;
|
1079 | b[i] = p - _y[i];
|
1080 | y[i] = 1;
|
1081 |
|
1082 | alpha[i+sample_count] = 0;
|
1083 | b[i+sample_count] = p + _y[i];
|
1084 | y[i+sample_count] = -1;
|
1085 | }
|
1086 |
|
1087 | if( !solve_generic( _si ))
|
1088 | return false;
|
1089 |
|
1090 | for( i = 0; i < sample_count; i++ )
|
1091 | _alpha[i] = alpha[i] - alpha[i+sample_count];
|
1092 |
|
1093 | return true;
|
1094 | }
|
1095 |
|
1096 |
|
1097 | bool CvSVMSolver::solve_nu_svr( int _sample_count, int _var_count, const float** _samples,
|
1098 | const float* _y, CvMemStorage* _storage,
|
1099 | CvSVMKernel* _kernel, double* _alpha, CvSVMSolutionInfo& _si )
|
1100 | {
|
1101 | int i;
|
1102 | double kernel_param_c = _kernel->params->C, sum;
|
1103 |
|
1104 | if( !create( _sample_count, _var_count, _samples, 0,
|
1105 | _sample_count*2, 0, 1., 1., _storage, _kernel, &CvSVMSolver::get_row_svr,
|
1106 | &CvSVMSolver::select_working_set_nu_svm, &CvSVMSolver::calc_rho_nu_svm ))
|
1107 | return false;
|
1108 |
|
1109 | y = (schar*)cvMemStorageAlloc( storage, sample_count*2*sizeof(y[0]) );
|
1110 | alpha = (double*)cvMemStorageAlloc( storage, alpha_count*sizeof(alpha[0]) );
|
1111 | sum = kernel_param_c * _kernel->params->nu * sample_count * 0.5;
|
1112 |
|
1113 | for( i = 0; i < sample_count; i++ )
|
1114 | {
|
1115 | alpha[i] = alpha[i + sample_count] = MIN(sum, kernel_param_c);
|
1116 | sum -= alpha[i];
|
1117 |
|
1118 | b[i] = -_y[i];
|
1119 | y[i] = 1;
|
1120 |
|
1121 | b[i + sample_count] = _y[i];
|
1122 | y[i + sample_count] = -1;
|
1123 | }
|
1124 |
|
1125 | if( !solve_generic( _si ))
|
1126 | return false;
|
1127 |
|
1128 | for( i = 0; i < sample_count; i++ )
|
1129 | _alpha[i] = alpha[i] - alpha[i+sample_count];
|
1130 |
|
1131 | return true;
|
1132 | }
|
1133 |
|
1134 |
|
1135 |
|
1136 |
|
1137 | CvSVM::CvSVM()
|
1138 | {
|
1139 | decision_func = 0;
|
1140 | class_labels = 0;
|
1141 | class_weights = 0;
|
1142 | storage = 0;
|
1143 | var_idx = 0;
|
1144 | kernel = 0;
|
1145 | solver = 0;
|
1146 | default_model_name = "my_svm";
|
1147 |
|
1148 | clear();
|
1149 | }
|
1150 |
|
1151 |
|
1152 | CvSVM::~CvSVM()
|
1153 | {
|
1154 | clear();
|
1155 | }
|
1156 |
|
1157 |
|
1158 | void CvSVM::clear()
|
1159 | {
|
1160 | cvFree( &decision_func );
|
1161 | cvReleaseMat( &class_labels );
|
1162 | cvReleaseMat( &class_weights );
|
1163 | cvReleaseMemStorage( &storage );
|
1164 | cvReleaseMat( &var_idx );
|
1165 | delete kernel;
|
1166 | delete solver;
|
1167 | kernel = 0;
|
1168 | solver = 0;
|
1169 | var_all = 0;
|
1170 | sv = 0;
|
1171 | sv_total = 0;
|
1172 | }
|
1173 |
|
1174 |
|
1175 | CvSVM::CvSVM( const CvMat* _train_data, const CvMat* _responses,
|
1176 | const CvMat* _var_idx, const CvMat* _sample_idx, CvSVMParams _params )
|
1177 | {
|
1178 | decision_func = 0;
|
1179 | class_labels = 0;
|
1180 | class_weights = 0;
|
1181 | storage = 0;
|
1182 | var_idx = 0;
|
1183 | kernel = 0;
|
1184 | solver = 0;
|
1185 | default_model_name = "my_svm";
|
1186 |
|
1187 | train( _train_data, _responses, _var_idx, _sample_idx, _params );
|
1188 | }
|
1189 |
|
1190 |
|
1191 | int CvSVM::get_support_vector_count() const
|
1192 | {
|
1193 | return sv_total;
|
1194 | }
|
1195 |
|
1196 |
|
1197 | const float* CvSVM::get_support_vector(int i) const
|
1198 | {
|
1199 | return sv && (unsigned)i < (unsigned)sv_total ? sv[i] : 0;
|
1200 | }
|
1201 |
|
1202 |
|
1203 | bool CvSVM::set_params( const CvSVMParams& _params )
|
1204 | {
|
1205 | bool ok = false;
|
1206 |
|
1207 | CV_FUNCNAME( "CvSVM::set_params" );
|
1208 |
|
1209 | __BEGIN__;
|
1210 |
|
1211 | int kernel_type, svm_type;
|
1212 |
|
1213 | params = _params;
|
1214 |
|
1215 | kernel_type = params.kernel_type;
|
1216 | svm_type = params.svm_type;
|
1217 |
|
1218 | if( kernel_type != LINEAR && kernel_type != POLY &&
|
1219 | kernel_type != SIGMOID && kernel_type != RBF )
|
1220 | CV_ERROR( CV_StsBadArg, "Unknown/unsupported kernel type" );
|
1221 |
|
1222 | if( kernel_type == LINEAR )
|
1223 | params.gamma = 1;
|
1224 | else if( params.gamma <= 0 )
|
1225 | CV_ERROR( CV_StsOutOfRange, "gamma parameter of the kernel must be positive" );
|
1226 |
|
1227 | if( kernel_type != SIGMOID && kernel_type != POLY )
|
1228 | params.coef0 = 0;
|
1229 | else if( params.coef0 < 0 )
|
1230 | CV_ERROR( CV_StsOutOfRange, "The kernel parameter <coef0> must be positive or zero" );
|
1231 |
|
1232 | if( kernel_type != POLY )
|
1233 | params.degree = 0;
|
1234 | else if( params.degree <= 0 )
|
1235 | CV_ERROR( CV_StsOutOfRange, "The kernel parameter <degree> must be positive" );
|
1236 |
|
1237 | if( svm_type != C_SVC && svm_type != NU_SVC &&
|
1238 | svm_type != ONE_CLASS && svm_type != EPS_SVR &&
|
1239 | svm_type != NU_SVR )
|
1240 | CV_ERROR( CV_StsBadArg, "Unknown/unsupported SVM type" );
|
1241 |
|
1242 | if( svm_type == ONE_CLASS || svm_type == NU_SVC )
|
1243 | params.C = 0;
|
1244 | else if( params.C <= 0 )
|
1245 | CV_ERROR( CV_StsOutOfRange, "The parameter C must be positive" );
|
1246 |
|
1247 | if( svm_type == C_SVC || svm_type == EPS_SVR )
|
1248 | params.nu = 0;
|
1249 | else if( params.nu <= 0 || params.nu >= 1 )
|
1250 | CV_ERROR( CV_StsOutOfRange, "The parameter nu must be between 0 and 1" );
|
1251 |
|
1252 | if( svm_type != EPS_SVR )
|
1253 | params.p = 0;
|
1254 | else if( params.p <= 0 )
|
1255 | CV_ERROR( CV_StsOutOfRange, "The parameter p must be positive" );
|
1256 |
|
1257 | if( svm_type != C_SVC )
|
1258 | params.class_weights = 0;
|
1259 |
|
1260 | params.term_crit = cvCheckTermCriteria( params.term_crit, DBL_EPSILON, INT_MAX );
|
1261 | params.term_crit.epsilon = MAX( params.term_crit.epsilon, DBL_EPSILON );
|
1262 | ok = true;
|
1263 |
|
1264 | __END__;
|
1265 |
|
1266 | return ok;
|
1267 | }
|
1268 |
|
1269 |
|
1270 |
|
1271 | void CvSVM::create_kernel()
|
1272 | {
|
1273 | kernel = new CvSVMKernel(¶ms,0);
|
1274 | }
|
1275 |
|
1276 |
|
1277 | void CvSVM::create_solver( )
|
1278 | {
|
1279 | solver = new CvSVMSolver;
|
1280 | }
|
1281 |
|
1282 |
|
1283 |
|
1284 | bool CvSVM::train1( int sample_count, int var_count, const float** samples,
|
1285 | const void* _responses, double Cp, double Cn,
|
1286 | CvMemStorage* _storage, double* alpha, double& rho )
|
1287 | {
|
1288 | bool ok = false;
|
1289 |
|
1290 |
|
1291 |
|
1292 | __BEGIN__;
|
1293 |
|
1294 | CvSVMSolutionInfo si;
|
1295 | int svm_type = params.svm_type;
|
1296 |
|
1297 | si.rho = 0;
|
1298 |
|
1299 | ok = svm_type == C_SVC ? solver->solve_c_svc( sample_count, var_count, samples, (schar*)_responses,
|
1300 | Cp, Cn, _storage, kernel, alpha, si ) :
|
1301 | svm_type == NU_SVC ? solver->solve_nu_svc( sample_count, var_count, samples, (schar*)_responses,
|
1302 | _storage, kernel, alpha, si ) :
|
1303 | svm_type == ONE_CLASS ? solver->solve_one_class( sample_count, var_count, samples,
|
1304 | _storage, kernel, alpha, si ) :
|
1305 | svm_type == EPS_SVR ? solver->solve_eps_svr( sample_count, var_count, samples, (float*)_responses,
|
1306 | _storage, kernel, alpha, si ) :
|
1307 | svm_type == NU_SVR ? solver->solve_nu_svr( sample_count, var_count, samples, (float*)_responses,
|
1308 | _storage, kernel, alpha, si ) : false;
|
1309 |
|
1310 | rho = si.rho;
|
1311 |
|
1312 | __END__;
|
1313 |
|
1314 | return ok;
|
1315 | }
|
1316 |
|
1317 |
|
1318 | bool CvSVM::do_train( int svm_type, int sample_count, int var_count, const float** samples,
|
1319 | const CvMat* responses, CvMemStorage* temp_storage, double* alpha )
|
1320 | {
|
1321 | bool ok = false;
|
1322 |
|
1323 | CV_FUNCNAME( "CvSVM::do_train" );
|
1324 |
|
1325 | __BEGIN__;
|
1326 |
|
1327 | CvSVMDecisionFunc* df = 0;
|
1328 | const int sample_size = var_count*sizeof(samples[0][0]);
|
1329 | int i, j, k;
|
1330 |
|
1331 | cvClearMemStorage( storage );
|
1332 |
|
1333 | if( svm_type == ONE_CLASS || svm_type == EPS_SVR || svm_type == NU_SVR )
|
1334 | {
|
1335 | int sv_count = 0;
|
1336 |
|
1337 | CV_CALL( decision_func = df =
|
1338 | (CvSVMDecisionFunc*)cvAlloc( sizeof(df[0]) ));
|
1339 |
|
1340 | df->rho = 0;
|
1341 | if( !train1( sample_count, var_count, samples, svm_type == ONE_CLASS ? 0 :
|
1342 | responses->data.i, 0, 0, temp_storage, alpha, df->rho ))
|
1343 | EXIT;
|
1344 |
|
1345 | for( i = 0; i < sample_count; i++ )
|
1346 | sv_count += fabs(alpha[i]) > 0;
|
1347 |
|
1348 | CV_Assert(sv_count != 0);
|
1349 |
|
1350 | sv_total = df->sv_count = sv_count;
|
1351 | CV_CALL( df->alpha = (double*)cvMemStorageAlloc( storage, sv_count*sizeof(df->alpha[0])) );
|
1352 | CV_CALL( sv = (float**)cvMemStorageAlloc( storage, sv_count*sizeof(sv[0])));
|
1353 |
|
1354 | for( i = k = 0; i < sample_count; i++ )
|
1355 | {
|
1356 | if( fabs(alpha[i]) > 0 )
|
1357 | {
|
1358 | CV_CALL( sv[k] = (float*)cvMemStorageAlloc( storage, sample_size ));
|
1359 | memcpy( sv[k], samples[i], sample_size );
|
1360 | df->alpha[k++] = alpha[i];
|
1361 | }
|
1362 | }
|
1363 | }
|
1364 | else
|
1365 | {
|
1366 | int class_count = class_labels->cols;
|
1367 | int* sv_tab = 0;
|
1368 | const float** temp_samples = 0;
|
1369 | int* class_ranges = 0;
|
1370 | schar* temp_y = 0;
|
1371 | assert( svm_type == CvSVM::C_SVC || svm_type == CvSVM::NU_SVC );
|
1372 |
|
1373 | if( svm_type == CvSVM::C_SVC && params.class_weights )
|
1374 | {
|
1375 | const CvMat* cw = params.class_weights;
|
1376 |
|
1377 | if( !CV_IS_MAT(cw) || (cw->cols != 1 && cw->rows != 1) ||
|
1378 | cw->rows + cw->cols - 1 != class_count ||
|
1379 | (CV_MAT_TYPE(cw->type) != CV_32FC1 && CV_MAT_TYPE(cw->type) != CV_64FC1) )
|
1380 | CV_ERROR( CV_StsBadArg, "params.class_weights must be 1d floating-point vector "
|
1381 | "containing as many elements as the number of classes" );
|
1382 |
|
1383 | CV_CALL( class_weights = cvCreateMat( cw->rows, cw->cols, CV_64F ));
|
1384 | CV_CALL( cvConvert( cw, class_weights ));
|
1385 | CV_CALL( cvScale( class_weights, class_weights, params.C ));
|
1386 | }
|
1387 |
|
1388 | CV_CALL( decision_func = df = (CvSVMDecisionFunc*)cvAlloc(
|
1389 | (class_count*(class_count-1)/2)*sizeof(df[0])));
|
1390 |
|
1391 | CV_CALL( sv_tab = (int*)cvMemStorageAlloc( temp_storage, sample_count*sizeof(sv_tab[0]) ));
|
1392 | memset( sv_tab, 0, sample_count*sizeof(sv_tab[0]) );
|
1393 | CV_CALL( class_ranges = (int*)cvMemStorageAlloc( temp_storage,
|
1394 | (class_count + 1)*sizeof(class_ranges[0])));
|
1395 | CV_CALL( temp_samples = (const float**)cvMemStorageAlloc( temp_storage,
|
1396 | sample_count*sizeof(temp_samples[0])));
|
1397 | CV_CALL( temp_y = (schar*)cvMemStorageAlloc( temp_storage, sample_count));
|
1398 |
|
1399 | class_ranges[class_count] = 0;
|
1400 | cvSortSamplesByClasses( samples, responses, class_ranges, 0 );
|
1401 |
|
1402 | if( class_ranges[class_count] <= 0 )
|
1403 | CV_ERROR( CV_StsBadArg, "While cross-validation one or more of the classes have "
|
1404 | "been fell out of the sample. Try to enlarge <CvSVMParams::k_fold>" );
|
1405 |
|
1406 | if( svm_type == NU_SVC )
|
1407 | {
|
1408 |
|
1409 | for(i = 0; i < class_count; i++ )
|
1410 | {
|
1411 | int ci = class_ranges[i+1] - class_ranges[i];
|
1412 | for( j = i+1; j< class_count; j++ )
|
1413 | {
|
1414 | int cj = class_ranges[j+1] - class_ranges[j];
|
1415 | if( params.nu*(ci + cj)*0.5 > MIN( ci, cj ) )
|
1416 | {
|
1417 |
|
1418 | EXIT;
|
1419 | }
|
1420 | }
|
1421 | }
|
1422 | }
|
1423 |
|
1424 |
|
1425 | for( i = 0; i < class_count; i++ )
|
1426 | {
|
1427 | for( j = i+1; j < class_count; j++, df++ )
|
1428 | {
|
1429 | int si = class_ranges[i], ci = class_ranges[i+1] - si;
|
1430 | int sj = class_ranges[j], cj = class_ranges[j+1] - sj;
|
1431 | double Cp = params.C, Cn = Cp;
|
1432 | int k1 = 0, sv_count = 0;
|
1433 |
|
1434 | for( k = 0; k < ci; k++ )
|
1435 | {
|
1436 | temp_samples[k] = samples[si + k];
|
1437 | temp_y[k] = 1;
|
1438 | }
|
1439 |
|
1440 | for( k = 0; k < cj; k++ )
|
1441 | {
|
1442 | temp_samples[ci + k] = samples[sj + k];
|
1443 | temp_y[ci + k] = -1;
|
1444 | }
|
1445 |
|
1446 | if( class_weights )
|
1447 | {
|
1448 | Cp = class_weights->data.db[i];
|
1449 | Cn = class_weights->data.db[j];
|
1450 | }
|
1451 |
|
1452 | if( !train1( ci + cj, var_count, temp_samples, temp_y,
|
1453 | Cp, Cn, temp_storage, alpha, df->rho ))
|
1454 | EXIT;
|
1455 |
|
1456 | for( k = 0; k < ci + cj; k++ )
|
1457 | sv_count += fabs(alpha[k]) > 0;
|
1458 |
|
1459 | df->sv_count = sv_count;
|
1460 |
|
1461 | CV_CALL( df->alpha = (double*)cvMemStorageAlloc( temp_storage,
|
1462 | sv_count*sizeof(df->alpha[0])));
|
1463 | CV_CALL( df->sv_index = (int*)cvMemStorageAlloc( temp_storage,
|
1464 | sv_count*sizeof(df->sv_index[0])));
|
1465 |
|
1466 | for( k = 0; k < ci; k++ )
|
1467 | {
|
1468 | if( fabs(alpha[k]) > 0 )
|
1469 | {
|
1470 | sv_tab[si + k] = 1;
|
1471 | df->sv_index[k1] = si + k;
|
1472 | df->alpha[k1++] = alpha[k];
|
1473 | }
|
1474 | }
|
1475 |
|
1476 | for( k = 0; k < cj; k++ )
|
1477 | {
|
1478 | if( fabs(alpha[ci + k]) > 0 )
|
1479 | {
|
1480 | sv_tab[sj + k] = 1;
|
1481 | df->sv_index[k1] = sj + k;
|
1482 | df->alpha[k1++] = alpha[ci + k];
|
1483 | }
|
1484 | }
|
1485 | }
|
1486 | }
|
1487 |
|
1488 |
|
1489 | for( i = 0, k = 0; i < sample_count; i++ )
|
1490 | {
|
1491 | if( sv_tab[i] )
|
1492 | sv_tab[i] = ++k;
|
1493 | }
|
1494 |
|
1495 | sv_total = k;
|
1496 | CV_CALL( sv = (float**)cvMemStorageAlloc( storage, sv_total*sizeof(sv[0])));
|
1497 |
|
1498 | for( i = 0, k = 0; i < sample_count; i++ )
|
1499 | {
|
1500 | if( sv_tab[i] )
|
1501 | {
|
1502 | CV_CALL( sv[k] = (float*)cvMemStorageAlloc( storage, sample_size ));
|
1503 | memcpy( sv[k], samples[i], sample_size );
|
1504 | k++;
|
1505 | }
|
1506 | }
|
1507 |
|
1508 | df = (CvSVMDecisionFunc*)decision_func;
|
1509 |
|
1510 |
|
1511 | for( i = 0; i < class_count; i++ )
|
1512 | {
|
1513 | for( j = i+1; j < class_count; j++, df++ )
|
1514 | {
|
1515 | for( k = 0; k < df->sv_count; k++ )
|
1516 | {
|
1517 | df->sv_index[k] = sv_tab[df->sv_index[k]]-1;
|
1518 | assert( (unsigned)df->sv_index[k] < (unsigned)sv_total );
|
1519 | }
|
1520 | }
|
1521 | }
|
1522 | }
|
1523 |
|
1524 | optimize_linear_svm();
|
1525 | ok = true;
|
1526 |
|
1527 | __END__;
|
1528 |
|
1529 | return ok;
|
1530 | }
|
1531 |
|
1532 |
|
1533 | void CvSVM::optimize_linear_svm()
|
1534 | {
|
1535 |
|
1536 | if( params.kernel_type != LINEAR )
|
1537 | return;
|
1538 |
|
1539 | int class_count = class_labels ? class_labels->cols :
|
1540 | params.svm_type == CvSVM::ONE_CLASS ? 1 : 0;
|
1541 |
|
1542 | int i, df_count = class_count > 1 ? class_count*(class_count-1)/2 : 1;
|
1543 | CvSVMDecisionFunc* df = decision_func;
|
1544 |
|
1545 | for( i = 0; i < df_count; i++ )
|
1546 | {
|
1547 | int sv_count = df[i].sv_count;
|
1548 | if( sv_count != 1 )
|
1549 | break;
|
1550 | }
|
1551 |
|
1552 |
|
1553 |
|
1554 | if( i == df_count )
|
1555 | return;
|
1556 |
|
1557 | int var_count = get_var_count();
|
1558 | cv::AutoBuffer<double> vbuf(var_count);
|
1559 | double* v = vbuf;
|
1560 | float** new_sv = (float**)cvMemStorageAlloc(storage, df_count*sizeof(new_sv[0]));
|
1561 |
|
1562 | for( i = 0; i < df_count; i++ )
|
1563 | {
|
1564 | new_sv[i] = (float*)cvMemStorageAlloc(storage, var_count*sizeof(new_sv[i][0]));
|
1565 | float* dst = new_sv[i];
|
1566 | memset(v, 0, var_count*sizeof(v[0]));
|
1567 | int j, k, sv_count = df[i].sv_count;
|
1568 | for( j = 0; j < sv_count; j++ )
|
1569 | {
|
1570 | const float* src = class_count > 1 && df[i].sv_index ? sv[df[i].sv_index[j]] : sv[j];
|
1571 | double a = df[i].alpha[j];
|
1572 | for( k = 0; k < var_count; k++ )
|
1573 | v[k] += src[k]*a;
|
1574 | }
|
1575 | for( k = 0; k < var_count; k++ )
|
1576 | dst[k] = (float)v[k];
|
1577 | df[i].sv_count = 1;
|
1578 | df[i].alpha[0] = 1.;
|
1579 | if( class_count > 1 && df[i].sv_index )
|
1580 | df[i].sv_index[0] = i;
|
1581 | }
|
1582 |
|
1583 | sv = new_sv;
|
1584 | sv_total = df_count;
|
1585 | }
|
1586 |
|
1587 |
|
1588 | bool CvSVM::train( const CvMat* _train_data, const CvMat* _responses,
|
1589 | const CvMat* _var_idx, const CvMat* _sample_idx, CvSVMParams _params )
|
1590 | {
|
1591 | bool ok = false;
|
1592 | CvMat* responses = 0;
|
1593 | CvMemStorage* temp_storage = 0;
|
1594 | const float** samples = 0;
|
1595 |
|
1596 | CV_FUNCNAME( "CvSVM::train" );
|
1597 |
|
1598 | __BEGIN__;
|
1599 |
|
1600 | int svm_type, sample_count, var_count, sample_size;
|
1601 | int block_size = 1 << 16;
|
1602 | double* alpha;
|
1603 |
|
1604 | clear();
|
1605 | CV_CALL( set_params( _params ));
|
1606 |
|
1607 | svm_type = _params.svm_type;
|
1608 |
|
1609 |
|
1610 | CV_CALL( cvPrepareTrainData( "CvSVM::train", _train_data, CV_ROW_SAMPLE,
|
1611 | svm_type != CvSVM::ONE_CLASS ? _responses : 0,
|
1612 | svm_type == CvSVM::C_SVC ||
|
1613 | svm_type == CvSVM::NU_SVC ? CV_VAR_CATEGORICAL :
|
1614 | CV_VAR_ORDERED, _var_idx, _sample_idx,
|
1615 | false, &samples, &sample_count, &var_count, &var_all,
|
1616 | &responses, &class_labels, &var_idx ));
|
1617 |
|
1618 |
|
1619 | sample_size = var_count*sizeof(samples[0][0]);
|
1620 |
|
1621 |
|
1622 |
|
1623 | block_size = MAX( block_size, sample_count*(int)sizeof(CvSVMKernelRow));
|
1624 | block_size = MAX( block_size, sample_count*2*(int)sizeof(double) + 1024 );
|
1625 | block_size = MAX( block_size, sample_size*2 + 1024 );
|
1626 |
|
1627 | CV_CALL( storage = cvCreateMemStorage(block_size + sizeof(CvMemBlock) + sizeof(CvSeqBlock)));
|
1628 | CV_CALL( temp_storage = cvCreateChildMemStorage(storage));
|
1629 | CV_CALL( alpha = (double*)cvMemStorageAlloc(temp_storage, sample_count*sizeof(double)));
|
1630 |
|
1631 | create_kernel();
|
1632 | create_solver();
|
1633 |
|
1634 | if( !do_train( svm_type, sample_count, var_count, samples, responses, temp_storage, alpha ))
|
1635 | EXIT;
|
1636 |
|
1637 | ok = true;
|
1638 |
|
1639 | __END__;
|
1640 |
|
1641 | delete solver;
|
1642 | solver = 0;
|
1643 | cvReleaseMemStorage( &temp_storage );
|
1644 | cvReleaseMat( &responses );
|
1645 | cvFree( &samples );
|
1646 |
|
1647 | if( cvGetErrStatus() < 0 || !ok )
|
1648 | clear();
|
1649 |
|
1650 | return ok;
|
1651 | }
|
1652 |
|
1653 | struct indexedratio
|
1654 | {
|
1655 | double val;
|
1656 | int ind;
|
1657 | int count_smallest, count_biggest;
|
1658 | void eval() { val = (double) count_smallest/(count_smallest+count_biggest); }
|
1659 | };
|
1660 |
|
1661 | static int CV_CDECL
|
1662 | icvCmpIndexedratio( const void* a, const void* b )
|
1663 | {
|
1664 | return ((const indexedratio*)a)->val < ((const indexedratio*)b)->val ? -1
|
1665 | : ((const indexedratio*)a)->val > ((const indexedratio*)b)->val ? 1
|
1666 | : 0;
|
1667 | }
|
1668 |
|
1669 | bool CvSVM::train_auto( const CvMat* _train_data, const CvMat* _responses,
|
1670 | const CvMat* _var_idx, const CvMat* _sample_idx, CvSVMParams _params, int k_fold,
|
1671 | CvParamGrid C_grid, CvParamGrid gamma_grid, CvParamGrid p_grid,
|
1672 | CvParamGrid nu_grid, CvParamGrid coef_grid, CvParamGrid degree_grid,
|
1673 | bool balanced)
|
1674 | {
|
1675 | bool ok = false;
|
1676 | CvMat* responses = 0;
|
1677 | CvMat* responses_local = 0;
|
1678 | CvMemStorage* temp_storage = 0;
|
1679 | const float** samples = 0;
|
1680 | const float** samples_local = 0;
|
1681 |
|
1682 | CV_FUNCNAME( "CvSVM::train_auto" );
|
1683 | __BEGIN__;
|
1684 |
|
1685 | int svm_type, sample_count, var_count, sample_size;
|
1686 | int block_size = 1 << 16;
|
1687 | double* alpha;
|
1688 | RNG* rng = &theRNG();
|
1689 |
|
1690 |
|
1691 | double degree_step = 10, g_step = 10, coef_step = 10, C_step = 10, nu_step = 10, p_step = 10;
|
1692 | double gamma = 0, curr_c = 0, degree = 0, coef = 0, p = 0, nu = 0;
|
1693 | double best_degree = 0, best_gamma = 0, best_coef = 0, best_C = 0, best_nu = 0, best_p = 0;
|
1694 | float min_error = FLT_MAX, error;
|
1695 |
|
1696 | if( _params.svm_type == CvSVM::ONE_CLASS )
|
1697 | {
|
1698 | if(!train( _train_data, _responses, _var_idx, _sample_idx, _params ))
|
1699 | EXIT;
|
1700 | return true;
|
1701 | }
|
1702 |
|
1703 | clear();
|
1704 |
|
1705 | if( k_fold < 2 )
|
1706 | CV_ERROR( CV_StsBadArg, "Parameter <k_fold> must be > 1" );
|
1707 |
|
1708 | CV_CALL(set_params( _params ));
|
1709 | svm_type = _params.svm_type;
|
1710 |
|
1711 |
|
1712 |
|
1713 | if( C_grid.step <= 1 )
|
1714 | {
|
1715 | C_grid.min_val = C_grid.max_val = params.C;
|
1716 | C_grid.step = 10;
|
1717 | }
|
1718 | else
|
1719 | CV_CALL(C_grid.check());
|
1720 |
|
1721 | if( gamma_grid.step <= 1 )
|
1722 | {
|
1723 | gamma_grid.min_val = gamma_grid.max_val = params.gamma;
|
1724 | gamma_grid.step = 10;
|
1725 | }
|
1726 | else
|
1727 | CV_CALL(gamma_grid.check());
|
1728 |
|
1729 | if( p_grid.step <= 1 )
|
1730 | {
|
1731 | p_grid.min_val = p_grid.max_val = params.p;
|
1732 | p_grid.step = 10;
|
1733 | }
|
1734 | else
|
1735 | CV_CALL(p_grid.check());
|
1736 |
|
1737 | if( nu_grid.step <= 1 )
|
1738 | {
|
1739 | nu_grid.min_val = nu_grid.max_val = params.nu;
|
1740 | nu_grid.step = 10;
|
1741 | }
|
1742 | else
|
1743 | CV_CALL(nu_grid.check());
|
1744 |
|
1745 | if( coef_grid.step <= 1 )
|
1746 | {
|
1747 | coef_grid.min_val = coef_grid.max_val = params.coef0;
|
1748 | coef_grid.step = 10;
|
1749 | }
|
1750 | else
|
1751 | CV_CALL(coef_grid.check());
|
1752 |
|
1753 | if( degree_grid.step <= 1 )
|
1754 | {
|
1755 | degree_grid.min_val = degree_grid.max_val = params.degree;
|
1756 | degree_grid.step = 10;
|
1757 | }
|
1758 | else
|
1759 | CV_CALL(degree_grid.check());
|
1760 |
|
1761 |
|
1762 | if( params.kernel_type != CvSVM::POLY )
|
1763 | degree_grid.min_val = degree_grid.max_val = params.degree;
|
1764 | if( params.kernel_type == CvSVM::LINEAR )
|
1765 | gamma_grid.min_val = gamma_grid.max_val = params.gamma;
|
1766 | if( params.kernel_type != CvSVM::POLY && params.kernel_type != CvSVM::SIGMOID )
|
1767 | coef_grid.min_val = coef_grid.max_val = params.coef0;
|
1768 | if( svm_type == CvSVM::NU_SVC || svm_type == CvSVM::ONE_CLASS )
|
1769 | C_grid.min_val = C_grid.max_val = params.C;
|
1770 | if( svm_type == CvSVM::C_SVC || svm_type == CvSVM::EPS_SVR )
|
1771 | nu_grid.min_val = nu_grid.max_val = params.nu;
|
1772 | if( svm_type != CvSVM::EPS_SVR )
|
1773 | p_grid.min_val = p_grid.max_val = params.p;
|
1774 |
|
1775 | CV_ASSERT( g_step > 1 && degree_step > 1 && coef_step > 1);
|
1776 | CV_ASSERT( p_step > 1 && C_step > 1 && nu_step > 1 );
|
1777 |
|
1778 |
|
1779 | CV_CALL(cvPrepareTrainData( "CvSVM::train_auto", _train_data, CV_ROW_SAMPLE,
|
1780 | svm_type != CvSVM::ONE_CLASS ? _responses : 0,
|
1781 | svm_type == CvSVM::C_SVC ||
|
1782 | svm_type == CvSVM::NU_SVC ? CV_VAR_CATEGORICAL :
|
1783 | CV_VAR_ORDERED, _var_idx, _sample_idx,
|
1784 | false, &samples, &sample_count, &var_count, &var_all,
|
1785 | &responses, &class_labels, &var_idx ));
|
1786 |
|
1787 | sample_size = var_count*sizeof(samples[0][0]);
|
1788 |
|
1789 |
|
1790 |
|
1791 | block_size = MAX( block_size, sample_count*(int)sizeof(CvSVMKernelRow));
|
1792 | block_size = MAX( block_size, sample_count*2*(int)sizeof(double) + 1024 );
|
1793 | block_size = MAX( block_size, sample_size*2 + 1024 );
|
1794 |
|
1795 | CV_CALL( storage = cvCreateMemStorage(block_size + sizeof(CvMemBlock) + sizeof(CvSeqBlock)));
|
1796 | CV_CALL(temp_storage = cvCreateChildMemStorage(storage));
|
1797 | CV_CALL(alpha = (double*)cvMemStorageAlloc(temp_storage, sample_count*sizeof(double)));
|
1798 |
|
1799 | create_kernel();
|
1800 | create_solver();
|
1801 |
|
1802 | {
|
1803 | const int testset_size = sample_count/k_fold;
|
1804 | const int trainset_size = sample_count - testset_size;
|
1805 | const int last_testset_size = sample_count - testset_size*(k_fold-1);
|
1806 | const int last_trainset_size = sample_count - last_testset_size;
|
1807 | const bool is_regression = (svm_type == EPS_SVR) || (svm_type == NU_SVR);
|
1808 |
|
1809 | size_t resp_elem_size = CV_ELEM_SIZE(responses->type);
|
1810 | size_t size = 2*last_trainset_size*sizeof(samples[0]);
|
1811 |
|
1812 | samples_local = (const float**) cvAlloc( size );
|
1813 | memset( samples_local, 0, size );
|
1814 |
|
1815 | responses_local = cvCreateMat( 1, trainset_size, CV_MAT_TYPE(responses->type) );
|
1816 | cvZero( responses_local );
|
1817 |
|
1818 |
|
1819 | for(int i = 0; i < sample_count; i++ )
|
1820 | {
|
1821 | int i1 = (*rng)(sample_count);
|
1822 | int i2 = (*rng)(sample_count);
|
1823 | const float* temp;
|
1824 | float t;
|
1825 | int y;
|
1826 |
|
1827 | CV_SWAP( samples[i1], samples[i2], temp );
|
1828 | if( is_regression )
|
1829 | CV_SWAP( responses->data.fl[i1], responses->data.fl[i2], t );
|
1830 | else
|
1831 | CV_SWAP( responses->data.i[i1], responses->data.i[i2], y );
|
1832 | }
|
1833 |
|
1834 | if (!is_regression && class_labels->cols==2 && balanced)
|
1835 | {
|
1836 |
|
1837 | int num_0=0,num_1=0;
|
1838 | for (int i=0; i<sample_count; ++i)
|
1839 | {
|
1840 | if (responses->data.i[i]==class_labels->data.i[0])
|
1841 | ++num_0;
|
1842 | else
|
1843 | ++num_1;
|
1844 | }
|
1845 |
|
1846 | int label_smallest_class;
|
1847 | int label_biggest_class;
|
1848 | if (num_0 < num_1)
|
1849 | {
|
1850 | label_biggest_class = class_labels->data.i[1];
|
1851 | label_smallest_class = class_labels->data.i[0];
|
1852 | }
|
1853 | else
|
1854 | {
|
1855 | label_biggest_class = class_labels->data.i[0];
|
1856 | label_smallest_class = class_labels->data.i[1];
|
1857 | int y;
|
1858 | CV_SWAP(num_0,num_1,y);
|
1859 | }
|
1860 | const double class_ratio = (double) num_0/sample_count;
|
1861 |
|
1862 | indexedratio *ratios=0;
|
1863 | ratios = (indexedratio*) cvAlloc(k_fold*sizeof(*ratios));
|
1864 | for (int k=0, i_begin=0; k<k_fold; ++k, i_begin+=testset_size)
|
1865 | {
|
1866 | int count0=0;
|
1867 | int count1=0;
|
1868 | int i_end = i_begin + (k<k_fold-1 ? testset_size : last_testset_size);
|
1869 | for (int i=i_begin; i<i_end; ++i)
|
1870 | {
|
1871 | if (responses->data.i[i]==label_smallest_class)
|
1872 | ++count0;
|
1873 | else
|
1874 | ++count1;
|
1875 | }
|
1876 | ratios[k].ind = k;
|
1877 | ratios[k].count_smallest = count0;
|
1878 | ratios[k].count_biggest = count1;
|
1879 | ratios[k].eval();
|
1880 | }
|
1881 |
|
1882 | qsort(ratios, k_fold, sizeof(ratios[0]), icvCmpIndexedratio);
|
1883 | double old_dist = 0.0;
|
1884 | for (int k=0; k<k_fold; ++k)
|
1885 | old_dist += fabs(ratios[k].val-class_ratio);
|
1886 | double new_dist = 1.0;
|
1887 |
|
1888 | while (new_dist > 0.0)
|
1889 | {
|
1890 | if (ratios[0].count_biggest==0 || ratios[k_fold-1].count_smallest==0)
|
1891 | break;
|
1892 |
|
1893 | qsort(ratios, k_fold, sizeof(ratios[0]), icvCmpIndexedratio);
|
1894 |
|
1895 |
|
1896 | ratios[0].count_smallest++;
|
1897 | ratios[0].count_biggest--;
|
1898 | ratios[0].eval();
|
1899 | ratios[k_fold-1].count_smallest--;
|
1900 | ratios[k_fold-1].count_biggest++;
|
1901 | ratios[k_fold-1].eval();
|
1902 |
|
1903 | new_dist = 0.0;
|
1904 | for (int k=0; k<k_fold; ++k)
|
1905 | new_dist += fabs(ratios[k].val-class_ratio);
|
1906 |
|
1907 | if (new_dist < old_dist)
|
1908 | {
|
1909 |
|
1910 |
|
1911 |
|
1912 | int i1 = ratios[0].ind * testset_size;
|
1913 | for ( ; i1<sample_count; ++i1)
|
1914 | {
|
1915 | if (responses->data.i[i1]==label_biggest_class)
|
1916 | break;
|
1917 | }
|
1918 |
|
1919 | int i2 = ratios[k_fold-1].ind * testset_size;
|
1920 | for ( ; i2<sample_count; ++i2)
|
1921 | {
|
1922 | if (responses->data.i[i2]==label_smallest_class)
|
1923 | break;
|
1924 | }
|
1925 |
|
1926 | const float* temp;
|
1927 | int y;
|
1928 | CV_SWAP( samples[i1], samples[i2], temp );
|
1929 | CV_SWAP( responses->data.i[i1], responses->data.i[i2], y );
|
1930 | old_dist = new_dist;
|
1931 | }
|
1932 | else
|
1933 | break;
|
1934 | }
|
1935 | cvFree(&ratios);
|
1936 | }
|
1937 |
|
1938 | int* cls_lbls = class_labels ? class_labels->data.i : 0;
|
1939 | curr_c = C_grid.min_val;
|
1940 | do
|
1941 | {
|
1942 | params.C = curr_c;
|
1943 | gamma = gamma_grid.min_val;
|
1944 | do
|
1945 | {
|
1946 | params.gamma = gamma;
|
1947 | p = p_grid.min_val;
|
1948 | do
|
1949 | {
|
1950 | params.p = p;
|
1951 | nu = nu_grid.min_val;
|
1952 | do
|
1953 | {
|
1954 | params.nu = nu;
|
1955 | coef = coef_grid.min_val;
|
1956 | do
|
1957 | {
|
1958 | params.coef0 = coef;
|
1959 | degree = degree_grid.min_val;
|
1960 | do
|
1961 | {
|
1962 | params.degree = degree;
|
1963 |
|
1964 | float** test_samples_ptr = (float**)samples;
|
1965 | uchar* true_resp = responses->data.ptr;
|
1966 | int test_size = testset_size;
|
1967 | int train_size = trainset_size;
|
1968 |
|
1969 | error = 0;
|
1970 | for(int k = 0; k < k_fold; k++ )
|
1971 | {
|
1972 | memcpy( samples_local, samples, sizeof(samples[0])*test_size*k );
|
1973 | memcpy( samples_local + test_size*k, test_samples_ptr + test_size,
|
1974 | sizeof(samples[0])*(sample_count - testset_size*(k+1)) );
|
1975 |
|
1976 | memcpy( responses_local->data.ptr, responses->data.ptr, resp_elem_size*test_size*k );
|
1977 | memcpy( responses_local->data.ptr + resp_elem_size*test_size*k,
|
1978 | true_resp + resp_elem_size*test_size,
|
1979 | resp_elem_size*(sample_count - testset_size*(k+1)) );
|
1980 |
|
1981 | if( k == k_fold - 1 )
|
1982 | {
|
1983 | test_size = last_testset_size;
|
1984 | train_size = last_trainset_size;
|
1985 | responses_local->cols = last_trainset_size;
|
1986 | }
|
1987 |
|
1988 |
|
1989 | if( !do_train( svm_type, train_size, var_count,
|
1990 | (const float**)samples_local, responses_local, temp_storage, alpha ) )
|
1991 | EXIT;
|
1992 |
|
1993 |
|
1994 | for(int i = 0; i < test_size; i++, true_resp += resp_elem_size, test_samples_ptr++ )
|
1995 | {
|
1996 | float resp = predict( *test_samples_ptr, var_count );
|
1997 | error += is_regression ? powf( resp - *(float*)true_resp, 2 )
|
1998 | : ((int)resp != cls_lbls[*(int*)true_resp]);
|
1999 | }
|
2000 | }
|
2001 | if( min_error > error )
|
2002 | {
|
2003 | min_error = error;
|
2004 | best_degree = degree;
|
2005 | best_gamma = gamma;
|
2006 | best_coef = coef;
|
2007 | best_C = curr_c;
|
2008 | best_nu = nu;
|
2009 | best_p = p;
|
2010 | }
|
2011 | degree *= degree_grid.step;
|
2012 | }
|
2013 | while( degree < degree_grid.max_val );
|
2014 | coef *= coef_grid.step;
|
2015 | }
|
2016 | while( coef < coef_grid.max_val );
|
2017 | nu *= nu_grid.step;
|
2018 | }
|
2019 | while( nu < nu_grid.max_val );
|
2020 | p *= p_grid.step;
|
2021 | }
|
2022 | while( p < p_grid.max_val );
|
2023 | gamma *= gamma_grid.step;
|
2024 | }
|
2025 | while( gamma < gamma_grid.max_val );
|
2026 | curr_c *= C_grid.step;
|
2027 | }
|
2028 | while( curr_c < C_grid.max_val );
|
2029 | }
|
2030 |
|
2031 | min_error /= (float) sample_count;
|
2032 |
|
2033 | params.C = best_C;
|
2034 | params.nu = best_nu;
|
2035 | params.p = best_p;
|
2036 | params.gamma = best_gamma;
|
2037 | params.degree = best_degree;
|
2038 | params.coef0 = best_coef;
|
2039 |
|
2040 | CV_CALL(ok = do_train( svm_type, sample_count, var_count, samples, responses, temp_storage, alpha ));
|
2041 |
|
2042 | __END__;
|
2043 |
|
2044 | delete solver;
|
2045 | solver = 0;
|
2046 | cvReleaseMemStorage( &temp_storage );
|
2047 | cvReleaseMat( &responses );
|
2048 | cvReleaseMat( &responses_local );
|
2049 | cvFree( &samples );
|
2050 | cvFree( &samples_local );
|
2051 |
|
2052 | if( cvGetErrStatus() < 0 || !ok )
|
2053 | clear();
|
2054 |
|
2055 | return ok;
|
2056 | }
|
2057 |
|
2058 | float CvSVM::predict( const float* row_sample, int row_len, bool returnDFVal ) const
|
2059 | {
|
2060 | assert( kernel );
|
2061 | assert( row_sample );
|
2062 |
|
2063 | int var_count = get_var_count();
|
2064 | assert( row_len == var_count );
|
2065 | (void)row_len;
|
2066 |
|
2067 | int class_count = class_labels ? class_labels->cols :
|
2068 | params.svm_type == ONE_CLASS ? 1 : 0;
|
2069 |
|
2070 | float result = 0;
|
2071 | cv::AutoBuffer<float> _buffer(sv_total + (class_count+1)*2);
|
2072 | float* buffer = _buffer;
|
2073 |
|
2074 | if( params.svm_type == EPS_SVR ||
|
2075 | params.svm_type == NU_SVR ||
|
2076 | params.svm_type == ONE_CLASS )
|
2077 | {
|
2078 | CvSVMDecisionFunc* df = (CvSVMDecisionFunc*)decision_func;
|
2079 | int i, sv_count = df->sv_count;
|
2080 | double sum = -df->rho;
|
2081 |
|
2082 | kernel->calc( sv_count, var_count, (const float**)sv, row_sample, buffer );
|
2083 | for( i = 0; i < sv_count; i++ )
|
2084 | sum += buffer[i]*df->alpha[i];
|
2085 |
|
2086 | result = params.svm_type == ONE_CLASS ? (float)(sum > 0) : (float)sum;
|
2087 | }
|
2088 | else if( params.svm_type == C_SVC ||
|
2089 | params.svm_type == NU_SVC )
|
2090 | {
|
2091 | CvSVMDecisionFunc* df = (CvSVMDecisionFunc*)decision_func;
|
2092 | int* vote = (int*)(buffer + sv_total);
|
2093 | int i, j, k;
|
2094 |
|
2095 | memset( vote, 0, class_count*sizeof(vote[0]));
|
2096 | kernel->calc( sv_total, var_count, (const float**)sv, row_sample, buffer );
|
2097 | double sum = 0.;
|
2098 |
|
2099 | for( i = 0; i < class_count; i++ )
|
2100 | {
|
2101 | for( j = i+1; j < class_count; j++, df++ )
|
2102 | {
|
2103 | sum = -df->rho;
|
2104 | int sv_count = df->sv_count;
|
2105 | for( k = 0; k < sv_count; k++ )
|
2106 | sum += df->alpha[k]*buffer[df->sv_index[k]];
|
2107 |
|
2108 | vote[sum > 0 ? i : j]++;
|
2109 | }
|
2110 | }
|
2111 |
|
2112 | for( i = 1, k = 0; i < class_count; i++ )
|
2113 | {
|
2114 | if( vote[i] > vote[k] )
|
2115 | k = i;
|
2116 | }
|
2117 | result = returnDFVal && class_count == 2 ? (float)sum : (float)(class_labels->data.i[k]);
|
2118 | }
|
2119 | else
|
2120 | CV_Error( CV_StsBadArg, "INTERNAL ERROR: Unknown SVM type, "
|
2121 | "the SVM structure is probably corrupted" );
|
2122 |
|
2123 | return result;
|
2124 | }
|
2125 |
|
2126 | float CvSVM::predict( const CvMat* sample, bool returnDFVal ) const
|
2127 | {
|
2128 | float result = 0;
|
2129 | float* row_sample = 0;
|
2130 |
|
2131 | CV_FUNCNAME( "CvSVM::predict" );
|
2132 |
|
2133 | __BEGIN__;
|
2134 |
|
2135 | int class_count;
|
2136 |
|
2137 | if( !kernel )
|
2138 | CV_ERROR( CV_StsBadArg, "The SVM should be trained first" );
|
2139 |
|
2140 | class_count = class_labels ? class_labels->cols :
|
2141 | params.svm_type == ONE_CLASS ? 1 : 0;
|
2142 |
|
2143 | CV_CALL( cvPreparePredictData( sample, var_all, var_idx,
|
2144 | class_count, 0, &row_sample ));
|
2145 | result = predict( row_sample, get_var_count(), returnDFVal );
|
2146 |
|
2147 | __END__;
|
2148 |
|
2149 | if( sample && (!CV_IS_MAT(sample) || sample->data.fl != row_sample) )
|
2150 | cvFree( &row_sample );
|
2151 |
|
2152 | return result;
|
2153 | }
|
2154 |
|
2155 | struct predict_body_svm : ParallelLoopBody {
|
2156 | predict_body_svm(const CvSVM* _pointer, float* _result, const CvMat* _samples, CvMat* _results)
|
2157 | {
|
2158 | pointer = _pointer;
|
2159 | result = _result;
|
2160 | samples = _samples;
|
2161 | results = _results;
|
2162 | }
|
2163 |
|
2164 | const CvSVM* pointer;
|
2165 | float* result;
|
2166 | const CvMat* samples;
|
2167 | CvMat* results;
|
2168 |
|
2169 | void operator()( const cv::Range& range ) const
|
2170 | {
|
2171 | for(int i = range.start; i < range.end; i++ )
|
2172 | {
|
2173 | CvMat sample;
|
2174 | cvGetRow( samples, &sample, i );
|
2175 | int r = (int)pointer->predict(&sample);
|
2176 | if (results)
|
2177 | results->data.fl[i] = (float)r;
|
2178 | if (i == 0)
|
2179 | *result = (float)r;
|
2180 | }
|
2181 | }
|
2182 | };
|
2183 |
|
2184 | float CvSVM::predict(const CvMat* samples, CV_OUT CvMat* results) const
|
2185 | {
|
2186 | float result = 0;
|
2187 | cv::parallel_for_(cv::Range(0, samples->rows),
|
2188 | predict_body_svm(this, &result, samples, results)
|
2189 | );
|
2190 | return result;
|
2191 | }
|
2192 |
|
2193 | void CvSVM::predict( cv::InputArray _samples, cv::OutputArray _results ) const
|
2194 | {
|
2195 | _results.create(_samples.size().height, 1, CV_32F);
|
2196 | CvMat samples = _samples.getMat(), results = _results.getMat();
|
2197 | predict(&samples, &results);
|
2198 | }
|
2199 |
|
2200 | CvSVM::CvSVM( const Mat& _train_data, const Mat& _responses,
|
2201 | const Mat& _var_idx, const Mat& _sample_idx, CvSVMParams _params )
|
2202 | {
|
2203 | decision_func = 0;
|
2204 | class_labels = 0;
|
2205 | class_weights = 0;
|
2206 | storage = 0;
|
2207 | var_idx = 0;
|
2208 | kernel = 0;
|
2209 | solver = 0;
|
2210 | default_model_name = "my_svm";
|
2211 |
|
2212 | train( _train_data, _responses, _var_idx, _sample_idx, _params );
|
2213 | }
|
2214 |
|
2215 | bool CvSVM::train( const Mat& _train_data, const Mat& _responses,
|
2216 | const Mat& _var_idx, const Mat& _sample_idx, CvSVMParams _params )
|
2217 | {
|
2218 | CvMat tdata = _train_data, responses = _responses, vidx = _var_idx, sidx = _sample_idx;
|
2219 | return train(&tdata, &responses, vidx.data.ptr ? &vidx : 0, sidx.data.ptr ? &sidx : 0, _params);
|
2220 | }
|
2221 |
|
2222 |
|
2223 | bool CvSVM::train_auto( const Mat& _train_data, const Mat& _responses,
|
2224 | const Mat& _var_idx, const Mat& _sample_idx, CvSVMParams _params, int k_fold,
|
2225 | CvParamGrid C_grid, CvParamGrid gamma_grid, CvParamGrid p_grid,
|
2226 | CvParamGrid nu_grid, CvParamGrid coef_grid, CvParamGrid degree_grid, bool balanced )
|
2227 | {
|
2228 | CvMat tdata = _train_data, responses = _responses, vidx = _var_idx, sidx = _sample_idx;
|
2229 | return train_auto(&tdata, &responses, vidx.data.ptr ? &vidx : 0,
|
2230 | sidx.data.ptr ? &sidx : 0, _params, k_fold, C_grid, gamma_grid, p_grid,
|
2231 | nu_grid, coef_grid, degree_grid, balanced);
|
2232 | }
|
2233 |
|
2234 | float CvSVM::predict( const Mat& _sample, bool returnDFVal ) const
|
2235 | {
|
2236 | CvMat sample = _sample;
|
2237 | return predict(&sample, returnDFVal);
|
2238 | }
|
2239 |
|
2240 |
|
2241 | void CvSVM::write_params( CvFileStorage* fs ) const
|
2242 | {
|
2243 |
|
2244 |
|
2245 | __BEGIN__;
|
2246 |
|
2247 | int svm_type = params.svm_type;
|
2248 | int kernel_type = params.kernel_type;
|
2249 |
|
2250 | const char* svm_type_str =
|
2251 | svm_type == CvSVM::C_SVC ? "C_SVC" :
|
2252 | svm_type == CvSVM::NU_SVC ? "NU_SVC" :
|
2253 | svm_type == CvSVM::ONE_CLASS ? "ONE_CLASS" :
|
2254 | svm_type == CvSVM::EPS_SVR ? "EPS_SVR" :
|
2255 | svm_type == CvSVM::NU_SVR ? "NU_SVR" : 0;
|
2256 | const char* kernel_type_str =
|
2257 | kernel_type == CvSVM::LINEAR ? "LINEAR" :
|
2258 | kernel_type == CvSVM::POLY ? "POLY" :
|
2259 | kernel_type == CvSVM::RBF ? "RBF" :
|
2260 | kernel_type == CvSVM::SIGMOID ? "SIGMOID" : 0;
|
2261 |
|
2262 | if( svm_type_str )
|
2263 | cvWriteString( fs, "svm_type", svm_type_str );
|
2264 | else
|
2265 | cvWriteInt( fs, "svm_type", svm_type );
|
2266 |
|
2267 |
|
2268 | cvStartWriteStruct( fs, "kernel", CV_NODE_MAP + CV_NODE_FLOW );
|
2269 |
|
2270 | if( kernel_type_str )
|
2271 | cvWriteString( fs, "type", kernel_type_str );
|
2272 | else
|
2273 | cvWriteInt( fs, "type", kernel_type );
|
2274 |
|
2275 | if( kernel_type == CvSVM::POLY || !kernel_type_str )
|
2276 | cvWriteReal( fs, "degree", params.degree );
|
2277 |
|
2278 | if( kernel_type != CvSVM::LINEAR || !kernel_type_str )
|
2279 | cvWriteReal( fs, "gamma", params.gamma );
|
2280 |
|
2281 | if( kernel_type == CvSVM::POLY || kernel_type == CvSVM::SIGMOID || !kernel_type_str )
|
2282 | cvWriteReal( fs, "coef0", params.coef0 );
|
2283 |
|
2284 | cvEndWriteStruct(fs);
|
2285 |
|
2286 | if( svm_type == CvSVM::C_SVC || svm_type == CvSVM::EPS_SVR ||
|
2287 | svm_type == CvSVM::NU_SVR || !svm_type_str )
|
2288 | cvWriteReal( fs, "C", params.C );
|
2289 |
|
2290 | if( svm_type == CvSVM::NU_SVC || svm_type == CvSVM::ONE_CLASS ||
|
2291 | svm_type == CvSVM::NU_SVR || !svm_type_str )
|
2292 | cvWriteReal( fs, "nu", params.nu );
|
2293 |
|
2294 | if( svm_type == CvSVM::EPS_SVR || !svm_type_str )
|
2295 | cvWriteReal( fs, "p", params.p );
|
2296 |
|
2297 | cvStartWriteStruct( fs, "term_criteria", CV_NODE_MAP + CV_NODE_FLOW );
|
2298 | if( params.term_crit.type & CV_TERMCRIT_EPS )
|
2299 | cvWriteReal( fs, "epsilon", params.term_crit.epsilon );
|
2300 | if( params.term_crit.type & CV_TERMCRIT_ITER )
|
2301 | cvWriteInt( fs, "iterations", params.term_crit.max_iter );
|
2302 | cvEndWriteStruct( fs );
|
2303 |
|
2304 | __END__;
|
2305 | }
|
2306 |
|
2307 |
|
2308 | static bool isSvmModelApplicable(int sv_total, int var_all, int var_count, int class_count)
|
2309 | {
|
2310 | return (sv_total > 0 && var_count > 0 && var_count <= var_all && class_count >= 0);
|
2311 | }
|
2312 |
|
2313 |
|
2314 | void CvSVM::write( CvFileStorage* fs, const char* name ) const
|
2315 | {
|
2316 | CV_FUNCNAME( "CvSVM::write" );
|
2317 |
|
2318 | __BEGIN__;
|
2319 |
|
2320 | int i, var_count = get_var_count(), df_count;
|
2321 | int class_count = class_labels ? class_labels->cols :
|
2322 | params.svm_type == CvSVM::ONE_CLASS ? 1 : 0;
|
2323 | const CvSVMDecisionFunc* df = decision_func;
|
2324 | if( !isSvmModelApplicable(sv_total, var_all, var_count, class_count) )
|
2325 | CV_ERROR( CV_StsParseError, "SVM model data is invalid, check sv_count, var_* and class_count tags" );
|
2326 |
|
2327 | cvStartWriteStruct( fs, name, CV_NODE_MAP, CV_TYPE_NAME_ML_SVM );
|
2328 |
|
2329 | write_params( fs );
|
2330 |
|
2331 | cvWriteInt( fs, "var_all", var_all );
|
2332 | cvWriteInt( fs, "var_count", var_count );
|
2333 |
|
2334 | if( class_count )
|
2335 | {
|
2336 | cvWriteInt( fs, "class_count", class_count );
|
2337 |
|
2338 | if( class_labels )
|
2339 | cvWrite( fs, "class_labels", class_labels );
|
2340 |
|
2341 | if( class_weights )
|
2342 | cvWrite( fs, "class_weights", class_weights );
|
2343 | }
|
2344 |
|
2345 | if( var_idx )
|
2346 | cvWrite( fs, "var_idx", var_idx );
|
2347 |
|
2348 |
|
2349 | cvWriteInt( fs, "sv_total", sv_total );
|
2350 | cvStartWriteStruct( fs, "support_vectors", CV_NODE_SEQ );
|
2351 | for( i = 0; i < sv_total; i++ )
|
2352 | {
|
2353 | cvStartWriteStruct( fs, 0, CV_NODE_SEQ + CV_NODE_FLOW );
|
2354 | cvWriteRawData( fs, sv[i], var_count, "f" );
|
2355 | cvEndWriteStruct( fs );
|
2356 | }
|
2357 |
|
2358 | cvEndWriteStruct( fs );
|
2359 |
|
2360 |
|
2361 | df_count = class_count > 1 ? class_count*(class_count-1)/2 : 1;
|
2362 | df = decision_func;
|
2363 |
|
2364 | cvStartWriteStruct( fs, "decision_functions", CV_NODE_SEQ );
|
2365 | for( i = 0; i < df_count; i++ )
|
2366 | {
|
2367 | int sv_count = df[i].sv_count;
|
2368 | cvStartWriteStruct( fs, 0, CV_NODE_MAP );
|
2369 | cvWriteInt( fs, "sv_count", sv_count );
|
2370 | cvWriteReal( fs, "rho", df[i].rho );
|
2371 | cvStartWriteStruct( fs, "alpha", CV_NODE_SEQ+CV_NODE_FLOW );
|
2372 | cvWriteRawData( fs, df[i].alpha, df[i].sv_count, "d" );
|
2373 | cvEndWriteStruct( fs );
|
2374 | if( class_count > 1 )
|
2375 | {
|
2376 | cvStartWriteStruct( fs, "index", CV_NODE_SEQ+CV_NODE_FLOW );
|
2377 | cvWriteRawData( fs, df[i].sv_index, df[i].sv_count, "i" );
|
2378 | cvEndWriteStruct( fs );
|
2379 | }
|
2380 | else
|
2381 | CV_ASSERT( sv_count == sv_total );
|
2382 | cvEndWriteStruct( fs );
|
2383 | }
|
2384 | cvEndWriteStruct( fs );
|
2385 | cvEndWriteStruct( fs );
|
2386 |
|
2387 | __END__;
|
2388 | }
|
2389 |
|
2390 |
|
2391 | void CvSVM::read_params( CvFileStorage* fs, CvFileNode* svm_node )
|
2392 | {
|
2393 | CV_FUNCNAME( "CvSVM::read_params" );
|
2394 |
|
2395 | __BEGIN__;
|
2396 |
|
2397 | int svm_type, kernel_type;
|
2398 | CvSVMParams _params;
|
2399 |
|
2400 | CvFileNode* tmp_node = cvGetFileNodeByName( fs, svm_node, "svm_type" );
|
2401 | CvFileNode* kernel_node;
|
2402 | if( !tmp_node )
|
2403 | CV_ERROR( CV_StsBadArg, "svm_type tag is not found" );
|
2404 |
|
2405 | if( CV_NODE_TYPE(tmp_node->tag) == CV_NODE_INT )
|
2406 | svm_type = cvReadInt( tmp_node, -1 );
|
2407 | else
|
2408 | {
|
2409 | const char* svm_type_str = cvReadString( tmp_node, "" );
|
2410 | svm_type =
|
2411 | strcmp( svm_type_str, "C_SVC" ) == 0 ? CvSVM::C_SVC :
|
2412 | strcmp( svm_type_str, "NU_SVC" ) == 0 ? CvSVM::NU_SVC :
|
2413 | strcmp( svm_type_str, "ONE_CLASS" ) == 0 ? CvSVM::ONE_CLASS :
|
2414 | strcmp( svm_type_str, "EPS_SVR" ) == 0 ? CvSVM::EPS_SVR :
|
2415 | strcmp( svm_type_str, "NU_SVR" ) == 0 ? CvSVM::NU_SVR : -1;
|
2416 |
|
2417 | if( svm_type < 0 )
|
2418 | CV_ERROR( CV_StsParseError, "Missing of invalid SVM type" );
|
2419 | }
|
2420 |
|
2421 | kernel_node = cvGetFileNodeByName( fs, svm_node, "kernel" );
|
2422 | if( !kernel_node )
|
2423 | CV_ERROR( CV_StsParseError, "SVM kernel tag is not found" );
|
2424 |
|
2425 | tmp_node = cvGetFileNodeByName( fs, kernel_node, "type" );
|
2426 | if( !tmp_node )
|
2427 | CV_ERROR( CV_StsParseError, "SVM kernel type tag is not found" );
|
2428 |
|
2429 | if( CV_NODE_TYPE(tmp_node->tag) == CV_NODE_INT )
|
2430 | kernel_type = cvReadInt( tmp_node, -1 );
|
2431 | else
|
2432 | {
|
2433 | const char* kernel_type_str = cvReadString( tmp_node, "" );
|
2434 | kernel_type =
|
2435 | strcmp( kernel_type_str, "LINEAR" ) == 0 ? CvSVM::LINEAR :
|
2436 | strcmp( kernel_type_str, "POLY" ) == 0 ? CvSVM::POLY :
|
2437 | strcmp( kernel_type_str, "RBF" ) == 0 ? CvSVM::RBF :
|
2438 | strcmp( kernel_type_str, "SIGMOID" ) == 0 ? CvSVM::SIGMOID : -1;
|
2439 |
|
2440 | if( kernel_type < 0 )
|
2441 | CV_ERROR( CV_StsParseError, "Missing of invalid SVM kernel type" );
|
2442 | }
|
2443 |
|
2444 | _params.svm_type = svm_type;
|
2445 | _params.kernel_type = kernel_type;
|
2446 | _params.degree = cvReadRealByName( fs, kernel_node, "degree", 0 );
|
2447 | _params.gamma = cvReadRealByName( fs, kernel_node, "gamma", 0 );
|
2448 | _params.coef0 = cvReadRealByName( fs, kernel_node, "coef0", 0 );
|
2449 |
|
2450 | _params.C = cvReadRealByName( fs, svm_node, "C", 0 );
|
2451 | _params.nu = cvReadRealByName( fs, svm_node, "nu", 0 );
|
2452 | _params.p = cvReadRealByName( fs, svm_node, "p", 0 );
|
2453 | _params.class_weights = 0;
|
2454 |
|
2455 | tmp_node = cvGetFileNodeByName( fs, svm_node, "term_criteria" );
|
2456 | if( tmp_node )
|
2457 | {
|
2458 | _params.term_crit.epsilon = cvReadRealByName( fs, tmp_node, "epsilon", -1. );
|
2459 | _params.term_crit.max_iter = cvReadIntByName( fs, tmp_node, "iterations", -1 );
|
2460 | _params.term_crit.type = (_params.term_crit.epsilon >= 0 ? CV_TERMCRIT_EPS : 0) +
|
2461 | (_params.term_crit.max_iter >= 0 ? CV_TERMCRIT_ITER : 0);
|
2462 | }
|
2463 | else
|
2464 | _params.term_crit = cvTermCriteria( CV_TERMCRIT_EPS + CV_TERMCRIT_ITER, 1000, FLT_EPSILON );
|
2465 |
|
2466 | set_params( _params );
|
2467 |
|
2468 | __END__;
|
2469 | }
|
2470 |
|
2471 | void CvSVM::read( CvFileStorage* fs, CvFileNode* svm_node )
|
2472 | {
|
2473 | const double not_found_dbl = DBL_MAX;
|
2474 |
|
2475 | CV_FUNCNAME( "CvSVM::read" );
|
2476 |
|
2477 | __BEGIN__;
|
2478 |
|
2479 | int i, var_count, df_count, class_count;
|
2480 | int block_size = 1 << 16, sv_size;
|
2481 | CvFileNode *sv_node, *df_node;
|
2482 | CvSVMDecisionFunc* df;
|
2483 | CvSeqReader reader;
|
2484 |
|
2485 | if( !svm_node )
|
2486 | CV_ERROR( CV_StsParseError, "The requested element is not found" );
|
2487 |
|
2488 | clear();
|
2489 |
|
2490 |
|
2491 | read_params( fs, svm_node );
|
2492 |
|
2493 |
|
2494 | sv_total = cvReadIntByName( fs, svm_node, "sv_total", -1 );
|
2495 | var_all = cvReadIntByName( fs, svm_node, "var_all", -1 );
|
2496 | var_count = cvReadIntByName( fs, svm_node, "var_count", var_all );
|
2497 | class_count = cvReadIntByName( fs, svm_node, "class_count", 0 );
|
2498 |
|
2499 | if( !isSvmModelApplicable(sv_total, var_all, var_count, class_count) )
|
2500 | CV_ERROR( CV_StsParseError, "SVM model data is invalid, check sv_count, var_* and class_count tags" );
|
2501 |
|
2502 | CV_CALL( class_labels = (CvMat*)cvReadByName( fs, svm_node, "class_labels" ));
|
2503 | CV_CALL( class_weights = (CvMat*)cvReadByName( fs, svm_node, "class_weights" ));
|
2504 | CV_CALL( var_idx = (CvMat*)cvReadByName( fs, svm_node, "var_idx" ));
|
2505 |
|
2506 | if( class_count > 1 && (!class_labels ||
|
2507 | !CV_IS_MAT(class_labels) || class_labels->cols != class_count))
|
2508 | CV_ERROR( CV_StsParseError, "Array of class labels is missing or invalid" );
|
2509 |
|
2510 | if( var_count < var_all && (!var_idx || !CV_IS_MAT(var_idx) || var_idx->cols != var_count) )
|
2511 | CV_ERROR( CV_StsParseError, "var_idx array is missing or invalid" );
|
2512 |
|
2513 |
|
2514 | sv_node = cvGetFileNodeByName( fs, svm_node, "support_vectors" );
|
2515 | if( !sv_node || !CV_NODE_IS_SEQ(sv_node->tag))
|
2516 | CV_ERROR( CV_StsParseError, "Missing or invalid sequence of support vectors" );
|
2517 |
|
2518 | block_size = MAX( block_size, sv_total*(int)sizeof(CvSVMKernelRow));
|
2519 | block_size = MAX( block_size, sv_total*2*(int)sizeof(double));
|
2520 | block_size = MAX( block_size, var_all*(int)sizeof(double));
|
2521 |
|
2522 | CV_CALL( storage = cvCreateMemStorage(block_size + sizeof(CvMemBlock) + sizeof(CvSeqBlock)));
|
2523 | CV_CALL( sv = (float**)cvMemStorageAlloc( storage,
|
2524 | sv_total*sizeof(sv[0]) ));
|
2525 |
|
2526 | CV_CALL( cvStartReadSeq( sv_node->data.seq, &reader, 0 ));
|
2527 | sv_size = var_count*sizeof(sv[0][0]);
|
2528 |
|
2529 | for( i = 0; i < sv_total; i++ )
|
2530 | {
|
2531 | CvFileNode* sv_elem = (CvFileNode*)reader.ptr;
|
2532 | CV_ASSERT( var_count == 1 || (CV_NODE_IS_SEQ(sv_elem->tag) &&
|
2533 | sv_elem->data.seq->total == var_count) );
|
2534 |
|
2535 | CV_CALL( sv[i] = (float*)cvMemStorageAlloc( storage, sv_size ));
|
2536 | CV_CALL( cvReadRawData( fs, sv_elem, sv[i], "f" ));
|
2537 | CV_NEXT_SEQ_ELEM( sv_node->data.seq->elem_size, reader );
|
2538 | }
|
2539 |
|
2540 |
|
2541 | df_count = class_count > 1 ? class_count*(class_count-1)/2 : 1;
|
2542 | df_node = cvGetFileNodeByName( fs, svm_node, "decision_functions" );
|
2543 | if( !df_node || !CV_NODE_IS_SEQ(df_node->tag) ||
|
2544 | df_node->data.seq->total != df_count )
|
2545 | CV_ERROR( CV_StsParseError, "decision_functions is missing or is not a collection "
|
2546 | "or has a wrong number of elements" );
|
2547 |
|
2548 | CV_CALL( df = decision_func = (CvSVMDecisionFunc*)cvAlloc( df_count*sizeof(df[0]) ));
|
2549 | cvStartReadSeq( df_node->data.seq, &reader, 0 );
|
2550 |
|
2551 | for( i = 0; i < df_count; i++ )
|
2552 | {
|
2553 | CvFileNode* df_elem = (CvFileNode*)reader.ptr;
|
2554 | CvFileNode* alpha_node = cvGetFileNodeByName( fs, df_elem, "alpha" );
|
2555 |
|
2556 | int sv_count = cvReadIntByName( fs, df_elem, "sv_count", -1 );
|
2557 | if( sv_count <= 0 )
|
2558 | CV_ERROR( CV_StsParseError, "sv_count is missing or non-positive" );
|
2559 | df[i].sv_count = sv_count;
|
2560 |
|
2561 | df[i].rho = cvReadRealByName( fs, df_elem, "rho", not_found_dbl );
|
2562 | if( fabs(df[i].rho - not_found_dbl) < DBL_EPSILON )
|
2563 | CV_ERROR( CV_StsParseError, "rho is missing" );
|
2564 |
|
2565 | if( !alpha_node )
|
2566 | CV_ERROR( CV_StsParseError, "alpha is missing in the decision function" );
|
2567 |
|
2568 | CV_CALL( df[i].alpha = (double*)cvMemStorageAlloc( storage,
|
2569 | sv_count*sizeof(df[i].alpha[0])));
|
2570 | CV_ASSERT( sv_count == 1 || (CV_NODE_IS_SEQ(alpha_node->tag) &&
|
2571 | alpha_node->data.seq->total == sv_count) );
|
2572 | CV_CALL( cvReadRawData( fs, alpha_node, df[i].alpha, "d" ));
|
2573 |
|
2574 | if( class_count > 1 )
|
2575 | {
|
2576 | CvFileNode* index_node = cvGetFileNodeByName( fs, df_elem, "index" );
|
2577 | if( !index_node )
|
2578 | CV_ERROR( CV_StsParseError, "index is missing in the decision function" );
|
2579 | CV_CALL( df[i].sv_index = (int*)cvMemStorageAlloc( storage,
|
2580 | sv_count*sizeof(df[i].sv_index[0])));
|
2581 | CV_ASSERT( sv_count == 1 || (CV_NODE_IS_SEQ(index_node->tag) &&
|
2582 | index_node->data.seq->total == sv_count) );
|
2583 | CV_CALL( cvReadRawData( fs, index_node, df[i].sv_index, "i" ));
|
2584 | }
|
2585 | else
|
2586 | df[i].sv_index = 0;
|
2587 |
|
2588 | CV_NEXT_SEQ_ELEM( df_node->data.seq->elem_size, reader );
|
2589 | }
|
2590 |
|
2591 | if( cvReadIntByName(fs, svm_node, "optimize_linear", 1) != 0 )
|
2592 | optimize_linear_svm();
|
2593 | create_kernel();
|
2594 |
|
2595 | __END__;
|
2596 | }
|
2597 |
|
2598 | |
2599 |
|
2600 | static void*
|
2601 | icvCloneSVM( const void* _src )
|
2602 | {
|
2603 | CvSVMModel* dst = 0;
|
2604 |
|
2605 | CV_FUNCNAME( "icvCloneSVM" );
|
2606 |
|
2607 | __BEGIN__;
|
2608 |
|
2609 | const CvSVMModel* src = (const CvSVMModel*)_src;
|
2610 | int var_count, class_count;
|
2611 | int i, sv_total, df_count;
|
2612 | int sv_size;
|
2613 |
|
2614 | if( !CV_IS_SVM(src) )
|
2615 | CV_ERROR( !src ? CV_StsNullPtr : CV_StsBadArg, "Input pointer is NULL or invalid" );
|
2616 |
|
2617 | // 0. create initial CvSVMModel structure
|
2618 | CV_CALL( dst = icvCreateSVM() );
|
2619 | dst->params = src->params;
|
2620 | dst->params.weight_labels = 0;
|
2621 | dst->params.weights = 0;
|
2622 |
|
2623 | dst->var_all = src->var_all;
|
2624 | if( src->class_labels )
|
2625 | dst->class_labels = cvCloneMat( src->class_labels );
|
2626 | if( src->class_weights )
|
2627 | dst->class_weights = cvCloneMat( src->class_weights );
|
2628 | if( src->comp_idx )
|
2629 | dst->comp_idx = cvCloneMat( src->comp_idx );
|
2630 |
|
2631 | var_count = src->comp_idx ? src->comp_idx->cols : src->var_all;
|
2632 | class_count = src->class_labels ? src->class_labels->cols :
|
2633 | src->params.svm_type == CvSVM::ONE_CLASS ? 1 : 0;
|
2634 | sv_total = dst->sv_total = src->sv_total;
|
2635 | CV_CALL( dst->storage = cvCreateMemStorage( src->storage->block_size ));
|
2636 | CV_CALL( dst->sv = (float**)cvMemStorageAlloc( dst->storage,
|
2637 | sv_total*sizeof(dst->sv[0]) ));
|
2638 |
|
2639 | sv_size = var_count*sizeof(dst->sv[0][0]);
|
2640 |
|
2641 | for( i = 0; i < sv_total; i++ )
|
2642 | {
|
2643 | CV_CALL( dst->sv[i] = (float*)cvMemStorageAlloc( dst->storage, sv_size ));
|
2644 | memcpy( dst->sv[i], src->sv[i], sv_size );
|
2645 | }
|
2646 |
|
2647 | df_count = class_count > 1 ? class_count*(class_count-1)/2 : 1;
|
2648 |
|
2649 | CV_CALL( dst->decision_func = cvAlloc( df_count*sizeof(CvSVMDecisionFunc) ));
|
2650 |
|
2651 | for( i = 0; i < df_count; i++ )
|
2652 | {
|
2653 | const CvSVMDecisionFunc *sdf =
|
2654 | (const CvSVMDecisionFunc*)src->decision_func+i;
|
2655 | CvSVMDecisionFunc *ddf =
|
2656 | (CvSVMDecisionFunc*)dst->decision_func+i;
|
2657 | int sv_count = sdf->sv_count;
|
2658 | ddf->sv_count = sv_count;
|
2659 | ddf->rho = sdf->rho;
|
2660 | CV_CALL( ddf->alpha = (double*)cvMemStorageAlloc( dst->storage,
|
2661 | sv_count*sizeof(ddf->alpha[0])));
|
2662 | memcpy( ddf->alpha, sdf->alpha, sv_count*sizeof(ddf->alpha[0]));
|
2663 |
|
2664 | if( class_count > 1 )
|
2665 | {
|
2666 | CV_CALL( ddf->sv_index = (int*)cvMemStorageAlloc( dst->storage,
|
2667 | sv_count*sizeof(ddf->sv_index[0])));
|
2668 | memcpy( ddf->sv_index, sdf->sv_index, sv_count*sizeof(ddf->sv_index[0]));
|
2669 | }
|
2670 | else
|
2671 | ddf->sv_index = 0;
|
2672 | }
|
2673 |
|
2674 | __END__;
|
2675 |
|
2676 | if( cvGetErrStatus() < 0 && dst )
|
2677 | icvReleaseSVM( &dst );
|
2678 |
|
2679 | return dst;
|
2680 | }
|
2681 |
|
2682 | static int icvRegisterSVMType()
|
2683 | {
|
2684 | CvTypeInfo info;
|
2685 | memset( &info, 0, sizeof(info) );
|
2686 |
|
2687 | info.flags = 0;
|
2688 | info.header_size = sizeof( info );
|
2689 | info.is_instance = icvIsSVM;
|
2690 | info.release = (CvReleaseFunc)icvReleaseSVM;
|
2691 | info.read = icvReadSVM;
|
2692 | info.write = icvWriteSVM;
|
2693 | info.clone = icvCloneSVM;
|
2694 | info.type_name = CV_TYPE_NAME_ML_SVM;
|
2695 | cvRegisterType( &info );
|
2696 |
|
2697 | return 1;
|
2698 | }
|
2699 |
|
2700 |
|
2701 | static int svm = icvRegisterSVMType();
|
2702 |
|
2703 | /* The function trains SVM model with optimal parameters, obtained by using cross-validation.
|
2704 | The parameters to be estimated should be indicated by setting theirs values to FLT_MAX.
|
2705 | The optimal parameters are saved in <model_params> */
|
2706 | CV_IMPL CvStatModel*
|
2707 | cvTrainSVM_CrossValidation( const CvMat* train_data, int tflag,
|
2708 | const CvMat* responses,
|
2709 | CvStatModelParams* model_params,
|
2710 | const CvStatModelParams* cross_valid_params,
|
2711 | const CvMat* comp_idx,
|
2712 | const CvMat* sample_idx,
|
2713 | const CvParamGrid* degree_grid,
|
2714 | const CvParamGrid* gamma_grid,
|
2715 | const CvParamGrid* coef_grid,
|
2716 | const CvParamGrid* C_grid,
|
2717 | const CvParamGrid* nu_grid,
|
2718 | const CvParamGrid* p_grid )
|
2719 | {
|
2720 | CvStatModel* svm = 0;
|
2721 |
|
2722 | CV_FUNCNAME("cvTainSVMCrossValidation");
|
2723 | __BEGIN__;
|
2724 |
|
2725 | double degree_step = 7,
|
2726 | g_step = 15,
|
2727 | coef_step = 14,
|
2728 | C_step = 20,
|
2729 | nu_step = 5,
|
2730 | p_step = 7; // all steps must be > 1
|
2731 | double degree_begin = 0.01, degree_end = 2;
|
2732 | double g_begin = 1e-5, g_end = 0.5;
|
2733 | double coef_begin = 0.1, coef_end = 300;
|
2734 | double C_begin = 0.1, C_end = 6000;
|
2735 | double nu_begin = 0.01, nu_end = 0.4;
|
2736 | double p_begin = 0.01, p_end = 100;
|
2737 |
|
2738 | double rate = 0, gamma = 0, C = 0, degree = 0, coef = 0, p = 0, nu = 0;
|
2739 |
|
2740 | double best_rate = 0;
|
2741 | double best_degree = degree_begin;
|
2742 | double best_gamma = g_begin;
|
2743 | double best_coef = coef_begin;
|
2744 | double best_C = C_begin;
|
2745 | double best_nu = nu_begin;
|
2746 | double best_p = p_begin;
|
2747 |
|
2748 | CvSVMModelParams svm_params, *psvm_params;
|
2749 | CvCrossValidationParams* cv_params = (CvCrossValidationParams*)cross_valid_params;
|
2750 | int svm_type, kernel;
|
2751 | int is_regression;
|
2752 |
|
2753 | if( !model_params )
|
2754 | CV_ERROR( CV_StsBadArg, "" );
|
2755 | if( !cv_params )
|
2756 | CV_ERROR( CV_StsBadArg, "" );
|
2757 |
|
2758 | svm_params = *(CvSVMModelParams*)model_params;
|
2759 | psvm_params = (CvSVMModelParams*)model_params;
|
2760 | svm_type = svm_params.svm_type;
|
2761 | kernel = svm_params.kernel_type;
|
2762 |
|
2763 | svm_params.degree = svm_params.degree > 0 ? svm_params.degree : 1;
|
2764 | svm_params.gamma = svm_params.gamma > 0 ? svm_params.gamma : 1;
|
2765 | svm_params.coef0 = svm_params.coef0 > 0 ? svm_params.coef0 : 1e-6;
|
2766 | svm_params.C = svm_params.C > 0 ? svm_params.C : 1;
|
2767 | svm_params.nu = svm_params.nu > 0 ? svm_params.nu : 1;
|
2768 | svm_params.p = svm_params.p > 0 ? svm_params.p : 1;
|
2769 |
|
2770 | if( degree_grid )
|
2771 | {
|
2772 | if( !(degree_grid->max_val == 0 && degree_grid->min_val == 0 &&
|
2773 | degree_grid->step == 0) )
|
2774 | {
|
2775 | if( degree_grid->min_val > degree_grid->max_val )
|
2776 | CV_ERROR( CV_StsBadArg,
|
2777 | "low bound of grid should be less then the upper one");
|
2778 | if( degree_grid->step <= 1 )
|
2779 | CV_ERROR( CV_StsBadArg, "grid step should be greater 1" );
|
2780 | degree_begin = degree_grid->min_val;
|
2781 | degree_end = degree_grid->max_val;
|
2782 | degree_step = degree_grid->step;
|
2783 | }
|
2784 | }
|
2785 | else
|
2786 | degree_begin = degree_end = svm_params.degree;
|
2787 |
|
2788 | if( gamma_grid )
|
2789 | {
|
2790 | if( !(gamma_grid->max_val == 0 && gamma_grid->min_val == 0 &&
|
2791 | gamma_grid->step == 0) )
|
2792 | {
|
2793 | if( gamma_grid->min_val > gamma_grid->max_val )
|
2794 | CV_ERROR( CV_StsBadArg,
|
2795 | "low bound of grid should be less then the upper one");
|
2796 | if( gamma_grid->step <= 1 )
|
2797 | CV_ERROR( CV_StsBadArg, "grid step should be greater 1" );
|
2798 | g_begin = gamma_grid->min_val;
|
2799 | g_end = gamma_grid->max_val;
|
2800 | g_step = gamma_grid->step;
|
2801 | }
|
2802 | }
|
2803 | else
|
2804 | g_begin = g_end = svm_params.gamma;
|
2805 |
|
2806 | if( coef_grid )
|
2807 | {
|
2808 | if( !(coef_grid->max_val == 0 && coef_grid->min_val == 0 &&
|
2809 | coef_grid->step == 0) )
|
2810 | {
|
2811 | if( coef_grid->min_val > coef_grid->max_val )
|
2812 | CV_ERROR( CV_StsBadArg,
|
2813 | "low bound of grid should be less then the upper one");
|
2814 | if( coef_grid->step <= 1 )
|
2815 | CV_ERROR( CV_StsBadArg, "grid step should be greater 1" );
|
2816 | coef_begin = coef_grid->min_val;
|
2817 | coef_end = coef_grid->max_val;
|
2818 | coef_step = coef_grid->step;
|
2819 | }
|
2820 | }
|
2821 | else
|
2822 | coef_begin = coef_end = svm_params.coef0;
|
2823 |
|
2824 | if( C_grid )
|
2825 | {
|
2826 | if( !(C_grid->max_val == 0 && C_grid->min_val == 0 && C_grid->step == 0))
|
2827 | {
|
2828 | if( C_grid->min_val > C_grid->max_val )
|
2829 | CV_ERROR( CV_StsBadArg,
|
2830 | "low bound of grid should be less then the upper one");
|
2831 | if( C_grid->step <= 1 )
|
2832 | CV_ERROR( CV_StsBadArg, "grid step should be greater 1" );
|
2833 | C_begin = C_grid->min_val;
|
2834 | C_end = C_grid->max_val;
|
2835 | C_step = C_grid->step;
|
2836 | }
|
2837 | }
|
2838 | else
|
2839 | C_begin = C_end = svm_params.C;
|
2840 |
|
2841 | if( nu_grid )
|
2842 | {
|
2843 | if(!(nu_grid->max_val == 0 && nu_grid->min_val == 0 && nu_grid->step==0))
|
2844 | {
|
2845 | if( nu_grid->min_val > nu_grid->max_val )
|
2846 | CV_ERROR( CV_StsBadArg,
|
2847 | "low bound of grid should be less then the upper one");
|
2848 | if( nu_grid->step <= 1 )
|
2849 | CV_ERROR( CV_StsBadArg, "grid step should be greater 1" );
|
2850 | nu_begin = nu_grid->min_val;
|
2851 | nu_end = nu_grid->max_val;
|
2852 | nu_step = nu_grid->step;
|
2853 | }
|
2854 | }
|
2855 | else
|
2856 | nu_begin = nu_end = svm_params.nu;
|
2857 |
|
2858 | if( p_grid )
|
2859 | {
|
2860 | if( !(p_grid->max_val == 0 && p_grid->min_val == 0 && p_grid->step == 0))
|
2861 | {
|
2862 | if( p_grid->min_val > p_grid->max_val )
|
2863 | CV_ERROR( CV_StsBadArg,
|
2864 | "low bound of grid should be less then the upper one");
|
2865 | if( p_grid->step <= 1 )
|
2866 | CV_ERROR( CV_StsBadArg, "grid step should be greater 1" );
|
2867 | p_begin = p_grid->min_val;
|
2868 | p_end = p_grid->max_val;
|
2869 | p_step = p_grid->step;
|
2870 | }
|
2871 | }
|
2872 | else
|
2873 | p_begin = p_end = svm_params.p;
|
2874 |
|
2875 | // these parameters are not used:
|
2876 | if( kernel != CvSVM::POLY )
|
2877 | degree_begin = degree_end = svm_params.degree;
|
2878 |
|
2879 | if( kernel == CvSVM::LINEAR )
|
2880 | g_begin = g_end = svm_params.gamma;
|
2881 |
|
2882 | if( kernel != CvSVM::POLY && kernel != CvSVM::SIGMOID )
|
2883 | coef_begin = coef_end = svm_params.coef0;
|
2884 |
|
2885 | if( svm_type == CvSVM::NU_SVC || svm_type == CvSVM::ONE_CLASS )
|
2886 | C_begin = C_end = svm_params.C;
|
2887 |
|
2888 | if( svm_type == CvSVM::C_SVC || svm_type == CvSVM::EPS_SVR )
|
2889 | nu_begin = nu_end = svm_params.nu;
|
2890 |
|
2891 | if( svm_type != CvSVM::EPS_SVR )
|
2892 | p_begin = p_end = svm_params.p;
|
2893 |
|
2894 | is_regression = cv_params->is_regression;
|
2895 | best_rate = is_regression ? FLT_MAX : 0;
|
2896 |
|
2897 | assert( g_step > 1 && degree_step > 1 && coef_step > 1);
|
2898 | assert( p_step > 1 && C_step > 1 && nu_step > 1 );
|
2899 |
|
2900 | for( degree = degree_begin; degree <= degree_end; degree *= degree_step )
|
2901 | {
|
2902 | svm_params.degree = degree;
|
2903 | //printf("degree = %.3f\n", degree );
|
2904 | for( gamma= g_begin; gamma <= g_end; gamma *= g_step )
|
2905 | {
|
2906 | svm_params.gamma = gamma;
|
2907 | //printf(" gamma = %.3f\n", gamma );
|
2908 | for( coef = coef_begin; coef <= coef_end; coef *= coef_step )
|
2909 | {
|
2910 | svm_params.coef0 = coef;
|
2911 | //printf(" coef = %.3f\n", coef );
|
2912 | for( C = C_begin; C <= C_end; C *= C_step )
|
2913 | {
|
2914 | svm_params.C = C;
|
2915 | //printf(" C = %.3f\n", C );
|
2916 | for( nu = nu_begin; nu <= nu_end; nu *= nu_step )
|
2917 | {
|
2918 | svm_params.nu = nu;
|
2919 | //printf(" nu = %.3f\n", nu );
|
2920 | for( p = p_begin; p <= p_end; p *= p_step )
|
2921 | {
|
2922 | int well;
|
2923 | svm_params.p = p;
|
2924 | //printf(" p = %.3f\n", p );
|
2925 |
|
2926 | CV_CALL(rate = cvCrossValidation( train_data, tflag, responses, &cvTrainSVM,
|
2927 | cross_valid_params, (CvStatModelParams*)&svm_params, comp_idx, sample_idx ));
|
2928 |
|
2929 | well = rate > best_rate && !is_regression || rate < best_rate && is_regression;
|
2930 | if( well || (rate == best_rate && C < best_C) )
|
2931 | {
|
2932 | best_rate = rate;
|
2933 | best_degree = degree;
|
2934 | best_gamma = gamma;
|
2935 | best_coef = coef;
|
2936 | best_C = C;
|
2937 | best_nu = nu;
|
2938 | best_p = p;
|
2939 | }
|
2940 | //printf(" rate = %.2f\n", rate );
|
2941 | }
|
2942 | }
|
2943 | }
|
2944 | }
|
2945 | }
|
2946 | }
|
2947 | //printf("The best:\nrate = %.2f%% degree = %f gamma = %f coef = %f c = %f nu = %f p = %f\n",
|
2948 | // best_rate, best_degree, best_gamma, best_coef, best_C, best_nu, best_p );
|
2949 |
|
2950 | psvm_params->C = best_C;
|
2951 | psvm_params->nu = best_nu;
|
2952 | psvm_params->p = best_p;
|
2953 | psvm_params->gamma = best_gamma;
|
2954 | psvm_params->degree = best_degree;
|
2955 | psvm_params->coef0 = best_coef;
|
2956 |
|
2957 | CV_CALL(svm = cvTrainSVM( train_data, tflag, responses, model_params, comp_idx, sample_idx ));
|
2958 |
|
2959 | __END__;
|
2960 |
|
2961 | return svm;
|
2962 | }
|
2963 |
|
2964 | #endif
|
2965 |
|
2966 |
|