-
Notifications
You must be signed in to change notification settings - Fork 68
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Slow contractions when scalars are present and output is large #189
Comments
Just to explain what is happening, scalars are a bit of an edge case, in that they look like 'disconnected' subgraphs (i.e. a node connected to nothing else), and so they are often processed separately because they look like outer products. Since the challenge is usually ordering non-outer contractions, and these are (usually) the most expensive operations, Notionally the optimal way to handle them is trivial - just multiply all scalars together and then multiply this into whichever input, intermediate or output tensor is smallest. Probably it will not be an intermediate, so a suggestion (other than to use the Some side notes on benchmarking:
If you do that, you can see that the different methods and times fall into two cases, those that do the multiplication first (n.b. at this v small sized import numpy as np
import time
import opt_einsum as oe
na = nc = 10000
nb = 50
n_iter = 10
A = np.random.random((na,nb))
B = np.random.random((nb,nc))
t_total = 0.
expr = oe.contract_expression(',ij,jk->ik', (), A.shape, B.shape, optimize='dp')
for i in range(n_iter):
start = time.time()
C = expr(0.5, A, B)
end = time.time()
t_total += end - start
print('AB->C scalar-dp',(t_total)/n_iter)
print(expr)
print()
t_total = 0.
expr = oe.contract_expression(',ij,jk->ik', (), A.shape, B.shape, optimize='optimal')
for i in range(n_iter):
start = time.time()
C = expr(0.5, A, B)
end = time.time()
t_total += end - start
print('AB->C scalar-optimal',(t_total)/n_iter)
print(expr)
print()
t_total = 0.
for i in range(n_iter):
start = time.time()
A @ B
end = time.time()
t_total += end - start
print('gemm',(t_total)/n_iter)
t_total = 0.
for i in range(n_iter):
start = time.time()
np.multiply(C, 0.5, out = C)
end = time.time()
t_total += end - start
print('C->0.5*C out',(t_total)/n_iter)
t_total = 0.
for i in range(n_iter):
start = time.time()
np.einsum(',ij', 0.5, A)
end = time.time()
t_total += end - start
print('einsum multiply small', (t_total) / n_iter)
t_total = 0.
for i in range(n_iter):
start = time.time()
np.einsum(',ij', 0.5, C)
end = time.time()
t_total += end - start
print('einsum multiply large', (t_total) / n_iter)
Part of the problem is also just that |
Thank you so much! I re-organized my test code to use larger dimensions and exclude finding path time. The timing is similar to the top post in this thread.
output
I also found the thing can be traced to Maybe that somehow related to the I will close this issue shortly. |
Yes, to clarify,
I think it might be fine to leave this issue open. Maybe a more specific title like "slow contractions when scalars are present and output is large" would be helpful though. |
The question in this issue is, suppose I have a contraction involving a scalar,
\sum_b 0.5 A[a,b] B[b,c] = C[a,c]
, in principle, the timing will be similar to\sum_b A[a,b] B[b,c] = C[a,c]
(by changing the alpha coefficient in DGEMM). I can find the similar timing, but only in auto/optimal searching path. All other algorithms don't work in the following example. In complicated contractions, the default searching path may not be sufficient. So I would like to ask if there is any approach can capture the good contraction when there is a scalar without going to theoptimize = 'optimal'
.I may look for a walk around, namely,
np.multiply
on top of the\sum_b A[a,b] B[b,c] = C[a,c]
, since the cost will be O(N^2) than O(N^3). In the contraction when the intermediate dimension is much smaller then other dimensions, the wall time ofnp.multiply
is comparable (~50%) with the contraction timing.The output is
The text was updated successfully, but these errors were encountered: