Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

*: make IfWatcher::new synchronous #24

Merged
merged 3 commits into from
Aug 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed
- Add `IfWatcher::poll_next`. Implement `Stream` instead of `Future` for `IfWatcher`. See [PR 23].
- Make `IfWatcher::new` synchronous. See [PR 24].
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure how to describe this better.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me.


[PR 23]: https://github.com/mxinden/if-watch/pull/23
[PR 24]: https://github.com/mxinden/if-watch/pull/24

## [1.1.1]

Expand Down
2 changes: 1 addition & 1 deletion examples/if_watch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use if_watch::IfWatcher;
fn main() {
env_logger::init();
futures::executor::block_on(async {
let mut set = IfWatcher::new().await.unwrap();
let mut set = IfWatcher::new().unwrap();
loop {
let event = set.select_next_some().await;
println!("Got event {:?}", event);
Expand Down
2 changes: 1 addition & 1 deletion src/apple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ pub struct IfWatcher {
}

impl IfWatcher {
pub async fn new() -> Result<Self> {
pub fn new() -> Result<Self> {
let (tx, rx) = mpsc::channel(1);
std::thread::spawn(|| background_task(tx));
let mut watcher = Self {
Expand Down
2 changes: 1 addition & 1 deletion src/fallback.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ pub struct IfWatcher {

impl IfWatcher {
/// Create a watcher
pub async fn new() -> Result<Self> {
pub fn new() -> Result<Self> {
Ok(Self {
addrs: Default::default(),
queue: Default::default(),
Expand Down
10 changes: 5 additions & 5 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ pub struct IfWatcher(platform_impl::IfWatcher);

impl IfWatcher {
/// Create a watcher
pub async fn new() -> Result<Self> {
Ok(Self(platform_impl::IfWatcher::new().await?))
pub fn new() -> Result<Self> {
platform_impl::IfWatcher::new().map(Self)
}

/// Iterate over current networks.
Expand Down Expand Up @@ -92,7 +92,7 @@ mod tests {
#[test]
fn test_ip_watch() {
futures::executor::block_on(async {
let mut set = IfWatcher::new().await.unwrap();
let mut set = IfWatcher::new().unwrap();
let event = set.select_next_some().await.unwrap();
println!("Got event {:?}", event);
});
Expand All @@ -103,8 +103,8 @@ mod tests {
futures::executor::block_on(async {
fn is_send<T: Send>(_: T) {}
is_send(IfWatcher::new());
is_send(IfWatcher::new().await.unwrap());
is_send(Pin::new(&mut IfWatcher::new().await.unwrap()));
is_send(IfWatcher::new().unwrap());
is_send(Pin::new(&mut IfWatcher::new().unwrap()));
});
}
}
65 changes: 23 additions & 42 deletions src/linux.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
use crate::{IfEvent, IpNet, Ipv4Net, Ipv6Net};
use fnv::FnvHashSet;
use futures::channel::mpsc::UnboundedReceiver;
use futures::future::Either;
use futures::ready;
use futures::stream::{Stream, TryStreamExt};
use futures::StreamExt;
use rtnetlink::constants::{RTMGRP_IPV4_IFADDR, RTMGRP_IPV6_IFADDR};
use rtnetlink::packet::address::nlas::Nla;
use rtnetlink::packet::{AddressMessage, RtnlMessage};
use rtnetlink::proto::{Connection, NetlinkMessage, NetlinkPayload};
use rtnetlink::proto::{Connection, NetlinkPayload};
use rtnetlink::sys::{AsyncSocket, SmolSocket, SocketAddr};
use std::collections::VecDeque;
use std::future::Future;
Expand All @@ -18,7 +17,7 @@ use std::task::{Context, Poll};

pub struct IfWatcher {
conn: Connection<RtnlMessage, SmolSocket>,
messages: UnboundedReceiver<(NetlinkMessage<RtnlMessage>, SocketAddr)>,
messages: Pin<Box<dyn Stream<Item = Result<RtnlMessage>> + Send>>,
addrs: FnvHashSet<IpNet>,
queue: VecDeque<IfEvent>,
}
Expand All @@ -32,40 +31,27 @@ impl std::fmt::Debug for IfWatcher {
}

impl IfWatcher {
pub async fn new() -> Result<Self> {
pub fn new() -> Result<Self> {
let (mut conn, handle, messages) = rtnetlink::new_connection_with_socket::<SmolSocket>()?;
let groups = RTMGRP_IPV4_IFADDR | RTMGRP_IPV6_IFADDR;
let addr = SocketAddr::new(0, groups);
conn.socket_mut().socket_mut().bind(&addr)?;
let mut stream = handle.address().get().execute();
let mut addrs = FnvHashSet::default();
let mut queue = VecDeque::default();

loop {
let fut = futures::future::select(conn, stream.try_next());
match fut.await {
Either::Left(_) => {
return Err(std::io::Error::new(
ErrorKind::BrokenPipe,
"rtnetlink socket closed",
))
}
Either::Right((x, c)) => {
conn = c;
match x {
Ok(Some(msg)) => {
for net in iter_nets(msg) {
if addrs.insert(net) {
queue.push_back(IfEvent::Up(net));
}
}
}
Ok(None) => break,
Err(err) => return Err(Error::new(ErrorKind::Other, err)),
}
}
let get_addrs_stream = handle
.address()
.get()
.execute()
.map_ok(RtnlMessage::NewAddress)
.map_err(|err| Error::new(ErrorKind::Other, err));
let msg_stream = messages.filter_map(|(msg, _)| async {
match msg.payload {
NetlinkPayload::Error(err) => Some(Err(err.to_io())),
NetlinkPayload::InnerMessage(msg) => Some(Ok(msg)),
_ => None,
}
}
});
let messages = get_addrs_stream.chain(msg_stream).boxed();
let addrs = FnvHashSet::default();
let queue = VecDeque::default();
Ok(Self {
conn,
messages,
Expand Down Expand Up @@ -102,15 +88,10 @@ impl IfWatcher {
if Pin::new(&mut self.conn).poll(cx).is_ready() {
return Poll::Ready(Err(socket_err()));
}
let (message, _) =
ready!(Pin::new(&mut self.messages).poll_next(cx)).ok_or_else(socket_err)?;
match message.payload {
NetlinkPayload::Error(err) => return Poll::Ready(Err(err.to_io())),
NetlinkPayload::InnerMessage(msg) => match msg {
RtnlMessage::NewAddress(msg) => self.add_address(msg),
RtnlMessage::DelAddress(msg) => self.rem_address(msg),
_ => {}
},
let message = ready!(self.messages.poll_next_unpin(cx)).ok_or_else(socket_err)??;
match message {
RtnlMessage::NewAddress(msg) => self.add_address(msg),
RtnlMessage::DelAddress(msg) => self.rem_address(msg),
_ => {}
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/win.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ pub struct IfWatcher {

impl IfWatcher {
/// Create a watcher
pub async fn new() -> Result<Self> {
pub fn new() -> Result<Self> {
let resync = Arc::new(AtomicBool::new(true));
let waker = Arc::new(AtomicWaker::new());
Ok(Self {
Expand Down