diff --git a/src/Common/CosmosClient.ts b/src/Common/CosmosClient.ts index 79ed76434..79e41dda4 100644 --- a/src/Common/CosmosClient.ts +++ b/src/Common/CosmosClient.ts @@ -8,6 +8,7 @@ import { PriorityLevel } from "../Common/Constants"; import * as Logger from "../Common/Logger"; import { Platform, configContext } from "../ConfigContext"; import { updateUserContext, userContext } from "../UserContext"; +import { isDataplaneRbacSupported } from "../Utils/APITypeUtils"; import { logConsoleError } from "../Utils/NotificationConsoleUtils"; import * as PriorityBasedExecutionUtils from "../Utils/PriorityBasedExecutionUtils"; import { EmulatorMasterKey, HttpHeaders } from "./Constants"; @@ -18,7 +19,7 @@ const _global = typeof self === "undefined" ? window : self; export const tokenProvider = async (requestInfo: Cosmos.RequestInfo) => { const { verb, resourceId, resourceType, headers } = requestInfo; - const dataPlaneRBACOptionEnabled = userContext.dataPlaneRbacEnabled && userContext.apiType === "SQL"; + const dataPlaneRBACOptionEnabled = userContext.dataPlaneRbacEnabled && isDataplaneRbacSupported(userContext.apiType); if (userContext.features.enableAadDataPlane || dataPlaneRBACOptionEnabled) { Logger.logInfo( `AAD Data Plane Feature flag set to ${userContext.features.enableAadDataPlane} for account with disable local auth ${userContext.databaseAccount.properties.disableLocalAuth} `, diff --git a/src/Explorer/Panes/SettingsPane/SettingsPane.tsx b/src/Explorer/Panes/SettingsPane/SettingsPane.tsx index a19c89be2..1626f09d1 100644 --- a/src/Explorer/Panes/SettingsPane/SettingsPane.tsx +++ b/src/Explorer/Panes/SettingsPane/SettingsPane.tsx @@ -32,6 +32,7 @@ import { } from "Shared/StorageUtility"; import * as StringUtility from "Shared/StringUtility"; import { updateUserContext, userContext } from "UserContext"; +import { isDataplaneRbacSupported } from "Utils/APITypeUtils"; import { acquireMsalTokenForAccount } from "Utils/AuthorizationUtils"; import { logConsoleError, logConsoleInfo } from "Utils/NotificationConsoleUtils"; import * as PriorityBasedExecutionUtils from "Utils/PriorityBasedExecutionUtils"; @@ -183,7 +184,7 @@ export const SettingsPane: FunctionComponent<{ explorer: Explorer }> = ({ const shouldShowCrossPartitionOption = userContext.apiType !== "Gremlin" && !isEmulator; const shouldShowParallelismOption = userContext.apiType !== "Gremlin" && !isEmulator; const showEnableEntraIdRbac = - userContext.apiType === "SQL" && + isDataplaneRbacSupported(userContext.apiType) && userContext.authType === AuthType.AAD && configContext.platform !== Platform.Fabric && !isEmulator; diff --git a/src/Utils/APITypeUtils.ts b/src/Utils/APITypeUtils.ts index b25ca1c29..aa88ecf66 100644 --- a/src/Utils/APITypeUtils.ts +++ b/src/Utils/APITypeUtils.ts @@ -89,3 +89,7 @@ export const getItemName = (): string => { return "Items"; } }; + +export const isDataplaneRbacSupported = (apiType: string): boolean => { + return apiType === "SQL" || apiType === "Tables"; +}; diff --git a/src/hooks/useKnockoutExplorer.ts b/src/hooks/useKnockoutExplorer.ts index 2d12f3af4..70a772559 100644 --- a/src/hooks/useKnockoutExplorer.ts +++ b/src/hooks/useKnockoutExplorer.ts @@ -13,6 +13,7 @@ import { readSubComponentState, } from "Shared/AppStatePersistenceUtility"; import { LocalStorageUtility, StorageKey } from "Shared/StorageUtility"; +import { isDataplaneRbacSupported } from "Utils/APITypeUtils"; import { logConsoleError } from "Utils/NotificationConsoleUtils"; import { useQueryCopilot } from "hooks/useQueryCopilot"; import { ReactTabKind, useTabs } from "hooks/useTabs"; @@ -299,7 +300,7 @@ async function configureHostedWithAAD(config: AAD): Promise { ); if (!userContext.features.enableAadDataPlane) { Logger.logInfo(`AAD Feature flag is not enabled for account ${account.name}`, "Explorer/configureHostedWithAAD"); - if (userContext.apiType === "SQL") { + if (isDataplaneRbacSupported(userContext.apiType)) { if (LocalStorageUtility.hasItem(StorageKey.DataPlaneRbacEnabled)) { const isDataPlaneRbacSetting = LocalStorageUtility.getEntryString(StorageKey.DataPlaneRbacEnabled); Logger.logInfo(