object_store/
limit.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! An object store that limits the maximum concurrency of the wrapped implementation
19
20use crate::{
21    BoxStream, GetOptions, GetResult, GetResultPayload, ListResult, MultipartUpload, ObjectMeta,
22    ObjectStore, Path, PutMultipartOpts, PutOptions, PutPayload, PutResult, Result, StreamExt,
23    UploadPart,
24};
25use async_trait::async_trait;
26use bytes::Bytes;
27use futures::{FutureExt, Stream};
28use std::ops::Range;
29use std::pin::Pin;
30use std::sync::Arc;
31use std::task::{Context, Poll};
32use tokio::sync::{OwnedSemaphorePermit, Semaphore};
33
34/// Store wrapper that wraps an inner store and limits the maximum number of concurrent
35/// object store operations. Where each call to an [`ObjectStore`] member function is
36/// considered a single operation, even if it may result in more than one network call
37///
38/// ```
39/// # use object_store::memory::InMemory;
40/// # use object_store::limit::LimitStore;
41///
42/// // Create an in-memory `ObjectStore` limited to 20 concurrent requests
43/// let store = LimitStore::new(InMemory::new(), 20);
44/// ```
45///
46#[derive(Debug)]
47pub struct LimitStore<T: ObjectStore> {
48    inner: T,
49    max_requests: usize,
50    semaphore: Arc<Semaphore>,
51}
52
53impl<T: ObjectStore> LimitStore<T> {
54    /// Create new limit store that will limit the maximum
55    /// number of outstanding concurrent requests to
56    /// `max_requests`
57    pub fn new(inner: T, max_requests: usize) -> Self {
58        Self {
59            inner,
60            max_requests,
61            semaphore: Arc::new(Semaphore::new(max_requests)),
62        }
63    }
64}
65
66impl<T: ObjectStore> std::fmt::Display for LimitStore<T> {
67    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
68        write!(f, "LimitStore({}, {})", self.max_requests, self.inner)
69    }
70}
71
72#[async_trait]
73impl<T: ObjectStore> ObjectStore for LimitStore<T> {
74    async fn put(&self, location: &Path, payload: PutPayload) -> Result<PutResult> {
75        let _permit = self.semaphore.acquire().await.unwrap();
76        self.inner.put(location, payload).await
77    }
78
79    async fn put_opts(
80        &self,
81        location: &Path,
82        payload: PutPayload,
83        opts: PutOptions,
84    ) -> Result<PutResult> {
85        let _permit = self.semaphore.acquire().await.unwrap();
86        self.inner.put_opts(location, payload, opts).await
87    }
88    async fn put_multipart(&self, location: &Path) -> Result<Box<dyn MultipartUpload>> {
89        let upload = self.inner.put_multipart(location).await?;
90        Ok(Box::new(LimitUpload {
91            semaphore: Arc::clone(&self.semaphore),
92            upload,
93        }))
94    }
95
96    async fn put_multipart_opts(
97        &self,
98        location: &Path,
99        opts: PutMultipartOpts,
100    ) -> Result<Box<dyn MultipartUpload>> {
101        let upload = self.inner.put_multipart_opts(location, opts).await?;
102        Ok(Box::new(LimitUpload {
103            semaphore: Arc::clone(&self.semaphore),
104            upload,
105        }))
106    }
107
108    async fn get(&self, location: &Path) -> Result<GetResult> {
109        let permit = Arc::clone(&self.semaphore).acquire_owned().await.unwrap();
110        let r = self.inner.get(location).await?;
111        Ok(permit_get_result(r, permit))
112    }
113
114    async fn get_opts(&self, location: &Path, options: GetOptions) -> Result<GetResult> {
115        let permit = Arc::clone(&self.semaphore).acquire_owned().await.unwrap();
116        let r = self.inner.get_opts(location, options).await?;
117        Ok(permit_get_result(r, permit))
118    }
119
120    async fn get_range(&self, location: &Path, range: Range<usize>) -> Result<Bytes> {
121        let _permit = self.semaphore.acquire().await.unwrap();
122        self.inner.get_range(location, range).await
123    }
124
125    async fn get_ranges(&self, location: &Path, ranges: &[Range<usize>]) -> Result<Vec<Bytes>> {
126        let _permit = self.semaphore.acquire().await.unwrap();
127        self.inner.get_ranges(location, ranges).await
128    }
129
130    async fn head(&self, location: &Path) -> Result<ObjectMeta> {
131        let _permit = self.semaphore.acquire().await.unwrap();
132        self.inner.head(location).await
133    }
134
135    async fn delete(&self, location: &Path) -> Result<()> {
136        let _permit = self.semaphore.acquire().await.unwrap();
137        self.inner.delete(location).await
138    }
139
140    fn delete_stream<'a>(
141        &'a self,
142        locations: BoxStream<'a, Result<Path>>,
143    ) -> BoxStream<'a, Result<Path>> {
144        self.inner.delete_stream(locations)
145    }
146
147    fn list(&self, prefix: Option<&Path>) -> BoxStream<'_, Result<ObjectMeta>> {
148        let prefix = prefix.cloned();
149        let fut = Arc::clone(&self.semaphore)
150            .acquire_owned()
151            .map(move |permit| {
152                let s = self.inner.list(prefix.as_ref());
153                PermitWrapper::new(s, permit.unwrap())
154            });
155        fut.into_stream().flatten().boxed()
156    }
157
158    fn list_with_offset(
159        &self,
160        prefix: Option<&Path>,
161        offset: &Path,
162    ) -> BoxStream<'_, Result<ObjectMeta>> {
163        let prefix = prefix.cloned();
164        let offset = offset.clone();
165        let fut = Arc::clone(&self.semaphore)
166            .acquire_owned()
167            .map(move |permit| {
168                let s = self.inner.list_with_offset(prefix.as_ref(), &offset);
169                PermitWrapper::new(s, permit.unwrap())
170            });
171        fut.into_stream().flatten().boxed()
172    }
173
174    async fn list_with_delimiter(&self, prefix: Option<&Path>) -> Result<ListResult> {
175        let _permit = self.semaphore.acquire().await.unwrap();
176        self.inner.list_with_delimiter(prefix).await
177    }
178
179    async fn copy(&self, from: &Path, to: &Path) -> Result<()> {
180        let _permit = self.semaphore.acquire().await.unwrap();
181        self.inner.copy(from, to).await
182    }
183
184    async fn rename(&self, from: &Path, to: &Path) -> Result<()> {
185        let _permit = self.semaphore.acquire().await.unwrap();
186        self.inner.rename(from, to).await
187    }
188
189    async fn copy_if_not_exists(&self, from: &Path, to: &Path) -> Result<()> {
190        let _permit = self.semaphore.acquire().await.unwrap();
191        self.inner.copy_if_not_exists(from, to).await
192    }
193
194    async fn rename_if_not_exists(&self, from: &Path, to: &Path) -> Result<()> {
195        let _permit = self.semaphore.acquire().await.unwrap();
196        self.inner.rename_if_not_exists(from, to).await
197    }
198}
199
200fn permit_get_result(r: GetResult, permit: OwnedSemaphorePermit) -> GetResult {
201    let payload = match r.payload {
202        v @ GetResultPayload::File(_, _) => v,
203        GetResultPayload::Stream(s) => {
204            GetResultPayload::Stream(PermitWrapper::new(s, permit).boxed())
205        }
206    };
207    GetResult { payload, ..r }
208}
209
210/// Combines an [`OwnedSemaphorePermit`] with some other type
211struct PermitWrapper<T> {
212    inner: T,
213    #[allow(dead_code)]
214    permit: OwnedSemaphorePermit,
215}
216
217impl<T> PermitWrapper<T> {
218    fn new(inner: T, permit: OwnedSemaphorePermit) -> Self {
219        Self { inner, permit }
220    }
221}
222
223impl<T: Stream + Unpin> Stream for PermitWrapper<T> {
224    type Item = T::Item;
225
226    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
227        Pin::new(&mut self.inner).poll_next(cx)
228    }
229
230    fn size_hint(&self) -> (usize, Option<usize>) {
231        self.inner.size_hint()
232    }
233}
234
235/// An [`MultipartUpload`] wrapper that limits the maximum number of concurrent requests
236#[derive(Debug)]
237pub struct LimitUpload {
238    upload: Box<dyn MultipartUpload>,
239    semaphore: Arc<Semaphore>,
240}
241
242impl LimitUpload {
243    /// Create a new [`LimitUpload`] limiting `upload` to `max_concurrency` concurrent requests
244    pub fn new(upload: Box<dyn MultipartUpload>, max_concurrency: usize) -> Self {
245        Self {
246            upload,
247            semaphore: Arc::new(Semaphore::new(max_concurrency)),
248        }
249    }
250}
251
252#[async_trait]
253impl MultipartUpload for LimitUpload {
254    fn put_part(&mut self, data: PutPayload) -> UploadPart {
255        let upload = self.upload.put_part(data);
256        let s = Arc::clone(&self.semaphore);
257        Box::pin(async move {
258            let _permit = s.acquire().await.unwrap();
259            upload.await
260        })
261    }
262
263    async fn complete(&mut self) -> Result<PutResult> {
264        let _permit = self.semaphore.acquire().await.unwrap();
265        self.upload.complete().await
266    }
267
268    async fn abort(&mut self) -> Result<()> {
269        let _permit = self.semaphore.acquire().await.unwrap();
270        self.upload.abort().await
271    }
272}
273
274#[cfg(test)]
275mod tests {
276    use crate::integration::*;
277    use crate::limit::LimitStore;
278    use crate::memory::InMemory;
279    use crate::ObjectStore;
280    use futures::stream::StreamExt;
281    use std::pin::Pin;
282    use std::time::Duration;
283    use tokio::time::timeout;
284
285    #[tokio::test]
286    async fn limit_test() {
287        let max_requests = 10;
288        let memory = InMemory::new();
289        let integration = LimitStore::new(memory, max_requests);
290
291        put_get_delete_list(&integration).await;
292        get_opts(&integration).await;
293        list_uses_directories_correctly(&integration).await;
294        list_with_delimiter(&integration).await;
295        rename_and_copy(&integration).await;
296        stream_get(&integration).await;
297
298        let mut streams = Vec::with_capacity(max_requests);
299        for _ in 0..max_requests {
300            let mut stream = integration.list(None).peekable();
301            Pin::new(&mut stream).peek().await; // Ensure semaphore is acquired
302            streams.push(stream);
303        }
304
305        let t = Duration::from_millis(20);
306
307        // Expect to not be able to make another request
308        let fut = integration.list(None).collect::<Vec<_>>();
309        assert!(timeout(t, fut).await.is_err());
310
311        // Drop one of the streams
312        streams.pop();
313
314        // Can now make another request
315        integration.list(None).collect::<Vec<_>>().await;
316    }
317}