Skip to content

Packed BF16 datatype #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: sandbox
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class LLVMFuncOp;
/// of the libc).
LLVM::LLVMFuncOp lookupOrCreatePrintI64Fn(ModuleOp moduleOp);
LLVM::LLVMFuncOp lookupOrCreatePrintU64Fn(ModuleOp moduleOp);
LLVM::LLVMFuncOp lookupOrCreatePrintBF16Fn(ModuleOp moduleOp);
LLVM::LLVMFuncOp lookupOrCreatePrintF32Fn(ModuleOp moduleOp);
LLVM::LLVMFuncOp lookupOrCreatePrintF64Fn(ModuleOp moduleOp);
LLVM::LLVMFuncOp lookupOrCreatePrintOpenFn(ModuleOp moduleOp);
Expand Down
68 changes: 68 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3791,3 +3791,71 @@ structured_op: !LinalgStructuredOpConfig
scalar_const: '2.3283063999999999E-10 : f64'
- !ScalarExpression
scalar_arg: min
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: reduce_batch_matmul
cpp_class_name: ReduceBatchMatmulOp
doc: |-
Performs a batched matrix multiplication of two 3D inputs.
Numeric casting is performed on the operands to the inner multiply, promoting
them to the same data type as the accumulator/output.
implements:
- LinalgContractionOpInterface
structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
name: A
kind: input_tensor
type_var: T1
shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s3)>
- !LinalgOperandDefConfig
name: B
kind: input_tensor
type_var: T2
shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s3, s2)>
- !LinalgOperandDefConfig
name: C
kind: output_tensor
type_var: U
shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s2)>
indexing_maps: !LinalgIndexingMapsConfig
static_indexing_maps:
- affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d3)>
- affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d3, d2)>
- affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d1, d2)>
iterator_types:
- reduction
- parallel
- parallel
- reduction
assignments:
- !ScalarAssign
arg: C
value: !ScalarExpression
scalar_fn:
kind: binary
fn_name: add
operands:
- !ScalarExpression
scalar_arg: C
- !ScalarExpression
scalar_fn:
kind: binary
fn_name: mul
operands:
- !ScalarExpression
scalar_fn:
kind: type
fn_name: cast_signed
type_var: U
operands:
- !ScalarExpression
scalar_arg: A
- !ScalarExpression
scalar_fn:
kind: type
fn_name: cast_signed
type_var: U
operands:
- !ScalarExpression
scalar_arg: B
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,11 @@ struct SCFTileAndFuseResult {
SmallVector<Operation *> tiledAndFusedOps;
SmallVector<scf::ForOp> loops;
};

using checkProducerFn =
std::function<LogicalResult(ArrayRef<Range> rootIterationDomain,
Operation *producer, OpBuilder &builder)>;

