标签:
继续之前的写。
三、对单个样本进行分类。
'''
function: classify the input sample by voting from its K nearest neighbor
input:
1. the input feature vector
2. the feature matrix
3. the label list
4. the value of k
return: the result label
'''
def ClassifySampleByKNN(featureVectorIn, featureMatrix, labelList, kValue):
# calculate the distance between feature input vector and the feature matrix
disValArray = CalcEucDistance(featureVectorIn,featureMatrix)
# sort and return the index
theIndexListOfSortedDist = disValArray.argsort()
# consider the first k index, vote for the label
labelAndCount = {}
for i in range(kValue):
theLabelIndex = theIndexListOfSortedDist[i]
theLabel = labelList[theLabelIndex]
labelAndCount[theLabel] = labelAndCount.get(theLabel,0) + 1
sortedLabelAndCount = sorted(labelAndCount.iteritems(), key=lambda x:x[1], reverse=True)
return sortedLabelAndCount[0][0]比较有特色的写法是这一句:
# sort and return the index
theIndexListOfSortedDist = disValArray.argsort()disValArray是numpy的一维数组,存储的仅仅是欧式距离的值。argsort直接对这些值进行排序,并且把排序结果所对应的原索引返回回来。很方便。另外一句是sorted函数的调用,按照value来对字典进行排序,用到了函数式编程的lambda表达式。这个用operator也能达到同样的目的。
四、对测试样本文件进行分类,并统计错误率
'''
function: classify the samples in test file by KNN algorithm
input:
1. the name of training sample file
2. the name of testing sample file
3. the K value for KNN
4. the name of log file
'''
def ClassifySampleFileByKNN(sampleFileNameForTrain, sampleFileNameForTest, kValue, logFileName):
logFile = open(logFileName,'w')
# load the feature matrix and normailize them
feaMatTrain, labelListTrain = LoadFeatureMatrixAndLabels(sampleFileNameForTrain)
norFeaMatTrain = AutoNormalizeFeatureMatrix(feaMatTrain)
feaMatTest, labelListTest = LoadFeatureMatrixAndLabels(sampleFileNameForTest)
norFeaMatTest = AutoNormalizeFeatureMatrix(feaMatTest)
# classify the test sample and write the result into log
errorNumber = 0.0
testSampleNum = norFeaMatTest.shape[0]
for i in range(testSampleNum):
label = ClassifySampleByKNN(norFeaMatTest[i,:],norFeaMatTrain,labelListTrain,kValue)
if label == labelListTest[i]:
logFile.write("%d:right\n"%i)
else:
logFile.write("%d:wrong\n"%i)
errorNumber += 1
errorRate = errorNumber / testSampleNum
logFile.write("the error rate: %f" %errorRate)
logFile.close()
return
五、入口调用函数
类似c/c++的main函数。只要运行kNN.py这个脚本,就会先执行这一段代码:
if __name__ == '__main__':
print "You are running KNN.py"
ClassifySampleFileByKNN('datingSetOne.txt','datingSetTwo.txt',3,'log.txt')
未完,待续。
【用Python玩Machine Learning】KNN * 代码 * 二
标签:
原文地址:http://blog.csdn.net/xceman1997/article/details/44994215