Skip to content

Commit

Permalink
Clean update some codegen, start adding back dedup
Browse files Browse the repository at this point in the history
  • Loading branch information
mzeitlin11 committed Dec 26, 2023
1 parent e09c203 commit acc796e
Show file tree
Hide file tree
Showing 18 changed files with 309 additions and 132 deletions.
13 changes: 4 additions & 9 deletions openapi/src/codegen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ use indoc::formatdoc;
use crate::components::{get_components, Components};
use crate::crate_table::write_crate_table;
use crate::crates::{get_crate_doc_comment, Crate, ALL_CRATES};
use crate::object_writing::{gen_obj, gen_requests, ObjectGenInfo};
use crate::rust_object::ObjectMetadata;
use crate::object_writing::{gen_obj, gen_requests};
use crate::rust_object::{ObjectKind, ObjectMetadata};
use crate::spec::Spec;
use crate::spec_inference::infer_doc_comment;
use crate::stripe_object::StripeObject;
Expand Down Expand Up @@ -63,11 +63,7 @@ impl CodeGen {
let crate_mod_path = crate_path.join("mod.rs");
for (ident, typ_info) in &self.components.extra_types {
let mut out = String::new();
let mut metadata = ObjectMetadata::new(ident.clone(), typ_info.gen_info);
if let Some(doc) = &typ_info.doc {
metadata = metadata.doc(doc.clone());
}
self.components.write_object(&typ_info.obj, &metadata, &mut out);
self.components.write_object(&typ_info.obj, &typ_info.metadata, &mut out);
write_to_file(out, crate_path.join(format!("{}.rs", typ_info.mod_path)))?;
append_to_file(
format!("pub mod {0}; pub use {0}::{1};", typ_info.mod_path, ident),
Expand Down Expand Up @@ -198,8 +194,7 @@ impl CodeGen {
let base_obj = comp.rust_obj();
let schema = self.spec.get_component_schema(comp.path());
let doc_comment = infer_doc_comment(schema, comp.stripe_doc_url.as_deref());
let meta =
ObjectMetadata::new(comp.ident().clone(), ObjectGenInfo::new_deser()).doc(doc_comment);
let meta = ObjectMetadata::new(comp.ident().clone(), ObjectKind::Type).doc(doc_comment);

gen_obj(base_obj, &meta, comp, &self.components)
}
Expand Down
29 changes: 21 additions & 8 deletions openapi/src/components.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,10 @@ use tracing::{debug, info};

use crate::crate_inference::validate_crate_info;
use crate::crates::{infer_crate_by_package, maybe_infer_crate_by_path, Crate};
use crate::object_writing::ObjectGenInfo;
use crate::overrides::Overrides;
use crate::printable::{PrintableContainer, PrintableType};
use crate::requests::parse_requests;
use crate::rust_object::RustObject;
use crate::rust_object::{ObjectMetadata, RustObject};
use crate::rust_type::{Container, PathToType, RustType};
use crate::spec::Spec;
use crate::stripe_object::{
Expand All @@ -24,9 +23,8 @@ use crate::visitor::Visit;

#[derive(Clone, Debug)]
pub struct TypeSpec {
pub doc: Option<String>,
pub gen_info: ObjectGenInfo,
pub obj: RustObject,
pub metadata: ObjectMetadata,
pub mod_path: String,
}

Expand Down Expand Up @@ -119,6 +117,9 @@ impl Components {
ident: ident.clone(),
}
}
RustType::Path { path: PathToType::Dedupped { .. }, is_ref: _ } => {
todo!()
}
RustType::Simple(typ) => PrintableType::Simple(*typ),
RustType::Container(typ) => {
let inner = Box::new(self.construct_printable_type(typ.value_typ()));
Expand Down Expand Up @@ -189,6 +190,15 @@ impl Components {
None
}

// fn run_deduplication_pass(&mut self) {
// for comp in self.components.values_mut() {
// let extra_typs = deduplicate_types(comp);
// for (ident, typ) in extra_typs {
// comp.deduplicated_objects.insert(ident, typ);
// }
// }
// }

#[tracing::instrument(level = "debug", skip(self))]
fn apply_overrides(&mut self) -> anyhow::Result<()> {
let mut overrides = Overrides::new(self)?;
Expand All @@ -197,12 +207,11 @@ impl Components {
}
for (obj, override_meta) in overrides.overrides {
self.extra_types.insert(
override_meta.ident,
override_meta.metadata.ident.clone(),
TypeSpec {
doc: Some(override_meta.doc),
mod_path: override_meta.mod_path,
gen_info: ObjectGenInfo::new_deser(),
obj,
metadata: override_meta.metadata,
mod_path: override_meta.mod_path,
},
);
}
Expand Down Expand Up @@ -270,6 +279,7 @@ pub fn get_components(spec: &Spec) -> anyhow::Result<Components> {
data,
krate: inferred_krate.map(CrateInfo::new),
stripe_doc_url: None,
deduplicated_objects: IndexMap::default(),
},
);
}
Expand All @@ -284,6 +294,9 @@ pub fn get_components(spec: &Spec) -> anyhow::Result<Components> {
components.apply_overrides()?;
debug!("Finished applying overrides");

// components.run_deduplication_pass();
// info!("Finished deduplication pass");

Ok(components)
}

Expand Down
169 changes: 169 additions & 0 deletions openapi/src/deduplication.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
use heck::ToLowerCamelCase;
use indexmap::map::Entry;
use indexmap::IndexMap;
use tracing::debug;

use crate::components::TypeSpec;
use crate::rust_object::{ObjectMetadata, RustObject};
use crate::rust_type::{PathToType, RustType};
use crate::stripe_object::StripeObject;
use crate::types::{ComponentPath, RustIdent};
use crate::visitor::{Visit, VisitMut};

#[derive(Debug, Default)]
struct CollectDuplicateObjects {
objs: IndexMap<RustObject, Vec<ObjectMetadata>>,
}

impl Visit<'_> for CollectDuplicateObjects {
fn visit_obj(&mut self, obj: &RustObject, meta: Option<&ObjectMetadata>) {
if let Some(meta) = meta {
match self.objs.entry(obj.clone()) {
Entry::Occupied(mut occ) => {
occ.get_mut().push(meta.clone());
}
Entry::Vacant(entry) => {
entry.insert(vec![meta.clone()]);
}
}
};
obj.visit(self);
}
}

#[derive(Debug)]
struct DeduppedObjectInfo {
ident: RustIdent,
gen_info: ObjectGenInfo,
}

#[derive(Debug)]
struct DeduplicateObjects {
objs: IndexMap<RustObject, DeduppedObjectInfo>,
component_path: ComponentPath,
}

impl DeduplicateObjects {
pub fn new(path: ComponentPath) -> Self {
Self { objs: Default::default(), component_path: path }
}
}

impl VisitMut for DeduplicateObjects {
fn visit_typ_mut(&mut self, typ: &mut RustType)
where
Self: Sized,
{
if let Some((obj, _)) = typ.as_object_mut() {
if let Some(dedup_spec) = self.objs.get(obj) {
*typ = RustType::Path {
path: PathToType::Dedupped {
path: self.component_path.clone(),
ident: dedup_spec.ident.clone(),
},
is_ref: false,
}
}
}
typ.visit_mut(self);
}
}

fn doc_implied_name(doc: &str) -> Option<&str> {
let mut words = doc.split_ascii_whitespace();
if words.next() != Some("The") {
return None;
}
let second_word = words.next();
if words.next() != Some("of") {
return None;
}
second_word
}

/// Try to infer an identifier given metadata about identical objects
fn infer_ident_for_duplicate_objects(meta: &[ObjectMetadata]) -> Option<RustIdent> {
let first = meta.first().unwrap();
if let Some(title) = &first.title {
// `param` is used very generally and will not be a helpful name to infer
if title != "param" && meta.iter().all(|m| m.title.as_ref() == Some(title)) {
return Some(RustIdent::create(title));
}
}
if let Some(field) = &first.field_name {
if meta.iter().all(|m| m.field_name.as_ref() == Some(field)) {
return Some(RustIdent::create(field));
}
}

if let Some(desc) = &first.doc {
if let Some(doc_name) = doc_implied_name(desc) {
if meta
.iter()
.all(|m| m.doc.as_ref().and_then(|m| doc_implied_name(m)) == Some(doc_name))
{
return Some(RustIdent::create(doc_name));
}
}
}
None
}

#[tracing::instrument(level = "debug", skip(comp), fields(path = %comp.path()))]
pub fn deduplicate_types(comp: &mut StripeObject) -> IndexMap<RustIdent, TypeSpec> {
let mut objs = IndexMap::new();
let comp_path = comp.path().clone();

// We run deduplication passes until there are no further changes since one round
// of deduplicating can enable another
loop {
let mut collector = CollectDuplicateObjects::default();
comp.visit(&mut collector);

let mut dedupper = DeduplicateObjects::new(comp_path.clone());
for (obj, meta) in collector.objs {
// Nothing to deduplicate
if meta.len() < 2 {
continue;
}
if let Some(inferred) = infer_ident_for_duplicate_objects(&meta) {
// Don't add another deduplicated type with the same name as an existing one
if dedupper.objs.values().all(|o| o.ident != inferred)
&& !objs.contains_key(&inferred)
{
// Make sure we respect all requirements, e.g. if one item to generate required `Deserialize`,
// make sure not to forget that when deduplicating
let gen_info = meta.iter().fold(ObjectGenInfo::new(), |prev_meta, meta| {
prev_meta.with_shared_requirements(&meta.gen_info)
});

dedupper
.objs
.insert(obj.clone(), DeduppedObjectInfo { ident: inferred, gen_info });
}
}
}
// If we weren't able to deduplicate anything new, we're done
if dedupper.objs.is_empty() {
break;
}

comp.visit_mut(&mut dedupper);
for (obj, info) in dedupper.objs {
let mod_path = info.ident.to_lower_camel_case();
if objs
.insert(
info.ident.clone(),
TypeSpec { doc: None, gen_info: info.gen_info, obj, mod_path },
)
.is_some()
{
panic!("Tried to add duplicate ident {}", info.ident);
}
}
}
if !objs.is_empty() {
debug!("Deduplicated {} types", objs.len());
}
objs
}
1 change: 1 addition & 0 deletions openapi/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ mod components;
mod crate_inference;
mod crate_table;
pub mod crates;
// mod deduplication;
mod graph;
mod ids;
mod object_writing;
Expand Down
52 changes: 10 additions & 42 deletions openapi/src/object_writing.rs
Original file line number Diff line number Diff line change
@@ -1,49 +1,17 @@
use std::fmt::{Debug, Write};
use std::fmt::Write;

use crate::components::Components;
use crate::ids::write_object_id;
use crate::printable::Lifetime;
use crate::rust_object::{as_enum_of_objects, ObjectMetadata, RustObject};
use crate::rust_type::RustType;
use crate::stripe_object::{RequestSpec, StripeObject};
use crate::templates::derives::Derives;
use crate::templates::object_trait::{write_object_trait, write_object_trait_for_enum};
use crate::templates::utils::write_doc_comment;
use crate::templates::ObjectWriter;

const ADD_UNKNOWN_VARIANT_THRESHOLD: usize = 12;

#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
pub struct ObjectGenInfo {
pub derives: Derives,
pub include_constructor: bool,
}

impl ObjectGenInfo {
pub const fn new() -> Self {
Self { derives: Derives::new(), include_constructor: false }
}

pub fn new_deser() -> Self {
Self::new().deserialize(true).serialize(true)
}

pub fn serialize(mut self, serialize: bool) -> Self {
self.derives.serialize(serialize);
self
}

pub fn deserialize(mut self, deserialize: bool) -> Self {
self.derives.deserialize(deserialize);
self
}

pub fn include_constructor(mut self) -> Self {
self.include_constructor = true;
self
}
}

impl Components {
fn write_rust_type_objs(&self, typ: &RustType, out: &mut String) {
let Some((obj, meta)) = typ.extract_object() else {
Expand All @@ -53,24 +21,24 @@ impl Components {
}

pub fn write_object(&self, obj: &RustObject, metadata: &ObjectMetadata, out: &mut String) {
let info = metadata.gen_info;
if let Some(doc) = &metadata.doc {
let comment = write_doc_comment(doc, 0);
let _ = write!(out, "{comment}");
}

// If the object contains any references, we'll need to print with a lifetime
let has_ref = obj.has_reference(self);
let lifetime = if has_ref { Some(Lifetime::new()) } else { None };
let ident = &metadata.ident;

if let Some(doc) = &metadata.doc {
let comment = write_doc_comment(doc, 0);
let _ = write!(out, "{comment}");
}
let mut writer = ObjectWriter::new(self, ident);
writer.lifetime(lifetime).derives(info.derives).derives_mut().copy(obj.is_copy(self));
let mut writer = ObjectWriter::new(self, ident, metadata.kind);
writer.lifetime(lifetime).derive_copy(obj.is_copy(self));

match obj {
RustObject::Struct(fields) => {
let should_derive_default = fields.iter().all(|field| field.rust_type.is_option());
writer.derives_mut().default(should_derive_default);
writer.write_struct(out, fields, info.include_constructor);
writer.derive_default(should_derive_default);
writer.write_struct(out, fields);

for field in fields {
if let Some((obj, meta)) = field.rust_type.extract_object() {
Expand Down
Loading

0 comments on commit acc796e

Please sign in to comment.