Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Databricks CLI & Azure CLI authentication and a basic integration test #37

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions .github/workflows/acceptance.yaml

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to fix is_<cloud> functions to handle NULL:

is_azure <- function() {
grepl(".azuredatabricks.net", cfg$host)
}
# is_gcp returns true if client is configured for GCP
is_gcp <- function() {
grepl(".gcp.databricks.com", cfg$host)
}
# is_aws returns true if client is configured for AWS
is_aws <- function() {
!is_azure() & !is_gcp()
}

e.g.

is_azure <- function() {
  if (!is.null(cfg$host)) {
    grepl(".azuredatabricks.net", cfg$host)
  } else {
    FALSE
  }
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a bit irrelevant to this PR, can you please send one?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also, in this case, we won't have null there, per se

Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
name: acceptance

on:
pull_request:
types: [opened, synchronize]

permissions:
id-token: write
contents: read
pull-requests: write

jobs:
integration:
if: github.event_name == 'pull_request'
environment: admin
runs-on: larger
steps:
- name: Checkout Code
uses: actions/checkout@v2.5.0

- name: Unshallow
run: git fetch --prune --unshallow

- uses: actions/checkout@v3

- uses: r-lib/actions/setup-r@v2
with:
r-version: release
use-public-rspm: true

- uses: azure/login@v1
with:
client-id: ${{ secrets.ARM_CLIENT_ID }}
tenant-id: ${{ secrets.ARM_TENANT_ID }}
allow-no-subscriptions: true

- uses: r-lib/actions/setup-r-dependencies@v2
with:
extra-packages: devtools
env:
R_COMPILE_AND_INSTALL_PACKAGES: never

- name: Run tests
run: Rscript -e "devtools::test()"
env:
CLOUD_ENV: "${{ vars.CLOUD_ENV }}"
DATABRICKS_HOST: "${{ secrets.DATABRICKS_HOST }}"

169 changes: 164 additions & 5 deletions R/api_client.R
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,11 @@ DatabricksClient <- function(profile = NULL, host = NULL, token = NULL, config_f

# cfg is the current unified authentication config of direct parameters,
# environment variables, and values loaded from ~/.databrickscfg file
cfg <- list(host = coalesce(host, Sys.getenv("DATABRICKS_HOST"), from_cli$host),
token = coalesce(token, Sys.getenv("DATABRICKS_TOKEN"), from_cli$token),
client_id = coalesce(Sys.getenv("DATABRICKS_CLIENT_ID"), from_cli$client_id),
client_secret = coalesce(Sys.getenv("DATABRICKS_CLIENT_SECRET"), from_cli$client_secret))
cfg <- new.env()
cfg$host = coalesce(host, from_cli$host, Sys.getenv("DATABRICKS_HOST"))
cfg$token = coalesce(token, from_cli$token, Sys.getenv("DATABRICKS_TOKEN"))
cfg$client_id = coalesce(from_cli$client_id, Sys.getenv("DATABRICKS_CLIENT_ID"))
cfg$client_secret = coalesce(from_cli$client_secret, Sys.getenv("DATABRICKS_CLIENT_SECRET"))

# add the missing https:// prefix to bare, ODBC-style hosts
if (!is.null(cfg$host) && !startsWith(cfg$host, "http")) {
Expand All @@ -105,7 +106,7 @@ DatabricksClient <- function(profile = NULL, host = NULL, token = NULL, config_f
used <- c()
sensitive <- c("token", "password", "client_secret", "google_credentials",
"azure_client_secret")
for (attr in names(cfg)) {
for (attr in sort(names(cfg))) {
value <- cfg[[attr]]
if (is.null(value)) {
next
Expand Down Expand Up @@ -140,6 +141,32 @@ DatabricksClient <- function(profile = NULL, host = NULL, token = NULL, config_f
return(function() {
c(Authentication = paste("Bearer", cfg$token))
})
}, `databricks-cli` = function() {
token_source <- .databricks_cli_token_source(cfg)
if (is.null(token_source)) {
return(NULL)
}
result <- try(token_source$token(), silent = TRUE)
if (inherits(result, "try-error")) {
return(NULL)
}
return(function() {
token <- token_source$token()
return(token$headers())
})
}, `azure-cli` = function() {
if (!is_azure()) {
return(NULL)
}
token_source <- .azure_cli_token_source("2ff814a6-3304-4ab8-85cb-cd0e6f879c1d")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be removed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope. It holds private state used for token refreshing internally when it's about to expire

result <- try(token_source$token(), silent = TRUE)
if (inherits(result, "try-error")) {
return(NULL)
}
return(function() {
token <- token_source$token()
return(token$headers())
})
})

# authenticate follows the semantics of Unified Client Authentication and
Expand Down Expand Up @@ -212,3 +239,135 @@ DatabricksClient <- function(profile = NULL, host = NULL, token = NULL, config_f

return(list(is_aws = is_aws, is_azure = is_azure, is_gcp = is_gcp, do = do, debug_string = debug_string))
}

.create_token <- function(access_token, token_type = NULL, expiry = NULL) {
if (is.na(expiry)) {
expiry <- Sys.time() + as.difftime(300, units = "secs")
}
state <- new.env()
state$access_token <- access_token
state$token_type <- token_type
state$expiry <- expiry
headers <- function() {
return(c(Authorization = paste(state$token_type, state$access_token)))
}
expired <- function() {
if (is.null(state$expiry)) {
return(FALSE)
}
# Azure Databricks rejects tokens that expire in 30 seconds or less, so we
# refresh the token 40 seconds before it expires.
potentially_expired <- state$expiry - as.difftime(40, units = "secs")
now <- Sys.time()
is_expired <- potentially_expired < now
return(is_expired)
}
valid <- function() {
if (is.null(state$access_token)) {
return(FALSE)
}
if (expired()) {
return(FALSE)
}
return(TRUE)
}
return(list(headers = headers, valid = valid))
}

.refreshable_token_source <- function(refresh) {
state <- new.env()
state$token <- NULL
# this is not thread-safe, but R is single-threaded
return(list(token = function() {
tok <- state$token
if (!is.null(tok)) {
is_valid <- tok$valid()
if (is_valid) {
return(tok)
}
}
state$token <- refresh()
return(state$token)
}))
}

.token_source <- function() {
return(list(token = function() {
stop("token method must be implemented", call. = FALSE)
}))
}

.cli_token_source <- function(cmd, token_type_field, access_token_field, expiry_field) {
parse_expiry <- function(expiry) {
formats <- c("%Y-%m-%dT%H:%M:%OS", "%Y-%m-%dT%H:%M:%S")
for (fmt in formats) {
tryCatch({
x <- as.POSIXct(expiry, format = fmt)
return(x)
}, error = function(e) {
# TODO: improve this
last_error <<- e
})
}
if (exists("last_error")) {
stop(last_error)
}
}
return(.refreshable_token_source(function() {
tryCatch({
# TODO: do better handling, so that we don't see a warning message, when
# Databricks OAuth is not configured for the given host on the given
# machine
out <- system2(command = cmd, stdout = TRUE, stderr = TRUE)
tryCatch({
it <- jsonlite::fromJSON(out)
expiry <- it[[expiry_field]]
expires_on <- parse_expiry(expiry)
access_token <- it[[access_token_field]]
token_type <- it[[token_type_field]]
return(.create_token(access_token, token_type, expires_on))
}, error = function(e) {
if (inherits(e, "error")) {
message <- if (length(e$message) > 0) e$message else ""
stop(paste("cannot unmarshal CLI result:", out, message))
} else if (inherits(e, "try-error")) {
stdout <- if (length(e$message) > 0) e$message else ""
stderr <- if (length(e$stderr) > 0) e$stderr else ""
message <- if (nchar(stdout) > 0) stdout else stderr
stop(paste("cannot get access token:", message))
}
})
}, error = function(e) {
stop("cannot execute CLI command:", e)
})
}))
}

.databricks_cli_token_source <- function(cfg) {
if (is.null(cfg$host)) {
return(NULL)
}
args <- c("auth", "token", "--host", cfg$host)
cli_path <- tryCatch({
# Try to find 'databricks' in PATH
(Sys.which("databricks"))
}, error = function(e) {
# If 'databricks' is not found, try to find 'databricks.exe'
if (Sys.info()["sysname"] == "Windows") {
(Sys.which("databricks.exe"))
} else {
(NULL)
}
})
if (is.null(cli_path)) {
return(NULL)
}
cmd <- c(cli_path, args)
return(.cli_token_source(cmd, "token_type", "access_token", "expiry"))
}

.azure_cli_token_source <- function(resource) {
cmd <- c("az", "account", "get-access-token", "--resource", resource, "--output",
"json")
return(.cli_token_source(cmd, "tokenType", "accessToken", "expiresOn"))
}
2 changes: 1 addition & 1 deletion tests/testthat/test_api_client.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ test_that("loads configuration file", {

test_that("parses configuration profile", {
client <- databricks:::DatabricksClient(config_file = "./data/awscfg", profile="client-secret")
expected <- "host=https://another.cloud.databricks.com/, client_id=xxx, client_secret=***"
expected <- "client_id=xxx, client_secret=***, host=https://another.cloud.databricks.com/"
expect_equal(expected, client$debug_string())
})

Expand Down
9 changes: 9 additions & 0 deletions tests/testthat/test_current_user.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
library(testthat)

skip_if(Sys.getenv("DATABRICKS_HOST") == "", "Not integration test")

test_that("detects current user", {
client <- databricks::DatabricksClient()
user <- databricks::me(client)
expect_false(is.null(user$userName))
})
Loading