svm.cpp

Modified the function train_auto in svm.cpp - Juan Manuel Baruffaldi, 2014-08-19 02:14 am

Download (91.1 kB)

 
1
/*M///////////////////////////////////////////////////////////////////////////////////////
2
//
3
//  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4
//
5
//  By downloading, copying, installing or using the software you agree to this license.
6
//  If you do not agree to this license, do not download, install,
7
//  copy or use the software.
8
//
9
//
10
//                        Intel License Agreement
11
//
12
// Copyright (C) 2000, Intel Corporation, all rights reserved.
13
// Third party copyrights are property of their respective owners.
14
//
15
// Redistribution and use in source and binary forms, with or without modification,
16
// are permitted provided that the following conditions are met:
17
//
18
//   * Redistribution's of source code must retain the above copyright notice,
19
//     this list of conditions and the following disclaimer.
20
//
21
//   * Redistribution's in binary form must reproduce the above copyright notice,
22
//     this list of conditions and the following disclaimer in the documentation
23
//     and/or other materials provided with the distribution.
24
//
25
//   * The name of Intel Corporation may not be used to endorse or promote products
26
//     derived from this software without specific prior written permission.
27
//
28
// This software is provided by the copyright holders and contributors "as is" and
29
// any express or implied warranties, including, but not limited to, the implied
30
// warranties of merchantability and fitness for a particular purpose are disclaimed.
31
// In no event shall the Intel Corporation or contributors be liable for any direct,
32
// indirect, incidental, special, exemplary, or consequential damages
33
// (including, but not limited to, procurement of substitute goods or services;
34
// loss of use, data, or profits; or business interruption) however caused
35
// and on any theory of liability, whether in contract, strict liability,
36
// or tort (including negligence or otherwise) arising in any way out of
37
// the use of this software, even if advised of the possibility of such damage.
38
//
39
//M*/
40
41
//Reformed function CvSVM::train_auto by Baruffaldi Juan Manuel -> [email protected]
42
43
#include "precomp.hpp"
44
45
/****************************************************************************************\
46
                                COPYRIGHT NOTICE
47
                                ----------------
48
49
  The code has been derived from libsvm library (version 2.6)
50
  (http://www.csie.ntu.edu.tw/~cjlin/libsvm).
51
52
  Here is the orignal copyright:
53
------------------------------------------------------------------------------------------
54
    Copyright (c) 2000-2003 Chih-Chung Chang and Chih-Jen Lin
55
    All rights reserved.
56
57
    Redistribution and use in source and binary forms, with or without
58
    modification, are permitted provided that the following conditions
59
    are met:
60
61
    1. Redistributions of source code must retain the above copyright
62
    notice, this list of conditions and the following disclaimer.
63
64
    2. Redistributions in binary form must reproduce the above copyright
65
    notice, this list of conditions and the following disclaimer in the
66
    documentation and/or other materials provided with the distribution.
67
68
    3. Neither name of copyright holders nor the names of its contributors
69
    may be used to endorse or promote products derived from this software
70
    without specific prior written permission.
71
72
73
    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
74
    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
75
    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
76
    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE REGENTS OR
77
    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
78
    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
79
    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
80
    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
81
    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
82
    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
83
    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
84
\****************************************************************************************/
85
86
using namespace cv;
87
88
#define CV_SVM_MIN_CACHE_SIZE  (40 << 20)  /* 40Mb */
89
90
#include <stdarg.h>
91
#include <ctype.h>
92
93
#if 1
94
typedef float Qfloat;
95
#define QFLOAT_TYPE CV_32F
96
#else
97
typedef double Qfloat;
98
#define QFLOAT_TYPE CV_64F
99
#endif
100
101
// Param Grid
102
bool CvParamGrid::check() const
103
{
104
    bool ok = false;
105
106
    CV_FUNCNAME( "CvParamGrid::check" );
107
    __BEGIN__;
108
109
    if( min_val > max_val )
110
        CV_ERROR( CV_StsBadArg, "Lower bound of the grid must be less then the upper one" );
111
    if( min_val < DBL_EPSILON )
112
        CV_ERROR( CV_StsBadArg, "Lower bound of the grid must be positive" );
113
    if( step < 1. + FLT_EPSILON )
114
        CV_ERROR( CV_StsBadArg, "Grid step must greater then 1" );
115
116
    ok = true;
117
118
    __END__;
119
120
    return ok;
121
}
122
123
CvParamGrid CvSVM::get_default_grid( int param_id )
124
{
125
    CvParamGrid grid;
126
    if( param_id == CvSVM::C )
127
    {
128
        grid.min_val = 0.1;
129
        grid.max_val = 500;
130
        grid.step = 5; // total iterations = 5
131
    }
132
    else if( param_id == CvSVM::GAMMA )
133
    {
134
        grid.min_val = 1e-5;
135
        grid.max_val = 0.6;
136
        grid.step = 15; // total iterations = 4
137
    }
138
    else if( param_id == CvSVM::P )
139
    {
140
        grid.min_val = 0.01;
141
        grid.max_val = 100;
142
        grid.step = 7; // total iterations = 4
143
    }
144
    else if( param_id == CvSVM::NU )
145
    {
146
        grid.min_val = 0.01;
147
        grid.max_val = 0.2;
148
        grid.step = 3; // total iterations = 3
149
    }
150
    else if( param_id == CvSVM::COEF )
151
    {
152
        grid.min_val = 0.1;
153
        grid.max_val = 300;
154
        grid.step = 14; // total iterations = 3
155
    }
156
    else if( param_id == CvSVM::DEGREE )
157
    {
158
        grid.min_val = 0.01;
159
        grid.max_val = 4;
160
        grid.step = 7; // total iterations = 3
161
    }
162
    else
163
        cvError( CV_StsBadArg, "CvSVM::get_default_grid", "Invalid type of parameter "
164
            "(use one of CvSVM::C, CvSVM::GAMMA et al.)", __FILE__, __LINE__ );
165
    return grid;
166
}
167
168
// SVM training parameters
169
CvSVMParams::CvSVMParams() :
170
    svm_type(CvSVM::C_SVC), kernel_type(CvSVM::RBF), degree(0),
171
    gamma(1), coef0(0), C(1), nu(0), p(0), class_weights(0)
172
{
173
    term_crit = cvTermCriteria( CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, 1000, FLT_EPSILON );
174
}
175
176
177
CvSVMParams::CvSVMParams( int _svm_type, int _kernel_type,
178
    double _degree, double _gamma, double _coef0,
179
    double _Con, double _nu, double _p,
180
    CvMat* _class_weights, CvTermCriteria _term_crit ) :
181
    svm_type(_svm_type), kernel_type(_kernel_type),
182
    degree(_degree), gamma(_gamma), coef0(_coef0),
183
    C(_Con), nu(_nu), p(_p), class_weights(_class_weights), term_crit(_term_crit)
184
{
185
}
186
187
188
/////////////////////////////////////// SVM kernel ///////////////////////////////////////
189
190
CvSVMKernel::CvSVMKernel()
191
{
192
    clear();
193
}
194
195
196
void CvSVMKernel::clear()
197
{
198
    params = 0;
199
    calc_func = 0;
200
}
201
202
203
CvSVMKernel::~CvSVMKernel()
204
{
205
}
206
207
208
CvSVMKernel::CvSVMKernel( const CvSVMParams* _params, Calc _calc_func )
209
{
210
    clear();
211
    create( _params, _calc_func );
212
}
213
214
215
bool CvSVMKernel::create( const CvSVMParams* _params, Calc _calc_func )
216
{
217
    clear();
218
    params = _params;
219
    calc_func = _calc_func;
220
221
    if( !calc_func )
222
        calc_func = params->kernel_type == CvSVM::RBF ? &CvSVMKernel::calc_rbf :
223
                    params->kernel_type == CvSVM::POLY ? &CvSVMKernel::calc_poly :
224
                    params->kernel_type == CvSVM::SIGMOID ? &CvSVMKernel::calc_sigmoid :
225
                    &CvSVMKernel::calc_linear;
226
227
    return true;
228
}
229
230
231
void CvSVMKernel::calc_non_rbf_base( int vcount, int var_count, const float** vecs,
232
                                     const float* another, Qfloat* results,
233
                                     double alpha, double beta )
234
{
235
    int j, k;
236
    for( j = 0; j < vcount; j++ )
237
    {
238
        const float* sample = vecs[j];
239
        double s = 0;
240
        for( k = 0; k <= var_count - 4; k += 4 )
241
            s += sample[k]*another[k] + sample[k+1]*another[k+1] +
242
                 sample[k+2]*another[k+2] + sample[k+3]*another[k+3];
243
        for( ; k < var_count; k++ )
244
            s += sample[k]*another[k];
245
        results[j] = (Qfloat)(s*alpha + beta);
246
    }
247
}
248
249
250
void CvSVMKernel::calc_linear( int vcount, int var_count, const float** vecs,
251
                               const float* another, Qfloat* results )
252
{
253
    calc_non_rbf_base( vcount, var_count, vecs, another, results, 1, 0 );
254
}
255
256
257
void CvSVMKernel::calc_poly( int vcount, int var_count, const float** vecs,
258
                             const float* another, Qfloat* results )
259
{
260
    CvMat R = cvMat( 1, vcount, QFLOAT_TYPE, results );
261
    calc_non_rbf_base( vcount, var_count, vecs, another, results, params->gamma, params->coef0 );
262
    if( vcount > 0 )
263
        cvPow( &R, &R, params->degree );
264
}
265
266
267
void CvSVMKernel::calc_sigmoid( int vcount, int var_count, const float** vecs,
268
                                const float* another, Qfloat* results )
269
{
270
    int j;
271
    calc_non_rbf_base( vcount, var_count, vecs, another, results,
272
                       -2*params->gamma, -2*params->coef0 );
273
    // TODO: speedup this
274
    for( j = 0; j < vcount; j++ )
275
    {
276
        Qfloat t = results[j];
277
        double e = exp(-fabs(t));
278
        if( t > 0 )
279
            results[j] = (Qfloat)((1. - e)/(1. + e));
280
        else
281
            results[j] = (Qfloat)((e - 1.)/(e + 1.));
282
    }
283
}
284
285
286
void CvSVMKernel::calc_rbf( int vcount, int var_count, const float** vecs,
287
                            const float* another, Qfloat* results )
288
{
289
    CvMat R = cvMat( 1, vcount, QFLOAT_TYPE, results );
290
    double gamma = -params->gamma;
291
    int j, k;
292
293
    for( j = 0; j < vcount; j++ )
294
    {
295
        const float* sample = vecs[j];
296
        double s = 0;
297
298
        for( k = 0; k <= var_count - 4; k += 4 )
299
        {
300
            double t0 = sample[k] - another[k];
301
            double t1 = sample[k+1] - another[k+1];
302
303
            s += t0*t0 + t1*t1;
304
305
            t0 = sample[k+2] - another[k+2];
306
            t1 = sample[k+3] - another[k+3];
307
308
            s += t0*t0 + t1*t1;
309
        }
310
311
        for( ; k < var_count; k++ )
312
        {
313
            double t0 = sample[k] - another[k];
314
            s += t0*t0;
315
        }
316
        results[j] = (Qfloat)(s*gamma);
317
    }
318
319
    if( vcount > 0 )
320
        cvExp( &R, &R );
321
}
322
323
324
void CvSVMKernel::calc( int vcount, int var_count, const float** vecs,
325
                        const float* another, Qfloat* results )
326
{
327
    const Qfloat max_val = (Qfloat)(FLT_MAX*1e-3);
328
    int j;
329
    (this->*calc_func)( vcount, var_count, vecs, another, results );
330
    for( j = 0; j < vcount; j++ )
331
    {
332
        if( results[j] > max_val )
333
            results[j] = max_val;
334
    }
335
}
336
337
338
// Generalized SMO+SVMlight algorithm
339
// Solves:
340
//
341
//  min [0.5(\alpha^T Q \alpha) + b^T \alpha]
342
//
343
//      y^T \alpha = \delta
344
//      y_i = +1 or -1
345
//      0 <= alpha_i <= Cp for y_i = 1
346
//      0 <= alpha_i <= Cn for y_i = -1
347
//
348
// Given:
349
//
350
//  Q, b, y, Cp, Cn, and an initial feasible point \alpha
351
//  l is the size of vectors and matrices
352
//  eps is the stopping criterion
353
//
354
// solution will be put in \alpha, objective value will be put in obj
355
//
356
357
void CvSVMSolver::clear()
358
{
359
    G = 0;
360
    alpha = 0;
361
    y = 0;
362
    b = 0;
363
    buf[0] = buf[1] = 0;
364
    cvReleaseMemStorage( &storage );
365
    kernel = 0;
366
    select_working_set_func = 0;
367
    calc_rho_func = 0;
368
369
    rows = 0;
370
    samples = 0;
371
    get_row_func = 0;
372
}
373
374
375
CvSVMSolver::CvSVMSolver()
376
{
377
    storage = 0;
378
    clear();
379
}
380
381
382
CvSVMSolver::~CvSVMSolver()
383
{
384
    clear();
385
}
386
387
388
CvSVMSolver::CvSVMSolver( int _sample_count, int _var_count, const float** _samples, schar* _y,
389
                int _alpha_count, double* _alpha, double _Cp, double _Cn,
390
                CvMemStorage* _storage, CvSVMKernel* _kernel, GetRow _get_row,
391
                SelectWorkingSet _select_working_set, CalcRho _calc_rho )
392
{
393
    storage = 0;
394
    create( _sample_count, _var_count, _samples, _y, _alpha_count, _alpha, _Cp, _Cn,
395
            _storage, _kernel, _get_row, _select_working_set, _calc_rho );
396
}
397
398
399
bool CvSVMSolver::create( int _sample_count, int _var_count, const float** _samples, schar* _y,
400
                int _alpha_count, double* _alpha, double _Cp, double _Cn,
401
                CvMemStorage* _storage, CvSVMKernel* _kernel, GetRow _get_row,
402
                SelectWorkingSet _select_working_set, CalcRho _calc_rho )
