Skip to content

Commit

Permalink
[IFRT] Extend Sharding disassembly operations to distinguish between …
Browse files Browse the repository at this point in the history
…addressable-shards and all-shards processing

IFRT's assembly/disassembly operations
(`Client::AssembleArrayFromSingleDeviceArray`,
`Array::DisassembleIntoSingleDeviceArrays`, and related methods in `Sharding`)
treated all shards equally without distinguishing the addressability of the
device of the shards. This had practical problems:

* When the user only has single-device arrays for addressable devices, and
asssemble them into a multi-shard array, the user is forced to use a `Sharding`
that only contains addressable devices. However, with SPMD, it is common to use
a `Sharding` that can express both adressable/non-addressable shards (e.g.,
`HloSharding`).

* When the user has a multi-shard array that spans both
addressable/non-addressable devices, disassembling the array into single-device
arrays would create a single-device array with no addressable devices, which is
often not well supported in the user code because the user code sometimes makes
a strong assumption that any array contains at least one addressable device.

On the other hand, making assembly/diassembly handle only addressable shards is
not future proof. An MPMD setup (not all inputs use a single device mesh) can
see an array with no addressable devices. Thus, changing assembly/diassembly
sematics to handle only addressable shards is too restrictive.

To resolve this single-device array addressability issue, this change makes it
explicit whether only addressable shards will be processed or all shards will
be processed in assembly/disassembly operations.

This change focuses on extending `Sharding` interface and implementing
addressable-shards processing. The default behavior remains to be processing
all shards.

Using this new capability for `Array` assembly and disassembly is a separate
change that will be sent out soon.

It will also be done as subsequent changes to make the IFRT user code to
request only addressable devices.

PiperOrigin-RevId: 685822245
  • Loading branch information
