Obsidian Phoenix
Obsidian Phoenix

Reputation: 4155

Connect to Multiple Azure DBs with AAD/MFA Credentials

I'm trying to write a netcore console app that connects to multiple Azure SQL Databases, and executes some scripts against them. Our company requires Azure AD with MFA logins for the databases.

I've managed to get it to log in successfully, using the information here:

Setup

static void Main(string[] args)
{
    var provider = new ActiveDirectoryAuthProvider();

    SqlAuthenticationProvider.SetProvider(
        SqlAuthenticationMethod.ActiveDirectoryIntegrated,
        //SC.SqlAuthenticationMethod.ActiveDirectoryInteractive,
        //SC.SqlAuthenticationMethod.ActiveDirectoryIntegrated,  // Alternatives.
        //SC.SqlAuthenticationMethod.ActiveDirectoryPassword,
        provider);
}

public class ActiveDirectoryAuthProvider : SqlAuthenticationProvider
{
    // Program._ more static values that you set!
    private readonly string _clientId = "MyClientID";

    public override async TT.Task<SC.SqlAuthenticationToken>
        AcquireTokenAsync(SC.SqlAuthenticationParameters parameters)
    {
        AD.AuthenticationContext authContext =
            new AD.AuthenticationContext(parameters.Authority);
        authContext.CorrelationId = parameters.ConnectionId;
        AD.AuthenticationResult result;

        switch (parameters.AuthenticationMethod)
        {
             case SC.SqlAuthenticationMethod.ActiveDirectoryIntegrated:
                Console.WriteLine("In method 'AcquireTokenAsync', case_1 == '.ActiveDirectoryIntegrated'.");
                Console.WriteLine($"Resource: {parameters.Resource}");

                result = await authContext.AcquireTokenAsync(
                    parameters.Resource,
                    _clientId,
                    new AD.UserCredential(GlobalSettings.CredentialsSettings.Username));
                break;

            default: throw new InvalidOperationException();
        }           

        return new SC.SqlAuthenticationToken(result.AccessToken, result.ExpiresOn);
    }

    public override bool IsSupported(SC.SqlAuthenticationMethod authenticationMethod)
    {
        return authenticationMethod == SC.SqlAuthenticationMethod.ActiveDirectoryIntegrated
            || authenticationMethod == SC.SqlAuthenticationMethod.ActiveDirectoryInteractive;
    }
}

Connection

private SqlConnection GetConnection()
{
    var builder = new SqlConnectionStringBuilder();
    builder.DataSource = "MyServer";            
    builder.Encrypt = true;
    builder.TrustServerCertificate = true;
    builder.PersistSecurityInfo = true;
    builder.Authentication = SqlAuthenticationMethod.ActiveDirectoryInteractive;
    builder.InitialCatalog = "MyDatabase";

    var conn = new SqlConnection(builder.ToString());
    conn.Open();

    return conn;        
}

This works, and I am able to run the queries as I like. However, whenever the application connects to a new database (at the same address), it opens up a Browser window to login.microsoftonline.com asking me to select my account/sign in.

Is there any way to require this browser authentication only once for all the databases? They are all on the same Azure SQL instance.

Upvotes: 1

Views: 657

Answers (1)

Obsidian Phoenix
Obsidian Phoenix

Reputation: 4155

So, there's a bit of PEBKAC in the code. Although it's using builder.Authentication = SqlAuthenticationMethod.ActiveDirectoryInteractive;, the class is actually attempting to use ActiveDirectoryIntegrated. So my AD class was never actually hit. Also, in the example code it would actually have never worked either, because the case statement exists for ActiveDirectoryIntegrated - I've stripped it out on my local copy.

I actually needed to use the proper ActiveDirectoryInteractive code to hook this up. Once I did, it was able to authenticate once against the system. And this allowed all the db connections to work without requiring additional browser checks.

Setup

static void Main(string[] args)
{
    var provider = new ActiveDirectoryAuthProvider();

    SqlAuthenticationProvider.SetProvider(
        SqlAuthenticationMethod.ActiveDirectoryInteractive,
        //SC.SqlAuthenticationMethod.ActiveDirectoryIntegrated,  // Alternatives.
        //SC.SqlAuthenticationMethod.ActiveDirectoryPassword,
        provider);
}

