1use super::compression::{
2 compress, CompressionEncoding, CompressionSettings, SingleMessageCompressionOverride,
3};
4use super::{BufferSettings, EncodeBuf, Encoder, DEFAULT_MAX_SEND_MESSAGE_SIZE, HEADER_SIZE};
5use crate::Status;
6use bytes::{BufMut, Bytes, BytesMut};
7use http::HeaderMap;
8use http_body::{Body, Frame};
9use pin_project::pin_project;
10use std::{
11 pin::Pin,
12 task::{ready, Context, Poll},
13};
14use tokio_stream::{adapters::Fuse, Stream, StreamExt};
15
16#[pin_project(project = EncodedBytesProj)]
22#[derive(Debug)]
23struct EncodedBytes<T, U> {
24 #[pin]
25 source: Fuse<U>,
26 encoder: T,
27 compression_encoding: Option<CompressionEncoding>,
28 max_message_size: Option<usize>,
29 buf: BytesMut,
30 uncompression_buf: BytesMut,
31 error: Option<Status>,
32}
33
34impl<T: Encoder, U: Stream> EncodedBytes<T, U> {
35 fn new(
36 encoder: T,
37 source: U,
38 compression_encoding: Option<CompressionEncoding>,
39 compression_override: SingleMessageCompressionOverride,
40 max_message_size: Option<usize>,
41 ) -> Self {
42 let buffer_settings = encoder.buffer_settings();
43 let buf = BytesMut::with_capacity(buffer_settings.buffer_size);
44
45 let compression_encoding =
46 if compression_override == SingleMessageCompressionOverride::Disable {
47 None
48 } else {
49 compression_encoding
50 };
51
52 let uncompression_buf = if compression_encoding.is_some() {
53 BytesMut::with_capacity(buffer_settings.buffer_size)
54 } else {
55 BytesMut::new()
56 };
57
58 Self {
59 source: source.fuse(),
60 encoder,
61 compression_encoding,
62 max_message_size,
63 buf,
64 uncompression_buf,
65 error: None,
66 }
67 }
68}
69
70impl<T, U> Stream for EncodedBytes<T, U>
71where
72 T: Encoder<Error = Status>,
73 U: Stream<Item = Result<T::Item, Status>>,
74{
75 type Item = Result<Bytes, Status>;
76
77 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
78 let EncodedBytesProj {
79 mut source,
80 encoder,
81 compression_encoding,
82 max_message_size,
83 buf,
84 uncompression_buf,
85 error,
86 } = self.project();
87 let buffer_settings = encoder.buffer_settings();
88
89 if let Some(status) = error.take() {
90 return Poll::Ready(Some(Err(status)));
91 }
92
93 loop {
94 match source.as_mut().poll_next(cx) {
95 Poll::Pending if buf.is_empty() => {
96 return Poll::Pending;
97 }
98 Poll::Ready(None) if buf.is_empty() => {
99 return Poll::Ready(None);
100 }
101 Poll::Pending | Poll::Ready(None) => {
102 return Poll::Ready(Some(Ok(buf.split_to(buf.len()).freeze())));
103 }
104 Poll::Ready(Some(Ok(item))) => {
105 if let Err(status) = encode_item(
106 encoder,
107 buf,
108 uncompression_buf,
109 *compression_encoding,
110 *max_message_size,
111 buffer_settings,
112 item,
113 ) {
114 return Poll::Ready(Some(Err(status)));
115 }
116
117 if buf.len() >= buffer_settings.yield_threshold {
118 return Poll::Ready(Some(Ok(buf.split_to(buf.len()).freeze())));
119 }
120 }
121 Poll::Ready(Some(Err(status))) => {
122 if buf.is_empty() {
123 return Poll::Ready(Some(Err(status)));
124 }
125 *error = Some(status);
126 return Poll::Ready(Some(Ok(buf.split_to(buf.len()).freeze())));
127 }
128 }
129 }
130 }
131}
132
133fn encode_item<T>(
134 encoder: &mut T,
135 buf: &mut BytesMut,
136 uncompression_buf: &mut BytesMut,
137 compression_encoding: Option<CompressionEncoding>,
138 max_message_size: Option<usize>,
139 buffer_settings: BufferSettings,
140 item: T::Item,
141) -> Result<(), Status>
142where
143 T: Encoder<Error = Status>,
144{
145 let offset = buf.len();
146
147 buf.reserve(HEADER_SIZE);
148 unsafe {
149 buf.advance_mut(HEADER_SIZE);
150 }
151
152 if let Some(encoding) = compression_encoding {
153 uncompression_buf.clear();
154
155 encoder
156 .encode(item, &mut EncodeBuf::new(uncompression_buf))
157 .map_err(|err| Status::internal(format!("Error encoding: {}", err)))?;
158
159 let uncompressed_len = uncompression_buf.len();
160
161 compress(
162 CompressionSettings {
163 encoding,
164 buffer_growth_interval: buffer_settings.buffer_size,
165 },
166 uncompression_buf,
167 buf,
168 uncompressed_len,
169 )
170 .map_err(|err| Status::internal(format!("Error compressing: {}", err)))?;
171 } else {
172 encoder
173 .encode(item, &mut EncodeBuf::new(buf))
174 .map_err(|err| Status::internal(format!("Error encoding: {}", err)))?;
175 }
176
177 finish_encoding(compression_encoding, max_message_size, &mut buf[offset..])
179}
180
181fn finish_encoding(
182 compression_encoding: Option<CompressionEncoding>,
183 max_message_size: Option<usize>,
184 buf: &mut [u8],
185) -> Result<(), Status> {
186 let len = buf.len() - HEADER_SIZE;
187 let limit = max_message_size.unwrap_or(DEFAULT_MAX_SEND_MESSAGE_SIZE);
188 if len > limit {
189 return Err(Status::out_of_range(format!(
190 "Error, encoded message length too large: found {} bytes, the limit is: {} bytes",
191 len, limit
192 )));
193 }
194
195 if len > u32::MAX as usize {
196 return Err(Status::resource_exhausted(format!(
197 "Cannot return body with more than 4GB of data but got {len} bytes"
198 )));
199 }
200 {
201 let mut buf = &mut buf[..HEADER_SIZE];
202 buf.put_u8(compression_encoding.is_some() as u8);
203 buf.put_u32(len as u32);
204 }
205
206 Ok(())
207}
208
209#[derive(Debug)]
210enum Role {
211 Client,
212 Server,
213}
214
215#[pin_project]
217#[derive(Debug)]
218pub struct EncodeBody<T, U> {
219 #[pin]
220 inner: EncodedBytes<T, U>,
221 state: EncodeState,
222}
223
224#[derive(Debug)]
225struct EncodeState {
226 error: Option<Status>,
227 role: Role,
228 is_end_stream: bool,
229}
230
231impl<T: Encoder, U: Stream> EncodeBody<T, U> {
232 pub fn new_client(
235 encoder: T,
236 source: U,
237 compression_encoding: Option<CompressionEncoding>,
238 max_message_size: Option<usize>,
239 ) -> Self {
240 Self {
241 inner: EncodedBytes::new(
242 encoder,
243 source,
244 compression_encoding,
245 SingleMessageCompressionOverride::default(),
246 max_message_size,
247 ),
248 state: EncodeState {
249 error: None,
250 role: Role::Client,
251 is_end_stream: false,
252 },
253 }
254 }
255
256 pub fn new_server(
259 encoder: T,
260 source: U,
261 compression_encoding: Option<CompressionEncoding>,
262 compression_override: SingleMessageCompressionOverride,
263 max_message_size: Option<usize>,
264 ) -> Self {
265 Self {
266 inner: EncodedBytes::new(
267 encoder,
268 source,
269 compression_encoding,
270 compression_override,
271 max_message_size,
272 ),
273 state: EncodeState {
274 error: None,
275 role: Role::Server,
276 is_end_stream: false,
277 },
278 }
279 }
280}
281
282impl EncodeState {
283 fn trailers(&mut self) -> Option<Result<HeaderMap, Status>> {
284 match self.role {
285 Role::Client => None,
286 Role::Server => {
287 if self.is_end_stream {
288 return None;
289 }
290
291 self.is_end_stream = true;
292 let status = if let Some(status) = self.error.take() {
293 status
294 } else {
295 Status::ok("")
296 };
297 Some(status.to_header_map())
298 }
299 }
300 }
301}
302
303impl<T, U> Body for EncodeBody<T, U>
304where
305 T: Encoder<Error = Status>,
306 U: Stream<Item = Result<T::Item, Status>>,
307{
308 type Data = Bytes;
309 type Error = Status;
310
311 fn is_end_stream(&self) -> bool {
312 self.state.is_end_stream
313 }
314
315 fn poll_frame(
316 self: Pin<&mut Self>,
317 cx: &mut Context<'_>,
318 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
319 let self_proj = self.project();
320 match ready!(self_proj.inner.poll_next(cx)) {
321 Some(Ok(d)) => Some(Ok(Frame::data(d))).into(),
322 Some(Err(status)) => match self_proj.state.role {
323 Role::Client => Some(Err(status)).into(),
324 Role::Server => {
325 self_proj.state.is_end_stream = true;
326 Some(Ok(Frame::trailers(status.to_header_map()?))).into()
327 }
328 },
329 None => self_proj
330 .state
331 .trailers()
332 .map(|t| t.map(Frame::trailers))
333 .into(),
334 }
335 }
336}