1 | %function [testPerf,rankmat,rank] = nnclassFn(train,test,trainClass,answer) |
---|
2 | % |
---|
3 | %Reads in training examples, test examples, class labels of training |
---|
4 | %examples, and correct class of test examples. Data are in columns of train |
---|
5 | %and test, and labels are column vectors. |
---|
6 | % |
---|
7 | %Note: You will need to create label vectors. TrainClass is a column |
---|
8 | %vector of integers indicating the identity of the training examples. |
---|
9 | %e.g. for faces of 3 people with two views each, TrainClass = [1 1 2 2 3 3 ]'; |
---|
10 | %Answer contains the correct labels of the test images, which enables |
---|
11 | %us to compute percent correct. |
---|
12 | % |
---|
13 | %Gets matrix of normalized dot products. Outputs nearest neighbor |
---|
14 | %classification of test examples and percent correct. |
---|
15 | %rankmat gives the top 30 matches for each test image. rank is a vector |
---|
16 | %containing the percent of times the correct match is in the top N matches. |
---|
17 | |
---|
18 | |
---|
19 | function [testPerf,rankmat,rank] = nnclassFn(train,test,trainClass,answer); |
---|
20 | |
---|
21 | numTest = size(test,2); |
---|
22 | numTrain = size(train,2); |
---|
23 | |
---|
24 | %Get distances to training examples |
---|
25 | %dists = eucDist(test,train); %Outputs a Ntest x Ntrain matrix of Euc dist |
---|
26 | dists=-1 * cosFn(test,train);%Outputs a Ntest x Ntrain matrix of cosines |
---|
27 | |
---|
28 | %sort the rows of dists to find the nearest training example: |
---|
29 | [Sdist,nearest] = sort(dists'); %cols of Sdist are distances in ascend order |
---|
30 | %1st row of nearest is index of 1st closest training example |
---|
31 | |
---|
32 | %Create vector with nearest example, and vector with class label. |
---|
33 | Nnbr = nearest(1,:); %First row of nearest contains NN |
---|
34 | %Nnbr = nearest(2,:); |
---|
35 | testClass = trainClass(Nnbr); |
---|
36 | |
---|
37 | correct = find( (testClass - answer == 0)); |
---|
38 | testPerf = size(correct,1) / size(answer,1) |
---|
39 | if(size(correct,2)>size(correct,1)) |
---|
40 | testPerf = size(correct,2) / size(answer,2) |
---|
41 | 'check vector orientation' |
---|
42 | end |
---|
43 | |
---|
44 | %get rank = %correct in top N: |
---|
45 | cumtestPerf=0; |
---|
46 | for i = 1:3 |
---|
47 | rankmat(:,i) = trainClass(nearest(i,:)'); |
---|
48 | correcti = find( (rankmat(:,i) - answer == 0)); |
---|
49 | cumtestPerf = cumtestPerf + size(correcti,1) / size(answer,1); |
---|
50 | rank(i) = cumtestPerf; |
---|
51 | end |
---|
52 | |
---|
53 | %For FERET test, want probeID (answer), then rank, then matched ID no., |
---|
54 | %then FA flag, then "matching score". This will be a matrix with: |
---|
55 | %probe rank match FAflag matching score |
---|
56 | %i 1 trainClass(nearest(i,:)) Sdist(:,i)>4.7 1./Sdist(:,i) |
---|
57 | %i 2 OR rankmat(i,:)' |
---|
58 | %i 3 |
---|
59 | %i 4 |
---|
60 | |
---|