Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Manual and automatic slicing #95

Open
jcmgray opened this issue Aug 7, 2019 · 12 comments
Open

Manual and automatic slicing #95

jcmgray opened this issue Aug 7, 2019 · 12 comments

Comments

@jcmgray
Copy link
Collaborator

jcmgray commented Aug 7, 2019

Tensor slicing (essentially explicitly doing an outer sum over some indices rather than including them in the pairwise contractions) is a pretty useful way of decreasing the memory cost of performing a contraction (at some computational cost increase) in order to fit it in memory or on the GPU and/or massively parallelizing it.

I thought I'd chuck out the following kind of functionaliy and see if it might be useful to people?

  1. Be able to supply a list of which indices to slice over contract('abm,bcm,cd', ... slice=['m'])
  2. Return an iterator over the slices so that you can perform them in parallel as you wish, maybe contract_sliced
  3. Perform some inference on a path to determine which indices should be sliced (generally the ones that appear on the largest intermediate). You could specific the maximum memory, and/or the minimum number of slices etc.

I proof of principle versions of all of these which could be starting point (for me or someone else...). Here's what a full example might look like:

# find an initial path
path, info = oe.contract_path(eq, *arrays, optimize='random-greedy')

slices = info.suggest_slices(max_memory=2**20, min_num_slices=None)
print(info.slice_statistics(slices))
# slice: ['a', 'd', 'e']
# total size: 64 
# peak memory reduction: 30.45x
# flops increase 2.53x etc.

# perform contraction EDIT: (not sure about this general syntax)
sliced_arrays = oe.gen_sliced_arrays(eq, *arrays, slices=slices)