hyeontaek authored and Google-ML-Automation committed Oct 14, 2024
1 parent 63311bd commit 032b62e
Show file tree
Hide file tree
Showing 7 changed files with 1,329 additions and 281 deletions.
17 changes: 17 additions & 0 deletions xla/python/ifrt/mock.h
Original file line number Diff line number Diff line change
Expand Up @@ -335,11 +335,28 @@ class MockSharding : public llvm::RTTIExtends<MockSharding, Sharding> {
(absl::StatusOr<
std::vector<std::pair<Shape, std::shared_ptr<const Sharding>>>>),
Disassemble, (const Shape& shape), (const, final));
MOCK_METHOD(
(absl::StatusOr<
std::vector<std::pair<Shape, std::shared_ptr<const Sharding>>>>),
Disassemble,
(const Shape& shape,
SingleDeviceShardSemantics single_device_shard_semantics),
(const, final));
MOCK_METHOD((absl::StatusOr<std::vector<
std::pair<DynamicShape, std::shared_ptr<const Sharding>>>>),
Disassemble, (const DynamicShape& dynamic_shape), (const final));
MOCK_METHOD((absl::StatusOr<std::vector<
std::pair<DynamicShape, std::shared_ptr<const Sharding>>>>),
Disassemble,
(const DynamicShape& dynamic_shape,
SingleDeviceShardSemantics single_device_shard_semantics),
(const final));
MOCK_METHOD(absl::StatusOr<std::vector<IndexDomain>>, IndexDomains,
(const Shape& shape), (const, final));
MOCK_METHOD(absl::StatusOr<std::vector<IndexDomain>>, IndexDomains,
(const Shape& shape,
SingleDeviceShardSemantics single_device_shard_semantics),
(const, final));
MOCK_METHOD(std::string, DebugString, (), (const, final));
MOCK_METHOD(absl::StatusOr<Shape>, GetShardShape, (const Shape& shape),
(const, final));
Expand Down
224 changes: 201 additions & 23 deletions xla/python/ifrt/sharding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -237,26 +237,63 @@ SingleDeviceSharding::WithDeviceAssignment(
absl::StatusOr<std::vector<std::pair<Shape, std::shared_ptr<const Sharding>>>>
SingleDeviceSharding::Disassemble(const Shape& shape) const {
DCHECK(this);
return std::vector<std::pair<Shape, std::shared_ptr<const Sharding>>>{
{shape, SingleDeviceSharding::Create(devices_->devices().front(),
memory_kind_)}};
return Disassemble(shape, SingleDeviceShardSemantics::kAllShards);
}

absl::StatusOr<std::vector<std::pair<Shape, std::shared_ptr<const Sharding>>>>
SingleDeviceSharding::Disassemble(
const Shape& shape,
SingleDeviceShardSemantics single_device_shard_semantics) const {
DCHECK(this);
std::vector<std::pair<Shape, std::shared_ptr<const Sharding>>> result;
if (single_device_shard_semantics == SingleDeviceShardSemantics::kAllShards ||
devices_->devices().front()->IsAddressable()) {
result.reserve(1);
result.push_back({shape, SingleDeviceSharding::Create(
devices_->devices().front(), memory_kind_)});
}
return result;
}

absl::StatusOr<
std::vector<std::pair<DynamicShape, std::shared_ptr<const Sharding>>>>
SingleDeviceSharding::Disassemble(const DynamicShape& dynamic_shape) const {
DCHECK(this);
return std::vector<std::pair<DynamicShape, std::shared_ptr<const Sharding>>>{
{dynamic_shape, SingleDeviceSharding::Create(devices_->devices().front(),
memory_kind_)}};
return Disassemble(dynamic_shape, SingleDeviceShardSemantics::kAllShards);
}
absl::StatusOr<
std::vector<std::pair<DynamicShape, std::shared_ptr<const Sharding>>>>
SingleDeviceSharding::Disassemble(
const DynamicShape& dynamic_shape,
SingleDeviceShardSemantics single_device_shard_semantics) const {
DCHECK(this);
std::vector<std::pair<DynamicShape, std::shared_ptr<const Sharding>>> result;
if (single_device_shard_semantics == SingleDeviceShardSemantics::kAllShards ||
devices_->devices().front()->IsAddressable()) {
result.reserve(1);
result.push_back(
{dynamic_shape, SingleDeviceSharding::Create(
devices_->devices().front(), memory_kind_)});
}
return result;
}

absl::StatusOr<std::vector<IndexDomain>> SingleDeviceSharding::IndexDomains(
const Shape& shape) const {
DCHECK(this);
return IndexDomains(shape, SingleDeviceShardSemantics::kAllShards);
}

absl::StatusOr<std::vector<IndexDomain>> SingleDeviceSharding::IndexDomains(
const Shape& shape,
SingleDeviceShardSemantics single_device_shard_semantics) const {
DCHECK(this);
std::vector<IndexDomain> result;
result.reserve(1);
result.push_back(IndexDomain(shape));
if (single_device_shard_semantics == SingleDeviceShardSemantics::kAllShards ||
devices_->devices().front()->IsAddressable()) {
result.reserve(1);
result.push_back(IndexDomain(shape));
}
return result;
}

Expand Down Expand Up @@ -308,6 +345,14 @@ absl::StatusOr<std::unique_ptr<Sharding>> OpaqueSharding::WithDeviceAssignment(
absl::StatusOr<std::vector<std::pair<Shape, std::shared_ptr<const Sharding>>>>
OpaqueSharding::Disassemble(const Shape& shape) const {
DCHECK(this);
return Disassemble(shape, SingleDeviceShardSemantics::kAllShards);
}

absl::StatusOr<std::vector<std::pair<Shape, std::shared_ptr<const Sharding>>>>
OpaqueSharding::Disassemble(
const Shape& shape,
SingleDeviceShardSemantics single_device_shard_semantics) const {
DCHECK(this);
return InvalidArgument(
"OpaqueSharding does not have shard shape information");
}
Expand All @@ -316,13 +361,29 @@ absl::StatusOr<
std::vector<std::pair<DynamicShape, std::shared_ptr<const Sharding>>>>
OpaqueSharding::Disassemble(const DynamicShape& dynamic_shape) const {
DCHECK(this);
return Disassemble(dynamic_shape, SingleDeviceShardSemantics::kAllShards);
}

absl::StatusOr<
std::vector<std::pair<DynamicShape, std::shared_ptr<const Sharding>>>>
OpaqueSharding::Disassemble(
const DynamicShape& dynamic_shape,
SingleDeviceShardSemantics single_device_shard_semantics) const {
DCHECK(this);
return InvalidArgument(
"OpaqueSharding does not have shard shape information");
}

absl::StatusOr<std::vector<IndexDomain>> OpaqueSharding::IndexDomains(
const Shape& shape) const {
DCHECK(this);
return IndexDomains(shape, SingleDeviceShardSemantics::kAllShards);
}

absl::StatusOr<std::vector<IndexDomain>> OpaqueSharding::IndexDomains(
const Shape& shape,
SingleDeviceShardSemantics single_device_shard_semantics) const {
DCHECK(this);
return InvalidArgument(
"OpaqueSharding does not have index domain information");
}
Expand Down Expand Up @@ -413,6 +474,14 @@ ConcreteSharding::WithDeviceAssignment(
absl::StatusOr<std::vector<std::pair<Shape, std::shared_ptr<const Sharding>>>>
ConcreteSharding::Disassemble(const Shape& shape) const {
DCHECK(this);
return Disassemble(shape, SingleDeviceShardSemantics::kAllShards);
}

absl::StatusOr<std::vector<std::pair<Shape, std::shared_ptr<const Sharding>>>>
ConcreteSharding::Disassemble(
const Shape& shape,
SingleDeviceShardSemantics single_device_shard_semantics) const {
DCHECK(this);
if (!has_static_shape()) {
return InvalidArgument(
"ConcreteSharding holds dynamic shape, but was asked "
Expand All @@ -428,11 +497,19 @@ ConcreteSharding::Disassemble(const Shape& shape) const {
std::vector<std::pair<Shape, std::shared_ptr<const Sharding>>> result;
const std::vector<Shape>& shard_shapes =
std::get<std::vector<Shape>>(shard_shapes_);
if (single_device_shard_semantics == SingleDeviceShardSemantics::kAllShards) {
result.reserve(devices_->size());
} else {
result.reserve(devices_->AddressableDeviceList()->size());
}
const absl::Span<Device* const> devices = devices_->devices();
result.reserve(devices.size());
for (int i = 0; i < devices.size(); ++i) {
result.push_back({shard_shapes[i],
SingleDeviceSharding::Create(devices[i], memory_kind_)});
if (single_device_shard_semantics ==
SingleDeviceShardSemantics::kAllShards ||
devices[i]->IsAddressable()) {
result.push_back({shard_shapes[i], SingleDeviceSharding::Create(
devices[i], memory_kind_)});
}
}
return result;
}
Expand All @@ -441,6 +518,15 @@ absl::StatusOr<
std::vector<std::pair<DynamicShape, std::shared_ptr<const Sharding>>>>
ConcreteSharding::Disassemble(const DynamicShape& dynamic_shape) const {
DCHECK(this);
return Disassemble(dynamic_shape, SingleDeviceShardSemantics::kAllShards);
}

absl::StatusOr<
std::vector<std::pair<DynamicShape, std::shared_ptr<const Sharding>>>>
ConcreteSharding::Disassemble(
const DynamicShape& dynamic_shape,
SingleDeviceShardSemantics single_device_shard_semantics) const {
DCHECK(this);
if (!has_dynamic_shape()) {
return InvalidArgument(
"ConcreteSharding holds static shape, but was asked "
Expand All @@ -458,17 +544,33 @@ ConcreteSharding::Disassemble(const DynamicShape& dynamic_shape) const {
const std::vector<DynamicShape>& shard_dynamic_shapes =
std::get<std::vector<DynamicShape>>(shard_shapes_);
const absl::Span<Device* const> devices = devices_->devices();
result.reserve(devices.size());
if (single_device_shard_semantics == SingleDeviceShardSemantics::kAllShards) {
result.reserve(devices_->size());
} else {
result.reserve(devices_->AddressableDeviceList()->size());
}
for (int i = 0; i < devices.size(); ++i) {
result.push_back({shard_dynamic_shapes[i],
SingleDeviceSharding::Create(devices[i], memory_kind_)});
if (single_device_shard_semantics ==
SingleDeviceShardSemantics::kAllShards ||
devices[i]->IsAddressable()) {
result.push_back(
{shard_dynamic_shapes[i],
SingleDeviceSharding::Create(devices[i], memory_kind_)});
}
}
return result;
}

absl::StatusOr<std::vector<IndexDomain>> ConcreteSharding::IndexDomains(
const Shape& shape) const {
DCHECK(this);
return IndexDomains(shape, SingleDeviceShardSemantics::kAllShards);
}

absl::StatusOr<std::vector<IndexDomain>> ConcreteSharding::IndexDomains(
const Shape& shape,
SingleDeviceShardSemantics single_device_shard_semantics) const {
DCHECK(this);
return InvalidArgument(
"ConcreteSharding does not have index domain information");
}
Expand Down Expand Up @@ -552,6 +654,14 @@ ConcreteEvenSharding::WithDeviceAssignment(
absl::StatusOr<std::vector<std::pair<Shape, std::shared_ptr<const Sharding>>>>
ConcreteEvenSharding::Disassemble(const Shape& shape) const {
DCHECK(this);
return Disassemble(shape, SingleDeviceShardSemantics::kAllShards);
}

absl::StatusOr<std::vector<std::pair<Shape, std::shared_ptr<const Sharding>>>>
ConcreteEvenSharding::Disassemble(
const Shape& shape,
SingleDeviceShardSemantics single_device_shard_semantics) const {
DCHECK(this);
if (shape != shape_) {
return InvalidArgument(
"ConcreteEvenSharding can only disassemble shape %s, but was asked "
Expand All @@ -560,17 +670,35 @@ ConcreteEvenSharding::Disassemble(const Shape& shape) const {
}
std::vector<std::pair<Shape, std::shared_ptr<const Sharding>>> result;
const absl::Span<Device* const> devices = devices_->devices();
result.reserve(devices.size());
if (single_device_shard_semantics == SingleDeviceShardSemantics::kAllShards) {
result.reserve(devices_->size());
} else {
result.reserve(devices_->AddressableDeviceList()->size());
}
for (int i = 0; i < devices.size(); ++i) {
result.push_back(
{shard_shape_, SingleDeviceSharding::Create(devices[i], memory_kind_)});
if (single_device_shard_semantics ==
SingleDeviceShardSemantics::kAllShards ||
devices[i]->IsAddressable()) {
result.push_back({shard_shape_, SingleDeviceSharding::Create(
devices[i], memory_kind_)});
}
}
return result;
}

absl::StatusOr<
std::vector<std::pair<DynamicShape, std::shared_ptr<const Sharding>>>>
ConcreteEvenSharding::Disassemble(const DynamicShape& dynamic_shape) const {
DCHECK(this);
return Disassemble(dynamic_shape, SingleDeviceShardSemantics::kAllShards);
}

absl::StatusOr<
std::vector<std::pair<DynamicShape, std::shared_ptr<const Sharding>>>>
ConcreteEvenSharding::Disassemble(
const DynamicShape& dynamic_shape,
SingleDeviceShardSemantics single_device_shard_semantics) const {
DCHECK(this);
return InvalidArgument(
"ConcreteEvenSharding can only disassemble static shape, but was asked "
"to disassemble dynamic shape %s",
Expand All @@ -580,6 +708,12 @@ ConcreteEvenSharding::Disassemble(const DynamicShape& dynamic_shape) const {
absl::StatusOr<std::vector<IndexDomain>> ConcreteEvenSharding::IndexDomains(
const Shape& shape) const {
DCHECK(this);
return IndexDomains(shape, SingleDeviceShardSemantics::kAllShards);
}
absl::StatusOr<std::vector<IndexDomain>> ConcreteEvenSharding::IndexDomains(
const Shape& shape,
SingleDeviceShardSemantics single_device_shard_semantics) const {
DCHECK(this);
return InvalidArgument(
"ConcreteEvenSharding does not have index domain information");
}
Expand Down Expand Up @@ -622,12 +756,29 @@ ShardingParamSharding::ShardingParamSharding(
absl::StatusOr<std::vector<std::pair<Shape, std::shared_ptr<const Sharding>>>>
ShardingParamSharding::Disassemble(const Shape& shape) const {
DCHECK(this);
return Disassemble(shape, SingleDeviceShardSemantics::kAllShards);
}

absl::StatusOr<std::vector<std::pair<Shape, std::shared_ptr<const Sharding>>>>
ShardingParamSharding::Disassemble(
const Shape& shape,
SingleDeviceShardSemantics single_device_shard_semantics) const {
DCHECK(this);
TF_ASSIGN_OR_RETURN(Shape local_shape, GetShardShape(shape));

std::vector<std::pair<Shape, std::shared_ptr<const Sharding>>> result;
if (single_device_shard_semantics == SingleDeviceShardSemantics::kAllShards) {
result.reserve(devices_->size());
} else {
result.reserve(devices_->AddressableDeviceList()->size());
}
for (Device* device : devices_->devices()) {
result.push_back(
{local_shape, SingleDeviceSharding::Create(device, memory_kind_)});
if (single_device_shard_semantics ==
SingleDeviceShardSemantics::kAllShards ||
device->IsAddressable()) {
result.push_back(
{local_shape, SingleDeviceSharding::Create(device, memory_kind_)});
}
}

return result;
Expand Down Expand Up @@ -684,6 +835,16 @@ ShardingParamSharding::WithDeviceAssignment(
absl::StatusOr<
std::vector<std::pair<DynamicShape, std::shared_ptr<const Sharding>>>>
ShardingParamSharding::Disassemble(const DynamicShape& dynamic_shape) const {
DCHECK(this);
return Disassemble(dynamic_shape, SingleDeviceShardSemantics::kAllShards);
}

absl::StatusOr<
std::vector<std::pair<DynamicShape, std::shared_ptr<const Sharding>>>>
ShardingParamSharding::Disassemble(
const DynamicShape& dynamic_shape,
SingleDeviceShardSemantics single_device_shard_semantics) const {
DCHECK(this);
return InvalidArgument(
"ShardingParamSharding can only disassemble static shape, but was asked "
"to disassemble dynamic shape %s",
Expand All @@ -693,6 +854,13 @@ ShardingParamSharding::Disassemble(const DynamicShape& dynamic_shape) const {
absl::StatusOr<std::vector<IndexDomain>> ShardingParamSharding::IndexDomains(
const Shape& shape) const {
DCHECK(this);
return IndexDomains(shape, SingleDeviceShardSemantics::kAllShards);
}

absl::StatusOr<std::vector<IndexDomain>> ShardingParamSharding::IndexDomains(
const Shape& shape,
SingleDeviceShardSemantics single_device_shard_semantics) const {
DCHECK(this);

// Calculate the origins of tiles, ignoring device assignments.
TF_ASSIGN_OR_RETURN(Shape local_shape, GetShardShape(shape));
Expand All @@ -718,12 +886,22 @@ absl::StatusOr<std::vector<IndexDomain>> ShardingParamSharding::IndexDomains(
DCHECK_EQ(device_to_index.size() % origins.size(), 0);
int replication = device_to_index.size() / origins.size();

DCHECK_EQ(device_to_index.size(), devices_->size());
std::vector<IndexDomain> result;
result.reserve(device_to_index.size());
if (single_device_shard_semantics == SingleDeviceShardSemantics::kAllShards) {
result.reserve(devices_->size());
} else {
result.reserve(devices_->AddressableDeviceList()->size());
}
const absl::Span<Device* const> devices = devices_->devices();
for (int i = 0; i < device_to_index.size(); ++i) {
int index = device_to_index[i];
DCHECK_NE(index, kInvalidIndex);
result.push_back(IndexDomain(origins[index / replication], local_shape));
if (single_device_shard_semantics ==
SingleDeviceShardSemantics::kAllShards ||
devices[i]->IsAddressable()) {
int index = device_to_index[i];
DCHECK_NE(index, kInvalidIndex);
result.push_back(IndexDomain(origins[index / replication], local_shape));
}
}
return result;
}
Expand Down
Loading

0 comments on commit 032b62e

Please sign in to comment.