From 58782a8842bdd74e3304092251a1d06e5b6b550b Mon Sep 17 00:00:00 2001 From: Oli Scherer Date: Mon, 12 Dec 2022 10:48:02 +0000 Subject: [PATCH] Harden the pre-tyctxt query system against accidental recomputation --- compiler/rustc_data_structures/src/steal.rs | 5 + compiler/rustc_driver/src/lib.rs | 11 +- compiler/rustc_interface/src/queries.rs | 108 ++++++++++-------- src/librustdoc/doctest.rs | 8 +- src/librustdoc/lib.rs | 5 +- src/tools/miri/src/bin/miri.rs | 2 +- .../obtain-borrowck/driver.rs | 2 +- 7 files changed, 78 insertions(+), 63 deletions(-) diff --git a/compiler/rustc_data_structures/src/steal.rs b/compiler/rustc_data_structures/src/steal.rs index a3ece6550473c..9a0fd52677d13 100644 --- a/compiler/rustc_data_structures/src/steal.rs +++ b/compiler/rustc_data_structures/src/steal.rs @@ -40,6 +40,11 @@ impl Steal { ReadGuard::map(borrow, |opt| opt.as_ref().unwrap()) } + #[track_caller] + pub fn get_mut(&mut self) -> &mut T { + self.value.get_mut().as_mut().expect("attempt to read from stolen value") + } + #[track_caller] pub fn steal(&self) -> T { let value_ref = &mut *self.value.try_write().expect("stealing value which is locked"); diff --git a/compiler/rustc_driver/src/lib.rs b/compiler/rustc_driver/src/lib.rs index 3cbe0052359b3..508d5cfc8e312 100644 --- a/compiler/rustc_driver/src/lib.rs +++ b/compiler/rustc_driver/src/lib.rs @@ -309,8 +309,8 @@ fn run_compiler( if let Some(ppm) = &sess.opts.pretty { if ppm.needs_ast_map() { - let expanded_crate = queries.expansion()?.peek().0.clone(); - queries.global_ctxt()?.peek_mut().enter(|tcx| { + let expanded_crate = queries.expansion()?.borrow().0.clone(); + queries.global_ctxt()?.enter(|tcx| { pretty::print_after_hir_lowering( tcx, compiler.input(), @@ -321,7 +321,7 @@ fn run_compiler( Ok(()) })?; } else { - let krate = queries.parse()?.take(); + let krate = queries.parse()?.steal(); pretty::print_after_parsing( sess, compiler.input(), @@ -343,7 +343,8 @@ fn run_compiler( } { - let (_, lint_store) = &*queries.register_plugins()?.peek(); + let plugins = queries.register_plugins()?; + let (_, lint_store) = &*plugins.borrow(); // Lint plugins are registered; now we can process command line flags. if sess.opts.describe_lints { @@ -371,7 +372,7 @@ fn run_compiler( return early_exit(); } - queries.global_ctxt()?.peek_mut().enter(|tcx| { + queries.global_ctxt()?.enter(|tcx| { let result = tcx.analysis(()); if sess.opts.unstable_opts.save_analysis { let crate_name = tcx.crate_name(LOCAL_CRATE); diff --git a/compiler/rustc_interface/src/queries.rs b/compiler/rustc_interface/src/queries.rs index 1d0c7f5b7a388..041bb9eb7a1cb 100644 --- a/compiler/rustc_interface/src/queries.rs +++ b/compiler/rustc_interface/src/queries.rs @@ -5,6 +5,7 @@ use crate::passes::{self, BoxedResolver, QueryContext}; use rustc_ast as ast; use rustc_codegen_ssa::traits::CodegenBackend; use rustc_codegen_ssa::CodegenResults; +use rustc_data_structures::steal::Steal; use rustc_data_structures::svh::Svh; use rustc_data_structures::sync::{Lrc, OnceCell, WorkerLocal}; use rustc_hir::def_id::LOCAL_CRATE; @@ -19,43 +20,53 @@ use rustc_session::{output::find_crate_name, Session}; use rustc_span::symbol::sym; use rustc_span::Symbol; use std::any::Any; -use std::cell::{Ref, RefCell, RefMut}; +use std::cell::{RefCell, RefMut}; use std::rc::Rc; use std::sync::Arc; /// Represent the result of a query. /// -/// This result can be stolen with the [`take`] method and generated with the [`compute`] method. +/// This result can be stolen once with the [`steal`] method and generated with the [`compute`] method. /// -/// [`take`]: Self::take +/// [`steal`]: Steal::steal /// [`compute`]: Self::compute pub struct Query { - result: RefCell>>, + /// `None` means no value has been computed yet. + result: RefCell>>>, } impl Query { - fn compute Result>(&self, f: F) -> Result<&Query> { - self.result.borrow_mut().get_or_insert_with(f).as_ref().map(|_| self).map_err(|&err| err) + fn compute Result>(&self, f: F) -> Result> { + RefMut::filter_map( + self.result.borrow_mut(), + |r: &mut Option>>| -> Option<&mut Steal> { + r.get_or_insert_with(|| f().map(Steal::new)).as_mut().ok() + }, + ) + .map_err(|r| *r.as_ref().unwrap().as_ref().map(|_| ()).unwrap_err()) + .map(QueryResult) } +} + +pub struct QueryResult<'a, T>(RefMut<'a, Steal>); + +impl<'a, T> std::ops::Deref for QueryResult<'a, T> { + type Target = RefMut<'a, Steal>; - /// Takes ownership of the query result. Further attempts to take or peek the query - /// result will panic unless it is generated by calling the `compute` method. - pub fn take(&self) -> T { - self.result.borrow_mut().take().expect("missing query result").unwrap() + fn deref(&self) -> &Self::Target { + &self.0 } +} - /// Borrows the query result using the RefCell. Panics if the result is stolen. - pub fn peek(&self) -> Ref<'_, T> { - Ref::map(self.result.borrow(), |r| { - r.as_ref().unwrap().as_ref().expect("missing query result") - }) +impl<'a, T> std::ops::DerefMut for QueryResult<'a, T> { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 } +} - /// Mutably borrows the query result using the RefCell. Panics if the result is stolen. - pub fn peek_mut(&self) -> RefMut<'_, T> { - RefMut::map(self.result.borrow_mut(), |r| { - r.as_mut().unwrap().as_mut().expect("missing query result") - }) +impl<'a, 'tcx> QueryResult<'a, QueryContext<'tcx>> { + pub fn enter(mut self, f: impl FnOnce(TyCtxt<'tcx>) -> T) -> T { + (*self.0).get_mut().enter(f) } } @@ -111,24 +122,24 @@ impl<'tcx> Queries<'tcx> { self.compiler.codegen_backend() } - fn dep_graph_future(&self) -> Result<&Query>> { + fn dep_graph_future(&self) -> Result>> { self.dep_graph_future.compute(|| { let sess = self.session(); Ok(sess.opts.build_dep_graph().then(|| rustc_incremental::load_dep_graph(sess))) }) } - pub fn parse(&self) -> Result<&Query> { + pub fn parse(&self) -> Result> { self.parse.compute(|| { passes::parse(self.session(), &self.compiler.input) .map_err(|mut parse_error| parse_error.emit()) }) } - pub fn register_plugins(&self) -> Result<&Query<(ast::Crate, Lrc)>> { + pub fn register_plugins(&self) -> Result)>> { self.register_plugins.compute(|| { - let crate_name = *self.crate_name()?.peek(); - let krate = self.parse()?.take(); + let crate_name = *self.crate_name()?.borrow(); + let krate = self.parse()?.steal(); let empty: &(dyn Fn(&Session, &mut LintStore) + Sync + Send) = &|_, _| {}; let (krate, lint_store) = passes::register_plugins( @@ -150,11 +161,11 @@ impl<'tcx> Queries<'tcx> { }) } - pub fn crate_name(&self) -> Result<&Query> { + pub fn crate_name(&self) -> Result> { self.crate_name.compute(|| { Ok({ let parse_result = self.parse()?; - let krate = parse_result.peek(); + let krate = parse_result.borrow(); // parse `#[crate_name]` even if `--crate-name` was passed, to make sure it matches. find_crate_name(self.session(), &krate.attrs, &self.compiler.input) }) @@ -163,11 +174,12 @@ impl<'tcx> Queries<'tcx> { pub fn expansion( &self, - ) -> Result<&Query<(Lrc, Rc>, Lrc)>> { + ) -> Result, Rc>, Lrc)>> + { trace!("expansion"); self.expansion.compute(|| { - let crate_name = *self.crate_name()?.peek(); - let (krate, lint_store) = self.register_plugins()?.take(); + let crate_name = *self.crate_name()?.borrow(); + let (krate, lint_store) = self.register_plugins()?.steal(); let _timer = self.session().timer("configure_and_expand"); let sess = self.session(); let mut resolver = passes::create_resolver( @@ -183,10 +195,10 @@ impl<'tcx> Queries<'tcx> { }) } - fn dep_graph(&self) -> Result<&Query> { + fn dep_graph(&self) -> Result> { self.dep_graph.compute(|| { let sess = self.session(); - let future_opt = self.dep_graph_future()?.take(); + let future_opt = self.dep_graph_future()?.steal(); let dep_graph = future_opt .and_then(|future| { let (prev_graph, prev_work_products) = @@ -199,10 +211,11 @@ impl<'tcx> Queries<'tcx> { }) } - pub fn prepare_outputs(&self) -> Result<&Query> { + pub fn prepare_outputs(&self) -> Result> { self.prepare_outputs.compute(|| { - let (krate, boxed_resolver, _) = &*self.expansion()?.peek(); - let crate_name = *self.crate_name()?.peek(); + let expansion = self.expansion()?; + let (krate, boxed_resolver, _) = &*expansion.borrow(); + let crate_name = *self.crate_name()?.borrow(); passes::prepare_outputs( self.session(), self.compiler, @@ -213,12 +226,12 @@ impl<'tcx> Queries<'tcx> { }) } - pub fn global_ctxt(&'tcx self) -> Result<&Query>> { + pub fn global_ctxt(&'tcx self) -> Result>> { self.global_ctxt.compute(|| { - let crate_name = *self.crate_name()?.peek(); - let outputs = self.prepare_outputs()?.take(); - let dep_graph = self.dep_graph()?.peek().clone(); - let (krate, resolver, lint_store) = self.expansion()?.take(); + let crate_name = *self.crate_name()?.borrow(); + let outputs = self.prepare_outputs()?.steal(); + let dep_graph = self.dep_graph()?.borrow().clone(); + let (krate, resolver, lint_store) = self.expansion()?.steal(); Ok(passes::create_global_ctxt( self.compiler, lint_store, @@ -235,9 +248,9 @@ impl<'tcx> Queries<'tcx> { }) } - pub fn ongoing_codegen(&'tcx self) -> Result<&Query>> { + pub fn ongoing_codegen(&'tcx self) -> Result>> { self.ongoing_codegen.compute(|| { - self.global_ctxt()?.peek_mut().enter(|tcx| { + self.global_ctxt()?.enter(|tcx| { tcx.analysis(()).ok(); // Don't do code generation if there were any errors @@ -293,12 +306,10 @@ impl<'tcx> Queries<'tcx> { let sess = self.session().clone(); let codegen_backend = self.codegen_backend().clone(); - let dep_graph = self.dep_graph()?.peek().clone(); - let (crate_hash, prepare_outputs) = self - .global_ctxt()? - .peek_mut() - .enter(|tcx| (tcx.crate_hash(LOCAL_CRATE), tcx.output_filenames(()).clone())); - let ongoing_codegen = self.ongoing_codegen()?.take(); + let (crate_hash, prepare_outputs, dep_graph) = self.global_ctxt()?.enter(|tcx| { + (tcx.crate_hash(LOCAL_CRATE), tcx.output_filenames(()).clone(), tcx.dep_graph.clone()) + }); + let ongoing_codegen = self.ongoing_codegen()?.steal(); Ok(Linker { sess, @@ -382,6 +393,7 @@ impl Compiler { // NOTE: intentionally does not compute the global context if it hasn't been built yet, // since that likely means there was a parse error. if let Some(Ok(gcx)) = &mut *queries.global_ctxt.result.borrow_mut() { + let gcx = gcx.get_mut(); // We assume that no queries are run past here. If there are new queries // after this point, they'll show up as "" in self-profiling data. { diff --git a/src/librustdoc/doctest.rs b/src/librustdoc/doctest.rs index 81d9c46447a37..5f8c777f32a5c 100644 --- a/src/librustdoc/doctest.rs +++ b/src/librustdoc/doctest.rs @@ -115,9 +115,7 @@ pub(crate) fn run(options: RustdocOptions) -> Result<(), ErrorGuaranteed> { let (tests, unused_extern_reports, compiling_test_count) = interface::run_compiler(config, |compiler| { compiler.enter(|queries| { - let mut global_ctxt = queries.global_ctxt()?.take(); - - let collector = global_ctxt.enter(|tcx| { + let collector = queries.global_ctxt()?.enter(|tcx| { let crate_attrs = tcx.hir().attrs(CRATE_HIR_ID); let opts = scrape_test_config(crate_attrs); @@ -156,9 +154,7 @@ pub(crate) fn run(options: RustdocOptions) -> Result<(), ErrorGuaranteed> { let unused_extern_reports = collector.unused_extern_reports.clone(); let compiling_test_count = collector.compiling_test_count.load(Ordering::SeqCst); - let ret: Result<_, ErrorGuaranteed> = - Ok((collector.tests, unused_extern_reports, compiling_test_count)); - ret + Ok((collector.tests, unused_extern_reports, compiling_test_count)) }) })?; diff --git a/src/librustdoc/lib.rs b/src/librustdoc/lib.rs index ef1d7da5a341c..ed77de200a9b7 100644 --- a/src/librustdoc/lib.rs +++ b/src/librustdoc/lib.rs @@ -800,7 +800,8 @@ fn main_args(at_args: &[String]) -> MainResult { // FIXME(#83761): Resolver cloning can lead to inconsistencies between data in the // two copies because one of the copies can be modified after `TyCtxt` construction. let (resolver, resolver_caches) = { - let (krate, resolver, _) = &*abort_on_err(queries.expansion(), sess).peek(); + let expansion = abort_on_err(queries.expansion(), sess); + let (krate, resolver, _) = &*expansion.borrow(); let resolver_caches = resolver.borrow_mut().access(|resolver| { collect_intra_doc_links::early_resolve_intra_doc_links( resolver, @@ -817,7 +818,7 @@ fn main_args(at_args: &[String]) -> MainResult { sess.fatal("Compilation failed, aborting rustdoc"); } - let mut global_ctxt = abort_on_err(queries.global_ctxt(), sess).peek_mut(); + let global_ctxt = abort_on_err(queries.global_ctxt(), sess); global_ctxt.enter(|tcx| { let (krate, render_opts, mut cache) = sess.time("run_global_ctxt", || { diff --git a/src/tools/miri/src/bin/miri.rs b/src/tools/miri/src/bin/miri.rs index 8c01748613cf8..6a147de3be2ea 100644 --- a/src/tools/miri/src/bin/miri.rs +++ b/src/tools/miri/src/bin/miri.rs @@ -61,7 +61,7 @@ impl rustc_driver::Callbacks for MiriCompilerCalls { ) -> Compilation { compiler.session().abort_if_errors(); - queries.global_ctxt().unwrap().peek_mut().enter(|tcx| { + queries.global_ctxt().unwrap().enter(|tcx| { init_late_loggers(tcx); if !tcx.sess.crate_types().contains(&CrateType::Executable) { tcx.sess.fatal("miri only makes sense on bin crates"); diff --git a/tests/run-make-fulldeps/obtain-borrowck/driver.rs b/tests/run-make-fulldeps/obtain-borrowck/driver.rs index a6c60df83a63c..9cd504f004dec 100644 --- a/tests/run-make-fulldeps/obtain-borrowck/driver.rs +++ b/tests/run-make-fulldeps/obtain-borrowck/driver.rs @@ -62,7 +62,7 @@ impl rustc_driver::Callbacks for CompilerCalls { queries: &'tcx Queries<'tcx>, ) -> Compilation { compiler.session().abort_if_errors(); - queries.global_ctxt().unwrap().peek_mut().enter(|tcx| { + queries.global_ctxt().unwrap().enter(|tcx| { // Collect definition ids of MIR bodies. let hir = tcx.hir(); let mut bodies = Vec::new();