1use crate::{
21 BoxStream, GetOptions, GetResult, GetResultPayload, ListResult, MultipartUpload, ObjectMeta,
22 ObjectStore, Path, PutMultipartOpts, PutOptions, PutPayload, PutResult, Result, StreamExt,
23 UploadPart,
24};
25use async_trait::async_trait;
26use bytes::Bytes;
27use futures::{FutureExt, Stream};
28use std::ops::Range;
29use std::pin::Pin;
30use std::sync::Arc;
31use std::task::{Context, Poll};
32use tokio::sync::{OwnedSemaphorePermit, Semaphore};
33
34#[derive(Debug)]
47pub struct LimitStore<T: ObjectStore> {
48 inner: T,
49 max_requests: usize,
50 semaphore: Arc<Semaphore>,
51}
52
53impl<T: ObjectStore> LimitStore<T> {
54 pub fn new(inner: T, max_requests: usize) -> Self {
58 Self {
59 inner,
60 max_requests,
61 semaphore: Arc::new(Semaphore::new(max_requests)),
62 }
63 }
64}
65
66impl<T: ObjectStore> std::fmt::Display for LimitStore<T> {
67 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
68 write!(f, "LimitStore({}, {})", self.max_requests, self.inner)
69 }
70}
71
72#[async_trait]
73impl<T: ObjectStore> ObjectStore for LimitStore<T> {
74 async fn put(&self, location: &Path, payload: PutPayload) -> Result<PutResult> {
75 let _permit = self.semaphore.acquire().await.unwrap();
76 self.inner.put(location, payload).await
77 }
78
79 async fn put_opts(
80 &self,
81 location: &Path,
82 payload: PutPayload,
83 opts: PutOptions,
84 ) -> Result<PutResult> {
85 let _permit = self.semaphore.acquire().await.unwrap();
86 self.inner.put_opts(location, payload, opts).await
87 }
88 async fn put_multipart(&self, location: &Path) -> Result<Box<dyn MultipartUpload>> {
89 let upload = self.inner.put_multipart(location).await?;
90 Ok(Box::new(LimitUpload {
91 semaphore: Arc::clone(&self.semaphore),
92 upload,
93 }))
94 }
95
96 async fn put_multipart_opts(
97 &self,
98 location: &Path,
99 opts: PutMultipartOpts,
100 ) -> Result<Box<dyn MultipartUpload>> {
101 let upload = self.inner.put_multipart_opts(location, opts).await?;
102 Ok(Box::new(LimitUpload {
103 semaphore: Arc::clone(&self.semaphore),
104 upload,
105 }))
106 }
107
108 async fn get(&self, location: &Path) -> Result<GetResult> {
109 let permit = Arc::clone(&self.semaphore).acquire_owned().await.unwrap();
110 let r = self.inner.get(location).await?;
111 Ok(permit_get_result(r, permit))
112 }
113
114 async fn get_opts(&self, location: &Path, options: GetOptions) -> Result<GetResult> {
115 let permit = Arc::clone(&self.semaphore).acquire_owned().await.unwrap();
116 let r = self.inner.get_opts(location, options).await?;
117 Ok(permit_get_result(r, permit))
118 }
119
120 async fn get_range(&self, location: &Path, range: Range<usize>) -> Result<Bytes> {
121 let _permit = self.semaphore.acquire().await.unwrap();
122 self.inner.get_range(location, range).await
123 }
124
125 async fn get_ranges(&self, location: &Path, ranges: &[Range<usize>]) -> Result<Vec<Bytes>> {
126 let _permit = self.semaphore.acquire().await.unwrap();
127 self.inner.get_ranges(location, ranges).await
128 }
129
130 async fn head(&self, location: &Path) -> Result<ObjectMeta> {
131 let _permit = self.semaphore.acquire().await.unwrap();
132 self.inner.head(location).await
133 }
134
135 async fn delete(&self, location: &Path) -> Result<()> {
136 let _permit = self.semaphore.acquire().await.unwrap();
137 self.inner.delete(location).await
138 }
139
140 fn delete_stream<'a>(
141 &'a self,
142 locations: BoxStream<'a, Result<Path>>,
143 ) -> BoxStream<'a, Result<Path>> {
144 self.inner.delete_stream(locations)
145 }
146
147 fn list(&self, prefix: Option<&Path>) -> BoxStream<'_, Result<ObjectMeta>> {
148 let prefix = prefix.cloned();
149 let fut = Arc::clone(&self.semaphore)
150 .acquire_owned()
151 .map(move |permit| {
152 let s = self.inner.list(prefix.as_ref());
153 PermitWrapper::new(s, permit.unwrap())
154 });
155 fut.into_stream().flatten().boxed()
156 }
157
158 fn list_with_offset(
159 &self,
160 prefix: Option<&Path>,
161 offset: &Path,
162 ) -> BoxStream<'_, Result<ObjectMeta>> {
163 let prefix = prefix.cloned();
164 let offset = offset.clone();
165 let fut = Arc::clone(&self.semaphore)
166 .acquire_owned()
167 .map(move |permit| {
168 let s = self.inner.list_with_offset(prefix.as_ref(), &offset);
169 PermitWrapper::new(s, permit.unwrap())
170 });
171 fut.into_stream().flatten().boxed()
172 }
173
174 async fn list_with_delimiter(&self, prefix: Option<&Path>) -> Result<ListResult> {
175 let _permit = self.semaphore.acquire().await.unwrap();
176 self.inner.list_with_delimiter(prefix).await
177 }
178
179 async fn copy(&self, from: &Path, to: &Path) -> Result<()> {
180 let _permit = self.semaphore.acquire().await.unwrap();
181 self.inner.copy(from, to).await
182 }
183
184 async fn rename(&self, from: &Path, to: &Path) -> Result<()> {
185 let _permit = self.semaphore.acquire().await.unwrap();
186 self.inner.rename(from, to).await
187 }
188
189 async fn copy_if_not_exists(&self, from: &Path, to: &Path) -> Result<()> {
190 let _permit = self.semaphore.acquire().await.unwrap();
191 self.inner.copy_if_not_exists(from, to).await
192 }
193
194 async fn rename_if_not_exists(&self, from: &Path, to: &Path) -> Result<()> {
195 let _permit = self.semaphore.acquire().await.unwrap();
196 self.inner.rename_if_not_exists(from, to).await
197 }
198}
199
200fn permit_get_result(r: GetResult, permit: OwnedSemaphorePermit) -> GetResult {
201 let payload = match r.payload {
202 v @ GetResultPayload::File(_, _) => v,
203 GetResultPayload::Stream(s) => {
204 GetResultPayload::Stream(PermitWrapper::new(s, permit).boxed())
205 }
206 };
207 GetResult { payload, ..r }
208}
209
210struct PermitWrapper<T> {
212 inner: T,
213 #[allow(dead_code)]
214 permit: OwnedSemaphorePermit,
215}
216
217impl<T> PermitWrapper<T> {
218 fn new(inner: T, permit: OwnedSemaphorePermit) -> Self {
219 Self { inner, permit }
220 }
221}
222
223impl<T: Stream + Unpin> Stream for PermitWrapper<T> {
224 type Item = T::Item;
225
226 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
227 Pin::new(&mut self.inner).poll_next(cx)
228 }
229
230 fn size_hint(&self) -> (usize, Option<usize>) {
231 self.inner.size_hint()
232 }
233}
234
235#[derive(Debug)]
237pub struct LimitUpload {
238 upload: Box<dyn MultipartUpload>,
239 semaphore: Arc<Semaphore>,
240}
241
242impl LimitUpload {
243 pub fn new(upload: Box<dyn MultipartUpload>, max_concurrency: usize) -> Self {
245 Self {
246 upload,
247 semaphore: Arc::new(Semaphore::new(max_concurrency)),
248 }
249 }
250}
251
252#[async_trait]
253impl MultipartUpload for LimitUpload {
254 fn put_part(&mut self, data: PutPayload) -> UploadPart {
255 let upload = self.upload.put_part(data);
256 let s = Arc::clone(&self.semaphore);
257 Box::pin(async move {
258 let _permit = s.acquire().await.unwrap();
259 upload.await
260 })
261 }
262
263 async fn complete(&mut self) -> Result<PutResult> {
264 let _permit = self.semaphore.acquire().await.unwrap();
265 self.upload.complete().await
266 }
267
268 async fn abort(&mut self) -> Result<()> {
269 let _permit = self.semaphore.acquire().await.unwrap();
270 self.upload.abort().await
271 }
272}
273
274#[cfg(test)]
275mod tests {
276 use crate::integration::*;
277 use crate::limit::LimitStore;
278 use crate::memory::InMemory;
279 use crate::ObjectStore;
280 use futures::stream::StreamExt;
281 use std::pin::Pin;
282 use std::time::Duration;
283 use tokio::time::timeout;
284
285 #[tokio::test]
286 async fn limit_test() {
287 let max_requests = 10;
288 let memory = InMemory::new();
289 let integration = LimitStore::new(memory, max_requests);
290
291 put_get_delete_list(&integration).await;
292 get_opts(&integration).await;
293 list_uses_directories_correctly(&integration).await;
294 list_with_delimiter(&integration).await;
295 rename_and_copy(&integration).await;
296 stream_get(&integration).await;
297
298 let mut streams = Vec::with_capacity(max_requests);
299 for _ in 0..max_requests {
300 let mut stream = integration.list(None).peekable();
301 Pin::new(&mut stream).peek().await; streams.push(stream);
303 }
304
305 let t = Duration::from_millis(20);
306
307 let fut = integration.list(None).collect::<Vec<_>>();
309 assert!(timeout(t, fut).await.is_err());
310
311 streams.pop();
313
314 integration.list(None).collect::<Vec<_>>().await;
316 }
317}