ActiveDirectoryAuthProvider

public class ActiveDirectoryAuthProvider : SqlAuthenticationProvider
{
    private readonly string _clientId = "MyClientID";

    private Uri _redirectURL { get; set; } = new Uri("http://localhost:8089");

    private AD.AuthenticationContext AuthContext { get; set; }

    private TokenCache Cache { get; set; }

    public ActiveDirectoryAuthProvider()
    {
        Cache = new TokenCache();
    }

    public override async TT.Task<SC.SqlAuthenticationToken> AcquireTokenAsync(SC.SqlAuthenticationParameters parameters)
    {
        var authContext = AuthContext ?? new AD.AuthenticationContext(parameters.Authority, Cache);
        authContext.CorrelationId = parameters.ConnectionId;
        AD.AuthenticationResult result;

        try
        {
            result = await authContext.AcquireTokenSilentAsync(
                parameters.Resource,
                _clientId);     
        }
        catch (AdalSilentTokenAcquisitionException)
        {
            result = await authContext.AcquireTokenAsync(
                parameters.Resource,
                _clientId,
                _redirectURL, 
                new AD.PlatformParameters(PromptBehavior.Auto, new CustomWebUi()), 
                new UserIdentifier(parameters.UserId, UserIdentifierType.RequiredDisplayableId));
        }         

        var token = new SC.SqlAuthenticationToken(result.AccessToken, result.ExpiresOn);

        return token;
    }

    public override bool IsSupported(SC.SqlAuthenticationMethod authenticationMethod)
    {
        return authenticationMethod == SC.SqlAuthenticationMethod.ActiveDirectoryInteractive;
    }
}

There are a few things different here:

  1. I've added an in-memory Token Cache
  2. I've moved the AuthContext to a property on the class to leave it in memory between runs
  3. I've set the _redirectURL property = http://localhost:8089
  4. I've added a silent check for the token, before reverting

Finally, I've created my own implementation of the ICustomWebUi that handles loading the browser login and the response:

CustomWebUi

internal class CustomWebUi : ICustomWebUi
{
    public async Task<Uri> AcquireAuthorizationCodeAsync(Uri authorizationUri, Uri redirectUri)
    {
        using (var listener = new SingleMessageTcpListener(redirectUri.Port))
        {
            Uri authCode = null;
            var listenerTask = listener.ListenToSingleRequestAndRespondAsync(u => {
                authCode = u;
                
                return @"
<html>
<body>
    <p>Successfully Authenticated, you may now close this window</p>
</body>
</html>";
            }, System.Threading.CancellationToken.None);

            var ps = new ProcessStartInfo(authorizationUri.ToString())
            { 
                UseShellExecute = true, 
                Verb = "open" 
            };
            Process.Start(ps);

            await listenerTask;

            return authCode;
        }            
    }
}

Because I've set the redirect back to localhost, and this code lives inside a console application, I need to listen on the port for the response and capture it in the app, then display a value to the browser to indicate it all worked.

To listen to the port, I used a listener class cribbed from the MS Github:

SingleMessageTcpListener

/// <summary>
/// This object is responsible for listening to a single TCP request, on localhost:port, 
/// extracting the uri, parsing 
/// </summary>
/// <remarks>
/// The underlying TCP listener might capture multiple requests, but only the first one is handled.
///
/// Cribbed this class from https://github.com/AzureAD/microsoft-authentication-library-for-dotnet/blob/9e0f57b53edfdcf027cbff401d3ca6c02e95ef1b/tests/devapps/NetCoreTestApp/Experimental/SingleMessageTcpListener.cs
/// </remarks>
internal class SingleMessageTcpListener : IDisposable
{
    private readonly int _port;
    private readonly System.Net.Sockets.TcpListener _tcpListener;

    public SingleMessageTcpListener(int port)
    {
        if (port < 1 || port == 80)
        {
            throw new ArgumentOutOfRangeException("Expected a valid port number, > 0, not 80");
        }

        _port = port;
        _tcpListener = new System.Net.Sockets.TcpListener(IPAddress.Loopback, _port);
        

    }