403
{
404
    bool ok = false;
405
    int i, svm_type;
406
407
    CV_FUNCNAME( "CvSVMSolver::create" );
408
409
    __BEGIN__;
410
411
    int rows_hdr_size;
412
413
    clear();
414
415
    sample_count = _sample_count;
416
    var_count = _var_count;
417
    samples = _samples;
418
    y = _y;
419
    alpha_count = _alpha_count;
420
    alpha = _alpha;
421
    kernel = _kernel;
422
423
    C[0] = _Cn;
424
    C[1] = _Cp;
425
    eps = kernel->params->term_crit.epsilon;
426
    max_iter = kernel->params->term_crit.max_iter;
427
    storage = cvCreateChildMemStorage( _storage );
428
429
    b = (double*)cvMemStorageAlloc( storage, alpha_count*sizeof(b[0]));
430
    alpha_status = (schar*)cvMemStorageAlloc( storage, alpha_count*sizeof(alpha_status[0]));
431
    G = (double*)cvMemStorageAlloc( storage, alpha_count*sizeof(G[0]));
432
    for( i = 0; i < 2; i++ )
433
        buf[i] = (Qfloat*)cvMemStorageAlloc( storage, sample_count*2*sizeof(buf[i][0]) );
434
    svm_type = kernel->params->svm_type;
435
436
    select_working_set_func = _select_working_set;
437
    if( !select_working_set_func )
438
        select_working_set_func = svm_type == CvSVM::NU_SVC || svm_type == CvSVM::NU_SVR ?
439
        &CvSVMSolver::select_working_set_nu_svm : &CvSVMSolver::select_working_set;
440
441
    calc_rho_func = _calc_rho;
442
    if( !calc_rho_func )
443
        calc_rho_func = svm_type == CvSVM::NU_SVC || svm_type == CvSVM::NU_SVR ?
444
            &CvSVMSolver::calc_rho_nu_svm : &CvSVMSolver::calc_rho;
445
446
    get_row_func = _get_row;
447
    if( !get_row_func )
448
        get_row_func = params->svm_type == CvSVM::EPS_SVR ||
449
                       params->svm_type == CvSVM::NU_SVR ? &CvSVMSolver::get_row_svr :
450
                       params->svm_type == CvSVM::C_SVC ||
451
                       params->svm_type == CvSVM::NU_SVC ? &CvSVMSolver::get_row_svc :
452
                       &CvSVMSolver::get_row_one_class;
453
454
    cache_line_size = sample_count*sizeof(Qfloat);
455
    // cache size = max(num_of_samples^2*sizeof(Qfloat)*0.25, 64Kb)
456
    // (assuming that for large training sets ~25% of Q matrix is used)
457
    cache_size = MAX( cache_line_size*sample_count/4, CV_SVM_MIN_CACHE_SIZE );
458
459
    // the size of Q matrix row headers
460
    rows_hdr_size = sample_count*sizeof(rows[0]);
461
    if( rows_hdr_size > storage->block_size )
462
        CV_ERROR( CV_StsOutOfRange, "Too small storage block size" );
463
464
    lru_list.prev = lru_list.next = &lru_list;
465
    rows = (CvSVMKernelRow*)cvMemStorageAlloc( storage, rows_hdr_size );
466
    memset( rows, 0, rows_hdr_size );
467
468
    ok = true;
469
470
    __END__;
471
472
    return ok;
473
}
474
475
476
float* CvSVMSolver::get_row_base( int i, bool* _existed )
477
{
478
    int i1 = i < sample_count ? i : i - sample_count;
479
    CvSVMKernelRow* row = rows + i1;
480
    bool existed = row->data != 0;
481
    Qfloat* data;
482
483
    if( existed || cache_size <= 0 )
484
    {
485
        CvSVMKernelRow* del_row = existed ? row : lru_list.prev;
486
        data = del_row->data;
487
        assert( data != 0 );
488
489
        // delete row from the LRU list
490
        del_row->data = 0;
491
        del_row->prev->next = del_row->next;
492
        del_row->next->prev = del_row->prev;
493
    }
494
    else
495
    {
496
        data = (Qfloat*)cvMemStorageAlloc( storage, cache_line_size );
497
        cache_size -= cache_line_size;
498
    }
499
500
    // insert row into the LRU list
501
    row->data = data;
502
    row->prev = &lru_list;
503
    row->next = lru_list.next;
504
    row->prev->next = row->next->prev = row;
505
506
    if( !existed )
507
    {
508
        kernel->calc( sample_count, var_count, samples, samples[i1], row->data );
509
    }
510
511
    if( _existed )
512
        *_existed = existed;
513
514
    return row->data;
515
}
516
517
518
float* CvSVMSolver::get_row_svc( int i, float* row, float*, bool existed )
519
{
520
    if( !existed )
521
    {
522
        const schar* _y = y;
523
        int j, len = sample_count;
524
        assert( _y && i < sample_count );
525
526
        if( _y[i] > 0 )
527
        {
528
            for( j = 0; j < len; j++ )
529
                row[j] = _y[j]*row[j];
530
        }
531
        else
532
        {
533
            for( j = 0; j < len; j++ )
534
                row[j] = -_y[j]*row[j];
535
        }
536
    }
537
    return row;
538
}
539
540
541
float* CvSVMSolver::get_row_one_class( int, float* row, float*, bool )
542
{
543
    return row;
544
}
545
546
547
float* CvSVMSolver::get_row_svr( int i, float* row, float* dst, bool )
548
{
549
    int j, len = sample_count;
550
    Qfloat* dst_pos = dst;
551
    Qfloat* dst_neg = dst + len;
552
    if( i >= len )
553
    {
554
        Qfloat* temp;
555
        CV_SWAP( dst_pos, dst_neg, temp );
556
    }
557
558
    for( j = 0; j < len; j++ )
559
    {
560
        Qfloat t = row[j];
561
        dst_pos[j] = t;
562
        dst_neg[j] = -t;
563
    }
564
    return dst;
565
}
566
567
568
569
float* CvSVMSolver::get_row( int i, float* dst )
570
{
571
    bool existed = false;
572
    float* row = get_row_base( i, &existed );
573
    return (this->*get_row_func)( i, row, dst, existed );
574
}
575
576
577
#undef is_upper_bound
578
#define is_upper_bound(i) (alpha_status[i] > 0)
579
580
#undef is_lower_bound
581
#define is_lower_bound(i) (alpha_status[i] < 0)
582
583
#undef is_free
584
#define is_free(i) (alpha_status[i] == 0)
585
586
#undef get_C
587
#define get_C(i) (C[y[i]>0])
588
589
#undef update_alpha_status
590
#define update_alpha_status(i) \
591
    alpha_status[i] = (schar)(alpha[i] >= get_C(i) ? 1 : alpha[i] <= 0 ? -1 : 0)
592
593
#undef reconstruct_gradient
594
#define reconstruct_gradient() /* empty for now */
595
596
597
bool CvSVMSolver::solve_generic( CvSVMSolutionInfo& si )
598
{
599
    int iter = 0;
600
    int i, j, k;
601
602
    // 1. initialize gradient and alpha status
603
    for( i = 0; i < alpha_count; i++ )
604
    {
605
        update_alpha_status(i);
606
        G[i] = b[i];
607
        if( fabs(G[i]) > 1e200 )
608
            return false;
609
    }
610
611
    for( i = 0; i < alpha_count; i++ )
612
    {
613
        if( !is_lower_bound(i) )
614
        {
615
            const Qfloat *Q_i = get_row( i, buf[0] );
616
            double alpha_i = alpha[i];
617
618
            for( j = 0; j < alpha_count; j++ )
619
                G[j] += alpha_i*Q_i[j];
620
        }
621
    }
622
623
    // 2. optimization loop
624
    for(;;)
625
    {
626
        const Qfloat *Q_i, *Q_j;
627
        double C_i, C_j;
628
        double old_alpha_i, old_alpha_j, alpha_i, alpha_j;
629
        double delta_alpha_i, delta_alpha_j;
630
631
#ifdef _DEBUG
632
        for( i = 0; i < alpha_count; i++ )
633
        {
634
            if( fabs(G[i]) > 1e+300 )
635
                return false;
636
637
            if( fabs(alpha[i]) > 1e16 )
638
                return false;
639
        }
640
#endif
641
642
        if( (this->*select_working_set_func)( i, j ) != 0 || iter++ >= max_iter )
643
            break;
644
645
        Q_i = get_row( i, buf[0] );
646
        Q_j = get_row( j, buf[1] );
647
648
        C_i = get_C(i);
649
        C_j = get_C(j);
650
651
        alpha_i = old_alpha_i = alpha[i];
652
        alpha_j = old_alpha_j = alpha[j];
653
654
        if( y[i] != y[j] )
655
        {
656
            double denom = Q_i[i]+Q_j[j]+2*Q_i[j];
657
            double delta = (-G[i]-G[j])/MAX(fabs(denom),FLT_EPSILON);
658
            double diff = alpha_i - alpha_j;
659
            alpha_i += delta;
660
            alpha_j += delta;
661
662
            if( diff > 0 && alpha_j < 0 )
663
            {
664
                alpha_j = 0;
665
                alpha_i = diff;
666
            }
667
            else if( diff <= 0 && alpha_i < 0 )
668
            {
669
                alpha_i = 0;
670
                alpha_j = -diff;
671
            }
672
673
            if( diff > C_i - C_j && alpha_i > C_i )
674
            {
675
                alpha_i = C_i;
676
                alpha_j = C_i - diff;
677
            }
678
            else if( diff <= C_i - C_j && alpha_j > C_j )
679
            {
680
                alpha_j = C_j;
681
                alpha_i = C_j + diff;
682
            }
683
        }
684
        else
685
        {
686
            double denom = Q_i[i]+Q_j[j]-2*Q_i[j];
687
            double delta = (G[i]-G[j])/MAX(fabs(denom),FLT_EPSILON);
688
            double sum = alpha_i + alpha_j;
689
            alpha_i -= delta;
690
            alpha_j += delta;
691
692
            if( sum > C_i && alpha_i > C_i )
693
            {
694
                alpha_i = C_i;
695
                alpha_j = sum - C_i;
696
            }
697
            else if( sum <= C_i && alpha_j < 0)
698
            {
699
                alpha_j = 0;
700
                alpha_i = sum;
701
            }
702
703
            if( sum > C_j && alpha_j > C_j )
704
            {
705
                alpha_j = C_j;
706
                alpha_i = sum - C_j;
707
            }
708
            else if( sum <= C_j && alpha_i < 0 )
709
            {
710
                alpha_i = 0;
711
                alpha_j = sum;
712
            }
713
        }
714
715
        // update alpha
716
        alpha[i] = alpha_i;
717
        alpha[j] = alpha_j;
718
        update_alpha_status(i);
719
        update_alpha_status(j);
720
721
        // update G
722
        delta_alpha_i = alpha_i - old_alpha_i;
723
        delta_alpha_j = alpha_j - old_alpha_j;
724
725
        for( k = 0; k < alpha_count; k++ )
726
            G[k] += Q_i[k]*delta_alpha_i + Q_j[k]*delta_alpha_j;
727
    }
728
729
    // calculate rho
730
    (this->*calc_rho_func)( si.rho, si.r );
731
732
    // calculate objective value
733
    for( i = 0, si.obj = 0; i < alpha_count; i++ )
734
        si.obj += alpha[i] * (G[i] + b[i]);
735
736
    si.obj *= 0.5;
737
738
    si.upper_bound_p = C[1];
739
    si.upper_bound_n = C[0];
740
741
    return true;
742
}
743
744
745
// return 1 if already optimal, return 0 otherwise
746
bool
747
CvSVMSolver::select_working_set( int& out_i, int& out_j )
748
{
749
    // return i,j which maximize -grad(f)^T d , under constraint
750
    // if alpha_i == C, d != +1
751
    // if alpha_i == 0, d != -1
752
    double Gmax1 = -DBL_MAX;        // max { -grad(f)_i * d | y_i*d = +1 }
753
    int Gmax1_idx = -1;
754
755
    double Gmax2 = -DBL_MAX;        // max { -grad(f)_i * d | y_i*d = -1 }
756
    int Gmax2_idx = -1;
757
758
    int i;
759
760
    for( i = 0; i < alpha_count; i++ )
761
    {
762
        double t;
763
764
        if( y[i] > 0 )    // y = +1
765
        {
766
            if( !is_upper_bound(i) && (t = -G[i]) > Gmax1 )  // d = +1
767
            {
768
                Gmax1 = t;
769
                Gmax1_idx = i;
770
            }
771
            if( !is_lower_bound(i) && (t = G[i]) > Gmax2 )  // d = -1
772
            {
773
                Gmax2 = t;
774
                Gmax2_idx = i;
775
            }
776
        }
777
        else        // y = -1
778
        {
779
            if( !is_upper_bound(i) && (t = -G[i]) > Gmax2 )  // d = +1
780
            {
781
                Gmax2 = t;
782
                Gmax2_idx = i;
783
            }
784
            if( !is_lower_bound(i) && (t = G[i]) > Gmax1 )  // d = -1
785
            {
786
                Gmax1 = t;
787
                Gmax1_idx = i;
788
            }
789
        }
790
    }
791
792
    out_i = Gmax1_idx;
793
    out_j = Gmax2_idx;
794
795
    return Gmax1 + Gmax2 < eps;
796
}
797
798
799
void
800
CvSVMSolver::calc_rho( double& rho, double& r )
801
{
802
    int i, nr_free = 0;
803
    double ub = DBL_MAX, lb = -DBL_MAX, sum_free = 0;
804
805
    for( i = 0; i < alpha_count; i++ )
806
    {
807
        double yG = y[i]*G[i];
808
809
        if( is_lower_bound(i) )
810
        {
811
            if( y[i] > 0 )
812
                ub = MIN(ub,yG);
813
            else
814
                lb = MAX(lb,yG);
815
        }
816
        else if( is_upper_bound(i) )
817
        {
818
            if( y[i] < 0)
819
                ub = MIN(ub,yG);
820
            else
821
                lb = MAX(lb,yG);
822
        }
823
        else
824
        {
825
            ++nr_free;
826
            sum_free += yG;
827
        }
828
    }
829
830
    rho = nr_free > 0 ? sum_free/nr_free : (ub + lb)*0.5;
831
    r = 0;
832
}
833
834
835
bool
836
CvSVMSolver::select_working_set_nu_svm( int& out_i, int& out_j )
837
{
838
    // return i,j which maximize -grad(f)^T d , under constraint
839
    // if alpha_i == C, d != +1
840
    // if alpha_i == 0, d != -1
841
    double Gmax1 = -DBL_MAX;    // max { -grad(f)_i * d | y_i = +1, d = +1 }
842
    int Gmax1_idx = -1;
843
844
    double Gmax2 = -DBL_MAX;    // max { -grad(f)_i * d | y_i = +1, d = -1 }
845
    int Gmax2_idx = -1;
846
847
    double Gmax3 = -DBL_MAX;    // max { -grad(f)_i * d | y_i = -1, d = +1 }
848
    int Gmax3_idx = -1;
849
850
    double Gmax4 = -DBL_MAX;    // max { -grad(f)_i * d | y_i = -1, d = -1 }
851
    int Gmax4_idx = -1;
852
853
    int i;
854
855
    for( i = 0; i < alpha_count; i++ )
856
    {
857
        double t;
858
859
        if( y[i] > 0 )    // y == +1
860
        {
861
            if( !is_upper_bound(i) && (t = -G[i]) > Gmax1 )  // d = +1
862
            {
863
                Gmax1 = t;
864
                Gmax1_idx = i;
865
            }
866
            if( !is_lower_bound(i) && (t = G[i]) > Gmax2 )  // d = -1
867
            {
868
                Gmax2 = t;
869
                Gmax2_idx = i;
870
            }
871
        }
872
        else        // y == -1
873
        {
874
            if( !is_upper_bound(i) && (t = -G[i]) > Gmax3 )  // d = +1
875
            {
876
                Gmax3 = t;
877
                Gmax3_idx = i;
878
            }
879
            if( !is_lower_bound(i) && (t = G[i]) > Gmax4 )  // d = -1
880
            {
881
                Gmax4 = t;
882
                Gmax4_idx = i;
883
            }
884
        }
885
    }
886
887
    if( MAX(Gmax1 + Gmax2, Gmax3 + Gmax4) < eps )
888
        return 1;
889
890
    if( Gmax1 + Gmax2 > Gmax3 + Gmax4 )
891
    {
892
        out_i = Gmax1_idx;
893
        out_j = Gmax2_idx;
894
    }
895
    else
896
    {
897
        out_i = Gmax3_idx;
898
        out_j = Gmax4_idx;
899
    }
900
    return 0;
901
}
902
903
904
void
905
CvSVMSolver::calc_rho_nu_svm( double& rho, double& r )
906
{
907
    int nr_free1 = 0, nr_free2 = 0;
908
    double ub1 = DBL_MAX, ub2 = DBL_MAX;
909
    double lb1 = -DBL_MAX, lb2 = -DBL_MAX;
910
    double sum_free1 = 0, sum_free2 = 0;
911
    double r1, r2;
912
913
    int i;
914
915
    for( i = 0; i < alpha_count; i++ )
916
    {
917
        double G_i = G[i];
918
        if( y[i] > 0 )
919
        {
920
            if( is_lower_bound(i) )
921
                ub1 = MIN( ub1, G_i );
922
            else if( is_upper_bound(i) )
923
                lb1 = MAX( lb1, G_i );
924
            else
925
            {
926
                ++nr_free1;
927
                sum_free1 += G_i;
928
            }
929
        }
930
        else
931
        {
932
            if( is_lower_bound(i) )
933
                ub2 = MIN( ub2, G_i );
934
            else if( is_upper_bound(i) )
935
                lb2 = MAX( lb2, G_i );
936
            else
937
            {
938
                ++nr_free2;
939
                sum_free2 += G_i;
940
            }
941
        }
942
    }
943
944
    r1 = nr_free1 > 0 ? sum_free1/nr_free1 : (ub1 + lb1)*0.5;
945
    r2 = nr_free2 > 0 ? sum_free2/nr_free2 : (ub2 + lb2)*0.5;
946
947
    rho = (r1 - r2)*0.5;
948
    r = (r1 + r2)*0.5;
949
}
950
951
952
/*
953
///////////////////////// construct and solve various formulations ///////////////////////
954
*/
955
956
bool CvSVMSolver::solve_c_svc( int _sample_count, int _var_count, const float** _samples, schar* _y,
957
                               double _Cp, double _Cn, CvMemStorage* _storage,
958
                               CvSVMKernel* _kernel, double* _alpha, CvSVMSolutionInfo& _si )
