Skip to content

Commit 58e8147

Browse files
authored
[flang][openacc] Use original input for base address with optional (#80931)
In #80317 the data op generation was updated to use correctly the #0 result from the hlfir.delcare op. In case of optional that are not descriptor, it is preferable to use the original input for the varPtr value of the OpenACC data op. This patch also make sure that the descriptor value of optional is only accessed when present.
1 parent 5aeabf2 commit 58e8147

File tree

3 files changed

+124
-27
lines changed

3 files changed

+124
-27
lines changed

flang/lib/Lower/DirectivesCommon.h

Lines changed: 73 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,13 @@ namespace lower {
5252
/// operations.
5353
struct AddrAndBoundsInfo {
5454
explicit AddrAndBoundsInfo() {}
55-
explicit AddrAndBoundsInfo(mlir::Value addr) : addr(addr) {}
56-
explicit AddrAndBoundsInfo(mlir::Value addr, mlir::Value isPresent)
57-
: addr(addr), isPresent(isPresent) {}
55+
explicit AddrAndBoundsInfo(mlir::Value addr, mlir::Value rawInput)
56+
: addr(addr), rawInput(rawInput) {}
57+
explicit AddrAndBoundsInfo(mlir::Value addr, mlir::Value rawInput,
58+
mlir::Value isPresent)
59+
: addr(addr), rawInput(rawInput), isPresent(isPresent) {}
5860
mlir::Value addr = nullptr;
61+
mlir::Value rawInput = nullptr;
5962
mlir::Value isPresent = nullptr;
6063
};
6164

@@ -615,20 +618,30 @@ getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
615618
fir::FirOpBuilder &builder,
616619
Fortran::lower::SymbolRef sym, mlir::Location loc) {
617620
mlir::Value symAddr = converter.getSymbolAddress(sym);
621+
mlir::Value rawInput = symAddr;
618622
if (auto declareOp =
619-
mlir::dyn_cast_or_null<hlfir::DeclareOp>(symAddr.getDefiningOp()))
623+
mlir::dyn_cast_or_null<hlfir::DeclareOp>(symAddr.getDefiningOp())) {
620624
symAddr = declareOp.getResults()[0];
625+
rawInput = declareOp.getResults()[1];
626+
}
621627

622628
// TODO: Might need revisiting to handle for non-shared clauses
623629
if (!symAddr) {
624630
if (const auto *details =
625-
sym->detailsIf<Fortran::semantics::HostAssocDetails>())
631+
sym->detailsIf<Fortran::semantics::HostAssocDetails>()) {
626632
symAddr = converter.getSymbolAddress(details->symbol());
633+
rawInput = symAddr;
634+
}
627635
}
628636

629637
if (!symAddr)
630638
llvm::report_fatal_error("could not retrieve symbol address");
631639

640+
mlir::Value isPresent;
641+
if (Fortran::semantics::IsOptional(sym))
642+
isPresent =
643+
builder.create<fir::IsPresentOp>(loc, builder.getI1Type(), rawInput);
644+
632645
if (auto boxTy =
633646
fir::unwrapRefType(symAddr.getType()).dyn_cast<fir::BaseBoxType>()) {
634647
if (boxTy.getEleTy().isa<fir::RecordType>())
@@ -638,8 +651,6 @@ getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
638651
// `fir.ref<fir.class<T>>` type.
639652
if (symAddr.getType().isa<fir::ReferenceType>()) {
640653
if (Fortran::semantics::IsOptional(sym)) {
641-
mlir::Value isPresent =
642-
builder.create<fir::IsPresentOp>(loc, builder.getI1Type(), symAddr);
643654
mlir::Value addr =
644655
builder.genIfOp(loc, {boxTy}, isPresent, /*withElseRegion=*/true)
645656
.genThen([&]() {
@@ -652,14 +663,13 @@ getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
652663
builder.create<fir::ResultOp>(loc, mlir::ValueRange{absent});
653664
})
654665
.getResults()[0];
655-
return AddrAndBoundsInfo(addr, isPresent);
666+
return AddrAndBoundsInfo(addr, rawInput, isPresent);
656667
}
657668
mlir::Value addr = builder.create<fir::LoadOp>(loc, symAddr);
658-
return AddrAndBoundsInfo(addr);
659-
;
669+
return AddrAndBoundsInfo(addr, rawInput, isPresent);
660670
}
661671
}
662-
return AddrAndBoundsInfo(symAddr);
672+
return AddrAndBoundsInfo(symAddr, rawInput, isPresent);
663673
}
664674