struct TileConsumerAndFuseProducersUsingSCFForOp
: public OpInterfaceRewritePattern<TilingInterface> {

Expand All @@ -127,7 +132,8 @@ struct TileConsumerAndFuseProducersUsingSCFForOp
/// `matchAndRewrite` implementation that returns the significant transformed
/// pieces of IR.
FailureOr<SCFTileAndFuseResult>
returningMatchAndRewrite(TilingInterface op, PatternRewriter &rewriter) const;
returningMatchAndRewrite(TilingInterface op, PatternRewriter &rewriter,
checkProducerFn = nullptr) const;

LogicalResult matchAndRewrite(TilingInterface op,
PatternRewriter &rewriter) const override {
Expand Down
1 change: 1 addition & 0 deletions mlir/include/mlir/IR/Builders.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class Builder {

// Types.
FloatType getBF16Type();
FloatType getPackedBF16Type();
FloatType getF16Type();
FloatType getF32Type();
FloatType getF64Type();
Expand Down
7 changes: 6 additions & 1 deletion mlir/include/mlir/IR/BuiltinTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class FloatType : public Type {

// Convenience factories.
static FloatType getBF16(MLIRContext *ctx);
static FloatType getPackedBF16(MLIRContext *ctx);
static FloatType getF16(MLIRContext *ctx);
static FloatType getF32(MLIRContext *ctx);
static FloatType getF64(MLIRContext *ctx);
Expand Down Expand Up @@ -374,13 +375,17 @@ inline bool BaseMemRefType::isValidElementType(Type type) {

inline bool FloatType::classof(Type type) {
return type.isa<BFloat16Type, Float16Type, Float32Type, Float64Type,
Float80Type, Float128Type>();
Float80Type, Float128Type, PackedBF16Type>();
}

inline FloatType FloatType::getBF16(MLIRContext *ctx) {
return BFloat16Type::get(ctx);
}

inline FloatType FloatType::getPackedBF16(MLIRContext *ctx) {
return PackedBF16Type::get(ctx);
}

inline FloatType FloatType::getF16(MLIRContext *ctx) {
return Float16Type::get(ctx);
}
Expand Down
8 changes: 8 additions & 0 deletions mlir/include/mlir/IR/BuiltinTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,14 @@ def Builtin_Float128 : Builtin_FloatType<"Float128"> {
let summary = "128-bit floating-point type";
}

//===----------------------------------------------------------------------===//
// PackedBF16Type

def Builtin_PackedBF16 : Builtin_FloatType<"PackedBF16"> {
let summary = "Packed BF16 format";
}


//===----------------------------------------------------------------------===//
// FunctionType
//===----------------------------------------------------------------------===//
Expand Down
1 change: 1 addition & 0 deletions mlir/include/mlir/IR/Types.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ class Type {
bool isF64() const;
bool isF80() const;
bool isF128() const;
bool isPackedBF16() const;

/// Return true if this is an integer type with the specified width.
bool isInteger(unsigned width) const;
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/AsmParser/TokenKinds.def
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ TOK_KEYWORD(affine_set)
TOK_KEYWORD(array)
TOK_KEYWORD(attributes)
TOK_KEYWORD(bf16)
TOK_KEYWORD(pbf16)
TOK_KEYWORD(ceildiv)
TOK_KEYWORD(complex)
TOK_KEYWORD(dense)
Expand Down
6 changes: 5 additions & 1 deletion mlir/lib/AsmParser/TypeParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ OptionalParseResult Parser::parseOptionalType(Type &type) {
case Token::kw_vector:
case Token::inttype:
case Token::kw_bf16:
case Token::kw_pbf16:
case Token::kw_f16:
case Token::kw_f32:
case Token::kw_f64:
Expand Down Expand Up @@ -249,7 +250,7 @@ Type Parser::parseMemRefType() {
/// | none-type
///
/// index-type ::= `index`
/// float-type ::= `f16` | `bf16` | `f32` | `f64` | `f80` | `f128`
/// float-type ::= `f16` | `bf16` | `pbf16` | `f32` | `f64` | `f80` | `f128`
/// none-type ::= `none`
///
Type Parser::parseNonFunctionType() {
Expand Down Expand Up @@ -289,6 +290,9 @@ Type Parser::parseNonFunctionType() {
case Token::kw_bf16:
consumeToken(Token::kw_bf16);
return builder.getBF16Type();
case Token::kw_pbf16:
consumeToken(Token::kw_pbf16);
return builder.getPackedBF16Type();
case Token::kw_f16:
consumeToken(Token::kw_f16);
return builder.getF16Type();
Expand Down
5 changes: 4 additions & 1 deletion mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1082,7 +1082,10 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
VectorType vectorType = printType.dyn_cast<VectorType>();
Type eltType = vectorType ? vectorType.getElementType() : printType;
Operation *printer;
if (eltType.isF32()) {
if (eltType.isBF16()) {
printer =
LLVM::lookupOrCreatePrintBF16Fn(printOp->getParentOfType<ModuleOp>());
} else if (eltType.isF32()) {
printer =
LLVM::lookupOrCreatePrintF32Fn(printOp->getParentOfType<ModuleOp>());
} else if (eltType.isF64()) {
Expand Down
7 changes: 7 additions & 0 deletions mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ using namespace mlir::LLVM;
/// part of the libc).
static constexpr llvm::StringRef kPrintI64 = "printI64";
static constexpr llvm::StringRef kPrintU64 = "printU64";
static constexpr llvm::StringRef kPrintBF16 = "printBF16";
static constexpr llvm::StringRef kPrintF32 = "printF32";
static constexpr llvm::StringRef kPrintF64 = "printF64";
static constexpr llvm::StringRef kPrintOpen = "printOpen";
Expand Down Expand Up @@ -66,6 +67,12 @@ LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintU64Fn(ModuleOp moduleOp) {
LLVM::LLVMVoidType::get(moduleOp->getContext()));
}

LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintBF16Fn(ModuleOp moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintBF16,
FloatType::getBF16(moduleOp->getContext()),
LLVM::LLVMVoidType::get(moduleOp->getContext()));
}

LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF32Fn(ModuleOp moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintF32,
Float32Type::get(moduleOp->getContext()),
Expand Down
10 changes: 6 additions & 4 deletions mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -765,6 +765,7 @@ bool mlir::LLVM::isCompatibleOuterType(Type type) {
Float64Type,
Float80Type,
Float128Type,
PackedBF16Type,
LLVMArrayType,
LLVMFunctionType,
LLVMLabelType,
Expand Down Expand Up @@ -865,8 +866,9 @@ bool mlir::LLVM::isCompatibleType(Type type) {
}

bool mlir::LLVM::isCompatibleFloatingPointType(Type type) {
return type.isa<BFloat16Type, Float16Type, Float32Type, Float64Type,
Float80Type, Float128Type, LLVMPPCFP128Type>();
return type
.isa<BFloat16Type, Float16Type, Float32Type, Float64Type, Float80Type,
Float128Type, LLVMPPCFP128Type, PackedBF16Type>();
}

bool mlir::LLVM::isCompatibleVectorType(Type type) {
Expand All @@ -880,7 +882,7 @@ bool mlir::LLVM::isCompatibleVectorType(Type type) {
if (auto intType = elementType.dyn_cast<IntegerType>())
return intType.isSignless();
return elementType.isa<BFloat16Type, Float16Type, Float32Type, Float64Type,
Float80Type, Float128Type>();
Float80Type, Float128Type, PackedBF16Type>();
}
return false;
}
Expand Down Expand Up @@ -965,7 +967,7 @@ llvm::TypeSize mlir::LLVM::getPrimitiveTypeSizeInBits(Type type) {
"expected a type compatible with the LLVM dialect");

return llvm::TypeSwitch<Type, llvm::TypeSize>(type)
.Case<BFloat16Type, Float16Type>(
.Case<BFloat16Type, Float16Type, PackedBF16Type>(
[](Type) { return llvm::TypeSize::Fixed(16); })
.Case<Float32Type>([](Type) { return llvm::TypeSize::Fixed(32); })
.Case<Float64Type, LLVMX86MMXType>(
Expand Down
7 changes: 6 additions & 1 deletion mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -397,13 +397,14 @@ static void replaceIterArgs(scf::ForOp outerFor, scf::ForOp innerFor,

FailureOr<scf::SCFTileAndFuseResult>
scf::TileConsumerAndFuseProducersUsingSCFForOp::returningMatchAndRewrite(
TilingInterface op, PatternRewriter &rewriter) const {
TilingInterface op, PatternRewriter &rewriter, checkProducerFn fn) const {
// This transformation is only valid for ops that return values (i.e. not
// valid to use with operations that have memref operands).
if (!op->getNumResults()) {
return rewriter.notifyMatchFailure(
op, "invalid pattern for op with no results");
}
SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter);

// 1. First tile the consumer.
SCFTileAndFuseResult tileAndFuseResult;
Expand Down Expand Up @@ -446,6 +447,10 @@ scf::TileConsumerAndFuseProducersUsingSCFForOp::returningMatchAndRewrite(
if (!fusableProducer)
continue;

if (fn &&
failed(fn(iterationDomain, fusableProducer->getDefiningOp(), rewriter)))
continue;

// 2c. Generate the tiled implementation of the producer of the source
rewriter.setInsertionPoint(candidateSliceOp);
FailureOr<Value> fusedProducerValue =
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/IR/AsmPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2078,6 +2078,7 @@ void AsmPrinter::Impl::printType(Type type) {
})
.Case<IndexType>([&](Type) { os << "index"; })
.Case<BFloat16Type>([&](Type) { os << "bf16"; })
.Case<PackedBF16Type>([&](Type) { os << "pbf16"; })
.Case<Float16Type>([&](Type) { os << "f16"; })
.Case<Float32Type>([&](Type) { os << "f32"; })
.Case<Float64Type>([&](Type) { os << "f64"; })
Expand Down
4 changes: 4 additions & 0 deletions mlir/lib/IR/Builders.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ Location Builder::getFusedLoc(ArrayRef<Location> locs, Attribute metadata) {

FloatType Builder::getBF16Type() { return FloatType::getBF16(context); }

FloatType Builder::getPackedBF16Type() {
return FloatType::getPackedBF16(context);
}

FloatType Builder::getF16Type() { return FloatType::getF16(context); }

FloatType Builder::getF32Type() { return FloatType::getF32(context); }
Expand Down
Loading