Skip to content

Commit

Permalink
Move common ET/Aten op stuff to ops/library.h
Browse files Browse the repository at this point in the history
Differential Revision: D63983380

Pull Request resolved: #1116
  • Loading branch information
metascroy authored Oct 18, 2024
1 parent 677af6f commit 67c12dd
Show file tree
Hide file tree
Showing 8 changed files with 81 additions and 75 deletions.
34 changes: 34 additions & 0 deletions torchao/experimental/ops/library.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.

#if defined(USE_ATEN) && !defined(USE_EXECUTORCH)
#pragma message("USE_ATEN")
#include <torch/library.h>
#include <torch/script.h>
#include <torch/torch.h>
using Tensor = at::Tensor;
#define TORCHAO_CHECK(cond, msg) TORCH_CHECK(cond, msg)

#elif defined(USE_EXECUTORCH) && !defined(USE_ATEN)
#pragma message("USE_EXECUTORCH")
#include <executorch/extension/kernel_util/make_boxed_from_unboxed_functor.h>
#include <executorch/runtime/kernel/kernel_includes.h>
using Tensor = torch::executor::Tensor;
using RuntimeContext = torch::executor::KernelRuntimeContext;
#define TORCHAO_CHECK(cond, msg) ET_CHECK_MSG(cond, msg)

#elif !defined(USE_EXECUTORCH) && !defined(USE_ATEN)
#pragma message("Neither USE_ATEN or USE_EXECUTORCH defined")
#include <stdexcept>

#define TORCHAO_CHECK(cond, message) \
if (!(cond)) { \
throw std::runtime_error(message); \
}

#else
#error "Cannot define both USE_ATEN or USE_EXECUTORCH"
#endif
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

#pragma once
#include <torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.h>
#include <torchao/experimental/ops/macro.h>
#include <torchao/experimental/ops/library.h>
#include <torchao/experimental/ops/memory.h>
#include <cassert>
#include <optional>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

#include <stdint.h>
#include <torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.h>
#include <torchao/experimental/ops/macro.h>
#include <torchao/experimental/ops/library.h>
#include <torchao/experimental/ops/parallel.h>
#include <algorithm>
#include <cassert>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,31 +10,13 @@
#include <torchao/experimental/kernels/cpu/aarch64/linear/linear.h>
#endif // defined(__aarch64__) || defined(__ARM_NEON)

#include <torchao/experimental/ops/library.h>
#include <torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.h>
#include <torchao/experimental/ops/linear_8bit_act_xbit_weight/packed_weights_header.h>
#include <torchao/experimental/ops/packed_weights_header.h>
#include <optional>
#include <vector>

#if defined(USE_ATEN) && !defined(USE_EXECUTORCH)
#pragma message("USE_ATEN")
#include <torch/library.h>
#include <torch/script.h>
#include <torch/torch.h>
using Tensor = at::Tensor;
#define CHECK_MSG(cond, msg) TORCH_CHECK(cond, msg)

#elif defined(USE_EXECUTORCH) && !defined(USE_ATEN)
#pragma message("USE_EXECUTORCH")
#include <executorch/extension/kernel_util/make_boxed_from_unboxed_functor.h>
#include <executorch/runtime/kernel/kernel_includes.h>
using Tensor = torch::executor::Tensor;
using RuntimeContext = torch::executor::KernelRuntimeContext;
#define CHECK_MSG(cond, msg) ET_CHECK_MSG(cond, msg)
#else
#error "Must define either USE_ATEN or USE_EXECUTORCH"
#endif