# each item of sliced_arrays will be the input tensors with some
#    combination of the ['a', 'd', 'e'] dimensions selected
sum(my_parallel_pool.map(
    lambda x: oe.contract(eq, *x, optimize=path),
    sliced_arrays
)
@dgasmith
Copy link
Owner

dgasmith commented Aug 8, 2019

This goes a bit with automatic looping over GEMM-like products such as aij,ajk->aik. max_memory would be hard to satisfy I think, I guess you could try to automatically loop largest intermediates?

Valid slices inputs could be:

  • [{"index": m, "slices": [(0, 10), (10, 20), ...]}, ...]
  • [{"index": m, "stride": 1, ...]
  • max_memory
  • best

The FLOP cost shouldn't increase I believe?

@jcmgray
Copy link
Collaborator Author

jcmgray commented Aug 8, 2019

So in the extreme case, you just sum all over the indices which ends up being exactly the same as pure einsum. For this case the largest intermediate is now the largest input, whilst the FLOP cost has increased probably dramatically. Slicing just a few indices moves you in that direction just a bit: assuming you already have the lowest memory contraction path, it allows you to save a bit more memory, but in general doing d times the sliced contraction is more expensive.

Here's an illustration:

import numpy as np
import opt_einsum as oe

# setup a contraction
eq = 'ab,bc,ca->'
x, y, z = (np.random.randn(100, 100) for _ in 'xyz')
path = [(0, 1), (0, 1)]
oe.contract(eq, x, y, z, optimize=path)
# array(140.59548574)

# workout the normal cost
info = oe.contract_path(eq, x, y, z, optimize=path)[1]
print(info.opt_cost)
# 2020000

Now imagine we want to slice over 'b', ultimately that would look like this

sum([
    oe.contract(eq, x[:, [i]], y[[i], :], z, optimize=path) 
    for i in range(100)
])
# 140.595485736799

each of these is now completely independent (and could be done in parallel), and you can imagine that for many contractions the peak memory will be lower. The cost of a single of these is now:

info_slice = oe.contract_path(eq, x[:, [0]], y[[0], :], z, optimize=path)[1]
print(info_slice.opt_cost)
# 40000

But since we need to do 100 of these, our FLOP cost has actually risen a bit:

(100 * 40000) / 2020000
# 1.9801980198019802

Finding the indices which balance the memory reduction / parallelization and the total flops increase is the name of the game! But you can imagine, reducing a contraction to fit on a GPU, or parallelize over a cluster, can really outweigh a moderate FLOPS increase. Moreover, sometimes a contraction is too big to fit in memory and you really have no other options.

@dgasmith
Copy link
Owner

This goes a bit back to FLOP cost is not the only metric, especially when it comes to GPU's. For that, it seems a weighted metric would help (outer product: 2x cost, Hadamard: 1x cost, GEMM: 0.75x cost, TPU-op: 0.1x cost).

The slicing idea seems good to me, automatic construction of Python loops should be fairly straightforward. The depth here is pretty low I think overall. Do you foresee any issues?

@jcmgray
Copy link
Collaborator Author

jcmgray commented Oct 30, 2019

I've been making some nice progress on this. There's quite a lot of considerations so I think it makes sense to focus on an explicit interface first. Thought I'd post an example of the kind of design I am considering for future reference and in case anyone has feedback!

from opt_einsum import SliceFinder, SlicedContraction

# there'll be sensible defaults for all these options, which guide how to heuristically 
# choose the indices to slice over (in general an NP-Hard problem!)
sf = SliceFinder(path_info, parallel=True, power_factor=1.5, temperature=0.5)
inds_to_slice = sf.search(max_repeats=128, target_memory=2**27)

# automatic index finding separated out from performing the sliced contraction
sc = SlicedContraction(eq, arrays, inds_to_slice, optimize=path_info.path)
print(sc.info)
{
    'num_slices': 512,
    'sliced_cost': 94129464192,  # can be compared to path_info.opt_cost
    'largest_intermediate': 39741827364,  # etc.
}

# simply perform all contractions, now using less memory for e.g. GPU.
# this will internally generate a ContractionExpression and specify constant tensors etc
x = sc.contract(backend='jax')

# ... or explicitly perform contractions independently (could be distributed)
futures = [
    pool.submit(oe.contract, sc.eq, *sliced_arrays, optimize=sc.path)
    for sliced_arrays in sc.gen_sliced_arrays()
]
x = sum(f.result() for f in futures)

Ultimately it might be nice (and I think straightforward) to have an option like contract(..., slice='auto') that could perform this stuff sensibly in the background.

Possible issues:

  • If output indices are sliced over, the contraction is no longer a simple sum, instead you are computing different bits of the output tensor, requiring combining them all in an order specific way. I think it makes sense to simply not support this for now.

@dgasmith
Copy link
Owner

Cool, this looks like it is coming along well! I wonder if it would be worth have deeper Dask integration to automatically batch this out. The interface prototype seems good, it would be good to see the algorithm as that seems like the real trick.

Sensible defaults for contract seem difficult, though we seem to do pretty well in the past for guessing a broad use case.

@jcmgray
Copy link
Collaborator Author

jcmgray commented Oct 31, 2019

Cool, if this seems like a decent interface I'll start plugging in the actual algorithm.

I wonder if it would be worth have deeper Dask integration to automatically batch this out.

Yep, to_chunked_dask_arrays or something would be a very natural fit! The indices to slice are good choices for indices to chunk, and - depending on the performance of dask's caching/compute graph handling, there is then theoretically no flops increase.

Sensible defaults for contract seem difficult, though we seem to do pretty well in the past for guessing a broad use case.

One 'sensible' default setting is just to keep on slicing indices until the total flops cost rises above say 110% of the original contraction,

@dgasmith
Copy link
Owner

@jcmgray Seems we have a good use case for this. After sitting on it a bit, how easy do you think it would be to port the slicing component over?

@jcmgray
Copy link
Collaborator Author

jcmgray commented Jun 30, 2020

So I would definitely still like to add this in. There are a few different use cases that require thinking about the API. My thoughts are that maybe all this functionality can be cleanly added to the current functions.

1 - Automatic slicing and contraction

I think being able to just do the following and have it use slicing might be very useful (?):

oe.contract(eq, *arrays, memory_limit=2**28)

I.e. have slices found and performed automatically in the background (maybe with a warning if the overhead induced rises above e.g. 2x). I'm pretty sure this is always advantageous as compared to the memory_limit options for the paths - and it might simplify things to remove all that logic in favor of slicing.

What would

path, info = oe.contract_path(eq, *arrays, memory_limit=2**28)

mean? I think the PathInfo object should additionally carry sliced_inds and a 'multiplicity' (i.e. the number of slices - could be nslices also). The .opt_cost is then the multiplicity times the individual slice contraction cost, (and I think most other attributes are easily translated).

2 - Explicit slicing but automatic contraction

Given the above point, if this path and info is computed first, it can then be used like

oe.contract(eq, *arrays, optimize=info.path, slice=info.sliced_inds)

This slice kwarg could either be an explicit list of indices like this and also be used for manually specifying options regarding how to find slices:

# memory_limit would really just get translated to
res = oe.contract(eq, *arrays, slice=oe.SliceFinder(target_size=2**28))

# aim for 128 separate contractions (mainly useful for parallelization purposes)
path, info = oe.contract_path(eq, *arrays, slice=oe.SliceFinder(target_slices=128))

# make the contraction as small as possible without inducing overhead greater than 2x
res = oe.contract(eq, *arrays, slice=oe.SliceFinder(target_overhead=2))

3 - Explicit slicing and explicit contraction

The final case is when one wants to have some control over how to perform the individual contractions - e.g. distributed on a cluster or something. I think it makes sense to use the ContractionExpression for this.

expr = oe.contract_expression(eq, *shapes, path=info.path, slice=info.sliced_inds)

# can still call as normal (and do sum over slices automatically) 
res = expr(arrays, backend='cupy')

# or manually contract each slice
res = 0
for i in range(path.multiplicity):
    res += expr[i](arrays)

Or eventually maybe just let the use specify a pool:

res = expr(arrays, parallel=pool)

Or contract_slices_dask or similar, these could come later.

Some possible future nuances

  • Whether to allow output indices
  • how the actual SliceFinder is called, i.e. allowing custom slice finders
  • It might be nice to support returning a contraction path in conjuction with the indices to slice simultaneously (which happens when you are using methods as for e.g. https://arxiv.org/abs/2005.06787) . Rather than just slicing once the path is found

@dgasmith
Copy link
Owner

dgasmith commented Jul 1, 2020

I like 1), but I want to be fairly cautious about it and it feels like something that can be worked through once the slicing infrastructure is complete.

Another item to consider is to use the language of dask.array since this is the most popular distribution library that I am aware of. This would mean chunks instead of slices fundamentally, but their handling of edge chunks for example could be adopted.

In addition to the slicing cases you outlined it may be the case that only specific contractions require slicing further complicating the path finding.

@jcmgray
Copy link
Collaborator Author

jcmgray commented Jul 2, 2020

I like 1), but I want to be fairly cautious about it and it feels like something that can be worked through once the slicing infrastructure is complete.

Indeed (1) could be implemented later as it just involves translating the call to (2).

Another item to consider is to use the language of dask.array since this is the most popular distribution library that I am aware of. This would mean chunks instead of slices fundamentally, but their handling of edge chunks for example could be adopted.

Ah do you mean calling things chunked_inds/ChunkFinder etc? - that might be a good call. The current implementation I have only does full slicing (i.e. chunksize 1) but maybe changing that could be a future upgrade and we could keep an eye on the API to allow such a thing (chunked_inds={'a': 1, 'b': 2}).

However, what are 'edge chunks' and how does dask handle them?

In addition to the slicing cases you outlined it may be the case that only specific contractions require slicing further complicating the path finding.

Right, if no slicing is required the contraction should just dispatch to the usual method - that's one motivation of having this functionality incorporated into contract, contract_path, contract_expression rather than a separate SlicedContractor as I had originally suggested.

@dgasmith
Copy link
Owner

dgasmith commented Jul 2, 2020

Yes to ChunkFinder, and chunked_inds={'a': 1, 'b': 2} was quite useful when I was using Dask. 'edge chunks' simply referring to the common case of non-integer numbers of chunks for a given size (e.g. dimension of 13, chunk size of 4, you get a 1 remainder). I hadn't read that article is some time, but they used to discuss this and apparently do not these days ;/ Maybe follow their lead as well and omit this for simplicity.

What if we take this in stages and implement this in contract_path to start with (even if doesn't go anywhere) to see the level of complexity it adds to that particular function?

@jcmgray
Copy link
Collaborator Author

jcmgray commented Jul 3, 2020

For the moment, I only really know how to efficiently think about the chunksize 1 case (as you can simply remove the edge from the graph), so with that starting point there will be no edge chunks at least!

What if we take this in stages and implement this in contract_path to start with (even if doesn't go anywhere) to see the level of complexity it adds to that particular function?

Sounds like a plan.

For what its worth, in the case of actually contracting with dask I finally got round to (hopefully) fixing dask.tensordot for more than 32 involved indices (dask/dask#6368), which at least in my case, should make it practical to actually perform the sliced contractions with dask itself as well.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants