Skip to content

Add support for setting and querying fast math flags. #383

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions deps/LLVMExtra/include/LLVMExtra.h
Original file line number Diff line number Diff line change
Expand Up @@ -191,5 +191,28 @@ LLVMBool LLVMPostDominatorTreeInstructionDominates(LLVMPostDominatorTreeRef Tree
LLVMValueRef InstA, LLVMValueRef InstB);


// fastmath (backport of llvm/llvm-project#75123)
#if LLVM_VERSION_MAJOR < 18
enum {
LLVMFastMathAllowReassoc = (1 << 0),
LLVMFastMathNoNaNs = (1 << 1),
LLVMFastMathNoInfs = (1 << 2),
LLVMFastMathNoSignedZeros = (1 << 3),
LLVMFastMathAllowReciprocal = (1 << 4),
LLVMFastMathAllowContract = (1 << 5),
LLVMFastMathApproxFunc = (1 << 6),
LLVMFastMathNone = 0,
LLVMFastMathAll = LLVMFastMathAllowReassoc | LLVMFastMathNoNaNs | LLVMFastMathNoInfs |
LLVMFastMathNoSignedZeros | LLVMFastMathAllowReciprocal |
LLVMFastMathAllowContract | LLVMFastMathApproxFunc,
};
typedef unsigned LLVMFastMathFlags;

LLVMFastMathFlags LLVMGetFastMathFlags(LLVMValueRef FPMathInst);
void LLVMSetFastMathFlags(LLVMValueRef FPMathInst, LLVMFastMathFlags FMF);
LLVMBool LLVMCanValueUseFastMathFlags(LLVMValueRef Inst);
#endif


LLVM_C_EXTERN_C_END
#endif
56 changes: 56 additions & 0 deletions deps/LLVMExtra/lib/Core.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -578,3 +578,59 @@ LLVMBool LLVMPostDominatorTreeInstructionDominates(LLVMPostDominatorTreeRef Tree
LLVMValueRef InstA, LLVMValueRef InstB) {
return unwrap(Tree)->dominates(unwrap<Instruction>(InstA), unwrap<Instruction>(InstB));
}


// fastmath (backport of llvm/llvm-project#75123)

#if LLVM_VERSION_MAJOR < 18

static FastMathFlags mapFromLLVMFastMathFlags(LLVMFastMathFlags FMF) {
FastMathFlags NewFMF;
NewFMF.setAllowReassoc((FMF & LLVMFastMathAllowReassoc) != 0);
NewFMF.setNoNaNs((FMF & LLVMFastMathNoNaNs) != 0);
NewFMF.setNoInfs((FMF & LLVMFastMathNoInfs) != 0);
NewFMF.setNoSignedZeros((FMF & LLVMFastMathNoSignedZeros) != 0);
NewFMF.setAllowReciprocal((FMF & LLVMFastMathAllowReciprocal) != 0);
NewFMF.setAllowContract((FMF & LLVMFastMathAllowContract) != 0);
NewFMF.setApproxFunc((FMF & LLVMFastMathApproxFunc) != 0);

return NewFMF;
}

static LLVMFastMathFlags mapToLLVMFastMathFlags(FastMathFlags FMF) {
LLVMFastMathFlags NewFMF = LLVMFastMathNone;
if (FMF.allowReassoc())
NewFMF |= LLVMFastMathAllowReassoc;
if (FMF.noNaNs())
NewFMF |= LLVMFastMathNoNaNs;
if (FMF.noInfs())
NewFMF |= LLVMFastMathNoInfs;
if (FMF.noSignedZeros())
NewFMF |= LLVMFastMathNoSignedZeros;
if (FMF.allowReciprocal())
NewFMF |= LLVMFastMathAllowReciprocal;
if (FMF.allowContract())
NewFMF |= LLVMFastMathAllowContract;
if (FMF.approxFunc())
NewFMF |= LLVMFastMathApproxFunc;

return NewFMF;
}

LLVMFastMathFlags LLVMGetFastMathFlags(LLVMValueRef FPMathInst) {
Value *P = unwrap<Value>(FPMathInst);
FastMathFlags FMF = cast<Instruction>(P)->getFastMathFlags();
return mapToLLVMFastMathFlags(FMF);
}

