bagofwords_scan.cpp

Joel Mckay, 2012-05-04 09:18 pm

Download (18.2 kB)

 
1
/* Warning: This sample is still rough... I cleaned it up a bit, but have not tested it yet.. */
2
3
/*****************************************************************************************
4
        This program reads in a generic trained VOC2010 sample xml params, vocabulary,
5
        and configuration. It works in conjunction with OpenCV's sample code 
6
        bagofwords_classification.cpp training class, and is mostly based on its class design.        
7
8
        2012  Joel Mckay        
9
        [email protected]
10
 
11
        Disclaimer:
12
        THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
13
        "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 
14
        LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
15
        FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE 
16
        COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, 
17
        INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
18
        BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 
19
        LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 
20
        CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT 
21
        LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY
22
        WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY
23
        OF SUCH DAMAGE. 
24
25
 *****************************************************************************************/
26
27
28
#include "global_headers.hpp"
29
 
30
                        
31
const string paramsFile = "params.xml";
32
const string vocabularyFile = "vocabulary.xml.gz";
33
const string bowImageDescriptorsDir = "/bowImageDescriptors";
34
const string svmsDir = "/svms";
35
const string plotsDir = "/plots"; 
36
37
/****************************************************************************************\
38
*                          OpenCV's  Sample on image classification                             *
39
\****************************************************************************************/
40
//
41
// This part of the code was a little refactor
42
//
43
struct DDMParams
44
{
45
    DDMParams() : detectorType("SURF"), descriptorType("SURF"), matcherType("BruteForce") {}
46
    DDMParams( const string _detectorType, const string _descriptorType, const string& _matcherType ) :
47
        detectorType(_detectorType), descriptorType(_descriptorType), matcherType(_matcherType){}
48
    void read( const FileNode& fn )
49
    {
50
        fn["detectorType"] >> detectorType;
51
        fn["descriptorType"] >> descriptorType;
52
        fn["matcherType"] >> matcherType;
53
    }
54
    void write( FileStorage& fs ) const
55
    {
56
        fs << "detectorType" << detectorType;
57
        fs << "descriptorType" << descriptorType;
58
        fs << "matcherType" << matcherType;
59
    }
60
    void print() const
61
    {
62
        cout << "detectorType: " << detectorType << endl;
63
        cout << "descriptorType: " << descriptorType << endl;
64
        cout << "matcherType: " << matcherType << endl;
65
    }
66
67
    string detectorType;
68
    string descriptorType;
69
    string matcherType;
70
};
71
72
struct VocabTrainParams
73
{ 
74
    VocabTrainParams() : trainObjClass("chair"), vocabSize(VISUAL_VOCABULARY_FOR_BOG), memoryUse(VISUAL_VOCABULARY_MEMORY_LIMIT), descProportion(VISUAL_VOCABULARY_DESCRIPTORS_FROM_EACH_IMAGE_PROPORTION_OF_TOTAL) {}
75
    VocabTrainParams( const string _trainObjClass, size_t _vocabSize, size_t _memoryUse, float _descProportion ) :
76
            trainObjClass(_trainObjClass), vocabSize(_vocabSize), memoryUse(_memoryUse), descProportion(_descProportion) {}
77
    void read( const FileNode& fn )
78
    {
79
        fn["trainObjClass"] >> trainObjClass;
80
        fn["vocabSize"] >> vocabSize;
81
        fn["memoryUse"] >> memoryUse;
82
        fn["descProportion"] >> descProportion;
83
    }
84
    void write( FileStorage& fs ) const
85
    {
86
        fs << "trainObjClass" << trainObjClass;
87
        fs << "vocabSize" << vocabSize;
88
        fs << "memoryUse" << memoryUse;
89
        fs << "descProportion" << descProportion;
90
    }
91
    void print() const
92
    {
93
        cout << "trainObjClass: " << trainObjClass << endl;
94
        cout << "vocabSize: " << vocabSize << endl;
95
        cout << "memoryUse: " << memoryUse << endl;
96
        cout << "descProportion: " << descProportion << endl;
97
    }
98
99
100
    string trainObjClass; // Object class used for training visual vocabulary.
101
                          // It shouldn't matter which object class is specified here - visual vocab will still be the same.
102
    int vocabSize; //number of visual words in vocabulary to train
103
    int memoryUse; // Memory to preallocate (in MB) when training vocab.
104
                      // Change this depending on the size of the dataset/available memory.
105
    float descProportion; // Specifies the number of descriptors to use from each image as a proportion of the total num descs.
106
};
107
108
struct SVMTrainParamsExt
109
{
110
    SVMTrainParamsExt() : descPercent(VISUAL_VOCABULARY_DESCRIPTORS_FROM_EACH_TRAINING_PROPORTION_IMAGE), targetRatio(VISUAL_VOCABULARY_TRAINING_SUCEESS_TARGET), balanceClasses(true) {}
111
    SVMTrainParamsExt( float _descPercent, float _targetRatio, bool _balanceClasses ) :
112
            descPercent(_descPercent), targetRatio(_targetRatio), balanceClasses(_balanceClasses) {}
113
    void read( const FileNode& fn )
114
    {
115
        fn["descPercent"] >> descPercent;
116
        fn["targetRatio"] >> targetRatio;
117
        fn["balanceClasses"] >> balanceClasses;
118
    }
119
    void write( FileStorage& fs ) const
120
    {
121
        fs << "descPercent" << descPercent;
122
        fs << "targetRatio" << targetRatio;
123
        fs << "balanceClasses" << balanceClasses;
124
    }
125
    void print() const
126
    {
127
        cout << "descPercent: " << descPercent << endl;
128
        cout << "targetRatio: " << targetRatio << endl;
129
        cout << "balanceClasses: " << balanceClasses << endl;
130
    }
131
132
    float descPercent; // Percentage of extracted descriptors to use for training.
133
    float targetRatio; // Try to get this ratio of positive to negative samples (minimum).
134
    bool balanceClasses;    // Balance class weights by number of samples in each (if true cSvmTrainTargetRatio is ignored).
135
};
136
137
 
