Skip to content

Commit bbc6d16

Browse files
authored
Rollup merge of #135504 - veluca93:target-feature-cast-to-fn-ptr, r=oli-obk
Allow coercing safe-to-call target_feature functions to safe fn pointers r? oli-obk `@oli-obk:` this is based on your PR #134353 :-) See #134090 (comment) for the motivation behind this change.
2 parents 4aae8d1 + 8fee6a7 commit bbc6d16

File tree

7 files changed

+114
-23
lines changed

7 files changed

+114
-23
lines changed

Diff for: compiler/rustc_borrowck/src/type_check/mod.rs

+14-1
Original file line numberDiff line numberDiff line change
@@ -1654,7 +1654,20 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
16541654
match *cast_kind {
16551655
CastKind::PointerCoercion(PointerCoercion::ReifyFnPointer, coercion_source) => {
16561656
let is_implicit_coercion = coercion_source == CoercionSource::Implicit;
1657-
let src_sig = op.ty(body, tcx).fn_sig(tcx);
1657+
let src_ty = op.ty(body, tcx);
1658+
let mut src_sig = src_ty.fn_sig(tcx);
1659+
if let ty::FnDef(def_id, _) = src_ty.kind()
1660+
&& let ty::FnPtr(_, target_hdr) = *ty.kind()
1661+
&& tcx.codegen_fn_attrs(def_id).safe_target_features
1662+
&& target_hdr.safety.is_safe()
1663+
&& let Some(safe_sig) = tcx.adjust_target_feature_sig(
1664+
*def_id,
1665+
src_sig,
1666+
body.source.def_id(),
1667+
)
1668+
{
1669+
src_sig = safe_sig;
1670+
}
16581671

16591672
// HACK: This shouldn't be necessary... We can remove this when we actually
16601673
// get binders with where clauses, then elaborate implied bounds into that

Diff for: compiler/rustc_hir_typeck/src/coercion.rs

+12-11
Original file line numberDiff line numberDiff line change
@@ -920,7 +920,7 @@ impl<'f, 'tcx> Coerce<'f, 'tcx> {
920920

921921
match b.kind() {
922922
ty::FnPtr(_, b_hdr) => {
923-
let a_sig = a.fn_sig(self.tcx);
923+
let mut a_sig = a.fn_sig(self.tcx);
924924
if let ty::FnDef(def_id, _) = *a.kind() {
925925
// Intrinsics are not coercible to function pointers
926926
if self.tcx.intrinsic(def_id).is_some() {
@@ -932,19 +932,20 @@ impl<'f, 'tcx> Coerce<'f, 'tcx> {
932932
return Err(TypeError::ForceInlineCast);
933933
}
934934

935-
let fn_attrs = self.tcx.codegen_fn_attrs(def_id);
936-
if matches!(fn_attrs.inline, InlineAttr::Force { .. }) {
937-
return Err(TypeError::ForceInlineCast);
938-
}
939-
940-
// FIXME(target_feature): Safe `#[target_feature]` functions could be cast to safe fn pointers (RFC 2396),
941-
// as you can already write that "cast" in user code by wrapping a target_feature fn call in a closure,
942-
// which is safe. This is sound because you already need to be executing code that is satisfying the target
943-
// feature constraints..
944935
if b_hdr.safety.is_safe()
945936
&& self.tcx.codegen_fn_attrs(def_id).safe_target_features
946937
{
947-
return Err(TypeError::TargetFeatureCast(def_id));
938+
// Allow the coercion if the current function has all the features that would be
939+
// needed to call the coercee safely.
940+
if let Some(safe_sig) = self.tcx.adjust_target_feature_sig(
941+
def_id,
942+
a_sig,
943+
self.fcx.body_id.into(),
944+
) {
945+
a_sig = safe_sig;
946+
} else {
947+
return Err(TypeError::TargetFeatureCast(def_id));
948+
}
948949
}
949950
}
950951

Diff for: compiler/rustc_middle/src/ty/context.rs

+32-1
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ use crate::dep_graph::{DepGraph, DepKindStruct};
6060
use crate::infer::canonical::{CanonicalParamEnvCache, CanonicalVarInfo, CanonicalVarInfos};
6161
use crate::lint::lint_level;
6262
use crate::metadata::ModChild;
63-
use crate::middle::codegen_fn_attrs::CodegenFnAttrs;
63+
use crate::middle::codegen_fn_attrs::{CodegenFnAttrs, TargetFeature};
6464
use crate::middle::{resolve_bound_vars, stability};
6565
use crate::mir::interpret::{self, Allocation, ConstAllocation};
6666
use crate::mir::{Body, Local, Place, PlaceElem, ProjectionKind, Promoted};
@@ -1776,6 +1776,37 @@ impl<'tcx> TyCtxt<'tcx> {
17761776
pub fn dcx(self) -> DiagCtxtHandle<'tcx> {
17771777
self.sess.dcx()
17781778
}
1779+
1780+
pub fn is_target_feature_call_safe(
1781+
self,
1782+
callee_features: &[TargetFeature],
1783+
body_features: &[TargetFeature],
1784+
) -> bool {
1785+
// If the called function has target features the calling function hasn't,
1786+
// the call requires `unsafe`. Don't check this on wasm
1787+
// targets, though. For more information on wasm see the
1788+
// is_like_wasm check in hir_analysis/src/collect.rs
1789+
self.sess.target.options.is_like_wasm
1790+
|| callee_features
1791+
.iter()
1792+
.all(|feature| body_features.iter().any(|f| f.name == feature.name))
1793+
}
1794+
1795+
/// Returns the safe version of the signature of the given function, if calling it
1796+
/// would be safe in the context of the given caller.
1797+
pub fn adjust_target_feature_sig(
1798+
self,
1799+
fun_def: DefId,
1800+
fun_sig: ty::Binder<'tcx, ty::FnSig<'tcx>>,
1801+
caller: DefId,
1802+
) -> Option<ty::Binder<'tcx, ty::FnSig<'tcx>>> {
1803+
let fun_features = &self.codegen_fn_attrs(fun_def).target_features;
1804+
let callee_features = &self.codegen_fn_attrs(caller).target_features;
1805+
if self.is_target_feature_call_safe(&fun_features, &callee_features) {
1806+
return Some(fun_sig.map_bound(|sig| ty::FnSig { safety: hir::Safety::Safe, ..sig }));
1807+
}
1808+
None
1809+
}
17791810
}
17801811

17811812
impl<'tcx> TyCtxtAt<'tcx> {

Diff for: compiler/rustc_mir_build/src/check_unsafety.rs

+3-8
Original file line numberDiff line numberDiff line change
@@ -495,14 +495,9 @@ impl<'a, 'tcx> Visitor<'a, 'tcx> for UnsafetyVisitor<'a, 'tcx> {
495495
};
496496
self.requires_unsafe(expr.span, CallToUnsafeFunction(func_id));
497497
} else if let &ty::FnDef(func_did, _) = fn_ty.kind() {
498-
// If the called function has target features the calling function hasn't,
499-
// the call requires `unsafe`. Don't check this on wasm
500-
// targets, though. For more information on wasm see the
501-
// is_like_wasm check in hir_analysis/src/collect.rs
502-
if !self.tcx.sess.target.options.is_like_wasm
503-
&& !callee_features.iter().all(|feature| {
504-
self.body_target_features.iter().any(|f| f.name == feature.name)
505-
})
498+
if !self
499+
.tcx
500+
.is_target_feature_call_safe(callee_features, self.body_target_features)
506501
{
507502
let missing: Vec<_> = callee_features
508503
.iter()

Diff for: tests/ui/rfcs/rfc-2396-target_feature-11/fn-ptr.rs

+14
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,23 @@
22

33
#![feature(target_feature_11)]
44

5+
#[target_feature(enable = "avx")]
6+
fn foo_avx() {}
7+
58
#[target_feature(enable = "sse2")]
69
fn foo() {}
710

11+
#[target_feature(enable = "sse2")]
12+
fn bar() {
13+
let foo: fn() = foo; // this is OK, as we have the necessary target features.
14+
let foo: fn() = foo_avx; //~ ERROR mismatched types
15+
}
16+
817
fn main() {
18+
if std::is_x86_feature_detected!("sse2") {
19+
unsafe {
20+
bar();
21+
}
22+
}
923
let foo: fn() = foo; //~ ERROR mismatched types
1024
}

Diff for: tests/ui/rfcs/rfc-2396-target_feature-11/fn-ptr.stderr

+17-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,20 @@
11
error[E0308]: mismatched types
2-
--> $DIR/fn-ptr.rs:9:21
2+
--> $DIR/fn-ptr.rs:14:21
3+
|
4+
LL | #[target_feature(enable = "avx")]
5+
| --------------------------------- `#[target_feature]` added here
6+
...
7+
LL | let foo: fn() = foo_avx;
8+
| ---- ^^^^^^^ cannot coerce functions with `#[target_feature]` to safe function pointers
9+
| |
10+
| expected due to this
11+
|
12+
= note: expected fn pointer `fn()`
13+
found fn item `#[target_features] fn() {foo_avx}`
14+
= note: functions with `#[target_feature]` can only be coerced to `unsafe` function pointers
15+
16+
error[E0308]: mismatched types
17+
--> $DIR/fn-ptr.rs:23:21
318
|
419
LL | #[target_feature(enable = "sse2")]
520
| ---------------------------------- `#[target_feature]` added here
@@ -13,6 +28,6 @@ LL | let foo: fn() = foo;
1328
found fn item `#[target_features] fn() {foo}`
1429
= note: functions with `#[target_feature]` can only be coerced to `unsafe` function pointers
1530

16-
error: aborting due to 1 previous error
31+
error: aborting due to 2 previous errors
1732

1833
For more information about this error, try `rustc --explain E0308`.
+22
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
//@ only-x86_64
2+
//@ run-pass
3+
4+
#![feature(target_feature_11)]
5+
6+
#[target_feature(enable = "sse2")]
7+
fn foo() -> bool {
8+
true
9+
}
10+
11+
#[target_feature(enable = "sse2")]
12+
fn bar() -> fn() -> bool {
13+
foo
14+
}
15+
16+
fn main() {
17+
if !std::is_x86_feature_detected!("sse2") {
18+
return;
19+
}
20+
let f = unsafe { bar() };
21+
assert!(f());
22+
}

0 commit comments

Comments
 (0)