Skip to content

[AArch64][SVE2] Generate urshr rounding shift rights #78374

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jan 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 129 additions & 26 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2689,6 +2689,7 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
MAKE_CASE(AArch64ISD::RSHRNB_I)
MAKE_CASE(AArch64ISD::CTTZ_ELTS)
MAKE_CASE(AArch64ISD::CALL_ARM64EC_TO_X64)
MAKE_CASE(AArch64ISD::URSHR_I_PRED)
}
#undef MAKE_CASE
return nullptr;
Expand Down Expand Up @@ -2973,6 +2974,7 @@ static SDValue convertToScalableVector(SelectionDAG &DAG, EVT VT, SDValue V);
static SDValue convertFromScalableVector(SelectionDAG &DAG, EVT VT, SDValue V);
static SDValue convertFixedMaskToScalableVector(SDValue Mask,
SelectionDAG &DAG);
static SDValue getPredicateForVector(SelectionDAG &DAG, SDLoc &DL, EVT VT);
static SDValue getPredicateForScalableVector(SelectionDAG &DAG, SDLoc &DL,
EVT VT);

Expand Down Expand Up @@ -13838,6 +13840,51 @@ SDValue AArch64TargetLowering::LowerTRUNCATE(SDValue Op,
return SDValue();
}

// Check if we can we lower this SRL to a rounding shift instruction. ResVT is
// possibly a truncated type, it tells how many bits of the value are to be
// used.
static bool canLowerSRLToRoundingShiftForVT(SDValue Shift, EVT ResVT,
SelectionDAG &DAG,
unsigned &ShiftValue,
SDValue &RShOperand) {
if (Shift->getOpcode() != ISD::SRL)
return false;

EVT VT = Shift.getValueType();
assert(VT.isScalableVT());

auto ShiftOp1 =
dyn_cast_or_null<ConstantSDNode>(DAG.getSplatValue(Shift->getOperand(1)));
if (!ShiftOp1)
return false;

ShiftValue = ShiftOp1->getZExtValue();
if (ShiftValue < 1 || ShiftValue > ResVT.getScalarSizeInBits())
return false;

SDValue Add = Shift->getOperand(0);
if (Add->getOpcode() != ISD::ADD || !Add->hasOneUse())
return false;

assert(ResVT.getScalarSizeInBits() <= VT.getScalarSizeInBits() &&
"ResVT must be truncated or same type as the shift.");
// Check if an overflow can lead to incorrect results.
uint64_t ExtraBits = VT.getScalarSizeInBits() - ResVT.getScalarSizeInBits();
if (ShiftValue > ExtraBits && !Add->getFlags().hasNoUnsignedWrap())
return false;

auto AddOp1 =
dyn_cast_or_null<ConstantSDNode>(DAG.getSplatValue(Add->getOperand(1)));
if (!AddOp1)
return false;
uint64_t AddValue = AddOp1->getZExtValue();
if (AddValue != 1ULL << (ShiftValue - 1))
return false;

RShOperand = Add->getOperand(0);
return true;
}