138
139
void printUsedParams( const string& mediaPath, const string& resDir,
140
                      const DDMParams& ddmParams, const VocabTrainParams& vocabTrainParams,
141
                      const SVMTrainParamsExt& svmTrainParamsExt )
142
{
143
    cout << "CURRENT SCANNER CONFIGURATION" << endl;
144
    cout << "----------------------------------------------------------------" << endl;
145
    cout << "mediaPath: " << mediaPath << endl;
146
    cout << "resDir: " << resDir << endl;
147
    cout << endl; ddmParams.print();
148
    cout << endl; vocabTrainParams.print();
149
    cout << endl; svmTrainParamsExt.print();
150
    cout << "----------------------------------------------------------------" << endl << endl;
151
}
152
153
bool readVocabulary( const string& filename, Mat& vocabulary )
154
{
155
    #if defined(DEBUG_MODE)
156
    cout << "Reading vocabulary...";
157
    #endif
158
        
159
    FileStorage fs( filename, FileStorage::READ );
160
    if( fs.isOpened() )
161
    {
162
        fs["vocabulary"] >> vocabulary; 
163
        return true;
164
    }
165
    return false;
166
}
167
 
168
 
169
  bool writeBowImageDescriptor( const string& file, const Mat& bowImageDescriptor )
170
{
171
    FileStorage fs( file, FileStorage::WRITE );
172
    if( fs.isOpened() )
173
    {
174
        fs << "imageDescriptor" << bowImageDescriptor;
175
        return true;
176
    }
177
    return false;
178
}
179
 
180
 
181
 
182
/***********************************************************************************/
183
//This function scans the SVM dir to push the specific file names one at a time to a vector list
184
185
void loadListFromDir( string dir , vector<string>* m_object_classes)        
186
{
187
string filepath;
188
string filename;
189
string filebasename;
190
string fileext; 
191
int arrTemplatesCount;
192
DIR *dp;
193
struct dirent *dirp;
194
struct stat filestat; 
195
196
arrTemplatesCount=0;
197
dp = opendir( dir.c_str() );        //try to open the directory
198
if (dp == NULL)
199
{
200
    cout << "Error opening " << dir << endl;
201
    return;
202
}
203
204
while ((dirp = readdir( dp )) && (arrTemplatesCount < 100000)) //scan the dir
205
{
206
        filename=dirp->d_name;
207
        filepath = dir + "/" + filename; 
208
209
        // file invalid? we'll skip it... 
210
        if (stat( filepath.c_str(), &filestat )) continue;
211
            
212
        //is a file? 
213
        if (S_ISREG( filestat.st_mode ))  //is a real file?
214
        {
215
                fileext = filename.substr(filename.find_last_of(".") + 1);
216
                std::transform(fileext.begin(), fileext.end(),fileext.begin(), ::tolower); //str ::tolower   ::toupper 
217
                    
218
                if(( fileext == "xml") || ( fileext == "xml.gz") || ( fileext == "gz"))        //gzip or XML ?
219
                {
220
                        filebasename = filename.substr(0, filename.find_first_of(".") );
221
                        
222
                        #if defined(DEBUG_MODE)
223
                                cout << "Loaded: " << filebasename << "  " << filepath.c_str() << endl;
224
                        #endif
225
                        
226
                        (*m_object_classes).push_back(filebasename);
227
                         
228
                }
229
        }        
230
        
231
        arrTemplatesCount++;        //prevents endless loops from ln -s tricks =)
232
}
233
234
closedir( dp );        //done loading list
235
236
}
237
238
239
240
        