665675
template <typename BoundsOp, typename BoundsType>
@@ -807,7 +817,7 @@ genBoundsOps(fir::FirOpBuilder &builder, mlir::Location loc,
807817
Fortran::lower::StatementContext &stmtCtx,
808818
const std::list<Fortran::parser::SectionSubscript> &subscripts,
809819
std::stringstream &asFortran, fir::ExtendedValue &dataExv,
810-
bool dataExvIsAssumedSize, mlir::Value baseAddr,
820+
bool dataExvIsAssumedSize, AddrAndBoundsInfo &info,
811821
bool treatIndexAsSection = false) {
812822
int dimension = 0;
813823
mlir::Type idxTy = builder.getIndexType();
@@ -831,11 +841,30 @@ genBoundsOps(fir::FirOpBuilder &builder, mlir::Location loc,
831841
mlir::Value stride = one;
832842
bool strideInBytes = false;
833843

834-
if (fir::unwrapRefType(baseAddr.getType()).isa<fir::BaseBoxType>()) {
835-
mlir::Value d = builder.createIntegerConstant(loc, idxTy, dimension);
836-
auto dimInfo = builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy,
837-
baseAddr, d);
838-
stride = dimInfo.getByteStride();
844+
if (fir::unwrapRefType(info.addr.getType()).isa<fir::BaseBoxType>()) {
845+
if (info.isPresent) {
846+
stride =
847+
builder
848+
.genIfOp(loc, idxTy, info.isPresent, /*withElseRegion=*/true)
849+
.genThen([&]() {
850+
mlir::Value d =
851+
builder.createIntegerConstant(loc, idxTy, dimension);
852+
auto dimInfo = builder.create<fir::BoxDimsOp>(
853+
loc, idxTy, idxTy, idxTy, info.addr, d);
854+
builder.create<fir::ResultOp>(loc, dimInfo.getByteStride());
855+
})
856+
.genElse([&] {
857+
mlir::Value zero =
858+
builder.createIntegerConstant(loc, idxTy, 0);
859+
builder.create<fir::ResultOp>(loc, zero);
860+
})
861+
.getResults()[0];
862+
} else {
863+
mlir::Value d = builder.createIntegerConstant(loc, idxTy, dimension);
864+
auto dimInfo = builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy,
865+
idxTy, info.addr, d);
866+
stride = dimInfo.getByteStride();
867+
}
839868
strideInBytes = true;
840869
}
841870

@@ -919,7 +948,26 @@ genBoundsOps(fir::FirOpBuilder &builder, mlir::Location loc,
919948
}
920949
}
921950

