From 187c45e03b44c92b37caff3f2d1ba615538f0e59 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Mon, 28 Oct 2024 15:26:19 -0700 Subject: [PATCH 1/3] [shortfin] Implement fiber local program forking. (#332) API changes: * sf.Program(..., fiber=) no longer takes a fiber. The fiber is now moved to the invocation (i.e. function(..., fiber=) * sf.Program now requires an explicit devices= (as opposed to implicitly inferring this from the fiber). The prior behavior can be achieved by sf.Program(, devices=fiber.raw_devices). Semantic changes: * Programs are now created with an isolation level (NONE, PER_FIBER, PER_CALL), defaulting to PER_FIBER. The prior behavior was like NONE. * In PER_FIBER mode, only one in-flight call to a program can happen at a time on a single fiber. When a program is used across fibers, each fiber will get their own context fork (which is a light-weight clone of the context that shares immutable data). * In PER_CALL mode, backing contexts are pooled on each calling fiber and each call gets their own context. This supports multiple concurrent invocations to a context at the expense of each call having its own mutable state (i.e. this is only suitable for immutable programs). --- shortfin/python/lib_ext.cc | 38 ++-- shortfin/src/shortfin/local/fiber.h | 17 ++ shortfin/src/shortfin/local/program.cc | 184 ++++++++++++++---- shortfin/src/shortfin/local/program.h | 97 +++++++-- shortfin/tests/invocation/conftest.py | 34 ++-- .../invocation/mobilenet_program_test.py | 97 ++++++--- 6 files changed, 357 insertions(+), 110 deletions(-) diff --git a/shortfin/python/lib_ext.cc b/shortfin/python/lib_ext.cc index dca171d67..c73bf5a93 100644 --- a/shortfin/python/lib_ext.cc +++ b/shortfin/python/lib_ext.cc @@ -248,8 +248,9 @@ void PyAddProgramInvocationArg(py::capsule &inv_capsule, py::handle arg) { } local::ProgramInvocation::Future PyFunctionCall(local::ProgramFunction &self, - py::args args) { - auto inv = self.CreateInvocation(); + py::args args, + local::Fiber &fiber) { + auto inv = self.CreateInvocation(fiber.shared_from_this()); py::capsule inv_capsule(inv.get()); for (py::handle arg : args) { PyAddProgramInvocationArg(inv_capsule, arg); @@ -592,13 +593,14 @@ void BindLocal(py::module_ &m) { py::class_(m, "Program") .def(py::new_([](std::span modules, - local::Fiber &fiber, bool trace_execution) { + std::vector devices, + bool trace_execution) { local::Program::Options options; + options.devices = devices; options.trace_execution = trace_execution; - return local::Program::Load(fiber.shared_from_this(), modules, - std::move(options)); + return local::Program::Load(modules, std::move(options)); }), - py::arg("modules"), py::arg("fiber"), py::kw_only(), + py::arg("modules"), py::kw_only(), py::arg("devices"), py::arg("trace_execution") = false) .def_prop_ro("exports", &local::Program::exports) .def("lookup_function", &local::Program::LookupRequiredFunction) @@ -607,9 +609,14 @@ void BindLocal(py::module_ &m) { .def_prop_ro("name", &local::ProgramFunction::name) .def_prop_ro("calling_convention", &local::ProgramFunction::calling_convention) - .def("invocation", &local::ProgramFunction::CreateInvocation, - DOCSTRING_PROGRAM_FUNCTION_INVOCATION) - .def("__call__", PyFunctionCall, py::arg("args")) + .def( + "invocation", + [](local::ProgramFunction &self, local::Fiber &fiber) { + return self.CreateInvocation(fiber.shared_from_this()); + }, + DOCSTRING_PROGRAM_FUNCTION_INVOCATION) + .def("__call__", PyFunctionCall, py::arg("args"), py::kw_only(), + py::arg("fiber")) .def("__repr__", &local::ProgramFunction::to_s); py::class_(m, "ProgramModule") .def_prop_ro("exports", &local::ProgramModule::exports) @@ -718,8 +725,17 @@ void BindLocal(py::module_ &m) { }; py::class_(m, "Fiber") .def("__repr__", &local::Fiber::to_s) - .def_prop_ro("raw_devices", &local::Fiber::raw_devices, - py::rv_policy::reference_internal) + .def_prop_ro( + "raw_devices", + [](local::Fiber &self) { + std::vector devices; + devices.reserve(self.raw_devices().size()); + for (auto it : self.raw_devices()) { + devices.push_back(it.second); + } + return devices; + }, + py::rv_policy::reference_internal) .def( "raw_device", [](local::Fiber &self, int index) { return self.raw_device(index); }, diff --git a/shortfin/src/shortfin/local/fiber.h b/shortfin/src/shortfin/local/fiber.h index dd63b30f4..afd65b346 100644 --- a/shortfin/src/shortfin/local/fiber.h +++ b/shortfin/src/shortfin/local/fiber.h @@ -146,6 +146,23 @@ class SHORTFIN_API Fiber : public std::enable_shared_from_this { std::unordered_map device_class_count_; // Ordered devices named as ``. std::vector> devices_; + + // Program isolation control. + // This data structure is manipulated by APIs on the Program class hierarchy. + // It maps a parent context pointer to an isolate accounting struct. This + // struct contains a strong reference to the parent_context and a vector + // of fork contexts. For PER_FIBER invocations, there will only ever be either + // zero or one fork_contexts: when no calls have been issued there will be one + // and if a call is outstanding, there will be zero. This is used to guard + // concurrent access. For PER_CALL invocations, there will be as many + // fork_contexts as are needed to satisfy the peak number of calls in flight + // at any time. + // The program_isolate_mu_ must be held to manipulate the accounting structs. + iree::slim_mutex program_isolate_mu_; + std::unordered_map> + program_isolates_; + friend struct detail::ProgramIsolate; }; } // namespace shortfin::local diff --git a/shortfin/src/shortfin/local/program.cc b/shortfin/src/shortfin/local/program.cc index 2f2d95ab4..038cd106a 100644 --- a/shortfin/src/shortfin/local/program.cc +++ b/shortfin/src/shortfin/local/program.cc @@ -36,12 +36,12 @@ void GetVmModuleExports(iree_vm_module_t *vm_module, // -------------------------------------------------------------------------- // ProgramFunction::ProgramFunction( - std::shared_ptr fiber, iree::vm_context_ptr vm_context, - iree_vm_function_t vm_function, + iree::vm_context_ptr vm_context, iree_vm_function_t vm_function, + ProgramIsolation isolation, std::optional invocation_model) - : fiber_(std::move(fiber)), - vm_context_(std::move(vm_context)), + : vm_context_(std::move(vm_context)), vm_function_(vm_function), + isolation_(isolation), invocation_model_(invocation_model ? *invocation_model : GetInvocationModelFromFunction(vm_function)) {} @@ -73,9 +73,19 @@ std::string_view ProgramFunction::calling_convention() const { iree_vm_function_signature(&vm_function_).calling_convention); } -ProgramInvocation::Ptr ProgramFunction::CreateInvocation() { - return ProgramInvocation::New(fiber_, vm_context_, vm_function_, - invocation_model_); +ProgramInvocation::Ptr ProgramFunction::CreateInvocation( + std::shared_ptr fiber) { + // Low-overhead NONE isolation handling (saves some ref-count twiddling). + if (isolation_ == ProgramIsolation::NONE) { + return ProgramInvocation::New(std::move(fiber), vm_context_, vm_function_, + invocation_model_, /*isolate=*/nullptr); + } + + // Create an isolated invocation. + auto [isolated_context, isolate] = + detail::ProgramIsolate::AcquireIsolate(*fiber, vm_context_, isolation_); + return ProgramInvocation::New(std::move(fiber), std::move(isolated_context), + vm_function_, invocation_model_, isolate); } std::string ProgramFunction::to_s() const { @@ -106,7 +116,7 @@ ProgramModule ProgramModule::Load(System &system, system.vm_instance(), contents.const_buffer(), contents.deallocator(), system.host_allocator(), module.for_output())); contents.release(); // Must be invoked on success path only. - return ProgramModule(std::move(module)); + return ProgramModule(system.shared_from_this(), std::move(module)); } ProgramModule ProgramModule::ParameterProvider( @@ -126,7 +136,7 @@ ProgramModule ProgramModule::ParameterProvider( SHORTFIN_THROW_IF_ERROR(iree_io_parameters_module_create( system.vm_instance(), providers.size(), providers.data(), system.host_allocator(), module.for_output())); - return ProgramModule(std::move(module)); + return ProgramModule(system.shared_from_this(), std::move(module)); } std::string_view ProgramModule::name() const { @@ -158,14 +168,27 @@ std::vector ProgramModule::exports() const { // Program // -------------------------------------------------------------------------- // -Program Program::Load(std::shared_ptr fiber, - std::span modules, Options options) { +Program Program::Load(std::span modules, + Options &&options) { std::vector all_modules; std::vector raw_devices; + System *system = nullptr; // By default, bind all devices in the fiber in order to the program. - for (auto &it : fiber->raw_devices()) { - raw_devices.push_back(it.second->hal_device()); + for (auto &it : options.devices) { + raw_devices.push_back(it->hal_device()); + } + + for (auto &mod : modules) { + if (system && &mod.system() != system) { + throw std::invalid_argument( + "Cannot create Program from modules loaded from multiple system " + "instances"); + } + system = &mod.system(); + } + if (!system) { + throw std::invalid_argument("Cannot create Program with no modules"); } // Add a HAL module. @@ -177,12 +200,11 @@ Program Program::Load(std::shared_ptr fiber, // functionality (or module versions; iree_vm_module_dependency_t has the // minimum version required so you can switch between them, and whether they // are optional/required). - auto &system = fiber->system(); iree::vm_module_ptr hal_module; - SHORTFIN_THROW_IF_ERROR( - iree_hal_module_create(system.vm_instance(), raw_devices.size(), - raw_devices.data(), IREE_HAL_MODULE_FLAG_NONE, - system.host_allocator(), hal_module.for_output())); + SHORTFIN_THROW_IF_ERROR(iree_hal_module_create( + system->vm_instance(), raw_devices.size(), raw_devices.data(), + IREE_HAL_MODULE_FLAG_NONE, system->host_allocator(), + hal_module.for_output())); all_modules.push_back(hal_module); // Add explicit modules. @@ -195,10 +217,10 @@ Program Program::Load(std::shared_ptr fiber, iree_vm_context_flags_t flags = IREE_VM_CONTEXT_FLAG_CONCURRENT; if (options.trace_execution) flags |= IREE_VM_CONTEXT_FLAG_TRACE_EXECUTION; SHORTFIN_THROW_IF_ERROR(iree_vm_context_create_with_modules( - system.vm_instance(), flags, all_modules.size(), all_modules.data(), - system.host_allocator(), context.for_output())); + system->vm_instance(), flags, all_modules.size(), all_modules.data(), + system->host_allocator(), context.for_output())); - return Program(std::move(fiber), std::move(context)); + return Program(std::move(context), options.isolation); } std::optional Program::LookupFunction(std::string_view name) { @@ -217,7 +239,7 @@ std::optional Program::LookupFunction(std::string_view name) { // TODO: Torch import is not setting the coarse-fences abi.model on // its functions. Get it from there instead of just assuming based on // name. - return ProgramFunction(fiber_, vm_context_, f, + return ProgramFunction(vm_context_, f, isolation_, ProgramInvocationModel::COARSE_FENCES); } else if (!iree_status_is_not_found(status)) { SHORTFIN_THROW_IF_ERROR(status); @@ -229,7 +251,7 @@ std::optional Program::LookupFunction(std::string_view name) { vm_context_, to_iree_string_view(name), &f); if (iree_status_is_not_found(status)) return {}; SHORTFIN_THROW_IF_ERROR(status); - return ProgramFunction(fiber_, vm_context_, f); + return ProgramFunction(vm_context_, f, isolation_); } ProgramFunction Program::LookupRequiredFunction(std::string_view name) { @@ -260,6 +282,15 @@ std::vector Program::exports() const { return results; } +void Program::PrepareIsolate(Fiber &fiber) { + if (isolation_ == ProgramIsolation::NONE) return; + auto [context, isolate] = + detail::ProgramIsolate::AcquireIsolate(fiber, vm_context_, isolation_); + if (isolate) { + detail::ProgramIsolate::ReleaseIsolate(fiber, std::move(context), isolate); + } +} + // -------------------------------------------------------------------------- // // ProgramInvocation // -------------------------------------------------------------------------- // @@ -287,18 +318,23 @@ void ProgramInvocation::Deleter::operator()(ProgramInvocation *inst) { } ProgramInvocation::ProgramInvocation() = default; -ProgramInvocation::~ProgramInvocation() { - if (!scheduled()) { - // This instance was dropped on the floor before scheduling. - // Clean up the initialization parameters. - iree::vm_context_ptr drop = - iree::vm_context_ptr::steal_reference(state.params.context); +ProgramInvocation::~ProgramInvocation() { ReleaseContext(); } + +void ProgramInvocation::ReleaseContext() { + if (vm_context_) { + if (isolate_) { + detail::ProgramIsolate::ReleaseIsolate(*fiber_, std::move(vm_context_), + isolate_); + } else { + vm_context_.reset(); + } } } ProgramInvocation::Ptr ProgramInvocation::New( std::shared_ptr fiber, iree::vm_context_ptr vm_context, - iree_vm_function_t &vm_function, ProgramInvocationModel invocation_model) { + iree_vm_function_t &vm_function, ProgramInvocationModel invocation_model, + detail::ProgramIsolate *isolate) { auto sig = iree_vm_function_signature(&vm_function); iree_host_size_t arg_count; iree_host_size_t result_count; @@ -337,8 +373,8 @@ ProgramInvocation::Ptr ProgramInvocation::New( static_cast(inst_storage.release())), Deleter()); inst->fiber_ = std::move(fiber); - inst->state.params.context = - vm_context.release(); // Ref transfer to ProgramInvocation. + inst->vm_context_ = std::move(vm_context); + inst->isolate_ = isolate; inst->state.params.function = vm_function; inst->state.params.invocation_model = invocation_model; inst->result_list_ = result_list; @@ -421,7 +457,6 @@ ProgramInvocation::Future ProgramInvocation::Invoke( Params params = invocation->state.params; auto schedule = [](ProgramInvocation *raw_invocation, Worker *worker, - iree_vm_context_t *owned_context, iree_vm_function_t function, ProgramInvocationModel invocation_model, std::optional failure_future) { @@ -440,6 +475,7 @@ ProgramInvocation::Future ProgramInvocation::Invoke( ProgramInvocation::Ptr invocation( static_cast(user_data)); ProgramInvocation *raw_invocation = invocation.get(); + raw_invocation->ReleaseContext(); if (iree_status_is_ok(status)) { raw_invocation->future_->set_result(std::move(invocation)); } else { @@ -469,7 +505,7 @@ ProgramInvocation::Future ProgramInvocation::Invoke( if (iree_status_is_ok(status)) { status = iree_vm_async_invoke(worker->loop(), &invocation->state.async_invoke_state, - owned_context, function, + invocation->vm_context_.get(), function, /*flags=*/IREE_VM_INVOCATION_FLAG_NONE, /*policy=*/nullptr, /*inputs=*/invocation->arg_list(), @@ -478,10 +514,6 @@ ProgramInvocation::Future ProgramInvocation::Invoke( /*user_data=*/invocation.get()); } - // Regardless of status, the context reference we were holding is no - // longer needed. Drop it on the floor. - iree::vm_context_ptr::steal_reference(owned_context); - // On success, then the complete callback takes ownership of the // invocation, so we release it here and return. We have to treat // the invocation as possibly deallocated at this point, since the @@ -490,9 +522,11 @@ ProgramInvocation::Future ProgramInvocation::Invoke( invocation.release(); } else if (failure_future) { // Requested to set any failure on the future. + invocation->ReleaseContext(); failure_future->set_failure(status); } else { // Synchronous: just throw. + invocation->ReleaseContext(); SHORTFIN_THROW_IF_ERROR(status); } }; @@ -504,14 +538,13 @@ ProgramInvocation::Future ProgramInvocation::Invoke( if (&worker == Worker::GetCurrent()) { // On the same worker: fast-path directly to the loop. - schedule(invocation.release(), &worker, params.context, params.function, + schedule(invocation.release(), &worker, params.function, params.invocation_model, /*failure_future=*/{}); } else { // Cross worker coordination: submit an external task to bootstrap. - auto bound_schedule = - std::bind(schedule, invocation.release(), &worker, params.context, - params.function, params.invocation_model, - /*failure_future=*/fork_future); + auto bound_schedule = std::bind(schedule, invocation.release(), &worker, + params.function, params.invocation_model, + /*failure_future=*/fork_future); worker.CallThreadsafe(bound_schedule); } @@ -623,4 +656,69 @@ void StaticProgramParameters::Load(std::filesystem::path file_path, to_iree_string_view(options.format), file_handle.get(), index_.get())); } +// -------------------------------------------------------------------------- // +// ProgramIsolate +// -------------------------------------------------------------------------- // + +std::pair +detail::ProgramIsolate::AcquireIsolate(Fiber &fiber, + iree::vm_context_ptr root_context, + ProgramIsolation isolation) { + assert(isolation != ProgramIsolation::NONE && + "cannot AcquireIsolate when isolation == NONE"); + // Some isolation required. + detail::ProgramIsolate *isolate = nullptr; + { + iree::slim_mutex_lock_guard lock(fiber.program_isolate_mu_); + auto found_it = fiber.program_isolates_.find(root_context.get()); + if (found_it != fiber.program_isolates_.end()) { + isolate = found_it->second.get(); + } + if (isolate && !isolate->fork_contexts.empty()) { + // Fast path: there is an existing isolate and a context avaialable. + auto isolated_context = std::move(isolate->fork_contexts.back()); + isolate->fork_contexts.pop_back(); + return std::make_pair(std::move(isolated_context), isolate); + } else if (!isolate) { + // Initialize a new isolate accounting struct while in the lock. + // Note that this can cause a fault for PER_FIBER mode if the call + // to fork fails below as it will leave the isolate with no available + // context and every future call will raise an exception indicating that + // the context is busy (vs trying to create a new one). This is deemed + // an acceptable situation for a system fault (which is the only reason + // a fork will fail). + auto [inserted_it, inserted] = + fiber.program_isolates_.insert(std::make_pair( + root_context.get(), + std::make_unique(root_context))); + isolate = inserted_it->second.get(); + } else if (isolation == ProgramIsolation::PER_FIBER) { + throw std::logic_error( + "Cannot make concurrent invocations of a PER_FIBER program from " + "the same Fiber. This typically means that two invocations were " + "attempted on the same program on the same fiber without an " + "await. Consider fixing adding appropriate sequencing or switching " + "to either PER_CALL or NONE isolation if appropriate for the use " + "case. This exception can also occur if the first invocation to this " + "Program failed, leaving no initialized Program for this fiber."); + } + } + + // Slow-path: fork needed (and possibly new isolate registration needed). + iree::vm_context_ptr new_context; + SHORTFIN_THROW_IF_ERROR(iree_vm_context_fork( + root_context.get(), fiber.host_allocator(), new_context.for_output())); + return std::make_pair(std::move(new_context), isolate); +} + +void detail::ProgramIsolate::ReleaseIsolate(Fiber &fiber, + iree::vm_context_ptr context, + detail::ProgramIsolate *isolate) { + assert(isolate && "attempt to release null isolate"); + { + iree::slim_mutex_lock_guard lock(fiber.program_isolate_mu_); + isolate->fork_contexts.push_back(std::move(context)); + } +} + } // namespace shortfin::local diff --git a/shortfin/src/shortfin/local/program.h b/shortfin/src/shortfin/local/program.h index bc5ae05dc..ea4f0cc3f 100644 --- a/shortfin/src/shortfin/local/program.h +++ b/shortfin/src/shortfin/local/program.h @@ -26,6 +26,10 @@ class BaseProgramParameters; class Fiber; class System; +namespace detail { +struct ProgramIsolate; +} // namespace detail + enum class ProgramInvocationModel { // Uses the coarse-fences invocation model. In this model, the last two // arguments are a wait and signal fence, which are used for function-level @@ -37,6 +41,24 @@ enum class ProgramInvocationModel { UNKNOWN, }; +// The level of isolation that a program has with respect to concurrent use. +enum class ProgramIsolation { + // There is no isolation: Callers are completely on their own to only issue + // concurrent invocations if supported. + NONE = 0, + + // Each fiber in the system that makes calls into the program will have its + // own shallow fork of the module. This is done on-demand and the root + // program is retained for the lifetime of any referencing fibers. + // Concurrent calls on the same fiber are considered programming errors and + // will be flagged as such at an appropriate debug level. + PER_FIBER = 1, + + // Each call triggers a shallow fork of the module. This is the most expensive + // but safest way to ensure complete isolation of stateless invocations. + PER_CALL = 2, +}; + // State related to making an invocation of a function on a program. // // Since ownership of this object is transferred to the loop/callback and @@ -67,7 +89,8 @@ class SHORTFIN_API ProgramInvocation { static Ptr New(std::shared_ptr fiber, iree::vm_context_ptr vm_context, iree_vm_function_t &vm_function, - ProgramInvocationModel invocation_model); + ProgramInvocationModel invocation_model, + detail::ProgramIsolate *isolate); ProgramInvocation(const ProgramInvocation &) = delete; ProgramInvocation &operator=(const ProgramInvocation &) = delete; ProgramInvocation &operator=(ProgramInvocation &&) = delete; @@ -133,6 +156,11 @@ class SHORTFIN_API ProgramInvocation { private: ProgramInvocation(); void CheckNotScheduled(); + // Eagerly releases context when it is known that no further use of it can + // be made (allowing it to be returned to a pool prior to the invocation + // actually being recycled). Object destruction also does this, but possibly + // extending the context lifetime. + void ReleaseContext(); // Returns a pointer to the trailing arg list. iree_vm_list_t *arg_list(); @@ -156,8 +184,6 @@ class SHORTFIN_API ProgramInvocation { // This must not contain entities that require destruction or cannot be // trivially copied. struct Params { - // Context is retained upon construction and released when scheduled. - iree_vm_context_t *context; iree_vm_function_t function; ProgramInvocationModel invocation_model; }; @@ -169,6 +195,8 @@ class SHORTFIN_API ProgramInvocation { } state; std::shared_ptr fiber_; + iree::vm_context_ptr vm_context_; + detail::ProgramIsolate *isolate_; iree_vm_list_t *result_list_ = nullptr; std::optional future_; iree::hal_fence_ptr wait_fence_; @@ -187,7 +215,7 @@ class SHORTFIN_API ProgramFunction { std::string_view calling_convention() const; ProgramInvocationModel invocation_model() const { return invocation_model_; } - ProgramInvocation::Ptr CreateInvocation(); + ProgramInvocation::Ptr CreateInvocation(std::shared_ptr fiber); std::string to_s() const; @@ -195,17 +223,16 @@ class SHORTFIN_API ProgramFunction { operator iree_vm_function_t &() { return vm_function_; } private: - ProgramFunction(std::shared_ptr fiber, iree::vm_context_ptr vm_context, - iree_vm_function_t vm_function, + ProgramFunction(iree::vm_context_ptr vm_context, + iree_vm_function_t vm_function, ProgramIsolation isolation, std::optional invocation_model = {}); static ProgramInvocationModel GetInvocationModelFromFunction( iree_vm_function_t &f); - // The context that this function was resolved against. - std::shared_ptr fiber_; iree::vm_context_ptr vm_context_; iree_vm_function_t vm_function_; + ProgramIsolation isolation_; ProgramInvocationModel invocation_model_; friend class Program; }; @@ -231,6 +258,7 @@ class SHORTFIN_API ProgramModule { std::string to_s() const; iree_vm_module_t *vm_module() const { return vm_module_; } std::string_view name() const; + System &system() const { return *system_; } // Loads a dynamic bytecode module (VMFB) from a path on the file system. static ProgramModule Load(System &system, const std::filesystem::path &path, @@ -246,10 +274,12 @@ class SHORTFIN_API ProgramModule { std::vector exports() const; protected: - explicit ProgramModule(iree::vm_module_ptr vm_module) - : vm_module_(std::move(vm_module)) {} + explicit ProgramModule(std::shared_ptr system, + iree::vm_module_ptr vm_module) + : system_(std::move(system)), vm_module_(std::move(vm_module)) {} private: + std::shared_ptr system_; iree::vm_module_ptr vm_module_; }; @@ -269,15 +299,19 @@ class SHORTFIN_API Program { struct Options { Options() {} + // Ordered list of devices to bind this program to. + std::span devices; + + // The isolation level to apply to program invocation. + ProgramIsolation isolation = ProgramIsolation::PER_FIBER; + // Enables program-wide execution tracing (to stderr). bool trace_execution = false; }; - // Loads a program attached to a fiber with a list of user provided modules - // and options. - static Program Load(std::shared_ptr fiber, - std::span modules, - Options options = {}); + // Load a program from a list of modules and options. + static Program Load(std::span modules, + Options &&options); // Looks up a public function by fully qualified name (i.e. module.function). // Returns nothing if not found. @@ -290,12 +324,16 @@ class SHORTFIN_API Program { // Gets the name of all exported functions. std::vector exports() const; + // Eagerly does any per-fiber isolation preparation for the program at a + // convenient point (usually init time) to avoid first-invocation overhead. + void PrepareIsolate(Fiber &fiber); + private: - explicit Program(std::shared_ptr fiber, - iree::vm_context_ptr vm_context) - : fiber_(std::move(fiber)), vm_context_(std::move(vm_context)) {} - std::shared_ptr fiber_; + explicit Program(iree::vm_context_ptr vm_context, ProgramIsolation isolation) + : vm_context_(std::move(vm_context)), isolation_(isolation) {} + iree::vm_context_ptr vm_context_; + ProgramIsolation isolation_; friend class Fiber; }; @@ -354,6 +392,27 @@ class SHORTFIN_API StaticProgramParameters : public BaseProgramParameters { iree::io_parameter_index_ptr index_; }; +namespace detail { +// See Fiber::program_isolates_. +struct ProgramIsolate { + ProgramIsolate(iree::vm_context_ptr parent_context) + : parent_context(std::move(parent_context)) {} + iree::vm_context_ptr parent_context; + std::vector fork_contexts; + + // Acquires an isolate for the given fiber. This will return a context which + // may be the original program context or may be a forked child that is + // available for use. It is only valid to call this when isolation != NONE. + static std::pair + AcquireIsolate(Fiber &fiber, iree::vm_context_ptr root_context, + ProgramIsolation isolation); + + // Releases an isolate obtained from a fiber in AcquireIsolate. + static void ReleaseIsolate(Fiber &fiber, iree::vm_context_ptr context, + ProgramIsolate *isolate); +}; +}; // namespace detail + } // namespace shortfin::local #endif // SHORTFIN_LOCAL_PROGRAM_H diff --git a/shortfin/tests/invocation/conftest.py b/shortfin/tests/invocation/conftest.py index c366c7f82..148ae064d 100644 --- a/shortfin/tests/invocation/conftest.py +++ b/shortfin/tests/invocation/conftest.py @@ -22,15 +22,16 @@ def mobilenet_onnx_path(tmp_path_factory): import onnx except ModuleNotFoundError: raise pytest.skip("onnx python package not available") - print("Downloading mobilenet.onnx") parent_dir = tmp_path_factory.mktemp("mobilenet_onnx") orig_onnx_path = parent_dir / "mobilenet_orig.onnx" - urllib.request.urlretrieve( - "https://github.com/onnx/models/raw/main/validated/vision/classification/mobilenet/model/mobilenetv2-12.onnx", - orig_onnx_path, - ) upgraded_onnx_path = parent_dir / "mobilenet.onnx" - upgrade_onnx(orig_onnx_path, upgraded_onnx_path) + if not upgraded_onnx_path.exists(): + print("Downloading mobilenet.onnx") + urllib.request.urlretrieve( + "https://github.com/onnx/models/raw/main/validated/vision/classification/mobilenet/model/mobilenetv2-12.onnx", + orig_onnx_path, + ) + upgrade_onnx(orig_onnx_path, upgraded_onnx_path) return upgraded_onnx_path @@ -41,15 +42,18 @@ def mobilenet_compiled_cpu_path(mobilenet_onnx_path): import iree.compiler.tools.import_onnx.__main__ as import_onnx except ModuleNotFoundError: raise pytest.skip("iree.compiler packages not available") - print("Compiling mobilenet") mlir_path = mobilenet_onnx_path.parent / "mobilenet.mlir" vmfb_path = mobilenet_onnx_path.parent / "mobilenet_cpu.vmfb" - args = import_onnx.parse_arguments(["-o", str(mlir_path), str(mobilenet_onnx_path)]) - import_onnx.main(args) - tools.compile_file( - str(mlir_path), - output_file=str(vmfb_path), - target_backends=["llvm-cpu"], - input_type="onnx", - ) + if not vmfb_path.exists(): + print("Compiling mobilenet") + args = import_onnx.parse_arguments( + ["-o", str(mlir_path), str(mobilenet_onnx_path)] + ) + import_onnx.main(args) + tools.compile_file( + str(mlir_path), + output_file=str(vmfb_path), + target_backends=["llvm-cpu"], + input_type="onnx", + ) return vmfb_path diff --git a/shortfin/tests/invocation/mobilenet_program_test.py b/shortfin/tests/invocation/mobilenet_program_test.py index 4275fe9e2..84903fb8f 100644 --- a/shortfin/tests/invocation/mobilenet_program_test.py +++ b/shortfin/tests/invocation/mobilenet_program_test.py @@ -5,6 +5,8 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception import array +import asyncio +import time import functools import pytest @@ -21,38 +23,89 @@ def lsys(): @pytest.fixture -def fiber(lsys): +def fiber0(lsys): return lsys.create_fiber() @pytest.fixture -def device(fiber): - return fiber.device(0) +def device(fiber0): + return fiber0.device(0) -def test_invoke_mobilenet(lsys, fiber, mobilenet_compiled_cpu_path): - device = fiber.device(0) +@pytest.fixture +def mobilenet_program_function( + lsys, mobilenet_compiled_cpu_path +) -> tuple[sf.ProgramFunction]: + program_module = lsys.load_module(mobilenet_compiled_cpu_path) + program = sf.Program([program_module], devices=lsys.devices) + main_function = program["module.torch-jit-export"] + return main_function + + +def get_mobilenet_ref_input(device) -> sfnp.device_array: dummy_data = array.array( "f", ([0.2] * (224 * 224)) + ([0.4] * (224 * 224)) + ([-0.2] * (224 * 224)) ) - program_module = lsys.load_module(mobilenet_compiled_cpu_path) - program = sf.Program([program_module], fiber=fiber) - main_function = program["module.torch-jit-export"] + device_input = sfnp.device_array(device, [1, 3, 224, 224], sfnp.float32) + staging_input = device_input.for_transfer() + with staging_input.map(discard=True) as m: + m.fill(dummy_data) + device_input.copy_from(staging_input) + return device_input + + +async def assert_mobilenet_ref_output(device, device_output): + host_output = device_output.for_transfer() + host_output.copy_from(device_output) + await device + flat_output = host_output.items + absmean = functools.reduce( + lambda x, y: x + abs(y) / len(flat_output), flat_output, 0.0 + ) + print("RESULT:", absmean) + assert absmean == pytest.approx(5.01964943873882) + + +def test_invoke_mobilenet(lsys, fiber0, mobilenet_program_function): + device = fiber0.device(0) async def main(): - device_input = sfnp.device_array(device, [1, 3, 224, 224], sfnp.float32) - staging_input = device_input.for_transfer() - with staging_input.map(discard=True) as m: - m.fill(dummy_data) - device_input.copy_from(staging_input) - (device_output,) = await main_function(device_input) - host_output = device_output.for_transfer() - host_output.copy_from(device_output) - await device - flat_output = host_output.items - absmean = functools.reduce( - lambda x, y: x + abs(y) / len(flat_output), flat_output, 0.0 - ) - assert absmean == pytest.approx(5.01964943873882) + device_input = get_mobilenet_ref_input(device) + (device_output,) = await mobilenet_program_function(device_input, fiber=fiber0) + await assert_mobilenet_ref_output(device, device_output) + + lsys.run(main()) + + +def test_invoke_mobilenet_multi_fiber(lsys, mobilenet_program_function): + class InferProcess(sf.Process): + async def run(self): + start_time = time.time() + + def duration(): + return round((time.time() - start_time) * 1000.0) + + print(f"{self}: Start") + device = self.fiber.device(0) + device_input = get_mobilenet_ref_input(device) + (device_output,) = await mobilenet_program_function( + device_input, fiber=self.fiber + ) + print(f"{self}: Program complete (+{duration()}ms)") + await assert_mobilenet_ref_output(device, device_output) + print(f"{self} End (+{duration()}ms)") + + async def main(): + start_time = time.time() + + def duration(): + return round((time.time() - start_time) * 1000.0) + + fibers = [lsys.create_fiber() for _ in range(5)] + print("Fibers:", fibers) + processes = [InferProcess(fiber=f).launch() for f in fibers] + print("Waiting for processes:", processes) + await asyncio.gather(*processes) + print(f"All processes complete: (+{duration()}ms)") lsys.run(main()) From 89e26c00b30d61a74f6569ee1770acc002833bc1 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Mon, 28 Oct 2024 18:49:33 -0700 Subject: [PATCH 2/3] [llama] Update kv cache to have read/write functions (#280) Made the interfaces of both caches line up. This allows us to interface with the caches via their utility functions instead of modifying the model behavior. Some roughness still exists in their parameters but the irrelevant details are ignored for each implementation. Need to still add some slicing / ignoring on the page ids to make more flexible. --- sharktank/sharktank/layers/kv_cache.py | 132 ++++- .../layers/paged_llama_attention_block.py | 156 ++---- sharktank/sharktank/types/tensors.py | 2 +- sharktank/tests/layers/kv_cache_test.py | 502 ++++++++++++++++++ .../layers/sharded_paged_kv_cache_test.py | 2 + 5 files changed, 671 insertions(+), 123 deletions(-) create mode 100644 sharktank/tests/layers/kv_cache_test.py diff --git a/sharktank/sharktank/layers/kv_cache.py b/sharktank/sharktank/layers/kv_cache.py index bed0b451d..048bc364c 100644 --- a/sharktank/sharktank/layers/kv_cache.py +++ b/sharktank/sharktank/layers/kv_cache.py @@ -92,6 +92,7 @@ def __init__( attn_head_count: int, attn_head_dim: int, seq_length: int, + shard_count: int = 1, dtype: torch.dtype = torch.float32, device: Optional[torch.device] = None, ): @@ -100,6 +101,7 @@ def __init__( self.attn_head_count = attn_head_count self.attn_head_dim = attn_head_dim self.seq_length = seq_length + self.shard_count = shard_count self.device = device self.dtype = dtype @@ -113,15 +115,109 @@ def allocate(self, *, bs: int) -> list[torch.Tensor]: Each tensor has shape: [bs, sl, attn_head_count, attn_head_dim] """ - return [ + allocations = [ torch.empty( - [bs, self.seq_length, self.attn_head_count, self.attn_head_dim], + [ + bs, + self.seq_length, + self.attn_head_count, + self.attn_head_dim, + ], dtype=self.dtype, device=self.device, ) for _ in range(2 * self.transformer_block_count) ] + if self.shard_count == 1: + return allocations + + return [ + ops.reshard_split(allocation, dim=2, count=self.shard_count) + for allocation in allocations + ] + + def read( + self, + state: list[Union[torch.Tensor, SplitPrimitiveTensor]], + *, + read_into_partitions: list[Union[torch.Tensor, SplitPrimitiveTensor]], + transformer_block_index: int, + seq_len: int, + page_ids: Optional[Union[torch.Tensor, ReplicatedTensor]] = None, + ): + """Reads cache partitions from the page table for the given page_ids. + + Args: + state: State struct as returned from allocate(). + read_into_partitions: List of cache partitions to read into in-place. + transformer_block_index: The index of the transformer block accessing + the cache. + page_ids: Tensor of [bs, max_seqlen // block_pos_stride] of page ids + to access. + + Returns a tuple of cache partitions (i.e. k and v caches for the transformer + block), linearized. Note that this reference approach to reading by + materializing linearly may not be terribly efficient unless if the + compiler can fuse the gather. + """ + read_count = len(read_into_partitions) + reads = [] + for i in range(read_count): + reads.append( + state[transformer_block_index * read_count + i][:, :seq_len, :, :] + ) + + return tuple(reads) + + def write_timestep( + self, + state: list[Union[torch.Tensor, SplitPrimitiveTensor]], + # List of [bs, 1, attn_head_count, attn_head_dim] + cache_partitions: list[Union[torch.Tensor, SplitPrimitiveTensor]], + *, + transformer_block_index: int, + # [bs] + seq_positions: Union[torch.Tensor, ReplicatedTensor], + # [bs, max_seqlen // block_pos_stride] + page_ids: Optional[Union[torch.Tensor, ReplicatedTensor]] = None, + ): + """Writes a single batched timestep across all cache partitions. + + Note that this internally loops over the batch size, which cannot be + dynamic. + """ + bs, _, _, _ = cache_partitions[0].shape + update_count = len(cache_partitions) + + for b in range(bs): + row_index = torch.tensor(b, dtype=torch.int64) + row_start_pos = seq_positions[row_index] + + for i, update in enumerate(cache_partitions): + cache = state[transformer_block_index * update_count + i] + cache.index_put_((row_index, row_start_pos), update[row_index, 0]) + + def write( + self, + state: list[Union[torch.Tensor, SplitPrimitiveTensor]], + cache_partitions: list[Union[torch.Tensor, SplitPrimitiveTensor]], + *, + transformer_block_index: int, + page_ids: Optional[Union[torch.Tensor, ReplicatedTensor]] = None, + ): + """Writes cache partitions from a linear layout to the page table. + + This is the inverse of the linear read. The same caveat applies if the + in-place scatter cannot be fused. + """ + update_count = len(cache_partitions) + + for idx, update_src in enumerate(cache_partitions): + cache_dest = state[transformer_block_index * update_count + idx] + _, batch_seq_len, _, _ = update_src.shape + cache_dest[:, :batch_seq_len, :, :] = update_src + class PagedKVCache(BaseKVCache): """Implementation of a KV cache on top of a 'page table'. @@ -238,24 +334,19 @@ def allocate( """Allocates tensor state for a page table for the given capacity in pages. """ + shards = [ + torch.empty( + [page_count, self.page_slab_flat_dim], + dtype=self.dtype, + device=self.device, + ) + for _ in range(self.shard_count) + ] + if self.shard_count == 1: - return [ - torch.empty( - [page_count, self.page_slab_flat_dim], - dtype=self.dtype, - device=self.device, - ) - ] - else: - shards = [ - torch.empty( - [page_count, self.page_slab_flat_dim], - dtype=self.dtype, - device=self.device, - ) - for _ in range(self.shard_count) - ] - return [SplitPrimitiveTensor(ts=shards, shard_dim=1)] + return shards + + return [SplitPrimitiveTensor(ts=shards, shard_dim=1)] def read( self, @@ -263,6 +354,7 @@ def read( *, read_into_partitions: list[Union[torch.Tensor, SplitPrimitiveTensor]], transformer_block_index: int, + seq_len: int, page_ids: Union[torch.Tensor, ReplicatedTensor], ): """Reads cache partitions from the page table for the given page_ids. @@ -331,6 +423,8 @@ def read_cache_partition( for index, read_into_partition in enumerate(read_into_partitions): read_cache_partition(index, read_into_partition) + return tuple([p[:, :seq_len, :] for p in read_into_partitions]) + def write_timestep( self, state: list[Union[torch.Tensor, SplitPrimitiveTensor]], diff --git a/sharktank/sharktank/layers/paged_llama_attention_block.py b/sharktank/sharktank/layers/paged_llama_attention_block.py index 59ed7b43a..958dc954e 100644 --- a/sharktank/sharktank/layers/paged_llama_attention_block.py +++ b/sharktank/sharktank/layers/paged_llama_attention_block.py @@ -113,27 +113,16 @@ def forward( # Full sequence length. kv_seq_len = seq_block_ids.shape[1] * self.cache.block_seq_stride - if self.cache.is_paged: - xk, xv = self.transact_cache_paged( - xk_cache_update=xk, - xv_cache_update=xv, - seq_block_ids=seq_block_ids, - kv_seq_len=kv_seq_len, - start_positions=start_positions, - cache_state=cache_state, - xk_temp=xk_temp, - xv_temp=xv_temp, - ) - elif self.cache.is_direct: - xk, xv = self.transact_cache_direct( - xk_cache_update=xk, - xv_cache_update=xv, - start_positions=start_positions, - kv_seq_len=kv_seq_len, - cache_state=cache_state, - ) - else: - raise NotImplementedError(f"Unsupported KV cache type: {type(self.cache)}") + xk, xv = self.transact_cache( + xk_cache_update=xk, + xv_cache_update=xv, + seq_block_ids=seq_block_ids, + kv_seq_len=kv_seq_len, + start_positions=start_positions, + cache_state=cache_state, + xk_temp=xk_temp, + xv_temp=xv_temp, + ) # Expand kv heads for GQA. gqa_n_rep = self.head_count // self.head_count_kv @@ -202,58 +191,20 @@ def repeat_kv(x: torch.Tensor) -> torch.Tensor: h = h + attn_output return h - def transact_cache_direct( - self, - *, - cache_state: list[torch.Tensor], - xk_cache_update: torch.Tensor, - xv_cache_update: torch.Tensor, - kv_seq_len: int, - start_positions: Optional[torch.Tensor] = None, - ): - bs, batch_seq_len, _, _ = xk_cache_update.shape - cache_k = cache_state[self.block_index * 2] - cache_v = cache_state[self.block_index * 2 + 1] - - if start_positions is None: - # Prefill. Write the entire cache. - cache_k[:, :batch_seq_len] = xk_cache_update - cache_v[:, :batch_seq_len] = xv_cache_update - return xk_cache_update, xv_cache_update - else: - # Decode. Write a single timestep. - # TODO: This needs to be reworked with index ops. - assert xk_cache_update.shape[1] == 1 - assert xv_cache_update.shape[1] == 1 - for b in range(bs): - # Make a tensor because indices must be all tensors, so we can avoid - # doing start_positions[row_index].item(), which generates a lot of SymInts. - row_index = torch.tensor( - b, dtype=torch.int64, device=xk_cache_update.device - ) - row_start_pos = start_positions[row_index] - cache_k.index_put_( - (row_index, row_start_pos), xk_cache_update[row_index, 0] - ) - cache_v.index_put_( - (row_index, row_start_pos), xv_cache_update[row_index, 0] - ) - return cache_k[:, :kv_seq_len], cache_v[:, :kv_seq_len] - - def transact_cache_paged( + def transact_cache( self, *, xk_cache_update: torch.Tensor, xv_cache_update: torch.Tensor, cache_state: list[torch.Tensor], # [bs, batch_seq_len // block_seq_stride] - seq_block_ids: torch.Tensor, + seq_block_ids: Optional[torch.Tensor], kv_seq_len: int, start_positions: Optional[torch.Tensor] = None, xk_temp: Optional[torch.Tensor] = None, xv_temp: Optional[torch.Tensor] = None, ): - cache = self.cache.paged + cache = self.cache # Manage the cache. if start_positions is None: # Prefill: Write the entire cache. @@ -264,46 +215,45 @@ def transact_cache_paged( page_ids=seq_block_ids, ) return xk_cache_update, xv_cache_update - else: - # Decode at ragged start positions. - # We need to initialize/read the K/V from the cache for the whole - # sequence. Note that at this point, it is possible to fork and - # use a memory efficient attention kernel that can do indirect - # reads, skipping this materialization. This path is taken for - # a decode step. - assert xk_temp is not None and xv_temp is not None - assert xk_cache_update.shape[1] == 1 - assert xv_cache_update.shape[1] == 1 - assert kv_seq_len == seq_block_ids.shape[1] * cache.block_seq_stride - - # Write our one updated cache row into the cache. - cache.write_timestep( - cache_state, - cache_partitions=[ - xk_cache_update, - xv_cache_update, - ], - transformer_block_index=self.block_index, - seq_positions=start_positions, - page_ids=seq_block_ids, - ) - # Restore from the cache. - cache.read( - cache_state, - read_into_partitions=[ - xk_temp[:, 0:kv_seq_len, ...], - xv_temp[:, 0:kv_seq_len, ...], - ], - transformer_block_index=self.block_index, - page_ids=seq_block_ids, - ) + # Decode at ragged start positions. + # We need to initialize/read the K/V from the cache for the whole + # sequence. Note that at this point, it is possible to fork and + # use a memory efficient attention kernel that can do indirect + # reads, skipping this materialization. This path is taken for + # a decode step. + assert xk_temp is not None and xv_temp is not None + assert xk_cache_update.shape[1] == 1 + assert xv_cache_update.shape[1] == 1 + assert kv_seq_len == seq_block_ids.shape[1] * cache.block_seq_stride + + # Write our one updated cache row into the cache. + cache.write_timestep( + cache_state, + cache_partitions=[ + xk_cache_update, + xv_cache_update, + ], + transformer_block_index=self.block_index, + seq_positions=start_positions, + page_ids=seq_block_ids, + ) + + # Restore from the cache. + xk, xv = cache.read( + cache_state, + read_into_partitions=[ + xk_temp[:, 0:kv_seq_len, ...], + xv_temp[:, 0:kv_seq_len, ...], + ], + transformer_block_index=self.block_index, + page_ids=seq_block_ids, + seq_len=kv_seq_len, + ) - # For computation, we create a subview of the xk/xv tensors to have - # a sequence length covering the blocked size. This must include - # the newly added row (the caller is responsible for ensuring that - # every block has at least one row left). We'll compute on this - # ragged view and use an appropriate mask. - xk = xk_temp[:, 0:kv_seq_len, ...] - xv = xv_temp[:, 0:kv_seq_len, ...] - return xk, xv + # For computation, we create a subview of the xk/xv tensors to have + # a sequence length covering the blocked size. This must include + # the newly added row (the caller is responsible for ensuring that + # every block has at least one row left). We'll compute on this + # ragged view and use an appropriate mask. + return xk, xv diff --git a/sharktank/sharktank/types/tensors.py b/sharktank/sharktank/types/tensors.py index 226ffd777..7b3d2e04b 100644 --- a/sharktank/sharktank/types/tensors.py +++ b/sharktank/sharktank/types/tensors.py @@ -990,7 +990,7 @@ def _is_slicing_split_dim(self, key): else: # Any other collection is a indexing only dimension 0. return self.shard_dim == 0 - if len(key) < self.shard_dim: + if len(key) <= self.shard_dim: return False if not isinstance(key[self.shard_dim], slice): return True diff --git a/sharktank/tests/layers/kv_cache_test.py b/sharktank/tests/layers/kv_cache_test.py new file mode 100644 index 000000000..65b42c986 --- /dev/null +++ b/sharktank/tests/layers/kv_cache_test.py @@ -0,0 +1,502 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import unittest + +import torch + +from sharktank.ops import replicate, reshard_split, unshard +from sharktank.layers import * +from sharktank.types import * + + +def test_direct(): + bs = 4 + seq_length = 24 + attn_head_count = 4 + attn_head_dim = 16 + transformer_block_count = 4 + cache = DirectKVCache( + block_seq_stride=4, + transformer_block_count=transformer_block_count, + attn_head_count=attn_head_count, + attn_head_dim=attn_head_dim, + seq_length=seq_length, + dtype=torch.float32, + device=None, + ) + + allocation = cache.allocate(bs=bs) + allocation = [torch.full(t.shape, 0.0, out=t) for t in allocation] + + write_seq_length = seq_length - 5 + + # Write a prefill in: + write_ones = torch.full( + (bs, write_seq_length, attn_head_count, attn_head_dim), 1.0, dtype=torch.float32 + ) + write_twos = torch.full( + (bs, write_seq_length, attn_head_count, attn_head_dim), 2.0, dtype=torch.float32 + ) + cache.write( + allocation, cache_partitions=[write_ones, write_twos], transformer_block_index=1 + ) + + # Check the written values have updated: + read_empty = [ + torch.empty( + (bs, write_seq_length, attn_head_count, attn_head_dim), dtype=torch.float32 + ), + torch.empty( + (bs, write_seq_length, attn_head_count, attn_head_dim), dtype=torch.float32 + ), + ] + read_back = cache.read( + allocation, + read_into_partitions=read_empty, + transformer_block_index=1, + seq_len=write_seq_length, + ) + torch.testing.assert_close(write_ones, read_back[0]) + torch.testing.assert_close(write_twos, read_back[1]) + + # Check the others are still zero: + for i in range(transformer_block_count): + if i == 1: + continue + read_ones = [ + torch.zeros( + (bs, write_seq_length, attn_head_count, attn_head_dim), + dtype=torch.float32, + ), + torch.zeros( + (bs, write_seq_length, attn_head_count, attn_head_dim), + dtype=torch.float32, + ), + ] + read_ones = cache.read( + allocation, + read_into_partitions=read_ones, + transformer_block_index=i, + seq_len=write_seq_length, + ) + torch.testing.assert_close(read_ones[0], torch.full(read_ones[0].shape, 0.0)) + torch.testing.assert_close(read_ones[1], torch.full(read_ones[0].shape, 0.0)) + + # Write timestep + write_threes = torch.full( + (bs, 1, attn_head_count, attn_head_dim), 3.0, dtype=torch.float32 + ) + write_fours = torch.full( + (bs, 1, attn_head_count, attn_head_dim), 4.0, dtype=torch.float32 + ) + write_pos = torch.full((bs,), write_seq_length, dtype=torch.int64) + cache.write_timestep( + allocation, + cache_partitions=[write_threes, write_fours], + transformer_block_index=1, + seq_positions=write_pos, + ) + + read_empty = [ + torch.zeros( + (bs, write_seq_length + 1, attn_head_count, attn_head_dim), + dtype=torch.float32, + ), + torch.zeros( + (bs, write_seq_length + 1, attn_head_count, attn_head_dim), + dtype=torch.float32, + ), + ] + read_back = cache.read( + allocation, + read_into_partitions=read_empty, + transformer_block_index=1, + seq_len=write_seq_length + 1, + ) + + check_concat_0 = torch.concat([write_ones, write_threes], dim=1) + check_concat_1 = torch.concat([write_twos, write_fours], dim=1) + + torch.testing.assert_close(check_concat_0, read_back[0]) + torch.testing.assert_close(check_concat_1, read_back[1]) + + +def test_sharded_direct(): + bs = 4 + seq_length = 24 + attn_head_count = 8 + attn_head_dim = 16 + transformer_block_count = 4 + shard_count = 4 + cache = DirectKVCache( + block_seq_stride=4, + transformer_block_count=transformer_block_count, + attn_head_count=attn_head_count, + attn_head_dim=attn_head_dim, + seq_length=seq_length, + shard_count=shard_count, + dtype=torch.float32, + device=None, + ) + + allocation = cache.allocate(bs=bs) + # allocation = [torch.full(t.shape, 0.0, out=t) for t in allocation] + + write_seq_length = seq_length - 5 + + # Write a prefill in: + write_ones = reshard_split( + torch.full( + (bs, write_seq_length, attn_head_count, attn_head_dim), + 1.0, + dtype=torch.float32, + ), + dim=2, + count=shard_count, + ) + + write_twos = reshard_split( + torch.full( + (bs, write_seq_length, attn_head_count, attn_head_dim), + 2.0, + dtype=torch.float32, + ), + dim=2, + count=shard_count, + ) + + cache.write( + allocation, cache_partitions=[write_ones, write_twos], transformer_block_index=1 + ) + + # Check the written values have updated: + read_empty = [ + torch.empty( + (bs, write_seq_length, attn_head_count, attn_head_dim), dtype=torch.float32 + ), + torch.empty( + (bs, write_seq_length, attn_head_count, attn_head_dim), dtype=torch.float32 + ), + ] + read_back = cache.read( + allocation, + read_into_partitions=read_empty, + transformer_block_index=1, + seq_len=write_seq_length, + ) + torch.testing.assert_close(unshard(write_ones), unshard(read_back[0])) + torch.testing.assert_close(unshard(write_twos), unshard(read_back[1])) + + # Write timestep + write_threes = reshard_split( + torch.full((bs, 1, attn_head_count, attn_head_dim), 3.0, dtype=torch.float32), + dim=2, + count=shard_count, + ) + write_fours = reshard_split( + torch.full((bs, 1, attn_head_count, attn_head_dim), 4.0, dtype=torch.float32), + dim=2, + count=shard_count, + ) + + write_pos = replicate( + torch.full((bs,), write_seq_length, dtype=torch.int64), shard_count + ) + cache.write_timestep( + allocation, + cache_partitions=[write_threes, write_fours], + transformer_block_index=1, + seq_positions=write_pos, + ) + + read_empty = [ + torch.zeros( + (bs, write_seq_length + 1, attn_head_count, attn_head_dim), + dtype=torch.float32, + ), + torch.zeros( + (bs, write_seq_length + 1, attn_head_count, attn_head_dim), + dtype=torch.float32, + ), + ] + read_back = cache.read( + allocation, + read_into_partitions=read_empty, + transformer_block_index=1, + seq_len=write_seq_length + 1, + ) + + check_concat_0 = torch.concat([unshard(write_ones), unshard(write_threes)], dim=1) + check_concat_1 = torch.concat([unshard(write_twos), unshard(write_fours)], dim=1) + + torch.testing.assert_close(check_concat_0, unshard(read_back[0])) + torch.testing.assert_close(check_concat_1, unshard(read_back[1])) + + +def test_paged(): + bs = 4 + seq_length = 24 + attn_head_count = 4 + attn_head_dim = 16 + transformer_block_count = 4 + block_seq_stride = 4 + cache = PagedKVCache( + block_seq_stride=block_seq_stride, + transformer_block_count=transformer_block_count, + attn_head_count=attn_head_count, + attn_head_dim=attn_head_dim, + dtype=torch.float32, + device=None, + ) + + write_seq_length = seq_length - 4 + page_count = bs * seq_length // block_seq_stride + page_ids = torch.arange(page_count, dtype=torch.int64) + page_ids = page_ids.view(bs, seq_length // block_seq_stride) + write_page_ids = page_ids[:, : write_seq_length // block_seq_stride] + + allocation = cache.allocate(page_count=page_count) + allocation = [torch.full(t.shape, 0.0, out=t) for t in allocation] + + # Write a prefill in: + write_ones = torch.full( + (bs, write_seq_length, attn_head_count, attn_head_dim), 1.0, dtype=torch.float32 + ) + write_twos = torch.full( + (bs, write_seq_length, attn_head_count, attn_head_dim), 2.0, dtype=torch.float32 + ) + + cache.write( + allocation, + cache_partitions=[write_ones, write_twos], + transformer_block_index=1, + page_ids=write_page_ids, + ) + + # Check the written values have updated: + read_empty = [ + torch.empty( + (bs, write_seq_length, attn_head_count, attn_head_dim), dtype=torch.float32 + ), + torch.empty( + (bs, write_seq_length, attn_head_count, attn_head_dim), dtype=torch.float32 + ), + ] + read_back = cache.read( + allocation, + read_into_partitions=read_empty, + transformer_block_index=1, + seq_len=write_seq_length, + page_ids=write_page_ids, + ) + torch.testing.assert_close(write_ones, read_back[0]) + torch.testing.assert_close(write_twos, read_back[1]) + + # Check the others are still zero: + for i in range(transformer_block_count): + if i == 1: + continue + read_ones = [ + torch.zeros( + (bs, write_seq_length, attn_head_count, attn_head_dim), + dtype=torch.float32, + ), + torch.zeros( + (bs, write_seq_length, attn_head_count, attn_head_dim), + dtype=torch.float32, + ), + ] + read_ones = cache.read( + allocation, + read_into_partitions=read_ones, + transformer_block_index=i, + seq_len=write_seq_length, + page_ids=write_page_ids, + ) + torch.testing.assert_close(read_ones[0], torch.full(read_ones[0].shape, 0.0)) + torch.testing.assert_close(read_ones[1], torch.full(read_ones[0].shape, 0.0)) + + # Write timestep + write_threes = torch.full( + (bs, 1, attn_head_count, attn_head_dim), 3.0, dtype=torch.float32 + ) + write_fours = torch.full( + (bs, 1, attn_head_count, attn_head_dim), 4.0, dtype=torch.float32 + ) + write_pos = torch.full((bs,), write_seq_length, dtype=torch.int64) + cache.write_timestep( + allocation, + cache_partitions=[write_threes, write_fours], + transformer_block_index=1, + seq_positions=write_pos, + page_ids=page_ids, + ) + + read_empty = [ + torch.zeros( + (bs, write_seq_length + block_seq_stride, attn_head_count, attn_head_dim), + dtype=torch.float32, + ), + torch.zeros( + (bs, write_seq_length + block_seq_stride, attn_head_count, attn_head_dim), + dtype=torch.float32, + ), + ] + read_back = cache.read( + allocation, + read_into_partitions=read_empty, + transformer_block_index=1, + seq_len=write_seq_length + 1, + page_ids=page_ids, + ) + + check_concat_0 = torch.concat([write_ones, write_threes], dim=1) + check_concat_1 = torch.concat([write_twos, write_fours], dim=1) + + torch.testing.assert_close(check_concat_0, read_back[0]) + torch.testing.assert_close(check_concat_1, read_back[1]) + + +def test_sharded_paged(): + bs = 4 + seq_length = 24 + attn_head_count = 8 + attn_head_dim = 16 + transformer_block_count = 4 + block_seq_stride = 4 + shard_count = 4 + cache = PagedKVCache( + block_seq_stride=block_seq_stride, + transformer_block_count=transformer_block_count, + attn_head_count=attn_head_count, + attn_head_dim=attn_head_dim, + shard_count=shard_count, + dtype=torch.float32, + device=None, + ) + + write_seq_length = seq_length - 4 + page_count = bs * seq_length // block_seq_stride + page_ids = torch.arange(page_count, dtype=torch.int64) + page_ids = page_ids.view(bs, seq_length // block_seq_stride) + page_ids = replicate(page_ids, shard_count) + write_page_ids = page_ids[:, : write_seq_length // block_seq_stride] + + allocation = cache.allocate(page_count=page_count) + + # Write a prefill in: + write_ones = reshard_split( + torch.full( + (bs, write_seq_length, attn_head_count, attn_head_dim), + 1.0, + dtype=torch.float32, + ), + dim=2, + count=shard_count, + ) + write_twos = reshard_split( + torch.full( + (bs, write_seq_length, attn_head_count, attn_head_dim), + 2.0, + dtype=torch.float32, + ), + dim=2, + count=shard_count, + ) + + cache.write( + allocation, + cache_partitions=[write_ones, write_twos], + transformer_block_index=1, + page_ids=write_page_ids, + ) + + # Check the written values have updated: + empty_k = reshard_split( + torch.empty( + (bs, write_seq_length, attn_head_count, attn_head_dim), dtype=torch.float32 + ), + dim=2, + count=shard_count, + ) + + empty_v = reshard_split( + torch.empty( + (bs, write_seq_length, attn_head_count, attn_head_dim), dtype=torch.float32 + ), + dim=2, + count=shard_count, + ) + + read_empty = [empty_k, empty_v] + + read_back = cache.read( + allocation, + read_into_partitions=read_empty, + transformer_block_index=1, + seq_len=write_seq_length, + page_ids=write_page_ids, + ) + torch.testing.assert_close(unshard(write_ones), unshard(read_back[0])) + torch.testing.assert_close(unshard(write_twos), unshard(read_back[1])) + + # Write timestep + write_threes = reshard_split( + torch.full((bs, 1, attn_head_count, attn_head_dim), 3.0, dtype=torch.float32), + dim=2, + count=shard_count, + ) + + write_fours = reshard_split( + torch.full((bs, 1, attn_head_count, attn_head_dim), 4.0, dtype=torch.float32), + dim=2, + count=shard_count, + ) + + write_pos = replicate( + torch.full((bs,), write_seq_length, dtype=torch.int64), shard_count + ) + + cache.write_timestep( + allocation, + cache_partitions=[write_threes, write_fours], + transformer_block_index=1, + seq_positions=write_pos, + page_ids=page_ids, + ) + + empty_k = reshard_split( + torch.zeros( + (bs, write_seq_length + block_seq_stride, attn_head_count, attn_head_dim), + dtype=torch.float32, + ), + dim=2, + count=shard_count, + ) + + empty_v = reshard_split( + torch.zeros( + (bs, write_seq_length + block_seq_stride, attn_head_count, attn_head_dim), + dtype=torch.float32, + ), + dim=2, + count=shard_count, + ) + + read_back = cache.read( + allocation, + read_into_partitions=[empty_k, empty_v], + transformer_block_index=1, + seq_len=write_seq_length + 1, + page_ids=page_ids, + ) + + check_concat_0 = torch.concat([unshard(write_ones), unshard(write_threes)], dim=1) + check_concat_1 = torch.concat([unshard(write_twos), unshard(write_fours)], dim=1) + + torch.testing.assert_close(check_concat_0, unshard(read_back[0])) + torch.testing.assert_close(check_concat_1, unshard(read_back[1])) diff --git a/sharktank/tests/layers/sharded_paged_kv_cache_test.py b/sharktank/tests/layers/sharded_paged_kv_cache_test.py index d58874f25..d7b6a0b33 100644 --- a/sharktank/tests/layers/sharded_paged_kv_cache_test.py +++ b/sharktank/tests/layers/sharded_paged_kv_cache_test.py @@ -123,6 +123,7 @@ def testRead(self): read_into_partitions=read_into_partitions, transformer_block_index=transformer_block_index, page_ids=page_ids, + seq_len=self.block_seq_len * self.block_seq_stride, ) sharded_read_into_partitions = deepcopy( [ @@ -136,6 +137,7 @@ def testRead(self): read_into_partitions=sharded_read_into_partitions, transformer_block_index=transformer_block_index, page_ids=sharded_page_ids, + seq_len=self.block_seq_len * self.block_seq_stride, ) for unsharded, sharded in zip( read_into_partitions, sharded_read_into_partitions From f2b1a015ed648f48e0a55132fdcd9774e04c9340 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Mon, 28 Oct 2024 19:39:31 -0700 Subject: [PATCH 3/3] [shortfin] Make error copyable. (#348) * Just eagerly serializes the exception. Trying to defer this makes the type non copyable and is a dubious optimization given that we never use this type for flow control. * Should fix MSVC warnings and ill-defined behavior there. --- shortfin/src/shortfin/support/iree_helpers.cc | 9 +++++---- shortfin/src/shortfin/support/iree_helpers.h | 17 +++++++---------- 2 files changed, 12 insertions(+), 14 deletions(-) diff --git a/shortfin/src/shortfin/support/iree_helpers.cc b/shortfin/src/shortfin/support/iree_helpers.cc index 417a9f443..17430bb71 100644 --- a/shortfin/src/shortfin/support/iree_helpers.cc +++ b/shortfin/src/shortfin/support/iree_helpers.cc @@ -86,13 +86,14 @@ error::error(std::string message, iree_status_t failing_status) message_(std::move(message)), failing_status_(failing_status) { message_.append(": "); + AppendStatusMessage(); } -error::error(iree_status_t failing_status) : failing_status_(failing_status) {} -void error::AppendStatus() const noexcept { - if (status_appended_) return; - status_appended_ = false; +error::error(iree_status_t failing_status) : failing_status_(failing_status) { + AppendStatusMessage(); +} +void error::AppendStatusMessage() { iree_allocator_t allocator = iree_allocator_system(); char *status_buffer = nullptr; iree_host_size_t length = 0; diff --git a/shortfin/src/shortfin/support/iree_helpers.h b/shortfin/src/shortfin/support/iree_helpers.h index 446f32f41..f8d3f1398 100644 --- a/shortfin/src/shortfin/support/iree_helpers.h +++ b/shortfin/src/shortfin/support/iree_helpers.h @@ -277,24 +277,21 @@ class SHORTFIN_API error : public std::exception { public: error(std::string message, iree_status_t failing_status); error(iree_status_t failing_status); - error(const error &) = delete; + error(const error &other) + : code_(other.code_), + message_(other.message_), + failing_status_(iree_status_clone(other.failing_status_)) {} error &operator=(const error &) = delete; ~error() { iree_status_ignore(failing_status_); } - const char *what() const noexcept override { - if (!status_appended_) { - AppendStatus(); - } - return message_.c_str(); - }; + const char *what() const noexcept override { return message_.c_str(); }; iree_status_code_t code() const { return code_; } private: - void AppendStatus() const noexcept; + void AppendStatusMessage(); iree_status_code_t code_; - mutable std::string message_; + std::string message_; mutable iree_status_t failing_status_; - mutable bool status_appended_ = false; }; #define SHORTFIN_IMPL_HANDLE_IF_API_ERROR(var, ...) \