-
Notifications
You must be signed in to change notification settings - Fork 193
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Draft] Create blocked Jacobi method for eigen decomposition #1510
base: main
Are you sure you want to change the base?
[Draft] Create blocked Jacobi method for eigen decomposition #1510
Conversation
Draft commit to introduce the idea. Todo: * Handle complex numbers * Reject malformed matrices
defn eigh(matrix) do | ||
matrix | ||
|> Nx.revectorize([collapsed_axes: :auto], | ||
target_shape: {Nx.axis_size(matrix, -2), Nx.axis_size(matrix, -1)} | ||
) | ||
|> decompose() | ||
|> then(fn {w, v} -> | ||
revectorize_result({w, v}, matrix) | ||
end) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe this should be the only defn
in this module and the others would be defnp
. Or something close to that.
end | ||
|
||
# Initialze tensors to hold eigenvectors | ||
v_tl = Nx.eye(mid, type: :f32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why force f32
here? Is this another case where the algorithm just fails on f64?
Perhaps this should be masked underneath the implementation if it's the case.
# | ||
# The inner loop performs "sweep" rounds of n - 1, which is enough permutations to allow | ||
# all sub matrices to share the needed values. | ||
{_, _, tl, _tr, _bl, br, v_tl, v_tr, v_bl, v_br, _} = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can use a pattern for organizing the while state that we do quite a lot:
{{tl, br, v_tl, v_tr, v_bl, v_br}, _}
where you leave the outputs in a first-position tuple, and the other state in a second position, so pattern matching on the statement is easier, as well as understanding what's output and what's not
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking good. I'm also leaving a few stylistic suggestions for readability
c = Nx.take_diagonal(tl) | ||
|
||
tau = (a - c) / (2 * b) | ||
t = Nx.sqrt(1 + Nx.pow(tau, 2)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
t = Nx.sqrt(1 + Nx.pow(tau, 2)) | |
t = Nx.sqrt(1 + tau ** 2) |
|
||
tau = (a - c) / (2 * b) | ||
t = Nx.sqrt(1 + Nx.pow(tau, 2)) | ||
t = Nx.select(Nx.greater_equal(tau, 0), 1 / (tau + t), 1 / (tau - t)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
t = Nx.select(Nx.greater_equal(tau, 0), 1 / (tau + t), 1 / (tau - t)) | |
t = Nx.select(tau >= 0, 1 / (tau + t), 1 / (tau - t)) |
t = Nx.sqrt(1 + Nx.pow(tau, 2)) | ||
t = Nx.select(Nx.greater_equal(tau, 0), 1 / (tau + t), 1 / (tau - t)) | ||
|
||
pred = Nx.less_equal(Nx.abs(b), 0.1 * 1.0e-4 * Nx.min(Nx.abs(a), Nx.abs(c))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pred = Nx.less_equal(Nx.abs(b), 0.1 * 1.0e-4 * Nx.min(Nx.abs(a), Nx.abs(c))) | |
pred = Nx.abs(b) <= 1.0e-5 * Nx.min(Nx.abs(a), Nx.abs(c)) |
pred = Nx.less_equal(Nx.abs(b), 0.1 * 1.0e-4 * Nx.min(Nx.abs(a), Nx.abs(c))) | ||
t = Nx.select(pred, 0.0, t) | ||
|
||
c = 1.0 / Nx.sqrt(1.0 + Nx.pow(t, 2)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
c = 1.0 / Nx.sqrt(1.0 + Nx.pow(t, 2)) | |
c = 1.0 / Nx.sqrt(1.0 + t ** 2) |
end | ||
|
||
defn sq_norm(tl, tr, bl, br) do | ||
Nx.sum(Nx.pow(tl, 2) + Nx.pow(tr, 2) + Nx.pow(bl, 2) + Nx.pow(br, 2)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nx.sum(Nx.pow(tl, 2) + Nx.pow(tr, 2) + Nx.pow(bl, 2) + Nx.pow(br, 2)) | |
Nx.sum(tl ** 2 + tr ** 2 + bl ** 2 + br ** 2) |
o_tl = Nx.put_diagonal(tl, diag) | ||
o_br = Nx.put_diagonal(br, diag) | ||
|
||
Nx.sum(Nx.pow(o_tl, 2) + Nx.pow(tr, 2) + Nx.pow(bl, 2) + Nx.pow(o_br, 2)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nx.sum(Nx.pow(o_tl, 2) + Nx.pow(tr, 2) + Nx.pow(bl, 2) + Nx.pow(o_br, 2)) | |
Nx.sum(o_tl ** 2 + tr ** 2 + bl ** 2 + o_br ** 2) |
#1027
The current implementation of
eigh
using a 10x10 symmetric matrix takes about450ms
for a 20x20 matrix and154s
for a 100x100 matrix, while the new implementation takes0.5ms
and11ms
respectively.This is a
defn
version of the method used by XLA: https://github.com/openxla/xla/blob/main/xla/service/eigh_expander.ccThere is still a todo list and code cleanup/drying to do, but I wanted to pitch this before getting to far into the process. While this method has a static submatrix size with no recursion, this approach can be built on to recreate the recursive blocked-eigh used by JAX. This approach had less complexity and seemed like a nice way to make
eigh
performant without having to exactly copy the JAX method.The gist of the method is to break the matrix into four submatrices and apply the jacobi rotations across all rows and cols each iteration and then joining the results.
Draft commit to introduce the idea.
Todo:
Current issues:
Please let me know if this is of any use! <3