922-
extent = fir::factory::readExtent(builder, loc, dataExv, dimension);
951+
if (info.isPresent &&
952+
fir::unwrapRefType(info.addr.getType()).isa<fir::BaseBoxType>()) {
953+
extent =
954+
builder
955+
.genIfOp(loc, idxTy, info.isPresent, /*withElseRegion=*/true)
956+
.genThen([&]() {
957+
mlir::Value ext = fir::factory::readExtent(
958+
builder, loc, dataExv, dimension);
959+
builder.create<fir::ResultOp>(loc, ext);
960+
})
961+
.genElse([&] {
962+
mlir::Value zero =
963+
builder.createIntegerConstant(loc, idxTy, 0);
964+
builder.create<fir::ResultOp>(loc, zero);
965+
})
966+
.getResults()[0];
967+
} else {
968+
extent = fir::factory::readExtent(builder, loc, dataExv, dimension);
969+
}
970+
923971
if (dataExvIsAssumedSize && dimension + 1 == dataExvRank) {
924972
extent = zero;
925973
if (ubound && lbound) {
@@ -976,6 +1024,7 @@ AddrAndBoundsInfo gatherDataOperandAddrAndBounds(
9761024
dataExv = converter.genExprAddr(operandLocation, *exprBase,
9771025
stmtCtx);
9781026
info.addr = fir::getBase(dataExv);
1027+
info.rawInput = info.addr;
9791028
asFortran << (*exprBase).AsFortran();
9801029
} else {
9811030
const Fortran::parser::Name &name =
@@ -993,14 +1042,15 @@ AddrAndBoundsInfo gatherDataOperandAddrAndBounds(
9931042
bounds = genBoundsOps<BoundsOp, BoundsType>(
9941043
builder, operandLocation, converter, stmtCtx,
9951044
arrayElement->subscripts, asFortran, dataExv,
996-
dataExvIsAssumedSize, info.addr, treatIndexAsSection);
1045+
dataExvIsAssumedSize, info, treatIndexAsSection);
9971046
}
9981047
asFortran << ')';
9991048
} else if (auto structComp = Fortran::parser::Unwrap<
10001049
Fortran::parser::StructureComponent>(designator)) {
10011050
fir::ExtendedValue compExv =
10021051
converter.genExprAddr(operandLocation, *expr, stmtCtx);
10031052
info.addr = fir::getBase(compExv);
1053+
info.rawInput = info.addr;
10041054
if (fir::unwrapRefType(info.addr.getType())
10051055
.isa<fir::SequenceType>())
10061056
bounds = genBaseBoundsOps<BoundsOp, BoundsType>(
@@ -1012,14 +1062,15 @@ AddrAndBoundsInfo gatherDataOperandAddrAndBounds(
10121062
*Fortran::parser::GetLastName(*structComp).symbol);
10131063
if (isOptional)
10141064
info.isPresent = builder.create<fir::IsPresentOp>(
1015-
operandLocation, builder.getI1Type(), info.addr);
1065+
operandLocation, builder.getI1Type(), info.rawInput);
10161066

10171067
if (auto loadOp = mlir::dyn_cast_or_null<fir::LoadOp>(
10181068
info.addr.getDefiningOp())) {
10191069
if (fir::isAllocatableType(loadOp.getType()) ||
10201070
fir::isPointerType(loadOp.getType()))
10211071
info.addr = builder.create<fir::BoxAddrOp>(operandLocation,
10221072
info.addr);
1073+
info.rawInput = info.addr;
10231074
}
10241075

10251076
// If the component is an allocatable or pointer the result of
@@ -1029,6 +1080,7 @@ AddrAndBoundsInfo gatherDataOperandAddrAndBounds(
10291080
if (auto boxAddrOp = mlir::dyn_cast_or_null<fir::BoxAddrOp>(
10301081
info.addr.getDefiningOp())) {
10311082
info.addr = boxAddrOp.getVal();
1083+
info.rawInput = info.addr;
10321084
bounds = genBoundsOpsFromBox<BoundsOp, BoundsType>(
10331085
builder, operandLocation, converter, compExv, info);
10341086
}
@@ -1043,6 +1095,7 @@ AddrAndBoundsInfo gatherDataOperandAddrAndBounds(
10431095
fir::ExtendedValue compExv =
10441096
converter.genExprAddr(operandLocation, *expr, stmtCtx);
10451097
info.addr = fir::getBase(compExv);
1098+
info.rawInput = info.addr;
10461099
asFortran << (*expr).AsFortran();
10471100
} else if (const auto *dataRef{
10481101
std::get_if<Fortran::parser::DataRef>(

flang/lib/Lower/OpenACC.cpp

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,12 @@ static Op createDataEntryOp(fir::FirOpBuilder &builder, mlir::Location loc,
6767
mlir::Value varPtrPtr;
6868
if (auto boxTy = baseAddr.getType().dyn_cast<fir::BaseBoxType>()) {
6969
if (isPresent) {
70+
mlir::Type ifRetTy = boxTy.getEleTy();
71+
if (!fir::isa_ref_type(ifRetTy))
72+
ifRetTy = fir::ReferenceType::get(ifRetTy);
7073
baseAddr =
7174
builder
72-
.genIfOp(loc, {boxTy.getEleTy()}, isPresent,
75+
.genIfOp(loc, {ifRetTy}, isPresent,
7376
/*withElseRegion=*/true)
7477
.genThen([&]() {
7578
mlir::Value boxAddr =
@@ -78,7 +81,7 @@ static Op createDataEntryOp(fir::FirOpBuilder &builder, mlir::Location loc,
7881
})
7982
.genElse([&] {
8083
mlir::Value absent =
81-
builder.create<fir::AbsentOp>(loc, boxTy.getEleTy());
84+
builder.create<fir::AbsentOp>(loc, ifRetTy);
8285
builder.create<fir::ResultOp>(loc, mlir::ValueRange{absent});
8386
})
8487
.getResults()[0];
@@ -295,9 +298,16 @@ genDataOperandOperations(const Fortran::parser::AccObjectList &objectList,
295298
asFortran, bounds,
296299
/*treatIndexAsSection=*/true);
297300

298-
Op op = createDataEntryOp<Op>(
299-
builder, operandLocation, info.addr, asFortran, bounds, structured,
300-
implicit, dataClause, info.addr.getType(), info.isPresent);
301+
// If the input value is optional and is not a descriptor, we use the
302+
// rawInput directly.
303+
mlir::Value baseAddr =
304+
((info.addr.getType() != fir::unwrapRefType(info.rawInput.getType())) &&
305+
info.isPresent)
306+
? info.rawInput
307+
: info.addr;
308+
Op op = createDataEntryOp<Op>(builder, operandLocation, baseAddr, asFortran,
309+
bounds, structured, implicit, dataClause,
310+
baseAddr.getType(), info.isPresent);
301311
dataOperands.push_back(op.getAccPtr());
302312
}
303313
}

flang/test/Lower/OpenACC/acc-bounds.f90

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,8 +126,8 @@ subroutine acc_optional_data(a)
126126

127127
! CHECK-LABEL: func.func @_QMopenacc_boundsPacc_optional_data(
128128
! CHECK-SAME: %[[ARG0:.*]]: !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>> {fir.bindc_name = "a", fir.optional}) {
129-
! CHECK: %[[ARG0_DECL:.*]]:2 = hlfir.declare %arg0 {fortran_attrs = #fir.var_attrs<optional, pointer>, uniq_name = "_QMopenacc_boundsFacc_optional_dataEa"} : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) -> (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>)
130-
! CHECK: %[[IS_PRESENT:.*]] = fir.is_present %[[ARG0_DECL]]#0 : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) -> i1
129+
! CHECK: %[[ARG0_DECL:.*]]:2 = hlfir.declare %[[ARG0]] {fortran_attrs = #fir.var_attrs<optional, pointer>, uniq_name = "_QMopenacc_boundsFacc_optional_dataEa"} : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) -> (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>)
130+
! CHECK: %[[IS_PRESENT:.*]] = fir.is_present %[[ARG0_DECL]]#1 : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) -> i1
131131
! CHECK: %[[BOX:.*]] = fir.if %[[IS_PRESENT]] -> (!fir.box<!fir.ptr<!fir.array<?xf32>>>) {
132132
! CHECK: %[[LOAD:.*]] = fir.load %[[ARG0_DECL]]#0 : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>
133133
! CHECK: fir.result %[[LOAD]] : !fir.box<!fir.ptr<!fir.array<?xf32>>>
@@ -153,4 +153,38 @@ subroutine acc_optional_data(a)
153153
! CHECK: %[[ATTACH:.*]] = acc.attach varPtr(%[[BOX_ADDR]] : !fir.ptr<!fir.array<?xf32>>) bounds(%[[BOUND]]) -> !fir.ptr<!fir.array<?xf32>> {name = "a"}
154154
! CHECK: acc.data dataOperands(%[[ATTACH]] : !fir.ptr<!fir.array<?xf32>>)
155155

156+
subroutine acc_optional_data2(a, n)
157+
integer :: n
158+
real, optional :: a(n)
159+
!$acc data no_create(a)
160+
!$acc end data
161+
end subroutine
162+
163+
! CHECK-LABEL: func.func @_QMopenacc_boundsPacc_optional_data2(
164+
! CHECK-SAME: %[[A:.*]]: !fir.ref<!fir.array<?xf32>> {fir.bindc_name = "a", fir.optional}, %[[N:.*]]: !fir.ref<i32> {fir.bindc_name = "n"}) {
165+
! CHECK: %[[DECL_A:.*]]:2 = hlfir.declare %[[A]](%{{.*}}) {fortran_attrs = #fir.var_attrs<optional>, uniq_name = "_QMopenacc_boundsFacc_optional_data2Ea"} : (!fir.ref<!fir.array<?xf32>>, !fir.shape<1>) -> (!fir.box<!fir.array<?xf32>>, !fir.ref<!fir.array<?xf32>>)
166+
! CHECK: %[[NO_CREATE:.*]] = acc.nocreate varPtr(%[[DECL_A]]#1 : !fir.ref<!fir.array<?xf32>>) bounds(%10) -> !fir.ref<!fir.array<?xf32>> {name = "a"}
167+
! CHECK: acc.data dataOperands(%[[NO_CREATE]] : !fir.ref<!fir.array<?xf32>>) {
168+
169+
subroutine acc_optional_data3(a, n)
170+
integer :: n
171+
real, optional :: a(n)
172+
!$acc data no_create(a(1:n))
173+
!$acc end data
174+
end subroutine
175+
176+
! CHECK-LABEL: func.func @_QMopenacc_boundsPacc_optional_data3(
177+
! CHECK-SAME: %[[A:.*]]: !fir.ref<!fir.array<?xf32>> {fir.bindc_name = "a", fir.optional}, %[[N:.*]]: !fir.ref<i32> {fir.bindc_name = "n"}) {
178+
! CHECK: %[[DECL_A:.*]]:2 = hlfir.declare %[[A]](%{{.*}}) {fortran_attrs = #fir.var_attrs<optional>, uniq_name = "_QMopenacc_boundsFacc_optional_data3Ea"} : (!fir.ref<!fir.array<?xf32>>, !fir.shape<1>) -> (!fir.box<!fir.array<?xf32>>, !fir.ref<!fir.array<?xf32>>)
179+
! CHECK: %[[PRES:.*]] = fir.is_present %[[DECL_A]]#1 : (!fir.ref<!fir.array<?xf32>>) -> i1
180+
! CHECK: %[[STRIDE:.*]] = fir.if %[[PRES]] -> (index) {
181+
! CHECK: %[[DIMS:.*]]:3 = fir.box_dims %[[DECL_A]]#0, %c0{{.*}} : (!fir.box<!fir.array<?xf32>>, index) -> (index, index, index)
182+
! CHECK: fir.result %[[DIMS]]#2 : index
183+
! CHECK: } else {
184+
! CHECK: fir.result %c0{{.*}} : index
185+
! CHECK: }
186+
! CHECK: %[[BOUNDS:.*]] = acc.bounds lowerbound(%c0{{.*}} : index) upperbound(%{{.*}} : index) extent(%{{.*}} : index) stride(%[[STRIDE]] : index) startIdx(%c1 : index) {strideInBytes = true}
187+
! CHECK: %[[NOCREATE:.*]] = acc.nocreate varPtr(%[[DECL_A]]#1 : !fir.ref<!fir.array<?xf32>>) bounds(%14) -> !fir.ref<!fir.array<?xf32>> {name = "a(1:n)"}
188+
! CHECK: acc.data dataOperands(%[[NOCREATE]] : !fir.ref<!fir.array<?xf32>>) {
189+
156190
end module

0 commit comments

Comments
 (0)