Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TfrtCpuClient fails to allocate int2/int4 buffer #16795

Open
jonatanklosko opened this issue Sep 4, 2024 · 4 comments
Open

TfrtCpuClient fails to allocate int2/int4 buffer #16795

jonatanklosko opened this issue Sep 4, 2024 · 4 comments

Comments

@jonatanklosko
Copy link

jonatanklosko commented Sep 4, 2024

Allocating 1D tensor of type int4 via BufferFromHostBuffer with byte_strides = std::nullopt fails with message:

dims and input_strides_in_bytes must have equal sizes, got 1 and 0

(coming from here)

From debugging, I believe the issue is here:

options.input_layout = TransposePlan::Striding{*byte_strides};

This code path is taken, but byte_strides is actually std::nullopt, so it results in unexpected behaviour.

Here's a snippet to reproduce:

#include "xla/pjrt/tfrt_cpu_pjrt_client.h"

int main() {
  auto status = xla::GetTfrtCpuClient(false);
  auto client = std::move(status.value());

  char data[2] = {0};
  auto shape = xla::ShapeUtil::MakeShape(xla::PrimitiveType::S4, {4});
  auto type = shape.element_type();
  auto dims = shape.dimensions();
  auto byte_stride = std::optional<absl::Span<const int64_t>>{};
  auto semantics = xla::PjRtClient::HostBufferSemantics::kImmutableZeroCopy;
  auto on_done_with_host_buffer = []() {};
  auto device = client->LookupDevice(xla::PjRtGlobalDeviceId(0)).value();
  auto buffer_result = client->BufferFromHostBuffer(data, type, dims, byte_stride, semantics, on_done_with_host_buffer, device);

  if (buffer_result.ok()) {
    std::cout << "Successfull allocation" << std::endl;
  } else {
    std::cout << "Allocation failed: " << buffer_result.status().message() << std::endl;
  }

  return 0;
}

The above fails, but once we change byte_strides to explicit {1}, it passes.

@jonatanklosko
Copy link
Author

Hey! A follow up question. What would be the proper way to allocate and retrieve the PjRt buffer from/to a packed int2/int4 host data?

For example, we currently read the data via literal (PjRtBuffer::ToLiteralSync + Literal::untyped_data), but my understanding is that literal doesn't support subbyte elements (ref), so we always read the unpacked data.

@ezhulenev
Copy link
Member

At this point XLA:CPU also doesn't support packed buffers, it always unpacks them to byte (similar to Literal storage). We do plan to add "real" sub-byte support to XLA:CPU, but it's probably months away.

@jonatanklosko
Copy link
Author

I see, thank you for the update :)

@balancap
Copy link

@ezhulenev Thank you for the explanation. Do you have a more precise timeline and a discussion/RFC opened on the topic.
Sub-byte FP4 and FP6 (as well as E8M0) have just been merged in https://github.com/jax-ml/ml_dtypes (PRs 166 & 181) to allow support for MX block formats. And I am interested to have an understanding when we could get a support of these in XLA (with the optimal packed memory layout).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants