Skip to content

Commit

Permalink
Refactor Azure credential provider to remove unnecessary print statem…
Browse files Browse the repository at this point in the history
…ents and improve token handling
  • Loading branch information
RobinLin666 committed Sep 13, 2024
1 parent a0682bc commit 92144b2
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 25 deletions.
3 changes: 1 addition & 2 deletions object_store/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,13 @@ ring = { version = "0.17", default-features = false, features = ["std"], optiona
rustls-pemfile = { version = "2.0", default-features = false, features = ["std"], optional = true }
tokio = { version = "1.25.0", features = ["sync", "macros", "rt", "time", "io-util"] }
md-5 = { version = "0.10.6", default-features = false, optional = true }
jsonwebtoken = { version = "9.3.0", default-features = false, optional = true }

[target.'cfg(target_family="unix")'.dev-dependencies]
nix = { version = "0.29.0", features = ["fs"] }

[features]
cloud = ["serde", "serde_json", "quick-xml", "hyper", "reqwest", "reqwest/json", "reqwest/stream", "chrono/serde", "base64", "rand", "ring"]
azure = ["cloud", "jsonwebtoken"]
azure = ["cloud"]
gcp = ["cloud", "rustls-pemfile"]
aws = ["cloud", "md-5"]
http = ["cloud"]
Expand Down
45 changes: 22 additions & 23 deletions object_store/src/azure/credential.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,16 @@ use crate::client::{CredentialProvider, TokenProvider};
use crate::util::hmac_sha256;
use crate::RetryConfig;
use async_trait::async_trait;
use base64::prelude::BASE64_STANDARD;
use base64::prelude::{BASE64_STANDARD, BASE64_URL_SAFE_NO_PAD};
use base64::Engine;
use chrono::{DateTime, SecondsFormat, Utc};
use jsonwebtoken::{decode, get_current_timestamp, Algorithm, DecodingKey, Validation};
use reqwest::header::{
HeaderMap, HeaderName, HeaderValue, ACCEPT, AUTHORIZATION, CONTENT_ENCODING, CONTENT_LANGUAGE,
CONTENT_LENGTH, CONTENT_TYPE, DATE, IF_MATCH, IF_MODIFIED_SINCE, IF_NONE_MATCH,
IF_UNMODIFIED_SINCE, RANGE,
};
use reqwest::{Client, Method, Request, RequestBuilder};
use serde::{Deserialize, Serialize};
use serde::Deserialize;
use snafu::{ResultExt, Snafu};
use std::borrow::Cow;
use std::collections::HashMap;
Expand Down Expand Up @@ -951,7 +950,7 @@ pub struct FabricTokenOAuthProvider {
token_expiry: Option<u64>,
}

#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Deserialize)]
struct Claims {
exp: u64,
}
Expand All @@ -965,18 +964,14 @@ impl FabricTokenOAuthProvider {
fabric_cluster_identifier: impl Into<String>,
storage_access_token: Option<String>,
) -> Self {
let (storage_access_token, token_expiry) = if let Some(token) = storage_access_token {
if let Some(expiry) = Self::validate_and_get_expiry(&token) {
if expiry > get_current_timestamp() + TOKEN_MIN_TTL {
let (storage_access_token, token_expiry) = match storage_access_token {
Some(token) => match Self::validate_and_get_expiry(&token) {
Some(expiry) if expiry > Self::get_current_timestamp() + TOKEN_MIN_TTL => {
(Some(token), Some(expiry))
} else {
(None, None)
}
} else {
(None, None)
}
} else {
(None, None)
_ => (None, None),
},
None => (None, None),
};

Self {
Expand All @@ -990,13 +985,17 @@ impl FabricTokenOAuthProvider {
}

fn validate_and_get_expiry(token: &str) -> Option<u64> {
let mut validation: Validation = Validation::new(Algorithm::HS256);
validation.insecure_disable_signature_validation();
validation.set_audience(&[AZURE_STORAGE_RESOURCE]);
let key = DecodingKey::from_secret(&[]);
decode::<Claims>(token, &key, &validation)
.ok()
.map(|data| data.claims.exp)
let payload = token.split('.').nth(1)?;
let decoded_bytes = BASE64_URL_SAFE_NO_PAD.decode(payload).ok()?;
let decoded_str = str::from_utf8(&decoded_bytes).ok()?;
let claims: Claims = serde_json::from_str(decoded_str).ok()?;
Some(claims.exp)
}

fn get_current_timestamp() -> u64 {
SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.map_or(0, |d| d.as_secs())
}
}

Expand All @@ -1012,7 +1011,7 @@ impl TokenProvider for FabricTokenOAuthProvider {
) -> crate::Result<TemporaryToken<Arc<AzureCredential>>> {
if let Some(storage_access_token) = &self.storage_access_token {
if let Some(expiry) = self.token_expiry {
let exp_in = expiry - get_current_timestamp();
let exp_in = expiry - Self::get_current_timestamp();
if exp_in > TOKEN_MIN_TTL {
return Ok(TemporaryToken {
token: Arc::new(AzureCredential::BearerToken(storage_access_token.clone())),
Expand All @@ -1039,7 +1038,7 @@ impl TokenProvider for FabricTokenOAuthProvider {
.await
.context(TokenResponseBodySnafu)?;
let exp_in = Self::validate_and_get_expiry(&access_token)
.map_or(3600, |expiry| expiry - get_current_timestamp());
.map_or(3600, |expiry| expiry - Self::get_current_timestamp());
Ok(TemporaryToken {
token: Arc::new(AzureCredential::BearerToken(access_token)),
expiry: Some(Instant::now() + Duration::from_secs(exp_in)),
Expand Down

0 comments on commit 92144b2

Please sign in to comment.