Skip to content

[VFABI] Add support for vector functions that return struct types #119000

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 1 addition & 13 deletions llvm/include/llvm/Analysis/VectorUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "llvm/Analysis/LoopAccessAnalysis.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/VFABIDemangler.h"
#include "llvm/IR/VectorUtils.h"
#include "llvm/Support/CheckedArithmetic.h"

namespace llvm {
Expand Down Expand Up @@ -127,19 +128,6 @@ namespace Intrinsic {
typedef unsigned ID;
}

/// A helper function for converting Scalar types to vector types. If
/// the incoming type is void, we return void. If the EC represents a
/// scalar, we return the scalar type.
inline Type *ToVectorTy(Type *Scalar, ElementCount EC) {
if (Scalar->isVoidTy() || Scalar->isMetadataTy() || EC.isScalar())
return Scalar;
return VectorType::get(Scalar, EC);
}

inline Type *ToVectorTy(Type *Scalar, unsigned VF) {
return ToVectorTy(Scalar, ElementCount::getFixed(VF));
}

/// Identify if the intrinsic is trivially vectorizable.
/// This method returns true if the intrinsic's argument types are all scalars
/// for the scalar form of the intrinsic and all vectors (or scalars handled by
Expand Down
44 changes: 44 additions & 0 deletions llvm/include/llvm/IR/CallWideningUtils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
//===---- CallWideningUtils.h - Utils for widening scalar to vector calls --==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef LLVM_IR_CALLWIDENINGUTILS_H
#define LLVM_IR_CALLWIDENINGUTILS_H

#include "llvm/IR/DerivedTypes.h"

namespace llvm {

/// A helper for converting to wider (vector) types. For scalar types, this is
/// equivalent to calling `ToVectorTy`. For struct types, this returns a new
/// struct where each element type has been widened to a vector type. Note: Only
/// unpacked literal struct types are supported.
Type *ToWideTy(Type *Ty, ElementCount EC);

/// A helper for converting wide types to narrow (non-vector) types. For vector
/// types, this is equivalent to calling .getScalarType(). For struct types,
/// this returns a new struct where each element type has been converted to a
/// scalar type. Note: Only unpacked literal struct types are supported.
Type *ToNarrowTy(Type *Ty);

/// Returns the types contained in `Ty`. For struct types, it returns the
/// elements, all other types are returned directly.
SmallVector<Type *, 2> getContainedTypes(Type *Ty);

/// Returns true if `Ty` is a vector type or a struct of vector types where all
/// vector types share the same VF.
bool isWideTy(Type *Ty);

/// Returns the vectorization factor for a widened type.
inline ElementCount getWideTypeVF(Type *Ty) {
assert(isWideTy(Ty) && "expected widened type");
return cast<VectorType>(getContainedTypes(Ty).front())->getElementCount();
}

} // namespace llvm

#endif
32 changes: 32 additions & 0 deletions llvm/include/llvm/IR/VectorUtils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
//===----------- VectorUtils.h - Vector type utility functions -*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef LLVM_IR_VECTORUTILS_H
#define LLVM_IR_VECTORUTILS_H

#include "llvm/ADT/SmallVector.h"
#include "llvm/IR/DerivedTypes.h"

namespace llvm {

/// A helper function for converting Scalar types to vector types. If
/// the incoming type is void, we return void. If the EC represents a
/// scalar, we return the scalar type.
inline Type *ToVectorTy(Type *Scalar, ElementCount EC) {
if (Scalar->isVoidTy() || Scalar->isMetadataTy() || EC.isScalar())
return Scalar;
return VectorType::get(Scalar, EC);
}

inline Type *ToVectorTy(Type *Scalar, unsigned VF) {
return ToVectorTy(Scalar, ElementCount::getFixed(VF));
}

} // namespace llvm

#endif
1 change: 1 addition & 0 deletions llvm/lib/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ add_llvm_component_library(LLVMCore
AutoUpgrade.cpp
BasicBlock.cpp
BuiltinGCs.cpp
CallWideningUtils.cpp
Comdat.cpp
ConstantFold.cpp
ConstantFPRange.cpp
Expand Down
73 changes: 73 additions & 0 deletions llvm/lib/IR/CallWideningUtils.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
//===----------- VectorUtils.cpp - Vector type utility functions ----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "llvm/IR/CallWideningUtils.h"
#include "llvm/ADT/SmallVectorExtras.h"
#include "llvm/IR/VectorUtils.h"

using namespace llvm;

/// A helper for converting to wider (vector) types. For scalar types, this is
/// equivalent to calling `ToVectorTy`. For struct types, this returns a new
/// struct where each element type has been widened to a vector type. Note: Only
/// unpacked literal struct types are supported.
Type *llvm::ToWideTy(Type *Ty, ElementCount EC) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be worth adding a note that for now the only strategy we support is one where we widen the individual struct members, but in theory it could be extended to other strategies, for example a struct { i32, i32 } with a given VF could return a vector of VF x 2 elements.

Copy link
Member Author

@MacDue MacDue Dec 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm unclear on how other strategies could work in practice here. If you widened struct { i32, i32 } to <4 x i32> where that's meant to represent the interleaved access of two structs, you hit the problem that now there's no defined mapping from <4 x i32> back to a narrow type. It could have come from a struct or a scalar i32.

if (EC.isScalar())
return Ty;
auto *StructTy = dyn_cast<StructType>(Ty);
if (!StructTy)
return ToVectorTy(Ty, EC);
assert(StructTy->isLiteral() && !StructTy->isPacked() &&
"expected unpacked struct literal");
return StructType::get(
Ty->getContext(),
map_to_vector(StructTy->elements(), [&](Type *ElTy) -> Type * {
return VectorType::get(ElTy, EC);
}));
}

/// A helper for converting wide types to narrow (non-vector) types. For vector
/// types, this is equivalent to calling .getScalarType(). For struct types,
/// this returns a new struct where each element type has been converted to a
/// scalar type. Note: Only unpacked literal struct types are supported.
Type *llvm::ToNarrowTy(Type *Ty) {
auto *StructTy = dyn_cast<StructType>(Ty);
if (!StructTy)
return Ty->getScalarType();
assert(StructTy->isLiteral() && !StructTy->isPacked() &&
"expected unpacked struct literal");
return StructType::get(
Ty->getContext(),
map_to_vector(StructTy->elements(), [](Type *ElTy) -> Type * {
return ElTy->getScalarType();
}));
}

/// Returns the types contained in `Ty`. For struct types, it returns the
/// elements, all other types are returned directly.
SmallVector<Type *, 2> llvm::getContainedTypes(Type *Ty) {
auto *StructTy = dyn_cast<StructType>(Ty);
if (StructTy)
return to_vector<2>(StructTy->elements());
return {Ty};
}

/// Returns true if `Ty` is a vector type or a struct of vector types where all
/// vector types share the same VF.
bool llvm::isWideTy(Type *Ty) {
auto *StructTy = dyn_cast<StructType>(Ty);
if (StructTy && (!StructTy->isLiteral() || StructTy->isPacked()))
return false;
auto ContainedTys = getContainedTypes(Ty);
if (ContainedTys.empty() || !ContainedTys.front()->isVectorTy())
return false;
ElementCount VF = cast<VectorType>(ContainedTys.front())->getElementCount();
return all_of(ContainedTys, [&](Type *Ty) {
return Ty->isVectorTy() && cast<VectorType>(Ty)->getElementCount() == VF;
});
}
18 changes: 11 additions & 7 deletions llvm/lib/IR/VFABIDemangler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/IR/CallWideningUtils.h"
#include "llvm/IR/Module.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
Expand Down Expand Up @@ -346,12 +347,15 @@ getScalableECFromSignature(const FunctionType *Signature, const VFISAKind ISA,
// Also check the return type if not void.
Type *RetTy = Signature->getReturnType();
if (!RetTy->isVoidTy()) {
std::optional<ElementCount> ReturnEC = getElementCountForTy(ISA, RetTy);
// If we have an unknown scalar element type we can't find a reasonable VF.
if (!ReturnEC)
return std::nullopt;
if (ElementCount::isKnownLT(*ReturnEC, MinEC))
MinEC = *ReturnEC;
for (Type *RetTy : getContainedTypes(RetTy)) {
std::optional<ElementCount> ReturnEC = getElementCountForTy(ISA, RetTy);
// If we have an unknown scalar element type we can't find a reasonable
// VF.
if (!ReturnEC)
return std::nullopt;
if (ElementCount::isKnownLT(*ReturnEC, MinEC))
MinEC = *ReturnEC;
}
}

// The SVE Vector function call ABI bases the VF on the widest element types
Expand Down Expand Up @@ -566,7 +570,7 @@ FunctionType *VFABI::createFunctionType(const VFInfo &Info,

auto *RetTy = ScalarFTy->getReturnType();
if (!RetTy->isVoidTy())
RetTy = VectorType::get(RetTy, VF);
RetTy = ToWideTy(RetTy, VF);
return FunctionType::get(RetTy, VecTypes, false);
}

Expand Down
1 change: 1 addition & 0 deletions llvm/unittests/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ add_llvm_unittest(IRTests
AttributesTest.cpp
BasicBlockTest.cpp
BasicBlockDbgInfoTest.cpp
CallWideningUtilsTest.cpp
CFGBuilder.cpp
ConstantFPRangeTest.cpp
ConstantRangeTest.cpp
Expand Down
149 changes: 149 additions & 0 deletions llvm/unittests/IR/CallWideningUtilsTest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
//===------- CallWideningUtilsTest.cpp - Call widening utils tests --------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "llvm/IR/CallWideningUtils.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/LLVMContext.h"
#include "gtest/gtest.h"

using namespace llvm;

namespace {

class CallWideningUtilsTest : public ::testing::Test {};

TEST(CallWideningUtilsTest, TestToWideTy) {
LLVMContext C;

Type *ITy = Type::getInt32Ty(C);
Type *FTy = Type::getFloatTy(C);
Type *HomogeneousStructTy = StructType::get(FTy, FTy, FTy);
Type *MixedStructTy = StructType::get(FTy, ITy);
Type *VoidTy = Type::getVoidTy(C);

for (ElementCount VF :
{ElementCount::getFixed(4), ElementCount::getScalable(2)}) {
Type *IntVec = ToWideTy(ITy, VF);
EXPECT_TRUE(isa<VectorType>(IntVec));
EXPECT_EQ(IntVec, VectorType::get(ITy, VF));

Type *FloatVec = ToWideTy(FTy, VF);
EXPECT_TRUE(isa<VectorType>(FloatVec));
EXPECT_EQ(FloatVec, VectorType::get(FTy, VF));

Type *WideHomogeneousStructTy = ToWideTy(HomogeneousStructTy, VF);
EXPECT_TRUE(isa<StructType>(WideHomogeneousStructTy));
EXPECT_TRUE(
cast<StructType>(WideHomogeneousStructTy)->containsHomogeneousTypes());
EXPECT_TRUE(cast<StructType>(WideHomogeneousStructTy)->getNumElements() ==
3);
EXPECT_TRUE(cast<StructType>(WideHomogeneousStructTy)->getElementType(0) ==
VectorType::get(FTy, VF));

Type *WideMixedStructTy = ToWideTy(MixedStructTy, VF);
EXPECT_TRUE(isa<StructType>(WideMixedStructTy));
EXPECT_TRUE(cast<StructType>(WideMixedStructTy)->getNumElements() == 2);
EXPECT_TRUE(cast<StructType>(WideMixedStructTy)->getElementType(0) ==
VectorType::get(FTy, VF));
EXPECT_TRUE(cast<StructType>(WideMixedStructTy)->getElementType(1) ==
VectorType::get(ITy, VF));

EXPECT_EQ(ToWideTy(VoidTy, VF), VoidTy);
}

ElementCount ScalarVF = ElementCount::getFixed(1);
for (Type *Ty : {ITy, FTy, HomogeneousStructTy, MixedStructTy, VoidTy}) {
EXPECT_EQ(ToWideTy(Ty, ScalarVF), Ty);
}
}

TEST(CallWideningUtilsTest, TestToNarrowTy) {
LLVMContext C;

Type *ITy = Type::getInt32Ty(C);
Type *FTy = Type::getFloatTy(C);
Type *HomogeneousStructTy = StructType::get(FTy, FTy, FTy);
Type *MixedStructTy = StructType::get(FTy, ITy);
Type *VoidTy = Type::getVoidTy(C);

for (ElementCount VF : {ElementCount::getFixed(1), ElementCount::getFixed(4),
ElementCount::getScalable(2)}) {
for (Type *Ty : {ITy, FTy, HomogeneousStructTy, MixedStructTy, VoidTy}) {
// ToNarrowTy should be the inverse of ToWideTy.
EXPECT_EQ(ToNarrowTy(ToWideTy(Ty, VF)), Ty);
};
}
}

TEST(CallWideningUtilsTest, TestGetContainedTypes) {
LLVMContext C;

Type *ITy = Type::getInt32Ty(C);
Type *FTy = Type::getFloatTy(C);
Type *HomogeneousStructTy = StructType::get(FTy, FTy, FTy);
Type *MixedStructTy = StructType::get(FTy, ITy);
Type *VoidTy = Type::getVoidTy(C);

EXPECT_EQ(getContainedTypes(ITy), SmallVector<Type *>({ITy}));
EXPECT_EQ(getContainedTypes(FTy), SmallVector<Type *>({FTy}));
EXPECT_EQ(getContainedTypes(VoidTy), SmallVector<Type *>({VoidTy}));
EXPECT_EQ(getContainedTypes(HomogeneousStructTy),
SmallVector<Type *>({FTy, FTy, FTy}));
EXPECT_EQ(getContainedTypes(MixedStructTy), SmallVector<Type *>({FTy, ITy}));
}

TEST(CallWideningUtilsTest, TestIsWideTy) {
LLVMContext C;

Type *ITy = Type::getInt32Ty(C);
Type *FTy = Type::getFloatTy(C);
Type *NarrowStruct = StructType::get(FTy, ITy);
Type *VoidTy = Type::getVoidTy(C);

EXPECT_FALSE(isWideTy(ITy));
EXPECT_FALSE(isWideTy(NarrowStruct));
EXPECT_FALSE(isWideTy(VoidTy));

ElementCount VF = ElementCount::getFixed(4);
EXPECT_TRUE(isWideTy(ToWideTy(ITy, VF)));
EXPECT_TRUE(isWideTy(ToWideTy(NarrowStruct, VF)));

Type *MixedVFStruct =
StructType::get(VectorType::get(ITy, ElementCount::getFixed(2)),
VectorType::get(ITy, ElementCount::getFixed(4)));
EXPECT_FALSE(isWideTy(MixedVFStruct));

// Currently only literals types are considered wide.
Type *NamedWideStruct = StructType::create("Named", VectorType::get(ITy, VF),
VectorType::get(ITy, VF));
EXPECT_FALSE(isWideTy(NamedWideStruct));

// Currently only unpacked types are considered wide.
Type *PackedWideStruct = StructType::get(
C, ArrayRef<Type *>{VectorType::get(ITy, VF), VectorType::get(ITy, VF)},
/*isPacked=*/true);
EXPECT_FALSE(isWideTy(PackedWideStruct));
}

TEST(CallWideningUtilsTest, TestGetWideTypeVF) {
LLVMContext C;

Type *ITy = Type::getInt32Ty(C);
Type *FTy = Type::getFloatTy(C);
Type *HomogeneousStructTy = StructType::get(FTy, FTy, FTy);
Type *MixedStructTy = StructType::get(FTy, ITy);

for (ElementCount VF :
{ElementCount::getFixed(4), ElementCount::getScalable(2)}) {
for (Type *Ty : {ITy, FTy, HomogeneousStructTy, MixedStructTy}) {
EXPECT_EQ(getWideTypeVF(ToWideTy(Ty, VF)), VF);
};
}
}

} // namespace
Loading
Loading