[94] | 1 | %function [testPerf,rankmat,rank] = nnclassFn(train,test,trainClass,answer) |
---|
| 2 | % |
---|
| 3 | %Reads in training examples, test examples, class labels of training |
---|
| 4 | %examples, and correct class of test examples. Data are in columns of train |
---|
| 5 | %and test, and labels are column vectors. |
---|
| 6 | % |
---|
| 7 | %Note: You will need to create label vectors. TrainClass is a column |
---|
| 8 | %vector of integers indicating the identity of the training examples. |
---|
| 9 | %e.g. for faces of 3 people with two views each, TrainClass = [1 1 2 2 3 3 ]'; |
---|
| 10 | %Answer contains the correct labels of the test images, which enables |
---|
| 11 | %us to compute percent correct. |
---|
| 12 | % |
---|
| 13 | %Gets matrix of normalized dot products. Outputs nearest neighbor |
---|
| 14 | %classification of test examples and percent correct. |
---|
| 15 | %rankmat gives the top 30 matches for each test image. rank is a vector |
---|
| 16 | %containing the percent of times the correct match is in the top N matches. |
---|
| 17 | |
---|
| 18 | |
---|
| 19 | function [testPerf,rankmat,rank] = nnclassFn(train,test,trainClass,answer); |
---|
| 20 | |
---|
| 21 | numTest = size(test,2); |
---|
| 22 | numTrain = size(train,2); |
---|
| 23 | |
---|
| 24 | %Get distances to training examples |
---|
| 25 | %dists = eucDist(test,train); %Outputs a Ntest x Ntrain matrix of Euc dist |
---|
| 26 | dists=-1 * cosFn(test,train);%Outputs a Ntest x Ntrain matrix of cosines |
---|
| 27 | |
---|
| 28 | %sort the rows of dists to find the nearest training example: |
---|
| 29 | [Sdist,nearest] = sort(dists'); %cols of Sdist are distances in ascend order |
---|
| 30 | %1st row of nearest is index of 1st closest training example |
---|
| 31 | |
---|
| 32 | %Create vector with nearest example, and vector with class label. |
---|
| 33 | Nnbr = nearest(1,:); %First row of nearest contains NN |
---|
| 34 | %Nnbr = nearest(2,:); |
---|
| 35 | testClass = trainClass(Nnbr); |
---|
| 36 | |
---|
| 37 | correct = find( (testClass - answer == 0)); |
---|
| 38 | testPerf = size(correct,1) / size(answer,1) |
---|
| 39 | if(size(correct,2)>size(correct,1)) |
---|
| 40 | testPerf = size(correct,2) / size(answer,2) |
---|
| 41 | 'check vector orientation' |
---|
| 42 | end |
---|
| 43 | |
---|
| 44 | %get rank = %correct in top N: |
---|
| 45 | cumtestPerf=0; |
---|
| 46 | for i = 1:3 |
---|
| 47 | rankmat(:,i) = trainClass(nearest(i,:)'); |
---|
| 48 | correcti = find( (rankmat(:,i) - answer == 0)); |
---|
| 49 | cumtestPerf = cumtestPerf + size(correcti,1) / size(answer,1); |
---|
| 50 | rank(i) = cumtestPerf; |
---|
| 51 | end |
---|
| 52 | |
---|
| 53 | %For FERET test, want probeID (answer), then rank, then matched ID no., |
---|
| 54 | %then FA flag, then "matching score". This will be a matrix with: |
---|
| 55 | %probe rank match FAflag matching score |
---|
| 56 | %i 1 trainClass(nearest(i,:)) Sdist(:,i)>4.7 1./Sdist(:,i) |
---|
| 57 | %i 2 OR rankmat(i,:)' |
---|
| 58 | %i 3 |
---|
| 59 | %i 4 |
---|
| 60 | |
---|