Skip to content

Commit 1ee740a

Browse files
authored
[VFABI] Add support for vector functions that return struct types (#119000)
This patch updates the `VFABIDemangler` to support vector functions that return struct types. For example, a vector variant of `sincos` that returns a vector of sine values and a vector of cosine values within a struct. This patch also adds some helpers for vectorizing types (including struct types). Some of these are used in the `VFABIDemangler`, and others will be used in subsequent patches, so this patch simply adds tests for them.
1 parent 16c02df commit 1ee740a

File tree

8 files changed

+399
-20
lines changed

8 files changed

+399
-20
lines changed

llvm/include/llvm/Analysis/VectorUtils.h

+1-13
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "llvm/Analysis/LoopAccessAnalysis.h"
1919
#include "llvm/IR/Module.h"
2020
#include "llvm/IR/VFABIDemangler.h"
21+
#include "llvm/IR/VectorTypeUtils.h"
2122
#include "llvm/Support/CheckedArithmetic.h"
2223

2324
namespace llvm {
@@ -127,19 +128,6 @@ namespace Intrinsic {
127128
typedef unsigned ID;
128129
}
129130

130-
/// A helper function for converting Scalar types to vector types. If
131-
/// the incoming type is void, we return void. If the EC represents a
132-
/// scalar, we return the scalar type.
133-
inline Type *ToVectorTy(Type *Scalar, ElementCount EC) {
134-
if (Scalar->isVoidTy() || Scalar->isMetadataTy() || EC.isScalar())
135-
return Scalar;
136-
return VectorType::get(Scalar, EC);
137-
}
138-
139-
inline Type *ToVectorTy(Type *Scalar, unsigned VF) {
140-
return ToVectorTy(Scalar, ElementCount::getFixed(VF));
141-
}
142-
143131
/// Identify if the intrinsic is trivially vectorizable.
144132
/// This method returns true if the intrinsic's argument types are all scalars
145133
/// for the scalar form of the intrinsic and all vectors (or scalars handled by
+94
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
//===------- VectorTypeUtils.h - Vector type utility functions -*- C++ -*-====//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef LLVM_IR_VECTORTYPEUTILS_H
10+
#define LLVM_IR_VECTORTYPEUTILS_H
11+
12+
#include "llvm/IR/DerivedTypes.h"
13+
14+
namespace llvm {
15+
16+
/// A helper function for converting Scalar types to vector types. If
17+
/// the incoming type is void, we return void. If the EC represents a
18+
/// scalar, we return the scalar type.
19+
inline Type *ToVectorTy(Type *Scalar, ElementCount EC) {
20+
if (Scalar->isVoidTy() || Scalar->isMetadataTy() || EC.isScalar())
21+
return Scalar;
22+
return VectorType::get(Scalar, EC);
23+
}
24+
25+
inline Type *ToVectorTy(Type *Scalar, unsigned VF) {
26+
return ToVectorTy(Scalar, ElementCount::getFixed(VF));
27+
}
28+
29+
/// A helper for converting structs of scalar types to structs of vector types.
30+
/// Note:
31+
/// - If \p EC is scalar, \p StructTy is returned unchanged
32+
/// - Only unpacked literal struct types are supported
33+
Type *toVectorizedStructTy(StructType *StructTy, ElementCount EC);
34+
35+
/// A helper for converting structs of vector types to structs of scalar types.
36+
/// Note: Only unpacked literal struct types are supported.
37+
Type *toScalarizedStructTy(StructType *StructTy);
38+
39+
/// Returns true if `StructTy` is an unpacked literal struct where all elements
40+
/// are vectors of matching element count. This does not include empty structs.
41+
bool isVectorizedStructTy(StructType *StructTy);
42+
43+
/// A helper for converting to vectorized types. For scalar types, this is
44+
/// equivalent to calling `ToVectorTy`. For struct types, this returns a new
45+
/// struct where each element type has been widened to a vector type.
46+
/// Note:
47+
/// - If the incoming type is void, we return void
48+
/// - If \p EC is scalar, \p Ty is returned unchanged
49+
/// - Only unpacked literal struct types are supported
50+
inline Type *toVectorizedTy(Type *Ty, ElementCount EC) {
51+
if (StructType *StructTy = dyn_cast<StructType>(Ty))
52+
return toVectorizedStructTy(StructTy, EC);
53+
return ToVectorTy(Ty, EC);
54+
}
55+
56+
/// A helper for converting vectorized types to scalarized (non-vector) types.
57+
/// For vector types, this is equivalent to calling .getScalarType(). For struct
58+
/// types, this returns a new struct where each element type has been converted
59+
/// to a scalar type. Note: Only unpacked literal struct types are supported.
60+
inline Type *toScalarizedTy(Type *Ty) {
61+
if (StructType *StructTy = dyn_cast<StructType>(Ty))
62+
return toScalarizedStructTy(StructTy);
63+
return Ty->getScalarType();
64+
}
65+
66+
/// Returns true if `Ty` is a vector type or a struct of vector types where all
67+
/// vector types share the same VF.
68+
inline bool isVectorizedTy(Type *Ty) {
69+
if (StructType *StructTy = dyn_cast<StructType>(Ty))
70+
return isVectorizedStructTy(StructTy);
71+
return Ty->isVectorTy();
72+
}
73+
74+
/// Returns the types contained in `Ty`. For struct types, it returns the
75+
/// elements, all other types are returned directly.
76+
inline ArrayRef<Type *> getContainedTypes(Type *const &Ty) {
77+
if (auto *StructTy = dyn_cast<StructType>(Ty))
78+
return StructTy->elements();
79+
return ArrayRef<Type *>(&Ty, 1);
80+
}
81+
82+
/// Returns the number of vector elements for a vectorized type.
83+
inline ElementCount getVectorizedTypeVF(Type *Ty) {
84+
assert(isVectorizedTy(Ty) && "expected vectorized type");
85+
return cast<VectorType>(getContainedTypes(Ty).front())->getElementCount();
86+
}
87+
88+
inline bool isUnpackedStructLiteral(StructType *StructTy) {
89+
return StructTy->isLiteral() && !StructTy->isPacked();
90+
}
91+
92+
} // namespace llvm
93+
94+
#endif

llvm/lib/IR/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ add_llvm_component_library(LLVMCore
7373
Value.cpp
7474
ValueSymbolTable.cpp
7575
VectorBuilder.cpp
76+
VectorTypeUtils.cpp
7677
Verifier.cpp
7778
VFABIDemangler.cpp
7879
RuntimeLibcalls.cpp

llvm/lib/IR/VFABIDemangler.cpp

+15-6
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "llvm/ADT/SmallString.h"
1212
#include "llvm/ADT/StringSwitch.h"
1313
#include "llvm/IR/Module.h"
14+
#include "llvm/IR/VectorTypeUtils.h"
1415
#include "llvm/Support/Debug.h"
1516
#include "llvm/Support/raw_ostream.h"
1617
#include <limits>
@@ -346,12 +347,20 @@ getScalableECFromSignature(const FunctionType *Signature, const VFISAKind ISA,
346347
// Also check the return type if not void.
347348
Type *RetTy = Signature->getReturnType();
348349
if (!RetTy->isVoidTy()) {
349-
std::optional<ElementCount> ReturnEC = getElementCountForTy(ISA, RetTy);
350-
// If we have an unknown scalar element type we can't find a reasonable VF.
351-
if (!ReturnEC)
350+
// If the return type is a struct, only allow unpacked struct literals.
351+
StructType *StructTy = dyn_cast<StructType>(RetTy);
352+
if (StructTy && !isUnpackedStructLiteral(StructTy))
352353
return std::nullopt;
353-
if (ElementCount::isKnownLT(*ReturnEC, MinEC))
354-
MinEC = *ReturnEC;
354+
355+
for (Type *RetTy : getContainedTypes(RetTy)) {
356+
std::optional<ElementCount> ReturnEC = getElementCountForTy(ISA, RetTy);
357+
// If we have an unknown scalar element type we can't find a reasonable
358+
// VF.
359+
if (!ReturnEC)
360+
return std::nullopt;
361+
if (ElementCount::isKnownLT(*ReturnEC, MinEC))
362+
MinEC = *ReturnEC;
363+
}
355364
}
356365

357366
// The SVE Vector function call ABI bases the VF on the widest element types
@@ -566,7 +575,7 @@ FunctionType *VFABI::createFunctionType(const VFInfo &Info,
566575

567576
auto *RetTy = ScalarFTy->getReturnType();
568577
if (!RetTy->isVoidTy())
569-
RetTy = VectorType::get(RetTy, VF);
578+
RetTy = toVectorizedTy(RetTy, VF);
570579
return FunctionType::get(RetTy, VecTypes, false);
571580
}
572581

llvm/lib/IR/VectorTypeUtils.cpp

+54
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
//===------- VectorTypeUtils.cpp - Vector type utility functions ----------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "llvm/IR/VectorTypeUtils.h"
10+
#include "llvm/ADT/SmallVectorExtras.h"
11+
12+
using namespace llvm;
13+
14+
/// A helper for converting structs of scalar types to structs of vector types.
15+
/// Note: Only unpacked literal struct types are supported.
16+
Type *llvm::toVectorizedStructTy(StructType *StructTy, ElementCount EC) {
17+
if (EC.isScalar())
18+
return StructTy;
19+
assert(isUnpackedStructLiteral(StructTy) &&
20+
"expected unpacked struct literal");
21+
assert(all_of(StructTy->elements(), VectorType::isValidElementType) &&
22+
"expected all element types to be valid vector element types");
23+
return StructType::get(
24+
StructTy->getContext(),
25+
map_to_vector(StructTy->elements(), [&](Type *ElTy) -> Type * {
26+
return VectorType::get(ElTy, EC);
27+
}));
28+
}
29+
30+
/// A helper for converting structs of vector types to structs of scalar types.
31+
/// Note: Only unpacked literal struct types are supported.
32+
Type *llvm::toScalarizedStructTy(StructType *StructTy) {
33+
assert(isUnpackedStructLiteral(StructTy) &&
34+
"expected unpacked struct literal");
35+
return StructType::get(
36+
StructTy->getContext(),
37+
map_to_vector(StructTy->elements(), [](Type *ElTy) -> Type * {
38+
return ElTy->getScalarType();
39+
}));
40+
}
41+
42+
/// Returns true if `StructTy` is an unpacked literal struct where all elements
43+
/// are vectors of matching element count. This does not include empty structs.
44+
bool llvm::isVectorizedStructTy(StructType *StructTy) {
45+
if (!isUnpackedStructLiteral(StructTy))
46+
return false;
47+
auto ElemTys = StructTy->elements();
48+
if (ElemTys.empty() || !ElemTys.front()->isVectorTy())
49+
return false;
50+
ElementCount VF = cast<VectorType>(ElemTys.front())->getElementCount();
51+
return all_of(ElemTys, [&](Type *Ty) {
52+
return Ty->isVectorTy() && cast<VectorType>(Ty)->getElementCount() == VF;
53+
});
54+
}

llvm/unittests/IR/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ add_llvm_unittest(IRTests
5151
ValueMapTest.cpp
5252
ValueTest.cpp
5353
VectorBuilderTest.cpp
54+
VectorTypeUtilsTest.cpp
5455
VectorTypesTest.cpp
5556
VerifierTest.cpp
5657
VFABIDemanglerTest.cpp

llvm/unittests/IR/VFABIDemanglerTest.cpp

+84-1
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,9 @@ class VFABIParserTest : public ::testing::Test {
4040
VFInfo Info;
4141
/// Reset the data needed for the test.
4242
void reset(const StringRef ScalarFTyStr) {
43-
M = parseAssemblyString("declare void @dummy()", Err, Ctx);
43+
M = parseAssemblyString("%dummy_named_struct = type { double, double }\n"
44+
"declare void @dummy()",
45+
Err, Ctx);
4446
EXPECT_NE(M.get(), nullptr)
4547
<< "Loading an invalid module.\n " << Err.getMessage() << "\n";
4648
Type *Ty = parseType(ScalarFTyStr, Err, *(M));
@@ -753,6 +755,87 @@ TEST_F(VFABIParserTest, ParseVoidReturnTypeSVE) {
753755
EXPECT_EQ(VectorName, "vector_foo");
754756
}
755757

758+
TEST_F(VFABIParserTest, ParseWideStructReturnTypeSVE) {
759+
EXPECT_TRUE(
760+
invokeParser("_ZGVsMxv_foo(vector_foo)", "{double, double}(float)"));
761+
EXPECT_EQ(ISA, VFISAKind::SVE);
762+
EXPECT_TRUE(isMasked());
763+
ElementCount NXV2 = ElementCount::getScalable(2);
764+
FunctionType *FTy = FunctionType::get(
765+
StructType::get(VectorType::get(Type::getDoubleTy(Ctx), NXV2),
766+
VectorType::get(Type::getDoubleTy(Ctx), NXV2)),
767+
{
768+
VectorType::get(Type::getFloatTy(Ctx), NXV2),
769+
VectorType::get(Type::getInt1Ty(Ctx), NXV2),
770+
},
771+
false);
772+
EXPECT_EQ(getFunctionType(), FTy);
773+
EXPECT_EQ(Parameters.size(), 2U);
774+
EXPECT_EQ(Parameters[0], VFParameter({0, VFParamKind::Vector}));
775+
EXPECT_EQ(Parameters[1], VFParameter({1, VFParamKind::GlobalPredicate}));
776+
EXPECT_EQ(VF, NXV2);
777+
EXPECT_EQ(ScalarName, "foo");
778+
EXPECT_EQ(VectorName, "vector_foo");
779+
}
780+
781+
TEST_F(VFABIParserTest, ParseWideStructMixedReturnTypeSVE) {
782+
EXPECT_TRUE(invokeParser("_ZGVsMxv_foo(vector_foo)", "{float, i64}(float)"));
783+
EXPECT_EQ(ISA, VFISAKind::SVE);
784+
EXPECT_TRUE(isMasked());
785+
ElementCount NXV2 = ElementCount::getScalable(2);
786+
FunctionType *FTy = FunctionType::get(
787+
StructType::get(VectorType::get(Type::getFloatTy(Ctx), NXV2),
788+
VectorType::get(Type::getInt64Ty(Ctx), NXV2)),
789+
{
790+
VectorType::get(Type::getFloatTy(Ctx), NXV2),
791+
VectorType::get(Type::getInt1Ty(Ctx), NXV2),
792+
},
793+
false);
794+
EXPECT_EQ(getFunctionType(), FTy);
795+
EXPECT_EQ(Parameters.size(), 2U);
796+
EXPECT_EQ(Parameters[0], VFParameter({0, VFParamKind::Vector}));
797+
EXPECT_EQ(Parameters[1], VFParameter({1, VFParamKind::GlobalPredicate}));
798+
EXPECT_EQ(VF, NXV2);
799+
EXPECT_EQ(ScalarName, "foo");
800+
EXPECT_EQ(VectorName, "vector_foo");
801+
}
802+
803+
TEST_F(VFABIParserTest, ParseWideStructReturnTypeNEON) {
804+
EXPECT_TRUE(
805+
invokeParser("_ZGVnN4v_foo(vector_foo)", "{float, float}(float)"));
806+
EXPECT_EQ(ISA, VFISAKind::AdvancedSIMD);
807+
EXPECT_FALSE(isMasked());
808+
ElementCount V4 = ElementCount::getFixed(4);
809+
FunctionType *FTy = FunctionType::get(
810+
StructType::get(VectorType::get(Type::getFloatTy(Ctx), V4),
811+
VectorType::get(Type::getFloatTy(Ctx), V4)),
812+
{
813+
VectorType::get(Type::getFloatTy(Ctx), V4),
814+
},
815+
false);
816+
EXPECT_EQ(getFunctionType(), FTy);
817+
EXPECT_EQ(Parameters.size(), 1U);
818+
EXPECT_EQ(Parameters[0], VFParameter({0, VFParamKind::Vector}));
819+
EXPECT_EQ(VF, V4);
820+
EXPECT_EQ(ScalarName, "foo");
821+
EXPECT_EQ(VectorName, "vector_foo");
822+
}
823+
824+
TEST_F(VFABIParserTest, ParseUnsupportedStructReturnTypesSVE) {
825+
// Struct with array element type.
826+
EXPECT_FALSE(
827+
invokeParser("_ZGVsMxv_foo(vector_foo)", "{double, [4 x float]}(float)"));
828+
// Nested struct type.
829+
EXPECT_FALSE(
830+
invokeParser("_ZGVsMxv_foo(vector_foo)", "{{float, float}}(float)"));
831+
// Packed struct type.
832+
EXPECT_FALSE(
833+
invokeParser("_ZGVsMxv_foo(vector_foo)", "<{double, float}>(float)"));
834+
// Named struct type.
835+
EXPECT_FALSE(
836+
invokeParser("_ZGVsMxv_foo(vector_foo)", "%dummy_named_struct(float)"));
837+
}
838+
756839
// Make sure we reject unsupported parameter types.
757840
TEST_F(VFABIParserTest, ParseUnsupportedElementTypeSVE) {
758841
EXPECT_FALSE(invokeParser("_ZGVsMxv_foo(vector_foo)", "void(i128)"));

0 commit comments

Comments
 (0)