1use std::fs::File;
12use std::mem::size_of;
13use std::os::unix::io::{AsRawFd, FromRawFd, RawFd};
14use std::os::unix::net::{UnixDatagram, UnixStream};
15use std::ptr::{copy_nonoverlapping, null_mut, write_unaligned};
16
17use crate::errno::{Error, Result};
18use libc::{
19 c_long, c_void, cmsghdr, iovec, msghdr, recvmsg, sendmsg, MSG_NOSIGNAL, SCM_RIGHTS, SOL_SOCKET,
20};
21use std::os::raw::c_int;
22
23macro_rules! CMSG_ALIGN {
27 ($len:expr) => {
28 (($len) as usize + size_of::<c_long>() - 1) & !(size_of::<c_long>() - 1)
29 };
30}
31
32macro_rules! CMSG_SPACE {
33 ($len:expr) => {
34 size_of::<cmsghdr>() + CMSG_ALIGN!($len)
35 };
36}
37
38#[allow(non_snake_case)]
42#[inline(always)]
43fn CMSG_DATA(cmsg_buffer: *mut cmsghdr) -> *mut RawFd {
44 cmsg_buffer.wrapping_offset(1) as *mut RawFd
46}
47
48#[cfg(not(target_env = "musl"))]
49macro_rules! CMSG_LEN {
50 ($len:expr) => {
51 size_of::<cmsghdr>() + ($len)
52 };
53}
54
55#[cfg(target_env = "musl")]
56macro_rules! CMSG_LEN {
57 ($len:expr) => {{
58 let sz = size_of::<cmsghdr>() + ($len);
59 assert!(sz <= (std::u32::MAX as usize));
60 sz as u32
61 }};
62}
63
64#[cfg(not(target_env = "musl"))]
65fn new_msghdr(iovecs: &mut [iovec]) -> msghdr {
66 msghdr {
67 msg_name: null_mut(),
68 msg_namelen: 0,
69 msg_iov: iovecs.as_mut_ptr(),
70 msg_iovlen: iovecs.len(),
71 msg_control: null_mut(),
72 msg_controllen: 0,
73 msg_flags: 0,
74 }
75}
76
77#[cfg(target_env = "musl")]
78fn new_msghdr(iovecs: &mut [iovec]) -> msghdr {
79 assert!(iovecs.len() <= (std::i32::MAX as usize));
80 let mut msg: msghdr = unsafe { std::mem::zeroed() };
81 msg.msg_name = null_mut();
82 msg.msg_iov = iovecs.as_mut_ptr();
83 msg.msg_iovlen = iovecs.len() as i32;
84 msg.msg_control = null_mut();
85 msg
86}
87
88#[cfg(not(target_env = "musl"))]
89fn set_msg_controllen(msg: &mut msghdr, cmsg_capacity: usize) {
90 msg.msg_controllen = cmsg_capacity;
91}
92
93#[cfg(target_env = "musl")]
94fn set_msg_controllen(msg: &mut msghdr, cmsg_capacity: usize) {
95 assert!(cmsg_capacity <= (std::u32::MAX as usize));
96 msg.msg_controllen = cmsg_capacity as u32;
97}
98
99#[cfg_attr(
102 feature = "cargo-clippy",
103 allow(clippy::cast_ptr_alignment, clippy::unnecessary_cast)
104)]
105fn get_next_cmsg(msghdr: &msghdr, cmsg: &cmsghdr, cmsg_ptr: *mut cmsghdr) -> *mut cmsghdr {
106 let next_cmsg = (cmsg_ptr as *mut u8).wrapping_add(CMSG_ALIGN!(cmsg.cmsg_len)) as *mut cmsghdr;
107 if next_cmsg
108 .wrapping_offset(1)
109 .wrapping_sub(msghdr.msg_control as usize) as usize
110 > msghdr.msg_controllen as usize
111 {
112 null_mut()
113 } else {
114 next_cmsg
115 }
116}
117
118const CMSG_BUFFER_INLINE_CAPACITY: usize = CMSG_SPACE!(size_of::<RawFd>() * 32);
119
120enum CmsgBuffer {
121 Inline([u64; (CMSG_BUFFER_INLINE_CAPACITY + 7) / 8]),
122 Heap(Box<[cmsghdr]>),
123}
124
125impl CmsgBuffer {
126 fn with_capacity(capacity: usize) -> CmsgBuffer {
127 let cap_in_cmsghdr_units =
128 (capacity.checked_add(size_of::<cmsghdr>()).unwrap() - 1) / size_of::<cmsghdr>();
129 if capacity <= CMSG_BUFFER_INLINE_CAPACITY {
130 CmsgBuffer::Inline([0u64; (CMSG_BUFFER_INLINE_CAPACITY + 7) / 8])
131 } else {
132 CmsgBuffer::Heap(
133 vec![
134 cmsghdr {
135 cmsg_len: 0,
136 cmsg_level: 0,
137 cmsg_type: 0,
138 #[cfg(all(target_env = "musl", target_pointer_width = "64"))]
139 __pad1: 0,
140 };
141 cap_in_cmsghdr_units
142 ]
143 .into_boxed_slice(),
144 )
145 }
146 }
147
148 fn as_mut_ptr(&mut self) -> *mut cmsghdr {
149 match self {
150 CmsgBuffer::Inline(a) => a.as_mut_ptr() as *mut cmsghdr,
151 CmsgBuffer::Heap(a) => a.as_mut_ptr(),
152 }
153 }
154}
155
156fn raw_sendmsg<D: IntoIovec>(fd: RawFd, out_data: &[D], out_fds: &[RawFd]) -> Result<usize> {
157 let cmsg_capacity = CMSG_SPACE!(size_of::<RawFd>() * out_fds.len());
158 let mut cmsg_buffer = CmsgBuffer::with_capacity(cmsg_capacity);
159
160 let mut iovecs = Vec::with_capacity(out_data.len());
161 for data in out_data {
162 iovecs.push(iovec {
163 iov_base: data.as_ptr() as *mut c_void,
164 iov_len: data.size(),
165 });
166 }
167
168 let mut msg = new_msghdr(&mut iovecs);
169
170 if !out_fds.is_empty() {
171 let cmsg = cmsghdr {
172 cmsg_len: CMSG_LEN!(size_of::<RawFd>() * out_fds.len()),
173 cmsg_level: SOL_SOCKET,
174 cmsg_type: SCM_RIGHTS,
175 #[cfg(all(target_env = "musl", target_pointer_width = "64"))]
176 __pad1: 0,
177 };
178 unsafe {
180 write_unaligned(cmsg_buffer.as_mut_ptr() as *mut cmsghdr, cmsg);
182 copy_nonoverlapping(
185 out_fds.as_ptr(),
186 CMSG_DATA(cmsg_buffer.as_mut_ptr()),
187 out_fds.len(),
188 );
189 }
190
191 msg.msg_control = cmsg_buffer.as_mut_ptr() as *mut c_void;
192 set_msg_controllen(&mut msg, cmsg_capacity);
193 }
194
195 let write_count = unsafe { sendmsg(fd, &msg, MSG_NOSIGNAL) };
198
199 if write_count == -1 {
200 Err(Error::last())
201 } else {
202 Ok(write_count as usize)
203 }
204}
205
206#[cfg_attr(feature = "cargo-clippy", allow(clippy::unnecessary_cast))]
207unsafe fn raw_recvmsg(
208 fd: RawFd,
209 iovecs: &mut [iovec],
210 in_fds: &mut [RawFd],
211) -> Result<(usize, usize)> {
212 let cmsg_capacity = CMSG_SPACE!(size_of::<RawFd>() * in_fds.len());
213 let mut cmsg_buffer = CmsgBuffer::with_capacity(cmsg_capacity);
214 let mut msg = new_msghdr(iovecs);
215
216 if !in_fds.is_empty() {
217 msg.msg_control = cmsg_buffer.as_mut_ptr() as *mut c_void;
219 set_msg_controllen(&mut msg, cmsg_capacity);
220 }
221
222 let total_read = recvmsg(fd, &mut msg, 0);
226 if total_read == -1 {
227 return Err(Error::last());
228 }
229
230 if total_read == 0 && (msg.msg_controllen as usize) < size_of::<cmsghdr>() {
231 return Ok((0, 0));
232 }
233
234 let mut cmsg_ptr = msg.msg_control as *mut cmsghdr;
237 let mut copied_fds_count = 0;
238 let mut teardown_control_data = msg.msg_flags & libc::MSG_CTRUNC != 0;
241
242 while !cmsg_ptr.is_null() {
243 let cmsg = (cmsg_ptr as *mut cmsghdr).read_unaligned();
247 if cmsg.cmsg_level == SOL_SOCKET && cmsg.cmsg_type == SCM_RIGHTS {
248 let fds_count: usize = ((cmsg.cmsg_len - CMSG_LEN!(0)) as usize) / size_of::<RawFd>();
249 let fds_to_be_copied_count = std::cmp::min(in_fds.len() - copied_fds_count, fds_count);
253 teardown_control_data |= fds_count > fds_to_be_copied_count;
254 if teardown_control_data {
255 for fd_offset in 0..fds_count {
260 let raw_fds_ptr = CMSG_DATA(cmsg_ptr);
261 let raw_fd = *(raw_fds_ptr.wrapping_add(fd_offset)) as c_int;
264 libc::close(raw_fd);
265 }
266 } else {
267 copy_nonoverlapping(
270 CMSG_DATA(cmsg_ptr),
271 in_fds[copied_fds_count..(copied_fds_count + fds_to_be_copied_count)]
272 .as_mut_ptr(),
273 fds_to_be_copied_count,
274 );
275
276 copied_fds_count += fds_to_be_copied_count;
277 }
278 }
279
280 if teardown_control_data {
282 for fd in in_fds.iter().take(copied_fds_count) {
283 libc::close(*fd);
286 }
287
288 return Err(Error::new(libc::ENOBUFS));
289 }
290
291 cmsg_ptr = get_next_cmsg(&msg, &cmsg, cmsg_ptr);
292 }
293
294 Ok((total_read as usize, copied_fds_count))
295}
296
297pub trait ScmSocket {
339 fn socket_fd(&self) -> RawFd;
341
342 fn send_with_fd<D: IntoIovec>(&self, buf: D, fd: RawFd) -> Result<usize> {
351 self.send_with_fds(&[buf], &[fd])
352 }
353
354 fn send_with_fds<D: IntoIovec>(&self, bufs: &[D], fds: &[RawFd]) -> Result<usize> {
363 raw_sendmsg(self.socket_fd(), bufs, fds)
364 }
365
366 fn recv_with_fd(&self, buf: &mut [u8]) -> Result<(usize, Option<File>)> {
374 let mut fd = [0];
375 let mut iovecs = [iovec {
376 iov_base: buf.as_mut_ptr() as *mut c_void,
377 iov_len: buf.len(),
378 }];
379
380 let (read_count, fd_count) = unsafe { self.recv_with_fds(&mut iovecs[..], &mut fd)? };
383 let file = if fd_count == 0 {
384 None
385 } else {
386 Some(unsafe { File::from_raw_fd(fd[0]) })
389 };
390 Ok((read_count, file))
391 }
392
393 unsafe fn recv_with_fds(
412 &self,
413 iovecs: &mut [iovec],
414 fds: &mut [RawFd],
415 ) -> Result<(usize, usize)> {
416 raw_recvmsg(self.socket_fd(), iovecs, fds)
417 }
418}
419
420impl ScmSocket for UnixDatagram {
421 fn socket_fd(&self) -> RawFd {
422 self.as_raw_fd()
423 }
424}
425
426impl ScmSocket for UnixStream {
427 fn socket_fd(&self) -> RawFd {
428 self.as_raw_fd()
429 }
430}
431
432pub unsafe trait IntoIovec {
440 fn as_ptr(&self) -> *const c_void;
442
443 fn size(&self) -> usize;
445}
446
447unsafe impl<'a> IntoIovec for &'a [u8] {
450 #[cfg_attr(feature = "cargo-clippy", allow(clippy::useless_asref))]
452 fn as_ptr(&self) -> *const c_void {
453 self.as_ref().as_ptr() as *const c_void
454 }
455
456 fn size(&self) -> usize {
457 self.len()
458 }
459}
460
461#[cfg(test)]
462mod tests {
463 #![allow(clippy::undocumented_unsafe_blocks)]
464 use super::*;
465 use crate::eventfd::EventFd;
466
467 use std::io::Write;
468 use std::mem::size_of;
469 use std::os::raw::c_long;
470 use std::os::unix::net::UnixDatagram;
471 use std::slice::from_raw_parts;
472
473 use libc::cmsghdr;
474
475 #[test]
476 fn buffer_len() {
477 assert_eq!(CMSG_SPACE!(0), size_of::<cmsghdr>());
478 assert_eq!(
479 CMSG_SPACE!(size_of::<RawFd>()),
480 size_of::<cmsghdr>() + size_of::<c_long>()
481 );
482 if size_of::<RawFd>() == 4 {
483 assert_eq!(
484 CMSG_SPACE!(2 * size_of::<RawFd>()),
485 size_of::<cmsghdr>() + size_of::<c_long>()
486 );
487 assert_eq!(
488 CMSG_SPACE!(3 * size_of::<RawFd>()),
489 size_of::<cmsghdr>() + size_of::<c_long>() * 2
490 );
491 assert_eq!(
492 CMSG_SPACE!(4 * size_of::<RawFd>()),
493 size_of::<cmsghdr>() + size_of::<c_long>() * 2
494 );
495 } else if size_of::<RawFd>() == 8 {
496 assert_eq!(
497 CMSG_SPACE!(2 * size_of::<RawFd>()),
498 size_of::<cmsghdr>() + size_of::<c_long>() * 2
499 );
500 assert_eq!(
501 CMSG_SPACE!(3 * size_of::<RawFd>()),
502 size_of::<cmsghdr>() + size_of::<c_long>() * 3
503 );
504 assert_eq!(
505 CMSG_SPACE!(4 * size_of::<RawFd>()),
506 size_of::<cmsghdr>() + size_of::<c_long>() * 4
507 );
508 }
509 }
510
511 #[test]
512 fn send_recv_no_fd() {
513 let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair");
514
515 let write_count = s1
516 .send_with_fds(&[[1u8, 1, 2].as_ref(), [21u8, 34, 55].as_ref()], &[])
517 .expect("failed to send data");
518
519 assert_eq!(write_count, 6);
520
521 let mut buf = [0u8; 6];
522 let mut files = [0; 1];
523 let mut iovecs = [iovec {
524 iov_base: buf.as_mut_ptr() as *mut c_void,
525 iov_len: buf.len(),
526 }];
527 let (read_count, file_count) = unsafe {
528 s2.recv_with_fds(&mut iovecs[..], &mut files)
529 .expect("failed to recv data")
530 };
531
532 assert_eq!(read_count, 6);
533 assert_eq!(file_count, 0);
534 assert_eq!(buf, [1, 1, 2, 21, 34, 55]);
535 }
536
537 #[test]
538 fn send_recv_only_fd() {
539 let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair");
540
541 let evt = EventFd::new(0).expect("failed to create eventfd");
542 let write_count = s1
543 .send_with_fd([].as_ref(), evt.as_raw_fd())
544 .expect("failed to send fd");
545
546 assert_eq!(write_count, 0);
547
548 let (read_count, file_opt) = s2.recv_with_fd(&mut []).expect("failed to recv fd");
549
550 let mut file = file_opt.unwrap();
551
552 assert_eq!(read_count, 0);
553 assert!(file.as_raw_fd() >= 0);
554 assert_ne!(file.as_raw_fd(), s1.as_raw_fd());
555 assert_ne!(file.as_raw_fd(), s2.as_raw_fd());
556 assert_ne!(file.as_raw_fd(), evt.as_raw_fd());
557
558 file.write_all(unsafe { from_raw_parts(&1203u64 as *const u64 as *const u8, 8) })
559 .expect("failed to write to sent fd");
560
561 assert_eq!(evt.read().expect("failed to read from eventfd"), 1203);
562 }
563
564 #[test]
565 fn send_recv_with_fd() {
566 let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair");
567
568 let evt = EventFd::new(0).expect("failed to create eventfd");
569 let write_count = s1
570 .send_with_fds(&[[237].as_ref()], &[evt.as_raw_fd()])
571 .expect("failed to send fd");
572
573 assert_eq!(write_count, 1);
574
575 let mut files = [0; 2];
576 let mut buf = [0u8];
577 let mut iovecs = [iovec {
578 iov_base: buf.as_mut_ptr() as *mut c_void,
579 iov_len: buf.len(),
580 }];
581 let (read_count, file_count) = unsafe {
582 s2.recv_with_fds(&mut iovecs[..], &mut files)
583 .expect("failed to recv fd")
584 };
585
586 assert_eq!(read_count, 1);
587 assert_eq!(buf[0], 237);
588 assert_eq!(file_count, 1);
589 assert!(files[0] >= 0);
590 assert_ne!(files[0], s1.as_raw_fd());
591 assert_ne!(files[0], s2.as_raw_fd());
592 assert_ne!(files[0], evt.as_raw_fd());
593
594 let mut file = unsafe { File::from_raw_fd(files[0]) };
595
596 file.write_all(unsafe { from_raw_parts(&1203u64 as *const u64 as *const u8, 8) })
597 .expect("failed to write to sent fd");
598
599 assert_eq!(evt.read().expect("failed to read from eventfd"), 1203);
600 }
601
602 #[test]
603 fn send_more_recv_less1() {
606 let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair");
607
608 let evt1 = EventFd::new(0).expect("failed to create eventfd");
609 let evt2 = EventFd::new(0).expect("failed to create eventfd");
610 let evt3 = EventFd::new(0).expect("failed to create eventfd");
611 let evt4 = EventFd::new(0).expect("failed to create eventfd");
612 let write_count = s1
613 .send_with_fds(
614 &[[237].as_ref()],
615 &[
616 evt1.as_raw_fd(),
617 evt2.as_raw_fd(),
618 evt3.as_raw_fd(),
619 evt4.as_raw_fd(),
620 ],
621 )
622 .expect("failed to send fd");
623
624 assert_eq!(write_count, 1);
625
626 let mut files = [0; 2];
627 let mut buf = [0u8];
628 let mut iovecs = [iovec {
629 iov_base: buf.as_mut_ptr() as *mut c_void,
630 iov_len: buf.len(),
631 }];
632 assert!(unsafe { s2.recv_with_fds(&mut iovecs[..], &mut files).is_err() });
633 }
634
635 #[test]
638 fn send_more_recv_less2() {
639 let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair");
640
641 let evt1 = EventFd::new(0).expect("failed to create eventfd");
642 let evt2 = EventFd::new(0).expect("failed to create eventfd");
643 let evt3 = EventFd::new(0).expect("failed to create eventfd");
644 let evt4 = EventFd::new(0).expect("failed to create eventfd");
645 let write_count = s1
646 .send_with_fds(
647 &[[237].as_ref()],
648 &[
649 evt1.as_raw_fd(),
650 evt2.as_raw_fd(),
651 evt3.as_raw_fd(),
652 evt4.as_raw_fd(),
653 ],
654 )
655 .expect("failed to send fd");
656
657 assert_eq!(write_count, 1);
658
659 let mut files = [0; 1];
660 let mut buf = [0u8];
661 let mut iovecs = [iovec {
662 iov_base: buf.as_mut_ptr() as *mut c_void,
663 iov_len: buf.len(),
664 }];
665 assert!(unsafe { s2.recv_with_fds(&mut iovecs[..], &mut files).is_err() });
666 }
667}