Add changes for ARM/AAD Token renewal

This commit is contained in:
Senthamil Sindhu 2024-11-21 16:01:49 -08:00
parent d12f10afd8
commit 66a1ca321d
5 changed files with 103 additions and 16 deletions

View File

@ -8,11 +8,12 @@ import { AuthType } from "../AuthType";
import { BackendApi, PriorityLevel } from "../Common/Constants";
import * as Logger from "../Common/Logger";
import { Platform, configContext } from "../ConfigContext";
import { userContext } from "../UserContext";
import { updateUserContext, userContext } from "../UserContext";
import { logConsoleError } from "../Utils/NotificationConsoleUtils";
import * as PriorityBasedExecutionUtils from "../Utils/PriorityBasedExecutionUtils";
import { EmulatorMasterKey, HttpHeaders } from "./Constants";
import { getErrorMessage } from "./ErrorHandlingUtils";
import { runCommand } from "hooks/useDatabaseAccounts";
const _global = typeof self === "undefined" ? window : self;
@ -32,6 +33,7 @@ export const tokenProvider = async (requestInfo: Cosmos.RequestInfo) => {
return null;
}
const AUTH_PREFIX = `type=aad&ver=1.0&sig=`;
console.log("AAD Token ", userContext.aadToken);
const authorizationToken = `${AUTH_PREFIX}${userContext.aadToken}`;
return authorizationToken;
}
@ -209,10 +211,14 @@ export function client(): Cosmos.CosmosClient {
_defaultHeaders["x-ms-cosmos-priority-level"] = PriorityLevel.Default;
}
const wrappedTokenProvider = async (requestInfo: Cosmos.RequestInfo) => {
return await runCommand(tokenProvider, requestInfo);
};
const options: Cosmos.CosmosClientOptions = {
endpoint: endpoint() || "https://cosmos.azure.com", // CosmosClient gets upset if we pass a bad URL. This should never actually get called
key: userContext.dataPlaneRbacEnabled ? "" : userContext.masterKey,
tokenProvider,
tokenProvider: wrappedTokenProvider,
userAgentSuffix: "Azure Portal",
defaultHeaders: _defaultHeaders,
connectionPolicy: {

View File

@ -81,6 +81,7 @@ export interface UserContext {
readonly endpoint?: string;
readonly aadToken?: string;
readonly accessToken?: string;
readonly armToken?: string;
readonly authorizationToken?: string;
readonly resourceToken?: string;
readonly subscriptionType?: SubscriptionType;

View File

@ -3,6 +3,7 @@ import { useBoolean } from "@fluentui/react-hooks";
import * as React from "react";
import { configContext } from "../ConfigContext";
import { acquireTokenWithMsal, getMsalInstance } from "../Utils/AuthorizationUtils";
import { updateUserContext } from "UserContext";
const msalInstance = await getMsalInstance();
@ -79,7 +80,7 @@ export function useAADAuth(): ReturnType {
authority: `${configContext.AAD_ENDPOINT}${tenantId}`,
scopes: [`${configContext.ARM_ENDPOINT}/.default`],
});
updateUserContext({ armToken: armToken });
setArmToken(armToken);
setAuthFailure(null);
} catch (error) {

View File

@ -1,6 +1,8 @@
import { HttpHeaders } from "Common/Constants";
import { QueryRequestOptions, QueryResponse } from "Contracts/AzureResourceGraph";
import useSWR from "swr";
import { updateUserContext, userContext } from "UserContext";
import { acquireTokenWithMsal, getMsalInstance } from "Utils/AuthorizationUtils";
import { configContext } from "../ConfigContext";
import { DatabaseAccount } from "../Contracts/DataModels";
/* eslint-disable @typescript-eslint/no-explicit-any */
@ -33,12 +35,9 @@ export async function fetchDatabaseAccounts(subscriptionId: string, accessToken:
return accounts.sort((a, b) => a.name.localeCompare(b.name));
}
export async function fetchDatabaseAccountsFromGraph(
subscriptionId: string,
accessToken: string,
): Promise<DatabaseAccount[]> {
export async function fetchDatabaseAccountsFromGraph(subscriptionId: string): Promise<DatabaseAccount[]> {
const headers = new Headers();
const bearer = `Bearer ${accessToken}`;
const bearer = `Bearer ${userContext.armToken}`;
headers.append("Authorization", bearer);
headers.append(HttpHeaders.contentType, "application/json");
@ -46,8 +45,9 @@ export async function fetchDatabaseAccountsFromGraph(
const apiVersion = "2021-03-01";
const managementResourceGraphAPIURL = `${configContext.ARM_ENDPOINT}providers/Microsoft.ResourceGraph/resources?api-version=${apiVersion}`;
const databaseAccounts: DatabaseAccount[] = [];
let databaseAccounts: DatabaseAccount[] = [];
let skipToken: string;
console.log("Old ARM Token - fetchDatabaseAccountsFromGraph function", userContext.armToken);
do {
const body = {
query: databaseAccountsQuery,
@ -85,10 +85,81 @@ export async function fetchDatabaseAccountsFromGraph(
return databaseAccounts.sort((a, b) => a.name.localeCompare(b.name));
}
export function useDatabaseAccounts(subscriptionId: string, armToken: string): DatabaseAccount[] | undefined {
export function useDatabaseAccounts(subscriptionId: string): DatabaseAccount[] | undefined {
const { data } = useSWR(
() => (armToken && subscriptionId ? ["databaseAccounts", subscriptionId, armToken] : undefined),
(_, subscriptionId, armToken) => fetchDatabaseAccountsFromGraph(subscriptionId, armToken),
() => (subscriptionId ? ["databaseAccounts", subscriptionId] : undefined),
(_, subscriptionId) => runCommand(fetchDatabaseAccountsFromGraph, subscriptionId),
);
return data;
}
// Define the types for your responses
interface DatabaseAccount {
name: string;
id: string;
// Add other relevant fields as per your use case
}
interface QueryRequestOptions {
$top?: number;
$skipToken?: string;
$allowPartialScopes?: boolean;
}
// Define the configuration context and headers if not already defined
const configContext = {
ARM_ENDPOINT: "https://management.azure.com/",
AAD_ENDPOINT: "https://login.microsoftonline.com/",
};
interface QueryResponse {
data?: any[];
$skipToken?: string;
}
export async function runCommand<T>(fn: (...args: any[]) => Promise<T>, ...args: any[]): Promise<T> {
try {
// Attempt to execute the function passed as an argument
const result = await fn(...args);
console.log("Successfully executed function:", fn.name, result);
return result;
} catch (error) {
// Handle any error that is thrown during the execution of the function
if (error) {
console.log("Creating new token");
const msalInstance = await getMsalInstance();
const cachedAccount = msalInstance.getAllAccounts()?.[0];
const cachedTenantId = localStorage.getItem("cachedTenantId");
msalInstance.setActiveAccount(cachedAccount);
// TODO: Add condition to check if the ARM token needs to be renewed, then we need to run the code below for creating the ARM token
console.log("Creating new ARM token");
const newAccessToken = await acquireTokenWithMsal(msalInstance, {
authority: `${configContext.AAD_ENDPOINT}${cachedTenantId}`,
scopes: [`${configContext.ARM_ENDPOINT}/.default`],
});
updateUserContext({ armToken: newAccessToken });
// TODO: add condition to check if AAD token needs to be renewed (i.e) Token provider has failed with expired AAD token and create a new AAD Token using the below code
// const hrefEndpoint = new URL(userContext.databaseAccount.properties.documentEndpoint).href.replace(/\/$/, "/.default");
// console.log('Creating new AAD token');
// let aadToken = await acquireTokenWithMsal(msalInstance, {
// forceRefresh: true,
// authority: `${configContext.AAD_ENDPOINT}${cachedTenantId}`,
// scopes: [hrefEndpoint],
// });
// updateUserContext({aadToken: aadToken});
//console.log('Latest AAD Token', fn.name, userContext.aadToken);
const result = await fn(...args);
return result;
} else {
console.error("An error occurred:", error.message);
throw new error();
}
}
}

View File

@ -3,6 +3,8 @@ import { QueryRequestOptions, QueryResponse } from "Contracts/AzureResourceGraph
import useSWR from "swr";
import { configContext } from "../ConfigContext";
import { Subscription } from "../Contracts/DataModels";
import { runCommand } from "hooks/useDatabaseAccounts";
import { userContext } from "UserContext";
/* eslint-disable @typescript-eslint/no-explicit-any */
interface SubscriptionListResult {
@ -35,9 +37,9 @@ export async function fetchSubscriptions(accessToken: string): Promise<Subscript
return subscriptions.sort((a, b) => a.displayName.localeCompare(b.displayName));
}
export async function fetchSubscriptionsFromGraph(accessToken: string): Promise<Subscription[]> {
export async function fetchSubscriptionsFromGraph(): Promise<Subscription[]> {
const headers = new Headers();
const bearer = `Bearer ${accessToken}`;
const bearer = `Bearer ${userContext.armToken}`;
headers.append("Authorization", bearer);
headers.append(HttpHeaders.contentType, "application/json");
@ -48,6 +50,7 @@ export async function fetchSubscriptionsFromGraph(accessToken: string): Promise<
const subscriptions: Subscription[] = [];
let skipToken: string;
console.log("Old ARM Token fetchSubscriptionsFromGraph fn", userContext.armToken);
do {
const body = {
query: subscriptionsQuery,
@ -86,9 +89,14 @@ export async function fetchSubscriptionsFromGraph(accessToken: string): Promise<
}
export function useSubscriptions(armToken: string): Subscription[] | undefined {
const { data } = useSWR(
const { data, error } = useSWR(
() => (armToken ? ["subscriptions", armToken] : undefined),
(_, armToken) => fetchSubscriptionsFromGraph(armToken),
(_) => runCommand(fetchSubscriptionsFromGraph),
);
if (error) {
console.error("Error fetching subscriptions:", error);
}
return data;
}