Skip to content

Commit

Permalink
Auto merge of rust-lang#129458 - EnzymeAD:enzyme-frontend, r=jieyouxu
Browse files Browse the repository at this point in the history
Autodiff Upstreaming - enzyme frontend

This is an upstream PR for the `autodiff` rustc_builtin_macro that is part of the autodiff feature.

For the full implementation, see: rust-lang#129175

**Content:**
It contains a new `#[autodiff(<args>)]` rustc_builtin_macro, as well as a `#[rustc_autodiff]` builtin attribute.
The autodiff macro is applied on function `f` and will expand to a second function `df` (name given by user).
It will add a dummy body to `df` to make sure it type-checks. The body will later be replaced by enzyme on llvm-ir level,
we therefore don't really care about the content. Most of the changes (700 from 1.2k) are in `compiler/rustc_builtin_macros/src/autodiff.rs`, which expand the macro. Nothing except expansion is implemented for now.
I have a fallback implementation for relevant functions in case that rustc should be build without autodiff support. The default for now will be off, although we want to flip it later (once everything landed) to on for nightly. For the sake of CI, I have flipped the defaults, I'll revert this before merging.

**Dummy function Body:**
The first line is an `inline_asm` nop to make inlining less likely (I have additional checks to prevent this in the middle end of rustc. If `f` gets inlined too early, we can't pass it to enzyme and thus can't differentiate it.
If `df` gets inlined too early, the call site will just compute this dummy code instead of the derivatives, a correctness issue. The following black_box lines make sure that none of the input arguments is getting optimized away before we replace the body.

**Motivation:**
The user facing autodiff macro can verify the user input. Then I write it as args to the rustc_attribute, so from here on I can know that these values should be sensible. A rustc_attribute also turned out to be quite nice to attach this information to the corresponding function and carry it till the backend.
This is also just an experiment, I expect to adjust the user facing autodiff macro based on user feedback, to improve usability.

As a simple example of what this will do, we can see this expansion:
From:
```
#[autodiff(df, Reverse, Duplicated, Const, Active)]
pub fn f1(x: &[f64], y: f64) -> f64 {
    unimplemented!()
}
```
to
```
#[rustc_autodiff]
#[inline(never)]
pub fn f1(x: &[f64], y: f64) -> f64 {
    ::core::panicking::panic("not implemented")
}
#[rustc_autodiff(Reverse, Duplicated, Const, Active,)]
#[inline(never)]
pub fn df(x: &[f64], dx: &mut [f64], y: f64, dret: f64) -> f64 {
    unsafe { asm!("NOP"); };
    ::core::hint::black_box(f1(x, y));
    ::core::hint::black_box((dx, dret));
    ::core::hint::black_box(f1(x, y))
}
```
I will add a few more tests once I figured out why rustc rebuilds every time I touch a test.

Tracking:

- rust-lang#124509

try-job: dist-x86_64-msvc
  • Loading branch information
bors committed Oct 15, 2024
2 parents 9322d18 + 7c37d2d commit 785c830
Show file tree
Hide file tree
Showing 32 changed files with 2,128 additions and 1 deletion.
283 changes: 283 additions & 0 deletions compiler/rustc_ast/src/expand/autodiff_attrs.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,283 @@
//! This crate handles the user facing autodiff macro. For each `#[autodiff(...)]` attribute,
//! we create an [`AutoDiffItem`] which contains the source and target function names. The source
//! is the function to which the autodiff attribute is applied, and the target is the function
//! getting generated by us (with a name given by the user as the first autodiff arg).
use std::fmt::{self, Display, Formatter};
use std::str::FromStr;

use crate::expand::typetree::TypeTree;
use crate::expand::{Decodable, Encodable, HashStable_Generic};
use crate::ptr::P;
use crate::{Ty, TyKind};

/// Forward and Reverse Mode are well known names for automatic differentiation implementations.
/// Enzyme does support both, but with different semantics, see DiffActivity. The First variants
/// are a hack to support higher order derivatives. We need to compute first order derivatives
/// before we compute second order derivatives, otherwise we would differentiate our placeholder
/// functions. The proper solution is to recognize and resolve this DAG of autodiff invocations,
/// as it's already done in the C++ and Julia frontend of Enzyme.
///
/// (FIXME) remove *First variants.
/// Documentation for using [reverse](https://enzyme.mit.edu/rust/rev.html) and
/// [forward](https://enzyme.mit.edu/rust/fwd.html) mode is available online.
#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
pub enum DiffMode {
/// No autodiff is applied (used during error handling).
Error,
/// The primal function which we will differentiate.
Source,
/// The target function, to be created using forward mode AD.
Forward,
/// The target function, to be created using reverse mode AD.
Reverse,
/// The target function, to be created using forward mode AD.
/// This target function will also be used as a source for higher order derivatives,
/// so compute it before all Forward/Reverse targets and optimize it through llvm.
ForwardFirst,
/// The target function, to be created using reverse mode AD.
/// This target function will also be used as a source for higher order derivatives,
/// so compute it before all Forward/Reverse targets and optimize it through llvm.
ReverseFirst,
}

/// Dual and Duplicated (and their Only variants) are getting lowered to the same Enzyme Activity.
/// However, under forward mode we overwrite the previous shadow value, while for reverse mode
/// we add to the previous shadow value. To not surprise users, we picked different names.
/// Dual numbers is also a quite well known name for forward mode AD types.
#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
pub enum DiffActivity {
/// Implicit or Explicit () return type, so a special case of Const.
None,
/// Don't compute derivatives with respect to this input/output.
Const,
/// Reverse Mode, Compute derivatives for this scalar input/output.
Active,
/// Reverse Mode, Compute derivatives for this scalar output, but don't compute
/// the original return value.
ActiveOnly,
/// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument
/// with it.
Dual,
/// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument
/// with it. Drop the code which updates the original input/output for maximum performance.
DualOnly,
/// Reverse Mode, Compute derivatives for this &T or *T input and *add* it to the shadow argument.
Duplicated,
/// Reverse Mode, Compute derivatives for this &T or *T input and *add* it to the shadow argument.
/// Drop the code which updates the original input for maximum performance.
DuplicatedOnly,
/// All Integers must be Const, but these are used to mark the integer which represents the
/// length of a slice/vec. This is used for safety checks on slices.
FakeActivitySize,
}
/// We generate one of these structs for each `#[autodiff(...)]` attribute.
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
pub struct AutoDiffItem {
/// The name of the function getting differentiated
pub source: String,
/// The name of the function being generated
pub target: String,
pub attrs: AutoDiffAttrs,
/// Describe the memory layout of input types
pub inputs: Vec<TypeTree>,
/// Describe the memory layout of the output type
pub output: TypeTree,
}
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
pub struct AutoDiffAttrs {
/// Conceptually either forward or reverse mode AD, as described in various autodiff papers and
/// e.g. in the [JAX
/// Documentation](https://jax.readthedocs.io/en/latest/_tutorials/advanced-autodiff.html#how-it-s-made-two-foundational-autodiff-functions).
pub mode: DiffMode,
pub ret_activity: DiffActivity,
pub input_activity: Vec<DiffActivity>,
}

impl DiffMode {
pub fn is_rev(&self) -> bool {
matches!(self, DiffMode::Reverse | DiffMode::ReverseFirst)
}
pub fn is_fwd(&self) -> bool {
matches!(self, DiffMode::Forward | DiffMode::ForwardFirst)
}
}

impl Display for DiffMode {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
match self {
DiffMode::Error => write!(f, "Error"),
DiffMode::Source => write!(f, "Source"),
DiffMode::Forward => write!(f, "Forward"),
DiffMode::Reverse => write!(f, "Reverse"),
DiffMode::ForwardFirst => write!(f, "ForwardFirst"),
DiffMode::ReverseFirst => write!(f, "ReverseFirst"),
}
}
}

/// Active(Only) is valid in reverse-mode AD for scalar float returns (f16/f32/...).
/// Dual(Only) is valid in forward-mode AD for scalar float returns (f16/f32/...).
/// Const is valid for all cases and means that we don't compute derivatives wrt. this output.
/// That usually means we have a &mut or *mut T output and compute derivatives wrt. that arg,
/// but this is too complex to verify here. Also it's just a logic error if users get this wrong.
pub fn valid_ret_activity(mode: DiffMode, activity: DiffActivity) -> bool {
if activity == DiffActivity::None {
// Only valid if primal returns (), but we can't check that here.
return true;
}
match mode {
DiffMode::Error => false,
DiffMode::Source => false,
DiffMode::Forward | DiffMode::ForwardFirst => {
activity == DiffActivity::Dual
|| activity == DiffActivity::DualOnly
|| activity == DiffActivity::Const
}
DiffMode::Reverse | DiffMode::ReverseFirst => {
activity == DiffActivity::Const
|| activity == DiffActivity::Active
|| activity == DiffActivity::ActiveOnly
}
}
}

/// For indirections (ptr/ref) we can't use Active, since Active allocates a shadow value
/// for the given argument, but we generally can't know the size of such a type.
/// For scalar types (f16/f32/f64/f128) we can use Active and we can't use Duplicated,
/// since Duplicated expects a mutable ref/ptr and we would thus end up with a shadow value
/// who is an indirect type, which doesn't match the primal scalar type. We can't prevent
/// users here from marking scalars as Duplicated, due to type aliases.
pub fn valid_ty_for_activity(ty: &P<Ty>, activity: DiffActivity) -> bool {
use DiffActivity::*;
// It's always allowed to mark something as Const, since we won't compute derivatives wrt. it.
if matches!(activity, Const) {
return true;
}
if matches!(activity, Dual | DualOnly) {
return true;
}
// FIXME(ZuseZ4) We should make this more robust to also
// handle type aliases. Once that is done, we can be more restrictive here.
if matches!(activity, Active | ActiveOnly) {
return true;
}
matches!(ty.kind, TyKind::Ptr(_) | TyKind::Ref(..))
&& matches!(activity, Duplicated | DuplicatedOnly)
}
pub fn valid_input_activity(mode: DiffMode, activity: DiffActivity) -> bool {
use DiffActivity::*;
return match mode {
DiffMode::Error => false,
DiffMode::Source => false,
DiffMode::Forward | DiffMode::ForwardFirst => {
matches!(activity, Dual | DualOnly | Const)
}
DiffMode::Reverse | DiffMode::ReverseFirst => {
matches!(activity, Active | ActiveOnly | Duplicated | DuplicatedOnly | Const)
}
};
}

impl Display for DiffActivity {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
DiffActivity::None => write!(f, "None"),
DiffActivity::Const => write!(f, "Const"),
DiffActivity::Active => write!(f, "Active"),
DiffActivity::ActiveOnly => write!(f, "ActiveOnly"),
DiffActivity::Dual => write!(f, "Dual"),
DiffActivity::DualOnly => write!(f, "DualOnly"),
DiffActivity::Duplicated => write!(f, "Duplicated"),
DiffActivity::DuplicatedOnly => write!(f, "DuplicatedOnly"),
DiffActivity::FakeActivitySize => write!(f, "FakeActivitySize"),
}
}
}

impl FromStr for DiffMode {
type Err = ();

fn from_str(s: &str) -> Result<DiffMode, ()> {
match s {
"Error" => Ok(DiffMode::Error),
"Source" => Ok(DiffMode::Source),
"Forward" => Ok(DiffMode::Forward),
"Reverse" => Ok(DiffMode::Reverse),
"ForwardFirst" => Ok(DiffMode::ForwardFirst),
"ReverseFirst" => Ok(DiffMode::ReverseFirst),
_ => Err(()),
}
}
}
impl FromStr for DiffActivity {
type Err = ();

fn from_str(s: &str) -> Result<DiffActivity, ()> {
match s {
"None" => Ok(DiffActivity::None),
"Active" => Ok(DiffActivity::Active),
"ActiveOnly" => Ok(DiffActivity::ActiveOnly),
"Const" => Ok(DiffActivity::Const),
"Dual" => Ok(DiffActivity::Dual),
"DualOnly" => Ok(DiffActivity::DualOnly),
"Duplicated" => Ok(DiffActivity::Duplicated),
"DuplicatedOnly" => Ok(DiffActivity::DuplicatedOnly),
_ => Err(()),
}
}
}

impl AutoDiffAttrs {
pub fn has_ret_activity(&self) -> bool {
self.ret_activity != DiffActivity::None
}
pub fn has_active_only_ret(&self) -> bool {
self.ret_activity == DiffActivity::ActiveOnly
}

pub fn error() -> Self {
AutoDiffAttrs {
mode: DiffMode::Error,
ret_activity: DiffActivity::None,
input_activity: Vec::new(),
}
}
pub fn source() -> Self {
AutoDiffAttrs {
mode: DiffMode::Source,
ret_activity: DiffActivity::None,
input_activity: Vec::new(),
}
}

pub fn is_active(&self) -> bool {
self.mode != DiffMode::Error
}

pub fn is_source(&self) -> bool {
self.mode == DiffMode::Source
}
pub fn apply_autodiff(&self) -> bool {
!matches!(self.mode, DiffMode::Error | DiffMode::Source)
}

pub fn into_item(
self,
source: String,
target: String,
inputs: Vec<TypeTree>,
output: TypeTree,
) -> AutoDiffItem {
AutoDiffItem { source, target, inputs, output, attrs: self }
}
}

impl fmt::Display for AutoDiffItem {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Differentiating {} -> {}", self.source, self.target)?;
write!(f, " with attributes: {:?}", self.attrs)?;
write!(f, " with inputs: {:?}", self.inputs)?;
write!(f, " with output: {:?}", self.output)
}
}
2 changes: 2 additions & 0 deletions compiler/rustc_ast/src/expand/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ use rustc_span::symbol::Ident;
use crate::MetaItem;

pub mod allocator;
pub mod autodiff_attrs;
pub mod typetree;

#[derive(Debug, Clone, Encodable, Decodable, HashStable_Generic)]
pub struct StrippedCfgItem<ModId = DefId> {
Expand Down
90 changes: 90 additions & 0 deletions compiler/rustc_ast/src/expand/typetree.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
//! This module contains the definition of the `TypeTree` and `Type` structs.
//! They are thin Rust wrappers around the TypeTrees used by Enzyme as the LLVM based autodiff
//! backend. The Enzyme TypeTrees currently have various limitations and should be rewritten, so the
//! Rust frontend obviously has the same limitations. The main motivation of TypeTrees is to
//! represent how a type looks like "in memory". Enzyme can deduce this based on usage patterns in
//! the user code, but this is extremely slow and not even always sufficient. As such we lower some
//! information from rustc to help Enzyme. For a full explanation of their design it is necessary to
//! analyze the implementation in Enzyme core itself. As a rough summary, `-1` in Enzyme speech means
//! everywhere. That is `{0:-1: Float}` means at index 0 you have a ptr, if you dereference it it
//! will be floats everywhere. Thus `* f32`. If you have `{-1:int}` it means int's everywhere,
//! e.g. [i32; N]. `{0:-1:-1 float}` then means one pointer at offset 0, if you dereference it there
//! will be only pointers, if you dereference these new pointers they will point to array of floats.
//! Generally, it allows byte-specific descriptions.
//! FIXME: This description might be partly inaccurate and should be extended, along with
//! adding documentation to the corresponding Enzyme core code.
//! FIXME: Rewrite the TypeTree logic in Enzyme core to reduce the need for the rustc frontend to
//! provide typetree information.
//! FIXME: We should also re-evaluate where we create TypeTrees from Rust types, since MIR
//! representations of some types might not be accurate. For example a vector of floats might be
//! represented as a vector of u8s in MIR in some cases.
use std::fmt;

use crate::expand::{Decodable, Encodable, HashStable_Generic};

#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
pub enum Kind {
Anything,
Integer,
Pointer,
Half,
Float,
Double,
Unknown,
}

#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
pub struct TypeTree(pub Vec<Type>);

impl TypeTree {
pub fn new() -> Self {
Self(Vec::new())
}
pub fn all_ints() -> Self {
Self(vec![Type { offset: -1, size: 1, kind: Kind::Integer, child: TypeTree::new() }])
}
pub fn int(size: usize) -> Self {
let mut ints = Vec::with_capacity(size);
for i in 0..size {
ints.push(Type {
offset: i as isize,
size: 1,
kind: Kind::Integer,
child: TypeTree::new(),
});
}
Self(ints)
}
}

#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
pub struct FncTree {
pub args: Vec<TypeTree>,
pub ret: TypeTree,
}

#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
pub struct Type {
pub offset: isize,
pub size: usize,
pub kind: Kind,
pub child: TypeTree,
}

impl Type {
pub fn add_offset(self, add: isize) -> Self {
let offset = match self.offset {
-1 => add,
x => add + x,
};

Self { size: self.size, kind: self.kind, child: self.child, offset }
}
}

impl fmt::Display for Type {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
<Self as fmt::Debug>::fmt(self, f)
}
}
4 changes: 4 additions & 0 deletions compiler/rustc_builtin_macros/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@ name = "rustc_builtin_macros"
version = "0.0.0"
edition = "2021"


[lints.rust]
unexpected_cfgs = { level = "warn", check-cfg = ['cfg(llvm_enzyme)'] }

[lib]
doctest = false

Expand Down
Loading

0 comments on commit 785c830

Please sign in to comment.