959
{
960
    int i;
961
962
    if( !create( _sample_count, _var_count, _samples, _y, _sample_count,
963
                 _alpha, _Cp, _Cn, _storage, _kernel, &CvSVMSolver::get_row_svc,
964
                 &CvSVMSolver::select_working_set, &CvSVMSolver::calc_rho ))
965
        return false;
966
967
    for( i = 0; i < sample_count; i++ )
968
    {
969
        alpha[i] = 0;
970
        b[i] = -1;
971
    }
972
973
    if( !solve_generic( _si ))
974
        return false;
975
976
    for( i = 0; i < sample_count; i++ )
977
        alpha[i] *= y[i];
978
979
    return true;
980
}
981
982
983
bool CvSVMSolver::solve_nu_svc( int _sample_count, int _var_count, const float** _samples, schar* _y,
984
                                CvMemStorage* _storage, CvSVMKernel* _kernel,
985
                                double* _alpha, CvSVMSolutionInfo& _si )
986
{
987
    int i;
988
    double sum_pos, sum_neg, inv_r;
989
990
    if( !create( _sample_count, _var_count, _samples, _y, _sample_count,
991
                 _alpha, 1., 1., _storage, _kernel, &CvSVMSolver::get_row_svc,
992
                 &CvSVMSolver::select_working_set_nu_svm, &CvSVMSolver::calc_rho_nu_svm ))
993
        return false;
994
995
    sum_pos = kernel->params->nu * sample_count * 0.5;
996
    sum_neg = kernel->params->nu * sample_count * 0.5;
997
998
    for( i = 0; i < sample_count; i++ )
999
    {
1000
        if( y[i] > 0 )
1001
        {
1002
            alpha[i] = MIN(1.0, sum_pos);
1003
            sum_pos -= alpha[i];
1004
        }
1005
        else
1006
        {
1007
            alpha[i] = MIN(1.0, sum_neg);
1008
            sum_neg -= alpha[i];
1009
        }
1010
        b[i] = 0;
1011
    }
1012
1013
    if( !solve_generic( _si ))
1014
        return false;
1015
1016
    inv_r = 1./_si.r;
1017
1018
    for( i = 0; i < sample_count; i++ )
1019
        alpha[i] *= y[i]*inv_r;
1020
1021
    _si.rho *= inv_r;
1022
    _si.obj *= (inv_r*inv_r);
1023
    _si.upper_bound_p = inv_r;
1024
    _si.upper_bound_n = inv_r;
1025
1026
    return true;
1027
}
1028
1029
1030
bool CvSVMSolver::solve_one_class( int _sample_count, int _var_count, const float** _samples,
1031
                                   CvMemStorage* _storage, CvSVMKernel* _kernel,
1032
                                   double* _alpha, CvSVMSolutionInfo& _si )
1033
{
1034
    int i, n;
1035
    double nu = _kernel->params->nu;
1036
1037
    if( !create( _sample_count, _var_count, _samples, 0, _sample_count,
1038
                 _alpha, 1., 1., _storage, _kernel, &CvSVMSolver::get_row_one_class,
1039
                 &CvSVMSolver::select_working_set, &CvSVMSolver::calc_rho ))
1040
        return false;
1041
1042
    y = (schar*)cvMemStorageAlloc( storage, sample_count*sizeof(y[0]) );
1043
    n = cvRound( nu*sample_count );
1044
1045
    for( i = 0; i < sample_count; i++ )
1046
    {
1047
        y[i] = 1;
1048
        b[i] = 0;
1049
        alpha[i] = i < n ? 1 : 0;
1050
    }
1051
1052
    if( n < sample_count )
1053
        alpha[n] = nu * sample_count - n;
1054
    else
1055
        alpha[n-1] = nu * sample_count - (n-1);
1056
1057
    return solve_generic(_si);
1058
}
1059
1060
1061
bool CvSVMSolver::solve_eps_svr( int _sample_count, int _var_count, const float** _samples,
1062
                                 const float* _y, CvMemStorage* _storage,
1063
                                 CvSVMKernel* _kernel, double* _alpha, CvSVMSolutionInfo& _si )
1064
{
1065
    int i;
1066
    double p = _kernel->params->p, kernel_param_c = _kernel->params->C;
1067
1068
    if( !create( _sample_count, _var_count, _samples, 0,
1069
                 _sample_count*2, 0, kernel_param_c, kernel_param_c, _storage, _kernel, &CvSVMSolver::get_row_svr,
1070
                 &CvSVMSolver::select_working_set, &CvSVMSolver::calc_rho ))
1071
        return false;
1072
1073
    y = (schar*)cvMemStorageAlloc( storage, sample_count*2*sizeof(y[0]) );
1074
    alpha = (double*)cvMemStorageAlloc( storage, alpha_count*sizeof(alpha[0]) );
1075
1076
    for( i = 0; i < sample_count; i++ )
1077
    {
1078
        alpha[i] = 0;
1079
        b[i] = p - _y[i];
1080
        y[i] = 1;
1081
1082
        alpha[i+sample_count] = 0;
1083
        b[i+sample_count] = p + _y[i];
1084
        y[i+sample_count] = -1;
1085
    }
1086
1087
    if( !solve_generic( _si ))
1088
        return false;
1089
1090
    for( i = 0; i < sample_count; i++ )
1091
        _alpha[i] = alpha[i] - alpha[i+sample_count];
1092
1093
    return true;
1094
}
1095
1096
1097
bool CvSVMSolver::solve_nu_svr( int _sample_count, int _var_count, const float** _samples,
1098
                                const float* _y, CvMemStorage* _storage,
1099
                                CvSVMKernel* _kernel, double* _alpha, CvSVMSolutionInfo& _si )
1100
{
1101
    int i;
1102
    double kernel_param_c = _kernel->params->C, sum;
1103
1104
    if( !create( _sample_count, _var_count, _samples, 0,
1105
                 _sample_count*2, 0, 1., 1., _storage, _kernel, &CvSVMSolver::get_row_svr,
1106
                 &CvSVMSolver::select_working_set_nu_svm, &CvSVMSolver::calc_rho_nu_svm ))
1107
        return false;
1108
1109
    y = (schar*)cvMemStorageAlloc( storage, sample_count*2*sizeof(y[0]) );
1110
    alpha = (double*)cvMemStorageAlloc( storage, alpha_count*sizeof(alpha[0]) );
1111
    sum = kernel_param_c * _kernel->params->nu * sample_count * 0.5;
1112
1113
    for( i = 0; i < sample_count; i++ )
1114
    {
1115
        alpha[i] = alpha[i + sample_count] = MIN(sum, kernel_param_c);
1116
        sum -= alpha[i];
1117
1118
        b[i] = -_y[i];
1119
        y[i] = 1;
1120
1121
        b[i + sample_count] = _y[i];
1122
        y[i + sample_count] = -1;
1123
    }
1124
1125
    if( !solve_generic( _si ))
1126
        return false;
1127
1128
    for( i = 0; i < sample_count; i++ )
1129
        _alpha[i] = alpha[i] - alpha[i+sample_count];
1130
1131
    return true;
1132
}
1133
1134
1135
//////////////////////////////////////////////////////////////////////////////////////////
1136
1137
CvSVM::CvSVM()
1138
{
1139
    decision_func = 0;
1140
    class_labels = 0;
1141
    class_weights = 0;
1142
    storage = 0;
1143
    var_idx = 0;
1144
    kernel = 0;
1145
    solver = 0;
1146
    default_model_name = "my_svm";
1147
1148
    clear();
1149
}
1150
1151
1152
CvSVM::~CvSVM()
1153
{
1154
    clear();
1155
}
1156
1157
1158
void CvSVM::clear()
1159
{
1160
    cvFree( &decision_func );
1161
    cvReleaseMat( &class_labels );
1162
    cvReleaseMat( &class_weights );
1163
    cvReleaseMemStorage( &storage );
1164
    cvReleaseMat( &var_idx );
1165
    delete kernel;
1166
    delete solver;
1167
    kernel = 0;
1168
    solver = 0;
1169
    var_all = 0;
1170
    sv = 0;
1171
    sv_total = 0;
1172
}
1173
1174
1175
CvSVM::CvSVM( const CvMat* _train_data, const CvMat* _responses,
1176
    const CvMat* _var_idx, const CvMat* _sample_idx, CvSVMParams _params )
1177
{
1178
    decision_func = 0;
1179
    class_labels = 0;
1180
    class_weights = 0;
1181
    storage = 0;
1182
    var_idx = 0;
1183
    kernel = 0;
1184
    solver = 0;
1185
    default_model_name = "my_svm";
1186
1187
    train( _train_data, _responses, _var_idx, _sample_idx, _params );
1188
}
1189
1190
1191
int CvSVM::get_support_vector_count() const
1192
{
1193
    return sv_total;
1194
}
1195
1196
1197
const float* CvSVM::get_support_vector(int i) const
1198
{
1199
    return sv && (unsigned)i < (unsigned)sv_total ? sv[i] : 0;
1200
}
1201
1202
1203
bool CvSVM::set_params( const CvSVMParams& _params )
1204
{
1205
    bool ok = false;
1206
1207
    CV_FUNCNAME( "CvSVM::set_params" );
1208
1209
    __BEGIN__;
1210
1211
    int kernel_type, svm_type;
1212
1213
    params = _params;
1214
1215
    kernel_type = params.kernel_type;
1216
    svm_type = params.svm_type;
1217
1218
    if( kernel_type != LINEAR && kernel_type != POLY &&
1219
        kernel_type != SIGMOID && kernel_type != RBF )
1220
        CV_ERROR( CV_StsBadArg, "Unknown/unsupported kernel type" );
1221
1222
    if( kernel_type == LINEAR )
1223
        params.gamma = 1;
1224
    else if( params.gamma <= 0 )
1225
        CV_ERROR( CV_StsOutOfRange, "gamma parameter of the kernel must be positive" );
1226
1227
    if( kernel_type != SIGMOID && kernel_type != POLY )
1228
        params.coef0 = 0;
1229
    else if( params.coef0 < 0 )
1230
        CV_ERROR( CV_StsOutOfRange, "The kernel parameter <coef0> must be positive or zero" );
1231
1232
    if( kernel_type != POLY )
1233
        params.degree = 0;
1234
    else if( params.degree <= 0 )
1235
        CV_ERROR( CV_StsOutOfRange, "The kernel parameter <degree> must be positive" );
1236
1237
    if( svm_type != C_SVC && svm_type != NU_SVC &&
1238
        svm_type != ONE_CLASS && svm_type != EPS_SVR &&
1239
        svm_type != NU_SVR )
1240
        CV_ERROR( CV_StsBadArg, "Unknown/unsupported SVM type" );
1241
1242
    if( svm_type == ONE_CLASS || svm_type == NU_SVC )
1243
        params.C = 0;
1244
    else if( params.C <= 0 )
1245
        CV_ERROR( CV_StsOutOfRange, "The parameter C must be positive" );
1246
1247
    if( svm_type == C_SVC || svm_type == EPS_SVR )
1248
        params.nu = 0;
1249
    else if( params.nu <= 0 || params.nu >= 1 )
1250
        CV_ERROR( CV_StsOutOfRange, "The parameter nu must be between 0 and 1" );
1251
1252
    if( svm_type != EPS_SVR )
1253
        params.p = 0;
1254
    else if( params.p <= 0 )
1255
        CV_ERROR( CV_StsOutOfRange, "The parameter p must be positive" );
1256
1257
    if( svm_type != C_SVC )
1258
        params.class_weights = 0;
1259
1260
    params.term_crit = cvCheckTermCriteria( params.term_crit, DBL_EPSILON, INT_MAX );
1261
    params.term_crit.epsilon = MAX( params.term_crit.epsilon, DBL_EPSILON );
1262
    ok = true;
1263
1264
    __END__;
1265
1266
    return ok;
1267
}
1268
1269
1270
1271
void CvSVM::create_kernel()
1272
{
1273
    kernel = new CvSVMKernel(&params,0);
1274
}
1275
1276
1277
void CvSVM::create_solver( )
1278
{
1279
    solver = new CvSVMSolver;
1280
}
1281
1282
1283
// switching function
1284
bool CvSVM::train1( int sample_count, int var_count, const float** samples,
1285
                    const void* _responses, double Cp, double Cn,
1286
                    CvMemStorage* _storage, double* alpha, double& rho )
1287
{
1288
    bool ok = false;
1289
1290
    //CV_FUNCNAME( "CvSVM::train1" );
1291
1292
    __BEGIN__;
1293
1294
    CvSVMSolutionInfo si;
1295
    int svm_type = params.svm_type;
1296
1297
    si.rho = 0;
1298
1299
    ok = svm_type == C_SVC ? solver->solve_c_svc( sample_count, var_count, samples, (schar*)_responses,
1300
                                                  Cp, Cn, _storage, kernel, alpha, si ) :
1301
         svm_type == NU_SVC ? solver->solve_nu_svc( sample_count, var_count, samples, (schar*)_responses,
1302
                                                    _storage, kernel, alpha, si ) :
1303
         svm_type == ONE_CLASS ? solver->solve_one_class( sample_count, var_count, samples,
1304
                                                          _storage, kernel, alpha, si ) :
1305
         svm_type == EPS_SVR ? solver->solve_eps_svr( sample_count, var_count, samples, (float*)_responses,
1306
                                                      _storage, kernel, alpha, si ) :
1307
         svm_type == NU_SVR ? solver->solve_nu_svr( sample_count, var_count, samples, (float*)_responses,
1308
                                                    _storage, kernel, alpha, si ) : false;
1309
1310
    rho = si.rho;
1311
1312
    __END__;
1313
1314
    return ok;
1315
}
1316
1317
1318
bool CvSVM::do_train( int svm_type, int sample_count, int var_count, const float** samples,
1319
                    const CvMat* responses, CvMemStorage* temp_storage, double* alpha )
1320
{
1321
    bool ok = false;
1322
1323
    CV_FUNCNAME( "CvSVM::do_train" );
1324
1325
    __BEGIN__;
1326
1327
    CvSVMDecisionFunc* df = 0;
1328
    const int sample_size = var_count*sizeof(samples[0][0]);
1329
    int i, j, k;
1330
1331
    cvClearMemStorage( storage );
1332
1333
    if( svm_type == ONE_CLASS || svm_type == EPS_SVR || svm_type == NU_SVR )
1334
    {
1335
        int sv_count = 0;
1336
1337
        CV_CALL( decision_func = df =
1338
            (CvSVMDecisionFunc*)cvAlloc( sizeof(df[0]) ));
1339
1340
        df->rho = 0;
1341
        if( !train1( sample_count, var_count, samples, svm_type == ONE_CLASS ? 0 :
1342
            responses->data.i, 0, 0, temp_storage, alpha, df->rho ))
1343
            EXIT;
1344
1345
        for( i = 0; i < sample_count; i++ )
1346
            sv_count += fabs(alpha[i]) > 0;
1347
1348
        CV_Assert(sv_count != 0);
1349
1350
        sv_total = df->sv_count = sv_count;
1351
        CV_CALL( df->alpha = (double*)cvMemStorageAlloc( storage, sv_count*sizeof(df->alpha[0])) );
1352
        CV_CALL( sv = (float**)cvMemStorageAlloc( storage, sv_count*sizeof(sv[0])));
1353
1354
        for( i = k = 0; i < sample_count; i++ )
1355
        {
1356
            if( fabs(alpha[i]) > 0 )
1357
            {
1358
                CV_CALL( sv[k] = (float*)cvMemStorageAlloc( storage, sample_size ));
1359
                memcpy( sv[k], samples[i], sample_size );
1360
                df->alpha[k++] = alpha[i];
1361
            }
1362
        }
1363
    }
1364
    else
1365
    {
1366
        int class_count = class_labels->cols;
1367
        int* sv_tab = 0;
1368
        const float** temp_samples = 0;
1369
        int* class_ranges = 0;
1370
        schar* temp_y = 0;
1371
        assert( svm_type == CvSVM::C_SVC || svm_type == CvSVM::NU_SVC );
1372
1373
        if( svm_type == CvSVM::C_SVC && params.class_weights )
1374
        {
1375
            const CvMat* cw = params.class_weights;
1376
1377
            if( !CV_IS_MAT(cw) || (cw->cols != 1 && cw->rows != 1) ||
1378
                cw->rows + cw->cols - 1 != class_count ||
1379
                (CV_MAT_TYPE(cw->type) != CV_32FC1 && CV_MAT_TYPE(cw->type) != CV_64FC1) )
1380
                CV_ERROR( CV_StsBadArg, "params.class_weights must be 1d floating-point vector "
1381
                    "containing as many elements as the number of classes" );
1382
1383
            CV_CALL( class_weights = cvCreateMat( cw->rows, cw->cols, CV_64F ));
1384
            CV_CALL( cvConvert( cw, class_weights ));
1385
            CV_CALL( cvScale( class_weights, class_weights, params.C ));
1386
        }
1387
1388
        CV_CALL( decision_func = df = (CvSVMDecisionFunc*)cvAlloc(
1389
            (class_count*(class_count-1)/2)*sizeof(df[0])));
1390
1391
        CV_CALL( sv_tab = (int*)cvMemStorageAlloc( temp_storage, sample_count*sizeof(sv_tab[0]) ));
1392
        memset( sv_tab, 0, sample_count*sizeof(sv_tab[0]) );
1393
        CV_CALL( class_ranges = (int*)cvMemStorageAlloc( temp_storage,
1394
                            (class_count + 1)*sizeof(class_ranges[0])));
1395
        CV_CALL( temp_samples = (const float**)cvMemStorageAlloc( temp_storage,
1396
                            sample_count*sizeof(temp_samples[0])));
1397
        CV_CALL( temp_y = (schar*)cvMemStorageAlloc( temp_storage, sample_count));
1398
1399
        class_ranges[class_count] = 0;
1400
        cvSortSamplesByClasses( samples, responses, class_ranges, 0 );
1401
        //check that while cross-validation there were the samples from all the classes
1402
        if( class_ranges[class_count] <= 0 )
1403
            CV_ERROR( CV_StsBadArg, "While cross-validation one or more of the classes have "
1404
            "been fell out of the sample. Try to enlarge <CvSVMParams::k_fold>" );
1405
1406
        if( svm_type == NU_SVC )
1407
        {
1408
            // check if nu is feasible
1409
            for(i = 0; i < class_count; i++ )
1410
            {
1411
                int ci = class_ranges[i+1] - class_ranges[i];
1412
                for( j = i+1; j< class_count; j++ )
1413
                {
1414
                    int cj = class_ranges[j+1] - class_ranges[j];
1415
                    if( params.nu*(ci + cj)*0.5 > MIN( ci, cj ) )
1416
                    {
1417
                        // !!!TODO!!! add some diagnostic
1418
                        EXIT; // exit immediately; will release the model and return NULL pointer
1419
                    }
1420
                }
1421
            }
1422
        }
1423
1424
        // train n*(n-1)/2 classifiers
1425
        for( i = 0; i < class_count; i++ )
1426
        {
1427
            for( j = i+1; j < class_count; j++, df++ )
1428
            {
1429
                int si = class_ranges[i], ci = class_ranges[i+1] - si;
1430
                int sj = class_ranges[j], cj = class_ranges[j+1] - sj;
1431
                double Cp = params.C, Cn = Cp;
1432
                int k1 = 0, sv_count = 0;
1433
1434
                for( k = 0; k < ci; k++ )
1435
                {
1436
                    temp_samples[k] = samples[si + k];
1437
                    temp_y[k] = 1;
1438
                }
1439
1440
                for( k = 0; k < cj; k++ )
1441
                {
1442
                    temp_samples[ci + k] = samples[sj + k];
1443
                    temp_y[ci + k] = -1;
1444
                }
1445
1446
                if( class_weights )
1447
                {
1448
                    Cp = class_weights->data.db[i];
1449
                    Cn = class_weights->data.db[j];
1450
                }
1451
1452
                if( !train1( ci + cj, var_count, temp_samples, temp_y,
1453
                             Cp, Cn, temp_storage, alpha, df->rho ))
1454
                    EXIT;
1455
1456
                for( k = 0; k < ci + cj; k++ )
1457
                    sv_count += fabs(alpha[k]) > 0;
1458
1459
                df->sv_count = sv_count;
1460
1461
                CV_CALL( df->alpha = (double*)cvMemStorageAlloc( temp_storage,
1462
                                                sv_count*sizeof(df->alpha[0])));
1463
                CV_CALL( df->sv_index = (int*)cvMemStorageAlloc( temp_storage,
1464
                                                sv_count*sizeof(df->sv_index[0])));
1465
1466
                for( k = 0; k < ci; k++ )
1467
                {
1468
                    if( fabs(alpha[k]) > 0 )
1469
                    {
1470
                        sv_tab[si + k] = 1;
1471
                        df->sv_index[k1] = si + k;
1472
                        df->alpha[k1++] = alpha[k];
1473
                    }
1474
                }
1475
1476
                for( k = 0; k < cj; k++ )
1477
                {
1478
                    if( fabs(alpha[ci + k]) > 0 )
1479
                    {
1480
                        sv_tab[sj + k] = 1;
1481
                        df->sv_index[k1] = sj + k;
1482
                        df->alpha[k1++] = alpha[ci + k];
1483
                    }
1484
                }
1485
            }
1486
        }
1487
1488
        // allocate support vectors and initialize sv_tab
1489
        for( i = 0, k = 0; i < sample_count; i++ )
1490
        {
1491
            if( sv_tab[i] )
1492
                sv_tab[i] = ++k;
1493
        }
1494
1495
        sv_total = k;
1496
        CV_CALL( sv = (float**)cvMemStorageAlloc( storage, sv_total*sizeof(sv[0])));
1497
1498
        for( i = 0, k = 0; i < sample_count; i++ )
1499
        {
1500
            if( sv_tab[i] )
1501
            {
1502
                CV_CALL( sv[k] = (float*)cvMemStorageAlloc( storage, sample_size ));
1503
                memcpy( sv[k], samples[i], sample_size );
1504
                k++;
1505
            }
1506
        }
1507
1508
        df = (CvSVMDecisionFunc*)decision_func;
1509
1510
        // set sv pointers
1511
        for( i = 0; i < class_count; i++ )
1512
        {
1513
            for( j = i+1; j < class_count; j++, df++ )
1514
            {
1515
                for( k = 0; k < df->sv_count; k++ )
1516
                {
1517
                    df->sv_index[k] = sv_tab[df->sv_index[k]]-1;
1518
                    assert( (unsigned)df->sv_index[k] < (unsigned)sv_total );
1519
                }
1520
            }
1521
        }
1522
    }
1523
1524
    optimize_linear_svm();
1525
    ok = true;
1526
1527
    __END__;
1528
1529
    return ok;
1530
}
1531
1532
1533
void CvSVM::optimize_linear_svm()
1534
{
1535
    // we optimize only linear SVM: compress all the support vectors into one.
1536
    if( params.kernel_type != LINEAR )
1537
        return;
1538
1539
    int class_count = class_labels ? class_labels->cols :
1540
            params.svm_type == CvSVM::ONE_CLASS ? 1 : 0;
1541
1542
    int i, df_count = class_count > 1 ? class_count*(class_count-1)/2 : 1;
1543
    CvSVMDecisionFunc* df = decision_func;
1544
1545
    for( i = 0; i < df_count; i++ )
1546
    {
1547
        int sv_count = df[i].sv_count;
1548
        if( sv_count != 1 )
1549
            break;
1550
    }
1551
1552
    // if every decision functions uses a single support vector;
1553
    // it's already compressed. skip it then.
1554
    if( i == df_count )
1555
        return;
1556
1557
    int var_count = get_var_count();
1558
    cv::AutoBuffer<double> vbuf(var_count);
1559
    double* v = vbuf;
1560
    float** new_sv = (float**)cvMemStorageAlloc(storage, df_count*sizeof(new_sv[0]));
1561
1562
    for( i = 0; i < df_count; i++ )
1563
    {
1564
        new_sv[i] = (float*)cvMemStorageAlloc(storage, var_count*sizeof(new_sv[i][0]));
1565
        float* dst = new_sv[i];
1566
        memset(v, 0, var_count*sizeof(v[0]));
1567
        int j, k, sv_count = df[i].sv_count;
1568
        for( j = 0; j < sv_count; j++ )
1569
        {
1570
            const float* src = class_count > 1 && df[i].sv_index ? sv[df[i].sv_index[j]] : sv[j];
1571
            double a = df[i].alpha[j];
1572
            for( k = 0; k < var_count; k++ )
1573
                v[k] += src[k]*a;
1574
        }
1575
        for( k = 0; k < var_count; k++ )
1576
            dst[k] = (float)v[k];
1577
        df[i].sv_count = 1;
1578
        df[i].alpha[0] = 1.;
1579
        if( class_count > 1 && df[i].sv_index )
1580
            df[i].sv_index[0] = i;
1581
    }
1582
1583
    sv = new_sv;
1584
    sv_total = df_count;
1585
}
1586
1587
1588
bool CvSVM::train( const CvMat* _train_data, const CvMat* _responses,
1589
    const CvMat* _var_idx, const CvMat* _sample_idx, CvSVMParams _params )
1590
{
1591
    bool ok = false;
1592
    CvMat* responses = 0;
1593
    CvMemStorage* temp_storage = 0;
1594
    const float** samples = 0;
1595
1596
    CV_FUNCNAME( "CvSVM::train" );
1597
1598
    __BEGIN__;
1599
1600
    int svm_type, sample_count, var_count, sample_size;
1601
    int block_size = 1 << 16;
1602
    double* alpha;
1603
1604
    clear();
1605
    CV_CALL( set_params( _params ));
1606
1607
    svm_type = _params.svm_type;
1608
1609
    /* Prepare training data and related parameters */
1610
    CV_CALL( cvPrepareTrainData( "CvSVM::train", _train_data, CV_ROW_SAMPLE,
1611
                                 svm_type != CvSVM::ONE_CLASS ? _responses : 0,
1612
                                 svm_type == CvSVM::C_SVC ||
1613
                                 svm_type == CvSVM::NU_SVC ? CV_VAR_CATEGORICAL :
1614
                                 CV_VAR_ORDERED, _var_idx, _sample_idx,
1615
                                 false, &samples, &sample_count, &var_count, &var_all,
1616
                                 &responses, &class_labels, &var_idx ));
1617
1618
1619
    sample_size = var_count*sizeof(samples[0][0]);
1620
1621
    // make the storage block size large enough to fit all
1622
    // the temporary vectors and output support vectors.
1623
    block_size = MAX( block_size, sample_count*(int)sizeof(CvSVMKernelRow));
1624
    block_size = MAX( block_size, sample_count*2*(int)sizeof(double) + 1024 );
1625
    block_size = MAX( block_size, sample_size*2 + 1024 );
1626
1627
    CV_CALL( storage = cvCreateMemStorage(block_size + sizeof(CvMemBlock) + sizeof(CvSeqBlock)));
1628
    CV_CALL( temp_storage = cvCreateChildMemStorage(storage));
1629
    CV_CALL( alpha = (double*)cvMemStorageAlloc(temp_storage, sample_count*sizeof(double)));
1630
1631
    create_kernel();
1632
    create_solver();
1633
1634
    if( !do_train( svm_type, sample_count, var_count, samples, responses, temp_storage, alpha ))
1635
        EXIT;
1636
1637
    ok = true; // model has been trained succesfully
1638
1639
    __END__;
1640
1641
    delete solver;
1642
    solver = 0;
1643
    cvReleaseMemStorage( &temp_storage );
1644
    cvReleaseMat( &responses );
1645
    cvFree( &samples );
1646
1647
    if( cvGetErrStatus() < 0 || !ok )
1648
        clear();
1649
1650
    return ok;
1651
}
1652
1653
struct indexedratio
1654
{
1655
    double val;
1656
    int ind;
1657
    int count_smallest, count_biggest;
1658
    void eval() { val = (double) count_smallest/(count_smallest+count_biggest); }
1659
};
1660
1661
static int CV_CDECL
1662
icvCmpIndexedratio( const void* a, const void* b )
1663
{
1664
    return ((const indexedratio*)a)->val < ((const indexedratio*)b)->val ? -1
1665
    : ((const indexedratio*)a)->val > ((const indexedratio*)b)->val ? 1
1666
    : 0;
1667
}
1668
1669
bool CvSVM::train_auto( const CvMat* _train_data, const CvMat* _responses,
1670
    const CvMat* _var_idx, const CvMat* _sample_idx, CvSVMParams _params, int k_fold,
1671
    CvParamGrid C_grid, CvParamGrid gamma_grid, CvParamGrid p_grid,
1672
    CvParamGrid nu_grid, CvParamGrid coef_grid, CvParamGrid degree_grid,
1673
    bool balanced)
