-
Notifications
You must be signed in to change notification settings - Fork 44
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add BruteKNN * update docs and tests * mix format * move get_batches inside BruteKNN * raise error when k > n --------- Co-authored-by: Krsto Proroković <[email protected]>
- Loading branch information
Showing
2 changed files
with
383 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,260 @@ | ||
defmodule Scholar.Neighbors.BruteKNN do | ||
@moduledoc """ | ||
Brute-Force k-Nearest Neighbor Search Algorithm. | ||
In order to find the k-nearest neighbors the algorithm calculates | ||
the distance between the query point and each of the data samples. | ||
Therefore, its time complexity is $O(MN)$ for $N$ samples and $M$ query points. | ||
It uses $O(BN)$ memory for batch size $B$. | ||
Larger batch sizes will lead to faster predictions, but will consume more memory. | ||
""" | ||
import Nx.Defn | ||
import Scholar.Shared | ||
require Nx | ||
|
||
@derive {Nx.Container, keep: [:num_neighbors, :metric, :batch_size], containers: [:data]} | ||
defstruct [:num_neighbors, :metric, :data, :batch_size] | ||
|
||
opts = [ | ||
num_neighbors: [ | ||
required: true, | ||
type: :pos_integer, | ||
doc: "The number of nearest neighbors." | ||
], | ||
metric: [ | ||
type: {:or, [{:custom, Scholar.Options, :metric, []}, {:fun, 2}]}, | ||
default: {:minkowski, 2}, | ||
doc: ~S""" | ||
The function that measures distance between two points. Possible values: | ||
* `{:minkowski, p}` - Minkowski metric. By changing value of `p` parameter (a positive number or `:infinity`) | ||
we can set Manhattan (`1`), Euclidean (`2`), Chebyshev (`:infinity`), or any arbitrary $L_p$ metric. | ||
* `:cosine` - Cosine metric. | ||
* Anonymous function of arity 2 that takes two rank-1 tensors of same dimension and returns a scalar. | ||
""" | ||
], | ||
batch_size: [ | ||
type: :pos_integer, | ||
doc: "The number of samples in a batch." | ||
] | ||
] | ||
|
||
@opts_schema NimbleOptions.new!(opts) | ||
|
||
@doc """ | ||
Fits a brute-force k-NN model. | ||
## Options | ||
#{NimbleOptions.docs(@opts_schema)} | ||
## Examples | ||
iex> data = Nx.tensor([[1, 2], [2, 3], [3, 4], [4, 5], [5, 6]]) | ||
iex> model = Scholar.Neighbors.BruteKNN.fit(data, num_neighbors: 2) | ||
iex> model.num_neighbors | ||
2 | ||
iex> model.data | ||
#Nx.Tensor< | ||
s64[5][2] | ||
[ | ||
[1, 2], | ||
[2, 3], | ||
[3, 4], | ||
[4, 5], | ||
[5, 6] | ||
] | ||
> | ||
""" | ||
deftransform fit(data, opts) do | ||
if Nx.rank(data) != 2 do | ||
raise ArgumentError, | ||
"expected input tensor to have shape {num_samples, num_features}, | ||
got tensor with shape: #{inspect(Nx.shape(data))}" | ||
end | ||
|
||
opts = NimbleOptions.validate!(opts, @opts_schema) | ||
k = opts[:num_neighbors] | ||
|
||
if k > Nx.axis_size(data, 0) do | ||
raise ArgumentError, | ||
""" | ||
expected num_neighbors to be less than or equal to \ | ||
num_samples = #{Nx.axis_size(data, 0)}, got: #{k} | ||
""" | ||
end | ||
|
||
metric = | ||
case opts[:metric] do | ||
{:minkowski, p} -> | ||
&Scholar.Metrics.Distance.minkowski(&1, &2, p: p) | ||
|
||
:cosine -> | ||
&Scholar.Metrics.Distance.cosine/2 | ||
|
||
fun when is_function(fun, 2) -> | ||
fun | ||
end | ||
|
||
%__MODULE__{ | ||
num_neighbors: k, | ||
metric: metric, | ||
data: data, | ||
batch_size: opts[:batch_size] | ||
} | ||
end | ||
|
||
@doc """ | ||
Computes nearest neighbors of query tensor using brute-force search. | ||
Returns the neighbors indices and distances from query points. | ||
## Examples | ||
iex> data = Nx.tensor([[1, 2], [2, 3], [3, 4], [4, 5], [5, 6]]) | ||
iex> model = Scholar.Neighbors.BruteKNN.fit(data, num_neighbors: 2) | ||
iex> query = Nx.tensor([[1, 3], [4, 2], [3, 6]]) | ||
iex> {neighbors, distances} = Scholar.Neighbors.BruteKNN.predict(model, query) | ||
iex> neighbors | ||
#Nx.Tensor< | ||
u64[3][2] | ||
[ | ||
[0, 1], | ||
[1, 2], | ||
[3, 2] | ||
] | ||
> | ||
iex> distances | ||
#Nx.Tensor< | ||
f32[3][2] | ||
[ | ||
[1.0, 1.0], | ||
[2.2360680103302, 2.2360680103302], | ||
[1.4142135381698608, 2.0] | ||
] | ||
> | ||
""" | ||
deftransform predict(%__MODULE__{} = model, query) do | ||
if Nx.rank(query) != 2 do | ||
raise ArgumentError, | ||
"expected query tensor to have shape {num_queries, num_features}, | ||
got tensor with shape: #{inspect(Nx.shape(query))}" | ||
end | ||
|
||
if Nx.axis_size(model.data, 1) != Nx.axis_size(query, 1) do | ||
raise ArgumentError, | ||
""" | ||
expected query tensor to have the same dimension as tensor used for fitting the model, \ | ||
got #{inspect(Nx.axis_size(model.data, 1))} \ | ||
and #{inspect(Nx.axis_size(query, 1))} | ||
""" | ||
end | ||
|
||
predict_n(model, query) | ||
end | ||
|
||
defn predict_n(%__MODULE__{} = model, query) do | ||
k = model.num_neighbors | ||
metric = model.metric | ||
data = model.data | ||
type = Nx.Type.merge(to_float_type(data), to_float_type(query)) | ||
query_size = Nx.axis_size(query, 0) | ||
|
||
batch_size = | ||
case model.batch_size do | ||
nil -> query_size | ||
_ -> min(model.batch_size, query_size) | ||
end | ||
|
||
{batches, leftover} = get_batches(query, batch_size: batch_size) | ||
num_batches = Nx.axis_size(batches, 0) | ||
|
||
{neighbor_indices, neighbor_distances, _} = | ||
while { | ||
neighbor_indices = Nx.broadcast(Nx.u64(0), {query_size, k}), | ||
neighbor_distances = | ||
Nx.broadcast(Nx.as_type(:nan, type), {query_size, k}), | ||
{ | ||
data, | ||
batches, | ||
i = Nx.u64(0) | ||
} | ||
}, | ||
i < num_batches do | ||
batch = batches[i] | ||
|
||
{batch_indices, batch_distances} = | ||
brute_force_search(data, batch, num_neighbors: k, metric: metric) | ||
|
||
neighbor_indices = Nx.put_slice(neighbor_indices, [i * batch_size, 0], batch_indices) | ||
|
||
neighbor_distances = | ||
Nx.put_slice(neighbor_distances, [i * batch_size, 0], batch_distances) | ||
|
||
{neighbor_indices, neighbor_distances, {data, batches, i + 1}} | ||
end | ||
|
||
{neighbor_indices, neighbor_distances} = | ||
case leftover do | ||
nil -> | ||
{neighbor_indices, neighbor_distances} | ||
|
||
_ -> | ||
leftover_size = Nx.axis_size(leftover, 0) | ||
|
||
leftover = | ||
Nx.slice_along_axis(query, query_size - leftover_size, leftover_size, axis: 0) | ||
|
||
{leftover_indices, leftover_distances} = | ||
brute_force_search(data, leftover, num_neighbors: k, metric: metric) | ||
|
||
neighbor_indices = | ||
Nx.put_slice(neighbor_indices, [num_batches * batch_size, 0], leftover_indices) | ||
|
||
neighbor_distances = | ||
Nx.put_slice(neighbor_distances, [num_batches * batch_size, 0], leftover_distances) | ||
|
||
{neighbor_indices, neighbor_distances} | ||
end | ||
|
||
{neighbor_indices, neighbor_distances} | ||
end | ||
|
||
defn get_batches(tensor, opts) do | ||
{size, dim} = Nx.shape(tensor) | ||
batch_size = opts[:batch_size] | ||
num_batches = div(size, batch_size) | ||
leftover_size = rem(size, batch_size) | ||
|
||
batches = | ||
tensor | ||
|> Nx.slice_along_axis(0, num_batches * batch_size, axis: 0) | ||
|> Nx.reshape({num_batches, batch_size, dim}) | ||
|
||
leftover = | ||
if leftover_size > 0 do | ||
Nx.slice_along_axis(tensor, num_batches * batch_size, leftover_size, axis: 0) | ||
else | ||
nil | ||
end | ||
|
||
{batches, leftover} | ||
end | ||
|
||
defnp brute_force_search(data, query, opts) do | ||
k = opts[:num_neighbors] | ||
metric = opts[:metric] | ||
{m, d} = Nx.shape(data) | ||
n = Nx.axis_size(query, 0) | ||
x = query |> Nx.new_axis(1) |> Nx.broadcast({n, m, d}) |> Nx.vectorize([:query, :data]) | ||
y = data |> Nx.new_axis(0) |> Nx.broadcast({n, m, d}) |> Nx.vectorize([:query, :data]) | ||
distances = metric.(x, y) |> Nx.devectorize() |> Nx.rename(nil) | ||
|
||
neighbor_indices = | ||
Nx.argsort(distances, axis: 1, type: :u64) |> Nx.slice_along_axis(0, k, axis: 1) | ||
|
||
neighbor_distances = Nx.take_along_axis(distances, neighbor_indices, axis: 1) | ||
{neighbor_indices, neighbor_distances} | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
defmodule Scholar.Neighbors.BruteKNNTest do | ||
use ExUnit.Case, async: true | ||
alias Scholar.Neighbors.BruteKNN | ||
doctest BruteKNN | ||
|
||
defp data do | ||
Nx.tensor([ | ||
[10, 15], | ||
[46, 63], | ||
[68, 21], | ||
[40, 33], | ||
[25, 54], | ||
[15, 43], | ||
[44, 58], | ||
[45, 40], | ||
[62, 69], | ||
[53, 67] | ||
]) | ||
end | ||
|
||
defp query do | ||
Nx.tensor([ | ||
[12, 23], | ||
[55, 30], | ||
[41, 57], | ||
[64, 72], | ||
[26, 39] | ||
]) | ||
end | ||
|
||
defp result do | ||
neighbor_indices = | ||
Nx.tensor( | ||
[ | ||
[0, 5, 3], | ||
[7, 3, 2], | ||
[6, 1, 9], | ||
[8, 9, 1], | ||
[5, 4, 3] | ||
], | ||
type: :u64 | ||
) | ||
|
||
neighbor_distances = | ||
Nx.tensor([ | ||
[8.246211051940918, 20.2237491607666, 29.73213768005371], | ||
[14.142135620117188, 15.29705810546875, 15.81138801574707], | ||
[3.1622776985168457, 7.8102498054504395, 15.620499610900879], | ||
[3.605551242828369, 12.083045959472656, 20.124610900878906], | ||
[11.704699516296387, 15.033296585083008, 15.231546401977539] | ||
]) | ||
|
||
{neighbor_indices, neighbor_distances} | ||
end | ||
|
||
describe "fit" do | ||
test "default" do | ||
data = data() | ||
k = 3 | ||
model = BruteKNN.fit(data, num_neighbors: k) | ||
assert model.num_neighbors == 3 | ||
assert model.data == data | ||
assert model.batch_size == nil | ||
end | ||
|
||
test "custom metric and batch_size" do | ||
data = data() | ||
k = 3 | ||
metric = &Scholar.Metrics.Distance.minkowski/2 | ||
batch_size = 2 | ||
model = BruteKNN.fit(data, num_neighbors: k, metric: metric, batch_size: batch_size) | ||
assert model.num_neighbors == k | ||
assert model.metric == metric | ||
assert model.data == data | ||
assert model.batch_size == batch_size | ||
end | ||
end | ||
|
||
describe "predict" do | ||
test "batch_size = 1" do | ||
query = query() | ||
k = 3 | ||
model = BruteKNN.fit(data(), num_neighbors: k, batch_size: 1) | ||
{neighbors_true, distances_true} = result() | ||
{neighbors_pred, distances_pred} = BruteKNN.predict(model, query) | ||
assert neighbors_pred == neighbors_true | ||
assert distances_pred == distances_true | ||
end | ||
|
||
test "batch_size = 2" do | ||
query = query() | ||
k = 3 | ||
model = BruteKNN.fit(data(), num_neighbors: k, batch_size: 2) | ||
{neighbors_true, distances_true} = result() | ||
{neighbors_pred, distances_pred} = BruteKNN.predict(model, query) | ||
assert neighbors_pred == neighbors_true | ||
assert distances_pred == distances_true | ||
end | ||
|
||
test "batch_size = 5" do | ||
query = query() | ||
k = 3 | ||
model = BruteKNN.fit(data(), num_neighbors: k, batch_size: 5) | ||
{neighbors_true, distances_true} = result() | ||
{neighbors_pred, distances_pred} = BruteKNN.predict(model, query) | ||
assert neighbors_pred == neighbors_true | ||
assert distances_pred == distances_true | ||
end | ||
|
||
test "batch_size = 10" do | ||
query = query() | ||
k = 3 | ||
model = BruteKNN.fit(data(), num_neighbors: k, batch_size: 10) | ||
{neighbors_true, distances_true} = result() | ||
{neighbors_pred, distances_pred} = BruteKNN.predict(model, query) | ||
|
||
assert neighbors_pred == | ||
neighbors_true | ||
|
||
assert distances_pred == distances_true | ||
end | ||
end | ||
end |