BeGreen
BeGreen

Reputation: 951

Python DefaultAzureCredential get_token set expiration or renew token

I'm using DefaultAzureCredential from azure-identity to connect to Azure with service principal environment variables (AZURE_CLIENT_SECRET, AZURE_TENANT_ID, AZURE_CLIENT_ID).

I can get_token from a specific scope like databricks like this:

from azure.identity import DefaultAzureCredential

dbx_scope = "2ff814a6-3304-4ab8-85cb-cd0e6f879c1d/.default"
token = DefaultAzureCredential().get_token(dbx_scope).token

From my experience get_token will create a token with a Time To Live of 1 or 2 hours.

So if I have a large process using the ressource for more than 2 hours, the token expires and all my spark process is lost.

So is there a way to make the generated token last longer ? I see in the official documentation that get_token has a kwargs, but I find no ressources online on how to use it and what can be used inside it.

Upvotes: 2

Views: 2644

Answers (1)

BeGreen
BeGreen

Reputation: 951

I guess there is no option to make this "host" token last longer. So I created a class to handle my PAT following databrick's 2.O API for tokens https://docs.databricks.com/dev-tools/api/latest/tokens.html

Thankfully PATs are automatically removed once they are expired. So I don't have to handle old PATs.

import json
from typing import Dict, List

import requests
from azure.identity import DefaultAzureCredential


class DatabricksTokenManager:
    """Databricks Token Manager. Based on https://docs.databricks.com/dev-tools/api/latest/index.html
    It uses `DefaultAzureCredential` to generate a short token for Databricks. Then it can manage Databricks PATs.
    """

    def __init__(self, databricks_host) -> None:
        """Init DatabricksTokenManager

        Args:
            databricks_host (str): Databricks host with out "https" or ending "/"
        """
        self._token = self._get_databricks_token()
        self.databricks_host = databricks_host
        self._pat = None

    @property
    def token(self) -> str:
        """Token property

        Returns:
            str: token value
        """
        return self._token

    @property
    def pat(self) -> str:
        """PAT property

        Returns:
            str: PAT value
        """
        return self._pat

    def _get_databricks_token(self) -> str:
        """Get auto generated token from Default Azure Credentials.
        If you are running this code in local. You need to run `az login`. Or set Service Principal Environment Variables.

        Returns:
            str: Databricks temporary Token
        """
        dbx_scope = "2ff814a6-3304-4ab8-85cb-cd0e6f879c1d/.default"
        return DefaultAzureCredential().get_token(dbx_scope).token


    def list_databricks_pats(self) -> List[Dict]:
        """List all PATs for this user in Databricks

        Returns:
            list: List of dicts containing PAT info
        """
        headers = {
            "Authorization": f"Bearer {self.token}",
        }
        response = requests.get(
            f"https://{self.databricks_host}/api/2.0/token/list", headers=headers
        )
        return response.json()["token_infos"]

    def create_databricks_pat(self, comment=None) -> str:
        """Create and return a new PAT from Databricks

        Args:
            comment (str:Optional): Comment to link to PAT. Default None
        Returns:
            str: PAT value
        """
        if comment is None:
            comment = "Token created from datalab-framework"

        headers = {
            "Content-type": "application/json",
            "Authorization": f"Bearer {self.token}",
        }
        json_data = {
            "application_id": "ce3b7e02-a406-4afc-8123-3de02807e729",
            "comment": comment,
            "lifetime_seconds": 86400, # 24 Hours
        }
        response = requests.post(
            f"https://{self.databricks_host}/api/2.0/token/create",
            headers=headers,
            json=json_data,
        )
        self._pat = response.json()["token_value"]
        return self._pat

    def remove_databricks_pat(self, pat_id):
        """Remove PAT from databricks

        Args:
            pat_id str: PAT ID
        """
        headers = {
            "Authorization": f"Bearer {self.token}",
            "Content-Type": "application/x-www-form-urlencoded",
        }
        data = {"token_id": f"{pat_id}"}
        requests.post(
            f"https://{self.databricks_host}/api/2.0/token/delete",
            headers=headers,
            data=json.dumps(data),
        )

Upvotes: 3

Related Questions