Skip to content

Commit

Permalink
Brute k-NN (#257)
Browse files Browse the repository at this point in the history
* add BruteKNN

* update docs and tests

* mix format

* move get_batches inside BruteKNN

* raise error when k > n

---------

Co-authored-by: Krsto Proroković <[email protected]>
  • Loading branch information
krstopro and Krsto Proroković authored Apr 16, 2024
1 parent 0d1bcc1 commit 5c1786f
Show file tree
Hide file tree
Showing 2 changed files with 383 additions and 0 deletions.
260 changes: 260 additions & 0 deletions lib/scholar/neighbors/brute_knn.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,260 @@
defmodule Scholar.Neighbors.BruteKNN do
@moduledoc """
Brute-Force k-Nearest Neighbor Search Algorithm.
In order to find the k-nearest neighbors the algorithm calculates
the distance between the query point and each of the data samples.
Therefore, its time complexity is $O(MN)$ for $N$ samples and $M$ query points.
It uses $O(BN)$ memory for batch size $B$.
Larger batch sizes will lead to faster predictions, but will consume more memory.
"""
import Nx.Defn
import Scholar.Shared
require Nx

@derive {Nx.Container, keep: [:num_neighbors, :metric, :batch_size], containers: [:data]}
defstruct [:num_neighbors, :metric, :data, :batch_size]

opts = [
num_neighbors: [
required: true,
type: :pos_integer,
doc: "The number of nearest neighbors."
],
metric: [
type: {:or, [{:custom, Scholar.Options, :metric, []}, {:fun, 2}]},
default: {:minkowski, 2},
doc: ~S"""
The function that measures distance between two points. Possible values:
* `{:minkowski, p}` - Minkowski metric. By changing value of `p` parameter (a positive number or `:infinity`)
we can set Manhattan (`1`), Euclidean (`2`), Chebyshev (`:infinity`), or any arbitrary $L_p$ metric.
* `:cosine` - Cosine metric.
* Anonymous function of arity 2 that takes two rank-1 tensors of same dimension and returns a scalar.
"""
],
batch_size: [
type: :pos_integer,
doc: "The number of samples in a batch."
]
]

@opts_schema NimbleOptions.new!(opts)

@doc """
Fits a brute-force k-NN model.
## Options
#{NimbleOptions.docs(@opts_schema)}
## Examples
iex> data = Nx.tensor([[1, 2], [2, 3], [3, 4], [4, 5], [5, 6]])
iex> model = Scholar.Neighbors.BruteKNN.fit(data, num_neighbors: 2)
iex> model.num_neighbors
2
iex> model.data
#Nx.Tensor<
s64[5][2]
[
[1, 2],
[2, 3],
[3, 4],
[4, 5],
[5, 6]
]
>
"""
deftransform fit(data, opts) do
if Nx.rank(data) != 2 do
raise ArgumentError,
"expected input tensor to have shape {num_samples, num_features},
got tensor with shape: #{inspect(Nx.shape(data))}"
end

opts = NimbleOptions.validate!(opts, @opts_schema)
k = opts[:num_neighbors]

if k > Nx.axis_size(data, 0) do
raise ArgumentError,
"""
expected num_neighbors to be less than or equal to \
num_samples = #{Nx.axis_size(data, 0)}, got: #{k}
"""
end

metric =
case opts[:metric] do
{:minkowski, p} ->
&Scholar.Metrics.Distance.minkowski(&1, &2, p: p)

:cosine ->
&Scholar.Metrics.Distance.cosine/2

fun when is_function(fun, 2) ->
fun
end

%__MODULE__{
num_neighbors: k,
metric: metric,
data: data,
batch_size: opts[:batch_size]
}
end

@doc """
Computes nearest neighbors of query tensor using brute-force search.
Returns the neighbors indices and distances from query points.
## Examples
iex> data = Nx.tensor([[1, 2], [2, 3], [3, 4], [4, 5], [5, 6]])
iex> model = Scholar.Neighbors.BruteKNN.fit(data, num_neighbors: 2)
iex> query = Nx.tensor([[1, 3], [4, 2], [3, 6]])
iex> {neighbors, distances} = Scholar.Neighbors.BruteKNN.predict(model, query)
iex> neighbors
#Nx.Tensor<
u64[3][2]
[
[0, 1],
[1, 2],
[3, 2]
]
>
iex> distances
#Nx.Tensor<
f32[3][2]
[
[1.0, 1.0],
[2.2360680103302, 2.2360680103302],
[1.4142135381698608, 2.0]
]
>
"""
deftransform predict(%__MODULE__{} = model, query) do
if Nx.rank(query) != 2 do
raise ArgumentError,
"expected query tensor to have shape {num_queries, num_features},
got tensor with shape: #{inspect(Nx.shape(query))}"
end

if Nx.axis_size(model.data, 1) != Nx.axis_size(query, 1) do
raise ArgumentError,
"""
expected query tensor to have the same dimension as tensor used for fitting the model, \
got #{inspect(Nx.axis_size(model.data, 1))} \
and #{inspect(Nx.axis_size(query, 1))}
"""
end

predict_n(model, query)
end

defn predict_n(%__MODULE__{} = model, query) do
k = model.num_neighbors
metric = model.metric
data = model.data
type = Nx.Type.merge(to_float_type(data), to_float_type(query))
query_size = Nx.axis_size(query, 0)

batch_size =
case model.batch_size do
nil -> query_size
_ -> min(model.batch_size, query_size)
end

{batches, leftover} = get_batches(query, batch_size: batch_size)
num_batches = Nx.axis_size(batches, 0)

{neighbor_indices, neighbor_distances, _} =
while {
neighbor_indices = Nx.broadcast(Nx.u64(0), {query_size, k}),
neighbor_distances =
Nx.broadcast(Nx.as_type(:nan, type), {query_size, k}),
{
data,
batches,
i = Nx.u64(0)
}
},
i < num_batches do
batch = batches[i]

{batch_indices, batch_distances} =
brute_force_search(data, batch, num_neighbors: k, metric: metric)

neighbor_indices = Nx.put_slice(neighbor_indices, [i * batch_size, 0], batch_indices)

neighbor_distances =
Nx.put_slice(neighbor_distances, [i * batch_size, 0], batch_distances)

{neighbor_indices, neighbor_distances, {data, batches, i + 1}}
end

{neighbor_indices, neighbor_distances} =
case leftover do
nil ->
{neighbor_indices, neighbor_distances}

_ ->
leftover_size = Nx.axis_size(leftover, 0)

leftover =
Nx.slice_along_axis(query, query_size - leftover_size, leftover_size, axis: 0)

{leftover_indices, leftover_distances} =
brute_force_search(data, leftover, num_neighbors: k, metric: metric)

neighbor_indices =
Nx.put_slice(neighbor_indices, [num_batches * batch_size, 0], leftover_indices)

neighbor_distances =
Nx.put_slice(neighbor_distances, [num_batches * batch_size, 0], leftover_distances)

{neighbor_indices, neighbor_distances}
end

{neighbor_indices, neighbor_distances}
end

defn get_batches(tensor, opts) do
{size, dim} = Nx.shape(tensor)
batch_size = opts[:batch_size]
num_batches = div(size, batch_size)
leftover_size = rem(size, batch_size)

batches =
tensor
|> Nx.slice_along_axis(0, num_batches * batch_size, axis: 0)
|> Nx.reshape({num_batches, batch_size, dim})

leftover =
if leftover_size > 0 do
Nx.slice_along_axis(tensor, num_batches * batch_size, leftover_size, axis: 0)
else
nil
end

{batches, leftover}
end

defnp brute_force_search(data, query, opts) do
k = opts[:num_neighbors]
metric = opts[:metric]
{m, d} = Nx.shape(data)
n = Nx.axis_size(query, 0)
x = query |> Nx.new_axis(1) |> Nx.broadcast({n, m, d}) |> Nx.vectorize([:query, :data])
y = data |> Nx.new_axis(0) |> Nx.broadcast({n, m, d}) |> Nx.vectorize([:query, :data])
distances = metric.(x, y) |> Nx.devectorize() |> Nx.rename(nil)

neighbor_indices =
Nx.argsort(distances, axis: 1, type: :u64) |> Nx.slice_along_axis(0, k, axis: 1)

neighbor_distances = Nx.take_along_axis(distances, neighbor_indices, axis: 1)
{neighbor_indices, neighbor_distances}
end
end
123 changes: 123 additions & 0 deletions test/scholar/neighbors/brute_knn_test.exs
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
defmodule Scholar.Neighbors.BruteKNNTest do
use ExUnit.Case, async: true
alias Scholar.Neighbors.BruteKNN
doctest BruteKNN

defp data do
Nx.tensor([
[10, 15],
[46, 63],
[68, 21],
[40, 33],
[25, 54],
[15, 43],
[44, 58],
[45, 40],
[62, 69],
[53, 67]
])
end

defp query do
Nx.tensor([
[12, 23],
[55, 30],
[41, 57],
[64, 72],
[26, 39]
])
end

defp result do
neighbor_indices =
Nx.tensor(
[
[0, 5, 3],
[7, 3, 2],
[6, 1, 9],
[8, 9, 1],
[5, 4, 3]
],
type: :u64
)

neighbor_distances =
Nx.tensor([
[8.246211051940918, 20.2237491607666, 29.73213768005371],
[14.142135620117188, 15.29705810546875, 15.81138801574707],
[3.1622776985168457, 7.8102498054504395, 15.620499610900879],
[3.605551242828369, 12.083045959472656, 20.124610900878906],
[11.704699516296387, 15.033296585083008, 15.231546401977539]
])

{neighbor_indices, neighbor_distances}
end

describe "fit" do
test "default" do
data = data()
k = 3
model = BruteKNN.fit(data, num_neighbors: k)
assert model.num_neighbors == 3
assert model.data == data
assert model.batch_size == nil
end

test "custom metric and batch_size" do
data = data()
k = 3
metric = &Scholar.Metrics.Distance.minkowski/2
batch_size = 2
model = BruteKNN.fit(data, num_neighbors: k, metric: metric, batch_size: batch_size)
assert model.num_neighbors == k
assert model.metric == metric
assert model.data == data
assert model.batch_size == batch_size
end
end

describe "predict" do
test "batch_size = 1" do
query = query()
k = 3
model = BruteKNN.fit(data(), num_neighbors: k, batch_size: 1)
{neighbors_true, distances_true} = result()
{neighbors_pred, distances_pred} = BruteKNN.predict(model, query)
assert neighbors_pred == neighbors_true
assert distances_pred == distances_true
end

test "batch_size = 2" do
query = query()
k = 3
model = BruteKNN.fit(data(), num_neighbors: k, batch_size: 2)
{neighbors_true, distances_true} = result()
{neighbors_pred, distances_pred} = BruteKNN.predict(model, query)
assert neighbors_pred == neighbors_true
assert distances_pred == distances_true
end

test "batch_size = 5" do
query = query()
k = 3
model = BruteKNN.fit(data(), num_neighbors: k, batch_size: 5)
{neighbors_true, distances_true} = result()
{neighbors_pred, distances_pred} = BruteKNN.predict(model, query)
assert neighbors_pred == neighbors_true
assert distances_pred == distances_true
end

test "batch_size = 10" do
query = query()
k = 3
model = BruteKNN.fit(data(), num_neighbors: k, batch_size: 10)
{neighbors_true, distances_true} = result()
{neighbors_pred, distances_pred} = BruteKNN.predict(model, query)

assert neighbors_pred ==
neighbors_true

assert distances_pred == distances_true
end
end
end

0 comments on commit 5c1786f

Please sign in to comment.