Skip to content

Commit

Permalink
Fix reshaping of split sharded tensor (#291)
Browse files Browse the repository at this point in the history
Some cases of inserting dimensions of size 1 were buggy. Added some more
tests.
  • Loading branch information
sogartar authored Oct 18, 2024
1 parent 2850358 commit 38f3808
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 4 deletions.
20 changes: 17 additions & 3 deletions sharktank/sharktank/ops/sharded_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -1161,22 +1161,36 @@ def _reshape_get_single_split_dim(
from_shape: List[int], to_shape: List[int]
) -> Optional[Tuple[int, int]]:
"""If a reshape would split a single dimension, return its index and the length of the new dimensions.
If the reshape is not of that kind return `None`."""
If the reshape is not of that kind return `None`.
E.g.
_reshape_get_single_split_dim(from_shape=(2, 12, 5), to_shape=(2, 3, 4, 5))
results in
(1, 2)"""
from_shape, to_shape = _reshape_infer_dynamic_dim(from_shape, to_shape)

if len(to_shape) < len(from_shape):
return None
i = longest_equal_range(from_shape, to_shape)
split_dims_length = len(to_shape) - len(from_shape) + 1
if i == len(from_shape):
return i
return (
i,
split_dims_length,
)
j = len(to_shape) - longest_equal_range(reversed(from_shape), reversed(to_shape))
assert i < j
expected_split_dim_size = math.prod(to_shape[i:j])
if expected_split_dim_size == 1:
# 1's were inserted.
return (
i,
split_dims_length,
)
if expected_split_dim_size != from_shape[i]:
return None
return (
i,
j - i,
split_dims_length,
)


Expand Down
2 changes: 1 addition & 1 deletion sharktank/sharktank/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def longest_equal_range(l1: List[Any], l2: List[Any]) -> int:
for i, (a, b) in enumerate(zip(l1, l2)):
if a != b:
return i
return len(zip(l1, l2))
return min(len(l1), len(l2))


def iterables_equal(iterable1: Iterable, iterable2: Iterable) -> bool:
Expand Down
54 changes: 54 additions & 0 deletions sharktank/tests/ops/sharded_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,6 +810,42 @@ def testSplitTensorSplitDimIsLeadingFlattenDim(self):
actual_result = ops.reshape(sharded_tensor, new_shape)
assert expected_result.is_deep_equal(actual_result)

def testSplitTensorInsertSize1DimBeforeSplitDim(self):
tensor = torch.rand(4, 5, 6, 7)
new_shape = [4, 1, 5, 6, 7]
unsharded_expected_result = torch.reshape(tensor, new_shape)
shard_dim = 2
expected_result = ops.reshard_split(
unsharded_expected_result, dim=shard_dim + 1, count=2
)
sharded_tensor = ops.reshard_split(tensor, dim=shard_dim, count=2)
actual_result = ops.reshape(sharded_tensor, new_shape)
assert expected_result.is_deep_equal(actual_result)

def testSplitTensorInsertMultipleSize1DimsBeforeSplitDim(self):
tensor = torch.rand(4, 5, 6, 7)
new_shape = [4, 1, 1, 5, 6, 7]
unsharded_expected_result = torch.reshape(tensor, new_shape)
shard_dim = 2
expected_result = ops.reshard_split(
unsharded_expected_result, dim=shard_dim + 2, count=2
)
sharded_tensor = ops.reshard_split(tensor, dim=shard_dim, count=2)
actual_result = ops.reshape(sharded_tensor, new_shape)
assert expected_result.is_deep_equal(actual_result)

def testSplitTensorInsertMultipleSize1TrailingDimsNotRightAfterSplitDim(self):
tensor = torch.rand(4, 5, 6, 7)
new_shape = [4, 5, 6, 7, 1, 1]
unsharded_expected_result = torch.reshape(tensor, new_shape)
shard_dim = 2
expected_result = ops.reshard_split(
unsharded_expected_result, dim=shard_dim, count=2
)
sharded_tensor = ops.reshard_split(tensor, dim=shard_dim, count=2)
actual_result = ops.reshape(sharded_tensor, new_shape)
assert expected_result.is_deep_equal(actual_result)

def testSplitTensorUnflattenNonSplitDim(self):
tensor = torch.rand(3, 20, 6)
new_shape = [3, 4, 5, 6]
Expand All @@ -819,6 +855,15 @@ def testSplitTensorUnflattenNonSplitDim(self):
actual_result = ops.reshape(sharded_tensor, new_shape)
assert expected_result.is_deep_equal(actual_result)

def testSplitTensorUnflattenTrailingNonSplitDim(self):
tensor = torch.rand(3, 4, 30)
new_shape = [3, 4, 5, 6]
unsharded_expected_result = torch.reshape(tensor, new_shape)
expected_result = ops.reshard_split(unsharded_expected_result, dim=1, count=2)
sharded_tensor = ops.reshard_split(tensor, dim=1, count=2)
actual_result = ops.reshape(sharded_tensor, new_shape)
assert expected_result.is_deep_equal(actual_result)

def testSplitTensorUnflattenSplitDim(self):
tensor = torch.rand(3, 20, 6)
new_shape = [3, 4, 5, 6]
Expand All @@ -828,6 +873,15 @@ def testSplitTensorUnflattenSplitDim(self):
actual_result = ops.reshape(sharded_tensor, new_shape)
assert expected_result.is_deep_equal(actual_result)

def testSplitTensorUnflattenTrailingSplitDim(self):
tensor = torch.rand(2, 3, 20)
new_shape = [2, 3, 4, 5]
unsharded_expected_result = torch.reshape(tensor, new_shape)
expected_result = ops.reshard_split(unsharded_expected_result, dim=2, count=2)
sharded_tensor = ops.reshard_split(tensor, dim=2, count=2)
actual_result = ops.reshape(sharded_tensor, new_shape)
assert expected_result.is_deep_equal(actual_result)


class ReshardSplitTest(unittest.TestCase):
def testReshardReplicated(self):
Expand Down

0 comments on commit 38f3808

Please sign in to comment.