Skip to content

Commit eae6288

Browse files
authored
[CIR][Lowering] Handling Lowering of multiple dimension array correctly (#961)
Close #957 the previous algorithm to convert a multiple dimension array to a tensor is: fill the value one by one and fill the zero values in conditions. And it has some problems handling the multiple dimension array as above issue shows so that the generated values are not in the same shape with the original array. the new algorithm here is, full fill the values ahead of time with the correct element size and full fill the values to different slots and we only need to maintain the index to write. I feel the new version has better performance (avoid allocation) and better readability slightly.
1 parent 4c446b3 commit eae6288

File tree

2 files changed

+81
-62
lines changed

2 files changed

+81
-62
lines changed

clang/lib/CIR/Lowering/LoweringHelpers.cpp

+45-62
Original file line numberDiff line numberDiff line change
@@ -45,84 +45,67 @@ template <> mlir::APFloat getZeroInitFromType(mlir::Type Ty) {
4545
llvm_unreachable("NYI");
4646
}
4747

48-
// return the nested type and quantity of elements for cir.array type.
49-
// e.g: for !cir.array<!cir.array<!s32i x 3> x 1>
50-
// it returns !s32i as return value and stores 3 to elemQuantity.
51-
mlir::Type getNestedTypeAndElemQuantity(mlir::Type Ty, unsigned &elemQuantity) {
52-
assert(mlir::isa<mlir::cir::ArrayType>(Ty) && "expected ArrayType");
53-
54-
elemQuantity = 1;
55-
mlir::Type nestTy = Ty;
56-
while (auto ArrTy = mlir::dyn_cast<mlir::cir::ArrayType>(nestTy)) {
57-
nestTy = ArrTy.getEltType();
58-
elemQuantity *= ArrTy.getSize();
59-
}
60-
61-
return nestTy;
62-
}
63-
64-
template <typename StorageTy>
65-
void fillTrailingZeros(mlir::cir::ConstArrayAttr attr,
66-
llvm::SmallVectorImpl<StorageTy> &values) {
67-
auto numTrailingZeros = attr.getTrailingZerosNum();
68-
if (numTrailingZeros) {
69-
auto localArrayTy = mlir::dyn_cast<mlir::cir::ArrayType>(attr.getType());
70-
assert(localArrayTy && "expected !cir.array");
71-
72-
auto nestTy = localArrayTy.getEltType();
73-
if (!mlir::isa<mlir::cir::ArrayType>(nestTy))
74-
values.insert(values.end(), numTrailingZeros,
75-
getZeroInitFromType<StorageTy>(nestTy));
76-
}
77-
}
78-
48+
/// \param attr the ConstArrayAttr to convert
49+
/// \param values the output parameter, the values array to fill
50+
/// \param currentDims the shpae of tensor we're going to convert to
51+
/// \param dimIndex the current dimension we're processing
52+
/// \param currentIndex the current index in the values array
7953
template <typename AttrTy, typename StorageTy>
80-
void convertToDenseElementsAttrImpl(mlir::cir::ConstArrayAttr attr,
81-
llvm::SmallVectorImpl<StorageTy> &values) {
54+
void convertToDenseElementsAttrImpl(
55+
mlir::cir::ConstArrayAttr attr, llvm::SmallVectorImpl<StorageTy> &values,
56+
const llvm::SmallVectorImpl<int64_t> &currentDims, int64_t dimIndex,
57+
int64_t currentIndex) {
8258
if (auto stringAttr = mlir::dyn_cast<mlir::StringAttr>(attr.getElts())) {
8359
if (auto arrayType = mlir::dyn_cast<mlir::cir::ArrayType>(attr.getType())) {
8460
for (auto element : stringAttr) {
8561
auto intAttr = mlir::cir::IntAttr::get(arrayType.getEltType(), element);
86-
values.push_back(mlir::dyn_cast<AttrTy>(intAttr).getValue());
62+
values[currentIndex++] = mlir::dyn_cast<AttrTy>(intAttr).getValue();
8763
}
8864
return;
8965
}
9066
}
9167

68+
dimIndex++;
69+
std::size_t elementsSizeInCurrentDim = 1;
70+
for (std::size_t i = dimIndex; i < currentDims.size(); i++)
71+
elementsSizeInCurrentDim *= currentDims[i];
72+
9273
auto arrayAttr = mlir::cast<mlir::ArrayAttr>(attr.getElts());
9374
for (auto eltAttr : arrayAttr) {
9475
if (auto valueAttr = mlir::dyn_cast<AttrTy>(eltAttr)) {
95-
values.push_back(valueAttr.getValue());
96-
} else if (auto subArrayAttr =
97-
mlir::dyn_cast<mlir::cir::ConstArrayAttr>(eltAttr)) {
98-
convertToDenseElementsAttrImpl<AttrTy>(subArrayAttr, values);
99-
if (mlir::dyn_cast<mlir::StringAttr>(subArrayAttr.getElts()))
100-
fillTrailingZeros(subArrayAttr, values);
101-
} else if (auto zeroAttr = mlir::dyn_cast<mlir::cir::ZeroAttr>(eltAttr)) {
102-
unsigned numStoredZeros = 0;
103-
auto nestTy =
104-
getNestedTypeAndElemQuantity(zeroAttr.getType(), numStoredZeros);
105-
values.insert(values.end(), numStoredZeros,
106-
getZeroInitFromType<StorageTy>(nestTy));
107-
} else {
108-
llvm_unreachable("unknown element in ConstArrayAttr");
76+
values[currentIndex++] = valueAttr.getValue();
77+
continue;
10978
}
110-
}
11179

112-
// Only fill in trailing zeros at the local cir.array level where the element
113-
// type isn't another array (for the mult-dim case).
114-
fillTrailingZeros(attr, values);
80+
if (auto subArrayAttr =
81+
mlir::dyn_cast<mlir::cir::ConstArrayAttr>(eltAttr)) {
82+
convertToDenseElementsAttrImpl<AttrTy>(subArrayAttr, values, currentDims,
83+
dimIndex, currentIndex);
84+
currentIndex += elementsSizeInCurrentDim;
85+
continue;
86+
}
87+
88+
if (mlir::isa<mlir::cir::ZeroAttr>(eltAttr))
89+
continue;
90+
91+
llvm_unreachable("unknown element in ConstArrayAttr");
92+
}
11593
}
11694

11795
template <typename AttrTy, typename StorageTy>
118-
mlir::DenseElementsAttr
119-
convertToDenseElementsAttr(mlir::cir::ConstArrayAttr attr,
120-
const llvm::SmallVectorImpl<int64_t> &dims,
121-
mlir::Type type) {
122-
auto values = llvm::SmallVector<StorageTy, 8>{};
123-
convertToDenseElementsAttrImpl<AttrTy>(attr, values);
124-
return mlir::DenseElementsAttr::get(mlir::RankedTensorType::get(dims, type),
125-
llvm::ArrayRef(values));
96+
mlir::DenseElementsAttr convertToDenseElementsAttr(
97+
mlir::cir::ConstArrayAttr attr, const llvm::SmallVectorImpl<int64_t> &dims,
98+
mlir::Type elementType, mlir::Type convertedElementType) {
99+
unsigned vector_size = 1;
100+
for (auto dim : dims)
101+
vector_size *= dim;
102+
auto values = llvm::SmallVector<StorageTy, 8>(
103+
vector_size, getZeroInitFromType<StorageTy>(elementType));
104+
convertToDenseElementsAttrImpl<AttrTy>(attr, values, dims, /*currentDim=*/0,
105+
/*initialIndex=*/0);
106+
return mlir::DenseElementsAttr::get(
107+
mlir::RankedTensorType::get(dims, convertedElementType),
108+
llvm::ArrayRef(values));
126109
}
127110

128111
std::optional<mlir::Attribute>
@@ -151,10 +134,10 @@ lowerConstArrayAttr(mlir::cir::ConstArrayAttr constArr,
151134
converter->convertType(type));
152135
if (mlir::isa<mlir::cir::IntType>(type))
153136
return convertToDenseElementsAttr<mlir::cir::IntAttr, mlir::APInt>(
154-
constArr, dims, converter->convertType(type));
137+
constArr, dims, type, converter->convertType(type));
155138
if (mlir::isa<mlir::cir::CIRFPTypeInterface>(type))
156139
return convertToDenseElementsAttr<mlir::cir::FPAttr, mlir::APFloat>(
157-
constArr, dims, converter->convertType(type));
140+
constArr, dims, type, converter->convertType(type));
158141

