-
Notifications
You must be signed in to change notification settings - Fork 68
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Manual and automatic slicing #95
Comments
This goes a bit with automatic looping over GEMM-like products such as Valid slices inputs could be:
The FLOP cost shouldn't increase I believe? |
So in the extreme case, you just sum all over the indices which ends up being exactly the same as pure einsum. For this case the largest intermediate is now the largest input, whilst the FLOP cost has increased probably dramatically. Slicing just a few indices moves you in that direction just a bit: assuming you already have the lowest memory contraction path, it allows you to save a bit more memory, but in general doing Here's an illustration: import numpy as np
import opt_einsum as oe
# setup a contraction
eq = 'ab,bc,ca->'
x, y, z = (np.random.randn(100, 100) for _ in 'xyz')
path = [(0, 1), (0, 1)]
oe.contract(eq, x, y, z, optimize=path)
# array(140.59548574)
# workout the normal cost
info = oe.contract_path(eq, x, y, z, optimize=path)[1]
print(info.opt_cost)
# 2020000 Now imagine we want to slice over 'b', ultimately that would look like this sum([
oe.contract(eq, x[:, [i]], y[[i], :], z, optimize=path)
for i in range(100)
])
# 140.595485736799 each of these is now completely independent (and could be done in parallel), and you can imagine that for many contractions the peak memory will be lower. The cost of a single of these is now: info_slice = oe.contract_path(eq, x[:, [0]], y[[0], :], z, optimize=path)[1]
print(info_slice.opt_cost)
# 40000 But since we need to do 100 of these, our FLOP cost has actually risen a bit: (100 * 40000) / 2020000
# 1.9801980198019802 Finding the indices which balance the memory reduction / parallelization and the total flops increase is the name of the game! But you can imagine, reducing a contraction to fit on a GPU, or parallelize over a cluster, can really outweigh a moderate FLOPS increase. Moreover, sometimes a contraction is too big to fit in memory and you really have no other options. |
This goes a bit back to FLOP cost is not the only metric, especially when it comes to GPU's. For that, it seems a weighted metric would help (outer product: 2x cost, Hadamard: 1x cost, GEMM: 0.75x cost, TPU-op: 0.1x cost). The slicing idea seems good to me, automatic construction of Python loops should be fairly straightforward. The depth here is pretty low I think overall. Do you foresee any issues? |
I've been making some nice progress on this. There's quite a lot of considerations so I think it makes sense to focus on an explicit interface first. Thought I'd post an example of the kind of design I am considering for future reference and in case anyone has feedback! from opt_einsum import SliceFinder, SlicedContraction
# there'll be sensible defaults for all these options, which guide how to heuristically
# choose the indices to slice over (in general an NP-Hard problem!)
sf = SliceFinder(path_info, parallel=True, power_factor=1.5, temperature=0.5)
inds_to_slice = sf.search(max_repeats=128, target_memory=2**27)
# automatic index finding separated out from performing the sliced contraction
sc = SlicedContraction(eq, arrays, inds_to_slice, optimize=path_info.path)
print(sc.info)
{
'num_slices': 512,
'sliced_cost': 94129464192, # can be compared to path_info.opt_cost
'largest_intermediate': 39741827364, # etc.
}
# simply perform all contractions, now using less memory for e.g. GPU.
# this will internally generate a ContractionExpression and specify constant tensors etc
x = sc.contract(backend='jax')
# ... or explicitly perform contractions independently (could be distributed)
futures = [
pool.submit(oe.contract, sc.eq, *sliced_arrays, optimize=sc.path)
for sliced_arrays in sc.gen_sliced_arrays()
]
x = sum(f.result() for f in futures) Ultimately it might be nice (and I think straightforward) to have an option like Possible issues:
|
Cool, this looks like it is coming along well! I wonder if it would be worth have deeper Dask integration to automatically batch this out. The interface prototype seems good, it would be good to see the algorithm as that seems like the real trick. Sensible defaults for |
Cool, if this seems like a decent interface I'll start plugging in the actual algorithm.
Yep,
One 'sensible' default setting is just to keep on slicing indices until the total flops cost rises above say 110% of the original contraction, |
@jcmgray Seems we have a good use case for this. After sitting on it a bit, how easy do you think it would be to port the slicing component over? |
So I would definitely still like to add this in. There are a few different use cases that require thinking about the API. My thoughts are that maybe all this functionality can be cleanly added to the current functions. 1 - Automatic slicing and contractionI think being able to just do the following and have it use slicing might be very useful (?):
I.e. have slices found and performed automatically in the background (maybe with a warning if the overhead induced rises above e.g. 2x). I'm pretty sure this is always advantageous as compared to the What would path, info = oe.contract_path(eq, *arrays, memory_limit=2**28) mean? I think the 2 - Explicit slicing but automatic contractionGiven the above point, if this path and info is computed first, it can then be used like oe.contract(eq, *arrays, optimize=info.path, slice=info.sliced_inds) This # memory_limit would really just get translated to
res = oe.contract(eq, *arrays, slice=oe.SliceFinder(target_size=2**28))
# aim for 128 separate contractions (mainly useful for parallelization purposes)
path, info = oe.contract_path(eq, *arrays, slice=oe.SliceFinder(target_slices=128))
# make the contraction as small as possible without inducing overhead greater than 2x
res = oe.contract(eq, *arrays, slice=oe.SliceFinder(target_overhead=2)) 3 - Explicit slicing and explicit contractionThe final case is when one wants to have some control over how to perform the individual contractions - e.g. distributed on a cluster or something. I think it makes sense to use the expr = oe.contract_expression(eq, *shapes, path=info.path, slice=info.sliced_inds)
# can still call as normal (and do sum over slices automatically)
res = expr(arrays, backend='cupy')
# or manually contract each slice
res = 0
for i in range(path.multiplicity):
res += expr[i](arrays) Or eventually maybe just let the use specify a res = expr(arrays, parallel=pool) Or Some possible future nuances
|
I like 1), but I want to be fairly cautious about it and it feels like something that can be worked through once the slicing infrastructure is complete. Another item to consider is to use the language of dask.array since this is the most popular distribution library that I am aware of. This would mean chunks instead of slices fundamentally, but their handling of edge chunks for example could be adopted. In addition to the slicing cases you outlined it may be the case that only specific contractions require slicing further complicating the path finding. |
Indeed (1) could be implemented later as it just involves translating the call to (2).
Ah do you mean calling things However, what are 'edge chunks' and how does dask handle them?
Right, if no slicing is required the contraction should just dispatch to the usual method - that's one motivation of having this functionality incorporated into |
Yes to What if we take this in stages and implement this in |
For the moment, I only really know how to efficiently think about the chunksize 1 case (as you can simply remove the edge from the graph), so with that starting point there will be no edge chunks at least!
Sounds like a plan. For what its worth, in the case of actually contracting with dask I finally got round to (hopefully) fixing |
Tensor slicing (essentially explicitly doing an outer sum over some indices rather than including them in the pairwise contractions) is a pretty useful way of decreasing the memory cost of performing a contraction (at some computational cost increase) in order to fit it in memory or on the GPU and/or massively parallelizing it.
I thought I'd chuck out the following kind of functionaliy and see if it might be useful to people?
contract('abm,bcm,cd', ... slice=['m'])
contract_sliced
I proof of principle versions of all of these which could be starting point (for me or someone else...). Here's what a full example might look like:
The text was updated successfully, but these errors were encountered: