miniflann.cpp

modified miniflann.cpp - sebastien wybo, 2011-08-25 03:05 pm

Download (26.8 kB)

 
1
#include "precomp.hpp"
2
3
#define MINIFLANN_SUPPORT_EXOTIC_DISTANCE_TYPES 0
4
5
static cvflann::IndexParams& get_params(const cv::flann::IndexParams& p)
6
{
7
    return *(cvflann::IndexParams*)(p.params);
8
}
9
10
namespace cv
11
{
12
    
13
namespace flann
14
{
15
16
IndexParams::IndexParams()
17
{
18
    params = new ::cvflann::IndexParams();
19
}
20
    
21
IndexParams::~IndexParams()
22
{
23
    delete &get_params(*this);
24
}
25
26
template<typename T>
27
T getParam(const IndexParams& _p, const std::string& key, const T& defaultVal=T())
28
{
29
    ::cvflann::IndexParams& p = get_params(_p);
30
    ::cvflann::IndexParams::const_iterator it = p.find(key);
31
    if( it == p.end() )
32
        return defaultVal;
33
    return it->second.cast<T>();
34
}
35
36
template<typename T>
37
void setParam(IndexParams& _p, const std::string& key, const T& value)
38
{
39
    ::cvflann::IndexParams& p = get_params(_p);
40
    p[key] = value;
41
}    
42
    
43
std::string IndexParams::getString(const std::string& key, const std::string& defaultVal) const
44
{
45
    return getParam(*this, key, defaultVal);
46
}
47
    
48
int IndexParams::getInt(const std::string& key, int defaultVal) const
49
{
50
    return getParam(*this, key, defaultVal);
51
}
52
    
53
double IndexParams::getDouble(const std::string& key, double defaultVal) const
54
{
55
    return getParam(*this, key, defaultVal);
56
}
57
58
    
59
void IndexParams::setString(const std::string& key, const std::string& value)
60
{
61
    setParam(*this, key, value);
62
}
63
64
void IndexParams::setInt(const std::string& key, int value)
65
{
66
    setParam(*this, key, value);
67
}
68
69
void IndexParams::setDouble(const std::string& key, double value)
70
{
71
    setParam(*this, key, value);
72
}
73
74
void IndexParams::setFloat(const std::string& key, float value)
75
{
76
    setParam(*this, key, value);
77
}
78
79
void IndexParams::setBool(const std::string& key, bool value)
80
{
81
    setParam(*this, key, value);
82
}
83
84
void IndexParams::setAlgorithm(int value)
85
{
86
    setParam(*this, "algorithm", (cvflann::flann_algorithm_t)value);
87
}
88
    
89
void IndexParams::getAll(std::vector<std::string>& names,
90
            std::vector<int>& types,
91
            std::vector<std::string>& strValues,
92
            std::vector<double>& numValues) const
93
{
94
    names.clear();
95
    types.clear();
96
    strValues.clear();
97
    numValues.clear();
98
    
99
    ::cvflann::IndexParams& p = get_params(*this);
100
    ::cvflann::IndexParams::const_iterator it = p.begin(), it_end = p.end();
101
    
102
    for( ; it != it_end; ++it )
103
    {
104
        names.push_back(it->first);
105
        try
106
        {
107
            std::string val = it->second.cast<std::string>();
108
            types.push_back(CV_USRTYPE1);
109
            strValues.push_back(val);
110
            numValues.push_back(-1);
111
            continue;
112
        }
113
        catch (...) {}
114
        
115
        strValues.push_back(it->second.type().name());
116
        
117
        try
118
        {
119
            double val = it->second.cast<double>();
120
            types.push_back( CV_64F );
121
            numValues.push_back(val);
122
            continue;
123
        }
124
        catch (...) {}
125
        try
126
        {
127
            float val = it->second.cast<float>();
128
            types.push_back( CV_32F );
129
            numValues.push_back(val);
130
            continue;
131
        }
132
        catch (...) {}
133
        try
134
        {
135
            int val = it->second.cast<int>();
136
            types.push_back( CV_32S );
137
            numValues.push_back(val);
138
            continue;
139
        }
140
        catch (...) {}
141
        try
142
        {
143
            short val = it->second.cast<short>();
144
            types.push_back( CV_16S );
145
            numValues.push_back(val);
146
            continue;
147
        }
148
        catch (...) {}
149
        try
150
        {
151
            ushort val = it->second.cast<ushort>();
152
            types.push_back( CV_16U );
153
            numValues.push_back(val);
154
            continue;
155
        }
156
        catch (...) {}
157
        try
158
        {
159
            char val = it->second.cast<char>();
160
            types.push_back( CV_8S );
161
            numValues.push_back(val);
162
            continue;
163
        }
164
        catch (...) {}
165
        try
166
        {
167
            uchar val = it->second.cast<uchar>();
168
            types.push_back( CV_8U );
169
            numValues.push_back(val);
170
            continue;
171
        }
172
        catch (...) {}
173
        try
174
        {
175
            bool val = it->second.cast<bool>();
176
            types.push_back( CV_MAKETYPE(CV_USRTYPE1,2) );
177
            numValues.push_back(val);
178
            continue;
179
        }
180
        catch (...) {}
181
        try
182
        {
183
            cvflann::flann_algorithm_t val = it->second.cast<cvflann::flann_algorithm_t>();
184
            types.push_back( CV_MAKETYPE(CV_USRTYPE1,3) );
185
            numValues.push_back(val);
186
            continue;
187
        }
188
        catch (...) {}
189
190
191
        types.push_back(-1); // unknown type
192
        numValues.push_back(-1);
193
    }
194
}
195
    
196
    
197
KDTreeIndexParams::KDTreeIndexParams(int trees)
198
{
199
    ::cvflann::IndexParams& p = get_params(*this);
200
    p["algorithm"] = ::cvflann::FLANN_INDEX_KDTREE;
201
    p["trees"] = trees;
202
}
203
204
LinearIndexParams::LinearIndexParams()
205
{
206
    ::cvflann::IndexParams& p = get_params(*this);
207
    p["algorithm"] = ::cvflann::FLANN_INDEX_LINEAR;
208
}
209
210
CompositeIndexParams::CompositeIndexParams(int trees, int branching, int iterations,
211
                             ::cvflann::flann_centers_init_t centers_init, float cb_index )
212
{
213
    ::cvflann::IndexParams& p = get_params(*this);
214
    p["algorithm"] = ::cvflann::FLANN_INDEX_KMEANS;
215
    // number of randomized trees to use (for kdtree)
216
    p["trees"] = trees;
217
    // branching factor
218
    p["branching"] = branching;
219
    // max iterations to perform in one kmeans clustering (kmeans tree)
220
    p["iterations"] = iterations;
221
    // algorithm used for picking the initial cluster centers for kmeans tree
222
    p["centers_init"] = centers_init;
223
    // cluster boundary index. Used when searching the kmeans tree
224
    p["cb_index"] = cb_index;
225
}
226
    
227
AutotunedIndexParams::AutotunedIndexParams(float target_precision, float build_weight,
228
                                           float memory_weight, float sample_fraction)
229
{
230
    ::cvflann::IndexParams& p = get_params(*this);
231
    p["algorithm"] = ::cvflann::FLANN_INDEX_AUTOTUNED;
232
    // precision desired (used for autotuning, -1 otherwise)
233
    p["target_precision"] = target_precision;
234
    // build tree time weighting factor
235
    p["build_weight"] = build_weight;
236
    // index memory weighting factor
237
    p["memory_weight"] = memory_weight;
238
    // what fraction of the dataset to use for autotuning
239
    p["sample_fraction"] = sample_fraction;
240
}
241
    
242
243
KMeansIndexParams::KMeansIndexParams(int branching, int iterations,
244
                  ::cvflann::flann_centers_init_t centers_init, float cb_index )
245
{
246
    ::cvflann::IndexParams& p = get_params(*this);
247
    p["algorithm"] = ::cvflann::FLANN_INDEX_KMEANS;
248
    // branching factor
249
    p["branching"] = branching;
250
    // max iterations to perform in one kmeans clustering (kmeans tree)
251
    p["iterations"] = iterations;
252
    // algorithm used for picking the initial cluster centers for kmeans tree
253
    p["centers_init"] = centers_init;
254
    // cluster boundary index. Used when searching the kmeans tree
255
    p["cb_index"] = cb_index;
256
}
257
    
258
LshIndexParams::LshIndexParams(int table_number, int key_size, int multi_probe_level)
259
{
260
    ::cvflann::IndexParams& p = get_params(*this);
261
    p["algorithm"] = ::cvflann::FLANN_INDEX_LSH;
262
    // The number of hash tables to use
263
    p["table_number"] = (unsigned)table_number;
264
    // The length of the key in the hash tables
265
    p["key_size"] = (unsigned)key_size;
266
    // Number of levels to use in multi-probe (0 for standard LSH)
267
    p["multi_probe_level"] = (unsigned)multi_probe_level;
268
}    
269
    
270
SavedIndexParams::SavedIndexParams(const std::string& _filename)
271
{
272
    std::string filename = _filename;
273
    ::cvflann::IndexParams& p = get_params(*this);
274
    
275
    p["algorithm"] = ::cvflann::FLANN_INDEX_SAVED;
276
    p["filename"] = filename;
277
}
278
    
279
SearchParams::SearchParams( int checks, float eps, bool sorted )
280
{
281
    ::cvflann::IndexParams& p = get_params(*this);
282
    
283
    // how many leafs to visit when searching for neighbours (-1 for unlimited)
284
    p["checks"] = checks;
285
    // search for eps-approximate neighbours (default: 0)
286
    p["eps"] = eps;
287
    // only for radius search, require neighbours sorted by distance (default: true)
288
    p["sorted"] = sorted;
289
}    
290
291
    
292
template<typename Distance, typename IndexType> void
293
buildIndex_(void*& index, const Mat& data, const IndexParams& params, const Distance& dist = Distance())
294
{
295
    typedef typename Distance::ElementType ElementType;
296
    if(DataType<ElementType>::type != data.type())
297
        CV_Error_(CV_StsUnsupportedFormat, ("type=%d\n", data.type()));
298
    if(!data.isContinuous())
299
        CV_Error(CV_StsBadArg, "Only continuous arrays are supported");
300
    
301
    ::cvflann::Matrix<ElementType> dataset((ElementType*)data.data, data.rows, data.cols);
302
    IndexType* _index = new IndexType(dataset, get_params(params), dist);
303
    _index->buildIndex();
304
    index = _index;
305
}
306
307
template<typename Distance> void
308
buildIndex(void*& index, const Mat& data, const IndexParams& params, const Distance& dist = Distance())
309
{
310
    buildIndex_<Distance, ::cvflann::Index<Distance> >(index, data, params, dist);
311
}
312
313
typedef ::cvflann::HammingLUT HammingDistance;
314
typedef ::cvflann::LshIndex<HammingDistance> LshIndex;
315
316
Index::Index()
317
{
318
    index = 0;
319
    featureType = CV_32F;
320
    algo = ::cvflann::FLANN_INDEX_LINEAR;
321
    distType = ::cvflann::FLANN_DIST_L2;
322
}
323
    
324
Index::Index(InputArray _data, const IndexParams& params, ::cvflann::flann_distance_t _distType)
325
{
326
    index = 0;
327
    featureType = CV_32F;
328
    algo = ::cvflann::FLANN_INDEX_LINEAR;
329
    distType = ::cvflann::FLANN_DIST_L2;
330
    build(_data, params, _distType);
331
}
332
    
333
void Index::build(InputArray _data, const IndexParams& params, ::cvflann::flann_distance_t _distType)
334
{
335
    release();
336
    algo = getParam<::cvflann::flann_algorithm_t>(params, "algorithm", ::cvflann::FLANN_INDEX_LINEAR);
337
    if( algo == ::cvflann::FLANN_INDEX_SAVED )
338
    {
339
        load(_data, getParam<std::string>(params, "filename", std::string()));
340
        return;
341
    }
342
    
343
    Mat data = _data.getMat();
344
    index = 0;
345
    featureType = data.type();
346
    distType = _distType;
347
348
    if( algo == ::cvflann::FLANN_INDEX_LSH )
349
    {
350
        buildIndex_<HammingDistance, LshIndex>(index, data, params);
351
        return;
352
    }
353
    
354
    switch( distType )
355
    {
356
    case ::cvflann::FLANN_DIST_L2:
357
        buildIndex< ::cvflann::L2<float> >(index, data, params);
358
        break;
359
    case ::cvflann::FLANN_DIST_L1:
360
        buildIndex< ::cvflann::L1<float> >(index, data, params);
361
        break;
362
#if MINIFLANN_SUPPORT_EXOTIC_DISTANCE_TYPES
363
    case ::cvflann::FLANN_DIST_MAX:
364
        buildIndex< ::cvflann::MaxDistance<float> >(index, data, params);
365
        break;
366
    case ::cvflann::FLANN_DIST_HIST_INTERSECT:
367
        buildIndex< ::cvflann::HistIntersectionDistance<float> >(index, data, params);
368
        break;
369
    case ::cvflann::FLANN_DIST_HELLINGER:
370
        buildIndex< ::cvflann::HellingerDistance<float> >(index, data, params);
371
        break;
372
    case ::cvflann::FLANN_DIST_CHI_SQUARE:
373
        buildIndex< ::cvflann::ChiSquareDistance<float> >(index, data, params);
374
        break;
375
    case ::cvflann::FLANN_DIST_KL:
376
        buildIndex< ::cvflann::KL_Divergence<float> >(index, data, params);
377
        break;
378
#endif
379
    default:
380
        CV_Error(CV_StsBadArg, "Unknown/unsupported distance type");
381
    }
382
}
383
384
template<typename IndexType> void deleteIndex_(void* index)
385
{
386
    delete (IndexType*)index;
387
}
388
389
template<typename Distance> void deleteIndex(void* index)
390
{
391
    deleteIndex_< ::cvflann::Index<Distance> >(index);
392
}
393
    
394
Index::~Index()
395
{
396
    release();
397
}
398
    
399
void Index::release()
400
{
401
    if( !index )
402
        return;
403
    if( algo == ::cvflann::FLANN_INDEX_LSH )
404
    {
405
        deleteIndex_<LshIndex>(index);
406
    }
407
    else
408
    {
409
        CV_Assert( featureType == CV_32F );
410
        switch( distType )
411
        {
412
        case ::cvflann::FLANN_DIST_L2:
413
            deleteIndex< ::cvflann::L2<float> >(index);
414
            break;
415
        case ::cvflann::FLANN_DIST_L1:
416
            deleteIndex< ::cvflann::L1<float> >(index);
417
            break;
418
#if MINIFLANN_SUPPORT_EXOTIC_DISTANCE_TYPES
419
        case ::cvflann::FLANN_DIST_MAX:
420
            deleteIndex< ::cvflann::MaxDistance<float> >(index);
421
            break;
422
        case ::cvflann::FLANN_DIST_HIST_INTERSECT:
423
            deleteIndex< ::cvflann::HistIntersectionDistance<float> >(index);
424
            break;
425
        case ::cvflann::FLANN_DIST_HELLINGER:
426
            deleteIndex< ::cvflann::HellingerDistance<float> >(index);
427
            break;
428
        case ::cvflann::FLANN_DIST_CHI_SQUARE:
429
            deleteIndex< ::cvflann::ChiSquareDistance<float> >(index);
430
            break;
431
        case ::cvflann::FLANN_DIST_KL:
432
            deleteIndex< ::cvflann::KL_Divergence<float> >(index);
433
            break;
434
#endif
435
        default:
436
            CV_Error(CV_StsBadArg, "Unknown/unsupported distance type");
437
        }
438
    }
439
    index = 0;
440
}
441
442
template<typename Distance, typename IndexType>
443
void runKnnSearch_(void* index, const Mat& query, Mat& indices, Mat& dists,
444
                  int knn, const SearchParams& params)
445
{
446
    typedef typename Distance::ElementType ElementType;
447
    typedef typename Distance::ResultType DistanceType;
448
    int type = DataType<ElementType>::type;
449
    int dtype = DataType<DistanceType>::type;
450
    CV_Assert(query.type() == type && indices.type() == CV_32S && dists.type() == dtype);
451
    CV_Assert(query.isContinuous() && indices.isContinuous() && dists.isContinuous());
452
    
453
    ::cvflann::Matrix<ElementType> _query((ElementType*)query.data, query.rows, query.cols);
454
    ::cvflann::Matrix<int> _indices((int*)indices.data, indices.rows, indices.cols);
455
    ::cvflann::Matrix<DistanceType> _dists((DistanceType*)dists.data, dists.rows, dists.cols);
456
    
457
    ((IndexType*)index)->knnSearch(_query, _indices, _dists, knn,
458
                                   (const ::cvflann::SearchParams&)get_params(params));
459
}
460
    
461
template<typename Distance>
462
void runKnnSearch(void* index, const Mat& query, Mat& indices, Mat& dists,
463
                  int knn, const SearchParams& params)
464
{
465
    runKnnSearch_<Distance, ::cvflann::Index<Distance> >(index, query, indices, dists, knn, params);
466
}
467
468
template<typename Distance, typename IndexType>
469
int runRadiusSearch_(void* index, const Mat& query, Mat& indices, Mat& dists,
470
                    double radius, const SearchParams& params)
471
{
472
    typedef typename Distance::ElementType ElementType;
473
    typedef typename Distance::ResultType DistanceType;
474
    int type = DataType<ElementType>::type;
475
    int dtype = DataType<DistanceType>::type;
476
    CV_Assert(query.type() == type && indices.type() == CV_32S && dists.type() == dtype);
477
    CV_Assert(query.isContinuous() && indices.isContinuous() && dists.isContinuous());
478
    
479
    ::cvflann::Matrix<ElementType> _query((ElementType*)query.data, query.rows, query.cols);
480
    ::cvflann::Matrix<int> _indices((int*)indices.data, indices.rows, indices.cols);
481
    ::cvflann::Matrix<DistanceType> _dists((DistanceType*)dists.data, dists.rows, dists.cols);
482
    
483
    return ((IndexType*)index)->radiusSearch(_query, _indices, _dists,
484
                                            saturate_cast<DistanceType>(radius),
485
                                            (const ::cvflann::SearchParams&)get_params(params));
486
}
487
488
template<typename Distance>
489
int runRadiusSearch(void* index, const Mat& query, Mat& indices, Mat& dists,
490
                     double radius, const SearchParams& params)
491
{
492
    return runRadiusSearch_<Distance, ::cvflann::Index<Distance> >(index, query, indices, dists, radius, params);
493
}
494
        
495
    
496
static void createIndicesDists(OutputArray _indices, OutputArray _dists,
497
                               Mat& indices, Mat& dists, int rows,
498
                               int minCols, int maxCols, int dtype)
499
{
500
    if( _indices.needed() )
501
    {
502
        indices = _indices.getMat();
503
        if( !indices.isContinuous() || indices.type() != CV_32S ||
504
            indices.rows != rows || indices.cols < minCols || indices.cols > maxCols )
505
        {
506
            if( !indices.isContinuous() )
507
               _indices.release();
508
            _indices.create( rows, minCols, CV_32S );
509
            indices = _indices.getMat();
510
        }
511
    }
512
    else
513
        indices.create( rows, minCols, CV_32S );
514
    
515
    if( _dists.needed() )
516
    {
517
        dists = _dists.getMat();
518
        if( !dists.isContinuous() || dists.type() != dtype ||
519
           dists.rows != rows || dists.cols < minCols || dists.cols > maxCols )
520
        {
521
            if( !indices.isContinuous() )
522
                _dists.release();
523
            _dists.create( rows, minCols, dtype );
524
            dists = _dists.getMat();
525
        }
526
    }
527
    else
528
        dists.create( rows, minCols, dtype );
529
}
530
531
    
532
void Index::knnSearch(InputArray _query, OutputArray _indices, 
533
               OutputArray _dists, int knn, const SearchParams& params)
534
{
535
    Mat query = _query.getMat(), indices, dists;
536
    int dtype = algo == ::cvflann::FLANN_INDEX_LSH ? CV_32S : CV_32F;
537
    
538
    createIndicesDists( _indices, _dists, indices, dists, query.rows, knn, knn, dtype );
539
    
540
    if( algo == ::cvflann::FLANN_INDEX_LSH )
541
    {
542
        runKnnSearch_<HammingDistance, LshIndex>(index, query, indices, dists, knn, params);
543
        return;
544
    }
545
    
546
    switch( distType )
547
    {
548
    case ::cvflann::FLANN_DIST_L2:
549
        runKnnSearch< ::cvflann::L2<float> >(index, query, indices, dists, knn, params);
550
        break;
551
    case ::cvflann::FLANN_DIST_L1:
552
        runKnnSearch< ::cvflann::L1<float> >(index, query, indices, dists, knn, params);
553
        break;
554
#if MINIFLANN_SUPPORT_EXOTIC_DISTANCE_TYPES
555
    case ::cvflann::FLANN_DIST_MAX:
556
        runKnnSearch< ::cvflann::MaxDistance<float> >(index, query, indices, dists, knn, params);
557
        break;
558
    case ::cvflann::FLANN_DIST_HIST_INTERSECT:
559
        runKnnSearch< ::cvflann::HistIntersectionDistance<float> >(index, query, indices, dists, knn, params);
560
        break;
561
    case ::cvflann::FLANN_DIST_HELLINGER:
562
        runKnnSearch< ::cvflann::HellingerDistance<float> >(index, query, indices, dists, knn, params);
563
        break;
564
    case ::cvflann::FLANN_DIST_CHI_SQUARE:
565
        runKnnSearch< ::cvflann::ChiSquareDistance<float> >(index, query, indices, dists, knn, params);
566
        break;
567
    case ::cvflann::FLANN_DIST_KL:
568
        runKnnSearch< ::cvflann::KL_Divergence<float> >(index, query, indices, dists, knn, params);
569
        break;
570
#endif
571
    default:
572
        CV_Error(CV_StsBadArg, "Unknown/unsupported distance type");
573
    }
574
}
575
        
576
int Index::radiusSearch(InputArray _query, OutputArray _indices,
577
                        OutputArray _dists, double radius, int maxResults,
578
                        const SearchParams& params)
579
{
580
    Mat query = _query.getMat(), indices, dists;
581
    int dtype = algo == ::cvflann::FLANN_INDEX_LSH ? CV_32S : CV_32F;
582
    CV_Assert( maxResults > 0 );
583
    createIndicesDists( _indices, _dists, indices, dists, query.rows, maxResults, INT_MAX, dtype );
584
    
585
    if( algo == ::cvflann::FLANN_INDEX_LSH )
586
        CV_Error( CV_StsNotImplemented, "LSH index does not support radiusSearch operation" );
587
    
588
    switch( distType )
589
    {
590
    case ::cvflann::FLANN_DIST_L2:
591
        return runRadiusSearch< ::cvflann::L2<float> >(index, query, indices, dists, radius, params);
592
    case ::cvflann::FLANN_DIST_L1:
593
        return runRadiusSearch< ::cvflann::L1<float> >(index, query, indices, dists, radius, params);
594
#if MINIFLANN_SUPPORT_EXOTIC_DISTANCE_TYPES
595
    case ::cvflann::FLANN_DIST_MAX:
596
        return runRadiusSearch< ::cvflann::MaxDistance<float> >(index, query, indices, dists, radius, params);
597
    case ::cvflann::FLANN_DIST_HIST_INTERSECT:
598
        return runRadiusSearch< ::cvflann::HistIntersectionDistance<float> >(index, query, indices, dists, radius, params);
599
    case ::cvflann::FLANN_DIST_HELLINGER:
600
        return runRadiusSearch< ::cvflann::HellingerDistance<float> >(index, query, indices, dists, radius, params);
601
    case ::cvflann::FLANN_DIST_CHI_SQUARE:
602
        return runRadiusSearch< ::cvflann::ChiSquareDistance<float> >(index, query, indices, dists, radius, params);
603
    case ::cvflann::FLANN_DIST_KL:
604
        return runRadiusSearch< ::cvflann::KL_Divergence<float> >(index, query, indices, dists, radius, params);
605
#endif
606
    default:
607
        CV_Error(CV_StsBadArg, "Unknown/unsupported distance type");
608
    }
609
    return -1;
610
}
611
612
::cvflann::flann_distance_t Index::getDistance() const
613
{
614
    return distType;
615
}
616
    
617
::cvflann::flann_algorithm_t Index::getAlgorithm() const
618
{
619
    return algo;
620
}
621
622
template<typename IndexType> void saveIndex_(const Index* index0, const void* index, FILE* fout)
623
{
624
    IndexType* _index = (IndexType*)index;
625
    ::cvflann::save_header(fout, *_index);
626
    // some compilers may store short enumerations as bytes,
627
    // so make sure we always write integers (which are 4-byte values in any modern C compiler)
628
    int idistType = (int)index0->getDistance();
629
    ::cvflann::save_value<int>(fout, idistType);
630
    _index->saveIndex(fout);
631
}
632
633
template<typename Distance> void saveIndex(const Index* index0, const void* index, FILE* fout)
634
{
635
    saveIndex_< ::cvflann::Index<Distance> >(index0, index, fout);
636
}   
637
    
638
void Index::save(const std::string& filename) const
639
{
640
    FILE* fout = fopen(filename.c_str(), "wb");
641
    if (fout == NULL)
642
        CV_Error_( CV_StsError, ("Can not open file %s for writing FLANN index\n", filename.c_str()) );
643
    
644
    if( algo == ::cvflann::FLANN_INDEX_LSH )
645
    {
646
        saveIndex_<LshIndex>(this, index, fout);
647
        return;
648
    }
649
    
650
    switch( distType )
651
    {
652
    case ::cvflann::FLANN_DIST_L2:
653
        saveIndex< ::cvflann::L2<float> >(this, index, fout);
654
        break;
655
    case ::cvflann::FLANN_DIST_L1:
656
        saveIndex< ::cvflann::L1<float> >(this, index, fout);
657
        break;
658
#if MINIFLANN_SUPPORT_EXOTIC_DISTANCE_TYPES
659
    case ::cvflann::FLANN_DIST_MAX:
660
        saveIndex< ::cvflann::MaxDistance<float> >(this, index, fout);
661
        break;
662
    case ::cvflann::FLANN_DIST_HIST_INTERSECT:
663
        saveIndex< ::cvflann::HistIntersectionDistance<float> >(this, index, fout);
664
        break;
665
    case ::cvflann::FLANN_DIST_HELLINGER:
666
        saveIndex< ::cvflann::HellingerDistance<float> >(this, index, fout);
667
        break;
668
    case ::cvflann::FLANN_DIST_CHI_SQUARE:
669
        saveIndex< ::cvflann::ChiSquareDistance<float> >(this, index, fout);
670
        break;
671
    case ::cvflann::FLANN_DIST_KL:
672
        saveIndex< ::cvflann::KL_Divergence<float> >(this, index, fout);
673
        break;
674
#endif
675
    default:
676
        fclose(fout);
677
        fout = 0;
678
        CV_Error(CV_StsBadArg, "Unknown/unsupported distance type");
679
    }
680
    if( fout )
681
        fclose(fout);
682
}
683
684
685
template<typename Distance, typename IndexType>
686
bool loadIndex_(Index* index0, void*& index, const Mat& data, FILE* fin, const Distance& dist=Distance())
687
{
688
    typedef typename Distance::ElementType ElementType;
689
    CV_Assert(DataType<ElementType>::type == data.type() && data.isContinuous());
690
    
691
    ::cvflann::Matrix<ElementType> dataset((ElementType*)data.data, data.rows, data.cols);
692
    
693
    ::cvflann::IndexParams params;
694
    params["algorithm"] = index0->getAlgorithm();
695
    IndexType* _index = new IndexType(dataset, params, dist);
696
    _index->loadIndex(fin);
697
    index = _index;
698
    return true;
699
}
700
701
template<typename Distance>
702
bool loadIndex(Index* index0, void*& index, const Mat& data, FILE* fin, const Distance& dist=Distance())
703
{
704
    return loadIndex_<Distance, ::cvflann::Index<Distance> >(index0, index, data, fin, dist);
705
}    
706
    
707
bool Index::load(InputArray _data, const std::string& filename)
708
{
709
    Mat data = _data.getMat();
710
    bool ok = true;
711
    release();
712
    FILE* fin = fopen(filename.c_str(), "rb");
713
    if (fin == NULL)
714
        return false;
715
    
716
    ::cvflann::IndexHeader header = ::cvflann::load_header(fin);
717
    algo = header.index_type;
718
    featureType = header.data_type == ::cvflann::FLANN_UINT8 ? CV_8U :
719
                  header.data_type == ::cvflann::FLANN_INT8 ? CV_8S :
720
                  header.data_type == ::cvflann::FLANN_UINT16 ? CV_16U :
721
                  header.data_type == ::cvflann::FLANN_INT16 ? CV_16S :
722
                  header.data_type == ::cvflann::FLANN_INT32 ? CV_32S :
723
                  header.data_type == ::cvflann::FLANN_FLOAT32 ? CV_32F :
724
                  header.data_type == ::cvflann::FLANN_FLOAT64 ? CV_64F : -1;
725
    
726
    if( (int)header.rows != data.rows || (int)header.cols != data.cols ||
727
        featureType != data.type() )
728
    {
729
        fprintf(stderr, "Reading FLANN index error: the saved data size (%d, %d) or type (%d) is different from the passed one (%d, %d), %d\n",
730
                (int)header.rows, (int)header.cols, featureType, data.rows, data.cols, data.type());
731
        fclose(fin);
732
        return false;
733
    }
734
    
735
    if( !((algo == ::cvflann::FLANN_INDEX_LSH && featureType == CV_8U) ||
736
          (algo != ::cvflann::FLANN_INDEX_LSH && featureType == CV_32F)) )
737
    {
738
        fprintf(stderr, "Reading FLANN index error: unsupported feature type %d for the index type %d\n", featureType, algo);
739
        fclose(fin);
740
        return false;
741
    }
742
    int idistType = 0;
743
    ::cvflann::load_value(fin, idistType);
744
    distType = (::cvflann::flann_distance_t)idistType;
745
    
746
    if( algo == ::cvflann::FLANN_INDEX_LSH )
747
    {
748
        loadIndex_<HammingDistance, LshIndex>(this, index, data, fin);
749
    }
750
    else
751
    {
752
        switch( distType )
753
        {
754
        case ::cvflann::FLANN_DIST_L2:
755
            loadIndex< ::cvflann::L2<float> >(this, index, data, fin);
756
            break;
757
        case ::cvflann::FLANN_DIST_L1:
758
            loadIndex< ::cvflann::L1<float> >(this, index, data, fin);
759
            break;
760
    #if MINIFLANN_SUPPORT_EXOTIC_DISTANCE_TYPES
761
        case ::cvflann::FLANN_DIST_MAX:
762
            loadIndex< ::cvflann::MaxDistance<float> >(this, index, data, fin);
763
            break;
764
        case ::cvflann::FLANN_DIST_HIST_INTERSECT:
765
            loadIndex< ::cvflann::HistIntersectionDistance<float> >(index, data, fin);
766
            break;
767
        case ::cvflann::FLANN_DIST_HELLINGER:
768
            loadIndex< ::cvflann::HellingerDistance<float> >(this, index, data, fin);
769
            break;
770
        case ::cvflann::FLANN_DIST_CHI_SQUARE:
771
            loadIndex< ::cvflann::ChiSquareDistance<float> >(this, index, data, fin);
772
            break;
773
        case ::cvflann::FLANN_DIST_KL:
774
            loadIndex< ::cvflann::KL_Divergence<float> >(this, index, data, fin);
775
            break;
776
    #endif
777
        default:
778
            fprintf(stderr, "Reading FLANN index error: unsupported distance type %d\n", distType);
779
            ok = false;
780
        }
781
    }
782
    
783
    if( fin )
784
        fclose(fin);
785
    return ok;
786
}
787
    
788
}
789
    
790
}