SDValue AArch64TargetLowering::LowerVectorSRA_SRL_SHL(SDValue Op,
SelectionDAG &DAG) const {
EVT VT = Op.getValueType();
Expand All @@ -13863,6 +13910,15 @@ SDValue AArch64TargetLowering::LowerVectorSRA_SRL_SHL(SDValue Op,
Op.getOperand(0), Op.getOperand(1));
case ISD::SRA:
case ISD::SRL:
if (VT.isScalableVector() && Subtarget->hasSVE2orSME()) {
SDValue RShOperand;
unsigned ShiftValue;
if (canLowerSRLToRoundingShiftForVT(Op, VT, DAG, ShiftValue, RShOperand))
return DAG.getNode(AArch64ISD::URSHR_I_PRED, DL, VT,
getPredicateForVector(DAG, DL, VT), RShOperand,
DAG.getTargetConstant(ShiftValue, DL, MVT::i32));
}

if (VT.isScalableVector() ||
useSVEForFixedLengthVectorVT(VT, !Subtarget->isNeonAvailable())) {
unsigned Opc = Op.getOpcode() == ISD::SRA ? AArch64ISD::SRA_PRED
Expand Down Expand Up @@ -17687,9 +17743,6 @@ static SDValue performReinterpretCastCombine(SDNode *N) {

static SDValue performSVEAndCombine(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI) {
if (DCI.isBeforeLegalizeOps())
return SDValue();

SelectionDAG &DAG = DCI.DAG;
SDValue Src = N->getOperand(0);
unsigned Opc = Src->getOpcode();
Expand Down Expand Up @@ -17745,6 +17798,9 @@ static SDValue performSVEAndCombine(SDNode *N,
return DAG.getNode(Opc, DL, N->getValueType(0), And);
}

if (DCI.isBeforeLegalizeOps())
return SDValue();

// If both sides of AND operations are i1 splat_vectors then
// we can produce just i1 splat_vector as the result.
if (isAllActivePredicate(DAG, N->getOperand(0)))
Expand Down Expand Up @@ -20192,6 +20248,9 @@ static SDValue performIntrinsicCombine(SDNode *N,
case Intrinsic::aarch64_sve_uqsub_x:
return DAG.getNode(ISD::USUBSAT, SDLoc(N), N->getValueType(0),
N->getOperand(1), N->getOperand(2));
case Intrinsic::aarch64_sve_urshr:
return DAG.getNode(AArch64ISD::URSHR_I_PRED, SDLoc(N), N->getValueType(0),
N->getOperand(1), N->getOperand(2), N->getOperand(3));
Comment on lines +20251 to +20253
Copy link
Collaborator

@paulwalker-arm paulwalker-arm Feb 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't look sound to me. Either the naming of the AArch64ISD node is wrong or there needs to be a PatFrags that contains this node and the intrinsic. I say the because the _PRED nodes have no requirement when it comes to the result of inactive lanes whereas the aarch64_sve_urshr intrinsic has a very specific requirement.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For naming I looked at a few other instructions that have similar behavior as urshr, i.e. inactive elements in the destination vector remain unmodified, and they were also named as _PRED.

I am quite new to the isel backend. Can you please explain what difference having a PatFrag would make compared to the code above?

Copy link
Collaborator

@paulwalker-arm paulwalker-arm Feb 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you point to an example of where a _PRED node expects the results of inactive lanes to take a known value? because that really shouldn't be the case (there's a comment at the top of AArch64SVEInstrInfo.td and AArch64ISelLowering.h that details the naming strategy). The intent of the _PRED nodes is to allow predication to be represented at the DAG level rather than waiting until instruction selection. They have no requirement for the results of inactive lanes to free up instruction section to allow the best use of unpredicated and/or reversed instructions.

The naming is important because people will assume the documented rules implementing DAG combines or make changes to instruction selection and thus if they're not followed it's very likely to introduce bugs. If it's important for the ISD node to model the results of the inactive lanes in accordance with the underlying SVE instruction then it should be named as such (e.g. URSHR_I_MERGE_OP1).

This is generally not the case and typically at the ISD level the result of inactive lanes is not important (often because an all active predicate is passed in) and thus the _PRED suffix is used. When this is the case we still want to minimise the number of ISel patterns and so a PatFrags is created to match both the ISD node and the intrinsic to the same instruction (e.g. AArch64mla_m1).

Copy link
Contributor Author

@UsmanNadeem UsmanNadeem Feb 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I get your point now. I will post a follow-up fix.

case Intrinsic::aarch64_sve_asrd:
return DAG.getNode(AArch64ISD::SRAD_MERGE_OP1, SDLoc(N), N->getValueType(0),
N->getOperand(1), N->getOperand(2), N->getOperand(3));
Expand Down Expand Up @@ -20808,6 +20867,51 @@ static SDValue performUnpackCombine(SDNode *N, SelectionDAG &DAG,
return SDValue();
}

static bool isHalvingTruncateAndConcatOfLegalIntScalableType(SDNode *N) {
if (N->getOpcode() != AArch64ISD::UZP1)
return false;
SDValue Op0 = N->getOperand(0);
EVT SrcVT = Op0->getValueType(0);
EVT DstVT = N->getValueType(0);
return (SrcVT == MVT::nxv8i16 && DstVT == MVT::nxv16i8) ||
(SrcVT == MVT::nxv4i32 && DstVT == MVT::nxv8i16) ||
(SrcVT == MVT::nxv2i64 && DstVT == MVT::nxv4i32);
}

// Try to combine rounding shifts where the operands come from an extend, and
// the result is truncated and combined into one vector.
// uzp1(rshrnb(uunpklo(X),C), rshrnb(uunpkhi(X), C)) -> urshr(X, C)
static SDValue tryCombineExtendRShTrunc(SDNode *N, SelectionDAG &DAG) {
assert(N->getOpcode() == AArch64ISD::UZP1 && "Only UZP1 expected.");
SDValue Op0 = N->getOperand(0);
SDValue Op1 = N->getOperand(1);
EVT ResVT = N->getValueType(0);

unsigned RshOpc = Op0.getOpcode();
if (RshOpc != AArch64ISD::RSHRNB_I)
return SDValue();

// Same op code and imm value?
SDValue ShiftValue = Op0.getOperand(1);
if (RshOpc != Op1.getOpcode() || ShiftValue != Op1.getOperand(1))
return SDValue();

// Same unextended operand value?
SDValue Lo = Op0.getOperand(0);
SDValue Hi = Op1.getOperand(0);
if (Lo.getOpcode() != AArch64ISD::UUNPKLO &&
Hi.getOpcode() != AArch64ISD::UUNPKHI)
return SDValue();
SDValue OrigArg = Lo.getOperand(0);
if (OrigArg != Hi.getOperand(0))
return SDValue();

SDLoc DL(N);
return DAG.getNode(AArch64ISD::URSHR_I_PRED, DL, ResVT,
getPredicateForVector(DAG, DL, ResVT), OrigArg,
ShiftValue);
}

// Try to simplify:
// t1 = nxv8i16 add(X, 1 << (ShiftValue - 1))
// t2 = nxv8i16 srl(t1, ShiftValue)
Expand All @@ -20820,9 +20924,7 @@ static SDValue performUnpackCombine(SDNode *N, SelectionDAG &DAG,
static SDValue trySimplifySrlAddToRshrnb(SDValue Srl, SelectionDAG &DAG,
const AArch64Subtarget *Subtarget) {
EVT VT = Srl->getValueType(0);

if (!VT.isScalableVector() || !Subtarget->hasSVE2() ||
Srl->getOpcode() != ISD::SRL)
if (!VT.isScalableVector() || !Subtarget->hasSVE2())
return SDValue();

EVT ResVT;
Expand All @@ -20835,29 +20937,14 @@ static SDValue trySimplifySrlAddToRshrnb(SDValue Srl, SelectionDAG &DAG,
else
return SDValue();

auto SrlOp1 =
dyn_cast_or_null<ConstantSDNode>(DAG.getSplatValue(Srl->getOperand(1)));
if (!SrlOp1)
return SDValue();
unsigned ShiftValue = SrlOp1->getZExtValue();
if (ShiftValue < 1 || ShiftValue > ResVT.getScalarSizeInBits())
return SDValue();

SDValue Add = Srl->getOperand(0);
if (Add->getOpcode() != ISD::ADD || !Add->hasOneUse())
return SDValue();
auto AddOp1 =
dyn_cast_or_null<ConstantSDNode>(DAG.getSplatValue(Add->getOperand(1)));
if (!AddOp1)
return SDValue();
uint64_t AddValue = AddOp1->getZExtValue();
if (AddValue != 1ULL << (ShiftValue - 1))
return SDValue();

SDLoc DL(Srl);
unsigned ShiftValue;
SDValue RShOperand;
if (!canLowerSRLToRoundingShiftForVT(Srl, ResVT, DAG, ShiftValue, RShOperand))
return SDValue();
SDValue Rshrnb = DAG.getNode(
AArch64ISD::RSHRNB_I, DL, ResVT,
{Add->getOperand(0), DAG.getTargetConstant(ShiftValue, DL, MVT::i32)});
{RShOperand, DAG.getTargetConstant(ShiftValue, DL, MVT::i32)});
return DAG.getNode(ISD::BITCAST, DL, VT, Rshrnb);
}

Expand Down Expand Up @@ -20895,6 +20982,9 @@ static SDValue performUzpCombine(SDNode *N, SelectionDAG &DAG,
}
}

if (SDValue Urshr = tryCombineExtendRShTrunc(N, DAG))
return Urshr;

if (SDValue Rshrnb = trySimplifySrlAddToRshrnb(Op0, DAG, Subtarget))
return DAG.getNode(AArch64ISD::UZP1, DL, ResVT, Rshrnb, Op1);

Expand Down Expand Up @@ -20925,6 +21015,19 @@ static SDValue performUzpCombine(SDNode *N, SelectionDAG &DAG,
if (!IsLittleEndian)
return SDValue();

// uzp1(bitcast(x), bitcast(y)) -> uzp1(x, y)
// Example:
// nxv4i32 = uzp1 bitcast(nxv4i32 x to nxv2i64), bitcast(nxv4i32 y to nxv2i64)
// to
// nxv4i32 = uzp1 nxv4i32 x, nxv4i32 y
if (isHalvingTruncateAndConcatOfLegalIntScalableType(N) &&
Op0.getOpcode() == ISD::BITCAST && Op1.getOpcode() == ISD::BITCAST) {
if (Op0.getOperand(0).getValueType() == Op1.getOperand(0).getValueType()) {
return DAG.getNode(AArch64ISD::UZP1, DL, ResVT, Op0.getOperand(0),
Op1.getOperand(0));
}
}

if (ResVT != MVT::v2i32 && ResVT != MVT::v4i16 && ResVT != MVT::v8i8)
return SDValue();

Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ enum NodeType : unsigned {
SQSHLU_I,
SRSHR_I,
URSHR_I,
URSHR_I_PRED,

// Vector narrowing shift by immediate (bottom)
RSHRNB_I,
Expand Down
5 changes: 3 additions & 2 deletions llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ def SDT_AArch64Arith_Imm : SDTypeProfile<1, 3, [
]>;

def AArch64asrd_m1 : SDNode<"AArch64ISD::SRAD_MERGE_OP1", SDT_AArch64Arith_Imm>;
def AArch64urshri_p : SDNode<"AArch64ISD::URSHR_I_PRED", SDT_AArch64Arith_Imm>;

def SDT_AArch64IntExtend : SDTypeProfile<1, 4, [
SDTCisVec<0>, SDTCisVec<1>, SDTCisVec<2>, SDTCisVT<3, OtherVT>, SDTCisVec<4>,
Expand Down Expand Up @@ -3539,7 +3540,7 @@ let Predicates = [HasSVE2orSME] in {
defm SQSHL_ZPmI : sve_int_bin_pred_shift_imm_left_dup<0b0110, "sqshl", "SQSHL_ZPZI", int_aarch64_sve_sqshl>;
defm UQSHL_ZPmI : sve_int_bin_pred_shift_imm_left_dup<0b0111, "uqshl", "UQSHL_ZPZI", int_aarch64_sve_uqshl>;
defm SRSHR_ZPmI : sve_int_bin_pred_shift_imm_right< 0b1100, "srshr", "SRSHR_ZPZI", int_aarch64_sve_srshr>;
defm URSHR_ZPmI : sve_int_bin_pred_shift_imm_right< 0b1101, "urshr", "URSHR_ZPZI", int_aarch64_sve_urshr>;
defm URSHR_ZPmI : sve_int_bin_pred_shift_imm_right< 0b1101, "urshr", "URSHR_ZPZI", AArch64urshri_p>;
defm SQSHLU_ZPmI : sve_int_bin_pred_shift_imm_left< 0b1111, "sqshlu", "SQSHLU_ZPZI", int_aarch64_sve_sqshlu>;

// SVE2 integer add/subtract long
Expand Down Expand Up @@ -3584,7 +3585,7 @@ let Predicates = [HasSVE2orSME] in {
defm SSRA_ZZI : sve2_int_bin_accum_shift_imm_right<0b00, "ssra", AArch64ssra>;
defm USRA_ZZI : sve2_int_bin_accum_shift_imm_right<0b01, "usra", AArch64usra>;
defm SRSRA_ZZI : sve2_int_bin_accum_shift_imm_right<0b10, "srsra", int_aarch64_sve_srsra, int_aarch64_sve_srshr>;
defm URSRA_ZZI : sve2_int_bin_accum_shift_imm_right<0b11, "ursra", int_aarch64_sve_ursra, int_aarch64_sve_urshr>;
defm URSRA_ZZI : sve2_int_bin_accum_shift_imm_right<0b11, "ursra", int_aarch64_sve_ursra, AArch64urshri_p>;

// SVE2 complex integer add
defm CADD_ZZI : sve2_int_cadd<0b0, "cadd", int_aarch64_sve_cadd_x>;
Expand Down
Loading