void LLVMSetFastMathFlags(LLVMValueRef FPMathInst, LLVMFastMathFlags FMF) {
Value *P = unwrap<Value>(FPMathInst);
cast<Instruction>(P)->setFastMathFlags(mapFromLLVMFastMathFlags(FMF));
}

LLVMBool LLVMCanValueUseFastMathFlags(LLVMValueRef V) {
Value *Val = unwrap<Value>(V);
return isa<FPMathOperator>(Val);
}

#endif
26 changes: 26 additions & 0 deletions lib/13/libLLVM_extra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -337,3 +337,29 @@ function LLVMPostDominatorTreeInstructionDominates(Tree, InstA, InstB)
ccall((:LLVMPostDominatorTreeInstructionDominates, libLLVMExtra), LLVMBool, (LLVMPostDominatorTreeRef, LLVMValueRef, LLVMValueRef), Tree, InstA, InstB)
end

@cenum __JL_Ctag_52::UInt32 begin
LLVMFastMathAllowReassoc = 1
LLVMFastMathNoNaNs = 2
LLVMFastMathNoInfs = 4
LLVMFastMathNoSignedZeros = 8
LLVMFastMathAllowReciprocal = 16
LLVMFastMathAllowContract = 32
LLVMFastMathApproxFunc = 64
LLVMFastMathNone = 0
LLVMFastMathAll = 127
end

const LLVMFastMathFlags = Cuint

function LLVMGetFastMathFlags(FPMathInst)
ccall((:LLVMGetFastMathFlags, libLLVMExtra), LLVMFastMathFlags, (LLVMValueRef,), FPMathInst)
end

function LLVMSetFastMathFlags(FPMathInst, FMF)
ccall((:LLVMSetFastMathFlags, libLLVMExtra), Cvoid, (LLVMValueRef, LLVMFastMathFlags), FPMathInst, FMF)
end

function LLVMCanValueUseFastMathFlags(Inst)
ccall((:LLVMCanValueUseFastMathFlags, libLLVMExtra), LLVMBool, (LLVMValueRef,), Inst)
end

26 changes: 26 additions & 0 deletions lib/14/libLLVM_extra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -337,3 +337,29 @@ function LLVMPostDominatorTreeInstructionDominates(Tree, InstA, InstB)
ccall((:LLVMPostDominatorTreeInstructionDominates, libLLVMExtra), LLVMBool, (LLVMPostDominatorTreeRef, LLVMValueRef, LLVMValueRef), Tree, InstA, InstB)
end

@cenum __JL_Ctag_52::UInt32 begin
LLVMFastMathAllowReassoc = 1
LLVMFastMathNoNaNs = 2
LLVMFastMathNoInfs = 4
LLVMFastMathNoSignedZeros = 8
LLVMFastMathAllowReciprocal = 16
LLVMFastMathAllowContract = 32
LLVMFastMathApproxFunc = 64
LLVMFastMathNone = 0
LLVMFastMathAll = 127
end

const LLVMFastMathFlags = Cuint

function LLVMGetFastMathFlags(FPMathInst)
ccall((:LLVMGetFastMathFlags, libLLVMExtra), LLVMFastMathFlags, (LLVMValueRef,), FPMathInst)
end

function LLVMSetFastMathFlags(FPMathInst, FMF)
ccall((:LLVMSetFastMathFlags, libLLVMExtra), Cvoid, (LLVMValueRef, LLVMFastMathFlags), FPMathInst, FMF)
end

function LLVMCanValueUseFastMathFlags(Inst)
ccall((:LLVMCanValueUseFastMathFlags, libLLVMExtra), LLVMBool, (LLVMValueRef,), Inst)
end

26 changes: 26 additions & 0 deletions lib/15/libLLVM_extra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,32 @@ function LLVMPostDominatorTreeInstructionDominates(Tree, InstA, InstB)
ccall((:LLVMPostDominatorTreeInstructionDominates, libLLVMExtra), LLVMBool, (LLVMPostDominatorTreeRef, LLVMValueRef, LLVMValueRef), Tree, InstA, InstB)
end

