机器学习实战(笔记)------------KNN算法 (2)

假设我们有一些手写数字,以如下形式保存:

00000000000001100000000000000000 00000000000011111100000000000000 00000000000111111111000000000000 00000000011111111111000000000000 00000001111111111111100000000000 00000000111111100011110000000000 00000001111110000001110000000000 00000001111110000001110000000000 00000011111100000001110000000000 00000011111100000001111000000000 00000011111100000000011100000000 00000011111100000000011100000000 00000011111000000000001110000000 00000011111000000000001110000000 00000001111100000000000111000000 00000001111100000000000111000000 00000001111100000000000111000000 00000011111000000000000111000000 00000011111000000000000111000000 00000000111100000000000011100000 00000000111100000000000111100000 00000000111100000000000111100000 00000000111100000000001111100000 00000000011110000000000111110000 00000000011111000000001111100000 00000000011111000000011111100000 00000000011111000000111111000000 00000000011111100011111111000000 00000000000111111111111110000000 00000000000111111111111100000000 00000000000011111111110000000000 00000000000000111110000000000000

这是一个32*32的矩阵,利用0代表背景,1来代表手写数字
对于这些数据,我们也可以利用KNN算法来识别写的是0~9中的哪里数字
注:存储数据的文件,例如:0_0.txt代码数字0的第一个手写样本数据

数据预处理:转换成测试向量

  数据使用3232的矩阵形式存储,为了能够使用我们实现的KNN分类器,我们必须将其转化成1 1024的向量形式进行表示,也可以叫做降维,将二维数据转换成了一维数据

def img2vector(filename): fr=open(filename) returnVect=zeros((1,1024)) for i in range(32): linestr=fr.readline() for j in range(32): returnVect[0,i*32+j]=int(linestr[j]) return returnVect 使用KNN算法进行分类

  转换成向量以后,我们就可以使用我们实现的KNN分类器进行分类了

import operator from os import listdir import matplotlib import matplotlib.pyplot as plt from numpy import array, shape, tile, zeros def handwritingClassTest(): hwlabels=[] traingfilelist=listdir('digits/trainingDigits') m=len(traingfilelist) trainingDataMat=zeros((m,1024)) for i in range(m): filenameStr=traingfilelist[i] fileStr=filenameStr.split('.')[0] label=int(fileStr.split('_')[0]) hwlabels.append(label) trainingDataMat[i,:]=img2vector ('digits/trainingDigits/%s' % filenameStr) errorCount=0.0 testfilelist=listdir('digits/testDigits') mTest=len(testfilelist) for i in range(mTest): filenameStr=testfilelist[i] fileStr=filenameStr.split('.')[0] label=int(fileStr.split('_')[0]) testVector=img2vector('digits/testDigits/%s' %filenameStr) result=classify(testVector,trainingDataMat,hwlabels,3) print('come back with: %d,the real answer is: %d' % (int(result),label)) if(int(result)!=label): errorCount=errorCount+1.0 print('total number errors is :%f' % errorCount) print('error rate is :%f'% (errorCount/float(mTest)))

os.listdir()

利用该方法,可以得到指定目录里面的所有文件名

内容版权声明:除非注明,否则皆为本站原创文章。

转载注明出处:https://www.heiqu.com/zyygfy.html