Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit4623bb9

Browse files
committed
Add k-nearest neighbors algorithm.
1 parentb13291d commit4623bb9

File tree

7 files changed

+190
-126
lines changed

7 files changed

+190
-126
lines changed

‎README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ a set of rules that precisely define a sequence of operations.
143143
*`B`[Caesar Cipher](src/algorithms/cryptography/caesar-cipher) - simple substitution cipher
144144
***Machine Learning**
145145
*`B`[NanoNeuron](https://github.com/trekhleb/nano-neuron) - 7 simple JS functions that illustrate how machines can actually learn (forward/backward propagation)
146-
*`B`[KNN](src/algorithms/ML/KNN) -K Nearest Neighbors
146+
*`B`[k-NN](src/algorithms/ml/knn) -k-nearest neighbors classification algorithm
147147
***Uncategorized**
148148
*`B`[Tower of Hanoi](src/algorithms/uncategorized/hanoi-tower)
149149
*`B`[Square Matrix Rotation](src/algorithms/uncategorized/square-matrix-rotation) - in-place algorithm

‎src/algorithms/ML/KNN/README.md

Lines changed: 0 additions & 23 deletions
This file was deleted.

‎src/algorithms/ML/KNN/__test__/knn.test.js

Lines changed: 0 additions & 42 deletions
This file was deleted.

‎src/algorithms/ML/KNN/knn.js

Lines changed: 0 additions & 60 deletions
This file was deleted.

‎src/algorithms/ml/knn/README.md

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
#k-Nearest Neighbors Algorithm
2+
3+
The**k-nearest neighbors algorithm (k-NN)** is a supervised Machine Learning algorithm. It's a classification algorithm, determining the class of a sample vector using a sample data.
4+
5+
In k-NN classification, the output is a class membership. An object is classified by a plurality vote of its neighbors, with the object being assigned to the class most common among its`k` nearest neighbors (`k` is a positive integer, typically small). If`k = 1`, then the object is simply assigned to the class of that single nearest neighbor.
6+
7+
The idea is to calculate the similarity between two data points on the basis of a distance metric.[Euclidean distance](https://en.wikipedia.org/wiki/Euclidean_distance) is used mostly for this task.
8+
9+
![Euclidean distance between two points](https://upload.wikimedia.org/wikipedia/commons/5/55/Euclidean_distance_2d.svg)
10+
11+
_Image source:[Wikipedia](https://en.wikipedia.org/wiki/Euclidean_distance)_
12+
13+
The algorithm is as follows:
14+
15+
1. Check for errors like invalid data/labels.
16+
2. Calculate the euclidean distance of all the data points in training data with the classification point
17+
3. Sort the distances of points along with their classes in ascending order
18+
4. Take the initial`K` classes and find the mode to get the most similar class
19+
5. Report the most similar class
20+
21+
Here is a visualization of k-NN classification for better understanding:
22+
23+
![KNN Visualization 1](https://upload.wikimedia.org/wikipedia/commons/e/e7/KnnClassification.svg)
24+
25+
_Image source:[Wikipedia](https://en.wikipedia.org/wiki/K-nearest_neighbors_algorithm)_
26+
27+
The test sample (green dot) should be classified either to blue squares or to red triangles. If`k = 3` (solid line circle) it is assigned to the red triangles because there are`2` triangles and only`1` square inside the inner circle. If`k = 5` (dashed line circle) it is assigned to the blue squares (`3` squares vs.`2` triangles inside the outer circle).
28+
29+
Another k-NN classification example:
30+
31+
![KNN Visualization 2](https://media.geeksforgeeks.org/wp-content/uploads/graph2-2.png)
32+
33+
_Image source:[GeeksForGeeks](https://media.geeksforgeeks.org/wp-content/uploads/graph2-2.png)_
34+
35+
Here, as we can see, the classification of unknown points will be judged by their proximity to other points.
36+
37+
It is important to note that`K` is preferred to have odd values in order to break ties. Usually`K` is taken as`3` or`5`.
38+
39+
##References
40+
41+
-[k-nearest neighbors algorithm on Wikipedia](https://en.wikipedia.org/wiki/K-nearest_neighbors_algorithm)
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
importkNNfrom'../kNN';
2+
3+
describe('kNN',()=>{
4+
it('should throw an error on invalid data',()=>{
5+
expect(()=>{
6+
kNN();
7+
}).toThrowError('Either dataSet or labels or toClassify were not set');
8+
});
9+
10+
it('should throw an error on invalid labels',()=>{
11+
constnoLabels=()=>{
12+
kNN([[1,1]]);
13+
};
14+
expect(noLabels).toThrowError('Either dataSet or labels or toClassify were not set');
15+
});
16+
17+
it('should throw an error on not giving classification vector',()=>{
18+
constnoClassification=()=>{
19+
kNN([[1,1]],[1]);
20+
};
21+
expect(noClassification).toThrowError('Either dataSet or labels or toClassify were not set');
22+
});
23+
24+
it('should throw an error on not giving classification vector',()=>{
25+
constinconsistent=()=>{
26+
kNN([[1,1]],[1],[1]);
27+
};
28+
expect(inconsistent).toThrowError('Inconsistent vector lengths');
29+
});
30+
31+
it('should find the nearest neighbour',()=>{
32+
letdataSet;
33+
letlabels;
34+
lettoClassify;
35+
letexpectedClass;
36+
37+
dataSet=[[1,1],[2,2]];
38+
labels=[1,2];
39+
toClassify=[1,1];
40+
expectedClass=1;
41+
expect(kNN(dataSet,labels,toClassify)).toBe(expectedClass);
42+
43+
dataSet=[[1,1],[6,2],[3,3],[4,5],[9,2],[2,4],[8,7]];
44+
labels=[1,2,1,2,1,2,1];
45+
toClassify=[1.25,1.25];
46+
expectedClass=1;
47+
expect(kNN(dataSet,labels,toClassify)).toBe(expectedClass);
48+
49+
dataSet=[[1,1],[6,2],[3,3],[4,5],[9,2],[2,4],[8,7]];
50+
labels=[1,2,1,2,1,2,1];
51+
toClassify=[1.25,1.25];
52+
expectedClass=2;
53+
expect(kNN(dataSet,labels,toClassify,5)).toBe(expectedClass);
54+
});
55+
56+
it('should find the nearest neighbour with equal distances',()=>{
57+
constdataSet=[[0,0],[1,1],[0,2]];
58+
constlabels=[1,3,3];
59+
consttoClassify=[0,1];
60+
constexpectedClass=3;
61+
expect(kNN(dataSet,labels,toClassify)).toBe(expectedClass);
62+
});
63+
64+
it('should find the nearest neighbour in 3D space',()=>{
65+
constdataSet=[[0,0,0],[0,1,1],[0,0,2]];
66+
constlabels=[1,3,3];
67+
consttoClassify=[0,0,1];
68+
constexpectedClass=3;
69+
expect(kNN(dataSet,labels,toClassify)).toBe(expectedClass);
70+
});
71+
});

‎src/algorithms/ml/knn/kNN.js

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
/**
2+
* Calculates calculate the euclidean distance between 2 vectors.
3+
*
4+
*@param {number[]} x1
5+
*@param {number[]} x2
6+
*@returns {number}
7+
*/
8+
functioneuclideanDistance(x1,x2){
9+
// Checking for errors.
10+
if(x1.length!==x2.length){
11+
thrownewError('Inconsistent vector lengths');
12+
}
13+
// Calculate the euclidean distance between 2 vectors and return.
14+
letsquaresTotal=0;
15+
for(leti=0;i<x1.length;i+=1){
16+
squaresTotal+=(x1[i]-x2[i])**2;
17+
}
18+
returnNumber(Math.sqrt(squaresTotal).toFixed(2));
19+
}
20+
21+
/**
22+
* Classifies the point in space based on k-nearest neighbors algorithm.
23+
*
24+
*@param {number[][]} dataSet - array of data points, i.e. [[0, 1], [3, 4], [5, 7]]
25+
*@param {number[]} labels - array of classes (labels), i.e. [1, 1, 2]
26+
*@param {number[]} toClassify - the point in space that needs to be classified, i.e. [5, 4]
27+
*@param {number} k - number of nearest neighbors which will be taken into account (preferably odd)
28+
*@return {number} - the class of the point
29+
*/
30+
exportdefaultfunctionkNN(
31+
dataSet,
32+
labels,
33+
toClassify,
34+
k=3,
35+
){
36+
if(!dataSet||!labels||!toClassify){
37+
thrownewError('Either dataSet or labels or toClassify were not set');
38+
}
39+
40+
// Calculate distance from toClassify to each point for all dimensions in dataSet.
41+
// Store distance and point's label into distances list.
42+
constdistances=[];
43+
for(leti=0;i<dataSet.length;i+=1){
44+
distances.push({
45+
dist:euclideanDistance(dataSet[i],toClassify),
46+
label:labels[i],
47+
});
48+
}
49+
50+
// Sort distances list (from closer point to further ones).
51+
// Take initial k values, count with class index
52+
constkNearest=distances.sort((a,b)=>{
53+
if(a.dist===b.dist){
54+
return0;
55+
}
56+
returna.dist<b.dist ?-1 :1;
57+
}).slice(0,k);
58+
59+
// Count the number of instances of each class in top k members.
60+
constlabelsCounter={};
61+
lettopClass=0;
62+
lettopClassCount=0;
63+
for(leti=0;i<kNearest.length;i+=1){
64+
if(kNearest[i].labelinlabelsCounter){
65+
labelsCounter[kNearest[i].label]+=1;
66+
}else{
67+
labelsCounter[kNearest[i].label]=1;
68+
}
69+
if(labelsCounter[kNearest[i].label]>topClassCount){
70+
topClassCount=labelsCounter[kNearest[i].label];
71+
topClass=kNearest[i].label;
72+
}
73+
}
74+
75+
// Return the class with highest count.
76+
returntopClass;
77+
}

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp