@@ -711,6 +711,13 @@ SystemZTargetLowering::SystemZTargetLowering(const TargetMachine &TM,
711
711
setOperationAction (ISD::BITCAST, MVT::f32, Custom);
712
712
}
713
713
714
+ // Expand FP16 <=> FP32 conversions to libcalls and handle FP16 loads and
715
+ // stores in GPRs.
716
+ setOperationAction (ISD::FP16_TO_FP, MVT::f32, Expand);
717
+ setOperationAction (ISD::FP_TO_FP16, MVT::f32, Expand);
718
+ setLoadExtAction (ISD::EXTLOAD, MVT::f32, MVT::f16, Expand);
719
+ setTruncStoreAction (MVT::f32, MVT::f16, Expand);
720
+
714
721
// VASTART and VACOPY need to deal with the SystemZ-specific varargs
715
722
// structure, but VAEND is a no-op.
716
723
setOperationAction (ISD::VASTART, MVT::Other, Custom);
@@ -784,6 +791,20 @@ bool SystemZTargetLowering::useSoftFloat() const {
784
791
return Subtarget.hasSoftFloat ();
785
792
}
786
793
794
+ MVT SystemZTargetLowering::getRegisterTypeForCallingConv (
795
+ LLVMContext &Context, CallingConv::ID CC,
796
+ EVT VT) const {
797
+ // 128-bit single-element vector types are passed like other vectors,
798
+ // not like their element type.
799
+ if (VT.isVector () && VT.getSizeInBits () == 128 &&
800
+ VT.getVectorNumElements () == 1 )
801
+ return MVT::v16i8;
802
+ // Keep f16 so that they can be recognized and handled.
803
+ if (VT == MVT::f16)
804
+ return MVT::f16;
805
+ return TargetLowering::getRegisterTypeForCallingConv (Context, CC, VT);
806
+ }
807
+
787
808
EVT SystemZTargetLowering::getSetCCResultType (const DataLayout &DL,
788
809
LLVMContext &, EVT VT) const {
789
810
if (!VT.isVector ())
@@ -1602,6 +1623,15 @@ bool SystemZTargetLowering::splitValueIntoRegisterParts(
1602
1623
return true ;
1603
1624
}
1604
1625
1626
+ // Convert f16 to f32 (Out-arg).
1627
+ if (PartVT == MVT::f16) {
1628
+ assert (NumParts == 1 && " " );
1629
+ SDValue I16Val = DAG.getBitcast (MVT::i16, Val);
1630
+ SDValue I32Val = DAG.getAnyExtOrTrunc (I16Val, DL, MVT::i32);
1631
+ Parts[0 ] = DAG.getBitcast (MVT::f32, I32Val);
1632
+ return true ;
1633
+ }
1634
+
1605
1635
return false ;
1606
1636
}
1607
1637
@@ -1617,6 +1647,18 @@ SDValue SystemZTargetLowering::joinRegisterPartsIntoValue(
1617
1647
return SDValue ();
1618
1648
}
1619
1649
1650
+ // F32Val holds a f16 value in f32, return it as an f16 (In-arg). The
1651
+ // CopyFromReg was made into an f32 as required as FP32 registers are used
1652
+ // for arguments, now convert it to f16.
1653
+ static SDValue convertF32ToF16 (SDValue F32Val, SelectionDAG &DAG,
1654
+ const SDLoc &DL) {
1655
+ assert (F32Val->getOpcode () == ISD::CopyFromReg &&
1656
+ " Only expecting to handle f16 with CopyFromReg here." );
1657
+ SDValue I32Val = DAG.getBitcast (MVT::i32, F32Val);
1658
+ SDValue I16Val = DAG.getAnyExtOrTrunc (I32Val, DL, MVT::i16);
1659
+ return DAG.getBitcast (MVT::f16, I16Val);
1660
+ }
1661
+
1620
1662
SDValue SystemZTargetLowering::LowerFormalArguments (
1621
1663
SDValue Chain, CallingConv::ID CallConv, bool IsVarArg,
1622
1664
const SmallVectorImpl<ISD::InputArg> &Ins, const SDLoc &DL,
@@ -1656,6 +1698,7 @@ SDValue SystemZTargetLowering::LowerFormalArguments(
1656
1698
NumFixedGPRs += 1 ;
1657
1699
RC = &SystemZ::GR64BitRegClass;
1658
1700
break ;
1701
+ case MVT::f16:
1659
1702
case MVT::f32:
1660
1703
NumFixedFPRs += 1 ;
1661
1704
RC = &SystemZ::FP32BitRegClass;
@@ -1680,7 +1723,11 @@ SDValue SystemZTargetLowering::LowerFormalArguments(
1680
1723
1681
1724
Register VReg = MRI.createVirtualRegister (RC);
1682
1725
MRI.addLiveIn (VA.getLocReg (), VReg);
1683
- ArgValue = DAG.getCopyFromReg (Chain, DL, VReg, LocVT);
1726
+ // Special handling is needed for f16.
1727
+ MVT ArgVT = VA.getLocVT () == MVT::f16 ? MVT::f32 : VA.getLocVT ();
1728
+ ArgValue = DAG.getCopyFromReg (Chain, DL, VReg, ArgVT);
1729
+ if (VA.getLocVT () == MVT::f16)
1730
+ ArgValue = convertF32ToF16 (ArgValue, DAG, DL);
1684
1731
} else {
1685
1732
assert (VA.isMemLoc () && " Argument not register or memory" );
1686
1733
@@ -1700,9 +1747,12 @@ SDValue SystemZTargetLowering::LowerFormalArguments(
1700
1747
// from this parameter. Unpromoted ints and floats are
1701
1748
// passed as right-justified 8-byte values.
1702
1749
SDValue FIN = DAG.getFrameIndex (FI, PtrVT);
1703
- if (VA.getLocVT () == MVT::i32 || VA.getLocVT () == MVT::f32)
1750
+ if (VA.getLocVT () == MVT::i32 || VA.getLocVT () == MVT::f32 ||
1751
+ VA.getLocVT () == MVT::f16) {
1752
+ unsigned SlotOffs = VA.getLocVT () == MVT::f16 ? 6 : 4 ;
1704
1753
FIN = DAG.getNode (ISD::ADD, DL, PtrVT, FIN,
1705
- DAG.getIntPtrConstant (4 , DL));
1754
+ DAG.getIntPtrConstant (SlotOffs, DL));
1755
+ }
1706
1756
ArgValue = DAG.getLoad (LocVT, DL, Chain, FIN,
1707
1757
MachinePointerInfo::getFixedStack (MF, FI));
1708
1758
}
@@ -2121,10 +2171,14 @@ SystemZTargetLowering::LowerCall(CallLoweringInfo &CLI,
2121
2171
// Copy all of the result registers out of their specified physreg.
2122
2172
for (CCValAssign &VA : RetLocs) {
2123
2173
// Copy the value out, gluing the copy to the end of the call sequence.
2174
+ // Special handling is needed for f16.
2175
+ MVT ArgVT = VA.getLocVT () == MVT::f16 ? MVT::f32 : VA.getLocVT ();
2124
2176
SDValue RetValue = DAG.getCopyFromReg (Chain, DL, VA.getLocReg (),
2125
- VA. getLocVT () , Glue);
2177
+ ArgVT , Glue);
2126
2178
Chain = RetValue.getValue (1 );
2127
2179
Glue = RetValue.getValue (2 );
2180
+ if (VA.getLocVT () == MVT::f16)
2181
+ RetValue = convertF32ToF16 (RetValue, DAG, DL);
2128
2182
2129
2183
// Convert the value of the return register into the value that's
2130
2184
// being returned.
0 commit comments