@cenum __JL_Ctag_53::UInt32 begin
LLVMFastMathAllowReassoc = 1
LLVMFastMathNoNaNs = 2
LLVMFastMathNoInfs = 4
LLVMFastMathNoSignedZeros = 8
LLVMFastMathAllowReciprocal = 16
LLVMFastMathAllowContract = 32
LLVMFastMathApproxFunc = 64
LLVMFastMathNone = 0
LLVMFastMathAll = 127
end

const LLVMFastMathFlags = Cuint

function LLVMGetFastMathFlags(FPMathInst)
ccall((:LLVMGetFastMathFlags, libLLVMExtra), LLVMFastMathFlags, (LLVMValueRef,), FPMathInst)
end

function LLVMSetFastMathFlags(FPMathInst, FMF)
ccall((:LLVMSetFastMathFlags, libLLVMExtra), Cvoid, (LLVMValueRef, LLVMFastMathFlags), FPMathInst, FMF)
end

function LLVMCanValueUseFastMathFlags(Inst)
ccall((:LLVMCanValueUseFastMathFlags, libLLVMExtra), LLVMBool, (LLVMValueRef,), Inst)
end

mutable struct LLVMOpaquePreservedAnalyses end

const LLVMPreservedAnalysesRef = Ptr{LLVMOpaquePreservedAnalyses}
Expand Down
26 changes: 26 additions & 0 deletions lib/16/libLLVM_extra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,32 @@ function LLVMPostDominatorTreeInstructionDominates(Tree, InstA, InstB)
ccall((:LLVMPostDominatorTreeInstructionDominates, libLLVMExtra), LLVMBool, (LLVMPostDominatorTreeRef, LLVMValueRef, LLVMValueRef), Tree, InstA, InstB)
end

@cenum __JL_Ctag_53::UInt32 begin
LLVMFastMathAllowReassoc = 1
LLVMFastMathNoNaNs = 2
LLVMFastMathNoInfs = 4
LLVMFastMathNoSignedZeros = 8
LLVMFastMathAllowReciprocal = 16
LLVMFastMathAllowContract = 32
LLVMFastMathApproxFunc = 64
LLVMFastMathNone = 0
LLVMFastMathAll = 127
end

const LLVMFastMathFlags = Cuint

function LLVMGetFastMathFlags(FPMathInst)
ccall((:LLVMGetFastMathFlags, libLLVMExtra), LLVMFastMathFlags, (LLVMValueRef,), FPMathInst)
end

function LLVMSetFastMathFlags(FPMathInst, FMF)
ccall((:LLVMSetFastMathFlags, libLLVMExtra), Cvoid, (LLVMValueRef, LLVMFastMathFlags), FPMathInst, FMF)
end

function LLVMCanValueUseFastMathFlags(Inst)
ccall((:LLVMCanValueUseFastMathFlags, libLLVMExtra), LLVMBool, (LLVMValueRef,), Inst)
end

mutable struct LLVMOpaquePreservedAnalyses end