    public async Task ListenToSingleRequestAndRespondAsync(
        Func<Uri, string> responseProducer,
        CancellationToken cancellationToken)
    {
        cancellationToken.Register(() => _tcpListener.Stop());
        _tcpListener.Start();

        TcpClient tcpClient = null;
        try
        {
            tcpClient =
                await AcceptTcpClientAsync(cancellationToken)
                .ConfigureAwait(false);

            await ExtractUriAndRespondAsync(tcpClient, responseProducer, cancellationToken).ConfigureAwait(false);

        }
        finally
        {
            tcpClient?.Close();
        }
    }

    /// <summary>
    /// AcceptTcpClientAsync does not natively support cancellation, so use this wrapper. Make sure
    /// the cancellation token is registered to stop the listener.
    /// </summary>
    /// <remarks>See https://stackoverflow.com/questions/19220957/tcplistener-how-to-stop-listening-while-awaiting-accepttcpclientasync</remarks>
    private async Task<TcpClient> AcceptTcpClientAsync(CancellationToken token)
    {
        try
        {
            return await _tcpListener.AcceptTcpClientAsync().ConfigureAwait(false);
        }
        catch (Exception ex) when (token.IsCancellationRequested)
        {
            throw new OperationCanceledException("Cancellation was requested while awaiting TCP client connection.", ex);
        }
    }

    private async Task ExtractUriAndRespondAsync(
        TcpClient tcpClient,
        Func<Uri, string> responseProducer,
        CancellationToken cancellationToken)
    {
        cancellationToken.ThrowIfCancellationRequested();

        string httpRequest = await GetTcpResponseAsync(tcpClient, cancellationToken).ConfigureAwait(false);
        Uri uri = ExtractUriFromHttpRequest(httpRequest);

        // write an "OK, please close the browser message" 
        await WriteResponseAsync(responseProducer(uri), tcpClient.GetStream(), cancellationToken)
            .ConfigureAwait(false);
    }

    private Uri ExtractUriFromHttpRequest(string httpRequest)
    {
        string regexp = @"GET \/\?(.*) HTTP";
        string getQuery = null;
        Regex r1 = new Regex(regexp);
        Match match = r1.Match(httpRequest);
        if (!match.Success)
        {
            throw new InvalidOperationException("Not a GET query");
        }

        getQuery = match.Groups[1].Value;
        UriBuilder uriBuilder = new UriBuilder();
        uriBuilder.Query = getQuery;
        uriBuilder.Port = _port;

        return uriBuilder.Uri;
    }

    private static async Task<string> GetTcpResponseAsync(TcpClient client, CancellationToken cancellationToken)
    {
        NetworkStream networkStream = client.GetStream();

        byte[] readBuffer = new byte[1024];
        StringBuilder stringBuilder = new StringBuilder();
        int numberOfBytesRead = 0;

        // Incoming message may be larger than the buffer size. 
        do
        {
            numberOfBytesRead = await networkStream.ReadAsync(readBuffer, 0, readBuffer.Length, cancellationToken)
                .ConfigureAwait(false);

            string s = Encoding.ASCII.GetString(readBuffer, 0, numberOfBytesRead);
            stringBuilder.Append(s);

        }
        while (networkStream.DataAvailable);

        return stringBuilder.ToString();
    }

    private async Task WriteResponseAsync(
        string message,
        NetworkStream stream,
        CancellationToken cancellationToken)
    {
        string fullResponse = $"HTTP/1.1 200 OK\r\n\r\n{message}";
        var response = Encoding.ASCII.GetBytes(fullResponse);
        await stream.WriteAsync(response, 0, response.Length, cancellationToken).ConfigureAwait(false);
        await stream.FlushAsync(cancellationToken).ConfigureAwait(false);
    }

    public void Dispose()
    {
        _tcpListener?.Stop();
    }
}

With all this in place, the browser opens when connecting to the first database on a resource, and the token is reused between connections.

Upvotes: 1

Related Questions