prost_wkt_types/
pbany.rs

1use prost_wkt::MessageSerde;
2use serde::de::{Deserialize, Deserializer};
3use serde::ser::{Serialize, SerializeStruct, Serializer};
4
5include!(concat!(env!("OUT_DIR"), "/pbany/google.protobuf.rs"));
6
7use prost::{DecodeError, Message, EncodeError, Name};
8
9use std::borrow::Cow;
10
11#[derive(Clone, Debug, PartialEq, Eq)]
12pub struct AnyError {
13    description: Cow<'static, str>,
14}
15
16impl AnyError {
17    pub fn new<S>(description: S) -> Self
18    where
19        S: Into<Cow<'static, str>>,
20    {
21        AnyError {
22            description: description.into(),
23        }
24    }
25}
26
27impl std::error::Error for AnyError {
28    fn description(&self) -> &str {
29        &self.description
30    }
31}
32
33impl std::fmt::Display for AnyError {
34    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35        f.write_str("failed to convert Value: ")?;
36        f.write_str(&self.description)
37    }
38}
39
40impl From<prost::DecodeError> for AnyError {
41    fn from(error: DecodeError) -> Self {
42        AnyError::new(format!("Error decoding message: {error:?}"))
43    }
44}
45
46impl From<prost::EncodeError> for AnyError {
47    fn from(error: prost::EncodeError) -> Self {
48        AnyError::new(format!("Error encoding message: {error:?}"))
49    }
50}
51
52impl Any {
53    //#[deprecated(since = "0.5.0", note = "please use `from_msg` instead")]
54    /// Packs a message into an `Any` containing a `type_url` which will take the format
55    /// of `type.googleapis.com/package_name.struct_name`, and a value containing the
56    /// encoded message.
57    pub fn try_pack<T>(message: T) -> Result<Self, AnyError>
58    where
59        T: Message + MessageSerde + Default,
60    {
61        let type_url = MessageSerde::type_url(&message).to_string();
62        // Serialize the message into a value
63        let mut buf = Vec::with_capacity(message.encoded_len());
64        message.encode(&mut buf)?;
65        let encoded = Any {
66            type_url,
67            value: buf,
68        };
69        Ok(encoded)
70    }
71
72    //#[deprecated(since = "0.5.0", note = "please use `to_msg` instead")]
73    /// Unpacks the contents of the `Any` into the provided message type. Example usage:
74    ///
75    /// ```ignore
76    /// let back: Foo = any.unpack_as(Foo::default())?;
77    /// ```
78    pub fn unpack_as<T: Message>(self, mut target: T) -> Result<T, AnyError> {
79        let instance = target.merge(self.value.as_slice()).map(|_| target)?;
80        Ok(instance)
81    }
82
83    /// Unpacks the contents of the `Any` into the `MessageSerde` trait object. Example
84    /// usage:
85    ///
86    /// ```ignore
87    /// let back: Box<dyn MessageSerde> = any.try_unpack()?;
88    /// ```
89    pub fn try_unpack(self) -> Result<Box<dyn prost_wkt::MessageSerde>, AnyError> {
90        ::prost_wkt::inventory::iter::<::prost_wkt::MessageSerdeDecoderEntry>
91            .into_iter()
92            .find(|entry| self.type_url == entry.type_url)
93            .ok_or_else(|| format!("Failed to deserialize {}. Make sure prost-wkt-build is executed.", self.type_url))
94            .and_then(|entry| {
95                (entry.decoder)(&self.value).map_err(|error| {
96                    format!(
97                        "Failed to deserialize {}. Make sure it implements prost::Message. Error reported: {}",
98                        self.type_url,
99                        error
100                    )
101                })
102            })
103            .map_err(AnyError::new)
104    }
105
106    /// From Prost's [`Any`] implementation.
107    /// Serialize the given message type `M` as [`Any`].
108    pub fn from_msg<M>(msg: &M) -> Result<Self, EncodeError>
109    where
110        M: Name,
111    {
112        let type_url = M::type_url();
113        let mut value = Vec::new();
114        Message::encode(msg, &mut value)?;
115        Ok(Any { type_url, value })
116    }
117
118    /// From Prost's [`Any`] implementation.
119    /// Decode the given message type `M` from [`Any`], validating that it has
120    /// the expected type URL.
121    #[allow(clippy::all)]
122    pub fn to_msg<M>(&self) -> Result<M, DecodeError>
123    where
124        M: Default + Name + Sized,
125    {
126        let expected_type_url = M::type_url();
127
128        match (
129            TypeUrl::new(&expected_type_url),
130            TypeUrl::new(&self.type_url),
131        ) {
132            (Some(expected), Some(actual)) => {
133                if expected == actual {
134                    return Ok(M::decode(&*self.value)?);
135                }
136            }
137            _ => (),
138        }
139
140        let mut err = DecodeError::new(format!(
141            "expected type URL: \"{}\" (got: \"{}\")",
142            expected_type_url, &self.type_url
143        ));
144        err.push("unexpected type URL", "type_url");
145        Err(err)
146    }
147
148}
149
150impl Serialize for Any {
151    fn serialize<S>(&self, serializer: S) -> Result<<S as Serializer>::Ok, <S as Serializer>::Error>
152    where
153        S: Serializer,
154    {
155        match self.clone().try_unpack() {
156            Ok(result) => serde::ser::Serialize::serialize(result.as_ref(), serializer),
157            Err(_) => {
158                let mut state = serializer.serialize_struct("Any", 3)?;
159                state.serialize_field("@type", &self.type_url)?;
160                state.serialize_field("value", &self.value)?;
161                state.end()
162            }
163        }
164    }
165}
166
167impl<'de> Deserialize<'de> for Any {
168    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
169    where
170        D: Deserializer<'de>,
171    {
172        let erased: Box<dyn prost_wkt::MessageSerde> =
173            serde::de::Deserialize::deserialize(deserializer)?;
174        let type_url = erased.type_url().to_string();
175        let value = erased.try_encoded().map_err(|err| {
176            serde::de::Error::custom(format!("Failed to encode message: {err:?}"))
177        })?;
178        Ok(Any { type_url, value })
179    }
180}
181
182/// URL/resource name that uniquely identifies the type of the serialized protocol buffer message,
183/// e.g. `type.googleapis.com/google.protobuf.Duration`.
184///
185/// This string must contain at least one "/" character.
186///
187/// The last segment of the URL's path must represent the fully qualified name of the type (as in
188/// `path/google.protobuf.Duration`). The name should be in a canonical form (e.g., leading "." is
189/// not accepted).
190///
191/// If no scheme is provided, `https` is assumed.
192///
193/// Schemes other than `http`, `https` (or the empty scheme) might be used with implementation
194/// specific semantics.
195#[derive(Debug, Eq, PartialEq)]
196struct TypeUrl<'a> {
197    /// Fully qualified name of the type, e.g. `google.protobuf.Duration`
198    full_name: &'a str,
199}
200
201impl<'a> TypeUrl<'a> {
202    fn new(s: &'a str) -> core::option::Option<Self> {
203        // Must contain at least one "/" character.
204        let slash_pos = s.rfind('/')?;
205
206        // The last segment of the URL's path must represent the fully qualified name
207        // of the type (as in `path/google.protobuf.Duration`)
208        let full_name = s.get((slash_pos + 1)..)?;
209
210        // The name should be in a canonical form (e.g., leading "." is not accepted).
211        if full_name.starts_with('.') {
212            return None;
213        }
214
215        Some(Self { full_name })
216    }
217}
218
219#[cfg(test)]
220mod tests {
221    use crate::pbany::*;
222    use prost::{DecodeError, EncodeError, Message};
223    use prost_wkt::*;
224    use serde::*;
225    use serde_json::json;
226
227    #[derive(Clone, Eq, PartialEq, ::prost::Message, Serialize, Deserialize)]
228    #[serde(default, rename_all = "camelCase")]
229    pub struct Foo {
230        #[prost(string, tag = "1")]
231        pub string: std::string::String,
232    }
233
234    impl Name for Foo {
235        const NAME: &'static str = "Foo";
236        const PACKAGE: &'static str = "any.test";
237    }
238
239    #[typetag::serde(name = "type.googleapis.com/any.test.Foo")]
240    impl prost_wkt::MessageSerde for Foo {
241        fn message_name(&self) -> &'static str {
242            "Foo"
243        }
244
245        fn package_name(&self) -> &'static str {
246            "any.test"
247        }
248
249        fn type_url(&self) -> &'static str {
250            "type.googleapis.com/any.test.Foo"
251        }
252        fn new_instance(&self, data: Vec<u8>) -> Result<Box<dyn MessageSerde>, DecodeError> {
253            let mut target = Self::default();
254            Message::merge(&mut target, data.as_slice())?;
255            let erased: Box<dyn MessageSerde> = Box::new(target);
256            Ok(erased)
257        }
258
259        fn try_encoded(&self) -> Result<Vec<u8>, EncodeError> {
260            let mut buf = Vec::with_capacity(Message::encoded_len(self));
261            Message::encode(self, &mut buf)?;
262            Ok(buf)
263        }
264    }
265
266    #[test]
267    fn pack_unpack_test() {
268        let msg = Foo {
269            string: "Hello World!".to_string(),
270        };
271        let any = Any::try_pack(msg.clone()).unwrap();
272        println!("{any:?}");
273        let unpacked = any.unpack_as(Foo::default()).unwrap();
274        println!("{unpacked:?}");
275        assert_eq!(unpacked, msg)
276    }
277
278    #[test]
279    fn pack_unpack_with_downcast_test() {
280        let msg = Foo {
281            string: "Hello World!".to_string(),
282        };
283        let any = Any::try_pack(msg.clone()).unwrap();
284        println!("{any:?}");
285        let unpacked: &dyn MessageSerde = &any.unpack_as(Foo::default()).unwrap();
286        let downcast = unpacked.downcast_ref::<Foo>().unwrap();
287        assert_eq!(downcast, &msg);
288    }
289
290    #[test]
291    fn deserialize_default_test() {
292        let type_url = "type.googleapis.com/any.test.Foo";
293        let data = json!({
294            "@type": type_url,
295            "value": {}
296        });
297        let erased: Box<dyn MessageSerde> = serde_json::from_value(data).unwrap();
298        let foo: &Foo = erased.downcast_ref::<Foo>().unwrap();
299        println!("Deserialize default: {foo:?}");
300        assert_eq!(foo, &Foo::default())
301    }
302
303    #[test]
304    fn check_prost_any_serialization() {
305        let message = crate::Timestamp::date(2000, 1, 1).unwrap();
306        let any = Any::from_msg(&message).unwrap();
307        assert_eq!(
308            &any.type_url,
309            "type.googleapis.com/google.protobuf.Timestamp"
310        );
311
312        let message2 = any.to_msg::<crate::Timestamp>().unwrap();
313        assert_eq!(message, message2);
314
315        // Wrong type URL
316        assert!(any.to_msg::<crate::Duration>().is_err());
317    }
318}