241
/***********************************************************************************/
242
243
int main(int argc, char** argv)
244
{
245
    if( argc != 3 && argc != 6 )
246
    { 
247
        echo <<"\nbagofwords_scan </path/to/some/video/file.avi> </path/to/the/trained/BOW/VOCDATA/output> <SURF> <OpponentSURF> <BruteForce>"   << endl;
248
        exit(-1);
249
    }
250
    
251
    CvMemStorage* storageTmp = cvCreateMemStorage(0);
252
    const string mediaPath = argv[1], resPath = argv[2];
253
254
    // Read default parameters file
255
    string vocName;
256
    DDMParams ddmParams;
257
    VocabTrainParams vocabTrainParams;
258
    SVMTrainParamsExt svmTrainParamsExt;
259
  
260
    FileStorage paramsFS( resPath + "/" + paramsFile, FileStorage::READ );
261
    if( paramsFS.isOpened() )
262
    {  
263
            const FileNode& fn=paramsFS.root();                        //parse the XML file for the type of trained data settings
264
                
265
            fn["vocName"] >> vocName;
266
            FileNode currFn = fn;
267
268
            currFn = fn["ddmParams"];
269
            ddmParams.read( currFn );
270
271
            currFn = fn["vocabTrainParams"];
272
            vocabTrainParams.read( currFn );
273
274
            currFn = fn["svmTrainParamsExt"];
275
            svmTrainParamsExt.read( currFn ); 
276
            
277
    }else{
278
        cout << "\n Could open the file " <<  resPath << "/" << paramsFile << endl;
279
        exit(-1); 
280
    } 
281
282
    // Create detector, descriptor, matcher.
283
    Ptr<FeatureDetector> featureDetector = FeatureDetector::create( ddmParams.detectorType );
284
    Ptr<DescriptorExtractor> descExtractor = DescriptorExtractor::create( ddmParams.descriptorType );
285
    
286
        cout << "\nHeisenbug: descExtractor " << descExtractor->descriptorType() << "=" << CV_32FC1 << " ?\n"; 
287
    
288
    
289
    Ptr<BOWImgDescriptorExtractor> bowExtractor;
290
    if( featureDetector.empty() || descExtractor.empty() )
291
    {
292
        cout << "featureDetector or descExtractor was not created" << endl;
293
        exit(-1);
294
    }else{
295
        Ptr<DescriptorMatcher> descMatcher = DescriptorMatcher::create( ddmParams.matcherType );
296
        if( featureDetector.empty() || descExtractor.empty() || descMatcher.empty() )
297
        {
298
            cout << "descMatcher was not created" << endl;
299
            exit(-1);
300
        }
301
        bowExtractor = new BOWImgDescriptorExtractor( descExtractor, descMatcher );
302
    }
303
304
     
305
    
306
    // Print configuration to screen
307
        printUsedParams( mediaPath, resPath, ddmParams, vocabTrainParams, svmTrainParamsExt ); 
308
        cout << "\n Threshold for calculated " << MINIMUM_BOW_CONFIDENCE_SCORE << " class confidence...\n" << endl;
309
310
    // 1. Load visual word pre-calculated vocabulary file from previous run
311
        Mat vocabulary;
312
            string vocabularyFilename = resPath + "/" + vocabularyFile;
313
        if( !readVocabulary( vocabularyFilename, vocabulary) )
314
        {        
315
                cout << "\n Could not load vocabulary file! \n" << vocabularyFilename << endl;
316
                return -1;
317
        }
318
319
        bowExtractor->setVocabulary( vocabulary );
320
        #if defined(DEBUG_MODE)
321
        cout << "\nSet Vocabulary: rows=" << vocabulary.rows << "  cols="<< vocabulary.cols << endl << endl;
322
        #endif
323
        
324
    // 2. check for classifier and run a query for each object
325
        //define available object_classes for VOC2010 dataset etc...
326
        //by scanning for svm trained object classes
327
        vector<string> m_object_classes;  
328
        string svmFileLocation = resPath + svmsDir ;  
329
        loadListFromDir(svmFileLocation, &m_object_classes);
330
        
331
        std::vector<BogClassifierTracker> bogClasses;        //track all hits to the BOG classes
332
 
333
          
334
        //TODO: Prepare to query objects  (ptr as we may add a context subset filter later)
335
         const vector<string>& objClasses=m_object_classes; 
336
        
337
        
338
    #if defined(DEBUG_MODE)
339
                cout << "\n Loaded Vocabulary: bowExtractor->descriptorSize()=" << bowExtractor->descriptorSize() << endl;
340
    #endif
341
    CV_Assert( !bowExtractor->getVocabulary().empty() );        //poll Vocabulary is valid ?
342
        
343
         
344
        cout << "Load SVM files for selected Visual Vocabulary:" << endl;
345
            for( size_t classIdx = 0; (classIdx < objClasses.size()); ++classIdx )
346
            {
347
          
348
                /* first check if a previously trained svm for the current class has been saved to file */
349
                string svmFilename = resPath + svmsDir + "/" + objClasses[classIdx] + ".xml.gz";
350
                 
351
                FileStorage fs( svmFilename, FileStorage::READ);
352
                if( fs.isOpened() )
353
                {
354
                        // Load a classifier from the trainer dataset
355
                        BogClassifierTracker BOGCLassRecord = BogClassifierTracker(objClasses[classIdx], svmFilename);
356
                        bogClasses.push_back(BOGCLassRecord);
357
                        
358
                        #if defined(DEBUG_MODE)
359
                        cout << "*** LOADING SVM CLASSIFIER FOR CLASS " << bogClasses[classIdx].nameOfClass << " ***" << endl;
360
                        cout << svmFilename << endl;
361
                        #endif
362
                         
363
                        cout <<   bogClasses[classIdx].nameOfClass << " " << std::flush;
364
                        fs.release();
365
                } 
366
367
            }
368
        cout << "\n---------------------------------------------------------------" << endl;
369
    
370
    
371
    
372
        /* probe reference video for valid data */  
373
        IplImage        *imgBuffer, *img; 
374
        CvCapture *capture=cvCreateFileCapture(mediaPath.c_str()); 
375
        int frameCounter = 0;
376
         
377
        /******************** Open target file buffer? ******************/
378
        cvGrabFrame(capture);
379
        imgBuffer = cvRetrieveFrame(capture);                                     
380
        if( imgBuffer == 0 ) {
381
                fprintf( stderr, "Cannot load video target file %s!\n", mediaPath.c_str());
382
                exit(-1); 
383
        }         
384
        img = cvCreateImage(cvGetSize(imgBuffer), IPL_DEPTH_8U, 3); // imgBuffer->depth,   imgBuffer->nChannels); //frame copy
385
        
386
        /* create new image for the grayscale version */
387
        IplImage *imgBufferGray = cvCreateImage( cvGetSize(imgBuffer), IPL_DEPTH_8U, 1 );
388
 
389
        
390
        
391
        //cvNamedWindow("image_src", 1);
392
        
393
        while(cvGrabFrame(capture))   
394
        {        
395
                
396
                imgBuffer = cvRetrieveFrame(capture);      //buffer frame
397
        //        cvShowImage("image_src",imgBuffer);         
398
        //        cvWaitKey(0); 
399
                
400
                frameCounter++;
401
                cout << "\rFrame number " << frameCounter << "                                             ";
402
           
403
                Mat imgMat = imgBuffer;                 //Fast copy Pointer construct from buffer (not parallel)                                
404
                //Mat imgMat(imgBuffer);                 //copy and construct from buffer (to go parallel later)                
405
                                          
406
                size_t i = 0;
407
                vector<KeyPoint> keypoints;
408
                vector<Mat> bowImageDescriptors;
409
                
410
                #if defined(DEBUG_MODE)
411
                cout << "\nComputing descriptors for image... " ;
412
                #endif
413
                featureDetector->detect( imgMat, keypoints );                //svn  r8280 breaks this call
414
415
                
416
                #if defined(DEBUG_MODE)
417
                cout << "\nGenerating BoW vector... " << endl ;
418
                #endif
419
                bowImageDescriptors.resize( (i+1) ); //images size = 1 
420
                bowExtractor->compute( imgMat, keypoints, bowImageDescriptors[i] );
421
         
422
                float imageKeypointsSize = keypoints.size();
423
                 
424
                // Skip images for descriptors that could not be calculated 
425
                if( bowImageDescriptors[i].empty() || (bowImageDescriptors[i].cols == 0) || (bowImageDescriptors[i].rows == 0)  || (imageKeypointsSize < 1))
426
                {
427
                        cout << "\n Error: bow image descriptor empty.\n" << endl;                //coomon if the image is a black screen etc...
428
                        //exit(-1);
429
                }else{ 
430
                        #if defined(DEBUG_MODE)
431
                        cout << "\nNote: bowImageDescriptors.size=" << bowImageDescriptors.size() 
432
                                << " col=" << bowImageDescriptors[i].cols 
433
                                << " row="<< bowImageDescriptors[i].rows << endl;
434
                        #endif
435
                        
436
                        //display frame keypoints for selected SVM checker
437
                        drawKeypoints(imgMat, keypoints, imgMat, Scalar(0,255,255)); 
438
                     
439
                        
440
                        float signMul = -1.f;                //1.f
441
                        for( size_t imageIdx = 0; imageIdx < bogClasses.size(); imageIdx++ )
442
                        { 
443
                                
444
                                 // Use the bag of words vectors to calculate classifier output for each image in test set
445
                                #if defined(DEBUG_MODE)
446
                                        cout << "\nFrame " << frameCounter 
447
                                                << ": CALCULATING CONFIDENCE SCORE FOR CLASS " << bogClasses[imageIdx].nameOfClass << endl;
448
                                #endif
449
                         
450
                                #if defined(DEBUG_MODE)
451
                                float svmFeaturesUsed = (*bogClasses[imageIdx].svm).get_var_count();
452
                                #endif
453
                                float scoreVal = (*bogClasses[imageIdx].svm).predict( bowImageDescriptors[i], true );
454
                                float classVal=0;
455
                                
456
 
457
                                //no change in output seen
458
                        //        if( imageIdx == 0 )
459
                                {
460
                                    // In the first iteration, determine the sign of the positive class 
461
                                    classVal = (*bogClasses[imageIdx].svm).predict(bowImageDescriptors[i], false ); 
462
                                        
463
                                    signMul = (classVal < 0) == (scoreVal < 0) ? 1.f : -1.f;
464
                                } 
465
                                
466
                                // svm output of decision function 
467
                                float confidence = signMul * scoreVal;
468
                                #if defined(DEBUG_MODE)
469
                                cout << "\nConfidence=" << confidence << endl;
470
                                #endif
471
 
472
                                
473
                                #if defined(DEBUG_MODE)
474
                                cout << "\n classVal=" << classVal << "  scoreVal=" << scoreVal << endl;
475
                                #endif
476
                                
477
                                #if defined(DEBUG_MODE)
478
                                cout << "\n Keypoints=" << imageKeypointsSize << "    Used Points=" << svmFeaturesUsed <<endl;
479
                                #endif 
480
                                                
481
                                if( (confidence > MINIMUM_BOW_CONFIDENCE_SCORE) && (confidence < MAXIMUM_BOW_CONFIDENCE_SCORE))
482
                                {
483
                                        
484
                                        // Show support vector count
485
                                        int supportVectorCount     = (*bogClasses[imageIdx].svm).get_support_vector_count();
486
                                        
487
                                        #if defined(DEBUG_MODE)
488
                                                cout << "\n  support vector count: " << supportVectorCount << endl ; 
489
                                        #endif 
490
                                        
491
                                        #if defined(DEBUG_MODE)
492
                                                cout << "\n score: " << bogClasses[imageIdx].nameOfClass  << " = " << confidence << " [ " << scoreVal << " ] ";        
493
                                        #endif
494
495
                                }else {
496
                                        #if defined(DEBUG_MODE)
497
                                                cout << "\n Skipped: " << bogClasses[imageIdx].nameOfClass  << " = " << confidence << " [ " << scoreVal << " ] ";        
498
                                        #endif
499
                                }
500
                        }
501
                          
502
                        
503
                        #if defined(DEBUG_MODE)
504
                        imshow("image_keypoints",imgMat);
505
                        int sc = waitKey(1000); 
506
                        cout << "\n---------------------------------------------------------------" << endl;        
507
                        #else
508
                        imshow("image_keypoints",imgMat);
509
                        int sc = waitKey(0);
510
                        #endif
511
                                        
512
                        //debug
513
                        /*
514
                        if( !writeBowImageDescriptor( "example.jpg.xml.gz", bowImageDescriptors[i] ) )
515
                        {
516
                                cout << "Error: file example can not be opened to write bow image descriptor" << endl;
517
                                exit(-1);
518
                        }
519
                        */
520
            
521
                } 
522
                
523
        }
524
        
525
 
526
        cvReleaseCapture(&capture); 
527
        return 0;
528
}