然后就查看了opencv的文档,当传入数据是Mat 而不是cvMat时,可以利用predict的返回值(float)来判断预测是否正确。
运行结果:
1)1000个训练数据/1000个测试数据
2)2000个训练数据/2000个测试数据
3)5000个训练数据/5000个测试数据
4)10000个训练数据/10000个测试数据
5)60000个训练数据/10000个测试数据
最后,关于运行时间(在程序正确的前提下,训练时长和初始的参数设置有关),给出我最的运行结果(1000张图是11s左右,6000张是1300s ~ 2000s左右)
代码:
1 #ifndef MNIST_H 2 #define MNIST_H 3 4 #include <iostream> 5 #include <string> 6 #include <fstream> 7 #include <ctime> 8 #include <opencv2/opencv.hpp> 9 10 using namespace cv; 11 using namespace std; 12 13 //小端存储转换 14 int reverseInt(int i); 15 16 //读取image数据集信息 17 Mat read_mnist_image(const string fileName); 18 19 //读取label数据集信息 20 Mat read_mnist_label(const string fileName); 21 22 #endif
mnist.h
1 #include "mnist.h" 2 3 //计时器 4 double cost_time; 5 clock_t start_time; 6 clock_t end_time; 7 8 //测试item个数 9 int testNum = 10000; 10 11 int reverseInt(int i) { 12 unsigned char c1, c2, c3, c4; 13 14 c1 = i & 255; 15 c2 = (i >> 8) & 255; 16 c3 = (i >> 16) & 255; 17 c4 = (i >> 24) & 255; 18 19 return ((int)c1 << 24) + ((int)c2 << 16) + ((int)c3 << 8) + c4; 20 } 21 22 Mat read_mnist_image(const string fileName) { 23 int magic_number = 0; 24 int number_of_images = 0; 25 int n_rows = 0; 26 int n_cols = 0; 27 28 Mat DataMat; 29 30 ifstream file(fileName, ios::binary); 31 if (file.is_open()) 32 { 33 cout << "成功打开图像集 ... \n"; 34 35 file.read((char*)&magic_number, sizeof(magic_number)); 36 file.read((char*)&number_of_images, sizeof(number_of_images)); 37 file.read((char*)&n_rows, sizeof(n_rows)); 38 file.read((char*)&n_cols, sizeof(n_cols)); 39 //cout << magic_number << " " << number_of_images << " " << n_rows << " " << n_cols << endl; 40 41 magic_number = reverseInt(magic_number); 42 number_of_images = reverseInt(number_of_images); 43 n_rows = reverseInt(n_rows); 44 n_cols = reverseInt(n_cols); 45 cout << "MAGIC NUMBER = " << magic_number 46 << " ;NUMBER OF IMAGES = " << number_of_images 47 << " ; NUMBER OF ROWS = " << n_rows 48 << " ; NUMBER OF COLS = " << n_cols << endl; 49 50 //-test- 51 //number_of_images = testNum; 52 //输出第一张和最后一张图,检测读取数据无误 53 Mat s = Mat::zeros(n_rows, n_rows * n_cols, CV_32FC1); 54 Mat e = Mat::zeros(n_rows, n_rows * n_cols, CV_32FC1); 55 56 cout << "开始读取Image数据......\n"; 57 start_time = clock(); 58 DataMat = Mat::zeros(number_of_images, n_rows * n_cols, CV_32FC1); 59 for (int i = 0; i < number_of_images; i++) { 60 for (int j = 0; j < n_rows * n_cols; j++) { 61 unsigned char temp = 0; 62 file.read((char*)&temp, sizeof(temp)); 63 float pixel_value = float((temp + 0.0) / 255.0); 64 DataMat.at<float>(i, j) = pixel_value; 65 66 //打印第一张和最后一张图像数据 67 if (i == 0) { 68 s.at<float>(j / n_cols, j % n_cols) = pixel_value; 69 } 70 else if (i == number_of_images - 1) { 71 e.at<float>(j / n_cols, j % n_cols) = pixel_value; 72 } 73 } 74 } 75 end_time = clock(); 76 cost_time = (end_time - start_time) / CLOCKS_PER_SEC; 77 cout << "读取Image数据完毕......" << cost_time << "s\n"; 78 79 imshow("first image", s); 80 imshow("last image", e); 81 waitKey(0); 82 } 83 file.close(); 84 return DataMat; 85 } 86 87 Mat read_mnist_label(const string fileName) { 88 int magic_number; 89 int number_of_items; 90 91 Mat LabelMat; 92 93 ifstream file(fileName, ios::binary); 94 if (file.is_open()) 95 { 96 cout << "成功打开Label集 ... \n"; 97 98 file.read((char*)&magic_number, sizeof(magic_number)); 99 file.read((char*)&number_of_items, sizeof(number_of_items)); 100 magic_number = reverseInt(magic_number); 101 number_of_items = reverseInt(number_of_items); 102 103 cout << "MAGIC NUMBER = " << magic_number << " ; NUMBER OF ITEMS = " << number_of_items << endl; 104 105 //-test- 106 //number_of_items = testNum; 107 //记录第一个label和最后一个label 108 unsigned int s = 0, e = 0; 109 110 cout << "开始读取Label数据......\n"; 111 start_time = clock(); 112 LabelMat = Mat::zeros(number_of_items, 1, CV_32SC1); 113 for (int i = 0; i < number_of_items; i++) { 114 unsigned char temp = 0; 115 file.read((char*)&temp, sizeof(temp)); 116 LabelMat.at<unsigned int>(i, 0) = (unsigned int)temp; 117 118 //打印第一个和最后一个label 119 if (i == 0) s = (unsigned int)temp; 120 else if (i == number_of_items - 1) e = (unsigned int)temp; 121 } 122 end_time = clock(); 123 cost_time = (end_time - start_time) / CLOCKS_PER_SEC; 124 cout << "读取Label数据完毕......" << cost_time << "s\n"; 125 126 cout << "first label = " << s << endl; 127 cout << "last label = " << e << endl; 128 } 129 file.close(); 130 return LabelMat; 131 }
mnist.cpp