diff --git a/prost-build/Cargo.toml b/prost-build/Cargo.toml index aaf6698d7..d6f67faa3 100644 --- a/prost-build/Cargo.toml +++ b/prost-build/Cargo.toml @@ -17,7 +17,7 @@ rust-version = "1.70" [features] default = ["format"] -format = ["dep:prettyplease", "dep:syn"] +format = ["dep:prettyplease"] cleanup-markdown = ["dep:pulldown-cmark", "dep:pulldown-cmark-to-cmark"] [dependencies] @@ -27,14 +27,16 @@ itertools = { version = ">=0.10, <=0.12", default-features = false, features = [ log = "0.4.4" multimap = { version = ">=0.8, <=0.10", default-features = false } petgraph = { version = "0.6", default-features = false } +proc-macro2 = "1" prost = { version = "0.12.4", path = "..", default-features = false } prost-types = { version = "0.12.4", path = "../prost-types", default-features = false } +quote = "1" +syn = { version = "2", features = ["full", "extra-traits"] } tempfile = "3" once_cell = "1.17.1" regex = { version = "1.8.1", default-features = false, features = ["std", "unicode-bool"] } prettyplease = { version = "0.2", optional = true } -syn = { version = "2", features = ["full"], optional = true } # These two must be kept in sync, used for `cleanup-markdown` feature. pulldown-cmark = { version = "0.9.1", optional = true, default-features = false } @@ -42,3 +44,4 @@ pulldown-cmark-to-cmark = { version = "10.0.1", optional = true } [dev-dependencies] env_logger = { version = "0.10", default-features = false } +pretty_assertions = "1" diff --git a/prost-build/src/ast.rs b/prost-build/src/ast.rs index 9a6a0de99..3c71e53ed 100644 --- a/prost-build/src/ast.rs +++ b/prost-build/src/ast.rs @@ -42,13 +42,10 @@ impl Comments { /// Appends the comments to a buffer with indentation. /// /// Each level of indentation corresponds to four space (' ') characters. - pub fn append_with_indent(&self, indent_level: u8, buf: &mut String) { + pub fn append_with_indent(&self, buf: &mut String) { // Append blocks of detached comments. for detached_block in &self.leading_detached { for line in detached_block { - for _ in 0..indent_level { - buf.push_str(" "); - } buf.push_str("//"); buf.push_str(&Self::sanitize_line(line)); buf.push('\n'); @@ -58,9 +55,6 @@ impl Comments { // Append leading comments. for line in &self.leading { - for _ in 0..indent_level { - buf.push_str(" "); - } buf.push_str("///"); buf.push_str(&Self::sanitize_line(line)); buf.push('\n'); @@ -68,17 +62,11 @@ impl Comments { // Append an empty comment line if there are leading and trailing comments. if !self.leading.is_empty() && !self.trailing.is_empty() { - for _ in 0..indent_level { - buf.push_str(" "); - } buf.push_str("///\n"); } // Append trailing comments. for line in &self.trailing { - for _ in 0..indent_level { - buf.push_str(" "); - } buf.push_str("///"); buf.push_str(&Self::sanitize_line(line)); buf.push('\n'); @@ -262,7 +250,7 @@ mod tests { }; let mut actual = "".to_string(); - input.append_with_indent(0, &mut actual); + input.append_with_indent(&mut actual); assert_eq!(t.expected, actual, "failed {}", t.name); } @@ -306,7 +294,7 @@ mod tests { }; let mut actual = "".to_string(); - input.append_with_indent(0, &mut actual); + input.append_with_indent(&mut actual); assert_eq!(t.expected, actual, "failed {}", t.name); } @@ -400,7 +388,7 @@ mod tests { }; let mut actual = "".to_string(); - input.append_with_indent(0, &mut actual); + input.append_with_indent(&mut actual); assert_eq!(t.expected, actual, "failed {}", t.name); } diff --git a/prost-build/src/code_generator.rs b/prost-build/src/code_generator.rs deleted file mode 100644 index 5acb90d1c..000000000 --- a/prost-build/src/code_generator.rs +++ /dev/null @@ -1,1161 +0,0 @@ -use std::ascii; -use std::borrow::Cow; -use std::collections::{HashMap, HashSet}; -use std::iter; - -use itertools::{Either, Itertools}; -use log::debug; -use multimap::MultiMap; -use prost_types::field_descriptor_proto::{Label, Type}; -use prost_types::source_code_info::Location; -use prost_types::{ - DescriptorProto, EnumDescriptorProto, EnumValueDescriptorProto, FieldDescriptorProto, - FieldOptions, FileDescriptorProto, OneofDescriptorProto, ServiceDescriptorProto, - SourceCodeInfo, -}; - -use crate::ast::{Comments, Method, Service}; -use crate::extern_paths::ExternPaths; -use crate::ident::{strip_enum_prefix, to_snake, to_upper_camel}; -use crate::message_graph::MessageGraph; -use crate::{BytesType, Config, MapType}; - -mod c_escaping; -use c_escaping::unescape_c_escape_string; - -#[derive(PartialEq)] -enum Syntax { - Proto2, - Proto3, -} - -pub struct CodeGenerator<'a> { - config: &'a mut Config, - package: String, - type_path: Vec, - source_info: Option, - syntax: Syntax, - message_graph: &'a MessageGraph, - extern_paths: &'a ExternPaths, - depth: u8, - path: Vec, - buf: &'a mut String, -} - -fn push_indent(buf: &mut String, depth: u8) { - for _ in 0..depth { - buf.push_str(" "); - } -} - -fn prost_path(config: &Config) -> &str { - config.prost_path.as_deref().unwrap_or("::prost") -} - -impl<'a> CodeGenerator<'a> { - pub fn generate( - config: &mut Config, - message_graph: &MessageGraph, - extern_paths: &ExternPaths, - file: FileDescriptorProto, - buf: &mut String, - ) { - let source_info = file.source_code_info.map(|mut s| { - s.location.retain(|loc| { - let len = loc.path.len(); - len > 0 && len % 2 == 0 - }); - s.location.sort_by(|a, b| a.path.cmp(&b.path)); - s - }); - - let syntax = match file.syntax.as_ref().map(String::as_str) { - None | Some("proto2") => Syntax::Proto2, - Some("proto3") => Syntax::Proto3, - Some(s) => panic!("unknown syntax: {}", s), - }; - - let mut code_gen = CodeGenerator { - config, - package: file.package.unwrap_or_default(), - type_path: Vec::new(), - source_info, - syntax, - message_graph, - extern_paths, - depth: 0, - path: Vec::new(), - buf, - }; - - debug!( - "file: {:?}, package: {:?}", - file.name.as_ref().unwrap(), - code_gen.package - ); - - code_gen.path.push(4); - for (idx, message) in file.message_type.into_iter().enumerate() { - code_gen.path.push(idx as i32); - code_gen.append_message(message); - code_gen.path.pop(); - } - code_gen.path.pop(); - - code_gen.path.push(5); - for (idx, desc) in file.enum_type.into_iter().enumerate() { - code_gen.path.push(idx as i32); - code_gen.append_enum(desc); - code_gen.path.pop(); - } - code_gen.path.pop(); - - if code_gen.config.service_generator.is_some() { - code_gen.path.push(6); - for (idx, service) in file.service.into_iter().enumerate() { - code_gen.path.push(idx as i32); - code_gen.push_service(service); - code_gen.path.pop(); - } - - if let Some(service_generator) = code_gen.config.service_generator.as_mut() { - service_generator.finalize(code_gen.buf); - } - - code_gen.path.pop(); - } - } - - fn append_message(&mut self, message: DescriptorProto) { - debug!(" message: {:?}", message.name()); - - let message_name = message.name().to_string(); - let fq_message_name = self.fq_name(&message_name); - - // Skip external types. - if self.extern_paths.resolve_ident(&fq_message_name).is_some() { - return; - } - - // Split the nested message types into a vector of normal nested message types, and a map - // of the map field entry types. The path index of the nested message types is preserved so - // that comments can be retrieved. - type NestedTypes = Vec<(DescriptorProto, usize)>; - type MapTypes = HashMap; - let (nested_types, map_types): (NestedTypes, MapTypes) = message - .nested_type - .into_iter() - .enumerate() - .partition_map(|(idx, nested_type)| { - if nested_type - .options - .as_ref() - .and_then(|options| options.map_entry) - .unwrap_or(false) - { - let key = nested_type.field[0].clone(); - let value = nested_type.field[1].clone(); - assert_eq!("key", key.name()); - assert_eq!("value", value.name()); - - let name = format!("{}.{}", &fq_message_name, nested_type.name()); - Either::Right((name, (key, value))) - } else { - Either::Left((nested_type, idx)) - } - }); - - // Split the fields into a vector of the normal fields, and oneof fields. - // Path indexes are preserved so that comments can be retrieved. - type Fields = Vec<(FieldDescriptorProto, usize)>; - type OneofFields = MultiMap; - let (fields, mut oneof_fields): (Fields, OneofFields) = message - .field - .into_iter() - .enumerate() - .partition_map(|(idx, field)| { - if field.proto3_optional.unwrap_or(false) { - Either::Left((field, idx)) - } else if let Some(oneof_index) = field.oneof_index { - Either::Right((oneof_index, (field, idx))) - } else { - Either::Left((field, idx)) - } - }); - - self.append_doc(&fq_message_name, None); - self.append_type_attributes(&fq_message_name); - self.append_message_attributes(&fq_message_name); - self.push_indent(); - self.buf - .push_str("#[allow(clippy::derive_partial_eq_without_eq)]\n"); - self.buf.push_str(&format!( - "#[derive(Clone, PartialEq, {}::Message)]\n", - prost_path(self.config) - )); - self.append_skip_debug(&fq_message_name); - self.push_indent(); - self.buf.push_str("pub struct "); - self.buf.push_str(&to_upper_camel(&message_name)); - self.buf.push_str(" {\n"); - - self.depth += 1; - self.path.push(2); - for (field, idx) in fields { - self.path.push(idx as i32); - match field - .type_name - .as_ref() - .and_then(|type_name| map_types.get(type_name)) - { - Some((key, value)) => self.append_map_field(&fq_message_name, field, key, value), - None => self.append_field(&fq_message_name, field), - } - self.path.pop(); - } - self.path.pop(); - - self.path.push(8); - for (idx, oneof) in message.oneof_decl.iter().enumerate() { - let idx = idx as i32; - - let fields = match oneof_fields.get_vec(&idx) { - Some(fields) => fields, - None => continue, - }; - - self.path.push(idx); - self.append_oneof_field(&message_name, &fq_message_name, oneof, fields); - self.path.pop(); - } - self.path.pop(); - - self.depth -= 1; - self.push_indent(); - self.buf.push_str("}\n"); - - if !message.enum_type.is_empty() || !nested_types.is_empty() || !oneof_fields.is_empty() { - self.push_mod(&message_name); - self.path.push(3); - for (nested_type, idx) in nested_types { - self.path.push(idx as i32); - self.append_message(nested_type); - self.path.pop(); - } - self.path.pop(); - - self.path.push(4); - for (idx, nested_enum) in message.enum_type.into_iter().enumerate() { - self.path.push(idx as i32); - self.append_enum(nested_enum); - self.path.pop(); - } - self.path.pop(); - - for (idx, oneof) in message.oneof_decl.into_iter().enumerate() { - let idx = idx as i32; - // optional fields create a synthetic oneof that we want to skip - let fields = match oneof_fields.remove(&idx) { - Some(fields) => fields, - None => continue, - }; - self.append_oneof(&fq_message_name, oneof, idx, fields); - } - - self.pop_mod(); - } - - if self.config.enable_type_names { - self.append_type_name(&message_name, &fq_message_name); - } - } - - fn append_type_name(&mut self, message_name: &str, fq_message_name: &str) { - self.buf.push_str(&format!( - "impl {}::Name for {} {{\n", - self.config.prost_path.as_deref().unwrap_or("::prost"), - to_upper_camel(message_name) - )); - self.depth += 1; - - self.buf.push_str(&format!( - "const NAME: &'static str = \"{}\";\n", - message_name, - )); - self.buf.push_str(&format!( - "const PACKAGE: &'static str = \"{}\";\n", - self.package, - )); - - let prost_path = self.config.prost_path.as_deref().unwrap_or("::prost"); - let string_path = format!("{prost_path}::alloc::string::String"); - - let full_name = format!( - "{}{}{}{}{message_name}", - self.package.trim_matches('.'), - if self.package.is_empty() { "" } else { "." }, - self.type_path.join("."), - if self.type_path.is_empty() { "" } else { "." }, - ); - let domain_name = self - .config - .type_name_domains - .get_first(fq_message_name) - .map_or("", |name| name.as_str()); - - self.buf.push_str(&format!( - r#"fn full_name() -> {string_path} {{ "{full_name}".into() }}"#, - )); - - self.buf.push_str(&format!( - r#"fn type_url() -> {string_path} {{ "{domain_name}/{full_name}".into() }}"#, - )); - - self.depth -= 1; - self.buf.push_str("}\n"); - } - - fn append_type_attributes(&mut self, fq_message_name: &str) { - assert_eq!(b'.', fq_message_name.as_bytes()[0]); - for attribute in self.config.type_attributes.get(fq_message_name) { - push_indent(self.buf, self.depth); - self.buf.push_str(attribute); - self.buf.push('\n'); - } - } - - fn append_message_attributes(&mut self, fq_message_name: &str) { - assert_eq!(b'.', fq_message_name.as_bytes()[0]); - for attribute in self.config.message_attributes.get(fq_message_name) { - push_indent(self.buf, self.depth); - self.buf.push_str(attribute); - self.buf.push('\n'); - } - } - - fn should_skip_debug(&self, fq_message_name: &str) -> bool { - assert_eq!(b'.', fq_message_name.as_bytes()[0]); - self.config.skip_debug.get(fq_message_name).next().is_some() - } - - fn append_skip_debug(&mut self, fq_message_name: &str) { - if self.should_skip_debug(fq_message_name) { - push_indent(self.buf, self.depth); - self.buf.push_str("#[prost(skip_debug)]"); - self.buf.push('\n'); - } - } - - fn append_enum_attributes(&mut self, fq_message_name: &str) { - assert_eq!(b'.', fq_message_name.as_bytes()[0]); - for attribute in self.config.enum_attributes.get(fq_message_name) { - push_indent(self.buf, self.depth); - self.buf.push_str(attribute); - self.buf.push('\n'); - } - } - - fn append_field_attributes(&mut self, fq_message_name: &str, field_name: &str) { - assert_eq!(b'.', fq_message_name.as_bytes()[0]); - for attribute in self - .config - .field_attributes - .get_field(fq_message_name, field_name) - { - push_indent(self.buf, self.depth); - self.buf.push_str(attribute); - self.buf.push('\n'); - } - } - - fn append_field(&mut self, fq_message_name: &str, field: FieldDescriptorProto) { - let type_ = field.r#type(); - let repeated = field.label == Some(Label::Repeated as i32); - let deprecated = self.deprecated(&field); - let optional = self.optional(&field); - let ty = self.resolve_type(&field, fq_message_name); - - let boxed = !repeated - && ((type_ == Type::Message || type_ == Type::Group) - && self - .message_graph - .is_nested(field.type_name(), fq_message_name)) - || (self - .config - .boxed - .get_first_field(fq_message_name, field.name()) - .is_some()); - - debug!( - " field: {:?}, type: {:?}, boxed: {}", - field.name(), - ty, - boxed - ); - - self.append_doc(fq_message_name, Some(field.name())); - - if deprecated { - self.push_indent(); - self.buf.push_str("#[deprecated]\n"); - } - - self.push_indent(); - self.buf.push_str("#[prost("); - let type_tag = self.field_type_tag(&field); - self.buf.push_str(&type_tag); - - if type_ == Type::Bytes { - let bytes_type = self - .config - .bytes_type - .get_first_field(fq_message_name, field.name()) - .copied() - .unwrap_or_default(); - self.buf - .push_str(&format!("={:?}", bytes_type.annotation())); - } - - match field.label() { - Label::Optional => { - if optional { - self.buf.push_str(", optional"); - } - } - Label::Required => self.buf.push_str(", required"), - Label::Repeated => { - self.buf.push_str(", repeated"); - if can_pack(&field) - && !field - .options - .as_ref() - .map_or(self.syntax == Syntax::Proto3, |options| options.packed()) - { - self.buf.push_str(", packed=\"false\""); - } - } - } - - if boxed { - self.buf.push_str(", boxed"); - } - self.buf.push_str(", tag=\""); - self.buf.push_str(&field.number().to_string()); - - if let Some(ref default) = field.default_value { - self.buf.push_str("\", default=\""); - if type_ == Type::Bytes { - self.buf.push_str("b\\\""); - for b in unescape_c_escape_string(default) { - self.buf.extend( - ascii::escape_default(b).flat_map(|c| (c as char).escape_default()), - ); - } - self.buf.push_str("\\\""); - } else if type_ == Type::Enum { - let mut enum_value = to_upper_camel(default); - if self.config.strip_enum_prefix { - // Field types are fully qualified, so we extract - // the last segment and strip it from the left - // side of the default value. - let enum_type = field - .type_name - .as_ref() - .and_then(|ty| ty.split('.').last()) - .unwrap(); - - enum_value = strip_enum_prefix(&to_upper_camel(enum_type), &enum_value) - } - self.buf.push_str(&enum_value); - } else { - self.buf.push_str(&default.escape_default().to_string()); - } - } - - self.buf.push_str("\")]\n"); - self.append_field_attributes(fq_message_name, field.name()); - self.push_indent(); - self.buf.push_str("pub "); - self.buf.push_str(&to_snake(field.name())); - self.buf.push_str(": "); - - let prost_path = prost_path(self.config); - - if repeated { - self.buf - .push_str(&format!("{}::alloc::vec::Vec<", prost_path)); - } else if optional { - self.buf.push_str("::core::option::Option<"); - } - if boxed { - self.buf - .push_str(&format!("{}::alloc::boxed::Box<", prost_path)); - } - self.buf.push_str(&ty); - if boxed { - self.buf.push('>'); - } - if repeated || optional { - self.buf.push('>'); - } - self.buf.push_str(",\n"); - } - - fn append_map_field( - &mut self, - fq_message_name: &str, - field: FieldDescriptorProto, - key: &FieldDescriptorProto, - value: &FieldDescriptorProto, - ) { - let key_ty = self.resolve_type(key, fq_message_name); - let value_ty = self.resolve_type(value, fq_message_name); - - debug!( - " map field: {:?}, key type: {:?}, value type: {:?}", - field.name(), - key_ty, - value_ty - ); - - self.append_doc(fq_message_name, Some(field.name())); - self.push_indent(); - - let map_type = self - .config - .map_type - .get_first_field(fq_message_name, field.name()) - .copied() - .unwrap_or_default(); - let key_tag = self.field_type_tag(key); - let value_tag = self.map_value_type_tag(value); - - self.buf.push_str(&format!( - "#[prost({}=\"{}, {}\", tag=\"{}\")]\n", - map_type.annotation(), - key_tag, - value_tag, - field.number() - )); - self.append_field_attributes(fq_message_name, field.name()); - self.push_indent(); - self.buf.push_str(&format!( - "pub {}: {}<{}, {}>,\n", - to_snake(field.name()), - map_type.rust_type(), - key_ty, - value_ty - )); - } - - fn append_oneof_field( - &mut self, - message_name: &str, - fq_message_name: &str, - oneof: &OneofDescriptorProto, - fields: &[(FieldDescriptorProto, usize)], - ) { - let name = format!( - "{}::{}", - to_snake(message_name), - to_upper_camel(oneof.name()) - ); - self.append_doc(fq_message_name, None); - self.push_indent(); - self.buf.push_str(&format!( - "#[prost(oneof=\"{}\", tags=\"{}\")]\n", - name, - fields.iter().map(|(field, _)| field.number()).join(", ") - )); - self.append_field_attributes(fq_message_name, oneof.name()); - self.push_indent(); - self.buf.push_str(&format!( - "pub {}: ::core::option::Option<{}>,\n", - to_snake(oneof.name()), - name - )); - } - - fn append_oneof( - &mut self, - fq_message_name: &str, - oneof: OneofDescriptorProto, - idx: i32, - fields: Vec<(FieldDescriptorProto, usize)>, - ) { - self.path.push(8); - self.path.push(idx); - self.append_doc(fq_message_name, None); - self.path.pop(); - self.path.pop(); - - let oneof_name = format!("{}.{}", fq_message_name, oneof.name()); - self.append_type_attributes(&oneof_name); - self.append_enum_attributes(&oneof_name); - self.push_indent(); - self.buf - .push_str("#[allow(clippy::derive_partial_eq_without_eq)]\n"); - self.buf.push_str(&format!( - "#[derive(Clone, PartialEq, {}::Oneof)]\n", - prost_path(self.config) - )); - self.append_skip_debug(fq_message_name); - self.push_indent(); - self.buf.push_str("pub enum "); - self.buf.push_str(&to_upper_camel(oneof.name())); - self.buf.push_str(" {\n"); - - self.path.push(2); - self.depth += 1; - for (field, idx) in fields { - let type_ = field.r#type(); - - self.path.push(idx as i32); - self.append_doc(fq_message_name, Some(field.name())); - self.path.pop(); - - self.push_indent(); - let ty_tag = self.field_type_tag(&field); - self.buf.push_str(&format!( - "#[prost({}, tag=\"{}\")]\n", - ty_tag, - field.number() - )); - self.append_field_attributes(&oneof_name, field.name()); - - self.push_indent(); - let ty = self.resolve_type(&field, fq_message_name); - - let boxed = ((type_ == Type::Message || type_ == Type::Group) - && self - .message_graph - .is_nested(field.type_name(), fq_message_name)) - || (self - .config - .boxed - .get_first_field(&oneof_name, field.name()) - .is_some()); - - debug!( - " oneof: {:?}, type: {:?}, boxed: {}", - field.name(), - ty, - boxed - ); - - if boxed { - self.buf.push_str(&format!( - "{}(::prost::alloc::boxed::Box<{}>),\n", - to_upper_camel(field.name()), - ty - )); - } else { - self.buf - .push_str(&format!("{}({}),\n", to_upper_camel(field.name()), ty)); - } - } - self.depth -= 1; - self.path.pop(); - - self.push_indent(); - self.buf.push_str("}\n"); - } - - fn location(&self) -> Option<&Location> { - let source_info = self.source_info.as_ref()?; - let idx = source_info - .location - .binary_search_by_key(&&self.path[..], |location| &location.path[..]) - .unwrap(); - Some(&source_info.location[idx]) - } - - fn append_doc(&mut self, fq_name: &str, field_name: Option<&str>) { - let append_doc = if let Some(field_name) = field_name { - self.config - .disable_comments - .get_first_field(fq_name, field_name) - .is_none() - } else { - self.config.disable_comments.get(fq_name).next().is_none() - }; - if append_doc { - if let Some(comments) = self.location().map(Comments::from_location) { - comments.append_with_indent(self.depth, self.buf); - } - } - } - - fn append_enum(&mut self, desc: EnumDescriptorProto) { - debug!(" enum: {:?}", desc.name()); - - let proto_enum_name = desc.name(); - let enum_name = to_upper_camel(proto_enum_name); - - let enum_values = &desc.value; - let fq_proto_enum_name = self.fq_name(proto_enum_name); - - if self - .extern_paths - .resolve_ident(&fq_proto_enum_name) - .is_some() - { - return; - } - - self.append_doc(&fq_proto_enum_name, None); - self.append_type_attributes(&fq_proto_enum_name); - self.append_enum_attributes(&fq_proto_enum_name); - self.push_indent(); - let dbg = if self.should_skip_debug(&fq_proto_enum_name) { - "" - } else { - "Debug, " - }; - self.buf.push_str(&format!( - "#[derive(Clone, Copy, {}PartialEq, Eq, Hash, PartialOrd, Ord, {}::Enumeration)]\n", - dbg, - prost_path(self.config), - )); - self.push_indent(); - self.buf.push_str("#[repr(i32)]\n"); - self.push_indent(); - self.buf.push_str("pub enum "); - self.buf.push_str(&enum_name); - self.buf.push_str(" {\n"); - - let variant_mappings = - build_enum_value_mappings(&enum_name, self.config.strip_enum_prefix, enum_values); - - self.depth += 1; - self.path.push(2); - for variant in variant_mappings.iter() { - self.path.push(variant.path_idx as i32); - - self.append_doc(&fq_proto_enum_name, Some(variant.proto_name)); - self.append_field_attributes(&fq_proto_enum_name, variant.proto_name); - self.push_indent(); - self.buf.push_str(&variant.generated_variant_name); - self.buf.push_str(" = "); - self.buf.push_str(&variant.proto_number.to_string()); - self.buf.push_str(",\n"); - - self.path.pop(); - } - - self.path.pop(); - self.depth -= 1; - - self.push_indent(); - self.buf.push_str("}\n"); - - self.push_indent(); - self.buf.push_str("impl "); - self.buf.push_str(&enum_name); - self.buf.push_str(" {\n"); - self.depth += 1; - self.path.push(2); - - self.push_indent(); - self.buf.push_str( - "/// String value of the enum field names used in the ProtoBuf definition.\n", - ); - self.push_indent(); - self.buf.push_str("///\n"); - self.push_indent(); - self.buf.push_str( - "/// The values are not transformed in any way and thus are considered stable\n", - ); - self.push_indent(); - self.buf.push_str( - "/// (if the ProtoBuf definition does not change) and safe for programmatic use.\n", - ); - self.push_indent(); - self.buf - .push_str("pub fn as_str_name(&self) -> &'static str {\n"); - self.depth += 1; - - self.push_indent(); - self.buf.push_str("match self {\n"); - self.depth += 1; - - for variant in variant_mappings.iter() { - self.push_indent(); - self.buf.push_str(&enum_name); - self.buf.push_str("::"); - self.buf.push_str(&variant.generated_variant_name); - self.buf.push_str(" => \""); - self.buf.push_str(variant.proto_name); - self.buf.push_str("\",\n"); - } - - self.depth -= 1; - self.push_indent(); - self.buf.push_str("}\n"); // End of match - - self.depth -= 1; - self.push_indent(); - self.buf.push_str("}\n"); // End of as_str_name() - - self.push_indent(); - self.buf - .push_str("/// Creates an enum from field names used in the ProtoBuf definition.\n"); - - self.push_indent(); - self.buf - .push_str("pub fn from_str_name(value: &str) -> ::core::option::Option {\n"); - self.depth += 1; - - self.push_indent(); - self.buf.push_str("match value {\n"); - self.depth += 1; - - for variant in variant_mappings.iter() { - self.push_indent(); - self.buf.push('\"'); - self.buf.push_str(variant.proto_name); - self.buf.push_str("\" => Some(Self::"); - self.buf.push_str(&variant.generated_variant_name); - self.buf.push_str("),\n"); - } - self.push_indent(); - self.buf.push_str("_ => None,\n"); - - self.depth -= 1; - self.push_indent(); - self.buf.push_str("}\n"); // End of match - - self.depth -= 1; - self.push_indent(); - self.buf.push_str("}\n"); // End of from_str_name() - - self.path.pop(); - self.depth -= 1; - self.push_indent(); - self.buf.push_str("}\n"); // End of impl - } - - fn push_service(&mut self, service: ServiceDescriptorProto) { - let name = service.name().to_owned(); - debug!(" service: {:?}", name); - - let comments = self - .location() - .map(Comments::from_location) - .unwrap_or_default(); - - self.path.push(2); - let methods = service - .method - .into_iter() - .enumerate() - .map(|(idx, mut method)| { - debug!(" method: {:?}", method.name()); - - self.path.push(idx as i32); - let comments = self - .location() - .map(Comments::from_location) - .unwrap_or_default(); - self.path.pop(); - - let name = method.name.take().unwrap(); - let input_proto_type = method.input_type.take().unwrap(); - let output_proto_type = method.output_type.take().unwrap(); - let input_type = self.resolve_ident(&input_proto_type); - let output_type = self.resolve_ident(&output_proto_type); - let client_streaming = method.client_streaming(); - let server_streaming = method.server_streaming(); - - Method { - name: to_snake(&name), - proto_name: name, - comments, - input_type, - output_type, - input_proto_type, - output_proto_type, - options: method.options.unwrap_or_default(), - client_streaming, - server_streaming, - } - }) - .collect(); - self.path.pop(); - - let service = Service { - name: to_upper_camel(&name), - proto_name: name, - package: self.package.clone(), - comments, - methods, - options: service.options.unwrap_or_default(), - }; - - if let Some(service_generator) = self.config.service_generator.as_mut() { - service_generator.generate(service, self.buf) - } - } - - fn push_indent(&mut self) { - push_indent(self.buf, self.depth); - } - - fn push_mod(&mut self, module: &str) { - self.push_indent(); - self.buf.push_str("/// Nested message and enum types in `"); - self.buf.push_str(module); - self.buf.push_str("`.\n"); - - self.push_indent(); - self.buf.push_str("pub mod "); - self.buf.push_str(&to_snake(module)); - self.buf.push_str(" {\n"); - - self.type_path.push(module.into()); - - self.depth += 1; - } - - fn pop_mod(&mut self) { - self.depth -= 1; - - self.type_path.pop(); - - self.push_indent(); - self.buf.push_str("}\n"); - } - - fn resolve_type(&self, field: &FieldDescriptorProto, fq_message_name: &str) -> String { - match field.r#type() { - Type::Float => String::from("f32"), - Type::Double => String::from("f64"), - Type::Uint32 | Type::Fixed32 => String::from("u32"), - Type::Uint64 | Type::Fixed64 => String::from("u64"), - Type::Int32 | Type::Sfixed32 | Type::Sint32 | Type::Enum => String::from("i32"), - Type::Int64 | Type::Sfixed64 | Type::Sint64 => String::from("i64"), - Type::Bool => String::from("bool"), - Type::String => format!("{}::alloc::string::String", prost_path(self.config)), - Type::Bytes => self - .config - .bytes_type - .get_first_field(fq_message_name, field.name()) - .copied() - .unwrap_or_default() - .rust_type() - .to_owned(), - Type::Group | Type::Message => self.resolve_ident(field.type_name()), - } - } - - fn resolve_ident(&self, pb_ident: &str) -> String { - // protoc should always give fully qualified identifiers. - assert_eq!(".", &pb_ident[..1]); - - if let Some(proto_ident) = self.extern_paths.resolve_ident(pb_ident) { - return proto_ident; - } - - let mut local_path = self - .package - .split('.') - .chain(self.type_path.iter().map(String::as_str)) - .peekable(); - - // If no package is specified the start of the package name will be '.' - // and split will return an empty string ("") which breaks resolution - // The fix to this is to ignore the first item if it is empty. - if local_path.peek().map_or(false, |s| s.is_empty()) { - local_path.next(); - } - - let mut ident_path = pb_ident[1..].split('.'); - let ident_type = ident_path.next_back().unwrap(); - let mut ident_path = ident_path.peekable(); - - // Skip path elements in common. - while local_path.peek().is_some() && local_path.peek() == ident_path.peek() { - local_path.next(); - ident_path.next(); - } - - local_path - .map(|_| "super".to_string()) - .chain(ident_path.map(to_snake)) - .chain(iter::once(to_upper_camel(ident_type))) - .join("::") - } - - fn field_type_tag(&self, field: &FieldDescriptorProto) -> Cow<'static, str> { - match field.r#type() { - Type::Float => Cow::Borrowed("float"), - Type::Double => Cow::Borrowed("double"), - Type::Int32 => Cow::Borrowed("int32"), - Type::Int64 => Cow::Borrowed("int64"), - Type::Uint32 => Cow::Borrowed("uint32"), - Type::Uint64 => Cow::Borrowed("uint64"), - Type::Sint32 => Cow::Borrowed("sint32"), - Type::Sint64 => Cow::Borrowed("sint64"), - Type::Fixed32 => Cow::Borrowed("fixed32"), - Type::Fixed64 => Cow::Borrowed("fixed64"), - Type::Sfixed32 => Cow::Borrowed("sfixed32"), - Type::Sfixed64 => Cow::Borrowed("sfixed64"), - Type::Bool => Cow::Borrowed("bool"), - Type::String => Cow::Borrowed("string"), - Type::Bytes => Cow::Borrowed("bytes"), - Type::Group => Cow::Borrowed("group"), - Type::Message => Cow::Borrowed("message"), - Type::Enum => Cow::Owned(format!( - "enumeration={:?}", - self.resolve_ident(field.type_name()) - )), - } - } - - fn map_value_type_tag(&self, field: &FieldDescriptorProto) -> Cow<'static, str> { - match field.r#type() { - Type::Enum => Cow::Owned(format!( - "enumeration({})", - self.resolve_ident(field.type_name()) - )), - _ => self.field_type_tag(field), - } - } - - fn optional(&self, field: &FieldDescriptorProto) -> bool { - if field.proto3_optional.unwrap_or(false) { - return true; - } - - if field.label() != Label::Optional { - return false; - } - - match field.r#type() { - Type::Message => true, - _ => self.syntax == Syntax::Proto2, - } - } - - /// Returns `true` if the field options includes the `deprecated` option. - fn deprecated(&self, field: &FieldDescriptorProto) -> bool { - field - .options - .as_ref() - .map_or(false, FieldOptions::deprecated) - } - - /// Returns the fully-qualified name, starting with a dot - fn fq_name(&self, message_name: &str) -> String { - format!( - "{}{}{}{}.{}", - if self.package.is_empty() { "" } else { "." }, - self.package.trim_matches('.'), - if self.type_path.is_empty() { "" } else { "." }, - self.type_path.join("."), - message_name, - ) - } -} - -/// Returns `true` if the repeated field type can be packed. -fn can_pack(field: &FieldDescriptorProto) -> bool { - matches!( - field.r#type(), - Type::Float - | Type::Double - | Type::Int32 - | Type::Int64 - | Type::Uint32 - | Type::Uint64 - | Type::Sint32 - | Type::Sint64 - | Type::Fixed32 - | Type::Fixed64 - | Type::Sfixed32 - | Type::Sfixed64 - | Type::Bool - | Type::Enum - ) -} - -struct EnumVariantMapping<'a> { - path_idx: usize, - proto_name: &'a str, - proto_number: i32, - generated_variant_name: String, -} - -fn build_enum_value_mappings<'a>( - generated_enum_name: &str, - do_strip_enum_prefix: bool, - enum_values: &'a [EnumValueDescriptorProto], -) -> Vec> { - let mut numbers = HashSet::new(); - let mut generated_names = HashMap::new(); - let mut mappings = Vec::new(); - - for (idx, value) in enum_values.iter().enumerate() { - // Skip duplicate enum values. Protobuf allows this when the - // 'allow_alias' option is set. - if !numbers.insert(value.number()) { - continue; - } - - let mut generated_variant_name = to_upper_camel(value.name()); - if do_strip_enum_prefix { - generated_variant_name = - strip_enum_prefix(generated_enum_name, &generated_variant_name); - } - - if let Some(old_v) = generated_names.insert(generated_variant_name.to_owned(), value.name()) - { - panic!("Generated enum variant names overlap: `{}` variant name to be used both by `{}` and `{}` ProtoBuf enum values", - generated_variant_name, old_v, value.name()); - } - - mappings.push(EnumVariantMapping { - path_idx: idx, - proto_name: value.name(), - proto_number: value.number(), - generated_variant_name, - }) - } - mappings -} - -impl MapType { - /// The `prost-derive` annotation type corresponding to the map type. - fn annotation(&self) -> &'static str { - match self { - MapType::HashMap => "map", - MapType::BTreeMap => "btree_map", - } - } - - /// The fully-qualified Rust type corresponding to the map type. - fn rust_type(&self) -> &'static str { - match self { - MapType::HashMap => "::std::collections::HashMap", - MapType::BTreeMap => "::prost::alloc::collections::BTreeMap", - } - } -} - -impl BytesType { - /// The `prost-derive` annotation type corresponding to the bytes type. - fn annotation(&self) -> &'static str { - match self { - BytesType::Vec => "vec", - BytesType::Bytes => "bytes", - } - } - - /// The fully-qualified Rust type corresponding to the bytes type. - fn rust_type(&self) -> &'static str { - match self { - BytesType::Vec => "::prost::alloc::vec::Vec", - BytesType::Bytes => "::prost::bytes::Bytes", - } - } -} diff --git a/prost-build/src/code_generator/enums.rs b/prost-build/src/code_generator/enums.rs new file mode 100644 index 000000000..561bc8660 --- /dev/null +++ b/prost-build/src/code_generator/enums.rs @@ -0,0 +1,197 @@ +use super::*; + +impl CodeGenerator<'_> { + pub(super) fn push_enums(&mut self, enum_types: Vec) { + self.path.push(FileDescriptorProtoLocations::ENUM_TYPE); + for (idx, desc) in enum_types.into_iter().enumerate() { + self.path.push(idx as i32); + if let Some(resolved_enum) = self.resolve_enum(desc) { + self.buf.push_str(&resolved_enum.to_string()); + } + self.path.pop(); + } + self.path.pop(); + } + + pub(super) fn resolve_enum(&mut self, desc: EnumDescriptorProto) -> Option { + debug!(" enum: {:?}", desc.name()); + + let proto_enum_name = desc.name(); + let enum_name = to_upper_camel(proto_enum_name); + let fq_proto_enum_name = + FullyQualifiedName::new(&self.package, &self.type_path, proto_enum_name); + + if self + .extern_paths + .resolve_ident(&fq_proto_enum_name) + .is_some() + { + return None; + } + + let enum_docs = self.resolve_docs(&fq_proto_enum_name, None); + let enum_attributes = self.resolve_enum_attributes(&fq_proto_enum_name); + let prost_path = self.prost_type_path("Enumeration"); + let optional_debug = + (!self.should_skip_debug(&fq_proto_enum_name)).then_some(quote! {#[derive(Debug)]}); + let variant_mappings = EnumVariantMapping::build_enum_value_mappings( + &enum_name, + self.config.strip_enum_prefix, + &desc.value, + ); + let enum_variants = resolve_enum_variants(self, &variant_mappings, &fq_proto_enum_name); + let enum_name_syn = enum_name.parse_syn::(); + let arms_1 = variant_mappings.iter().map(|variant| { + format!( + "{}::{} => \"{}\"", + enum_name_syn, variant.generated_variant_name, variant.proto_name + ) + .parse_syn::() + }); + let arms_2 = variant_mappings.iter().map(|variant| { + format!( + "\"{}\" => Some(Self::{})", + variant.proto_name, variant.generated_variant_name + ) + .parse_syn::() + }); + + return Some(quote! { + #(#enum_docs)* + #enum_attributes + #optional_debug + #[derive(Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, #prost_path)] + #[repr(i32)] + pub enum #enum_name_syn { + #(#enum_variants,)* + } + + impl #enum_name_syn { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + #(#arms_1,)* + } + } + + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + #(#arms_2,)* + _ => None, + } + } + } + }); + + fn resolve_enum_variants( + code_generator: &mut CodeGenerator<'_>, + variant_mappings: &[EnumVariantMapping], + fq_proto_enum_name: &FullyQualifiedName, + ) -> Vec { + let mut variants = Vec::with_capacity(variant_mappings.len()); + + code_generator.path.push(EnumDescriptorLocations::VALUE); + + for variant in variant_mappings.iter() { + code_generator.path.push(variant.path_idx as i32); + + let documentation = + code_generator.resolve_docs(fq_proto_enum_name, Some(variant.proto_name)); + + let field_attributes = + code_generator.resolve_field_attributes(fq_proto_enum_name, variant.proto_name); + + let variant = format!( + "{} = {}", + variant.generated_variant_name, variant.proto_number + ) + .parse_syn::(); + + variants.push(quote! { + #(#documentation)* + #field_attributes + #variant + }); + + code_generator.path.pop(); + } + + code_generator.path.pop(); + + variants + } + } + + pub(super) fn resolve_enum_attributes( + &self, + fq_message_name: &FullyQualifiedName, + ) -> TokenStream { + let type_attributes = self.config.type_attributes.get(fq_message_name.as_ref()); + let enum_attributes = self.config.enum_attributes.get(fq_message_name.as_ref()); + quote! { + #(#(#type_attributes)*)* + #(#(#enum_attributes)*)* + } + } +} + +use variant_mapping::EnumVariantMapping; +mod variant_mapping { + use std::collections::HashSet; + + use prost_types::EnumValueDescriptorProto; + + use super::*; + + pub(super) struct EnumVariantMapping<'a> { + pub(super) path_idx: usize, + pub(super) proto_name: &'a str, + pub(super) proto_number: i32, + pub(super) generated_variant_name: String, + } + + impl EnumVariantMapping<'_> { + pub(super) fn build_enum_value_mappings<'a>( + generated_enum_name: &str, + do_strip_enum_prefix: bool, + enum_values: &'a [EnumValueDescriptorProto], + ) -> Vec> { + let mut numbers = HashSet::new(); + let mut generated_names = HashMap::new(); + let mut mappings = Vec::new(); + + for (idx, value) in enum_values.iter().enumerate() { + // Skip duplicate enum values. Protobuf allows this when the + // 'allow_alias' option is set. + if !numbers.insert(value.number()) { + continue; + } + + let mut generated_variant_name = to_upper_camel(value.name()); + if do_strip_enum_prefix { + generated_variant_name = + strip_enum_prefix(generated_enum_name, &generated_variant_name); + } + + if let Some(old_v) = + generated_names.insert(generated_variant_name.to_owned(), value.name()) + { + panic!("Generated enum variant names overlap: `{}` variant name to be used both by `{}` and `{}` ProtoBuf enum values", + generated_variant_name, old_v, value.name()); + } + + mappings.push(EnumVariantMapping { + path_idx: idx, + proto_name: value.name(), + proto_number: value.number(), + generated_variant_name, + }) + } + mappings + } + } +} diff --git a/prost-build/src/code_generator/locations.rs b/prost-build/src/code_generator/locations.rs new file mode 100644 index 000000000..3c3c1b3a4 --- /dev/null +++ b/prost-build/src/code_generator/locations.rs @@ -0,0 +1,28 @@ +pub(super) struct FileDescriptorProtoLocations; + +impl FileDescriptorProtoLocations { + pub const MESSAGE_TYPE: i32 = 4; + pub const ENUM_TYPE: i32 = 5; + pub const SERVICE: i32 = 6; +} + +pub(super) struct DescriptorLocations; + +impl DescriptorLocations { + pub const FIELD: i32 = 2; + pub const NESTED_TYPE: i32 = 3; + pub const ENUM_TYPE: i32 = 4; + pub const ONEOF_DECL: i32 = 8; +} + +pub(super) struct EnumDescriptorLocations; + +impl EnumDescriptorLocations { + pub const VALUE: i32 = 2; +} + +pub(super) struct ServiceDescriptorProtoLocations; + +impl ServiceDescriptorProtoLocations { + pub const METHOD: i32 = 2; +} diff --git a/prost-build/src/code_generator/messages.rs b/prost-build/src/code_generator/messages.rs new file mode 100644 index 000000000..24b6dda9d --- /dev/null +++ b/prost-build/src/code_generator/messages.rs @@ -0,0 +1,556 @@ +use super::*; + +mod oneof; + +type OneofFields = MultiMap; +type MapTypes = HashMap; + +impl CodeGenerator<'_> { + pub(super) fn push_messages(&mut self, message_types: Vec) { + self.path.push(FileDescriptorProtoLocations::MESSAGE_TYPE); + for (idx, message) in message_types.into_iter().enumerate() { + self.path.push(idx as i32); + if let Some(resolved_message) = self.resolve_message(message) { + self.buf.push_str(&resolved_message.to_string()); + } + self.path.pop(); + } + self.path.pop(); + } + + fn resolve_message(&mut self, message: DescriptorProto) -> Option { + debug!(" message: {:?}", message.name()); + + let message_name = message.name().to_string(); + let fq_message_name = + FullyQualifiedName::new(&self.package, &self.type_path, &message_name); + + // Skip external types. + if self.extern_paths.resolve_ident(&fq_message_name).is_some() { + return None; + } + + // Split the nested message types into a vector of normal nested message types, and a map + // of the map field entry types. The path index of the nested message types is preserved so + // that comments can be retrieved. + type NestedTypes = Vec<(DescriptorProto, usize)>; + let (nested_types, map_types): (NestedTypes, MapTypes) = message + .nested_type + .into_iter() + .enumerate() + .partition_map(|(idx, nested_type)| { + if nested_type + .options + .as_ref() + .and_then(|options| options.map_entry) + .unwrap_or(false) + { + let key = nested_type.field[0].clone(); + let value = nested_type.field[1].clone(); + assert_eq!("key", key.name()); + assert_eq!("value", value.name()); + Either::Right(( + fq_message_name + .join(nested_type.name()) + .as_ref() + .to_string(), + (key, value), + )) + } else { + Either::Left((nested_type, idx)) + } + }); + + // Split the fields into a vector of the normal fields, and oneof fields. + // Path indexes are preserved so that comments can be retrieved. + type Fields = Vec<(FieldDescriptorProto, usize)>; + let (fields, oneof_fields): (Fields, OneofFields) = message + .field + .into_iter() + .enumerate() + .partition_map(|(idx, field)| { + if field.proto3_optional.unwrap_or(false) { + Either::Left((field, idx)) + } else if let Some(oneof_index) = field.oneof_index { + Either::Right((oneof_index, (field, idx))) + } else { + Either::Left((field, idx)) + } + }); + + let documentation = self.resolve_docs(&fq_message_name, None); + let resolved_fields = self.resolve_message_fields(&fields, &map_types, &fq_message_name); + let resolved_oneof_fields = self.resolve_oneof_fields( + &message.oneof_decl, + &oneof_fields, + &message_name, + &fq_message_name, + ); + + let ident = to_upper_camel(&message_name).parse_syn::(); + + let nested = self.recursive_nested( + &message_name, + message.enum_type, + nested_types, + oneof_fields, + &message.oneof_decl, + &fq_message_name, + ); + + let maybe_type_name = self + .config + .enable_type_names + .then_some(self.resolve_type_name(&message_name, &fq_message_name)); + + let type_attributes = self.config.type_attributes.get(fq_message_name.as_ref()); + let message_attributes = self.config.message_attributes.get(fq_message_name.as_ref()); + + let prost_path = self.prost_type_path("Message"); + let maybe_skip_debug = self.resolve_skip_debug(&fq_message_name); + + Some(quote! { + #(#documentation)* + #(#(#type_attributes)*)* + #(#(#message_attributes)*)* + #[allow(clippy::derive_partial_eq_without_eq)] + #[derive(Clone, PartialEq, #prost_path)] + #maybe_skip_debug + pub struct #ident { + #(#resolved_fields,)* + #(#resolved_oneof_fields,)* + } + + #nested + + #maybe_type_name + }) + } + + fn recursive_nested( + &mut self, + message_name: &str, + enum_type: Vec, + nested_types: Vec<(DescriptorProto, usize)>, + oneof_fields: OneofFields, + oneof_declarations: &[OneofDescriptorProto], + fq_message_name: &FullyQualifiedName, + ) -> Option { + if !enum_type.is_empty() || !nested_types.is_empty() || !oneof_fields.is_empty() { + let comment = format!("/// Nested message and enum types in `{}`.", message_name) + .parse_outer_attributes(); + + let ident = to_snake(message_name).parse_syn::(); + self.type_path.push(message_name.to_string()); + + let resolved_messages = resolve_nested_messages(self, nested_types); + let resolved_enums = resolve_nested_enums(self, enum_type); + let resolved_oneofs = + self.resolve_oneofs(oneof_declarations, oneof_fields, fq_message_name); + + self.type_path.pop(); + + return Some(quote! { + #(#comment)* + pub mod #ident { + #(#resolved_messages)* + #(#resolved_enums)* + #(#resolved_oneofs)* + } + }); + } else { + return None; + } + + fn resolve_nested_messages( + code_generator: &mut CodeGenerator<'_>, + nested_types: Vec<(DescriptorProto, usize)>, + ) -> Vec { + let mut messages = Vec::with_capacity(nested_types.len()); + + code_generator.path.push(DescriptorLocations::NESTED_TYPE); + for (nested_type, idx) in nested_types { + code_generator.path.push(idx as i32); + if let Some(message) = code_generator.resolve_message(nested_type) { + messages.push(message); + } + code_generator.path.pop(); + } + code_generator.path.pop(); + + messages + } + + fn resolve_nested_enums( + codegen: &mut CodeGenerator<'_>, + enum_type: Vec, + ) -> Vec { + let mut enums = Vec::with_capacity(enum_type.len()); + + codegen.path.push(DescriptorLocations::ENUM_TYPE); + for (idx, nested_enum) in enum_type.into_iter().enumerate() { + codegen.path.push(idx as i32); + if let Some(resolved_enum) = codegen.resolve_enum(nested_enum) { + enums.push(resolved_enum); + } + codegen.path.pop(); + } + codegen.path.pop(); + + enums + } + } + + fn resolve_message_fields( + &mut self, + fields: &[(FieldDescriptorProto, usize)], + map_types: &MapTypes, + fq_message_name: &FullyQualifiedName, + ) -> Vec { + let mut resolved_fields = Vec::with_capacity(fields.len()); + + self.path.push(DescriptorLocations::FIELD); + for (field, idx) in fields { + self.path.push(*idx as i32); + + let field = match field + .type_name + .as_ref() + .and_then(|type_name| map_types.get(type_name)) + { + Some((key, value)) => self.resolve_map_field(fq_message_name, field, key, value), + None => self.resolve_field(fq_message_name, field), + }; + + resolved_fields.push(field); + self.path.pop(); + } + self.path.pop(); + + resolved_fields + } + + fn resolve_type_name( + &mut self, + message_name: &str, + fq_message_name: &FullyQualifiedName, + ) -> TokenStream { + let name_path = self.prost_type_path("Name"); + let message_name_syn = message_name.parse_syn::(); + let package_name = &self.package; + let string_path = self.prost_type_path("alloc::string::String"); + let fully_qualified_name = + FullyQualifiedName::new(&self.package, &self.type_path, message_name); + let domain_name = self + .config + .type_name_domains + .get_first(fq_message_name.as_ref()) + .map_or("", |name| name.as_str()); + + let fq_name_str = fully_qualified_name.as_ref().trim_start_matches('.'); + let type_url = format!("{}/{}", domain_name, fq_name_str); + + quote! { + impl #name_path for #message_name_syn { + const NAME: &'static str = #message_name; + const PACKAGE: &'static str = #package_name; + + fn full_name() -> #string_path { #fq_name_str.into() } + fn type_url() -> #string_path { #type_url.into() } + } + } + } + + fn resolve_field( + &self, + fq_message_name: &FullyQualifiedName, + field: &FieldDescriptorProto, + ) -> TokenStream { + let type_ = field.r#type(); + let repeated = field.label == Some(Label::Repeated as i32); + let optional = self.optional(field); + let ty = self.resolve_type(field, fq_message_name); + let boxed = !repeated && self.should_box_field(field, fq_message_name, fq_message_name); + + debug!( + " field: {:?}, type: {:?}, boxed: {}", + field.name(), + ty, + boxed + ); + + let documentation = self.resolve_docs(fq_message_name, Some(field.name())); + let maybe_deprecated = field + .options + .as_ref() + .is_some_and(FieldOptions::deprecated) + .then_some(quote! { #[deprecated] }); + let field_type_attr = match type_ { + Type::Bytes => { + let bytes_type = self + .config + .bytes_type + .get_first_field(fq_message_name, field.name()) + .copied() + .unwrap_or_default(); + + Cow::from(format!( + "{}=\"{}\"", + self.field_type_tag(field), + bytes_type.annotation() + )) + } + _ => self.field_type_tag(field), + } + .parse_syn::(); + + let maybe_label = { + match field.label() { + Label::Optional => optional.then_some(quote! { optional, }), + Label::Required => Some(quote! { required, }), + Label::Repeated => Some( + match can_pack(field) + && !field + .options + .as_ref() + .map_or(self.syntax == Syntax::Proto3, |options| options.packed()) + { + true => quote! { repeated, packed="false", }, + false => quote! { repeated, }, + }, + ), + } + }; + let maybe_boxed = boxed.then_some(quote! { boxed, }); + let field_number_string = field.number().to_string(); + let maybe_default = field.default_value.as_ref().map(|default| { + let default_value = match type_ { + Type::Bytes => { + let mut bytes_string = String::new(); + bytes_string.push_str("b\\\""); + for b in unescape_c_escape_string(default) { + bytes_string.extend( + ascii::escape_default(b).flat_map(|c| (c as char).escape_default()), + ); + } + bytes_string.push_str("\\\""); + bytes_string + } + Type::Enum => { + let mut enum_value = to_upper_camel(default); + if self.config.strip_enum_prefix { + let enum_type = field + .type_name + .as_ref() + .and_then(|ty| ty.split('.').last()) + .expect("field type not fully qualified"); + + enum_value = strip_enum_prefix(&to_upper_camel(enum_type), &enum_value) + } + + enum_value + } + _ => default.escape_default().to_string(), + }; + format!("default=\"{}\"", default_value).parse_syn::() + }); + + let field_attributes = self.resolve_field_attributes(fq_message_name, field.name()); + let field_identifier = to_snake(field.name()).parse_syn::(); + + let maybe_wrapped = if repeated { + Some(self.prost_type_path("alloc::vec::Vec")) + } else if optional { + Some("::core::option::Option".parse_syn::()) + } else { + None + }; + let maybe_boxed_type = boxed.then_some(self.prost_type_path("alloc::boxed::Box")); + + let inner_field_type = ty.parse_syn::(); + + let field_type = match (maybe_wrapped, &maybe_boxed_type) { + (Some(wrapper), Some(boxed)) => quote! { #wrapper<#boxed<#inner_field_type>> }, + (Some(wrapper), None) => quote! { #wrapper<#inner_field_type> }, + (None, Some(boxed)) => quote! { #boxed<#inner_field_type> }, + (None, None) => quote! { #inner_field_type }, + }; + + quote! { + #(#documentation)* + #maybe_deprecated + #[prost(#field_type_attr, #maybe_label #maybe_boxed tag=#field_number_string, #maybe_default)] + #field_attributes + pub #field_identifier: #field_type + } + } + + fn resolve_map_field( + &mut self, + fq_message_name: &FullyQualifiedName, + field: &FieldDescriptorProto, + key: &FieldDescriptorProto, + value: &FieldDescriptorProto, + ) -> TokenStream { + let key_ty = self.resolve_type(key, fq_message_name); + let value_ty = self.resolve_type(value, fq_message_name); + + debug!( + " map field: {:?}, key type: {:?}, value type: {:?}", + field.name(), + key_ty, + value_ty + ); + + let documentation = self.resolve_docs(fq_message_name, Some(field.name())); + let map_type = self + .config + .map_type + .get_first_field(fq_message_name, field.name()) + .copied() + .unwrap_or_default(); + let key_tag = self.field_type_tag(key); + let value_tag = self.map_value_type_tag(value); + let meta_name_value = format!("{}=\"{}, {}\"", map_type.annotation(), key_tag, value_tag) + .parse_syn::(); + let field_number_string = field.number().to_string(); + let field_attributes = self.resolve_field_attributes(fq_message_name, field.name()); + let field_name_syn = to_snake(field.name()).parse_syn::(); + let map_rust_type = map_type.rust_type().parse_syn::(); + let key_rust_type = key_ty.parse_syn::(); + let value_rust_type = value_ty.parse_syn::(); + + quote! { + #(#documentation)* + #[prost(#meta_name_value, tag=#field_number_string)] + #field_attributes + pub #field_name_syn: #map_rust_type<#key_rust_type, #value_rust_type> + } + } +} + +// Helpers +impl CodeGenerator<'_> { + fn resolve_skip_debug(&self, fq_message_name: &FullyQualifiedName) -> Option { + self.should_skip_debug(fq_message_name) + .then_some(quote! { #[prost(skip_debug)] }) + } + + fn should_box_field( + &self, + field: &FieldDescriptorProto, + fq_message_name: &FullyQualifiedName, + first_field: &FullyQualifiedName, + ) -> bool { + ((matches!(field.r#type(), Type::Message | Type::Group)) + && self + .message_graph + .is_nested(field.type_name(), fq_message_name.as_ref())) + || (self + .config + .boxed + .get_first_field(first_field, field.name()) + .is_some()) + } + + // TODO: to syn::Type + fn resolve_type( + &self, + field: &FieldDescriptorProto, + fq_message_name: &FullyQualifiedName, + ) -> String { + match field.r#type() { + Type::Float => String::from("f32"), + Type::Double => String::from("f64"), + Type::Uint32 | Type::Fixed32 => String::from("u32"), + Type::Uint64 | Type::Fixed64 => String::from("u64"), + Type::Int32 | Type::Sfixed32 | Type::Sint32 | Type::Enum => String::from("i32"), + Type::Int64 | Type::Sfixed64 | Type::Sint64 => String::from("i64"), + Type::Bool => String::from("bool"), + Type::String => format!("{}::alloc::string::String", self.resolve_prost_path()), + Type::Bytes => self + .config + .bytes_type + .get_first_field(fq_message_name, field.name()) + .copied() + .unwrap_or_default() + .rust_type() + .to_owned(), + Type::Group | Type::Message => { + self.resolve_ident(&FullyQualifiedName::from_type_name(field.type_name())) + } + } + } + + fn map_value_type_tag(&self, field: &FieldDescriptorProto) -> Cow<'static, str> { + match field.r#type() { + Type::Enum => Cow::Owned(format!( + "enumeration({})", + self.resolve_ident(&FullyQualifiedName::from_type_name(field.type_name())) + )), + _ => self.field_type_tag(field), + } + } + + fn field_type_tag(&self, field: &FieldDescriptorProto) -> Cow<'static, str> { + match field.r#type() { + Type::Float => Cow::Borrowed("float"), + Type::Double => Cow::Borrowed("double"), + Type::Int32 => Cow::Borrowed("int32"), + Type::Int64 => Cow::Borrowed("int64"), + Type::Uint32 => Cow::Borrowed("uint32"), + Type::Uint64 => Cow::Borrowed("uint64"), + Type::Sint32 => Cow::Borrowed("sint32"), + Type::Sint64 => Cow::Borrowed("sint64"), + Type::Fixed32 => Cow::Borrowed("fixed32"), + Type::Fixed64 => Cow::Borrowed("fixed64"), + Type::Sfixed32 => Cow::Borrowed("sfixed32"), + Type::Sfixed64 => Cow::Borrowed("sfixed64"), + Type::Bool => Cow::Borrowed("bool"), + Type::String => Cow::Borrowed("string"), + Type::Bytes => Cow::Borrowed("bytes"), + Type::Group => Cow::Borrowed("group"), + Type::Message => Cow::Borrowed("message"), + Type::Enum => Cow::Owned(format!( + "enumeration=\"{}\"", + self.resolve_ident(&FullyQualifiedName::from_type_name(field.type_name())) + )), + } + } + + fn optional(&self, field: &FieldDescriptorProto) -> bool { + if field.proto3_optional.unwrap_or(false) { + return true; + } + + if field.label() != Label::Optional { + return false; + } + + match field.r#type() { + Type::Message => true, + _ => self.syntax == Syntax::Proto2, + } + } +} + +/// Returns `true` if the repeated field type can be packed. +fn can_pack(field: &FieldDescriptorProto) -> bool { + matches!( + field.r#type(), + Type::Float + | Type::Double + | Type::Int32 + | Type::Int64 + | Type::Uint32 + | Type::Uint64 + | Type::Sint32 + | Type::Sint64 + | Type::Fixed32 + | Type::Fixed64 + | Type::Sfixed32 + | Type::Sfixed64 + | Type::Bool + | Type::Enum + ) +} diff --git a/prost-build/src/code_generator/messages/oneof.rs b/prost-build/src/code_generator/messages/oneof.rs new file mode 100644 index 000000000..7c9a2f6c7 --- /dev/null +++ b/prost-build/src/code_generator/messages/oneof.rs @@ -0,0 +1,169 @@ +use super::*; + +impl CodeGenerator<'_> { + pub(super) fn resolve_oneofs( + &mut self, + oneof_declarations: &[OneofDescriptorProto], + mut oneof_fields: OneofFields, + fq_message_name: &FullyQualifiedName, + ) -> Vec { + let mut oneofs = Vec::with_capacity(oneof_declarations.len()); + + for (idx, oneof) in oneof_declarations.iter().enumerate() { + let idx = idx as i32; + // optional fields create a synthetic oneof that we want to skip + let fields = match oneof_fields.remove(&idx) { + Some(fields) => fields, + None => continue, + }; + oneofs.push(self.append_oneof(fq_message_name, oneof, idx, fields)); + } + + oneofs + } + + pub(super) fn resolve_oneof_fields( + &mut self, + oneof_declarations: &[OneofDescriptorProto], + oneof_fields: &OneofFields, + message_name: &str, + fq_message_name: &FullyQualifiedName, + ) -> Vec { + let mut resolved_onefields = Vec::with_capacity(oneof_declarations.len()); + + self.path.push(DescriptorLocations::ONEOF_DECL); + for (idx, oneof) in oneof_declarations.iter().enumerate() { + let idx = idx as i32; + + let fields = match oneof_fields.get_vec(&idx) { + Some(fields) => fields, + None => continue, + }; + + self.path.push(idx); + + resolved_onefields.push(self.resolve_oneof_field( + message_name, + fq_message_name, + oneof, + fields, + )); + + self.path.pop(); + } + self.path.pop(); + + resolved_onefields + } + + fn resolve_oneof_field( + &mut self, + message_name: &str, + fq_message_name: &FullyQualifiedName, + oneof: &OneofDescriptorProto, + fields: &[(FieldDescriptorProto, usize)], + ) -> TokenStream { + let documentation = self.resolve_docs(fq_message_name, None); + let oneof_name = format!( + "{}::{}", + to_snake(message_name), + to_upper_camel(oneof.name()) + ); + let tags = fields.iter().map(|(field, _)| field.number()).join(", "); + let field_attributes = self.resolve_field_attributes(fq_message_name, oneof.name()); + let field_name = to_snake(oneof.name()).parse_syn::(); + let oneof_type_name = oneof_name.parse_syn::(); + + quote! { + #(#documentation)* + #[prost(oneof=#oneof_name, tags=#tags)] + #field_attributes + pub #field_name: ::core::option::Option<#oneof_type_name> + } + } + + fn append_oneof( + &mut self, + fq_message_name: &FullyQualifiedName, + oneof: &OneofDescriptorProto, + idx: i32, + fields: Vec<(FieldDescriptorProto, usize)>, + ) -> TokenStream { + self.path.push(DescriptorLocations::ONEOF_DECL); + self.path.push(idx); + let documentation = self.resolve_docs(fq_message_name, None); + self.path.pop(); + self.path.pop(); + + let oneof_name = fq_message_name.join(oneof.name()); + let enum_attributes = self.resolve_enum_attributes(&oneof_name); + let maybe_skip_debug = self.resolve_skip_debug(fq_message_name); + let enum_name = to_upper_camel(oneof.name()).parse_syn::(); + let variants = self.oneof_variants(&fields, fq_message_name, &oneof_name); + + let one_of_path = self.prost_type_path("Oneof"); + quote! { + #(#documentation)* + #enum_attributes + #[allow(clippy::derive_partial_eq_without_eq)] + #[derive(Clone, PartialEq, #one_of_path)] + #maybe_skip_debug + pub enum #enum_name { + #(#variants,)* + } + } + } + + fn oneof_variants( + &mut self, + fields: &[(FieldDescriptorProto, usize)], + fq_message_name: &FullyQualifiedName, + oneof_name: &FullyQualifiedName, + ) -> Vec { + let mut variants = Vec::with_capacity(fields.len()); + + self.path.push(DescriptorLocations::FIELD); + for (field, idx) in fields { + self.path.push((*idx).try_into().expect("idx overflow")); + let documentation = self.resolve_docs(fq_message_name, Some(field.name())); + self.path.pop(); + + let ty_tag = self.field_type_tag(field).parse_syn::(); + let field_number_string = field.number().to_string(); + let field_attributes = self.resolve_field_attributes(oneof_name, field.name()); + let enum_variant = { + let rust_type = self.resolve_type(field, fq_message_name); + let type_path = rust_type.parse_syn::(); + let field_name = to_upper_camel(field.name()).parse_syn::(); + + let boxed = self.should_box_field(field, fq_message_name, oneof_name); + + debug!( + " oneof: {}, type: {}, boxed: {}", + field.name(), + rust_type, + boxed + ); + + match boxed { + true => quote! { + #field_name(::prost::alloc::boxed::Box<#type_path>) + }, + false => quote! { + #field_name(#type_path) + }, + } + }; + + variants.push(quote! { + #(#documentation)* + #[prost(#ty_tag, tag=#field_number_string)] + #field_attributes + #enum_variant + }); + } + self.path.pop(); + + variants + } +} diff --git a/prost-build/src/code_generator/mod.rs b/prost-build/src/code_generator/mod.rs new file mode 100644 index 000000000..5903df2e6 --- /dev/null +++ b/prost-build/src/code_generator/mod.rs @@ -0,0 +1,208 @@ +use std::ascii; +use std::borrow::Cow; +use std::collections::HashMap; +use std::iter; + +use itertools::{Either, Itertools}; +use log::debug; +use multimap::MultiMap; +use proc_macro2::TokenStream; +use prost_types::field_descriptor_proto::{Label, Type}; +use prost_types::{ + DescriptorProto, EnumDescriptorProto, FieldDescriptorProto, FieldOptions, FileDescriptorProto, + OneofDescriptorProto, ServiceDescriptorProto, SourceCodeInfo, +}; +use quote::quote; +use syn::{Attribute, TypePath}; + +use crate::ast::{Comments, Method, Service}; +use crate::extern_paths::ExternPaths; +use crate::ident::{strip_enum_prefix, to_snake, to_upper_camel}; +use crate::message_graph::MessageGraph; +use crate::SynHelpers; +use crate::{Config, FullyQualifiedName}; + +mod c_escaping; +use c_escaping::unescape_c_escape_string; + +mod enums; +mod messages; +mod services; + +mod syntax; +use syntax::Syntax; + +// IMPROVEMENT: would be nice to have this auto-generated +mod locations; +use locations::*; + +pub struct CodeGenerator<'a> { + config: &'a mut Config, + package: String, + type_path: Vec, + source_info: Option, + syntax: Syntax, + message_graph: &'a MessageGraph, + extern_paths: &'a ExternPaths, + path: Vec, + buf: &'a mut String, +} + +impl<'a> CodeGenerator<'a> { + fn new( + config: &'a mut Config, + message_graph: &'a MessageGraph, + extern_paths: &'a ExternPaths, + source_code_info: Option, + package: Option, + syntax: Option, + buf: &'a mut String, + ) -> Self { + let source_info = source_code_info.map(|mut s| { + s.location.retain(|loc| { + let len = loc.path.len(); + len > 0 && len % 2 == 0 + }); + s.location.sort_by(|a, b| a.path.cmp(&b.path)); + s + }); + + Self { + config, + package: package.unwrap_or_default(), + type_path: Vec::new(), + source_info, + syntax: syntax.as_ref().map(String::as_str).into(), + message_graph, + extern_paths, + path: Vec::new(), + buf, + } + } + + pub fn generate( + config: &mut Config, + message_graph: &MessageGraph, + extern_paths: &ExternPaths, + file: FileDescriptorProto, + buf: &mut String, + ) { + let mut code_gen = CodeGenerator::new( + config, + message_graph, + extern_paths, + file.source_code_info, + file.package, + file.syntax, + buf, + ); + + debug!( + "file: {:?}, package: {:?}", + file.name.as_ref().unwrap(), + code_gen.package + ); + + code_gen.push_messages(file.message_type); + code_gen.push_enums(file.enum_type); + code_gen.push_services(file.service); + } + + fn should_skip_debug(&self, fq_message_name: &FullyQualifiedName) -> bool { + self.config + .skip_debug + .get(fq_message_name.as_ref()) + .next() + .is_some() + } + + fn resolve_field_attributes( + &self, + fully_qualified_name: &FullyQualifiedName, + field_name: &str, + ) -> TokenStream { + let fq_str = fully_qualified_name.as_ref(); + let field_attributes = self.config.field_attributes.get_field(fq_str, field_name); + + quote! { + #(#(#field_attributes)*)* + } + } + + fn comments_from_location(&self) -> Option { + let source_info = self.source_info.as_ref()?; + let idx = source_info + .location + .binary_search_by_key(&&self.path[..], |location| &location.path[..]) + .unwrap(); + Some(Comments::from_location(&source_info.location[idx])) + } + + fn resolve_docs( + &self, + fq_name: &FullyQualifiedName, + field_name: Option<&str>, + ) -> Vec { + let mut comment_string = String::new(); + let disable_comments = &self.config.disable_comments; + let append_doc = match field_name { + Some(field_name) => disable_comments.get_first_field(fq_name, field_name), + None => disable_comments.get(fq_name.as_ref()).next(), + } + .is_none(); + + if append_doc { + if let Some(comments) = self.comments_from_location() { + comments.append_with_indent(&mut comment_string); + } + } + + match comment_string.is_empty() { + true => Vec::new(), + false => comment_string.parse_outer_attributes(), + } + } + + fn resolve_ident(&self, pb_ident: &FullyQualifiedName) -> String { + if let Some(proto_ident) = self.extern_paths.resolve_ident(pb_ident) { + return proto_ident; + } + + let mut local_path = self + .package + .split('.') + .chain(self.type_path.iter().map(String::as_str)) + .peekable(); + + // If no package is specified the start of the package name will be '.' + // and split will return an empty string ("") which breaks resolution + // The fix to this is to ignore the first item if it is empty. + if local_path.peek().map_or(false, |s| s.is_empty()) { + local_path.next(); + } + + let mut ident_path = pb_ident.path_iterator(); + let ident_type = ident_path.next_back().unwrap(); + let mut ident_path = ident_path.peekable(); + + // Skip path elements in common. + while local_path.peek().is_some() && local_path.peek() == ident_path.peek() { + local_path.next(); + ident_path.next(); + } + + local_path + .map(|_| "super".to_string()) + .chain(ident_path.map(to_snake)) + .chain(iter::once(to_upper_camel(ident_type))) + .join("::") + } + + fn resolve_prost_path(&self) -> &str { + self.config.prost_path.as_deref().unwrap_or("::prost") + } + + fn prost_type_path(&self, item: &str) -> TypePath { + format!("{}::{}", self.resolve_prost_path(), item).parse_syn() + } +} diff --git a/prost-build/src/code_generator/services.rs b/prost-build/src/code_generator/services.rs new file mode 100644 index 000000000..4fbb4e588 --- /dev/null +++ b/prost-build/src/code_generator/services.rs @@ -0,0 +1,78 @@ +use super::*; + +impl CodeGenerator<'_> { + pub(super) fn push_services(&mut self, services: Vec) { + if self.config.service_generator.is_some() { + self.path.push(FileDescriptorProtoLocations::SERVICE); + for (idx, service) in services.into_iter().enumerate() { + self.path.push(idx as i32); + self.push_service(service); + self.path.pop(); + } + + if let Some(service_generator) = self.config.service_generator.as_mut() { + service_generator.finalize(self.buf); + } + + self.path.pop(); + } + } + + fn push_service(&mut self, service: ServiceDescriptorProto) { + let name = service.name().to_owned(); + debug!(" service: {:?}", name); + + let comments = self.comments_from_location().unwrap_or_default(); + + self.path.push(ServiceDescriptorProtoLocations::METHOD); + let methods = service + .method + .into_iter() + .enumerate() + .map(|(idx, mut method)| { + debug!(" method: {:?}", method.name()); + + self.path.push(idx as i32); + let comments = self.comments_from_location().unwrap_or_default(); + self.path.pop(); + + let name = method.name.take().unwrap(); + let input_proto_type = method.input_type.take().unwrap(); + let output_proto_type = method.output_type.take().unwrap(); + let input_type = + self.resolve_ident(&FullyQualifiedName::from_type_name(&input_proto_type)); + let output_type = + self.resolve_ident(&FullyQualifiedName::from_type_name(&output_proto_type)); + let client_streaming = method.client_streaming(); + let server_streaming = method.server_streaming(); + + Method { + name: to_snake(&name), + proto_name: name, + comments, + input_type, + output_type, + input_proto_type, + output_proto_type, + options: method.options.unwrap_or_default(), + client_streaming, + server_streaming, + } + }) + .collect(); + self.path.pop(); + + let service = Service { + name: to_upper_camel(&name), + proto_name: name, + package: self.package.clone(), + comments, + methods, + options: service.options.unwrap_or_default(), + }; + + if let Some(service_generator) = self.config.service_generator.as_mut() { + service_generator.generate(service, self.buf) + } + } +} diff --git a/prost-build/src/code_generator/syntax.rs b/prost-build/src/code_generator/syntax.rs new file mode 100644 index 000000000..ca44be423 --- /dev/null +++ b/prost-build/src/code_generator/syntax.rs @@ -0,0 +1,15 @@ +#[derive(PartialEq)] +pub(super) enum Syntax { + Proto2, + Proto3, +} + +impl From> for Syntax { + fn from(optional_str: Option<&str>) -> Self { + match optional_str { + None | Some("proto2") => Syntax::Proto2, + Some("proto3") => Syntax::Proto3, + Some(s) => panic!("unknown syntax: {}", s), + } + } +} diff --git a/prost-build/src/collections.rs b/prost-build/src/collections.rs new file mode 100644 index 000000000..63be4d627 --- /dev/null +++ b/prost-build/src/collections.rs @@ -0,0 +1,57 @@ +/// The map collection type to output for Protobuf `map` fields. +#[non_exhaustive] +#[derive(Default, Clone, Copy, Debug, PartialEq)] +pub(crate) enum MapType { + /// The [`std::collections::HashMap`] type. + #[default] + HashMap, + /// The [`std::collections::BTreeMap`] type. + BTreeMap, +} + +/// The bytes collection type to output for Protobuf `bytes` fields. +#[non_exhaustive] +#[derive(Default, Clone, Copy, Debug, PartialEq)] +pub(crate) enum BytesType { + /// The [`alloc::collections::Vec::`] type. + #[default] + Vec, + /// The [`bytes::Bytes`] type. + Bytes, +} + +impl MapType { + /// The `prost-derive` annotation type corresponding to the map type. + pub fn annotation(&self) -> &'static str { + match self { + MapType::HashMap => "map", + MapType::BTreeMap => "btree_map", + } + } + + /// The fully-qualified Rust type corresponding to the map type. + pub fn rust_type(&self) -> &'static str { + match self { + MapType::HashMap => "::std::collections::HashMap", + MapType::BTreeMap => "::prost::alloc::collections::BTreeMap", + } + } +} + +impl BytesType { + /// The `prost-derive` annotation type corresponding to the bytes type. + pub fn annotation(&self) -> &'static str { + match self { + BytesType::Vec => "vec", + BytesType::Bytes => "bytes", + } + } + + /// The fully-qualified Rust type corresponding to the bytes type. + pub fn rust_type(&self) -> &'static str { + match self { + BytesType::Vec => "::prost::alloc::vec::Vec", + BytesType::Bytes => "::prost::bytes::Bytes", + } + } +} diff --git a/prost-build/src/config.rs b/prost-build/src/config.rs index 670c0befe..e6ae44cf9 100644 --- a/prost-build/src/config.rs +++ b/prost-build/src/config.rs @@ -13,6 +13,7 @@ use log::trace; use prost::Message; use prost_types::{FileDescriptorProto, FileDescriptorSet}; +use syn::Attribute; use crate::code_generator::CodeGenerator; use crate::extern_paths::ExternPaths; @@ -22,6 +23,7 @@ use crate::BytesType; use crate::MapType; use crate::Module; use crate::ServiceGenerator; +use crate::SynHelpers; /// Configuration options for Protobuf code generation. /// @@ -31,10 +33,10 @@ pub struct Config { pub(crate) service_generator: Option>, pub(crate) map_type: PathMap, pub(crate) bytes_type: PathMap, - pub(crate) type_attributes: PathMap, - pub(crate) message_attributes: PathMap, - pub(crate) enum_attributes: PathMap, - pub(crate) field_attributes: PathMap, + pub(crate) type_attributes: PathMap>, + pub(crate) message_attributes: PathMap>, + pub(crate) enum_attributes: PathMap>, + pub(crate) field_attributes: PathMap>, pub(crate) boxed: PathMap<()>, pub(crate) prost_types: bool, pub(crate) strip_enum_prefix: bool, @@ -208,8 +210,11 @@ impl Config { P: AsRef, A: AsRef, { - self.field_attributes - .insert(path.as_ref().to_string(), attribute.as_ref().to_string()); + self.field_attributes.insert( + path.as_ref().to_string(), + // TEMP(gibbz00): return error instead? + attribute.parse_outer_attributes(), + ); self } @@ -257,8 +262,11 @@ impl Config { P: AsRef, A: AsRef, { - self.type_attributes - .insert(path.as_ref().to_string(), attribute.as_ref().to_string()); + self.type_attributes.insert( + path.as_ref().to_string(), + // TEMP(gibbz00): return error instead? + attribute.parse_outer_attributes(), + ); self } @@ -296,8 +304,11 @@ impl Config { P: AsRef, A: AsRef, { - self.message_attributes - .insert(path.as_ref().to_string(), attribute.as_ref().to_string()); + self.message_attributes.insert( + path.as_ref().to_string(), + // TEMP(gibbz00): return error instead? + attribute.parse_outer_attributes(), + ); self } @@ -345,8 +356,11 @@ impl Config { P: AsRef, A: AsRef, { - self.enum_attributes - .insert(path.as_ref().to_string(), attribute.as_ref().to_string()); + self.enum_attributes.insert( + path.as_ref().to_string(), + // TEMP(gibbz00): return error instead? + attribute.parse_outer_attributes(), + ); self } diff --git a/prost-build/src/extern_paths.rs b/prost-build/src/extern_paths.rs index 27c8d6d71..1b0168114 100644 --- a/prost-build/src/extern_paths.rs +++ b/prost-build/src/extern_paths.rs @@ -2,7 +2,10 @@ use std::collections::{hash_map, HashMap}; use itertools::Itertools; -use crate::ident::{to_snake, to_upper_camel}; +use crate::{ + ident::{to_snake, to_upper_camel}, + FullyQualifiedName, +}; fn validate_proto_path(path: &str) -> Result<(), String> { if path.chars().next().map(|c| c != '.').unwrap_or(true) { @@ -78,9 +81,8 @@ impl ExternPaths { Ok(()) } - pub fn resolve_ident(&self, pb_ident: &str) -> Option { - // protoc should always give fully qualified identifiers. - assert_eq!(".", &pb_ident[..1]); + pub fn resolve_ident(&self, pb_ident: &FullyQualifiedName) -> Option { + let pb_ident = pb_ident.as_ref(); if let Some(rust_path) = self.extern_paths.get(pb_ident) { return Some(rust_path.clone()); @@ -136,7 +138,10 @@ mod tests { .unwrap(); let case = |proto_ident: &str, resolved_ident: &str| { - assert_eq!(paths.resolve_ident(proto_ident).unwrap(), resolved_ident); + assert_eq!( + paths.resolve_ident(&proto_ident.into()).unwrap(), + resolved_ident + ); }; case(".foo", "::foo1"); @@ -150,9 +155,9 @@ mod tests { case(".a.b.c.d.e.f", "::abc::def"); case(".a.b.c.d.e.f.g.FooBar.Baz", "::abc::def::g::foo_bar::Baz"); - assert!(paths.resolve_ident(".a").is_none()); - assert!(paths.resolve_ident(".a.b").is_none()); - assert!(paths.resolve_ident(".a.c").is_none()); + assert!(paths.resolve_ident(&".a".into()).is_none()); + assert!(paths.resolve_ident(&".a.b".into()).is_none()); + assert!(paths.resolve_ident(&".a.c".into()).is_none()); } #[test] @@ -160,7 +165,10 @@ mod tests { let paths = ExternPaths::new(&[], true).unwrap(); let case = |proto_ident: &str, resolved_ident: &str| { - assert_eq!(paths.resolve_ident(proto_ident).unwrap(), resolved_ident); + assert_eq!( + paths.resolve_ident(&proto_ident.into()).unwrap(), + resolved_ident + ); }; case(".google.protobuf.Value", "::prost_types::Value"); diff --git a/prost-build/src/fixtures/helloworld/_expected_helloworld_formatted.rs b/prost-build/src/fixtures/helloworld/_expected_helloworld_formatted.rs index 3f688c7e0..c160287f2 100644 --- a/prost-build/src/fixtures/helloworld/_expected_helloworld_formatted.rs +++ b/prost-build/src/fixtures/helloworld/_expected_helloworld_formatted.rs @@ -14,7 +14,8 @@ pub struct Response { pub say: ::prost::alloc::string::String, } #[some_enum_attr(u8)] -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[derive(Debug)] +#[derive(Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] pub enum ServingStatus { Unknown = 0, diff --git a/prost-build/src/fully_qualified_name.rs b/prost-build/src/fully_qualified_name.rs new file mode 100644 index 000000000..36ff24cf6 --- /dev/null +++ b/prost-build/src/fully_qualified_name.rs @@ -0,0 +1,52 @@ +use itertools::Itertools; + +// Invariant: should always begin with a '.' (dot) +#[derive(Debug)] +pub struct FullyQualifiedName(String); + +impl FullyQualifiedName { + pub fn new(package_string: &str, type_path: &[impl AsRef], message_name: &str) -> Self { + Self(format!( + "{}{}{}{}{}{}", + if package_string.is_empty() { "" } else { "." }, + package_string.trim_matches('.'), + if type_path.is_empty() { "" } else { "." }, + type_path + .iter() + .map(AsRef::as_ref) + .map(|type_path_str| type_path_str.trim_start_matches('.')) + .join("."), + if message_name.is_empty() { "" } else { "." }, + message_name, + )) + } + + pub fn from_type_name(type_name: &str) -> Self { + Self::new("", &[type_name], "") + } + + pub fn path_iterator(&self) -> impl DoubleEndedIterator { + self.0[1..].split('.') + } + + pub fn join(&self, message_name: &str) -> Self { + Self(format!("{}.{}", self.0, message_name)) + } +} + +impl AsRef for FullyQualifiedName { + fn as_ref(&self) -> &str { + &self.0 + } +} + +#[cfg(test)] +mod test_helpers { + use super::*; + + impl From<&str> for FullyQualifiedName { + fn from(str: &str) -> Self { + Self(str.to_string()) + } + } +} diff --git a/prost-build/src/lib.rs b/prost-build/src/lib.rs index 7b4f43cba..a3338f6d3 100644 --- a/prost-build/src/lib.rs +++ b/prost-build/src/lib.rs @@ -139,6 +139,15 @@ use prost_types::FileDescriptorSet; mod ast; pub use crate::ast::{Comments, Method, Service}; +mod collections; +pub(crate) use collections::{BytesType, MapType}; + +mod fully_qualified_name; +pub(crate) use fully_qualified_name::FullyQualifiedName; + +mod syn_helpers; +pub(crate) use syn_helpers::SynHelpers; + mod code_generator; mod extern_paths; mod ident; @@ -196,28 +205,6 @@ pub trait ServiceGenerator { fn finalize_package(&mut self, _package: &str, _buf: &mut String) {} } -/// The map collection type to output for Protobuf `map` fields. -#[non_exhaustive] -#[derive(Default, Clone, Copy, Debug, PartialEq)] -enum MapType { - /// The [`std::collections::HashMap`] type. - #[default] - HashMap, - /// The [`std::collections::BTreeMap`] type. - BTreeMap, -} - -/// The bytes collection type to output for Protobuf `bytes` fields. -#[non_exhaustive] -#[derive(Default, Clone, Copy, Debug, PartialEq)] -enum BytesType { - /// The [`alloc::collections::Vec::`] type. - #[default] - Vec, - /// The [`bytes::Bytes`] type. - Bytes, -} - /// Compile `.proto` files into Rust files during a Cargo build. /// /// The generated `.rs` files are written to the Cargo `OUT_DIR` directory, suitable for use with @@ -306,12 +293,12 @@ mod tests { impl ServiceGenerator for ServiceTraitGenerator { fn generate(&mut self, service: Service, buf: &mut String) { // Generate a trait for the service. - service.comments.append_with_indent(0, buf); + service.comments.append_with_indent(buf); buf.push_str(&format!("trait {} {{\n", &service.name)); // Generate the service methods. for method in service.methods { - method.comments.append_with_indent(1, buf); + method.comments.append_with_indent(buf); buf.push_str(&format!( " fn {}(_: {}) -> {};\n", method.name, method.input_type, method.output_type @@ -428,11 +415,7 @@ mod tests { let expected_content = read_all_content("src/fixtures/helloworld/_expected_helloworld.rs") .replace("\r\n", "\n"); let content = read_all_content(&out_file).replace("\r\n", "\n"); - assert_eq!( - expected_content, content, - "Unexpected content: \n{}", - content - ); + pretty_assertions::assert_eq!(expected_content, content,); } #[test] @@ -518,11 +501,7 @@ mod tests { read_all_content("src/fixtures/field_attributes/_expected_field_attributes.rs") .replace("\r\n", "\n"); - assert_eq!( - expected_content, content, - "Unexpected content: \n{}", - content - ); + pretty_assertions::assert_eq!(expected_content, content,); } #[test] diff --git a/prost-build/src/path.rs b/prost-build/src/path.rs index f6897005d..a7aa9d6f9 100644 --- a/prost-build/src/path.rs +++ b/prost-build/src/path.rs @@ -2,6 +2,8 @@ use std::iter; +use crate::FullyQualifiedName; + /// Maps a fully-qualified Protobuf path to a value using path matchers. #[derive(Debug, Default)] pub(crate) struct PathMap { @@ -35,8 +37,12 @@ impl PathMap { /// Returns the first value found matching the path `fq_path.field` /// If nothing matches the path, suffix paths will be tried, then prefix paths, then the global path - pub(crate) fn get_first_field<'a>(&'a self, fq_path: &'_ str, field: &'_ str) -> Option<&'a T> { - self.find_best_matching(&format!("{}.{}", fq_path, field)) + pub(crate) fn get_first_field<'a>( + &'a self, + fq_path: &'_ FullyQualifiedName, + field: &'_ str, + ) -> Option<&'a T> { + self.find_best_matching(&format!("{}.{}", fq_path.as_ref(), field)) } /// Removes all matchers from the path map. @@ -195,27 +201,29 @@ mod tests { fn test_get_best() { let mut path_map = PathMap::default(); + let fq_name = FullyQualifiedName::new("a", &["b"], "c"); + // worst is global path_map.insert(".".to_owned(), 1); assert_eq!(Some(&1), path_map.get_first(".a.b.c.d")); assert_eq!(Some(&1), path_map.get_first("b.c.d")); - assert_eq!(Some(&1), path_map.get_first_field(".a.b.c", "d")); + assert_eq!(Some(&1), path_map.get_first_field(&fq_name, "d")); // then prefix path_map.insert(".a.b".to_owned(), 2); assert_eq!(Some(&2), path_map.get_first(".a.b.c.d")); - assert_eq!(Some(&2), path_map.get_first_field(".a.b.c", "d")); + assert_eq!(Some(&2), path_map.get_first_field(&fq_name, "d")); // then suffix path_map.insert("c.d".to_owned(), 3); assert_eq!(Some(&3), path_map.get_first(".a.b.c.d")); assert_eq!(Some(&3), path_map.get_first("b.c.d")); - assert_eq!(Some(&3), path_map.get_first_field(".a.b.c", "d")); + assert_eq!(Some(&3), path_map.get_first_field(&fq_name, "d")); // best is full path path_map.insert(".a.b.c.d".to_owned(), 4); assert_eq!(Some(&4), path_map.get_first(".a.b.c.d")); - assert_eq!(Some(&4), path_map.get_first_field(".a.b.c", "d")); + assert_eq!(Some(&4), path_map.get_first_field(&fq_name, "d")); } #[test] diff --git a/prost-build/src/syn_helpers.rs b/prost-build/src/syn_helpers.rs new file mode 100644 index 000000000..78bec2b1c --- /dev/null +++ b/prost-build/src/syn_helpers.rs @@ -0,0 +1,27 @@ +use syn::parse::{Parse, Parser}; + +#[allow(clippy::expect_fun_call)] +pub trait SynHelpers: AsRef { + /// Used internally for syn parsing where any errors are allowed to be immediatedly unwrapped. + fn parse_syn(&self) -> T { + let input_str = self.as_ref(); + syn::parse_str(input_str).expect(&build_error_string::(input_str)) + } + + fn parse_outer_attributes(&self) -> Vec { + let input_str = self.as_ref(); + syn::Attribute::parse_outer + .parse_str(input_str) + .expect(&build_error_string::(input_str)) + } +} + +impl> SynHelpers for T {} + +fn build_error_string(input_str: &str) -> String { + format!( + "unable to parse {} as {}", + input_str, + std::any::type_name::() + ) +} diff --git a/prost-types/src/compiler.rs b/prost-types/src/compiler.rs index 0a3b46804..35dfd60c2 100644 --- a/prost-types/src/compiler.rs +++ b/prost-types/src/compiler.rs @@ -135,17 +135,8 @@ pub mod code_generator_response { pub generated_code_info: ::core::option::Option, } /// Sync with code_generator.h. - #[derive( - Clone, - Copy, - Debug, - PartialEq, - Eq, - Hash, - PartialOrd, - Ord, - ::prost::Enumeration - )] + #[derive(Debug)] + #[derive(Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] pub enum Feature { None = 0, diff --git a/prost-types/src/protobuf.rs b/prost-types/src/protobuf.rs index edc1361be..e543057a1 100644 --- a/prost-types/src/protobuf.rs +++ b/prost-types/src/protobuf.rs @@ -181,17 +181,8 @@ pub struct FieldDescriptorProto { } /// Nested message and enum types in `FieldDescriptorProto`. pub mod field_descriptor_proto { - #[derive( - Clone, - Copy, - Debug, - PartialEq, - Eq, - Hash, - PartialOrd, - Ord, - ::prost::Enumeration - )] + #[derive(Debug)] + #[derive(Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] pub enum Type { /// 0 is reserved for errors. @@ -279,17 +270,8 @@ pub mod field_descriptor_proto { } } } - #[derive( - Clone, - Copy, - Debug, - PartialEq, - Eq, - Hash, - PartialOrd, - Ord, - ::prost::Enumeration - )] + #[derive(Debug)] + #[derive(Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] pub enum Label { /// 0 is reserved for errors @@ -561,17 +543,8 @@ pub struct FileOptions { /// Nested message and enum types in `FileOptions`. pub mod file_options { /// Generated classes can be optimized for speed or code size. - #[derive( - Clone, - Copy, - Debug, - PartialEq, - Eq, - Hash, - PartialOrd, - Ord, - ::prost::Enumeration - )] + #[derive(Debug)] + #[derive(Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] pub enum OptimizeMode { /// Generate complete code for parsing, serialization, @@ -750,17 +723,8 @@ pub struct FieldOptions { } /// Nested message and enum types in `FieldOptions`. pub mod field_options { - #[derive( - Clone, - Copy, - Debug, - PartialEq, - Eq, - Hash, - PartialOrd, - Ord, - ::prost::Enumeration - )] + #[derive(Debug)] + #[derive(Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] pub enum CType { /// Default mode. @@ -790,17 +754,8 @@ pub mod field_options { } } } - #[derive( - Clone, - Copy, - Debug, - PartialEq, - Eq, - Hash, - PartialOrd, - Ord, - ::prost::Enumeration - )] + #[derive(Debug)] + #[derive(Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] pub enum JsType { /// Use the default type. @@ -908,17 +863,8 @@ pub mod method_options { /// Is this method side-effect-free (or safe in HTTP parlance), or idempotent, /// or neither? HTTP based RPC implementation may choose GET verb for safe /// methods, and PUT verb for idempotent methods instead of the default POST. - #[derive( - Clone, - Copy, - Debug, - PartialEq, - Eq, - Hash, - PartialOrd, - Ord, - ::prost::Enumeration - )] + #[derive(Debug)] + #[derive(Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] pub enum IdempotencyLevel { IdempotencyUnknown = 0, @@ -1374,17 +1320,8 @@ pub struct Field { /// Nested message and enum types in `Field`. pub mod field { /// Basic field types. - #[derive( - Clone, - Copy, - Debug, - PartialEq, - Eq, - Hash, - PartialOrd, - Ord, - ::prost::Enumeration - )] + #[derive(Debug)] + #[derive(Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] pub enum Kind { /// Field type unknown. @@ -1481,17 +1418,8 @@ pub mod field { } } /// Whether a field is optional, required, or repeated. - #[derive( - Clone, - Copy, - Debug, - PartialEq, - Eq, - Hash, - PartialOrd, - Ord, - ::prost::Enumeration - )] + #[derive(Debug)] + #[derive(Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] pub enum Cardinality { /// For fields with unknown cardinality. @@ -1581,7 +1509,8 @@ pub struct Option { pub value: ::core::option::Option, } /// The syntax in which a protocol buffer element is defined. -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[derive(Debug)] +#[derive(Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] pub enum Syntax { /// Syntax `proto2`. @@ -2169,7 +2098,8 @@ pub struct ListValue { /// `Value` type union. /// /// The JSON representation for `NullValue` is JSON `null`. -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[derive(Debug)] +#[derive(Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] pub enum NullValue { /// Null value. diff --git a/tests-2015/Cargo.toml b/tests-2015/Cargo.toml index cdb48cabf..4d29c8d99 100644 --- a/tests-2015/Cargo.toml +++ b/tests-2015/Cargo.toml @@ -29,6 +29,7 @@ protobuf = { path = "../protobuf" } [dev-dependencies] diff = "0.1" +pretty_assertions = "1" prost-build = { path = "../prost-build" } tempfile = "3" diff --git a/tests/Cargo.toml b/tests/Cargo.toml index edc2fa86e..fd7f92c80 100644 --- a/tests/Cargo.toml +++ b/tests/Cargo.toml @@ -16,7 +16,6 @@ std = [] [dependencies] anyhow = "1.0.1" -# bytes = "1" cfg-if = "1" prost = { path = ".." } prost-types = { path = "../prost-types" } @@ -24,6 +23,7 @@ protobuf = { path = "../protobuf" } [dev-dependencies] diff = "0.1" +pretty_assertions = "1" prost-build = { path = "../prost-build", features = ["cleanup-markdown"] } tempfile = "3" diff --git a/tests/src/bootstrap.rs b/tests/src/bootstrap.rs index d329cad58..16d9028fc 100644 --- a/tests/src/bootstrap.rs +++ b/tests/src/bootstrap.rs @@ -92,6 +92,6 @@ fn bootstrap() { .unwrap(); } - assert_eq!(protobuf, bootstrapped_protobuf); - assert_eq!(compiler, bootstrapped_compiler); + pretty_assertions::assert_eq!(protobuf, bootstrapped_protobuf); + pretty_assertions::assert_eq!(compiler, bootstrapped_compiler); } diff --git a/tests/src/lib.rs b/tests/src/lib.rs index 00926afe0..ad0612b89 100644 --- a/tests/src/lib.rs +++ b/tests/src/lib.rs @@ -643,7 +643,7 @@ mod tests { #[test] fn test_default_string_escape() { let msg = default_string_escape::Person::default(); - assert_eq!(msg.name, r#"["unknown"]"#); + assert_eq!(r#"["unknown"]"#, msg.name); } #[test]