Skip to content

Commit

Permalink
Lowering dynamic slice fusion to command buffer
Browse files Browse the repository at this point in the history
  • Loading branch information
shawnwang18 committed Oct 15, 2024
1 parent 4c3bfad commit 87e5f79
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 60 deletions.
1 change: 1 addition & 0 deletions xla/service/gpu/transforms/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,7 @@ cc_library(
srcs = ["command_buffer_scheduling.cc"],
hdrs = ["command_buffer_scheduling.h"],
deps = [
":dynamic_slice_fusion_rewriter",
"//xla:shape_util",
"//xla:util",
"//xla/ffi:ffi_api",
Expand Down
113 changes: 63 additions & 50 deletions xla/service/gpu/transforms/command_buffer_scheduling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ limitations under the License.
#include <variant>
#include <vector>

#include "dynamic_slice_fusion_rewriter.h"
#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
Expand Down Expand Up @@ -244,20 +245,33 @@ static bool IsCommand(const HloInstruction* hlo,
return config.enabled_commands.contains(DebugOptions::CUDNN);
}
const auto& custom_config = backend_config.custom_fusion_config();
if (custom_config.name() == "address_computation") {
auto fusion_analysis =
HloFusionAnalysis::Create(*hlo, config.device_description);
const HloFusionAdaptor& adaptor = fusion_analysis.fusion();
auto hero_adaptor =
HloBfsFindIf(adaptor.GetRoots(), adaptor, [](auto node) {
return node.opcode() == HloOpcode::kCustomCall ||
node.opcode() == HloOpcode::kReduceScatter;
});
const HloInstruction* hero = &hero_adaptor->instruction();
return IsCommand(hero, config) || IsAsyncStartCommand(hero, config);
auto fusion_analysis =
HloFusionAnalysis::Create(*hlo, config.device_description);
const HloFusionAdaptor& adaptor = fusion_analysis.fusion();
auto hero_adaptor =
HloBfsFindIf(adaptor.GetRoots(), adaptor, [](auto node) {
return node.opcode() == HloOpcode::kCustomCall ||
node.opcode() == HloOpcode::kReduceScatter;
});
const HloInstruction* hero = &hero_adaptor->instruction();

if (!(IsCommand(hero, config) || IsAsyncStartCommand(hero, config))) {
return false;
}
if (custom_config.name() == "dynamic_address_computation") {
return false;
if (!config.enabled_commands.contains(DebugOptions::DYNAMIC_SLICE)) {
return false;
}
// Check all offsets of slice instructions are constant or loop
// iterations
bool all_slice_valid = llvm::all_of(
fusion->called_computation()->instructions(),
[](const HloInstruction* inst) {
auto* slice_inst = DynCast<HloDynamicIndexInstruction>(inst);
if (!slice_inst) return true;
return HasConstantOrLoopIterationOffsets(*slice_inst);
});
return all_slice_valid;
}
return config.enabled_commands.contains(DebugOptions::FUSION);
}
Expand Down Expand Up @@ -337,11 +351,11 @@ CommandBufferScheduling::CollectCommandBufferSequences(

auto& instructions = schedule.instructions();

// Collect the sequence of instructions that contains the async start and its
// corresponding done instruction. If there is another start instruction
// between the original start and done, we may potentially extend the sequence
// to include its corresponding done instruction. For example, if we call this
// function on async-start_a in the following sequence:
// Collect the sequence of instructions that contains the async start and
// its corresponding done instruction. If there is another start instruction
// between the original start and done, we may potentially extend the
// sequence to include its corresponding done instruction. For example, if
// we call this function on async-start_a in the following sequence:
//
// async_start_a
// async_start_b
Expand Down Expand Up @@ -369,8 +383,8 @@ CommandBufferScheduling::CollectCommandBufferSequences(
return seq;
};

// Check that instructions are safe to be captured by command buffer, and that
// we do not capture unmatched async done instruction.
// Check that instructions are safe to be captured by command buffer, and
// that we do not capture unmatched async done instruction.
auto check_async_region = [&](const HloInstructionSequence& seq) {
if (!absl::c_all_of(seq.instructions(), [&](HloInstruction* inst) {
return IsNoOp(inst) || IsCommand(inst, config) ||
Expand All @@ -397,10 +411,10 @@ CommandBufferScheduling::CollectCommandBufferSequences(
for (size_t i = 0; i < instructions.size(); i++) {
HloInstruction* inst = instructions.at(i);

// We add no-op instructions to current sequence only if they act as a glue
// between commands. We do not create command sequences consisting only from
// no-op instruction. First and last instruction in the command buffer is
// always a load-bearing command.
// We add no-op instructions to current sequence only if they act as a
// glue between commands. We do not create command sequences consisting
// only from no-op instruction. First and last instruction in the command
// buffer is always a load-bearing command.
if (IsNoOp(inst) && num_commands_in_current_seq) {
current_seq.push_back(inst);
continue;
Expand Down Expand Up @@ -437,12 +451,11 @@ CommandBufferScheduling::CollectCommandBufferSequences(
return sequences;
}

// This function moves kParameter and kConstant instructions in a computation to
// the beginning of the computation. This simplifies the construction of command
// buffer computations because we don't need to deal with parameters and
// constants that have users outside of a command buffer.
// Returns true if there is a change in the order of instructions, false
// otherwise.
// This function moves kParameter and kConstant instructions in a computation
// to the beginning of the computation. This simplifies the construction of
// command buffer computations because we don't need to deal with parameters
// and constants that have users outside of a command buffer. Returns true if
// there is a change in the order of instructions, false otherwise.
absl::StatusOr<bool> CommandBufferScheduling::MoveParametersAndConstantsToFront(
HloComputation* computation) {
HloInstructionSequence new_sequence;
Expand All @@ -454,9 +467,9 @@ absl::StatusOr<bool> CommandBufferScheduling::MoveParametersAndConstantsToFront(
new_sequence.push_back(inst);

// Because we move instruction to the front of the computation we can't
// have any control predecessors, however silently dropping them is unsafe
// as we can have transitive dependencies that define schedule order, so
// we forward control predecessors to all users.
// have any control predecessors, however silently dropping them is
// unsafe as we can have transitive dependencies that define schedule
// order, so we forward control predecessors to all users.
for (HloInstruction* control_predecessor : inst->control_predecessors()) {
for (HloInstruction* user : inst->users()) {
TF_RETURN_IF_ERROR(control_predecessor->AddControlDependencyTo(user));
Expand Down Expand Up @@ -495,18 +508,18 @@ absl::StatusOr<CommandBuffer> CommandBufferScheduling::PrepareCommandBuffer(
absl::flat_hash_set<HloInstruction*> in_command_buffer(instructions.begin(),
instructions.end());

// The sequence might use results of instructions that are not captured by the
// sequence. We pass those results as parameters and map the producers of the
// results to their corresponding parameter instructions.
// The sequence might use results of instructions that are not captured by
// the sequence. We pass those results as parameters and map the producers
// of the results to their corresponding parameter instructions.
absl::flat_hash_map<HloInstruction*, HloParameterInstruction*> parameters;

// Mapping from command buffer instructions to their clones in the command
// buffer computation body.
absl::flat_hash_map<HloInstruction*, HloInstruction*> inst_mapping;

// Maps HLO instructions in the original computation to instructions in the
// command buffer: (a) a parameter corresponding to captured value (b) cloned
// instruction corresponding to a command.
// command buffer: (a) a parameter corresponding to captured value (b)
// cloned instruction corresponding to a command.
auto mapped_operands = [&](HloInstruction* instr) {
absl::InlinedVector<HloInstruction*, 4> operands;
for (HloInstruction* operand : instr->operands()) {
Expand Down Expand Up @@ -608,9 +621,9 @@ absl::StatusOr<HloComputation*> CommandBufferScheduling::RewriteCommandBuffer(
if (command_buffer.results.empty())
return absl::InternalError("command buffer results must not be empty");

// If we have more than one result we return them as tuple, and get individual
// values using `get-tuple-element` instructions. Otherwise we simply return
// a result from a command buffer computation.
// If we have more than one result we return them as tuple, and get
// individual values using `get-tuple-element` instructions. Otherwise we
// simply return a result from a command buffer computation.
Shape cmd_buffer_result_shape;
bool has_single_result = command_buffer.results.size() == 1;

Expand Down Expand Up @@ -644,9 +657,9 @@ absl::StatusOr<HloComputation*> CommandBufferScheduling::RewriteCommandBuffer(
// As we are running after scheduling we have to keep it valid.
HloSchedule& schedule = parent->parent()->schedule();

// Update schedule to replace the last instruction with a command buffer call.
// Removal of the rest of the instructions in the sequence is handled by
// schedule update below.
// Update schedule to replace the last instruction with a command buffer
// call. Removal of the rest of the instructions in the sequence is handled
// by schedule update below.
HloInstructionSequence& sequence = schedule.GetOrCreateSequence(parent);
sequence.replace_instruction(seq.instructions().back(), call);

Expand Down Expand Up @@ -702,8 +715,8 @@ absl::StatusOr<HloComputation*> CommandBufferScheduling::RewriteCommandBuffer(
TF_RETURN_IF_ERROR(inst->DropAllControlDeps());
}

// Traverse in reverse order as original sequence was topologically sorted and
// we can't remove instructions with users.
// Traverse in reverse order as original sequence was topologically sorted
// and we can't remove instructions with users.
for (int32_t i = seq.instructions().size() - 1; i >= 0; i--) {
TF_RETURN_IF_ERROR(parent->RemoveInstruction(seq.instructions()[i]));
}
Expand All @@ -721,10 +734,10 @@ absl::StatusOr<bool> CommandBufferScheduling::Run(
HloModule* module,
const absl::flat_hash_set<absl::string_view>& execution_threads) {
// We run command buffer scheduling after a regular scheduling to guarantee
// that command buffers will not change execution order and buffer assignment
// compared to a regular execution. Some operations (i.e. async collectives)
// can't be captured into command buffers, and forming too large command
// buffers too early can impact async operations scheduling.
// that command buffers will not change execution order and buffer
// assignment compared to a regular execution. Some operations (i.e. async
// collectives) can't be captured into command buffers, and forming too
// large command buffers too early can impact async operations scheduling.
if (!module->has_schedule()) return Internal("module is not scheduled");

const DebugOptions& debug_options = module->config().debug_options();
Expand Down Expand Up @@ -776,7 +789,7 @@ absl::StatusOr<bool> CommandBufferScheduling::Run(
if (std::min(device_description_.runtime_version(),
device_description_.driver_version()) <
se::SemanticVersion{12, 3, 0}) {
erase(kRequireTracing); // cuStreamBeginCaptureToGraph
erase(kRequireTracing); // cuStreamBeginCaptureToGraph
erase(kRequireConditionals); // on-device control flow
}
};
Expand Down
18 changes: 8 additions & 10 deletions xla/service/gpu/transforms/dynamic_slice_fusion_rewriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -234,16 +234,6 @@ bool IsHandledConstantForDynamicSliceFusion(const HloInstruction& offset) {
return false;
}

// This checks whether a dynamic index operation has all offsets that are either
// constant or loop iteration offsets.
bool HasConstantOrLoopIterationOffsets(
const HloDynamicIndexInstruction& instr) {
return llvm::all_of(instr.index_operands(), [](const HloInstruction* offset) {
return IsLoopIterationNumber(*offset) ||
IsHandledConstantForDynamicSliceFusion(*offset);
});
}

UseDefDataflowPaths GetSlicedOperandPaths(const HloInstruction* instr) {
UseDefDataflowPaths sliced_operand_paths;

Expand Down Expand Up @@ -511,6 +501,14 @@ absl::StatusOr<HloInstruction*> CreateFusionInstruction(

} // namespace

bool HasConstantOrLoopIterationOffsets(
const HloDynamicIndexInstruction& instr) {
return llvm::all_of(instr.index_operands(), [](const HloInstruction* offset) {
return IsLoopIterationNumber(*offset) ||
IsHandledConstantForDynamicSliceFusion(*offset);
});
}

absl::StatusOr<bool> DynamicSliceFusionRewriter::Run(
HloModule* module,
const absl::flat_hash_set<absl::string_view>& execution_threads) {
Expand Down
5 changes: 5 additions & 0 deletions xla/service/gpu/transforms/dynamic_slice_fusion_rewriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ limitations under the License.
#include "absl/strings/string_view.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/hlo/pass/hlo_pass_interface.h"
#include "xla/hlo/ir/hlo_instructions.h"

namespace xla {
namespace gpu {
Expand Down Expand Up @@ -85,6 +86,10 @@ class DynamicSliceFusionRewriter : public HloModulePass {
std::string platform_name_;
};

// This checks whether a dynamic index operation has all offsets that are either
// constant or loop iteration offsets.
bool HasConstantOrLoopIterationOffsets(const HloDynamicIndexInstruction& instr);

} // namespace gpu
} // namespace xla

Expand Down
1 change: 1 addition & 0 deletions xla/xla.proto
Original file line number Diff line number Diff line change
Expand Up @@ -586,6 +586,7 @@ message DebugOptions {
CONDITIONALS = 5;
CUSTOM_CALL = 6;
CUBLASLT = 7;
DYNAMIC_SLICE = 8;
}

// Determine the types of commands that are recorded into command buffers.
Expand Down

0 comments on commit 87e5f79

Please sign in to comment.