1674
{
1675
    bool ok = false;
1676
    CvMat* responses = 0;
1677
    CvMat* responses_local = 0;
1678
    CvMemStorage* temp_storage = 0;
1679
    const float** samples = 0;
1680
    const float** samples_local = 0;
1681
1682
    CV_FUNCNAME( "CvSVM::train_auto" );
1683
    __BEGIN__;
1684
1685
    int svm_type, sample_count, var_count, sample_size;
1686
    int block_size = 1 << 16;
1687
    double* alpha;
1688
    RNG* rng = &theRNG();
1689
1690
    // all steps are logarithmic and must be > 1
1691
    double degree_step = 10, g_step = 10, coef_step = 10, C_step = 10, nu_step = 10, p_step = 10;
1692
    double gamma = 0, curr_c = 0, degree = 0, coef = 0, p = 0, nu = 0;
1693
    double best_degree = 0, best_gamma = 0, best_coef = 0, best_C = 0, best_nu = 0, best_p = 0;
1694
    float min_error = FLT_MAX, error;
1695
1696
    if( _params.svm_type == CvSVM::ONE_CLASS )
1697
    {
1698
        if(!train( _train_data, _responses, _var_idx, _sample_idx, _params ))
1699
            EXIT;
1700
        return true;
1701
    }
1702
1703
    clear();
1704
1705
    if( k_fold < 2 )
1706
        CV_ERROR( CV_StsBadArg, "Parameter <k_fold> must be > 1" );
1707
1708
    CV_CALL(set_params( _params ));
1709
    svm_type = _params.svm_type;
1710
1711
    // All the parameters except, possibly, <coef0> are positive.
1712
    // <coef0> is nonnegative
1713
    if( C_grid.step <= 1 )
1714
    {
1715
        C_grid.min_val = C_grid.max_val = params.C;
1716
        C_grid.step = 10;
1717
    }
1718
    else
1719
        CV_CALL(C_grid.check());
1720
1721
    if( gamma_grid.step <= 1 )
1722
    {
1723
        gamma_grid.min_val = gamma_grid.max_val = params.gamma;
1724
        gamma_grid.step = 10;
1725
    }
1726
    else
1727
        CV_CALL(gamma_grid.check());
1728
1729
    if( p_grid.step <= 1 )
1730
    {
1731
        p_grid.min_val = p_grid.max_val = params.p;
1732
        p_grid.step = 10;
1733
    }
1734
    else
1735
        CV_CALL(p_grid.check());
1736
1737
    if( nu_grid.step <= 1 )
1738
    {
1739
        nu_grid.min_val = nu_grid.max_val = params.nu;
1740
        nu_grid.step = 10;
1741
    }
1742
    else
1743
        CV_CALL(nu_grid.check());
1744
1745
    if( coef_grid.step <= 1 )
1746
    {
1747
        coef_grid.min_val = coef_grid.max_val = params.coef0;
1748
        coef_grid.step = 10;
1749
    }
1750
    else
1751
        CV_CALL(coef_grid.check());
1752
1753
    if( degree_grid.step <= 1 )
1754
    {
1755
        degree_grid.min_val = degree_grid.max_val = params.degree;
1756
        degree_grid.step = 10;
1757
    }
1758
    else
1759
        CV_CALL(degree_grid.check());
1760
1761
    // these parameters are not used:
1762
    if( params.kernel_type != CvSVM::POLY )
1763
        degree_grid.min_val = degree_grid.max_val = params.degree;
1764
    if( params.kernel_type == CvSVM::LINEAR )
1765
        gamma_grid.min_val = gamma_grid.max_val = params.gamma;
1766
    if( params.kernel_type != CvSVM::POLY && params.kernel_type != CvSVM::SIGMOID )
1767
        coef_grid.min_val = coef_grid.max_val = params.coef0;
1768
    if( svm_type == CvSVM::NU_SVC || svm_type == CvSVM::ONE_CLASS )
1769
        C_grid.min_val = C_grid.max_val = params.C;
1770
    if( svm_type == CvSVM::C_SVC || svm_type == CvSVM::EPS_SVR )
1771
        nu_grid.min_val = nu_grid.max_val = params.nu;
1772
    if( svm_type != CvSVM::EPS_SVR )
1773
        p_grid.min_val = p_grid.max_val = params.p;
1774
1775
    CV_ASSERT( g_step > 1 && degree_step > 1 && coef_step > 1);
1776
    CV_ASSERT( p_step > 1 && C_step > 1 && nu_step > 1 );
1777
1778
    /* Prepare training data and related parameters */
1779
    CV_CALL(cvPrepareTrainData( "CvSVM::train_auto", _train_data, CV_ROW_SAMPLE,
1780
                                 svm_type != CvSVM::ONE_CLASS ? _responses : 0,
1781
                                 svm_type == CvSVM::C_SVC ||
1782
                                 svm_type == CvSVM::NU_SVC ? CV_VAR_CATEGORICAL :
1783
                                 CV_VAR_ORDERED, _var_idx, _sample_idx,
1784
                                 false, &samples, &sample_count, &var_count, &var_all,
1785
                                 &responses, &class_labels, &var_idx ));
1786
1787
    sample_size = var_count*sizeof(samples[0][0]);
1788
1789
    // make the storage block size large enough to fit all
1790
    // the temporary vectors and output support vectors.
1791
    block_size = MAX( block_size, sample_count*(int)sizeof(CvSVMKernelRow));
1792
    block_size = MAX( block_size, sample_count*2*(int)sizeof(double) + 1024 );
1793
    block_size = MAX( block_size, sample_size*2 + 1024 );
1794
1795
    CV_CALL( storage = cvCreateMemStorage(block_size + sizeof(CvMemBlock) + sizeof(CvSeqBlock)));
1796
    CV_CALL(temp_storage = cvCreateChildMemStorage(storage));
1797
    CV_CALL(alpha = (double*)cvMemStorageAlloc(temp_storage, sample_count*sizeof(double)));
1798
1799
    create_kernel();
1800
    create_solver();
1801
1802
    {
1803
    const int testset_size = sample_count/k_fold;
1804
    const int trainset_size = sample_count - testset_size;
1805
    const int last_testset_size = sample_count - testset_size*(k_fold-1);
1806
    const int last_trainset_size = sample_count - last_testset_size;
1807
    const bool is_regression = (svm_type == EPS_SVR) || (svm_type == NU_SVR);
1808
1809
    size_t resp_elem_size = CV_ELEM_SIZE(responses->type);
1810
    size_t size = 2*last_trainset_size*sizeof(samples[0]);
1811
1812
    samples_local = (const float**) cvAlloc( size );
1813
    memset( samples_local, 0, size );
1814
1815
    responses_local = cvCreateMat( 1, trainset_size, CV_MAT_TYPE(responses->type) );
1816
    cvZero( responses_local );
1817
1818
    // randomly permute samples and responses
1819
    for(int i = 0; i < sample_count; i++ )
1820
    {
1821
        int i1 = (*rng)(sample_count);
1822
        int i2 = (*rng)(sample_count);
1823
        const float* temp;
1824
        float t;
1825
        int y;
1826
1827
        CV_SWAP( samples[i1], samples[i2], temp );
1828
        if( is_regression )
1829
            CV_SWAP( responses->data.fl[i1], responses->data.fl[i2], t );
1830
        else
1831
            CV_SWAP( responses->data.i[i1], responses->data.i[i2], y );
1832
    }
1833
1834
    if (!is_regression && class_labels->cols==2 && balanced)
1835
    {
1836
        // count class samples
1837
        int num_0=0,num_1=0;
1838
        for (int i=0; i<sample_count; ++i)
1839
        {
1840
            if (responses->data.i[i]==class_labels->data.i[0])
1841
                ++num_0;
1842
            else
1843
                ++num_1;
1844
        }
1845
1846
        int label_smallest_class;
1847
        int label_biggest_class;
1848
        if (num_0 < num_1)
1849
        {
1850
            label_biggest_class = class_labels->data.i[1];
1851
            label_smallest_class = class_labels->data.i[0];
1852
        }
1853
        else
1854
        {
1855
            label_biggest_class = class_labels->data.i[0];
1856
            label_smallest_class = class_labels->data.i[1];
1857
            int y;
1858
            CV_SWAP(num_0,num_1,y);
1859
        }
1860
        const double class_ratio = (double) num_0/sample_count;
1861
        // calculate class ratio of each fold
1862
        indexedratio *ratios=0;
1863
        ratios = (indexedratio*) cvAlloc(k_fold*sizeof(*ratios));
1864
        for (int k=0, i_begin=0; k<k_fold; ++k, i_begin+=testset_size)
1865
        {
1866
            int count0=0;
1867
            int count1=0;
1868
            int i_end = i_begin + (k<k_fold-1 ? testset_size : last_testset_size);
1869
            for (int i=i_begin; i<i_end; ++i)
1870
            {
1871
                if (responses->data.i[i]==label_smallest_class)
1872
                    ++count0;
1873
                else
1874
                    ++count1;
1875
            }
1876
            ratios[k].ind = k;
1877
            ratios[k].count_smallest = count0;
1878
            ratios[k].count_biggest = count1;
1879
            ratios[k].eval();
1880
        }
1881
        // initial distance
1882
        qsort(ratios, k_fold, sizeof(ratios[0]), icvCmpIndexedratio);
1883
        double old_dist = 0.0;
1884
        for (int k=0; k<k_fold; ++k)
1885
            old_dist += fabs(ratios[k].val-class_ratio); //fabs for double
1886
        double new_dist = 1.0;
1887
        // iterate to make the folds more balanced
1888
        while (new_dist > 0.0)
1889
        {
1890
            if (ratios[0].count_biggest==0 || ratios[k_fold-1].count_smallest==0)
1891
                break; // we are not able to swap samples anymore
1892
                
1893
            qsort(ratios, k_fold, sizeof(ratios[0]), icvCmpIndexedratio); //reorder for swap to balanced
1894
                
1895
            // what if we swap the samples, calculate the new distance
1896
            ratios[0].count_smallest++;
1897
            ratios[0].count_biggest--;
1898
            ratios[0].eval();
1899
            ratios[k_fold-1].count_smallest--;
1900
            ratios[k_fold-1].count_biggest++;
1901
            ratios[k_fold-1].eval();
1902
            
1903
            new_dist = 0.0;
1904
            for (int k=0; k<k_fold; ++k)
1905
                new_dist += fabs(ratios[k].val-class_ratio);
1906
                
1907
            if (new_dist < old_dist)
1908
            {
1909
                 //Swapping ratios[0].ind <-> ratios[k_fold-1].ind
1910
                // swapping really improves, so swap the samples
1911
                // index of the biggest_class sample from the minimum ratio fold
1912
                int i1 = ratios[0].ind * testset_size;
1913
                for ( ; i1<sample_count; ++i1)
1914
                {
1915
                    if (responses->data.i[i1]==label_biggest_class)
1916
                        break;
1917
                }
1918
                // index of the smallest_class sample from the maximum ratio fold
1919
                int i2 = ratios[k_fold-1].ind * testset_size;
1920
                for ( ; i2<sample_count; ++i2)
1921
                {
1922
                    if (responses->data.i[i2]==label_smallest_class)
1923
                        break;
1924
                }
1925
                // swap
1926
                const float* temp;
1927
                int y;
1928
                CV_SWAP( samples[i1], samples[i2], temp );
1929
                CV_SWAP( responses->data.i[i1], responses->data.i[i2], y );
1930
                old_dist = new_dist;
1931
            }
1932
            else
1933
                break; // does not improve, so break the loop
1934
        }
1935
        cvFree(&ratios);
1936
    }
1937
1938
    int* cls_lbls = class_labels ? class_labels->data.i : 0;
1939
    curr_c = C_grid.min_val;
1940
    do
1941
    {
1942
      params.C = curr_c;
1943
      gamma = gamma_grid.min_val;
1944
      do
1945
      {
1946
        params.gamma = gamma;
1947
        p = p_grid.min_val;
1948
        do
1949
        {
1950
          params.p = p;
1951
          nu = nu_grid.min_val;
1952
          do
1953
          {
1954
            params.nu = nu;
1955
            coef = coef_grid.min_val;
1956
            do
1957
            {
1958
              params.coef0 = coef;
1959
              degree = degree_grid.min_val;
1960
              do
1961
              {
1962
                params.degree = degree;
1963
1964
                float** test_samples_ptr = (float**)samples;
1965
                uchar* true_resp = responses->data.ptr;
1966
                int test_size = testset_size;
1967
                int train_size = trainset_size;
1968
1969
                error = 0;
1970
                for(int k = 0; k < k_fold; k++ )
1971
                {
1972
                    memcpy( samples_local, samples, sizeof(samples[0])*test_size*k );
1973
                    memcpy( samples_local + test_size*k, test_samples_ptr + test_size,
1974
                        sizeof(samples[0])*(sample_count - testset_size*(k+1)) );
1975
1976
                    memcpy( responses_local->data.ptr, responses->data.ptr, resp_elem_size*test_size*k );
1977
                    memcpy( responses_local->data.ptr + resp_elem_size*test_size*k,
1978
                        true_resp + resp_elem_size*test_size,
1979
                        resp_elem_size*(sample_count - testset_size*(k+1)) );
1980
1981
                    if( k == k_fold - 1 )
1982
                    {
1983
                        test_size = last_testset_size;
1984
                        train_size = last_trainset_size;
1985
                        responses_local->cols = last_trainset_size;
1986
                    }
1987
1988
                    // Train SVM on <train_size> samples
1989
                    if( !do_train( svm_type, train_size, var_count,
1990
                        (const float**)samples_local, responses_local, temp_storage, alpha ) )
1991
                        EXIT;
1992
1993
                    // Compute test set error on <test_size> samples
1994
                    for(int i = 0; i < test_size; i++, true_resp += resp_elem_size, test_samples_ptr++ )
1995
                    {
1996
                        float resp = predict( *test_samples_ptr, var_count );
1997
                        error += is_regression ? powf( resp - *(float*)true_resp, 2 )
1998
                            : ((int)resp != cls_lbls[*(int*)true_resp]);
1999
                    }
2000
                }
2001
                if( min_error > error )
2002
                {
2003
                    min_error   = error;
2004
                    best_degree = degree;
2005
                    best_gamma  = gamma;
2006
                    best_coef   = coef;
2007
                    best_C      = curr_c;
2008
                    best_nu     = nu;
2009
                    best_p      = p;
2010
                }
2011
                degree *= degree_grid.step;
2012
              }
2013
              while( degree < degree_grid.max_val );
2014
              coef *= coef_grid.step;
2015
            }
2016
            while( coef < coef_grid.max_val );
2017
            nu *= nu_grid.step;
2018
          }
2019
          while( nu < nu_grid.max_val );
2020
          p *= p_grid.step;
2021
        }
2022
        while( p < p_grid.max_val );
2023
        gamma *= gamma_grid.step;
2024
      }
2025
      while( gamma < gamma_grid.max_val );
2026
      curr_c *= C_grid.step;
2027
    }
2028
    while( curr_c < C_grid.max_val );
2029
    }
2030
2031
    min_error /= (float) sample_count;
2032
2033
    params.C      = best_C;
2034
    params.nu     = best_nu;
2035
    params.p      = best_p;
2036
    params.gamma  = best_gamma;
2037
    params.degree = best_degree;
2038
    params.coef0  = best_coef;
2039
2040
    CV_CALL(ok = do_train( svm_type, sample_count, var_count, samples, responses, temp_storage, alpha ));
2041
2042
    __END__;
2043
2044
    delete solver;
2045
    solver = 0;
2046
    cvReleaseMemStorage( &temp_storage );
2047
    cvReleaseMat( &responses );
2048
    cvReleaseMat( &responses_local );
2049
    cvFree( &samples );
2050
    cvFree( &samples_local );
2051
2052
    if( cvGetErrStatus() < 0 || !ok )
2053
        clear();
2054
2055
    return ok;
2056
}
2057
2058
float CvSVM::predict( const float* row_sample, int row_len, bool returnDFVal ) const
2059
{
2060
    assert( kernel );
2061
    assert( row_sample );
2062
2063
    int var_count = get_var_count();
2064
    assert( row_len == var_count );
2065
    (void)row_len;
2066
2067
    int class_count = class_labels ? class_labels->cols :
2068
                  params.svm_type == ONE_CLASS ? 1 : 0;
2069
2070
    float result = 0;
2071
    cv::AutoBuffer<float> _buffer(sv_total + (class_count+1)*2);
2072
    float* buffer = _buffer;
2073
2074
    if( params.svm_type == EPS_SVR ||
2075
        params.svm_type == NU_SVR ||
2076
        params.svm_type == ONE_CLASS )
2077
    {
2078
        CvSVMDecisionFunc* df = (CvSVMDecisionFunc*)decision_func;
2079
        int i, sv_count = df->sv_count;
2080
        double sum = -df->rho;
2081
2082
        kernel->calc( sv_count, var_count, (const float**)sv, row_sample, buffer );
2083
        for( i = 0; i < sv_count; i++ )
2084
            sum += buffer[i]*df->alpha[i];
2085
2086
        result = params.svm_type == ONE_CLASS ? (float)(sum > 0) : (float)sum;
2087
    }
2088
    else if( params.svm_type == C_SVC ||
2089
             params.svm_type == NU_SVC )
2090
    {
2091
        CvSVMDecisionFunc* df = (CvSVMDecisionFunc*)decision_func;
2092
        int* vote = (int*)(buffer + sv_total);
2093
        int i, j, k;
2094
2095
        memset( vote, 0, class_count*sizeof(vote[0]));
2096
        kernel->calc( sv_total, var_count, (const float**)sv, row_sample, buffer );
2097
        double sum = 0.;
2098
2099
        for( i = 0; i < class_count; i++ )
2100
        {
2101
            for( j = i+1; j < class_count; j++, df++ )
2102
            {
2103
                sum = -df->rho;
2104
                int sv_count = df->sv_count;
2105
                for( k = 0; k < sv_count; k++ )
2106
                    sum += df->alpha[k]*buffer[df->sv_index[k]];
2107
2108
                vote[sum > 0 ? i : j]++;
2109
            }
2110
        }
2111
2112
        for( i = 1, k = 0; i < class_count; i++ )
2113
        {
2114
            if( vote[i] > vote[k] )
2115
                k = i;
2116
        }
2117
        result = returnDFVal && class_count == 2 ? (float)sum : (float)(class_labels->data.i[k]);
2118
    }
2119
    else
2120
        CV_Error( CV_StsBadArg, "INTERNAL ERROR: Unknown SVM type, "
2121
                                "the SVM structure is probably corrupted" );
2122
2123
    return result;
2124
}
2125
2126
float CvSVM::predict( const CvMat* sample, bool returnDFVal ) const
2127
{
2128
    float result = 0;
2129
    float* row_sample = 0;
2130
2131
    CV_FUNCNAME( "CvSVM::predict" );
2132
2133
    __BEGIN__;
2134
2135
    int class_count;
2136
2137
    if( !kernel )
2138
        CV_ERROR( CV_StsBadArg, "The SVM should be trained first" );
2139
2140
    class_count = class_labels ? class_labels->cols :
2141
                  params.svm_type == ONE_CLASS ? 1 : 0;
2142
2143
    CV_CALL( cvPreparePredictData( sample, var_all, var_idx,
2144
                                   class_count, 0, &row_sample ));
2145
    result = predict( row_sample, get_var_count(), returnDFVal );
2146
2147
    __END__;
2148
2149
    if( sample && (!CV_IS_MAT(sample) || sample->data.fl != row_sample) )
2150
        cvFree( &row_sample );
2151
2152
    return result;
2153
}
2154
2155
struct predict_body_svm : ParallelLoopBody {
2156
    predict_body_svm(const CvSVM* _pointer, float* _result, const CvMat* _samples, CvMat* _results)
2157
    {
2158
        pointer = _pointer;
2159
        result = _result;
2160
        samples = _samples;
2161
        results = _results;
2162
    }
2163
2164
    const CvSVM* pointer;
2165
    float* result;
2166
    const CvMat* samples;
2167
    CvMat* results;
2168
2169
    void operator()( const cv::Range& range ) const
2170
    {
2171
        for(int i = range.start; i < range.end; i++ )
2172
        {
2173
            CvMat sample;
2174
            cvGetRow( samples, &sample, i );
2175
            int r = (int)pointer->predict(&sample);
2176
            if (results)
2177
                results->data.fl[i] = (float)r;
2178
            if (i == 0)
2179
                *result = (float)r;
2180
    }
2181
    }
2182
};
2183
2184
float CvSVM::predict(const CvMat* samples, CV_OUT CvMat* results) const
2185
{
2186
    float result = 0;
2187
    cv::parallel_for_(cv::Range(0, samples->rows),
2188
             predict_body_svm(this, &result, samples, results)
2189
    );
2190
    return result;
2191
}
2192
2193
void CvSVM::predict( cv::InputArray _samples, cv::OutputArray _results ) const
2194
{
2195
    _results.create(_samples.size().height, 1, CV_32F);
2196
    CvMat samples = _samples.getMat(), results = _results.getMat();
2197
    predict(&samples, &results);
2198
}
2199
2200
CvSVM::CvSVM( const Mat& _train_data, const Mat& _responses,
2201
              const Mat& _var_idx, const Mat& _sample_idx, CvSVMParams _params )