namespace {

// This selects a UkernelConfig based on the packed weight header
Expand All @@ -52,7 +34,7 @@ get_ukernel_config(torchao::ops::PackedWeightsHeader header) {
channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot;

// Check packing params match the kernel
CHECK_MSG(
TORCHAO_CHECK(
header ==
torchao::ops::linear_8bit_act_xbit_weight::
get_packed_weights_header_universal(
Expand Down Expand Up @@ -80,9 +62,9 @@ get_ukernel_config(torchao::ops::PackedWeightsHeader header) {
&ukernel::kernel<weight_nbit, has_weight_zeros, has_bias, has_clamp>;
return config;
break;
default:
CHECK_MSG(false, "Unsupported packed weights format");
#endif // defined(__aarch64__) || defined(__ARM_NEON)
default:
TORCHAO_CHECK(false, "Unsupported packed weights format");
}
}

Expand All @@ -103,8 +85,9 @@ Tensor pack_weights_cpu(
const Tensor& weight_scales,
const std::optional<Tensor>& weight_zeros,
int64_t group_size) {
CHECK_MSG(weight_qvals.dtype() == torch::kInt8, "weight_qvals must be int8");
CHECK_MSG(weight_qvals.dim() == 2, "weight_qvals must be 2D");
TORCHAO_CHECK(
weight_qvals.dtype() == torch::kInt8, "weight_qvals must be int8");
TORCHAO_CHECK(weight_qvals.dim() == 2, "weight_qvals must be 2D");

// In PyTorch, weights are nxk in row-major format (with activations being
// right-multiplied).
Expand All @@ -114,25 +97,25 @@ Tensor pack_weights_cpu(
int n = weight_qvals.size(0);
int k = weight_qvals.size(1);

CHECK_MSG(
TORCHAO_CHECK(
weight_scales.dtype() == torch::kFloat32,
"weight_scales must be float32");
CHECK_MSG(weight_scales.dim() == 1, "weight_scales must be 1D");
CHECK_MSG(group_size >= 1, "group_size must be >= 1");
CHECK_MSG(
TORCHAO_CHECK(weight_scales.dim() == 1, "weight_scales must be 1D");
TORCHAO_CHECK(group_size >= 1, "group_size must be >= 1");
TORCHAO_CHECK(
weight_scales.size(0) == ((n * k) / group_size),
"expected 1 scale per group");

CHECK_MSG(
TORCHAO_CHECK(
has_weight_zeros == weight_zeros.has_value(),
"has_weight_zeros must match weight_zeros.has_value()");
const int8_t* weight_zeros_ptr = nullptr;
if constexpr (has_weight_zeros) {
CHECK_MSG(
TORCHAO_CHECK(
weight_zeros.value().dtype() == torch::kInt8,
"weight_zeros must be int8");
CHECK_MSG(weight_zeros.value().dim() == 1, "weight_zeros must be 1D");
CHECK_MSG(
TORCHAO_CHECK(weight_zeros.value().dim() == 1, "weight_zeros must be 1D");
TORCHAO_CHECK(
weight_zeros.value().size(0) == ((n * k) / group_size),
"expected 1 zero per group");
weight_zeros_ptr = weight_zeros.value().const_data_ptr<int8_t>();
Expand Down Expand Up @@ -206,7 +189,7 @@ Tensor pack_weights_meta(
const Tensor& weight_scales,
const std::optional<Tensor>& weight_zeros,
int64_t group_size) {
CHECK_MSG(group_size >= 1, "group_size must be >= 1");
TORCHAO_CHECK(group_size >= 1, "group_size must be >= 1");
int n = weight_qvals.size(0);
int k = weight_qvals.size(1);

Expand Down Expand Up @@ -269,46 +252,47 @@ Tensor linear_out_cpu(
int n = n_tensor.size(1);
int k = k_tensor.size(1);
int group_size = group_size_tensor.size(1);
CHECK_MSG(n >= 1, "n must be >= 1");
CHECK_MSG(k >= 1, "k must be >= 1");
CHECK_MSG(group_size >= 1, "group_size must be >= 1");
TORCHAO_CHECK(n >= 1, "n must be >= 1");
TORCHAO_CHECK(k >= 1, "k must be >= 1");
TORCHAO_CHECK(group_size >= 1, "group_size must be >= 1");

#ifdef USE_ATEN
CHECK_MSG(
TORCHAO_CHECK(
activations.dtype() == torch::kFloat32, "activations must be float32");
#endif // USE_ATEN

CHECK_MSG(activations.dim() == 2, "activations must be 2D");
TORCHAO_CHECK(activations.dim() == 2, "activations must be 2D");
int m = activations.size(0);
int k_ = activations.size(1);
CHECK_MSG(k == k_, "activation shape is incompatible with packed weights.");
TORCHAO_CHECK(
k == k_, "activation shape is incompatible with packed weights.");

#ifdef USE_ATEN
CHECK_MSG(out.dtype() == torch::kFloat32, "out must be float32");
TORCHAO_CHECK(out.dtype() == torch::kFloat32, "out must be float32");
#endif // USE_ATEN

#ifdef USE_ATEN
out.resize_({m, n});
#endif // USE_ATEN

#ifdef USE_EXECUTORCH
CHECK_MSG(out.dim() == 2, "out must be 2D");
CHECK_MSG(out.size(0) == m, "out shape is incorrect");
CHECK_MSG(out.size(1) == n, "out shape is incorrect");
TORCHAO_CHECK(out.dim() == 2, "out must be 2D");
TORCHAO_CHECK(out.size(0) == m, "out shape is incorrect");
TORCHAO_CHECK(out.size(1) == n, "out shape is incorrect");
#endif // USE_EXECUTORCH

using namespace torchao::ops::linear_8bit_act_xbit_weight;

CHECK_MSG(packed_weights.dim() == 1, "packed_weights must be 1D");
TORCHAO_CHECK(packed_weights.dim() == 1, "packed_weights must be 1D");
#ifdef USE_ATEN
CHECK_MSG(
TORCHAO_CHECK(
packed_weights.dtype() == torch::kInt8, "packed_weights must be int8");
#endif // USE_ATEN
CHECK_MSG(
TORCHAO_CHECK(
packed_weights.size(0) >= torchao::ops::PackedWeightsHeader::size(),
"packed_weights is not big enough to read the header.");
auto header = torchao::ops::PackedWeightsHeader::read(
packed_weights.const_data_ptr());
auto header =
torchao::ops::PackedWeightsHeader::read(packed_weights.const_data_ptr());

auto ukernel_config = get_ukernel_config<
weight_nbit,
Expand Down Expand Up @@ -392,13 +376,14 @@ Tensor linear_meta(
const Tensor& k_tensor) {
int n = n_tensor.size(1);
int k = k_tensor.size(1);
CHECK_MSG(n >= 1, "n must be >= 1");
CHECK_MSG(k >= 1, "k must be >= 1");
TORCHAO_CHECK(n >= 1, "n must be >= 1");
TORCHAO_CHECK(k >= 1, "k must be >= 1");

CHECK_MSG(activations.dim() == 2, "activations must be 2D");
TORCHAO_CHECK(activations.dim() == 2, "activations must be 2D");
int m = activations.size(0);
int k_ = activations.size(1);
CHECK_MSG(k == k_, "activation shape is incompatible with packed weights.");
TORCHAO_CHECK(
k == k_, "activation shape is incompatible with packed weights.");
return torch::empty({m, n}).to("meta");
}
#endif // USE_ATEN
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
// LICENSE file in the root directory of this source tree.

#pragma once
#include <torchao/experimental/ops/macro.h>
#include <torchao/experimental/ops/library.h>
#include <torchao/experimental/ops/packed_weights_header.h>

namespace torchao::ops::linear_8bit_act_xbit_weight {

torchao::ops::PackedWeightsHeader get_packed_weights_header_universal(
inline torchao::ops::PackedWeightsHeader get_packed_weights_header_universal(
int weight_nbit,
bool has_weight_zeros,
bool has_bias,
Expand Down
13 changes: 0 additions & 13 deletions torchao/experimental/ops/macro.h

This file was deleted.

6 changes: 3 additions & 3 deletions torchao/experimental/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def _replace_linear_with_quantized_linear(module: nn.Module, kwargs={}):
)


class Int8DynActIntxWeightQuantizer:
class Int8DynActIntxWeightLinearQuantizer:
def __init__(
self,
device,
Expand All @@ -274,14 +274,14 @@ def __init__(
):
if device != "cpu":
raise NotImplementedError(
"Only device=cpu is currently supported in Int8DynActLowbitWeightQuantizer"
"Only device=cpu is currently supported in Int8DynActIntxWeightLinearQuantizer"
)
else:
self.device = device

if precision != torch.float32:
raise NotImplementedError(
"Only precision=torch.float32 is currently supported in Int8DynActLowbitWeightQuantizer"
"Only precision=torch.float32 is currently supported in Int8DynActIntxWeightLinearQuantizer"
)
else:
self.precision = precision
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from quant_api import (
_Int8DynActIntxWeightQuantizedLinearFallback,
Int8DynActIntxWeightQuantizer,
Int8DynActIntxWeightLinearQuantizer,
)


Expand Down Expand Up @@ -74,7 +74,7 @@ def test_accuracy(self):
for has_weight_zeros in [True, False]:
print(f"Testing nbit={nbit}, has_weight_zeros={has_weight_zeros}")
quantized_model = copy.deepcopy(model)
quantizer = Int8DynActIntxWeightQuantizer(
quantizer = Int8DynActIntxWeightLinearQuantizer(
device="cpu",
precision=torch.float32,
bitwidth=nbit,
Expand Down Expand Up @@ -122,7 +122,7 @@ def test_export_compile_aoti(self):
activations = torch.randn(m, k0, dtype=torch.float32)

print("Quantizing model")
quantizer = Int8DynActIntxWeightQuantizer(
quantizer = Int8DynActIntxWeightLinearQuantizer(
device="cpu",
precision=torch.float32,
bitwidth=nbit,
Expand Down

0 comments on commit 67c12dd

Please sign in to comment.