Skip to content

Commit 31001be

Browse files
LemonBoytstellar
authored andcommitted
[LoopVectorize] Refine hasIrregularType predicate
The `hasIrregularType` predicate checks whether an array of N values of type Ty is "bitcast-compatible" with a <N x Ty> vector. The previous check returned invalid results in some cases where there's some padding between the array elements: eg. a 4-element array of u7 values is considered as compatible with <4 x u7>, even though the vector is only loading/storing 28 bits instead of 32. The problem causes LLVM to generate incorrect code for some targets: for AArch64 the vector loads/stores are lowered in terms of ubfx/bfi, effectively losing the top (N * padding bits). Reviewed By: lebedev.ri Differential Revision: https://reviews.llvm.org/D97465 (cherry picked from commit 4f02493)
1 parent 9ae9ab1 commit 31001be

File tree

2 files changed

+34
-15
lines changed

2 files changed

+34
-15
lines changed

Diff for: llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

+7-15
Original file line numberDiff line numberDiff line change
@@ -372,19 +372,11 @@ static Type *getMemInstValueType(Value *I) {
372372

373373
/// A helper function that returns true if the given type is irregular. The
374374
/// type is irregular if its allocated size doesn't equal the store size of an
375-
/// element of the corresponding vector type at the given vectorization factor.
376-
static bool hasIrregularType(Type *Ty, const DataLayout &DL, ElementCount VF) {
377-
// Determine if an array of VF elements of type Ty is "bitcast compatible"
378-
// with a <VF x Ty> vector.
379-
if (VF.isVector()) {
380-
auto *VectorTy = VectorType::get(Ty, VF);
381-
return TypeSize::get(VF.getKnownMinValue() *
382-
DL.getTypeAllocSize(Ty).getFixedValue(),
383-
VF.isScalable()) != DL.getTypeStoreSize(VectorTy);
384-
}
385-
386-
// If the vectorization factor is one, we just check if an array of type Ty
387-
// requires padding between elements.
375+
/// element of the corresponding vector type.
376+
static bool hasIrregularType(Type *Ty, const DataLayout &DL) {
377+
// Determine if an array of N elements of type Ty is "bitcast compatible"
378+
// with a <N x Ty> vector.
379+
// This is only true if there is no padding between the array elements.
388380
return DL.getTypeAllocSizeInBits(Ty) != DL.getTypeSizeInBits(Ty);
389381
}
390382

@@ -5212,7 +5204,7 @@ bool LoopVectorizationCostModel::interleavedAccessCanBeWidened(
52125204
// requires padding and will be scalarized.
52135205
auto &DL = I->getModule()->getDataLayout();
52145206
auto *ScalarTy = getMemInstValueType(I);
5215-
if (hasIrregularType(ScalarTy, DL, VF))
5207+
if (hasIrregularType(ScalarTy, DL))
52165208
return false;
52175209

52185210
// Check if masking is required.
@@ -5259,7 +5251,7 @@ bool LoopVectorizationCostModel::memoryInstructionCanBeWidened(
52595251
// requires padding and will be scalarized.
52605252
auto &DL = I->getModule()->getDataLayout();
52615253
auto *ScalarTy = LI ? LI->getType() : SI->getValueOperand()->getType();
5262-
if (hasIrregularType(ScalarTy, DL, VF))
5254+
if (hasIrregularType(ScalarTy, DL))
52635255
return false;
52645256

52655257
return true;

Diff for: llvm/test/Transforms/LoopVectorize/irregular_type.ll

+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
; RUN: opt %s -loop-vectorize -force-vector-width=4 -S | FileCheck %s
2+
3+
; Ensure the array loads/stores are not optimized into vector operations when
4+
; the element type has padding bits.
5+
6+
; CHECK: foo
7+
; CHECK: vector.body
8+
; CHECK-NOT: load <4 x i7>
9+
; CHECK-NOT: store <4 x i7>
10+
; CHECK: for.body
11+
define void @foo(i7* %a, i64 %n) {
12+
entry:
13+
br label %for.body
14+
15+
for.body:
16+
%indvars.iv = phi i64 [ 0, %entry ], [ %indvars.iv.next, %for.body ]
17+
%arrayidx = getelementptr inbounds i7, i7* %a, i64 %indvars.iv
18+
%0 = load i7, i7* %arrayidx, align 1
19+
%sub = add nuw nsw i7 %0, 0
20+
store i7 %sub, i7* %arrayidx, align 1
21+
%indvars.iv.next = add nuw nsw i64 %indvars.iv, 1
22+
%cmp = icmp eq i64 %indvars.iv.next, %n
23+
br i1 %cmp, label %for.exit, label %for.body
24+
25+
for.exit:
26+
ret void
27+
}

0 commit comments

Comments
 (0)