2202
{
2203
    decision_func = 0;
2204
    class_labels = 0;
2205
    class_weights = 0;
2206
    storage = 0;
2207
    var_idx = 0;
2208
    kernel = 0;
2209
    solver = 0;
2210
    default_model_name = "my_svm";
2211
2212
    train( _train_data, _responses, _var_idx, _sample_idx, _params );
2213
}
2214
2215
bool CvSVM::train( const Mat& _train_data, const Mat& _responses,
2216
                  const Mat& _var_idx, const Mat& _sample_idx, CvSVMParams _params )
2217
{
2218
    CvMat tdata = _train_data, responses = _responses, vidx = _var_idx, sidx = _sample_idx;
2219
    return train(&tdata, &responses, vidx.data.ptr ? &vidx : 0, sidx.data.ptr ? &sidx : 0, _params);
2220
}
2221
2222
2223
bool CvSVM::train_auto( const Mat& _train_data, const Mat& _responses,
2224
                       const Mat& _var_idx, const Mat& _sample_idx, CvSVMParams _params, int k_fold,
2225
                       CvParamGrid C_grid, CvParamGrid gamma_grid, CvParamGrid p_grid,
2226
                       CvParamGrid nu_grid, CvParamGrid coef_grid, CvParamGrid degree_grid, bool balanced )
