From ddc89383f0dba617657f63ab4179b51d4538e6b2 Mon Sep 17 00:00:00 2001 From: Giang Minh Date: Mon, 6 May 2024 15:19:09 +0700 Subject: [PATCH 1/2] fix option enum by move default getter to get_fallback, get will return Option --- prost-build/src/code_generator.rs | 109 ++++++++++++++++-------------- prost-build/src/config.rs | 7 +- prost-build/src/message_graph.rs | 4 +- prost-derive/src/field/scalar.rs | 46 +++++++++++-- 4 files changed, 104 insertions(+), 62 deletions(-) diff --git a/prost-build/src/code_generator.rs b/prost-build/src/code_generator.rs index 6ca8581ab..daf8277f7 100644 --- a/prost-build/src/code_generator.rs +++ b/prost-build/src/code_generator.rs @@ -63,7 +63,7 @@ impl Field { } fn rust_name(&self) -> String { - to_snake(self.descriptor.name()) + to_snake(self.descriptor.name_fallback()) } } @@ -83,7 +83,7 @@ impl OneofField { } fn rust_name(&self) -> String { - to_snake(self.descriptor.name()) + to_snake(self.descriptor.name_fallback()) } } @@ -156,9 +156,9 @@ impl<'a> CodeGenerator<'a> { } fn append_message(&mut self, message: DescriptorProto) { - debug!(" message: {:?}", message.name()); + debug!(" message: {:?}", message.name_fallback()); - let message_name = message.name().to_string(); + let message_name = message.name_fallback().to_string(); let fq_message_name = self.fq_name(&message_name); // Skip external types. @@ -184,10 +184,10 @@ impl<'a> CodeGenerator<'a> { { let key = nested_type.field[0].clone(); let value = nested_type.field[1].clone(); - assert_eq!("key", key.name()); - assert_eq!("value", value.name()); + assert_eq!("key", key.name_fallback()); + assert_eq!("value", value.name_fallback()); - let name = format!("{}.{}", &fq_message_name, nested_type.name()); + let name = format!("{}.{}", &fq_message_name, nested_type.name_fallback()); Either::Right((name, (key, value))) } else { Either::Left((nested_type, idx)) @@ -398,7 +398,7 @@ impl<'a> CodeGenerator<'a> { } fn append_field(&mut self, fq_message_name: &str, field: &Field) { - let type_ = field.descriptor.r#type(); + let type_ = field.descriptor.r#type_fallback(); let repeated = field.descriptor.label == Some(Label::Repeated as i32); let deprecated = self.deprecated(&field.descriptor); let optional = self.optional(&field.descriptor); @@ -412,7 +412,7 @@ impl<'a> CodeGenerator<'a> { boxed ); - self.append_doc(fq_message_name, Some(field.descriptor.name())); + self.append_doc(fq_message_name, Some(field.descriptor.name_fallback())); if deprecated { self.push_indent(); @@ -428,14 +428,14 @@ impl<'a> CodeGenerator<'a> { let bytes_type = self .config .bytes_type - .get_first_field(fq_message_name, field.descriptor.name()) + .get_first_field(fq_message_name, field.descriptor.name_fallback()) .copied() .unwrap_or_default(); self.buf .push_str(&format!("={:?}", bytes_type.annotation())); } - match field.descriptor.label() { + match field.descriptor.label_fallback() { Label::Optional => { if optional { self.buf.push_str(", optional"); @@ -449,7 +449,9 @@ impl<'a> CodeGenerator<'a> { .descriptor .options .as_ref() - .map_or(self.syntax == Syntax::Proto3, |options| options.packed()) + .map_or(self.syntax == Syntax::Proto3, |options| { + options.packed_fallback() + }) { self.buf.push_str(", packed=\"false\""); } @@ -460,7 +462,8 @@ impl<'a> CodeGenerator<'a> { self.buf.push_str(", boxed"); } self.buf.push_str(", tag=\""); - self.buf.push_str(&field.descriptor.number().to_string()); + self.buf + .push_str(&field.descriptor.number_fallback().to_string()); if let Some(ref default) = field.descriptor.default_value { self.buf.push_str("\", default=\""); @@ -494,7 +497,7 @@ impl<'a> CodeGenerator<'a> { } self.buf.push_str("\")]\n"); - self.append_field_attributes(fq_message_name, field.descriptor.name()); + self.append_field_attributes(fq_message_name, field.descriptor.name_fallback()); self.push_indent(); self.buf.push_str("pub "); self.buf.push_str(&field.rust_name()); @@ -539,13 +542,13 @@ impl<'a> CodeGenerator<'a> { value_ty ); - self.append_doc(fq_message_name, Some(field.descriptor.name())); + self.append_doc(fq_message_name, Some(field.descriptor.name_fallback())); self.push_indent(); let map_type = self .config .map_type - .get_first_field(fq_message_name, field.descriptor.name()) + .get_first_field(fq_message_name, field.descriptor.name_fallback()) .copied() .unwrap_or_default(); let key_tag = self.field_type_tag(key); @@ -556,9 +559,9 @@ impl<'a> CodeGenerator<'a> { map_type.annotation(), key_tag, value_tag, - field.descriptor.number() + field.descriptor.number_fallback() )); - self.append_field_attributes(fq_message_name, field.descriptor.name()); + self.append_field_attributes(fq_message_name, field.descriptor.name_fallback()); self.push_indent(); self.buf.push_str(&format!( "pub {}: {}<{}, {}>,\n", @@ -578,7 +581,7 @@ impl<'a> CodeGenerator<'a> { let type_name = format!( "{}::{}", to_snake(message_name), - to_upper_camel(oneof.descriptor.name()) + to_upper_camel(oneof.descriptor.name_fallback()) ); self.append_doc(fq_message_name, None); self.push_indent(); @@ -588,10 +591,10 @@ impl<'a> CodeGenerator<'a> { oneof .fields .iter() - .map(|field| field.descriptor.number()) + .map(|field| field.descriptor.number_fallback()) .join(", "), )); - self.append_field_attributes(fq_message_name, oneof.descriptor.name()); + self.append_field_attributes(fq_message_name, oneof.descriptor.name_fallback()); self.push_indent(); self.buf.push_str(&format!( "pub {}: ::core::option::Option<{}>,\n", @@ -607,7 +610,7 @@ impl<'a> CodeGenerator<'a> { self.path.pop(); self.path.pop(); - let oneof_name = format!("{}.{}", fq_message_name, oneof.descriptor.name()); + let oneof_name = format!("{}.{}", fq_message_name, oneof.descriptor.name_fallback()); self.append_type_attributes(&oneof_name); self.append_enum_attributes(&oneof_name); self.push_indent(); @@ -620,14 +623,15 @@ impl<'a> CodeGenerator<'a> { self.append_skip_debug(fq_message_name); self.push_indent(); self.buf.push_str("pub enum "); - self.buf.push_str(&to_upper_camel(oneof.descriptor.name())); + self.buf + .push_str(&to_upper_camel(oneof.descriptor.name_fallback())); self.buf.push_str(" {\n"); self.path.push(2); self.depth += 1; for field in &oneof.fields { self.path.push(field.path_index); - self.append_doc(fq_message_name, Some(field.descriptor.name())); + self.append_doc(fq_message_name, Some(field.descriptor.name_fallback())); self.path.pop(); self.push_indent(); @@ -635,9 +639,9 @@ impl<'a> CodeGenerator<'a> { self.buf.push_str(&format!( "#[prost({}, tag=\"{}\")]\n", ty_tag, - field.descriptor.number() + field.descriptor.number_fallback() )); - self.append_field_attributes(&oneof_name, field.descriptor.name()); + self.append_field_attributes(&oneof_name, field.descriptor.name_fallback()); self.push_indent(); let ty = self.resolve_type(&field.descriptor, fq_message_name); @@ -645,7 +649,7 @@ impl<'a> CodeGenerator<'a> { let boxed = self.boxed( &field.descriptor, fq_message_name, - Some(oneof.descriptor.name()), + Some(oneof.descriptor.name_fallback()), ); debug!( @@ -658,13 +662,13 @@ impl<'a> CodeGenerator<'a> { if boxed { self.buf.push_str(&format!( "{}(::prost::alloc::boxed::Box<{}>),\n", - to_upper_camel(field.descriptor.name()), + to_upper_camel(field.descriptor.name_fallback()), ty )); } else { self.buf.push_str(&format!( "{}({}),\n", - to_upper_camel(field.descriptor.name()), + to_upper_camel(field.descriptor.name_fallback()), ty )); } @@ -704,7 +708,7 @@ impl<'a> CodeGenerator<'a> { fn append_enum(&mut self, desc: EnumDescriptorProto) { debug!(" enum: {:?}", desc.name()); - let proto_enum_name = desc.name(); + let proto_enum_name = desc.name_fallback(); let enum_name = to_upper_camel(proto_enum_name); let enum_values = &desc.value; @@ -851,7 +855,7 @@ impl<'a> CodeGenerator<'a> { } fn push_service(&mut self, service: ServiceDescriptorProto) { - let name = service.name().to_owned(); + let name = service.name_fallback().to_owned(); debug!(" service: {:?}", name); let comments = self @@ -879,8 +883,8 @@ impl<'a> CodeGenerator<'a> { let output_proto_type = method.output_type.take().unwrap(); let input_type = self.resolve_ident(&input_proto_type); let output_type = self.resolve_ident(&output_proto_type); - let client_streaming = method.client_streaming(); - let server_streaming = method.server_streaming(); + let client_streaming = method.client_streaming_fallback(); + let server_streaming = method.server_streaming_fallback(); Method { name: to_snake(&name), @@ -942,7 +946,7 @@ impl<'a> CodeGenerator<'a> { } fn resolve_type(&self, field: &FieldDescriptorProto, fq_message_name: &str) -> String { - match field.r#type() { + match field.r#type_fallback() { Type::Float => String::from("f32"), Type::Double => String::from("f64"), Type::Uint32 | Type::Fixed32 => String::from("u32"), @@ -954,12 +958,12 @@ impl<'a> CodeGenerator<'a> { Type::Bytes => self .config .bytes_type - .get_first_field(fq_message_name, field.name()) + .get_first_field(fq_message_name, field.name_fallback()) .copied() .unwrap_or_default() .rust_type() .to_owned(), - Type::Group | Type::Message => self.resolve_ident(field.type_name()), + Type::Group | Type::Message => self.resolve_ident(field.type_name_fallback()), } } @@ -1002,7 +1006,7 @@ impl<'a> CodeGenerator<'a> { } fn field_type_tag(&self, field: &FieldDescriptorProto) -> Cow<'static, str> { - match field.r#type() { + match field.r#type_fallback() { Type::Float => Cow::Borrowed("float"), Type::Double => Cow::Borrowed("double"), Type::Int32 => Cow::Borrowed("int32"), @@ -1022,16 +1026,16 @@ impl<'a> CodeGenerator<'a> { Type::Message => Cow::Borrowed("message"), Type::Enum => Cow::Owned(format!( "enumeration={:?}", - self.resolve_ident(field.type_name()) + self.resolve_ident(field.type_name_fallback()) )), } } fn map_value_type_tag(&self, field: &FieldDescriptorProto) -> Cow<'static, str> { - match field.r#type() { + match field.r#type_fallback() { Type::Enum => Cow::Owned(format!( "enumeration({})", - self.resolve_ident(field.type_name()) + self.resolve_ident(field.type_name_fallback()) )), _ => self.field_type_tag(field), } @@ -1042,11 +1046,11 @@ impl<'a> CodeGenerator<'a> { return true; } - if field.label() != Label::Optional { + if field.label_fallback() != Label::Optional { return false; } - match field.r#type() { + match field.r#type_fallback() { Type::Message => true, _ => self.syntax == Syntax::Proto2, } @@ -1064,12 +1068,12 @@ impl<'a> CodeGenerator<'a> { oneof: Option<&str>, ) -> bool { let repeated = field.label == Some(Label::Repeated as i32); - let fd_type = field.r#type(); + let fd_type = field.r#type_fallback(); if !repeated && (fd_type == Type::Message || fd_type == Type::Group) && self .message_graph - .is_nested(field.type_name(), fq_message_name) + .is_nested(field.type_name_fallback(), fq_message_name) { return true; } @@ -1080,7 +1084,7 @@ impl<'a> CodeGenerator<'a> { if self .config .boxed - .get_first_field(&config_path, field.name()) + .get_first_field(&config_path, field.name_fallback()) .is_some() { if repeated { @@ -1100,7 +1104,7 @@ impl<'a> CodeGenerator<'a> { field .options .as_ref() - .map_or(false, FieldOptions::deprecated) + .map_or(false, FieldOptions::deprecated_fallback) } /// Returns the fully-qualified name, starting with a dot @@ -1119,7 +1123,7 @@ impl<'a> CodeGenerator<'a> { /// Returns `true` if the repeated field type can be packed. fn can_pack(field: &FieldDescriptorProto) -> bool { matches!( - field.r#type(), + field.r#type_fallback(), Type::Float | Type::Double | Type::Int32 @@ -1160,22 +1164,23 @@ fn build_enum_value_mappings<'a>( continue; } - let mut generated_variant_name = to_upper_camel(value.name()); + let mut generated_variant_name = to_upper_camel(value.name_fallback()); if do_strip_enum_prefix { generated_variant_name = strip_enum_prefix(generated_enum_name, &generated_variant_name); } - if let Some(old_v) = generated_names.insert(generated_variant_name.to_owned(), value.name()) + if let Some(old_v) = + generated_names.insert(generated_variant_name.to_owned(), value.name_fallback()) { panic!("Generated enum variant names overlap: `{}` variant name to be used both by `{}` and `{}` ProtoBuf enum values", - generated_variant_name, old_v, value.name()); + generated_variant_name, old_v, value.name_fallback()); } mappings.push(EnumVariantMapping { path_idx: idx, - proto_name: value.name(), - proto_number: value.number(), + proto_name: value.name_fallback(), + proto_number: value.number_fallback(), generated_variant_name, }) } diff --git a/prost-build/src/config.rs b/prost-build/src/config.rs index a696e404b..fcffb6324 100644 --- a/prost-build/src/config.rs +++ b/prost-build/src/config.rs @@ -783,7 +783,7 @@ impl Config { .into_iter() .map(|descriptor| { ( - Module::from_protobuf_package_name(descriptor.package()), + Module::from_protobuf_package_name(descriptor.package_fallback()), descriptor, ) }) @@ -1031,7 +1031,10 @@ impl Config { for (request_module, request_fd) in requests { // Only record packages that have services if !request_fd.service.is_empty() { - packages.insert(request_module.clone(), request_fd.package().to_string()); + packages.insert( + request_module.clone(), + request_fd.package_fallback().to_string(), + ); } let buf = modules .entry(request_module.clone()) diff --git a/prost-build/src/message_graph.rs b/prost-build/src/message_graph.rs index ac0ad1523..bda40d4dd 100644 --- a/prost-build/src/message_graph.rs +++ b/prost-build/src/message_graph.rs @@ -58,8 +58,8 @@ impl MessageGraph { let msg_index = self.get_or_insert_index(msg_name.clone()); for field in &msg.field { - if field.r#type() == field_descriptor_proto::Type::Message - && field.label() != field_descriptor_proto::Label::Repeated + if field.r#type_fallback() == field_descriptor_proto::Type::Message + && field.label_fallback() != field_descriptor_proto::Label::Repeated { let field_index = self.get_or_insert_index(field.type_name.clone().unwrap()); self.graph.add_edge(msg_index, field_index, ()); diff --git a/prost-derive/src/field/scalar.rs b/prost-derive/src/field/scalar.rs index 6be16cd70..b17dec711 100644 --- a/prost-derive/src/field/scalar.rs +++ b/prost-derive/src/field/scalar.rs @@ -306,20 +306,35 @@ impl Field { } } Kind::Optional(ref default) => { - let get_doc = format!( + let get_doc_fallback = format!( "Returns the enum value of `{}`, \ or the default if the field is unset or set to an invalid enum value.", ident_str, ); + let get_doc = format!( + "Returns the enum value of `{}`, \ + or None if the field is unset or set to an invalid enum value.", + ident_str, + ); + let get_fallback = + Ident::new(&format!("{}_fallback", ident_str), Span::call_site()); quote! { - #[doc=#get_doc] - pub fn #get(&self) -> #ty { + #[doc=#get_doc_fallback] + pub fn #get_fallback(&self) -> #ty { self.#ident.and_then(|x| { let result: ::core::result::Result<#ty, _> = ::core::convert::TryFrom::try_from(x); result.ok() }).unwrap_or(#default) } + #[doc=#get_doc] + pub fn #get(&self) -> ::core::option::Option<#ty> { + self.#ident.and_then(|x| { + let result: ::core::result::Result<#ty, _> = ::core::convert::TryFrom::try_from(x); + result.ok() + }) + } + #[doc=#set_doc] pub fn #set(&mut self, value: #ty) { self.#ident = ::core::option::Option::Some(value as i32); @@ -360,19 +375,38 @@ impl Field { quote!(::core::option::Option::Some(ref val) => &val[..],) }; - let get_doc = format!( + let match_some2 = if self.ty.is_numeric() { + quote!(::core::option::Option::Some(val) => Some(val),) + } else { + quote!(::core::option::Option::Some(ref val) => Some(&val[..]),) + }; + + let get_doc_fallback = format!( "Returns the value of `{0}`, or the default value if `{0}` is unset.", ident_str, ); + let get_doc = format!( + "Returns the value of `{0}`, or None if `{0}` is unset.", + ident_str, + ); + let get_fallback = Ident::new(&format!("{}_fallback", ident_str), Span::call_site()); Some(quote! { - #[doc=#get_doc] - pub fn #get(&self) -> #ty { + #[doc=#get_doc_fallback] + pub fn #get_fallback(&self) -> #ty { match self.#ident { #match_some ::core::option::Option::None => #default, } } + + #[doc=#get_doc] + pub fn #get(&self) -> ::core::option::Option<#ty> { + match self.#ident { + #match_some2 + ::core::option::Option::None => None, + } + } }) } else { None From 22465a8feba58548f0a7026df59eaa43cbff48a6 Mon Sep 17 00:00:00 2001 From: Giang Minh Date: Mon, 6 May 2024 15:28:25 +0700 Subject: [PATCH 2/2] fix tests --- tests/src/lib.rs | 9 ++++-- tests/src/unittest.rs | 64 +++++++++++++++++++++---------------------- 2 files changed, 38 insertions(+), 35 deletions(-) diff --git a/tests/src/lib.rs b/tests/src/lib.rs index 6674ddddd..87fffa249 100644 --- a/tests/src/lib.rs +++ b/tests/src/lib.rs @@ -553,13 +553,16 @@ mod tests { #[test] fn test_default_enum() { let msg = default_enum_value::Test::default(); - assert_eq!(msg.privacy_level_1(), default_enum_value::PrivacyLevel::One); assert_eq!( - msg.privacy_level_3(), + msg.privacy_level_1_fallback(), + default_enum_value::PrivacyLevel::One + ); + assert_eq!( + msg.privacy_level_3_fallback(), default_enum_value::PrivacyLevel::PrivacyLevelThree ); assert_eq!( - msg.privacy_level_4(), + msg.privacy_level_4_fallback(), default_enum_value::PrivacyLevel::PrivacyLevelprivacyLevelFour ); } diff --git a/tests/src/unittest.rs b/tests/src/unittest.rs index 7c8c85d66..ff6cd92f7 100644 --- a/tests/src/unittest.rs +++ b/tests/src/unittest.rs @@ -11,38 +11,38 @@ fn extreme_default_values() { assert_eq!( b"\0\x01\x07\x08\x0C\n\r\t\x0B\\\'\"\xFE", - pb.escaped_bytes() + pb.escaped_bytes_fallback() ); - assert_eq!(0xFFFFFFFF, pb.large_uint32()); - assert_eq!(0xFFFFFFFFFFFFFFFF, pb.large_uint64()); - assert_eq!(-0x7FFFFFFF, pb.small_int32()); - assert_eq!(-0x7FFFFFFFFFFFFFFF, pb.small_int64()); - assert_eq!(-0x80000000, pb.really_small_int32()); - assert_eq!(-0x8000000000000000, pb.really_small_int64()); - - assert_eq!(pb.utf8_string(), "\u{1234}"); - - assert_eq!(0.0, pb.zero_float()); - assert_eq!(1.0, pb.one_float()); - assert_eq!(1.5, pb.small_float()); - assert_eq!(-1.0, pb.negative_one_float()); - assert_eq!(-1.5, pb.negative_float()); - assert_eq!(2E8, pb.large_float()); - assert_eq!(-8e-28, pb.small_negative_float()); - - assert_eq!(f64::INFINITY, pb.inf_double()); - assert_eq!(f64::NEG_INFINITY, pb.neg_inf_double()); - assert_ne!(pb.nan_double(), pb.nan_double()); - assert_eq!(f32::INFINITY, pb.inf_float()); - assert_eq!(f32::NEG_INFINITY, pb.neg_inf_float()); - assert_ne!(pb.nan_float(), pb.nan_float()); - - assert_eq!("? ? ?? ?? ??? ??/ ??-", pb.cpp_trigraph()); - - assert_eq!("hel\x00lo", pb.string_with_zero()); - assert_eq!(b"wor\x00ld", pb.bytes_with_zero()); - assert_eq!("ab\x00c", pb.string_piece_with_zero()); - assert_eq!("12\x003", pb.cord_with_zero()); - assert_eq!("${unknown}", pb.replacement_string()); + assert_eq!(0xFFFFFFFF, pb.large_uint32_fallback()); + assert_eq!(0xFFFFFFFFFFFFFFFF, pb.large_uint64_fallback()); + assert_eq!(-0x7FFFFFFF, pb.small_int32_fallback()); + assert_eq!(-0x7FFFFFFFFFFFFFFF, pb.small_int64_fallback()); + assert_eq!(-0x80000000, pb.really_small_int32_fallback()); + assert_eq!(-0x8000000000000000, pb.really_small_int64_fallback()); + + assert_eq!(pb.utf8_string_fallback(), "\u{1234}"); + + assert_eq!(0.0, pb.zero_float_fallback()); + assert_eq!(1.0, pb.one_float_fallback()); + assert_eq!(1.5, pb.small_float_fallback()); + assert_eq!(-1.0, pb.negative_one_float_fallback()); + assert_eq!(-1.5, pb.negative_float_fallback()); + assert_eq!(2E8, pb.large_float_fallback()); + assert_eq!(-8e-28, pb.small_negative_float_fallback()); + + assert_eq!(f64::INFINITY, pb.inf_double_fallback()); + assert_eq!(f64::NEG_INFINITY, pb.neg_inf_double_fallback()); + assert_ne!(pb.nan_double_fallback(), pb.nan_double_fallback()); + assert_eq!(f32::INFINITY, pb.inf_float_fallback()); + assert_eq!(f32::NEG_INFINITY, pb.neg_inf_float_fallback()); + assert_ne!(pb.nan_float_fallback(), pb.nan_float_fallback()); + + assert_eq!("? ? ?? ?? ??? ??/ ??-", pb.cpp_trigraph_fallback()); + + assert_eq!("hel\x00lo", pb.string_with_zero_fallback()); + assert_eq!(b"wor\x00ld", pb.bytes_with_zero_fallback()); + assert_eq!("ab\x00c", pb.string_piece_with_zero_fallback()); + assert_eq!("12\x003", pb.cord_with_zero_fallback()); + assert_eq!("${unknown}", pb.replacement_string_fallback()); }