From 8fee6a77394ffda819e58a648d4b44e1e566f34b Mon Sep 17 00:00:00 2001 From: Luca Versari Date: Tue, 14 Jan 2025 22:51:09 +0100 Subject: [PATCH] Coerce safe-to-call target_feature functions to fn pointers. --- compiler/rustc_borrowck/src/type_check/mod.rs | 15 ++++++++- compiler/rustc_hir_typeck/src/coercion.rs | 23 ++++++------- compiler/rustc_middle/src/ty/context.rs | 33 ++++++++++++++++++- .../rustc_mir_build/src/check_unsafety.rs | 11 ++----- .../rfcs/rfc-2396-target_feature-11/fn-ptr.rs | 14 ++++++++ .../rfc-2396-target_feature-11/fn-ptr.stderr | 19 +++++++++-- .../return-fn-ptr.rs | 22 +++++++++++++ 7 files changed, 114 insertions(+), 23 deletions(-) create mode 100644 tests/ui/rfcs/rfc-2396-target_feature-11/return-fn-ptr.rs diff --git a/compiler/rustc_borrowck/src/type_check/mod.rs b/compiler/rustc_borrowck/src/type_check/mod.rs index a1979c8b8aba1..eca8a688ff4a2 100644 --- a/compiler/rustc_borrowck/src/type_check/mod.rs +++ b/compiler/rustc_borrowck/src/type_check/mod.rs @@ -1654,7 +1654,20 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> { match *cast_kind { CastKind::PointerCoercion(PointerCoercion::ReifyFnPointer, coercion_source) => { let is_implicit_coercion = coercion_source == CoercionSource::Implicit; - let src_sig = op.ty(body, tcx).fn_sig(tcx); + let src_ty = op.ty(body, tcx); + let mut src_sig = src_ty.fn_sig(tcx); + if let ty::FnDef(def_id, _) = src_ty.kind() + && let ty::FnPtr(_, target_hdr) = *ty.kind() + && tcx.codegen_fn_attrs(def_id).safe_target_features + && target_hdr.safety.is_safe() + && let Some(safe_sig) = tcx.adjust_target_feature_sig( + *def_id, + src_sig, + body.source.def_id(), + ) + { + src_sig = safe_sig; + } // HACK: This shouldn't be necessary... We can remove this when we actually // get binders with where clauses, then elaborate implied bounds into that diff --git a/compiler/rustc_hir_typeck/src/coercion.rs b/compiler/rustc_hir_typeck/src/coercion.rs index ec7c1efa38e2e..6945dbc321697 100644 --- a/compiler/rustc_hir_typeck/src/coercion.rs +++ b/compiler/rustc_hir_typeck/src/coercion.rs @@ -920,7 +920,7 @@ impl<'f, 'tcx> Coerce<'f, 'tcx> { match b.kind() { ty::FnPtr(_, b_hdr) => { - let a_sig = a.fn_sig(self.tcx); + let mut a_sig = a.fn_sig(self.tcx); if let ty::FnDef(def_id, _) = *a.kind() { // Intrinsics are not coercible to function pointers if self.tcx.intrinsic(def_id).is_some() { @@ -932,19 +932,20 @@ impl<'f, 'tcx> Coerce<'f, 'tcx> { return Err(TypeError::ForceInlineCast); } - let fn_attrs = self.tcx.codegen_fn_attrs(def_id); - if matches!(fn_attrs.inline, InlineAttr::Force { .. }) { - return Err(TypeError::ForceInlineCast); - } - - // FIXME(target_feature): Safe `#[target_feature]` functions could be cast to safe fn pointers (RFC 2396), - // as you can already write that "cast" in user code by wrapping a target_feature fn call in a closure, - // which is safe. This is sound because you already need to be executing code that is satisfying the target - // feature constraints.. if b_hdr.safety.is_safe() && self.tcx.codegen_fn_attrs(def_id).safe_target_features { - return Err(TypeError::TargetFeatureCast(def_id)); + // Allow the coercion if the current function has all the features that would be + // needed to call the coercee safely. + if let Some(safe_sig) = self.tcx.adjust_target_feature_sig( + def_id, + a_sig, + self.fcx.body_id.into(), + ) { + a_sig = safe_sig; + } else { + return Err(TypeError::TargetFeatureCast(def_id)); + } } } diff --git a/compiler/rustc_middle/src/ty/context.rs b/compiler/rustc_middle/src/ty/context.rs index fab0047babfd0..7035e641f39e0 100644 --- a/compiler/rustc_middle/src/ty/context.rs +++ b/compiler/rustc_middle/src/ty/context.rs @@ -60,7 +60,7 @@ use crate::dep_graph::{DepGraph, DepKindStruct}; use crate::infer::canonical::{CanonicalParamEnvCache, CanonicalVarInfo, CanonicalVarInfos}; use crate::lint::lint_level; use crate::metadata::ModChild; -use crate::middle::codegen_fn_attrs::CodegenFnAttrs; +use crate::middle::codegen_fn_attrs::{CodegenFnAttrs, TargetFeature}; use crate::middle::{resolve_bound_vars, stability}; use crate::mir::interpret::{self, Allocation, ConstAllocation}; use crate::mir::{Body, Local, Place, PlaceElem, ProjectionKind, Promoted}; @@ -1776,6 +1776,37 @@ impl<'tcx> TyCtxt<'tcx> { pub fn dcx(self) -> DiagCtxtHandle<'tcx> { self.sess.dcx() } + + pub fn is_target_feature_call_safe( + self, + callee_features: &[TargetFeature], + body_features: &[TargetFeature], + ) -> bool { + // If the called function has target features the calling function hasn't, + // the call requires `unsafe`. Don't check this on wasm + // targets, though. For more information on wasm see the + // is_like_wasm check in hir_analysis/src/collect.rs + self.sess.target.options.is_like_wasm + || callee_features + .iter() + .all(|feature| body_features.iter().any(|f| f.name == feature.name)) + } + + /// Returns the safe version of the signature of the given function, if calling it + /// would be safe in the context of the given caller. + pub fn adjust_target_feature_sig( + self, + fun_def: DefId, + fun_sig: ty::Binder<'tcx, ty::FnSig<'tcx>>, + caller: DefId, + ) -> Option>> { + let fun_features = &self.codegen_fn_attrs(fun_def).target_features; + let callee_features = &self.codegen_fn_attrs(caller).target_features; + if self.is_target_feature_call_safe(&fun_features, &callee_features) { + return Some(fun_sig.map_bound(|sig| ty::FnSig { safety: hir::Safety::Safe, ..sig })); + } + None + } } impl<'tcx> TyCtxtAt<'tcx> { diff --git a/compiler/rustc_mir_build/src/check_unsafety.rs b/compiler/rustc_mir_build/src/check_unsafety.rs index 6279d0f94af63..5eed9ef798d01 100644 --- a/compiler/rustc_mir_build/src/check_unsafety.rs +++ b/compiler/rustc_mir_build/src/check_unsafety.rs @@ -495,14 +495,9 @@ impl<'a, 'tcx> Visitor<'a, 'tcx> for UnsafetyVisitor<'a, 'tcx> { }; self.requires_unsafe(expr.span, CallToUnsafeFunction(func_id)); } else if let &ty::FnDef(func_did, _) = fn_ty.kind() { - // If the called function has target features the calling function hasn't, - // the call requires `unsafe`. Don't check this on wasm - // targets, though. For more information on wasm see the - // is_like_wasm check in hir_analysis/src/collect.rs - if !self.tcx.sess.target.options.is_like_wasm - && !callee_features.iter().all(|feature| { - self.body_target_features.iter().any(|f| f.name == feature.name) - }) + if !self + .tcx + .is_target_feature_call_safe(callee_features, self.body_target_features) { let missing: Vec<_> = callee_features .iter() diff --git a/tests/ui/rfcs/rfc-2396-target_feature-11/fn-ptr.rs b/tests/ui/rfcs/rfc-2396-target_feature-11/fn-ptr.rs index 364b4d3581276..d7c17299d061c 100644 --- a/tests/ui/rfcs/rfc-2396-target_feature-11/fn-ptr.rs +++ b/tests/ui/rfcs/rfc-2396-target_feature-11/fn-ptr.rs @@ -2,9 +2,23 @@ #![feature(target_feature_11)] +#[target_feature(enable = "avx")] +fn foo_avx() {} + #[target_feature(enable = "sse2")] fn foo() {} +#[target_feature(enable = "sse2")] +fn bar() { + let foo: fn() = foo; // this is OK, as we have the necessary target features. + let foo: fn() = foo_avx; //~ ERROR mismatched types +} + fn main() { + if std::is_x86_feature_detected!("sse2") { + unsafe { + bar(); + } + } let foo: fn() = foo; //~ ERROR mismatched types } diff --git a/tests/ui/rfcs/rfc-2396-target_feature-11/fn-ptr.stderr b/tests/ui/rfcs/rfc-2396-target_feature-11/fn-ptr.stderr index a2bda229d10e1..1228404120a4c 100644 --- a/tests/ui/rfcs/rfc-2396-target_feature-11/fn-ptr.stderr +++ b/tests/ui/rfcs/rfc-2396-target_feature-11/fn-ptr.stderr @@ -1,5 +1,20 @@ error[E0308]: mismatched types - --> $DIR/fn-ptr.rs:9:21 + --> $DIR/fn-ptr.rs:14:21 + | +LL | #[target_feature(enable = "avx")] + | --------------------------------- `#[target_feature]` added here +... +LL | let foo: fn() = foo_avx; + | ---- ^^^^^^^ cannot coerce functions with `#[target_feature]` to safe function pointers + | | + | expected due to this + | + = note: expected fn pointer `fn()` + found fn item `#[target_features] fn() {foo_avx}` + = note: functions with `#[target_feature]` can only be coerced to `unsafe` function pointers + +error[E0308]: mismatched types + --> $DIR/fn-ptr.rs:23:21 | LL | #[target_feature(enable = "sse2")] | ---------------------------------- `#[target_feature]` added here @@ -13,6 +28,6 @@ LL | let foo: fn() = foo; found fn item `#[target_features] fn() {foo}` = note: functions with `#[target_feature]` can only be coerced to `unsafe` function pointers -error: aborting due to 1 previous error +error: aborting due to 2 previous errors For more information about this error, try `rustc --explain E0308`. diff --git a/tests/ui/rfcs/rfc-2396-target_feature-11/return-fn-ptr.rs b/tests/ui/rfcs/rfc-2396-target_feature-11/return-fn-ptr.rs new file mode 100644 index 0000000000000..b49493d66096d --- /dev/null +++ b/tests/ui/rfcs/rfc-2396-target_feature-11/return-fn-ptr.rs @@ -0,0 +1,22 @@ +//@ only-x86_64 +//@ run-pass + +#![feature(target_feature_11)] + +#[target_feature(enable = "sse2")] +fn foo() -> bool { + true +} + +#[target_feature(enable = "sse2")] +fn bar() -> fn() -> bool { + foo +} + +fn main() { + if !std::is_x86_feature_detected!("sse2") { + return; + } + let f = unsafe { bar() }; + assert!(f()); +}