2227
{
2228
    CvMat tdata = _train_data, responses = _responses, vidx = _var_idx, sidx = _sample_idx;
2229
    return train_auto(&tdata, &responses, vidx.data.ptr ? &vidx : 0,
2230
                      sidx.data.ptr ? &sidx : 0, _params, k_fold, C_grid, gamma_grid, p_grid,
2231
                      nu_grid, coef_grid, degree_grid, balanced);
2232
}
2233
2234
float CvSVM::predict( const Mat& _sample, bool returnDFVal ) const
2235
{
2236
    CvMat sample = _sample;
2237
    return predict(&sample, returnDFVal);
2238
}
2239
2240
2241
void CvSVM::write_params( CvFileStorage* fs ) const
2242
{
2243
    //CV_FUNCNAME( "CvSVM::write_params" );
2244
2245
    __BEGIN__;
2246
2247
    int svm_type = params.svm_type;
2248
    int kernel_type = params.kernel_type;
2249
2250
    const char* svm_type_str =
2251
        svm_type == CvSVM::C_SVC ? "C_SVC" :
2252
        svm_type == CvSVM::NU_SVC ? "NU_SVC" :
2253
        svm_type == CvSVM::ONE_CLASS ? "ONE_CLASS" :
2254
        svm_type == CvSVM::EPS_SVR ? "EPS_SVR" :
2255
        svm_type == CvSVM::NU_SVR ? "NU_SVR" : 0;
2256
    const char* kernel_type_str =
2257
        kernel_type == CvSVM::LINEAR ? "LINEAR" :
2258
        kernel_type == CvSVM::POLY ? "POLY" :
2259
        kernel_type == CvSVM::RBF ? "RBF" :
2260
        kernel_type == CvSVM::SIGMOID ? "SIGMOID" : 0;
2261
2262
    if( svm_type_str )
2263
        cvWriteString( fs, "svm_type", svm_type_str );
2264
    else
2265
        cvWriteInt( fs, "svm_type", svm_type );
2266
2267
    // save kernel
2268
    cvStartWriteStruct( fs, "kernel", CV_NODE_MAP + CV_NODE_FLOW );
2269
2270
    if( kernel_type_str )
2271
        cvWriteString( fs, "type", kernel_type_str );
2272
    else
2273
        cvWriteInt( fs, "type", kernel_type );
2274
2275
    if( kernel_type == CvSVM::POLY || !kernel_type_str )
2276
        cvWriteReal( fs, "degree", params.degree );
2277
2278
    if( kernel_type != CvSVM::LINEAR || !kernel_type_str )
2279
        cvWriteReal( fs, "gamma", params.gamma );
2280
2281
    if( kernel_type == CvSVM::POLY || kernel_type == CvSVM::SIGMOID || !kernel_type_str )
2282
        cvWriteReal( fs, "coef0", params.coef0 );
2283
2284
    cvEndWriteStruct(fs);
2285
2286
    if( svm_type == CvSVM::C_SVC || svm_type == CvSVM::EPS_SVR ||
2287
        svm_type == CvSVM::NU_SVR || !svm_type_str )
2288
        cvWriteReal( fs, "C", params.C );
2289
2290
    if( svm_type == CvSVM::NU_SVC || svm_type == CvSVM::ONE_CLASS ||
2291
        svm_type == CvSVM::NU_SVR || !svm_type_str )
2292
        cvWriteReal( fs, "nu", params.nu );
2293
2294
    if( svm_type == CvSVM::EPS_SVR || !svm_type_str )
2295
        cvWriteReal( fs, "p", params.p );
2296
2297
    cvStartWriteStruct( fs, "term_criteria", CV_NODE_MAP + CV_NODE_FLOW );
2298
    if( params.term_crit.type & CV_TERMCRIT_EPS )
2299
        cvWriteReal( fs, "epsilon", params.term_crit.epsilon );
2300
    if( params.term_crit.type & CV_TERMCRIT_ITER )
2301
        cvWriteInt( fs, "iterations", params.term_crit.max_iter );
2302
    cvEndWriteStruct( fs );
2303
2304
    __END__;
2305
}
2306
2307
2308
static bool isSvmModelApplicable(int sv_total, int var_all, int var_count, int class_count)
2309
{
2310
    return (sv_total > 0 && var_count > 0 && var_count <= var_all && class_count >= 0);
2311
}
2312
2313
2314
void CvSVM::write( CvFileStorage* fs, const char* name ) const
2315
{
2316
    CV_FUNCNAME( "CvSVM::write" );
2317
2318
    __BEGIN__;
2319
2320
    int i, var_count = get_var_count(), df_count;
2321
    int class_count = class_labels ? class_labels->cols :
2322
                      params.svm_type == CvSVM::ONE_CLASS ? 1 : 0;
2323
    const CvSVMDecisionFunc* df = decision_func;
2324
    if( !isSvmModelApplicable(sv_total, var_all, var_count, class_count) )
2325
        CV_ERROR( CV_StsParseError, "SVM model data is invalid, check sv_count, var_* and class_count tags" );
2326
2327
    cvStartWriteStruct( fs, name, CV_NODE_MAP, CV_TYPE_NAME_ML_SVM );
2328
2329
    write_params( fs );
2330
2331
    cvWriteInt( fs, "var_all", var_all );
2332
    cvWriteInt( fs, "var_count", var_count );
2333
2334
    if( class_count )
2335
    {
2336
        cvWriteInt( fs, "class_count", class_count );
2337
2338
        if( class_labels )
2339
            cvWrite( fs, "class_labels", class_labels );
2340
2341
        if( class_weights )
2342
            cvWrite( fs, "class_weights", class_weights );
2343
    }
2344
2345
    if( var_idx )
2346
        cvWrite( fs, "var_idx", var_idx );
2347
2348
    // write the joint collection of support vectors
2349
    cvWriteInt( fs, "sv_total", sv_total );
2350
    cvStartWriteStruct( fs, "support_vectors", CV_NODE_SEQ );
2351
    for( i = 0; i < sv_total; i++ )
2352
    {
2353
        cvStartWriteStruct( fs, 0, CV_NODE_SEQ + CV_NODE_FLOW );
2354
        cvWriteRawData( fs, sv[i], var_count, "f" );
2355
        cvEndWriteStruct( fs );
2356
    }
2357
2358
    cvEndWriteStruct( fs );
2359
2360
    // write decision functions
2361
    df_count = class_count > 1 ? class_count*(class_count-1)/2 : 1;
2362
    df = decision_func;
2363
2364
    cvStartWriteStruct( fs, "decision_functions", CV_NODE_SEQ );
2365
    for( i = 0; i < df_count; i++ )
2366
    {
2367
        int sv_count = df[i].sv_count;
2368
        cvStartWriteStruct( fs, 0, CV_NODE_MAP );
2369
        cvWriteInt( fs, "sv_count", sv_count );
2370
        cvWriteReal( fs, "rho", df[i].rho );
2371
        cvStartWriteStruct( fs, "alpha", CV_NODE_SEQ+CV_NODE_FLOW );
2372
        cvWriteRawData( fs, df[i].alpha, df[i].sv_count, "d" );
2373
        cvEndWriteStruct( fs );
2374
        if( class_count > 1 )
2375
        {
2376
            cvStartWriteStruct( fs, "index", CV_NODE_SEQ+CV_NODE_FLOW );
2377
            cvWriteRawData( fs, df[i].sv_index, df[i].sv_count, "i" );
2378
            cvEndWriteStruct( fs );
2379
        }
2380
        else
2381
            CV_ASSERT( sv_count == sv_total );
2382
        cvEndWriteStruct( fs );
2383
    }
2384
    cvEndWriteStruct( fs );
2385
    cvEndWriteStruct( fs );
2386
2387
    __END__;
2388
}
2389
2390
2391
void CvSVM::read_params( CvFileStorage* fs, CvFileNode* svm_node )
2392
{
2393
    CV_FUNCNAME( "CvSVM::read_params" );
2394
2395
    __BEGIN__;
2396
2397
    int svm_type, kernel_type;
2398
    CvSVMParams _params;
2399
2400
    CvFileNode* tmp_node = cvGetFileNodeByName( fs, svm_node, "svm_type" );
2401
    CvFileNode* kernel_node;
2402
    if( !tmp_node )
2403
        CV_ERROR( CV_StsBadArg, "svm_type tag is not found" );
2404
2405
    if( CV_NODE_TYPE(tmp_node->tag) == CV_NODE_INT )
2406
        svm_type = cvReadInt( tmp_node, -1 );
2407
    else
2408
    {
2409
        const char* svm_type_str = cvReadString( tmp_node, "" );
2410
        svm_type =
2411
            strcmp( svm_type_str, "C_SVC" ) == 0 ? CvSVM::C_SVC :
2412
            strcmp( svm_type_str, "NU_SVC" ) == 0 ? CvSVM::NU_SVC :
2413
            strcmp( svm_type_str, "ONE_CLASS" ) == 0 ? CvSVM::ONE_CLASS :
2414
            strcmp( svm_type_str, "EPS_SVR" ) == 0 ? CvSVM::EPS_SVR :
2415
            strcmp( svm_type_str, "NU_SVR" ) == 0 ? CvSVM::NU_SVR : -1;
2416
2417
        if( svm_type < 0 )
2418
            CV_ERROR( CV_StsParseError, "Missing of invalid SVM type" );
2419
    }
2420
2421
    kernel_node = cvGetFileNodeByName( fs, svm_node, "kernel" );
2422
    if( !kernel_node )
2423
        CV_ERROR( CV_StsParseError, "SVM kernel tag is not found" );
2424
2425
    tmp_node = cvGetFileNodeByName( fs, kernel_node, "type" );
2426
    if( !tmp_node )
2427
        CV_ERROR( CV_StsParseError, "SVM kernel type tag is not found" );
2428
2429
    if( CV_NODE_TYPE(tmp_node->tag) == CV_NODE_INT )
2430
        kernel_type = cvReadInt( tmp_node, -1 );
2431
    else
2432
    {
2433
        const char* kernel_type_str = cvReadString( tmp_node, "" );
2434
        kernel_type =
2435
            strcmp( kernel_type_str, "LINEAR" ) == 0 ? CvSVM::LINEAR :
2436
            strcmp( kernel_type_str, "POLY" ) == 0 ? CvSVM::POLY :
2437
            strcmp( kernel_type_str, "RBF" ) == 0 ? CvSVM::RBF :
2438
            strcmp( kernel_type_str, "SIGMOID" ) == 0 ? CvSVM::SIGMOID : -1;
2439
2440
        if( kernel_type < 0 )
2441
            CV_ERROR( CV_StsParseError, "Missing of invalid SVM kernel type" );
2442
    }
2443
2444
    _params.svm_type = svm_type;
2445
    _params.kernel_type = kernel_type;
2446
    _params.degree = cvReadRealByName( fs, kernel_node, "degree", 0 );
2447
    _params.gamma = cvReadRealByName( fs, kernel_node, "gamma", 0 );
2448
    _params.coef0 = cvReadRealByName( fs, kernel_node, "coef0", 0 );
2449
2450
    _params.C = cvReadRealByName( fs, svm_node, "C", 0 );
2451
    _params.nu = cvReadRealByName( fs, svm_node, "nu", 0 );
2452
    _params.p = cvReadRealByName( fs, svm_node, "p", 0 );
2453
    _params.class_weights = 0;
2454
2455
    tmp_node = cvGetFileNodeByName( fs, svm_node, "term_criteria" );
2456
    if( tmp_node )
2457
    {
2458
        _params.term_crit.epsilon = cvReadRealByName( fs, tmp_node, "epsilon", -1. );
2459
        _params.term_crit.max_iter = cvReadIntByName( fs, tmp_node, "iterations", -1 );
2460
        _params.term_crit.type = (_params.term_crit.epsilon >= 0 ? CV_TERMCRIT_EPS : 0) +
2461
                               (_params.term_crit.max_iter >= 0 ? CV_TERMCRIT_ITER : 0);
2462
    }
2463
    else
2464
        _params.term_crit = cvTermCriteria( CV_TERMCRIT_EPS + CV_TERMCRIT_ITER, 1000, FLT_EPSILON );
2465
2466
    set_params( _params );
2467
2468
    __END__;
2469
}
2470
2471
void CvSVM::read( CvFileStorage* fs, CvFileNode* svm_node )
2472
{
2473
    const double not_found_dbl = DBL_MAX;
2474
2475
    CV_FUNCNAME( "CvSVM::read" );
2476
2477
    __BEGIN__;
2478
2479
    int i, var_count, df_count, class_count;
2480
    int block_size = 1 << 16, sv_size;
2481
    CvFileNode *sv_node, *df_node;
2482
    CvSVMDecisionFunc* df;
2483
    CvSeqReader reader;
2484
2485
    if( !svm_node )
2486
        CV_ERROR( CV_StsParseError, "The requested element is not found" );
2487
2488
    clear();
2489
2490
    // read SVM parameters
2491
    read_params( fs, svm_node );
2492
2493
    // and top-level data
2494
    sv_total = cvReadIntByName( fs, svm_node, "sv_total", -1 );
2495
    var_all = cvReadIntByName( fs, svm_node, "var_all", -1 );
2496
    var_count = cvReadIntByName( fs, svm_node, "var_count", var_all );
2497
    class_count = cvReadIntByName( fs, svm_node, "class_count", 0 );
2498
2499
    if( !isSvmModelApplicable(sv_total, var_all, var_count, class_count) )
2500
        CV_ERROR( CV_StsParseError, "SVM model data is invalid, check sv_count, var_* and class_count tags" );
2501
2502
    CV_CALL( class_labels = (CvMat*)cvReadByName( fs, svm_node, "class_labels" ));
2503
    CV_CALL( class_weights = (CvMat*)cvReadByName( fs, svm_node, "class_weights" ));
2504
    CV_CALL( var_idx = (CvMat*)cvReadByName( fs, svm_node, "var_idx" ));
2505
2506
    if( class_count > 1 && (!class_labels ||
2507
        !CV_IS_MAT(class_labels) || class_labels->cols != class_count))
2508
        CV_ERROR( CV_StsParseError, "Array of class labels is missing or invalid" );
2509
2510
    if( var_count < var_all && (!var_idx || !CV_IS_MAT(var_idx) || var_idx->cols != var_count) )
2511
        CV_ERROR( CV_StsParseError, "var_idx array is missing or invalid" );
2512
2513
    // read support vectors
2514
    sv_node = cvGetFileNodeByName( fs, svm_node, "support_vectors" );
2515
    if( !sv_node || !CV_NODE_IS_SEQ(sv_node->tag))
2516
        CV_ERROR( CV_StsParseError, "Missing or invalid sequence of support vectors" );
2517
2518
    block_size = MAX( block_size, sv_total*(int)sizeof(CvSVMKernelRow));
2519
    block_size = MAX( block_size, sv_total*2*(int)sizeof(double));
2520
    block_size = MAX( block_size, var_all*(int)sizeof(double));
2521
2522
    CV_CALL( storage = cvCreateMemStorage(block_size + sizeof(CvMemBlock) + sizeof(CvSeqBlock)));
2523
    CV_CALL( sv = (float**)cvMemStorageAlloc( storage,
2524
                                sv_total*sizeof(sv[0]) ));
2525
2526
    CV_CALL( cvStartReadSeq( sv_node->data.seq, &reader, 0 ));
2527
    sv_size = var_count*sizeof(sv[0][0]);
2528
2529
    for( i = 0; i < sv_total; i++ )
2530
    {
2531
        CvFileNode* sv_elem = (CvFileNode*)reader.ptr;
2532
        CV_ASSERT( var_count == 1 || (CV_NODE_IS_SEQ(sv_elem->tag) &&
2533
                   sv_elem->data.seq->total == var_count) );
2534
2535
        CV_CALL( sv[i] = (float*)cvMemStorageAlloc( storage, sv_size ));
2536
        CV_CALL( cvReadRawData( fs, sv_elem, sv[i], "f" ));
2537
        CV_NEXT_SEQ_ELEM( sv_node->data.seq->elem_size, reader );
2538
    }
2539
2540
    // read decision functions
2541
    df_count = class_count > 1 ? class_count*(class_count-1)/2 : 1;
2542
    df_node = cvGetFileNodeByName( fs, svm_node, "decision_functions" );
2543
    if( !df_node || !CV_NODE_IS_SEQ(df_node->tag) ||
2544
        df_node->data.seq->total != df_count )
2545
        CV_ERROR( CV_StsParseError, "decision_functions is missing or is not a collection "
2546
                  "or has a wrong number of elements" );
2547
2548
    CV_CALL( df = decision_func = (CvSVMDecisionFunc*)cvAlloc( df_count*sizeof(df[0]) ));
2549
    cvStartReadSeq( df_node->data.seq, &reader, 0 );
2550
2551
    for( i = 0; i < df_count; i++ )
2552
    {
2553
        CvFileNode* df_elem = (CvFileNode*)reader.ptr;
2554
        CvFileNode* alpha_node = cvGetFileNodeByName( fs, df_elem, "alpha" );
2555
2556
        int sv_count = cvReadIntByName( fs, df_elem, "sv_count", -1 );
2557
        if( sv_count <= 0 )
2558
            CV_ERROR( CV_StsParseError, "sv_count is missing or non-positive" );
2559
        df[i].sv_count = sv_count;
2560
2561
        df[i].rho = cvReadRealByName( fs, df_elem, "rho", not_found_dbl );
2562
        if( fabs(df[i].rho - not_found_dbl) < DBL_EPSILON )
2563
            CV_ERROR( CV_StsParseError, "rho is missing" );
2564
2565
        if( !alpha_node )
2566
            CV_ERROR( CV_StsParseError, "alpha is missing in the decision function" );
2567
2568
        CV_CALL( df[i].alpha = (double*)cvMemStorageAlloc( storage,
2569
                                        sv_count*sizeof(df[i].alpha[0])));
2570
        CV_ASSERT( sv_count == 1 || (CV_NODE_IS_SEQ(alpha_node->tag) &&
2571
                   alpha_node->data.seq->total == sv_count) );
2572
        CV_CALL( cvReadRawData( fs, alpha_node, df[i].alpha, "d" ));
2573
2574
        if( class_count > 1 )
2575
        {
2576
            CvFileNode* index_node = cvGetFileNodeByName( fs, df_elem, "index" );
2577
            if( !index_node )
2578
                CV_ERROR( CV_StsParseError, "index is missing in the decision function" );
2579
            CV_CALL( df[i].sv_index = (int*)cvMemStorageAlloc( storage,
2580
                                            sv_count*sizeof(df[i].sv_index[0])));
2581
            CV_ASSERT( sv_count == 1 || (CV_NODE_IS_SEQ(index_node->tag) &&
2582
                   index_node->data.seq->total == sv_count) );
2583
            CV_CALL( cvReadRawData( fs, index_node, df[i].sv_index, "i" ));
2584
        }
2585
        else
2586
            df[i].sv_index = 0;
2587
2588
        CV_NEXT_SEQ_ELEM( df_node->data.seq->elem_size, reader );
2589
    }
2590
2591
    if( cvReadIntByName(fs, svm_node, "optimize_linear", 1) != 0 )
2592
        optimize_linear_svm();
2593
    create_kernel();
2594
2595
    __END__;
2596
}
2597
2598
#if 0
2599
2600
static void*
2601
icvCloneSVM( const void* _src )
2602
{
2603
    CvSVMModel* dst = 0;
2604
2605
    CV_FUNCNAME( "icvCloneSVM" );
2606
2607
    __BEGIN__;
2608
2609
    const CvSVMModel* src = (const CvSVMModel*)_src;
2610
    int var_count, class_count;
2611
    int i, sv_total, df_count;
2612
    int sv_size;
2613
2614
    if( !CV_IS_SVM(src) )
2615
        CV_ERROR( !src ? CV_StsNullPtr : CV_StsBadArg, "Input pointer is NULL or invalid" );
2616
2617
    // 0. create initial CvSVMModel structure
2618
    CV_CALL( dst = icvCreateSVM() );
2619
    dst->params = src->params;
2620
    dst->params.weight_labels = 0;
2621
    dst->params.weights = 0;
2622
2623
    dst->var_all = src->var_all;
2624
    if( src->class_labels )
2625
        dst->class_labels = cvCloneMat( src->class_labels );
2626
    if( src->class_weights )
2627
        dst->class_weights = cvCloneMat( src->class_weights );
2628
    if( src->comp_idx )
2629
        dst->comp_idx = cvCloneMat( src->comp_idx );
2630
2631
    var_count = src->comp_idx ? src->comp_idx->cols : src->var_all;
2632
    class_count = src->class_labels ? src->class_labels->cols :
2633
                  src->params.svm_type == CvSVM::ONE_CLASS ? 1 : 0;
2634
    sv_total = dst->sv_total = src->sv_total;
2635
    CV_CALL( dst->storage = cvCreateMemStorage( src->storage->block_size ));
2636
    CV_CALL( dst->sv = (float**)cvMemStorageAlloc( dst->storage,
2637
                                    sv_total*sizeof(dst->sv[0]) ));
2638
2639
    sv_size = var_count*sizeof(dst->sv[0][0]);
2640
2641
    for( i = 0; i < sv_total; i++ )
2642
    {
2643
        CV_CALL( dst->sv[i] = (float*)cvMemStorageAlloc( dst->storage, sv_size ));
2644
        memcpy( dst->sv[i], src->sv[i], sv_size );
2645
    }
2646
2647
    df_count = class_count > 1 ? class_count*(class_count-1)/2 : 1;
2648
2649
    CV_CALL( dst->decision_func = cvAlloc( df_count*sizeof(CvSVMDecisionFunc) ));
2650
2651
    for( i = 0; i < df_count; i++ )
2652
    {
2653
        const CvSVMDecisionFunc *sdf =
2654
            (const CvSVMDecisionFunc*)src->decision_func+i;
2655
        CvSVMDecisionFunc *ddf =
2656
            (CvSVMDecisionFunc*)dst->decision_func+i;
2657
        int sv_count = sdf->sv_count;
2658
        ddf->sv_count = sv_count;
2659
        ddf->rho = sdf->rho;
2660
        CV_CALL( ddf->alpha = (double*)cvMemStorageAlloc( dst->storage,
2661
                                        sv_count*sizeof(ddf->alpha[0])));
2662
        memcpy( ddf->alpha, sdf->alpha, sv_count*sizeof(ddf->alpha[0]));
2663
2664
        if( class_count > 1 )
2665
        {
2666
            CV_CALL( ddf->sv_index = (int*)cvMemStorageAlloc( dst->storage,
2667
                                                sv_count*sizeof(ddf->sv_index[0])));
2668
            memcpy( ddf->sv_index, sdf->sv_index, sv_count*sizeof(ddf->sv_index[0]));
2669
        }
2670
        else
2671
            ddf->sv_index = 0;
2672
    }
2673
2674
    __END__;
2675
2676
    if( cvGetErrStatus() < 0 && dst )
2677
        icvReleaseSVM( &dst );
2678
2679
    return dst;
2680
}
2681
2682
static int icvRegisterSVMType()
2683
{
2684
    CvTypeInfo info;
2685
    memset( &info, 0, sizeof(info) );
2686
2687
    info.flags = 0;
2688
    info.header_size = sizeof( info );
2689
    info.is_instance = icvIsSVM;
2690
    info.release = (CvReleaseFunc)icvReleaseSVM;
2691
    info.read = icvReadSVM;
2692
    info.write = icvWriteSVM;
2693
    info.clone = icvCloneSVM;
2694
    info.type_name = CV_TYPE_NAME_ML_SVM;
2695
    cvRegisterType( &info );
2696
2697
    return 1;
2698
}
2699
2700
2701
static int svm = icvRegisterSVMType();
2702
2703
/* The function trains SVM model with optimal parameters, obtained by using cross-validation.
2704
The parameters to be estimated should be indicated by setting theirs values to FLT_MAX.
2705
The optimal parameters are saved in <model_params> */
2706
CV_IMPL CvStatModel*
2707
cvTrainSVM_CrossValidation( const CvMat* train_data, int tflag,
2708
            const CvMat* responses,
2709
            CvStatModelParams* model_params,
2710
            const CvStatModelParams* cross_valid_params,
2711
            const CvMat* comp_idx,
2712
            const CvMat* sample_idx,
2713
            const CvParamGrid* degree_grid,
2714
            const CvParamGrid* gamma_grid,
2715
            const CvParamGrid* coef_grid,
2716
            const CvParamGrid* C_grid,
2717
            const CvParamGrid* nu_grid,
2718
            const CvParamGrid* p_grid )