const LLVMPreservedAnalysesRef = Ptr{LLVMOpaquePreservedAnalyses}
Expand Down
60 changes: 60 additions & 0 deletions src/core/instructions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -335,3 +335,63 @@ function Base.append!(iter::PhiIncomingSet, args::Vector{Tuple{V, BasicBlock}} w
end

Base.push!(iter::PhiIncomingSet, args::Tuple{<:Value, BasicBlock}) = append!(iter, [args])


## floating point operations

export fast_math, fast_math!

"""
fast_math(inst::Instruction)

Get the fast math flags on an instruction.
"""
function fast_math(inst::Instruction)
if !Bool(API.LLVMCanValueUseFastMathFlags(inst))
throw(ArgumentError("Instruction cannot use fast math flags"))
end
flags = API.LLVMGetFastMathFlags(inst)
return (;
nnan = flags & LLVM.API.LLVMFastMathNoNaNs != 0,
ninf = flags & LLVM.API.LLVMFastMathNoInfs != 0,
nsz = flags & LLVM.API.LLVMFastMathNoSignedZeros != 0,
arcp = flags & LLVM.API.LLVMFastMathAllowReciprocal != 0,
contract = flags & LLVM.API.LLVMFastMathAllowContract != 0,
afn = flags & LLVM.API.LLVMFastMathApproxFunc != 0,
reassoc = flags & LLVM.API.LLVMFastMathAllowReassoc != 0,
)
end

"""
fast_math!(inst::Instruction; [flag=...], [all=...])

Set the fast math flags on an instruction. If `all` is `true`, then all flags are set.

The following flags are supported:
- `nnan`: assume arguments and results are not NaN
- `ninf`: assume arguments and results are not Inf
- `nsz`: treat the sign of zero arguments and results as insignificant
- `arcp`: allow use of reciprocal rather than perform division
- `contract`: allow contraction of operations
- `afn`: allow substitution of approximate calculations for functions
- `reassoc`: allow reassociation of operations
"""
function fast_math!(inst::Instruction; nnan=false, ninf=false, nsz=false, arcp=false,
contract=false, afn=false, reassoc=false, all=false)
if !Bool(API.LLVMCanValueUseFastMathFlags(inst))
throw(ArgumentError("Instruction cannot use fast math flags"))
end
if all
API.LLVMSetFastMathFlags(inst, LLVM.API.LLVMFastMathAll)
else
flags = 0
nnan && (flags |= LLVM.API.LLVMFastMathNoNaNs)
ninf && (flags |= LLVM.API.LLVMFastMathNoInfs)
nsz && (flags |= LLVM.API.LLVMFastMathNoSignedZeros)
arcp && (flags |= LLVM.API.LLVMFastMathAllowReciprocal)
contract && (flags |= LLVM.API.LLVMFastMathAllowContract)
afn && (flags |= LLVM.API.LLVMFastMathApproxFunc)
reassoc && (flags |= LLVM.API.LLVMFastMathAllowReassoc)
API.LLVMSetFastMathFlags(inst, flags)
end
end
73 changes: 73 additions & 0 deletions test/instructions_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -504,4 +504,77 @@ end
end
end


@testset "fast math" begin
@dispose ctx=Context() mod=LLVM.Module("my_module") begin
# emit some IR
param_types = [LLVM.FloatType()]
ret_type = LLVM.FloatType()
fun_type = LLVM.FunctionType(ret_type, param_types)
fun = LLVM.Function(mod, "add_sub", fun_type)
@dispose builder=IRBuilder() begin
entry = BasicBlock(fun, "entry")
position!(builder, entry)
# add and substract 42
a = fadd!(builder, parameters(fun)[1], LLVM.ConstantFP(Float32(42.)), "a")
# fast_math!(a; all=true)
b = fsub!(builder, a, LLVM.ConstantFP(Float32(42.)), "b")
# fast_math!(b; all=true)
ret!(builder, b)
end
verify(mod)

# optimize
function optimize(mod)
if LLVM.has_newpm()
host_triple = triple()
host_t = Target(triple=host_triple)
@dispose tm=TargetMachine(host_t, host_triple) pb=PassBuilder(tm) begin
NewPMModulePassManager(pb) do mpm
parse!(pb, mpm, "default<O3>")
run!(mpm, mod, tm)
end
end
else
pmb = PassManagerBuilder()
optlevel!(pmb, 3)
@dispose mpm=ModulePassManager() begin
populate!(mpm, pmb)
run!(mpm, mod)
end
end
end
optimize(mod)
verify(mod)

# ensure we still have our two operations
@test length(blocks(fun)) == 1
bb = blocks(fun)[1]
instns = collect(instructions(bb))
@test length(instns) == 3
@test instns[1] isa LLVM.FAddInst
@test instns[2] isa LLVM.FAddInst
@test instns[3] isa LLVM.RetInst

# make them fast math
@test !fast_math(instns[1]).contract
fast_math!(instns[1]; all=true)
@test fast_math(instns[1]).contract
fast_math!(instns[2]; all=true)
@test_throws ArgumentError fast_math(instns[3])
@test_throws ArgumentError fast_math!(instns[3]; all=true)

# optimize again
optimize(mod)
verify(mod)

# observe there's only a single return now
@test length(blocks(fun)) == 1
bb = blocks(fun)[1]
instns = collect(instructions(bb))
@test length(instns) == 1
@test instns[1] isa LLVM.RetInst
end
end

end