1use prost_wkt::MessageSerde;
2use serde::de::{Deserialize, Deserializer};
3use serde::ser::{Serialize, SerializeStruct, Serializer};
4
5include!(concat!(env!("OUT_DIR"), "/pbany/google.protobuf.rs"));
6
7use prost::{DecodeError, Message, EncodeError, Name};
8
9use std::borrow::Cow;
10
11#[derive(Clone, Debug, PartialEq, Eq)]
12pub struct AnyError {
13 description: Cow<'static, str>,
14}
15
16impl AnyError {
17 pub fn new<S>(description: S) -> Self
18 where
19 S: Into<Cow<'static, str>>,
20 {
21 AnyError {
22 description: description.into(),
23 }
24 }
25}
26
27impl std::error::Error for AnyError {
28 fn description(&self) -> &str {
29 &self.description
30 }
31}
32
33impl std::fmt::Display for AnyError {
34 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35 f.write_str("failed to convert Value: ")?;
36 f.write_str(&self.description)
37 }
38}
39
40impl From<prost::DecodeError> for AnyError {
41 fn from(error: DecodeError) -> Self {
42 AnyError::new(format!("Error decoding message: {error:?}"))
43 }
44}
45
46impl From<prost::EncodeError> for AnyError {
47 fn from(error: prost::EncodeError) -> Self {
48 AnyError::new(format!("Error encoding message: {error:?}"))
49 }
50}
51
52impl Any {
53 pub fn try_pack<T>(message: T) -> Result<Self, AnyError>
58 where
59 T: Message + MessageSerde + Default,
60 {
61 let type_url = MessageSerde::type_url(&message).to_string();
62 let mut buf = Vec::with_capacity(message.encoded_len());
64 message.encode(&mut buf)?;
65 let encoded = Any {
66 type_url,
67 value: buf,
68 };
69 Ok(encoded)
70 }
71
72 pub fn unpack_as<T: Message>(self, mut target: T) -> Result<T, AnyError> {
79 let instance = target.merge(self.value.as_slice()).map(|_| target)?;
80 Ok(instance)
81 }
82
83 pub fn try_unpack(self) -> Result<Box<dyn prost_wkt::MessageSerde>, AnyError> {
90 ::prost_wkt::inventory::iter::<::prost_wkt::MessageSerdeDecoderEntry>
91 .into_iter()
92 .find(|entry| self.type_url == entry.type_url)
93 .ok_or_else(|| format!("Failed to deserialize {}. Make sure prost-wkt-build is executed.", self.type_url))
94 .and_then(|entry| {
95 (entry.decoder)(&self.value).map_err(|error| {
96 format!(
97 "Failed to deserialize {}. Make sure it implements prost::Message. Error reported: {}",
98 self.type_url,
99 error
100 )
101 })
102 })
103 .map_err(AnyError::new)
104 }
105
106 pub fn from_msg<M>(msg: &M) -> Result<Self, EncodeError>
109 where
110 M: Name,
111 {
112 let type_url = M::type_url();
113 let mut value = Vec::new();
114 Message::encode(msg, &mut value)?;
115 Ok(Any { type_url, value })
116 }
117
118 #[allow(clippy::all)]
122 pub fn to_msg<M>(&self) -> Result<M, DecodeError>
123 where
124 M: Default + Name + Sized,
125 {
126 let expected_type_url = M::type_url();
127
128 match (
129 TypeUrl::new(&expected_type_url),
130 TypeUrl::new(&self.type_url),
131 ) {
132 (Some(expected), Some(actual)) => {
133 if expected == actual {
134 return Ok(M::decode(&*self.value)?);
135 }
136 }
137 _ => (),
138 }
139
140 let mut err = DecodeError::new(format!(
141 "expected type URL: \"{}\" (got: \"{}\")",
142 expected_type_url, &self.type_url
143 ));
144 err.push("unexpected type URL", "type_url");
145 Err(err)
146 }
147
148}
149
150impl Serialize for Any {
151 fn serialize<S>(&self, serializer: S) -> Result<<S as Serializer>::Ok, <S as Serializer>::Error>
152 where
153 S: Serializer,
154 {
155 match self.clone().try_unpack() {
156 Ok(result) => serde::ser::Serialize::serialize(result.as_ref(), serializer),
157 Err(_) => {
158 let mut state = serializer.serialize_struct("Any", 3)?;
159 state.serialize_field("@type", &self.type_url)?;
160 state.serialize_field("value", &self.value)?;
161 state.end()
162 }
163 }
164 }
165}
166
167impl<'de> Deserialize<'de> for Any {
168 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
169 where
170 D: Deserializer<'de>,
171 {
172 let erased: Box<dyn prost_wkt::MessageSerde> =
173 serde::de::Deserialize::deserialize(deserializer)?;
174 let type_url = erased.type_url().to_string();
175 let value = erased.try_encoded().map_err(|err| {
176 serde::de::Error::custom(format!("Failed to encode message: {err:?}"))
177 })?;
178 Ok(Any { type_url, value })
179 }
180}
181
182#[derive(Debug, Eq, PartialEq)]
196struct TypeUrl<'a> {
197 full_name: &'a str,
199}
200
201impl<'a> TypeUrl<'a> {
202 fn new(s: &'a str) -> core::option::Option<Self> {
203 let slash_pos = s.rfind('/')?;
205
206 let full_name = s.get((slash_pos + 1)..)?;
209
210 if full_name.starts_with('.') {
212 return None;
213 }
214
215 Some(Self { full_name })
216 }
217}
218
219#[cfg(test)]
220mod tests {
221 use crate::pbany::*;
222 use prost::{DecodeError, EncodeError, Message};
223 use prost_wkt::*;
224 use serde::*;
225 use serde_json::json;
226
227 #[derive(Clone, Eq, PartialEq, ::prost::Message, Serialize, Deserialize)]
228 #[serde(default, rename_all = "camelCase")]
229 pub struct Foo {
230 #[prost(string, tag = "1")]
231 pub string: std::string::String,
232 }
233
234 impl Name for Foo {
235 const NAME: &'static str = "Foo";
236 const PACKAGE: &'static str = "any.test";
237 }
238
239 #[typetag::serde(name = "type.googleapis.com/any.test.Foo")]
240 impl prost_wkt::MessageSerde for Foo {
241 fn message_name(&self) -> &'static str {
242 "Foo"
243 }
244
245 fn package_name(&self) -> &'static str {
246 "any.test"
247 }
248
249 fn type_url(&self) -> &'static str {
250 "type.googleapis.com/any.test.Foo"
251 }
252 fn new_instance(&self, data: Vec<u8>) -> Result<Box<dyn MessageSerde>, DecodeError> {
253 let mut target = Self::default();
254 Message::merge(&mut target, data.as_slice())?;
255 let erased: Box<dyn MessageSerde> = Box::new(target);
256 Ok(erased)
257 }
258
259 fn try_encoded(&self) -> Result<Vec<u8>, EncodeError> {
260 let mut buf = Vec::with_capacity(Message::encoded_len(self));
261 Message::encode(self, &mut buf)?;
262 Ok(buf)
263 }
264 }
265
266 #[test]
267 fn pack_unpack_test() {
268 let msg = Foo {
269 string: "Hello World!".to_string(),
270 };
271 let any = Any::try_pack(msg.clone()).unwrap();
272 println!("{any:?}");
273 let unpacked = any.unpack_as(Foo::default()).unwrap();
274 println!("{unpacked:?}");
275 assert_eq!(unpacked, msg)
276 }
277
278 #[test]
279 fn pack_unpack_with_downcast_test() {
280 let msg = Foo {
281 string: "Hello World!".to_string(),
282 };
283 let any = Any::try_pack(msg.clone()).unwrap();
284 println!("{any:?}");
285 let unpacked: &dyn MessageSerde = &any.unpack_as(Foo::default()).unwrap();
286 let downcast = unpacked.downcast_ref::<Foo>().unwrap();
287 assert_eq!(downcast, &msg);
288 }
289
290 #[test]
291 fn deserialize_default_test() {
292 let type_url = "type.googleapis.com/any.test.Foo";
293 let data = json!({
294 "@type": type_url,
295 "value": {}
296 });
297 let erased: Box<dyn MessageSerde> = serde_json::from_value(data).unwrap();
298 let foo: &Foo = erased.downcast_ref::<Foo>().unwrap();
299 println!("Deserialize default: {foo:?}");
300 assert_eq!(foo, &Foo::default())
301 }
302
303 #[test]
304 fn check_prost_any_serialization() {
305 let message = crate::Timestamp::date(2000, 1, 1).unwrap();
306 let any = Any::from_msg(&message).unwrap();
307 assert_eq!(
308 &any.type_url,
309 "type.googleapis.com/google.protobuf.Timestamp"
310 );
311
312 let message2 = any.to_msg::<crate::Timestamp>().unwrap();
313 assert_eq!(message, message2);
314
315 assert!(any.to_msg::<crate::Duration>().is_err());
317 }
318}