Skip to content

Commit

Permalink
Add ALPN support to Deno.connectTls().
Browse files Browse the repository at this point in the history
* Make `connectTls()` accept `alpnProtocols` option
  (string[]), just like `listenTls()`.

* Add `getAgreedAlpnProtocol(): Promise<string | null>` method
  to query the protocol agreed with the peer via ALPN.

The `alpnProtocols` is added as an unstable API. Both methods
aren't yet documented because of that.
  • Loading branch information
1st1 committed Nov 16, 2021
1 parent ec9f5d5 commit a95fef9
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 0 deletions.
10 changes: 10 additions & 0 deletions ext/net/02_tls.js
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,15 @@
return core.opAsync("op_tls_handshake", rid);
}

function opTlsGetAlpnProtocol(rid) {
return core.opAsync("op_tls_get_alpn_protocol", rid);
}

class TlsConn extends Conn {
getAgreedAlpnProtocol() {
return opTlsGetAlpnProtocol(this.rid);
}

handshake() {
return opTlsHandshake(this.rid);
}
Expand All @@ -41,6 +49,7 @@
caCerts = [],
certChain = undefined,
privateKey = undefined,
alpnProtocols,
}) {
const res = await opConnectTls({
port,
Expand All @@ -50,6 +59,7 @@
caCerts,
certChain,
privateKey,
alpnProtocols,
});
return new TlsConn(res.rid, res.remoteAddr, res.localAddr);
}
Expand Down
46 changes: 46 additions & 0 deletions ext/net/ops_tls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,13 @@ impl TlsStream {
fn poll_handshake(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.inner_mut().poll_handshake(cx)
}

fn get_alpn_protocol(&mut self) -> Option<String> {
match self.inner_mut().tls.get_alpn_protocol() {
None => None,
Some(s) => Some(std::str::from_utf8(s).unwrap().to_string())
}
}
}

impl AsyncRead for TlsStream {
Expand Down Expand Up @@ -517,6 +524,10 @@ impl ReadHalf {
.tls_stream
.into_inner()
}

fn get_alpn_protocol(&mut self) -> Option<String> {
self.shared.get_alpn_protocol()
}
}

impl AsyncRead for ReadHalf {
Expand Down Expand Up @@ -658,6 +669,11 @@ impl Shared {
fn drop_shared_waker(self_ptr: *const ()) {
let _ = unsafe { Weak::from_raw(self_ptr as *const Self) };
}

fn get_alpn_protocol(self: &Arc<Self>) -> Option<String> {
let mut tls_stream = self.tls_stream.lock();
return tls_stream.get_alpn_protocol();
}
}

struct ImplementReadTrait<'a, T>(&'a mut T);
Expand Down Expand Up @@ -691,6 +707,7 @@ pub fn init<P: NetPermissions + 'static>() -> Vec<OpPair> {
("op_tls_listen", op_sync(op_tls_listen::<P>)),
("op_tls_accept", op_async(op_tls_accept)),
("op_tls_handshake", op_async(op_tls_handshake)),
("op_tls_get_alpn_protocol", op_async(op_tls_get_alpn_protocol)),
]
}

Expand Down Expand Up @@ -753,6 +770,13 @@ impl TlsStreamResource {
}
Ok(())
}

pub async fn get_alpn_protocol(self: &Rc<Self>) ->
Result<Option<String>, AnyError>
{
let mut rd = RcRef::map(self, |r| &r.rd).borrow_mut().await;
Ok(rd.get_alpn_protocol())
}
}

impl Resource for TlsStreamResource {
Expand Down Expand Up @@ -787,6 +811,7 @@ pub struct ConnectTlsArgs {
ca_certs: Vec<String>,
cert_chain: Option<String>,
private_key: Option<String>,
alpn_protocols: Option<Vec<String>>,
}

#[derive(Deserialize)]
Expand Down Expand Up @@ -905,6 +930,9 @@ where
if args.private_key.is_some() {
super::check_unstable2(&state, "ConnectTlsOptions.privateKey");
}
if args.alpn_protocols.is_some() {
super::check_unstable2(&state, "ConnectTlsOptions.alpnProtocols");
}

{
let mut s = state.borrow_mut();
Expand Down Expand Up @@ -948,6 +976,12 @@ where
unsafely_ignore_certificate_errors,
)?;

if let Some(alpn_protocols) = args.alpn_protocols {
super::check_unstable2(&state, "Deno.connectTls#alpnProtocols");
tls_config.alpn_protocols =
alpn_protocols.into_iter().map(|s| s.into_bytes()).collect();
}

if args.cert_chain.is_some() || args.private_key.is_some() {
let cert_chain = args
.cert_chain
Expand Down Expand Up @@ -1151,3 +1185,15 @@ pub async fn op_tls_handshake(
.get::<TlsStreamResource>(rid)?;
resource.handshake().await
}

pub async fn op_tls_get_alpn_protocol(
state: Rc<RefCell<OpState>>,
rid: ResourceId,
_: (),
) -> Result<Option<String>, AnyError> {
let resource = state
.borrow()
.resource_table
.get::<TlsStreamResource>(rid)?;
resource.get_alpn_protocol().await
}

0 comments on commit a95fef9

Please sign in to comment.