Skip to content

Add dynamic shape support for lowbit kernels #1942

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 3 additions & 8 deletions torchao/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,9 @@
# They can also be built outside of the torchao install process by
# running the script `torchao/experimental/build_torchao_ops.sh <aten|executorch>`
# For more information, see https://github.com/pytorch/ao/blob/main/torchao/experimental/docs/readme.md
experimental_lib = list(Path(__file__).parent.glob("libtorchao_ops_aten.*"))
if len(experimental_lib) > 0:
assert (
len(experimental_lib) == 1
), f"Expected at most one libtorchao_ops_aten.* file, found {len(experimental_lib)}"
torch.ops.load_library(str(experimental_lib[0]))
except:
logging.debug("Skipping import of cpp extensions")
from torchao.experimental.op_lib import * # noqa: F403
except Exception as e:
logging.debug(f"Skipping import of cpp extensions: {e}")

from torchao.quantization import (
autoquant,
Expand Down
57 changes: 57 additions & 0 deletions torchao/experimental/op_lib.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from pathlib import Path

import torch
from torch import Tensor
from torch.library import impl

# Load C++ ops
lib_path = Path(__file__).parent.parent
libs = list(lib_path.glob("libtorchao_ops_aten.*"))
assert (
len(libs) == 1
), f"Expected to find one libtorchao_ops_aten.* library at {lib_path}, but found {len(libs)}"
torch.ops.load_library(str(libs[0]))


# Define meta ops. To support dynamic shapes, some meta ops need to
# be defined in python instead of C++.
torchao_lib = torch.library.Library("torchao", "IMPL")
for weight_nbit in range(1, 9):

@impl(torchao_lib, f"_linear_8bit_act_{weight_nbit}bit_weight", "Meta")
def _(
activations: Tensor,
packed_weights: Tensor,
group_size: int,
n: int,
k: int,
):
assert activations.dim() == 2
m, k_ = activations.shape
assert k_ == k
return torch.empty(m, n, dtype=activations.dtype, device="meta")

@impl(torchao_lib, f"_embedding_{weight_nbit}bit", "Meta")
def _(
packed_weight_qvals: Tensor,
num_embeddings: int,
embedding_dim: int,
weight_scales: Tensor,
weight_zeros: Tensor,
indices: Tensor,
):
assert indices.dim() == 1
num_out = indices.shape[0]
return torch.empty(num_out, embedding_dim, dtype=torch.float32, device="meta")

@impl(torchao_lib, f"_shared_embedding_{weight_nbit}bit", "Meta")
def _(packed_weights: Tensor, group_size: int, n: int, k: int, indices: Tensor):
assert indices.dim() == 1
num_out = indices.shape[0]
return torch.empty(num_out, k, dtype=torch.float32, device="meta")
70 changes: 10 additions & 60 deletions torchao/experimental/ops/embedding_xbit/op_embedding_xbit-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,18 +92,12 @@ void check_embedding_inputs(
template <int weight_nbit>
Tensor embedding_out_cpu(
const Tensor& packed_weight_qvals,
// TODO(T200095131): convert to
// int64_t when supported by AOTI
// Currently they are tensors with size
// equal to (0, the int they wrap)
const Tensor& num_embeddings_tensor,
const Tensor& embedding_dim_tensor,
const int64_t& num_embeddings,
const int64_t& embedding_dim,
const Tensor& weight_scales,
const Tensor& weight_zeros,
const Tensor& indices,
Tensor& out) {
int num_embeddings = num_embeddings_tensor.size(1);
int embedding_dim = embedding_dim_tensor.size(1);
int group_size;
check_embedding_inputs<weight_nbit>(
packed_weight_qvals,
Expand All @@ -117,16 +111,8 @@ Tensor embedding_out_cpu(
int num_out = indices.size(0);
const int8_t* weight_zeros_ptr = weight_zeros.const_data_ptr<int8_t>();

#ifdef USE_ATEN
TORCHAO_CHECK(out.dtype() == torch::kFloat32, "out must be float32");
out.resize_({num_out, embedding_dim});
#endif // USE_ATEN

#ifdef USE_EXECUTORCH
TORCHAO_CHECK(out.dim() == 2, "out must be 2D");
TORCHAO_CHECK(out.size(0) == num_out, "out shape is incorrect");
TORCHAO_CHECK(out.size(1) == embedding_dim, "out shape is incorrect");
#endif // USE_EXECUTORCH
// Explicit cast from int64_t to int is required for Executorch
TORCHAO_RESIZE_TENSOR(out, {(int)num_out, (int)embedding_dim});

const int32_t* index32_ptr = nullptr;
const int64_t* index64_ptr = nullptr;
Expand Down Expand Up @@ -169,20 +155,16 @@ Tensor embedding_out_cpu(
template <int weight_nbit>
Tensor embedding_cpu(
const Tensor& packed_weight_qvals,
// TODO(T200095131): convert to
// int64_t when supported by AOTI
// Currently they are tensors with size
// equal to (0, the int they wrap)
const Tensor& num_embeddings_tensor,
const Tensor& embedding_dim_tensor,
const int64_t& num_embeddings,
const int64_t& embedding_dim,
const Tensor& weight_scales,
const Tensor& weight_zeros,
const Tensor& indices) {
Tensor output_tensor = torch::empty({}, torch::kFloat32);
embedding_out_cpu<weight_nbit>(
packed_weight_qvals,
num_embeddings_tensor,
embedding_dim_tensor,
num_embeddings,
embedding_dim,
weight_scales,
weight_zeros,
indices,
Expand All @@ -191,25 +173,6 @@ Tensor embedding_cpu(
}
#endif // USE_ATEN

#ifdef USE_ATEN
template <int weight_nbit>
Tensor embedding_meta(
const Tensor& packed_weight_qvals,
// TODO(T200095131): convert to
// int64_t when supported by AOTI
// Currently they are tensors with size
// equal to (0, the int they wrap)
const Tensor& num_embeddings_tensor,
const Tensor& embedding_dim_tensor,
const Tensor& weight_scales,
const Tensor& weight_zeros,
const Tensor& indices) {
int embedding_dim = embedding_dim_tensor.size(1);
int num_out = indices.size(0);
return torch::empty({num_out, embedding_dim}).to("meta");
}
#endif // USE_ATEN

#ifdef USE_ATEN
template <int weight_nbit>
Tensor pack_embedding_cpu(const Tensor& weight_qvals) {
Expand Down Expand Up @@ -261,10 +224,10 @@ Tensor pack_embedding_meta(const Tensor& weight_qvals) {
TORCHAO_CHECK(
embedding_dim % 8 == 0, "embedding_dim must be a multiple of 8 to pack");
int packed_embedding_dim = embedding_dim * weight_nbit / 8;
auto options = torch::TensorOptions().device(c10::DeviceType::Meta).dtype(torch::kInt8);
return torch::empty(
torchao::ops::PackedWeightsHeader::size() +
(num_embeddings * packed_embedding_dim))
.to("meta");
(num_embeddings * packed_embedding_dim), options);
}
#endif // USE_ATEN

Expand Down Expand Up @@ -371,17 +334,4 @@ Tensor shared_embedding_cpu(
}
#endif // USE_ATEN

#ifdef USE_ATEN
template <int weight_nbit>
Tensor shared_embedding_meta(
const Tensor& packed_weights,
const int64_t& group_size,
const int64_t& n, // same as num_embeddings
const int64_t& k, // same as embedding_dim
const Tensor& indices) {
int num_out = indices.size(0);
return torch::empty({num_out, k}).to("meta");
}
#endif // USE_ATEN

#endif // defined(USE_ATEN) || defined(USE_EXECUTORCH)
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
m.def("_pack_embedding_" #weight_nbit "bit(Tensor weight_qvals) -> Tensor"); \
m.def( \
"_embedding_" #weight_nbit \
"bit(Tensor packed_weight_qvals, Tensor num_embeddings_tensor, Tensor embedding_dim_tensor, Tensor weight_scales, Tensor weight_zeros, Tensor indices) -> Tensor"); \
"bit(Tensor packed_weight_qvals, int num_embeddings, int embedding_dim, Tensor weight_scales, Tensor weight_zeros, Tensor indices) -> Tensor"); \
m.def( \
"_embedding_" #weight_nbit \
"bit.out(Tensor packed_weight_qvals, Tensor num_embeddings_tensor, Tensor embedding_dim_tensor, Tensor weight_scales, Tensor weight_zeros, Tensor indices, *, Tensor(a!) out) -> Tensor(a!)"); \
"bit.out(Tensor packed_weight_qvals, int num_embeddings, int embedding_dim, Tensor weight_scales, Tensor weight_zeros, Tensor indices, *, Tensor(a!) out) -> Tensor(a!)"); \
m.def( \
"_shared_embedding_" #weight_nbit \
"bit.out(Tensor packed_weights, int group_size, int n, int k, Tensor indices, *, Tensor(a!) out) -> Tensor(a!)"); \
Expand All @@ -38,11 +38,7 @@
#define DEFINE_META_IMPL(weight_nbit) \
m.impl( \
"_pack_embedding_" #weight_nbit "bit", \
&pack_embedding_meta<weight_nbit>); \
m.impl("_embedding_" #weight_nbit "bit", &embedding_meta<weight_nbit>); \
m.impl( \
"_shared_embedding_" #weight_nbit "bit", \
&shared_embedding_meta<weight_nbit>);
&pack_embedding_meta<weight_nbit>);

TORCH_LIBRARY_FRAGMENT(torchao, m) {
DEFINE_OP(1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,17 @@
Tensor _op_out_##weight_nbit( \
RuntimeContext& ctx, \
const Tensor& packed_weight_qvals, \
const Tensor& num_embeddings_tensor, \
const Tensor& embedding_dim_tensor, \
const int64_t& num_embeddings, \
const int64_t& embedding_dim, \
const Tensor& weight_scales, \
const Tensor& weight_zeros, \
const Tensor& indices, \
Tensor& out) { \
(void)ctx; \
embedding_out_cpu<weight_nbit>( \
packed_weight_qvals, \
num_embeddings_tensor, \
embedding_dim_tensor, \
num_embeddings, \
embedding_dim, \
weight_scales, \
weight_zeros, \
indices, \
Expand Down
4 changes: 4 additions & 0 deletions torchao/experimental/ops/library.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,20 @@ using Tensor = at::Tensor;
#define Tensor_dtype_kInt32 torch::kInt32
#define Tensor_dtype_kInt64 torch::kInt64
#define TORCHAO_CHECK(cond, msg) TORCH_CHECK(cond, msg)
#define TORCHAO_RESIZE_TENSOR(tensor, ...) tensor.resize_({__VA_ARGS__})

#elif defined(USE_EXECUTORCH) && !defined(USE_ATEN)
#pragma message("USE_EXECUTORCH")
#include <executorch/extension/kernel_util/make_boxed_from_unboxed_functor.h>
#include <executorch/runtime/kernel/kernel_includes.h>
#include <executorch/runtime/core/exec_aten/util/tensor_util.h>
using Tensor = torch::executor::Tensor;
using RuntimeContext = torch::executor::KernelRuntimeContext;
#define Tensor_dtype_kInt32 torch::executor::ScalarType::Int
#define Tensor_dtype_kInt64 torch::executor::ScalarType::Long
#define TORCHAO_CHECK(cond, msg) ET_CHECK_MSG(cond, msg)
#define TORCHAO_RESIZE_TENSOR(tensor, ...) \
ET_CHECK_MSG(torch::executor::resize_tensor(tensor, {__VA_ARGS__}) == torch::executor::Error::Ok, "resize failed")

#elif !defined(USE_EXECUTORCH) && !defined(USE_ATEN)
#pragma message("Neither USE_ATEN or USE_EXECUTORCH defined")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,8 @@ Tensor pack_weights_meta(
torchao::ops::PackedWeightsHeader::size() +
get_packed_weight_data_size(
ukernel_config, n, k, group_size, has_weight_zeros, has_bias);
return torch::empty({static_cast<int64_t>(packed_weight_data_size)})
.to("meta");
auto options = torch::TensorOptions().device(c10::DeviceType::Meta).dtype(torch::kInt8);
return torch::empty({static_cast<int64_t>(packed_weight_data_size)}, options);
}
#endif // USE_ATEN

Expand Down Expand Up @@ -166,15 +166,8 @@ Tensor linear_out_cpu(
TORCHAO_CHECK(out.dtype() == torch::kFloat32, "out must be float32");
#endif // USE_ATEN

#ifdef USE_ATEN
out.resize_({m, n});
#endif // USE_ATEN

#ifdef USE_EXECUTORCH
TORCHAO_CHECK(out.dim() == 2, "out must be 2D");
TORCHAO_CHECK(out.size(0) == m, "out shape is incorrect");
TORCHAO_CHECK(out.size(1) == n, "out shape is incorrect");
#endif // USE_EXECUTORCH
// Explicit cast from int64_t to int is required for Executorch
TORCHAO_RESIZE_TENSOR(out, {(int)m, (int)n});

using namespace torchao::ops::linear_8bit_act_xbit_weight;

Expand Down Expand Up @@ -254,24 +247,4 @@ Tensor linear_cpu(
}
#endif // USE_ATEN

#ifdef USE_ATEN
template <int weight_nbit>
Tensor linear_meta(
const Tensor& activations,
const Tensor& packed_weights,
const int64_t& group_size,
const int64_t& n,
const int64_t& k) {
TORCHAO_CHECK(n >= 1, "n must be >= 1");
TORCHAO_CHECK(k >= 1, "k must be >= 1");

TORCHAO_CHECK(activations.dim() == 2, "activations must be 2D");
int m = activations.size(0);
int k_ = activations.size(1);
TORCHAO_CHECK(
k == k_, "activation shape is incompatible with packed weights.");
return torch::empty({m, n}).to("meta");
}
#endif // USE_ATEN

} // namespace
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,7 @@
&pack_weights_meta<weight_nbit>); \
m.impl( \
"_pack_8bit_act_" #weight_nbit "bit_weight", \
&pack_weights_meta<weight_nbit>); \
m.impl( \
"_linear_8bit_act_" #weight_nbit "bit_weight", \
&linear_meta<weight_nbit>);
&pack_weights_meta<weight_nbit>)

TORCH_LIBRARY_FRAGMENT(torchao, m) {
DEFINE_OP(1);
Expand Down
8 changes: 2 additions & 6 deletions torchao/experimental/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,12 +384,8 @@ def quantize_and_pack_weights(self, weights, group_size, has_weight_zeros):
self.register_buffer(
"packed_weight_qvals", self.pack_weights_op(weight_qvals.to(torch.int8))
)
self.register_buffer(
"num_embeddings", torch.empty(0, num_embeddings, dtype=torch.int8)
)
self.register_buffer(
"embedding_dim", torch.empty(0, embedding_dim, dtype=torch.int8)
)
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.register_buffer("weight_scales", weight_scales)
self.register_buffer("weight_zeros", weight_zeros.to(torch.int8))

Expand Down
Loading
Loading