This package provides a utility for creating a classifier using the K-Nearest Neighbors algorithm.
This package is different from the other packages in this repository in that it doesn't provide a model with weights, but rather a utility for constructing a KNN model using activations from another model or any other tensors you can associate with a class/label.
You can see example code here.
<html>
<head>
<!-- Load TensorFlow.js -->
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs"></script>
<!-- Load MobileNet -->
<script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/mobilenet"></script>
<!-- Load KNN Classifier -->
<script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/knn-classifier"></script>
</head>
<body>
<img id='class0' src='/images/class0.jpg '/>
<img id='class1' src='/images/class1.jpg '/>
<img id='test' src='/images/test.jpg '/>
</body>
<!-- Place your code in the script tag below. You can also use an external .js file -->
<script>
const init = async function() {
// Create the classifier.
const classifier = knnClassifier.create();
// Load mobilenet.
const mobilenetModule = await mobilenet.load();
// Add MobileNet activations to the model repeatedly for all classes.
const img0 = tf.browser.fromPixels(document.getElementById('class0'));
const logits0 = mobilenetModule.infer(img0, 'conv_preds');
classifier.addExample(logits0, 0);
const img1 = tf.browser.fromPixels(document.getElementById('class1'));
const logits1 = mobilenetModule.infer(img1, 'conv_preds');
classifier.addExample(logits1, 1);
// Make a prediction.
const x = tf.browser.fromPixels(document.getElementById('test'));
const xlogits = mobilenetModule.infer(x, 'conv_preds');
console.log('Predictions:');
const result = await classifier.predictClass(xlogits);
console.log(result);
}
init();
</script>
</html>
const tf = require('@tensorflow/tfjs');
const mobilenetModule = require('@tensorflow-models/mobilenet');
const knnClassifier = require('@tensorflow-models/knn-classifier');
// Create the classifier.
const classifier = knnClassifier.create();
// Load mobilenet.
const mobilenet = await mobilenetModule.load();
// Add MobileNet activations to the model repeatedly for all classes.
const img0 = tf.browser.fromPixels(document.getElementById('class0'));
const logits0 = mobilenet.infer(img0, 'conv_preds');
classifier.addExample(logits0, 0);
const img1 = tf.browser.fromPixels(document.getElementById('class1'));
const logits1 = mobilenet.infer(img1, 'conv_preds');
classifier.addExample(logits1, 1);
// Make a prediction.
const x = tf.browser.fromPixels(document.getElementById('test'));
const xlogits = mobilenet.infer(x, 'conv_preds');
console.log('Predictions:');
console.log(classifier.predictClass(xlogits));
knnClassifier
is the module name, which is automatically included when you use
the <script src> method.
classifier = knnClassifier.create()
Returns a KNNImageClassifier
.
classifier.addExample(
example: tf.Tensor,
label: number|string
): void;
Args:
- example: An example to add to the dataset, usually an activation from another model.
- label: The label (class name) of the example.
classifier.predictClass(
input: tf.Tensor,
k = 3
): Promise<{label: string, classIndex: number, confidences: {[classId: number]: number}}>;
Args:
- input: An example to make a prediction on, usually an activation from another model.
- k: The K value to use in K-nearest neighbors. The algorithm will first find the K nearest examples from those it was previously shown, and then choose the class that appears the most as the final prediction for the input example. Defaults to 3. If examples < k, k = examples.
Returns an object where:
label
: the label (class name) with the most confidence.classIndex
: the 0-based index of the class (for backwards compatibility).confidences
: maps each label to their confidence score.
classifier.clearClass(label: number|string)
Args:
- label: The label to clear all examples for.
classifier.clearAllClasses()
classifier.getClassExampleCount(): {[label: string]: number}
Returns an object that maps label name to example count for that label.
classifier.getClassifierDataset(): {[label: string]: Tensor2D}
classifier.setClassifierDataset(dataset: {[label: string]: Tensor2D})
Args:
- dataset: The label dataset matrices map. Can be retrieved from getClassDatsetMatrices. Useful for restoring state.
classifier.getNumClasses(): number
Clears up WebGL memory. Useful if you no longer need the classifier in your application.
classifier.dispose()