From 7254e51bed29ccd47ee1547ea597fc1ebe3e51e9 Mon Sep 17 00:00:00 2001 From: Jonas Zaddach Date: Sun, 25 Jul 2021 15:02:08 +0200 Subject: [PATCH] Fix failing tests and move descriptor to own module --- src/descriptor.rs | 170 ++++++++++++++++++++++++++++++++++ src/ipc.rs | 16 ++-- src/lib.rs | 1 + src/platform/common/fd.rs | 100 -------------------- src/platform/common/mod.rs | 1 - src/platform/inprocess/mod.rs | 35 +++---- src/platform/macos/mod.rs | 20 ++-- src/platform/mod.rs | 10 -- src/platform/unix/mod.rs | 18 ++-- src/platform/windows/mod.rs | 163 ++++++++++++++++++++++++++++---- src/test.rs | 2 +- 11 files changed, 362 insertions(+), 174 deletions(-) create mode 100644 src/descriptor.rs delete mode 100644 src/platform/common/fd.rs delete mode 100644 src/platform/common/mod.rs diff --git a/src/descriptor.rs b/src/descriptor.rs new file mode 100644 index 00000000..867b7e95 --- /dev/null +++ b/src/descriptor.rs @@ -0,0 +1,170 @@ +use std::io; +use std::thread; +use std::mem; +use std::default::Default; +use std::fs::File; +use std::cell::RefCell; + +#[cfg(windows)] +pub use { + std::os::windows::io::RawHandle as RawDescriptor, + std::os::windows::io::AsRawHandle, + std::os::windows::io::IntoRawHandle, + std::os::windows::io::FromRawHandle, +}; + +#[cfg(unix)] +pub use { + std::os::unix::io::RawFd as RawDescriptor, + std::os::unix::io::AsRawFd, + std::os::unix::io::IntoRawFd, + std::os::unix::io::FromRawFd, +}; + +#[cfg(windows)] +const INVALID_RAW_DESCRIPTOR: RawDescriptor = winapi::um::handleapi::INVALID_HANDLE_VALUE; + +#[cfg(windows)] +fn raw_descriptor_close(descriptor: &RawDescriptor) -> Result<(), io::Error> { + unsafe { + let result = winapi::um::handleapi::CloseHandle(*descriptor); + if result == 0 { + Err(io::Error::last_os_error()) + } + else { + Ok(()) + } + } +} + +#[cfg(unix)] +const INVALID_RAW_DESCRIPTOR: RawDescriptor = -1; + +#[cfg(unix)] +fn raw_descriptor_close(descriptor: &RawDescriptor) -> Result<(), io::Error> { + unsafe { + let result = libc::close(*descriptor); + if result == 0 { + Ok(()) + } + else { + Err(io::Error::last_os_error()) + } + } +} + +#[derive(Debug)] +pub struct OwnedDescriptor(RefCell); + +unsafe impl Send for OwnedDescriptor { } +unsafe impl Sync for OwnedDescriptor { } + +impl Drop for OwnedDescriptor { + fn drop(&mut self) { + if *self.0.borrow() != INVALID_RAW_DESCRIPTOR { + let result = raw_descriptor_close(&*self.0.borrow()); + assert!( thread::panicking() || result.is_ok() ); + } + } +} + +impl OwnedDescriptor { + pub fn new(descriptor: RawDescriptor) -> OwnedDescriptor { + OwnedDescriptor(RefCell::new(descriptor)) + } + + pub fn consume(& self) -> OwnedDescriptor { + OwnedDescriptor::new(self.0.replace(INVALID_RAW_DESCRIPTOR)) + } +} + +impl Default for OwnedDescriptor { + fn default() -> OwnedDescriptor { + OwnedDescriptor::new(INVALID_RAW_DESCRIPTOR) + } +} + +#[cfg(windows)] +impl IntoRawHandle for OwnedDescriptor { + fn into_raw_handle(self) -> RawDescriptor { + let handle = *self.0.borrow(); + mem::forget(self); + handle + } +} + +#[cfg(windows)] +impl AsRawHandle for OwnedDescriptor { + fn as_raw_handle(& self) -> RawDescriptor { + *self.0.borrow() + } +} + +#[cfg(windows)] +impl FromRawHandle for OwnedDescriptor { + unsafe fn from_raw_handle(handle: RawDescriptor) -> OwnedDescriptor { + OwnedDescriptor::new(handle) + } +} + +#[cfg(windows)] +impl Into for OwnedDescriptor { + fn into(self) -> File { + unsafe { + File::from_raw_handle(self.into_raw_handle()) + } + } +} + +#[cfg(windows)] +impl From for OwnedDescriptor { + fn from(file: File) -> Self { + OwnedDescriptor::new(file.into_raw_handle()) + } +} + +#[cfg(unix)] +impl IntoRawFd for OwnedDescriptor { + fn into_raw_fd(self) -> RawDescriptor { + let fd = self.0.replace(INVALID_RAW_DESCRIPTOR); + mem::forget(self); + fd + } +} + +#[cfg(unix)] +impl AsRawFd for OwnedDescriptor { + fn as_raw_fd(& self) -> RawDescriptor { + *self.0.borrow() + } +} + +#[cfg(unix)] +impl FromRawFd for OwnedDescriptor { + unsafe fn from_raw_fd(fd: RawDescriptor) -> OwnedDescriptor { + OwnedDescriptor::new(fd) + } +} + +#[cfg(unix)] +impl Into for OwnedDescriptor { + fn into(self) -> File { + unsafe { + File::from_raw_fd(self.into_raw_fd()) + } + } +} + +#[cfg(unix)] +impl From for OwnedDescriptor { + fn from(file: File) -> Self { + OwnedDescriptor::new(file.into_raw_fd()) + } +} + +#[cfg(unix)] +impl PartialEq for OwnedDescriptor { + fn eq(&self, other: &Self) -> bool { + self.0 == other.0 + } +} diff --git a/src/ipc.rs b/src/ipc.rs index 0b96fe4a..3f5bf64d 100644 --- a/src/ipc.rs +++ b/src/ipc.rs @@ -9,7 +9,7 @@ use crate::platform::{self, OsIpcChannel, OsIpcReceiver, OsIpcReceiverSet, OsIpcSender}; use crate::platform::{OsIpcOneShotServer, OsIpcSelectionResult, OsIpcSharedMemory, OsOpaqueIpcChannel}; -use crate::platform::Descriptor; +use crate::descriptor::OwnedDescriptor; use bincode; use serde::{Deserialize, Deserializer, Serialize, Serializer}; @@ -31,7 +31,7 @@ thread_local! { RefCell>> = RefCell::new(Vec::new()) } thread_local! { - static OS_IPC_DESCRIPTORS_FOR_DESERIALIZATION: RefCell> = RefCell::new(Vec::new()) + static OS_IPC_DESCRIPTORS_FOR_DESERIALIZATION: RefCell> = RefCell::new(Vec::new()) } thread_local! { static OS_IPC_CHANNELS_FOR_SERIALIZATION: RefCell> = RefCell::new(Vec::new()) @@ -41,7 +41,7 @@ thread_local! { RefCell::new(Vec::new()) } thread_local! { - static OS_IPC_DESCRIPTORS_FOR_SERIALIZATION: RefCell> = RefCell::new(Vec::new()) + static OS_IPC_DESCRIPTORS_FOR_SERIALIZATION: RefCell> = RefCell::new(Vec::new()) } #[derive(Debug)] @@ -621,7 +621,7 @@ pub struct OpaqueIpcMessage { data: Vec, os_ipc_channels: Vec, os_ipc_shared_memory_regions: Vec>, - os_ipc_descriptors: Vec, + os_ipc_descriptors: Vec, } impl Debug for OpaqueIpcMessage { @@ -637,7 +637,7 @@ impl OpaqueIpcMessage { fn new(data: Vec, os_ipc_channels: Vec, os_ipc_shared_memory_regions: Vec, - os_ipc_descriptors: Vec) + os_ipc_descriptors: Vec) -> OpaqueIpcMessage { OpaqueIpcMessage { data: data, @@ -924,7 +924,7 @@ fn deserialize_os_ipc_receiver<'de, D>(deserializer: D) } -impl Serialize for Descriptor { +impl Serialize for OwnedDescriptor { fn serialize(&self, serializer: S) -> Result where S: Serializer { let index = OS_IPC_DESCRIPTORS_FOR_SERIALIZATION.with(|os_ipc_descriptors_for_serialization| { let mut os_ipc_descriptors_for_serialization = @@ -937,12 +937,12 @@ impl Serialize for Descriptor { } } -impl<'de> Deserialize<'de> for Descriptor { +impl<'de> Deserialize<'de> for OwnedDescriptor { fn deserialize(deserializer: D) -> Result where D: Deserializer<'de> { let index: usize = Deserialize::deserialize(deserializer)?; OS_IPC_DESCRIPTORS_FOR_DESERIALIZATION.with(|os_ipc_descriptors_for_deserialization| { - os_ipc_descriptors_for_deserialization.borrow_mut().get_mut(index).map(|x| x.consume()).ok_or(serde::de::Error::invalid_value(serde::de::Unexpected::Unsigned(index as u64), &"index for Descriptor")) + os_ipc_descriptors_for_deserialization.borrow_mut().get_mut(index).map(|x| x.consume()).ok_or(serde::de::Error::invalid_value(serde::de::Unexpected::Unsigned(index as u64), &"index for OwnedDescriptor")) }) } } diff --git a/src/lib.rs b/src/lib.rs index f7011662..bae5a892 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -86,6 +86,7 @@ extern crate winapi; pub mod ipc; pub mod platform; pub mod router; +pub mod descriptor; #[cfg(test)] mod test; diff --git a/src/platform/common/fd.rs b/src/platform/common/fd.rs deleted file mode 100644 index 493c36aa..00000000 --- a/src/platform/common/fd.rs +++ /dev/null @@ -1,100 +0,0 @@ -// use std::ops::{ -// Deref, -// DerefMut, -// }; -use std::os::unix::io::{ - AsRawFd, - RawFd, - IntoRawFd, - FromRawFd, -}; -use std::fmt; -use std::cmp::{PartialEq}; -use std::mem; -use std::fs::File; -use std::cell::RefCell; - -pub struct OwnedFd(RefCell); - -impl Drop for OwnedFd { - fn drop(&mut self) { - if *self.0.borrow() != -1 { - unsafe { - let _ = libc::close(*self.0.borrow()); - } - } - } -} - -// impl Deref for OwnedFd { -// type Target = RawFd; - -// fn deref(&self) -> &Self::Target { -// &self.0 -// } -// } - -// impl DerefMut for OwnedFd { -// fn deref_mut(&mut self) -> &mut Self::Target { -// &mut self.0 -// } -// } - -impl IntoRawFd for OwnedFd { - fn into_raw_fd(self) -> RawFd { - let fd = *self.0.borrow(); - mem::forget(self); - fd - } -} - -impl AsRawFd for OwnedFd { - fn as_raw_fd(& self) -> RawFd { - *self.0.borrow() - } -} - -impl FromRawFd for OwnedFd { - unsafe fn from_raw_fd(fd: RawFd) -> OwnedFd { - OwnedFd::new(fd) - } -} - -impl Into for OwnedFd { - fn into(self) -> File { - unsafe { - File::from_raw_fd(self.into_raw_fd()) - } - } -} - -impl From for OwnedFd { - fn from(file: File) -> Self { - OwnedFd::new(file.into_raw_fd()) - } -} - -impl OwnedFd { - pub fn new(fd: RawFd) -> OwnedFd { - OwnedFd(RefCell::new(fd)) - } - - pub fn consume(&self) -> OwnedFd { - let fd = self.0.replace(-1); - OwnedFd::new(fd) - } -} - -impl PartialEq for OwnedFd { - fn eq(&self, other: &Self) -> bool { - *self.0.borrow() == *other.0.borrow() - } -} - -impl fmt::Debug for OwnedFd { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_tuple("") - .field(&self.0) - .finish() - } -} diff --git a/src/platform/common/mod.rs b/src/platform/common/mod.rs deleted file mode 100644 index 99821ee4..00000000 --- a/src/platform/common/mod.rs +++ /dev/null @@ -1 +0,0 @@ -pub mod fd; \ No newline at end of file diff --git a/src/platform/inprocess/mod.rs b/src/platform/inprocess/mod.rs index 7d5e369f..8b4a69f2 100644 --- a/src/platform/inprocess/mod.rs +++ b/src/platform/inprocess/mod.rs @@ -21,6 +21,7 @@ use std::cmp::{PartialEq}; use std::ops::{Deref, RangeFrom}; use std::usize; use uuid::Uuid; +use crate::descriptor::OwnedDescriptor; #[derive(Clone)] struct ServerRecord { @@ -52,7 +53,7 @@ lazy_static! { static ref ONE_SHOT_SERVERS: Mutex> = Mutex::new(HashMap::new()); } -struct ChannelMessage(Vec, Vec, Vec); +struct ChannelMessage(Vec, Vec, Vec, Vec); pub fn channel() -> Result<(OsIpcSender, OsIpcReceiver), ChannelError> { let (base_sender, base_receiver) = crossbeam_channel::unbounded::(); @@ -85,12 +86,12 @@ impl OsIpcReceiver { pub fn recv( &self - ) -> Result<(Vec, Vec, Vec), ChannelError> { + ) -> Result<(Vec, Vec, Vec, Vec), ChannelError> { let r = self.receiver.borrow(); let r = r.as_ref().unwrap(); match r.recv() { - Ok(ChannelMessage(d, c, s)) => { - Ok((d, c.into_iter().map(OsOpaqueIpcChannel::new).collect(), s)) + Ok(ChannelMessage(d, c, s, fd)) => { + Ok((d, c.into_iter().map(OsOpaqueIpcChannel::new).collect(), s, fd)) } Err(_) => Err(ChannelError::ChannelClosedError), } @@ -98,12 +99,12 @@ impl OsIpcReceiver { pub fn try_recv( &self - ) -> Result<(Vec, Vec, Vec), ChannelError> { + ) -> Result<(Vec, Vec, Vec, Vec), ChannelError> { let r = self.receiver.borrow(); let r = r.as_ref().unwrap(); match r.try_recv() { - Ok(ChannelMessage(d, c, s)) => { - Ok((d, c.into_iter().map(OsOpaqueIpcChannel::new).collect(), s)) + Ok(ChannelMessage(d, c, s, fd)) => { + Ok((d, c.into_iter().map(OsOpaqueIpcChannel::new).collect(), s, fd)) }, Err(e) => { match e { @@ -149,10 +150,11 @@ impl OsIpcSender { data: &[u8], ports: Vec, shared_memory_regions: Vec, + descriptors: Vec, ) -> Result<(), ChannelError> { Ok(self.sender .borrow() - .send(ChannelMessage(data.to_vec(), ports, shared_memory_regions)).map_err(|_| ChannelError::BrokenPipeError)?) + .send(ChannelMessage(data.to_vec(), ports, shared_memory_regions, descriptors)).map_err(|_| ChannelError::BrokenPipeError)?) } } @@ -198,9 +200,9 @@ impl OsIpcReceiverSet { let res = select.select(); let r_index = res.index(); let r_id = self.receiver_ids[r_index]; - if let Ok(ChannelMessage(data, channels, shmems)) = res.recv(&borrows[r_index as usize]) { + if let Ok(ChannelMessage(data, channels, shmems, descriptors)) = res.recv(&borrows[r_index as usize]) { let channels = channels.into_iter().map(OsOpaqueIpcChannel::new).collect(); - return Ok(vec![OsIpcSelectionResult::DataReceived(r_id, data, channels, shmems)]) + return Ok(vec![OsIpcSelectionResult::DataReceived(r_id, data, channels, shmems, descriptors)]) } else { Remove(r_index, r_id) } @@ -212,15 +214,15 @@ impl OsIpcReceiverSet { } pub enum OsIpcSelectionResult { - DataReceived(u64, Vec, Vec, Vec), + DataReceived(u64, Vec, Vec, Vec, Vec), ChannelClosed(u64), } impl OsIpcSelectionResult { - pub fn unwrap(self) -> (u64, Vec, Vec, Vec) { + pub fn unwrap(self) -> (u64, Vec, Vec, Vec, Vec) { match self { - OsIpcSelectionResult::DataReceived(id, data, channels, shared_memory_regions) => { - (id, data, channels, shared_memory_regions) + OsIpcSelectionResult::DataReceived(id, data, channels, shared_memory_regions, descriptors) => { + (id, data, channels, shared_memory_regions, descriptors) } OsIpcSelectionResult::ChannelClosed(id) => { panic!("OsIpcSelectionResult::unwrap(): receiver ID {} was closed!", id) @@ -255,6 +257,7 @@ impl OsIpcOneShotServer { Vec, Vec, Vec, + Vec, ), ChannelError, > { @@ -266,8 +269,8 @@ impl OsIpcOneShotServer { .clone(); record.accept(); ONE_SHOT_SERVERS.lock().unwrap().remove(&self.name).unwrap(); - let (data, channels, shmems) = self.receiver.recv()?; - Ok((self.receiver, data, channels, shmems)) + let (data, channels, shmems, descriptors) = self.receiver.recv()?; + Ok((self.receiver, data, channels, shmems, descriptors)) } } diff --git a/src/platform/macos/mod.rs b/src/platform/macos/mod.rs index 003b1f68..1d8b4226 100644 --- a/src/platform/macos/mod.rs +++ b/src/platform/macos/mod.rs @@ -14,7 +14,7 @@ use self::mach_sys::{mach_msg_timeout_t, mach_port_limits_t, mach_port_msgcount_ use self::mach_sys::{mach_port_right_t, mach_port_t, mach_task_self_, vm_inherit_t}; use self::mach_sys::mach_port_deallocate; use self::mach_sys::fileport_t; -use crate::platform::Descriptor; +use crate::descriptor::OwnedDescriptor; use bincode; use libc::{self, c_char, c_uint, c_void, size_t}; @@ -36,8 +36,6 @@ use std::os::raw::c_int; use std::os::unix::io::AsRawFd; -use crate::platform::common::fd::OwnedFd; - mod mach_sys; /// The size that we preallocate on the stack to receive messages. If the message is larger than @@ -358,7 +356,7 @@ impl OsIpcReceiver { } fn recv_with_blocking_mode(&self, blocking_mode: BlockingMode) - -> Result<(Vec, Vec, Vec, Vec), + -> Result<(Vec, Vec, Vec, Vec), MachError> { select(self.port.get(), blocking_mode).and_then(|result| { match result { @@ -371,12 +369,12 @@ impl OsIpcReceiver { } pub fn recv(&self) - -> Result<(Vec, Vec, Vec, Vec),MachError> { + -> Result<(Vec, Vec, Vec, Vec),MachError> { self.recv_with_blocking_mode(BlockingMode::Blocking) } pub fn try_recv(&self) - -> Result<(Vec, Vec, Vec, Vec),MachError> { + -> Result<(Vec, Vec, Vec, Vec),MachError> { self.recv_with_blocking_mode(BlockingMode::Nonblocking) } } @@ -507,7 +505,7 @@ impl OsIpcSender { data: &[u8], ports: Vec, mut shared_memory_regions: Vec, - descriptors: Vec) + descriptors: Vec) -> Result<(),MachError> { let mut data = SendData::from(data); if let Some(data) = data.take_shared_memory() { @@ -708,12 +706,12 @@ impl Drop for OsIpcReceiverSet { } pub enum OsIpcSelectionResult { - DataReceived(u64, Vec, Vec, Vec, Vec), + DataReceived(u64, Vec, Vec, Vec, Vec), ChannelClosed(u64), } impl OsIpcSelectionResult { - pub fn unwrap(self) -> (u64, Vec, Vec, Vec, Vec) { + pub fn unwrap(self) -> (u64, Vec, Vec, Vec, Vec) { match self { OsIpcSelectionResult::DataReceived(id, data, channels, shared_memory_regions, descriptors) => { (id, data, channels, shared_memory_regions, descriptors) @@ -831,7 +829,7 @@ fn select(port: mach_port_t, blocking_mode: BlockingMode) for idx in port_count .. (port_count + descriptor_count) { let fd = mach_fileport_makefd(raw_ports[idx])?; - descriptors.push(OwnedFd::new(fd)); + descriptors.push(OwnedDescriptor::new(fd)); } let has_inline_data_ptr = port_counts.offset(1) as *mut bool; @@ -886,7 +884,7 @@ impl OsIpcOneShotServer { Vec, Vec, Vec, - Vec),MachError> { + Vec),MachError> { let (bytes, channels, shared_memory_regions, descriptors) = self.receiver.recv()?; Ok((self.receiver.consume(), bytes, channels, shared_memory_regions, descriptors)) } diff --git a/src/platform/mod.rs b/src/platform/mod.rs index 0ae0b0d8..47c4e024 100644 --- a/src/platform/mod.rs +++ b/src/platform/mod.rs @@ -7,12 +7,6 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -#[cfg(all(not(feature = "force-inprocess"), any(target_os = "linux", - target_os = "openbsd", - target_os = "freebsd", - target_os = "macos")))] -mod common; - #[cfg(all(not(feature = "force-inprocess"), any(target_os = "linux", target_os = "openbsd", target_os = "freebsd")))] @@ -22,7 +16,6 @@ mod unix; target_os = "freebsd")))] mod os { pub use super::unix::*; - pub type Descriptor = super::common::fd::OwnedFd; } #[cfg(all(not(feature = "force-inprocess"), target_os = "macos"))] @@ -30,7 +23,6 @@ mod macos; #[cfg(all(not(feature = "force-inprocess"), target_os = "macos"))] mod os { pub use super::macos::*; - pub type Descriptor = super::common::fd::OwnedFd; } #[cfg(all(not(feature = "force-inprocess"), target_os = "windows"))] @@ -38,7 +30,6 @@ mod windows; #[cfg(all(not(feature = "force-inprocess"), target_os = "windows"))] mod os { pub use super::windows::*; - pub type Descriptor = super::windows::handle::WinHandle; } #[cfg(any( @@ -57,7 +48,6 @@ mod os { pub use self::os::{OsIpcChannel, OsIpcOneShotServer, OsIpcReceiver, OsIpcReceiverSet}; pub use self::os::{OsIpcSelectionResult, OsIpcSender, OsIpcSharedMemory}; pub use self::os::{OsOpaqueIpcChannel, channel}; -pub use self::os::{Descriptor}; #[cfg(test)] mod test; diff --git a/src/platform/unix/mod.rs b/src/platform/unix/mod.rs index ce961205..810c96b5 100644 --- a/src/platform/unix/mod.rs +++ b/src/platform/unix/mod.rs @@ -7,8 +7,8 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. +use crate::descriptor::OwnedDescriptor; use crate::ipc; -use crate::platform::Descriptor; use bincode; use fnv::FnvHasher; use libc::{EAGAIN, EWOULDBLOCK}; @@ -142,12 +142,12 @@ impl OsIpcReceiver { } pub fn recv(&self) - -> Result<(Vec, Vec, Vec, Vec),UnixError> { + -> Result<(Vec, Vec, Vec, Vec),UnixError> { recv(self.fd.get(), BlockingMode::Blocking) } pub fn try_recv(&self) - -> Result<(Vec, Vec, Vec, Vec),UnixError> { + -> Result<(Vec, Vec, Vec, Vec),UnixError> { recv(self.fd.get(), BlockingMode::Nonblocking) } } @@ -237,7 +237,7 @@ impl OsIpcSender { data: &[u8], channels: Vec, shared_memory_regions: Vec, - descriptors: Vec) + descriptors: Vec) -> Result<(),UnixError> { let header = Header { @@ -547,12 +547,12 @@ impl OsIpcReceiverSet { } pub enum OsIpcSelectionResult { - DataReceived(u64, Vec, Vec, Vec, Vec), + DataReceived(u64, Vec, Vec, Vec, Vec), ChannelClosed(u64), } impl OsIpcSelectionResult { - pub fn unwrap(self) -> (u64, Vec, Vec, Vec, Vec) { + pub fn unwrap(self) -> (u64, Vec, Vec, Vec, Vec) { match self { OsIpcSelectionResult::DataReceived(id, data, channels, shared_memory_regions, descriptors) => { (id, data, channels, shared_memory_regions, descriptors) @@ -643,7 +643,7 @@ impl OsIpcOneShotServer { Vec, Vec, Vec, - Vec),UnixError> { + Vec),UnixError> { unsafe { let sockaddr: *mut sockaddr = ptr::null_mut(); let sockaddr_len: *mut socklen_t = ptr::null_mut(); @@ -919,7 +919,7 @@ enum BlockingMode { } fn recv(fd: c_int, blocking_mode: BlockingMode) - -> Result<(Vec, Vec, Vec, Vec),UnixError> { + -> Result<(Vec, Vec, Vec, Vec),UnixError> { let (mut channels, mut shared_memory_regions, mut descriptors) = (Vec::new(), Vec::new(), Vec::new()); @@ -967,7 +967,7 @@ fn recv(fd: c_int, blocking_mode: BlockingMode) } for index in (header.channel_fd_num + header.shared_memory_fd_num) .. (header.channel_fd_num + header.shared_memory_fd_num + header.descriptor_num) { - descriptors.push(Descriptor::new(*cmsg_fds.offset(index as isize))); + descriptors.push(OwnedDescriptor::new(*cmsg_fds.offset(index as isize))); } } diff --git a/src/platform/windows/mod.rs b/src/platform/windows/mod.rs index 3e87f7cc..4e3e36a0 100644 --- a/src/platform/windows/mod.rs +++ b/src/platform/windows/mod.rs @@ -14,6 +14,7 @@ use crate::ipc; use libc::intptr_t; use std::cell::{Cell, RefCell}; use std::cmp::PartialEq; +use std::default::Default; use std::env; use std::error::Error as StdError; use std::ffi::CString; @@ -22,22 +23,19 @@ use std::io; use std::marker::{Send, Sync, PhantomData}; use std::mem; use std::ops::{Deref, DerefMut, RangeFrom}; +use std::os::windows::io::IntoRawHandle; use std::ptr; use std::slice; use std::thread; use uuid::Uuid; -use winapi::shared::minwindef::{LPVOID}; -use winapi; - - use winapi::um::winnt::{HANDLE}; use winapi::um::handleapi::{INVALID_HANDLE_VALUE}; +use winapi::shared::minwindef::{LPVOID}; +use winapi; mod aliased_cell; -pub mod handle; use self::aliased_cell::AliasedCell; -use crate::platform::Descriptor; -use self::handle::{WinHandle, dup_handle, dup_handle_to_process, move_handle_to_process}; +use crate::descriptor::OwnedDescriptor; lazy_static! { static ref CURRENT_PROCESS_ID: winapi::shared::ntdef::ULONG = unsafe { winapi::um::processthreadsapi::GetCurrentProcessId() }; @@ -261,6 +259,135 @@ fn make_pipe_name(pipe_id: &Uuid) -> CString { CString::new(format!("\\\\.\\pipe\\rust-ipc-{}", pipe_id.to_string())).unwrap() } +/// Duplicate a given handle from this process to the target one, passing the +/// given flags to DuplicateHandle. +/// +/// Unlike win32 DuplicateHandle, this will preserve INVALID_HANDLE_VALUE (which is +/// also the pseudohandle for the current process). +fn dup_handle_to_process_with_flags(handle: &WinHandle, other_process: &WinHandle, flags: winapi::shared::minwindef::DWORD) + -> Result +{ + if !handle.is_valid() { + return Ok(WinHandle::invalid()); + } + + unsafe { + let mut new_handle: HANDLE = INVALID_HANDLE_VALUE; + let ok = winapi::um::handleapi::DuplicateHandle(CURRENT_PROCESS_HANDLE.as_raw(), handle.as_raw(), + other_process.as_raw(), &mut new_handle, + 0, winapi::shared::minwindef::FALSE, flags); + if ok == winapi::shared::minwindef::FALSE { + Err(WinError::last("DuplicateHandle")) + } else { + Ok(WinHandle::new(new_handle)) + } + } +} + +/// Duplicate a handle in the current process. +fn dup_handle(handle: &WinHandle) -> Result { + dup_handle_to_process(handle, &WinHandle::new(CURRENT_PROCESS_HANDLE.as_raw())) +} + +/// Duplicate a handle to the target process. +fn dup_handle_to_process(handle: &WinHandle, other_process: &WinHandle) -> Result { + dup_handle_to_process_with_flags(handle, other_process, winapi::um::winnt::DUPLICATE_SAME_ACCESS) +} + +/// Duplicate a handle to the target process, closing the source handle. +fn move_handle_to_process(handle: WinHandle, other_process: &WinHandle) -> Result { + let result = dup_handle_to_process_with_flags(&handle, other_process, + winapi::um::winnt::DUPLICATE_CLOSE_SOURCE | winapi::um::winnt::DUPLICATE_SAME_ACCESS); + // Since the handle was moved to another process, the original is no longer valid; + // so we probably shouldn't try to close it explicitly? + mem::forget(handle); + result +} + +#[derive(Debug)] +struct WinHandle { + h: HANDLE +} + +unsafe impl Send for WinHandle { } +unsafe impl Sync for WinHandle { } + +impl Drop for WinHandle { + fn drop(&mut self) { + unsafe { + if self.is_valid() { + let result = winapi::um::handleapi::CloseHandle(self.h); + assert!(thread::panicking() || result != 0); + } + } + } +} + +impl Default for WinHandle { + fn default() -> WinHandle { + WinHandle { h: INVALID_HANDLE_VALUE } + } +} + +impl From for WinHandle { + fn from(descriptor: OwnedDescriptor) -> WinHandle { + WinHandle::new(descriptor.into_raw_handle()) + } +} + +const WINDOWS_APP_MODULE_NAME: &'static str = "api-ms-win-core-handle-l1-1-0"; +const COMPARE_OBJECT_HANDLES_FUNCTION_NAME: &'static str = "CompareObjectHandles"; + +lazy_static! { + static ref WINDOWS_APP_MODULE_NAME_CSTRING: CString = CString::new(WINDOWS_APP_MODULE_NAME).unwrap(); + static ref COMPARE_OBJECT_HANDLES_FUNCTION_NAME_CSTRING: CString = CString::new(COMPARE_OBJECT_HANDLES_FUNCTION_NAME).unwrap(); +} + +#[cfg(feature = "windows-shared-memory-equality")] +impl PartialEq for WinHandle { + fn eq(&self, other: &WinHandle) -> bool { + unsafe { + // Calling LoadLibraryA every time seems to be ok since libraries are refcounted and multiple calls won't produce multiple instances. + let module_handle = winapi::um::libloaderapi::LoadLibraryA(WINDOWS_APP_MODULE_NAME_CSTRING.as_ptr()); + if module_handle.is_null() { + panic!("Error loading library {}. {}", WINDOWS_APP_MODULE_NAME, WinError::error_string(GetLastError())); + } + let proc = winapi::um::libloaderapi::GetProcAddress(module_handle, COMPARE_OBJECT_HANDLES_FUNCTION_NAME_CSTRING.as_ptr()); + if proc.is_null() { + panic!("Error calling GetProcAddress to use {}. {}", COMPARE_OBJECT_HANDLES_FUNCTION_NAME, WinError::error_string(GetLastError())); + } + let compare_object_handles: unsafe extern "stdcall" fn(HANDLE, HANDLE) -> winapi::shared::minwindef::BOOL = std::mem::transmute(proc); + compare_object_handles(self.h, other.h) != 0 + } + } +} + +impl WinHandle { + fn new(h: HANDLE) -> WinHandle { + WinHandle { h: h } + } + + fn invalid() -> WinHandle { + WinHandle { h: INVALID_HANDLE_VALUE } + } + + fn is_valid(&self) -> bool { + self.h != INVALID_HANDLE_VALUE + } + + fn as_raw(&self) -> HANDLE { + self.h + } + + fn take_raw(&mut self) -> HANDLE { + mem::replace(&mut self.h, INVALID_HANDLE_VALUE) + } + + fn take(&mut self) -> WinHandle { + WinHandle::new(self.take_raw()) + } +} + /// Helper struct for all data being aliased by the kernel during async reads. #[derive(Debug)] struct AsyncData { @@ -642,7 +769,7 @@ impl MessageReader { } } - fn get_message(&mut self) -> Result, Vec, Vec, Vec)>, + fn get_message(&mut self) -> Result, Vec, Vec, Vec)>, WinError> { // Never touch the buffer while it's still mutably aliased by the kernel! if self.r#async.is_some() { @@ -655,7 +782,7 @@ impl MessageReader { let mut channels: Vec = vec![]; let mut shmems: Vec = vec![]; let mut big_data = None; - let mut descriptors: Vec = vec![]; + let mut descriptors: Vec = vec![]; if let Some(oob) = message.oob_data() { win32_trace!("[$ {:?}] msg with total {} bytes, {} channels, {} shmems, big data handle {:?}", @@ -673,7 +800,7 @@ impl MessageReader { } for handle in oob.descriptor_handles { - descriptors.push(WinHandle::new(handle as HANDLE)); + descriptors.push(OwnedDescriptor::new(handle as HANDLE)); } if oob.big_data_receiver_handle.is_some() { @@ -897,7 +1024,7 @@ impl OsIpcReceiver { // the implementation in select() is used. It does much the same thing, but across multiple // channels. fn receive_message(&self, mut blocking_mode: BlockingMode) - -> Result<(Vec, Vec, Vec, Vec),WinError> { + -> Result<(Vec, Vec, Vec, Vec),WinError> { let mut reader = self.reader.borrow_mut(); assert!(reader.entry_id.is_none(), "receive_message is only valid before this OsIpcReceiver was added to a Set"); @@ -928,13 +1055,13 @@ impl OsIpcReceiver { } pub fn recv(&self) - -> Result<(Vec, Vec, Vec, Vec),WinError> { + -> Result<(Vec, Vec, Vec, Vec),WinError> { win32_trace!("recv"); self.receive_message(BlockingMode::Blocking) } pub fn try_recv(&self) - -> Result<(Vec, Vec, Vec, Vec),WinError> { + -> Result<(Vec, Vec, Vec, Vec),WinError> { win32_trace!("try_recv"); self.receive_message(BlockingMode::Nonblocking) } @@ -1109,7 +1236,7 @@ impl OsIpcSender { data: &[u8], ports: Vec, shared_memory_regions: Vec, - descriptors: Vec) + descriptors: Vec) -> Result<(),WinError> { // We limit the max size we can send here; we can fix this @@ -1150,7 +1277,7 @@ impl OsIpcSender { } for descriptor in descriptors { - let mut raw_remote_handle = move_handle_to_process(descriptor, &server_h)?; + let mut raw_remote_handle = move_handle_to_process(descriptor.into(), &server_h)?; oob.descriptor_handles.push(raw_remote_handle.take_raw() as intptr_t); } @@ -1220,7 +1347,7 @@ impl OsIpcSender { } pub enum OsIpcSelectionResult { - DataReceived(u64, Vec, Vec, Vec, Vec), + DataReceived(u64, Vec, Vec, Vec, Vec), ChannelClosed(u64), } @@ -1442,7 +1569,7 @@ impl OsIpcReceiverSet { } impl OsIpcSelectionResult { - pub fn unwrap(self) -> (u64, Vec, Vec, Vec, Vec) { + pub fn unwrap(self) -> (u64, Vec, Vec, Vec, Vec) { match self { OsIpcSelectionResult::DataReceived(id, data, channels, shared_memory_regions, descriptors) => { (id, data, channels, shared_memory_regions, descriptors) @@ -1583,7 +1710,7 @@ impl OsIpcOneShotServer { Vec, Vec, Vec, - Vec),WinError> { + Vec),WinError> { let receiver = self.receiver; receiver.accept()?; let (data, channels, shmems, descriptors) = receiver.recv()?; diff --git a/src/test.rs b/src/test.rs index f9d80f2b..609b99a7 100644 --- a/src/test.rs +++ b/src/test.rs @@ -665,7 +665,7 @@ fn test_transfer_descriptor() { std::mem::drop(file); let file = std::fs::File::open(& temp_file_path).unwrap(); - let person_and_descriptor = (person, crate::platform::Descriptor::from(file)); + let person_and_descriptor = (person, crate::descriptor::OwnedDescriptor::from(file)); let (tx, rx) = ipc::channel().unwrap(); tx.send(person_and_descriptor).unwrap(); let received_person_and_descriptor = rx.recv().unwrap();