Skip to content

Commit

Permalink
Support object as value in extra_credential
Browse files Browse the repository at this point in the history
  • Loading branch information
huw0 committed Jun 12, 2024
1 parent bb35f1c commit a309986
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 1 deletion.
35 changes: 35 additions & 0 deletions tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -867,6 +867,41 @@ def test_extra_credential_value_encoding(mock_get_and_post):
assert headers[constants.HEADER_EXTRA_CREDENTIAL] == "foo=bar+%E7%9A%84"


def test_extra_credential_value_object(mock_get_and_post):
_, post = mock_get_and_post

class TestCredential(object):
value = "initial"

def __str__(self):
return self.value

credential = TestCredential()

req = TrinoRequest(
host="coordinator",
port=constants.DEFAULT_TLS_PORT,
client_session=ClientSession(
user="test",
extra_credential=[("foo", credential)]
)
)

req.post("SELECT 1")
_, post_kwargs = post.call_args
headers = post_kwargs["headers"]
assert constants.HEADER_EXTRA_CREDENTIAL in headers
assert headers[constants.HEADER_EXTRA_CREDENTIAL] == "foo=initial"

# Make a second request, assert that credential has changed
credential.value = "changed"
req.post("SELECT 1")
_, post_kwargs = post.call_args
headers = post_kwargs["headers"]
assert constants.HEADER_EXTRA_CREDENTIAL in headers
assert headers[constants.HEADER_EXTRA_CREDENTIAL] == "foo=changed"


class MockGssapiCredentials:
def __init__(self, name: gssapi.Name, usage: str):
self.name = name
Expand Down
3 changes: 2 additions & 1 deletion trino/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,8 @@ def http_headers(self) -> Dict[str, str]:
# extra credential value is encoded per spec (application/x-www-form-urlencoded MIME format)
headers[constants.HEADER_EXTRA_CREDENTIAL] = \
", ".join(
[f"{tup[0]}={urllib.parse.quote_plus(tup[1])}" for tup in self._client_session.extra_credential])
[f"{tup[0]}={urllib.parse.quote_plus(str(tup[1]))}"
for tup in self._client_session.extra_credential])

return headers

Expand Down

0 comments on commit a309986

Please sign in to comment.