159142
return std::nullopt;
160143
}

clang/test/CIR/Lowering/multi-array.c

+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -emit-llvm %s -o %t.ll
2+
// RUN: FileCheck --input-file=%t.ll %s -check-prefix=LLVM
3+
4+
unsigned char table[10][5] =
5+
{
6+
{1,0},
7+
{7,6,5},
8+
};
9+
10+
// LLVM: @table = {{.*}}[10 x [5 x i8]] {{.*}}[5 x i8] c"\01\00\00\00\00", [5 x i8] c"\07\06\05\00\00", [5 x i8] zeroinitializer
11+
12+
unsigned char table2[15][16] =
13+
{
14+
{1,0},
15+
{1,1,0},
16+
{3,2,1,0},
17+
{3,2,1,1,0},
18+
{3,2,3,2,1,0},
19+
{3,0,1,3,2,5,4},
20+
{7,6,5,4,3,2,1,1,1,1,1,1,1,1,1},
21+
};
22+
23+
// LLVM: @table2 = {{.*}}[15 x [16 x i8]] {{.*}}[16 x i8] c"\01\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00", [16 x i8] c"\01\01\00\00\00\00\00\00\00\00\00\00\00\00\00\00", [16 x i8] c"\03\02\01\00\00\00\00\00\00\00\00\00\00\00\00\00", [16 x i8] c"\03\02\01\01\00\00\00\00\00\00\00\00\00\00\00\00", [16 x i8] c"\03\02\03\02\01\00\00\00\00\00\00\00\00\00\00\00", [16 x i8] c"\03\00\01\03\02\05\04\00\00\00\00\00\00\00\00\00", [16 x i8] c"\07\06\05\04\03\02\01\01\01\01\01\01\01\01\01\00", [16 x i8] zeroinitializer
24+
25+
unsigned char table3[15][16] =
26+
{
27+
{1,1},
28+
{1,2,2},
29+
{2,2,2,2},
30+
{2,2,2,3,3},
31+
{2,2,3,3,3,3},
32+
{2,3,3,3,3,3,3},
33+
{3,3,3,3,3,3,3,4,5,6,7,8,9,10,11},
34+
};
35+
36+
// LLVM: @table3 = {{.*}}[15 x [16 x i8]] {{.*}}[16 x i8] c"\01\01\00\00\00\00\00\00\00\00\00\00\00\00\00\00", [16 x i8] c"\01\02\02\00\00\00\00\00\00\00\00\00\00\00\00\00", [16 x i8] c"\02\02\02\02\00\00\00\00\00\00\00\00\00\00\00\00", [16 x i8] c"\02\02\02\03\03\00\00\00\00\00\00\00\00\00\00\00", [16 x i8] c"\02\02\03\03\03\03\00\00\00\00\00\00\00\00\00\00", [16 x i8] c"\02\03\03\03\03\03\03\00\00\00\00\00\00\00\00\00", [16 x i8] c"\03\03\03\03\03\03\03\04\05\06\07\08\09\0A\0B\00", [16 x i8] zeroinitializer

0 commit comments

Comments
 (0)