OpenCV 3.0中的SVM训练 mnist 手写字体识别(5)

 

1 /* 2 svm_type – 3 指定SVM的类型,下面是可能的取值: 4 CvSVM::C_SVC C类支持向量分类机。 n类分组 (n \geq 2),允许用异常值惩罚因子C进行不完全分类。 5 CvSVM::NU_SVC \nu类支持向量分类机。n类似然不完全分类的分类器。参数为 \nu 取代C(其值在区间【0,1】中,nu越大,决策边界越平滑)。 6 CvSVM::ONE_CLASS 单分类器,所有的训练数据提取自同一个类里,然后SVM建立了一个分界线以分割该类在特征空间中所占区域和其它类在特征空间中所占区域。 7 CvSVM::EPS_SVR \epsilon类支持向量回归机。训练集中的特征向量和拟合出来的超平面的距离需要小于p。异常值惩罚因子C被采用。 8 CvSVM::NU_SVR \nu类支持向量回归机。 \nu 代替了 p。 9 10 可从 [LibSVM] 获取更多细节。 11 12 kernel_type – 13 SVM的内核类型,下面是可能的取值: 14 CvSVM::LINEAR 线性内核。没有任何向映射至高维空间,线性区分(或回归)在原始特征空间中被完成,这是最快的选择。K(x_i, x_j) = x_i^T x_j. 15 CvSVM::POLY 多项式内核: K(x_i, x_j) = (\gamma x_i^T x_j + coef0)^{degree}, \gamma > 0. 16 CvSVM::RBF 基于径向的函数,对于大多数情况都是一个较好的选择: K(x_i, x_j) = e^{-\gamma ||x_i - x_j||^2}, \gamma > 0. 17 CvSVM::SIGMOID Sigmoid函数内核:K(x_i, x_j) = \tanh(\gamma x_i^T x_j + coef0). 18 19 degree – 内核函数(POLY)的参数degree。 20 21 gamma – 内核函数(POLY/ RBF/ SIGMOID)的参数\gamma。 22 23 coef0 – 内核函数(POLY/ SIGMOID)的参数coef0。 24 25 Cvalue – SVM类型(C_SVC/ EPS_SVR/ NU_SVR)的参数C。 26 27 nu – SVM类型(NU_SVC/ ONE_CLASS/ NU_SVR)的参数 \nu。 28 29 p – SVM类型(EPS_SVR)的参数 \epsilon。 30 31 class_weights – C_SVC中的可选权重,赋给指定的类,乘以C以后变成 class\_weights_i * C。所以这些权重影响不同类别的错误分类惩罚项。权重越大,某一类别的误分类数据的惩罚项就越大。 32 33 term_crit – SVM的迭代训练过程的中止条件,解决部分受约束二次最优问题。您可以指定的公差和/或最大迭代次数。 34 35 */ 36 37 38 #include "mnist.h" 39 40 #include <opencv2/core.hpp> 41 #include <opencv2/imgproc.hpp> 42 #include "opencv2/imgcodecs.hpp" 43 #include <opencv2/highgui.hpp> 44 #include <opencv2/ml.hpp> 45 46 #include <string> 47 #include <iostream> 48 49 using namespace std; 50 using namespace cv; 51 using namespace cv::ml; 52 53 string trainImage = "mnist_dataset/train-images.idx3-ubyte"; 54 string trainLabel = "mnist_dataset/train-labels.idx1-ubyte"; 55 string testImage = "mnist_dataset/t10k-images.idx3-ubyte"; 56 string testLabel = "mnist_dataset/t10k-labels.idx1-ubyte"; 57 //string testImage = "mnist_dataset/train-images.idx3-ubyte"; 58 //string testLabel = "mnist_dataset/train-labels.idx1-ubyte"; 59 60 //计时器 61 double cost_time_; 62 clock_t start_time_; 63 clock_t end_time_; 64 65 int main() 66 { 67 68 //--------------------- 1. Set up training data --------------------------------------- 69 Mat trainData; 70 Mat labels; 71 trainData = read_mnist_image(trainImage); 72 labels = read_mnist_label(trainLabel); 73 74 cout << trainData.rows << " " << trainData.cols << endl; 75 cout << labels.rows << " " << labels.cols << endl; 76 77 //------------------------ 2. Set up the support vector machines parameters -------------------- 78 Ptr<SVM> svm = SVM::create(); 79 svm->setType(SVM::C_SVC); 80 svm->setKernel(SVM::RBF); 81 //svm->setDegree(10.0); 82 svm->setGamma(0.01); 83 //svm->setCoef0(1.0); 84 svm->setC(10.0); 85 //svm->setNu(0.5); 86 //svm->setP(0.1); 87 svm->setTermCriteria(TermCriteria(CV_TERMCRIT_EPS, 1000, FLT_EPSILON)); 88 89 //------------------------ 3. Train the svm ---------------------------------------------------- 90 cout << "Starting training process" << endl; 91 start_time_ = clock(); 92 svm->train(trainData, ROW_SAMPLE, labels); 93 end_time_ = clock(); 94 cost_time_ = (end_time_ - start_time_) / CLOCKS_PER_SEC; 95 cout << "Finished training process...cost " << cost_time_ << " seconds..." << endl; 96 97 //------------------------ 4. save the svm ---------------------------------------------------- 98 svm->save("mnist_dataset/mnist_svm.xml"); 99 cout << "save as /mnist_dataset/mnist_svm.xml" << endl; 100 101 102 //------------------------ 5. load the svm ---------------------------------------------------- 103 cout << "开始导入SVM文件...\n"; 104 Ptr<SVM> svm1 = StatModel::load<SVM>("mnist_dataset/mnist_svm.xml"); 105 cout << "成功导入SVM文件...\n"; 106 107 108 //------------------------ 6. read the test dataset ------------------------------------------- 109 cout << "开始导入测试数据...\n"; 110 Mat testData; 111 Mat tLabel; 112 testData = read_mnist_image(testImage); 113 tLabel = read_mnist_label(testLabel); 114 cout << "成功导入测试数据!!!\n"; 115 116 117 float count = 0; 118 for (int i = 0; i < testData.rows; i++) { 119 Mat sample = testData.row(i); 120 float res = svm1->predict(sample); 121 res = std::abs(res - tLabel.at<unsigned int>(i, 0)) <= FLT_EPSILON ? 1.f : 0.f; 122 count += res; 123 } 124 cout << "正确的识别个数 count = " << count << endl; 125 cout << "错误率为..." << (10000 - count + 0.0) / 10000 * 100.0 << "%....\n"; 126 127 system("pause"); 128 return 0; 129 }

OpenCV官方教程中文版(For Python) PDF 

内容版权声明:除非注明,否则皆为本站原创文章。

转载注明出处:https://www.heiqu.com/546444a28c3e5e56e941a035d61e2712.html