2719
{
2720
    CvStatModel* svm = 0;
2721
2722
    CV_FUNCNAME("cvTainSVMCrossValidation");
2723
    __BEGIN__;
2724
2725
    double degree_step = 7,
2726
           g_step      = 15,
2727
           coef_step   = 14,
2728
           C_step      = 20,
2729
           nu_step     = 5,
2730
           p_step      = 7; // all steps must be > 1
2731
    double degree_begin = 0.01, degree_end = 2;
2732
    double g_begin      = 1e-5, g_end      = 0.5;
2733
    double coef_begin   = 0.1,  coef_end   = 300;
2734
    double C_begin      = 0.1,  C_end      = 6000;
2735
    double nu_begin     = 0.01,  nu_end    = 0.4;
2736
    double p_begin      = 0.01, p_end      = 100;
2737
2738
    double rate = 0, gamma = 0, C = 0, degree = 0, coef = 0, p = 0, nu = 0;
2739
2740
    double best_rate    = 0;
2741
    double best_degree  = degree_begin;
2742
    double best_gamma   = g_begin;
2743
    double best_coef    = coef_begin;
2744
    double best_C       = C_begin;
2745
    double best_nu      = nu_begin;
2746
    double best_p       = p_begin;
2747
2748
    CvSVMModelParams svm_params, *psvm_params;
2749
    CvCrossValidationParams* cv_params = (CvCrossValidationParams*)cross_valid_params;
2750
    int svm_type, kernel;
2751
    int is_regression;
2752
2753
    if( !model_params )
2754
        CV_ERROR( CV_StsBadArg, "" );
2755
    if( !cv_params )
2756
        CV_ERROR( CV_StsBadArg, "" );
2757
2758
    svm_params = *(CvSVMModelParams*)model_params;
2759
    psvm_params = (CvSVMModelParams*)model_params;
2760
    svm_type = svm_params.svm_type;
2761
    kernel = svm_params.kernel_type;
2762
2763
    svm_params.degree = svm_params.degree > 0 ? svm_params.degree : 1;
2764
    svm_params.gamma = svm_params.gamma > 0 ? svm_params.gamma : 1;
2765
    svm_params.coef0 = svm_params.coef0 > 0 ? svm_params.coef0 : 1e-6;
2766
    svm_params.C = svm_params.C > 0 ? svm_params.C : 1;
2767
    svm_params.nu = svm_params.nu > 0 ? svm_params.nu : 1;
2768
    svm_params.p = svm_params.p > 0 ? svm_params.p : 1;
2769
2770
    if( degree_grid )
2771
    {
2772
        if( !(degree_grid->max_val == 0 && degree_grid->min_val == 0 &&
2773
              degree_grid->step == 0) )
2774
        {
2775
            if( degree_grid->min_val > degree_grid->max_val )
2776
                CV_ERROR( CV_StsBadArg,
2777
                "low bound of grid should be less then the upper one");
2778
            if( degree_grid->step <= 1 )
2779
                CV_ERROR( CV_StsBadArg, "grid step should be greater 1" );
2780
            degree_begin = degree_grid->min_val;
2781
            degree_end   = degree_grid->max_val;
2782
            degree_step  = degree_grid->step;
2783
        }
2784
    }
2785
    else
2786
        degree_begin = degree_end = svm_params.degree;
2787
2788
    if( gamma_grid )
2789
    {
2790
        if( !(gamma_grid->max_val == 0 && gamma_grid->min_val == 0 &&
2791
              gamma_grid->step == 0) )
2792
        {
2793
            if( gamma_grid->min_val > gamma_grid->max_val )
2794
                CV_ERROR( CV_StsBadArg,
2795
                "low bound of grid should be less then the upper one");
2796
            if( gamma_grid->step <= 1 )
2797
                CV_ERROR( CV_StsBadArg, "grid step should be greater 1" );
2798
            g_begin = gamma_grid->min_val;
2799
            g_end   = gamma_grid->max_val;
2800
            g_step  = gamma_grid->step;
2801
        }
2802
    }
2803
    else
2804
        g_begin = g_end = svm_params.gamma;
2805
2806
    if( coef_grid )
2807
    {
2808
        if( !(coef_grid->max_val == 0 && coef_grid->min_val == 0 &&
2809
              coef_grid->step == 0) )
2810
        {
2811
            if( coef_grid->min_val > coef_grid->max_val )
2812
                CV_ERROR( CV_StsBadArg,
2813
                "low bound of grid should be less then the upper one");
2814
            if( coef_grid->step <= 1 )
2815
                CV_ERROR( CV_StsBadArg, "grid step should be greater 1" );
2816
            coef_begin = coef_grid->min_val;
2817
            coef_end   = coef_grid->max_val;
2818
            coef_step  = coef_grid->step;
2819
        }
2820
    }
2821
    else
2822
        coef_begin = coef_end = svm_params.coef0;
2823
2824
    if( C_grid )
2825
    {
2826
        if( !(C_grid->max_val == 0 && C_grid->min_val == 0 && C_grid->step == 0))
2827
        {
2828
            if( C_grid->min_val > C_grid->max_val )
2829
                CV_ERROR( CV_StsBadArg,
2830
                "low bound of grid should be less then the upper one");
2831
            if( C_grid->step <= 1 )
2832
                CV_ERROR( CV_StsBadArg, "grid step should be greater 1" );
2833
            C_begin = C_grid->min_val;
2834
            C_end   = C_grid->max_val;
2835
            C_step  = C_grid->step;
2836
        }
2837
    }
2838
    else
2839
        C_begin = C_end = svm_params.C;
2840
2841
    if( nu_grid )
2842
    {
2843
        if(!(nu_grid->max_val == 0 && nu_grid->min_val == 0 && nu_grid->step==0))
2844
        {
2845
            if( nu_grid->min_val > nu_grid->max_val )
2846
                CV_ERROR( CV_StsBadArg,
2847
                "low bound of grid should be less then the upper one");
2848
            if( nu_grid->step <= 1 )
2849
                CV_ERROR( CV_StsBadArg, "grid step should be greater 1" );
2850
            nu_begin = nu_grid->min_val;
2851
            nu_end   = nu_grid->max_val;
2852
            nu_step  = nu_grid->step;
2853
        }
2854
    }
2855
    else
2856
        nu_begin = nu_end = svm_params.nu;
2857
2858
    if( p_grid )
2859
    {
2860
        if( !(p_grid->max_val == 0 && p_grid->min_val == 0 && p_grid->step == 0))
2861
        {
2862
            if( p_grid->min_val > p_grid->max_val )
2863
                CV_ERROR( CV_StsBadArg,
2864
                "low bound of grid should be less then the upper one");
2865
            if( p_grid->step <= 1 )
2866
                CV_ERROR( CV_StsBadArg, "grid step should be greater 1" );
2867
            p_begin = p_grid->min_val;
2868
            p_end   = p_grid->max_val;
2869
            p_step  = p_grid->step;
2870
        }
2871
    }
2872
    else
2873
        p_begin = p_end = svm_params.p;
2874
2875
    // these parameters are not used:
2876
    if( kernel != CvSVM::POLY )
2877
        degree_begin = degree_end = svm_params.degree;
2878
2879
   if( kernel == CvSVM::LINEAR )
2880
        g_begin = g_end = svm_params.gamma;
2881
2882
    if( kernel != CvSVM::POLY && kernel != CvSVM::SIGMOID )
2883
        coef_begin = coef_end = svm_params.coef0;
2884
2885
    if( svm_type == CvSVM::NU_SVC || svm_type == CvSVM::ONE_CLASS )
2886
        C_begin = C_end = svm_params.C;
2887
2888
    if( svm_type == CvSVM::C_SVC || svm_type == CvSVM::EPS_SVR )
2889
        nu_begin = nu_end = svm_params.nu;
2890
2891
    if( svm_type != CvSVM::EPS_SVR )
2892
        p_begin = p_end = svm_params.p;
2893
2894
    is_regression = cv_params->is_regression;
2895
    best_rate = is_regression ? FLT_MAX : 0;
2896
2897
    assert( g_step > 1 && degree_step > 1 && coef_step > 1);
2898
    assert( p_step > 1 && C_step > 1 && nu_step > 1 );
2899
2900
    for( degree = degree_begin; degree <= degree_end; degree *= degree_step )
2901
    {
2902
      svm_params.degree = degree;
2903
      //printf("degree = %.3f\n", degree );
2904
      for( gamma= g_begin; gamma <= g_end; gamma *= g_step )
2905
      {
2906
        svm_params.gamma = gamma;
2907
        //printf("   gamma = %.3f\n", gamma );
2908
        for( coef = coef_begin; coef <= coef_end; coef *= coef_step )
2909
        {
2910
          svm_params.coef0 = coef;
2911
          //printf("      coef = %.3f\n", coef );
2912
          for( C = C_begin; C <= C_end; C *= C_step )
2913
          {
2914
            svm_params.C = C;
2915
            //printf("         C = %.3f\n", C );
2916
            for( nu = nu_begin; nu <= nu_end; nu *= nu_step )
2917
            {
2918
              svm_params.nu = nu;
2919
              //printf("            nu = %.3f\n", nu );
2920
              for( p = p_begin; p <= p_end; p *= p_step )
2921
              {
2922
                int well;
2923
                svm_params.p = p;
2924
                //printf("               p = %.3f\n", p );
2925
2926
                CV_CALL(rate = cvCrossValidation( train_data, tflag, responses, &cvTrainSVM,
2927
                    cross_valid_params, (CvStatModelParams*)&svm_params, comp_idx, sample_idx ));
2928
2929
                well =  rate > best_rate && !is_regression || rate < best_rate && is_regression;
2930
                if( well || (rate == best_rate && C < best_C) )
2931
                {
2932
                    best_rate   = rate;
2933
                    best_degree = degree;
2934
                    best_gamma  = gamma;
2935
                    best_coef   = coef;
2936
                    best_C      = C;
2937
                    best_nu     = nu;
2938
                    best_p      = p;
2939
                }
2940
                //printf("                  rate = %.2f\n", rate );
2941
              }
2942
            }
2943
          }
2944
        }
2945
      }
2946
    }
2947
    //printf("The best:\nrate = %.2f%% degree = %f gamma = %f coef = %f c = %f nu = %f p = %f\n",
2948
      //  best_rate, best_degree, best_gamma, best_coef, best_C, best_nu, best_p );
2949
2950
    psvm_params->C      = best_C;
2951
    psvm_params->nu     = best_nu;
2952
    psvm_params->p      = best_p;
2953
    psvm_params->gamma  = best_gamma;
2954
    psvm_params->degree = best_degree;
2955
    psvm_params->coef0  = best_coef;
2956
2957
    CV_CALL(svm = cvTrainSVM( train_data, tflag, responses, model_params, comp_idx, sample_idx ));
2958
2959
    __END__;
2960
2961
    return svm;
2962
}
2963
2964
#endif
2965
2966
/* End of file. */