1use pin_project_lite::pin_project;
10use std::collections::BTreeSet;
11use std::pin::Pin;
12use std::sync::Arc;
13use std::sync::atomic::{AtomicBool, Ordering, fence};
14use std::task::{Poll, ready};
15use tokio::io::{AsyncBufRead, AsyncRead, ReadBuf};
16use wu_manber::TwoByteWM;
17
18pub struct ReferencePatternInner<P> {
23 candidates: Vec<P>,
24 longest_candidate: usize,
25 searcher: Option<TwoByteWM>,
27}
28
29#[derive(Clone)]
30pub struct ReferencePattern<P> {
31 inner: Arc<ReferencePatternInner<P>>,
32}
33
34impl<P> ReferencePattern<P> {
35 pub fn candidates(&self) -> &[P] {
36 &self.inner.candidates
37 }
38
39 pub fn longest_candidate(&self) -> usize {
40 self.inner.longest_candidate
41 }
42}
43
44impl<P: AsRef<[u8]>> ReferencePattern<P> {
45 pub fn new(candidates: Vec<P>) -> Self {
48 let searcher = if candidates.is_empty() {
49 None
50 } else {
51 Some(TwoByteWM::new(&candidates))
52 };
53 let longest_candidate = candidates.iter().fold(0, |v, c| v.max(c.as_ref().len()));
54
55 ReferencePattern {
56 inner: Arc::new(ReferencePatternInner {
57 searcher,
58 candidates,
59 longest_candidate,
60 }),
61 }
62 }
63}
64
65impl<P> From<Vec<P>> for ReferencePattern<P>
66where
67 P: AsRef<[u8]>,
68{
69 fn from(candidates: Vec<P>) -> Self {
70 Self::new(candidates)
71 }
72}
73
74pub struct ReferenceScanner<P> {
77 pattern: ReferencePattern<P>,
78 matches: Vec<AtomicBool>,
79}
80
81impl<P: AsRef<[u8]>> ReferenceScanner<P> {
82 pub fn new<IP: Into<ReferencePattern<P>>>(pattern: IP) -> Self {
85 let pattern = pattern.into();
86 let mut matches = Vec::new();
87 for _ in 0..pattern.candidates().len() {
88 matches.push(AtomicBool::new(false));
89 }
90 ReferenceScanner { pattern, matches }
91 }
92
93 pub fn scan<S: AsRef<[u8]>>(&self, haystack: S) {
96 if haystack.as_ref().len() < self.pattern.longest_candidate() {
97 return;
98 }
99
100 if let Some(searcher) = &self.pattern.inner.searcher {
101 for m in searcher.find(haystack) {
102 self.matches[m.pat_idx].store(true, Ordering::Relaxed);
103 }
104 fence(Ordering::Release);
105 }
106 }
107
108 pub fn pattern(&self) -> &ReferencePattern<P> {
109 &self.pattern
110 }
111
112 pub fn matches(&self) -> Vec<bool> {
113 fence(Ordering::Acquire);
114 self.matches
115 .iter()
116 .map(|m| m.load(Ordering::Relaxed))
117 .collect()
118 }
119
120 pub fn candidate_matches(&self) -> impl Iterator<Item = &P> {
121 let candidates = self.pattern.candidates();
122 fence(Ordering::Acquire);
123 Iterator::zip(candidates.iter(), self.matches.iter())
124 .filter_map(|(candidate, found)| found.load(Ordering::Relaxed).then_some(candidate))
125 }
126}
127
128impl<P: Clone + Ord + AsRef<[u8]>> ReferenceScanner<P> {
129 pub fn finalise(self) -> BTreeSet<P> {
131 self.candidate_matches().cloned().collect()
132 }
133}
134
135const DEFAULT_BUF_SIZE: usize = 8 * 1024;
136
137pin_project! {
138 pub struct ReferenceReader<'a, P, R> {
139 scanner: &'a ReferenceScanner<P>,
140 buffer: Vec<u8>,
141 consumed: usize,
142 #[pin]
143 reader: R,
144 }
145}
146
147impl<'a, P, R> ReferenceReader<'a, P, R>
148where
149 P: AsRef<[u8]>,
150{
151 pub fn new(scanner: &'a ReferenceScanner<P>, reader: R) -> Self {
152 Self::with_capacity(DEFAULT_BUF_SIZE, scanner, reader)
153 }
154
155 pub fn with_capacity(capacity: usize, scanner: &'a ReferenceScanner<P>, reader: R) -> Self {
156 let capacity = capacity.max(scanner.pattern().longest_candidate());
158 ReferenceReader {
159 scanner,
160 buffer: Vec::with_capacity(capacity),
161 consumed: 0,
162 reader,
163 }
164 }
165}
166
167impl<P, R> AsyncRead for ReferenceReader<'_, P, R>
168where
169 R: AsyncRead,
170 P: AsRef<[u8]>,
171{
172 fn poll_read(
173 mut self: Pin<&mut Self>,
174 cx: &mut std::task::Context<'_>,
175 buf: &mut tokio::io::ReadBuf<'_>,
176 ) -> Poll<std::io::Result<()>> {
177 let internal_buf = ready!(self.as_mut().poll_fill_buf(cx))?;
178 let amt = buf.remaining().min(internal_buf.len());
179 buf.put_slice(&internal_buf[..amt]);
180 self.consume(amt);
181 Poll::Ready(Ok(()))
182 }
183}
184
185impl<P, R> AsyncBufRead for ReferenceReader<'_, P, R>
186where
187 R: AsyncRead,
188 P: AsRef<[u8]>,
189{
190 fn poll_fill_buf(
191 self: Pin<&mut Self>,
192 cx: &mut std::task::Context<'_>,
193 ) -> Poll<std::io::Result<&[u8]>> {
194 #[allow(clippy::manual_saturating_arithmetic)] let overlap = self
196 .scanner
197 .pattern
198 .longest_candidate()
199 .checked_sub(1)
200 .unwrap_or(0);
203 let mut this = self.project();
204 if *this.consumed < this.buffer.len() {
206 return Poll::Ready(Ok(&this.buffer[*this.consumed..]));
207 }
208 if *this.consumed > overlap {
210 let start = this.buffer.len() - overlap;
211 this.buffer.copy_within(start.., 0);
212 this.buffer.truncate(overlap);
213 *this.consumed = overlap;
214 }
215 loop {
217 let filled = {
218 let mut buf = ReadBuf::uninit(this.buffer.spare_capacity_mut());
219 ready!(this.reader.as_mut().poll_read(cx, &mut buf))?;
220 buf.filled().len()
221 };
222 unsafe {
224 this.buffer.set_len(filled + this.buffer.len());
225 }
226 if filled == 0 || this.buffer.len() > overlap {
227 break;
228 }
229 }
230
231 #[allow(clippy::needless_borrows_for_generic_args)] this.scanner.scan(&this.buffer);
233
234 Poll::Ready(Ok(&this.buffer[*this.consumed..]))
235 }
236
237 fn consume(self: Pin<&mut Self>, amt: usize) {
238 debug_assert!(self.consumed + amt <= self.buffer.len());
239 let this = self.project();
240 *this.consumed += amt;
241 }
242}
243
244#[cfg(test)]
245mod tests {
246 use rstest::rstest;
247 use tokio::io::AsyncReadExt as _;
248 use tokio_test::io::Builder;
249
250 use super::*;
251
252 const HELLO_DRV: &str = r#"Derive([("out","/nix/store/33l4p0pn0mybmqzaxfkpppyh7vx1c74p-hello-2.12.1","","")],[("/nix/store/6z1jfnqqgyqr221zgbpm30v91yfj3r45-bash-5.1-p16.drv",["out"]),("/nix/store/ap9g09fxbicj836zm88d56dn3ff4clxl-stdenv-linux.drv",["out"]),("/nix/store/pf80kikyxr63wrw56k00i1kw6ba76qik-hello-2.12.1.tar.gz.drv",["out"])],["/nix/store/9krlzvny65gdc8s7kpb6lkx8cd02c25b-default-builder.sh"],"x86_64-linux","/nix/store/4xw8n979xpivdc46a9ndcvyhwgif00hz-bash-5.1-p16/bin/bash",["-e","/nix/store/9krlzvny65gdc8s7kpb6lkx8cd02c25b-default-builder.sh"],[("buildInputs",""),("builder","/nix/store/4xw8n979xpivdc46a9ndcvyhwgif00hz-bash-5.1-p16/bin/bash"),("cmakeFlags",""),("configureFlags",""),("depsBuildBuild",""),("depsBuildBuildPropagated",""),("depsBuildTarget",""),("depsBuildTargetPropagated",""),("depsHostHost",""),("depsHostHostPropagated",""),("depsTargetTarget",""),("depsTargetTargetPropagated",""),("doCheck","1"),("doInstallCheck",""),("mesonFlags",""),("name","hello-2.12.1"),("nativeBuildInputs",""),("out","/nix/store/33l4p0pn0mybmqzaxfkpppyh7vx1c74p-hello-2.12.1"),("outputs","out"),("patches",""),("pname","hello"),("propagatedBuildInputs",""),("propagatedNativeBuildInputs",""),("src","/nix/store/pa10z4ngm0g83kx9mssrqzz30s84vq7k-hello-2.12.1.tar.gz"),("stdenv","/nix/store/cp65c8nk29qq5cl1wyy5qyw103cwmax7-stdenv-linux"),("strictDeps",""),("system","x86_64-linux"),("version","2.12.1")])"#;
254
255 #[test]
256 fn test_no_patterns() {
257 let scanner: ReferenceScanner<String> = ReferenceScanner::new(vec![]);
258
259 scanner.scan(HELLO_DRV);
260
261 let result = scanner.finalise();
262
263 assert_eq!(result.len(), 0);
264 }
265
266 #[test]
267 fn test_single_match() {
268 let scanner = ReferenceScanner::new(vec![
269 "/nix/store/4xw8n979xpivdc46a9ndcvyhwgif00hz-bash-5.1-p16".to_string(),
270 ]);
271 scanner.scan(HELLO_DRV);
272
273 let result = scanner.finalise();
274
275 assert_eq!(result.len(), 1);
276 assert!(result.contains("/nix/store/4xw8n979xpivdc46a9ndcvyhwgif00hz-bash-5.1-p16"));
277 }
278
279 #[test]
280 fn test_multiple_matches() {
281 let candidates = vec![
282 "/nix/store/33l4p0pn0mybmqzaxfkpppyh7vx1c74p-hello-2.12.1".to_string(),
284 "/nix/store/pf80kikyxr63wrw56k00i1kw6ba76qik-hello-2.12.1.tar.gz.drv".to_string(),
285 "/nix/store/cp65c8nk29qq5cl1wyy5qyw103cwmax7-stdenv-linux".to_string(),
286 "/nix/store/fn7zvafq26f0c8b17brs7s95s10ibfzs-emacs-28.2.drv".to_string(),
288 ];
289
290 let scanner = ReferenceScanner::new(candidates.clone());
291 scanner.scan(HELLO_DRV);
292
293 let result = scanner.finalise();
294 assert_eq!(result.len(), 3);
295
296 for c in candidates[..3].iter() {
297 assert!(result.contains(c));
298 }
299 }
300
301 #[rstest]
302 #[case::normal(8096, 8096)]
303 #[case::small_capacity(8096, 1)]
304 #[case::small_read(1, 8096)]
305 #[case::all_small(1, 1)]
306 #[tokio::test]
307 async fn test_reference_reader(#[case] chunk_size: usize, #[case] capacity: usize) {
308 let candidates = vec![
309 "33l4p0pn0mybmqzaxfkpppyh7vx1c74p",
311 "pf80kikyxr63wrw56k00i1kw6ba76qik",
312 "cp65c8nk29qq5cl1wyy5qyw103cwmax7",
313 "fn7zvafq26f0c8b17brs7s95s10ibfzs",
315 ];
316 let pattern = ReferencePattern::new(candidates.clone());
317 let scanner = ReferenceScanner::new(pattern);
318 let mut mock = Builder::new();
319 for c in HELLO_DRV.as_bytes().chunks(chunk_size) {
320 mock.read(c);
321 }
322 let mock = mock.build();
323 let mut reader = ReferenceReader::with_capacity(capacity, &scanner, mock);
324 let mut s = String::new();
325 reader.read_to_string(&mut s).await.unwrap();
326 assert_eq!(s, HELLO_DRV);
327
328 let result = scanner.finalise();
329 assert_eq!(result.len(), 3);
330
331 for c in candidates[..3].iter() {
332 assert!(result.contains(c));
333 }
334 }
335
336 #[tokio::test]
337 async fn test_reference_reader_no_patterns() {
338 let pattern = ReferencePattern::new(Vec::<&str>::new());
339 let scanner = ReferenceScanner::new(pattern);
340 let mut mock = Builder::new();
341 mock.read(HELLO_DRV.as_bytes());
342 let mock = mock.build();
343 let mut reader = ReferenceReader::new(&scanner, mock);
344 let mut s = String::new();
345 reader.read_to_string(&mut s).await.unwrap();
346 assert_eq!(s, HELLO_DRV);
347
348 let result = scanner.finalise();
349 assert_eq!(result.len(), 0);
350 }
351
352 }