Enable RBAC support for MongoDB and Cassandra APIs (#2198)

* enable RBAC support for Mongo & Cassandra API

* fix formatting issue

* Handling AAD integration for Mongo Shell

* remove empty aadToken error

* fix formatting issue

* added environment specific scope endpoints
This commit is contained in:
BChoudhury-ms 2025-09-19 01:25:35 +05:30 committed by GitHub
parent cfb5db4df6
commit 76e63818d3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 371 additions and 91 deletions

View File

@ -138,6 +138,14 @@ export enum MongoBackendEndpointType {
remote, remote,
} }
export class AadScopeEndpoints {
public static readonly Development: string = "https://cosmos.azure.com";
public static readonly MPAC: string = "https://cosmos.azure.com";
public static readonly Prod: string = "https://cosmos.azure.com";
public static readonly Fairfax: string = "https://cosmos.azure.us";
public static readonly Mooncake: string = "https://cosmos.azure.cn";
}
export class PortalBackendEndpoints { export class PortalBackendEndpoints {
public static readonly Development: string = "https://localhost:7235"; public static readonly Development: string = "https://localhost:7235";
public static readonly Mpac: string = "https://cdb-ms-mpac-pbe.cosmos.azure.com"; public static readonly Mpac: string = "https://cdb-ms-mpac-pbe.cosmos.azure.com";
@ -255,6 +263,7 @@ export class HttpHeaders {
public static activityId: string = "x-ms-activity-id"; public static activityId: string = "x-ms-activity-id";
public static apiType: string = "x-ms-cosmos-apitype"; public static apiType: string = "x-ms-cosmos-apitype";
public static authorization: string = "authorization"; public static authorization: string = "authorization";
public static entraIdToken: string = "x-ms-entraid-token";
public static collectionIndexTransformationProgress: string = public static collectionIndexTransformationProgress: string =
"x-ms-documentdb-collection-index-transformation-progress"; "x-ms-documentdb-collection-index-transformation-progress";
public static continuation: string = "x-ms-continuation"; public static continuation: string = "x-ms-continuation";

View File

@ -28,3 +28,39 @@ describe("Environment Utility Test", () => {
expect(EnvironmentUtility.getEnvironment()).toBe(EnvironmentUtility.Environment.Development); expect(EnvironmentUtility.getEnvironment()).toBe(EnvironmentUtility.Environment.Development);
}); });
}); });
describe("normalizeArmEndpoint", () => {
it("should append '/' if not present", () => {
expect(EnvironmentUtility.normalizeArmEndpoint("https://example.com")).toBe("https://example.com/");
});
it("should return the same uri if '/' is present at the end", () => {
expect(EnvironmentUtility.normalizeArmEndpoint("https://example.com/")).toBe("https://example.com/");
});
it("should handle empty string", () => {
expect(EnvironmentUtility.normalizeArmEndpoint("")).toBe("");
});
});
describe("getEnvironment", () => {
it("should return Prod environment", () => {
updateConfigContext({
PORTAL_BACKEND_ENDPOINT: PortalBackendEndpoints.Prod,
});
expect(EnvironmentUtility.getEnvironment()).toBe(EnvironmentUtility.Environment.Prod);
});
it("should return Fairfax environment", () => {
updateConfigContext({
PORTAL_BACKEND_ENDPOINT: PortalBackendEndpoints.Fairfax,
});
expect(EnvironmentUtility.getEnvironment()).toBe(EnvironmentUtility.Environment.Fairfax);
});
it("should return Mooncake environment", () => {
updateConfigContext({
PORTAL_BACKEND_ENDPOINT: PortalBackendEndpoints.Mooncake,
});
expect(EnvironmentUtility.getEnvironment()).toBe(EnvironmentUtility.Environment.Mooncake);
});
});

View File

@ -1,4 +1,5 @@
import { PortalBackendEndpoints } from "Common/Constants"; import { AadScopeEndpoints, PortalBackendEndpoints } from "Common/Constants";
import * as Logger from "Common/Logger";
import { configContext } from "ConfigContext"; import { configContext } from "ConfigContext";
export function normalizeArmEndpoint(uri: string): string { export function normalizeArmEndpoint(uri: string): string {
@ -27,3 +28,17 @@ export const getEnvironment = (): Environment => {
return environmentMap[configContext.PORTAL_BACKEND_ENDPOINT]; return environmentMap[configContext.PORTAL_BACKEND_ENDPOINT];
}; };
export const getEnvironmentScopeEndpoint = (): string => {
const environment = getEnvironment();
const endpoint = AadScopeEndpoints[environment];
if (!endpoint) {
throw new Error("Cannot determine AAD scope endpoint");
}
const hrefEndpoint = new URL(endpoint).href.replace(/\/+$/, "/.default");
Logger.logInfo(
`Using AAD scope endpoint: ${hrefEndpoint}, Environment: ${environment}`,
"EnvironmentUtility/getEnvironmentScopeEndpoint",
);
return hrefEndpoint;
};

View File

@ -7,6 +7,7 @@ import { MessageTypes } from "../Contracts/ExplorerContracts";
import { Collection } from "../Contracts/ViewModels"; import { Collection } from "../Contracts/ViewModels";
import DocumentId from "../Explorer/Tree/DocumentId"; import DocumentId from "../Explorer/Tree/DocumentId";
import { userContext } from "../UserContext"; import { userContext } from "../UserContext";
import { isDataplaneRbacEnabledForProxyApi } from "../Utils/AuthorizationUtils";
import { logConsoleError } from "../Utils/NotificationConsoleUtils"; import { logConsoleError } from "../Utils/NotificationConsoleUtils";
import { ApiType, ContentType, HttpHeaders, HttpStatusCodes } from "./Constants"; import { ApiType, ContentType, HttpHeaders, HttpStatusCodes } from "./Constants";
import { MinimalQueryIterator } from "./IteratorUtilities"; import { MinimalQueryIterator } from "./IteratorUtilities";
@ -22,7 +23,13 @@ function authHeaders() {
if (userContext.authType === AuthType.EncryptedToken) { if (userContext.authType === AuthType.EncryptedToken) {
return { [HttpHeaders.guestAccessToken]: userContext.accessToken }; return { [HttpHeaders.guestAccessToken]: userContext.accessToken };
} else { } else {
return { [HttpHeaders.authorization]: userContext.authorizationToken }; const headers: { [key: string]: string } = {
[HttpHeaders.authorization]: userContext.authorizationToken,
};
if (isDataplaneRbacEnabledForProxyApi(userContext)) {
headers[HttpHeaders.entraIdToken] = userContext.aadToken;
}
return headers;
} }
} }

View File

@ -5,6 +5,7 @@
*/ */
import { CommandBar as FluentCommandBar, ICommandBarItemProps } from "@fluentui/react"; import { CommandBar as FluentCommandBar, ICommandBarItemProps } from "@fluentui/react";
import { useNotebook } from "Explorer/Notebook/useNotebook"; import { useNotebook } from "Explorer/Notebook/useNotebook";
import { useDataPlaneRbac } from "Explorer/Panes/SettingsPane/SettingsPane";
import { KeyboardActionGroup, useKeyboardActionGroup } from "KeyboardShortcuts"; import { KeyboardActionGroup, useKeyboardActionGroup } from "KeyboardShortcuts";
import { isFabric } from "Platform/Fabric/FabricUtil"; import { isFabric } from "Platform/Fabric/FabricUtil";
import { userContext } from "UserContext"; import { userContext } from "UserContext";
@ -30,7 +31,7 @@ export interface CommandBarStore {
} }
export const useCommandBar: UseStore<CommandBarStore> = create((set) => ({ export const useCommandBar: UseStore<CommandBarStore> = create((set) => ({
contextButtons: [], contextButtons: [] as CommandButtonComponentProps[],
setContextButtons: (contextButtons: CommandButtonComponentProps[]) => set((state) => ({ ...state, contextButtons })), setContextButtons: (contextButtons: CommandButtonComponentProps[]) => set((state) => ({ ...state, contextButtons })),
isHidden: false, isHidden: false,
setIsHidden: (isHidden: boolean) => set((state) => ({ ...state, isHidden })), setIsHidden: (isHidden: boolean) => set((state) => ({ ...state, isHidden })),
@ -43,6 +44,15 @@ export const CommandBar: React.FC<Props> = ({ container }: Props) => {
const backgroundColor = StyleConstants.BaseLight; const backgroundColor = StyleConstants.BaseLight;
const setKeyboardHandlers = useKeyboardActionGroup(KeyboardActionGroup.COMMAND_BAR); const setKeyboardHandlers = useKeyboardActionGroup(KeyboardActionGroup.COMMAND_BAR);
// Subscribe to the store changes that affect button creation
const dataPlaneRbacEnabled = useDataPlaneRbac((state) => state.dataPlaneRbacEnabled);
const aadTokenUpdated = useDataPlaneRbac((state) => state.aadTokenUpdated);
// Memoize the expensive button creation
const staticButtons = React.useMemo(() => {
return CommandBarComponentButtonFactory.createStaticCommandBarButtons(container, selectedNodeState);
}, [container, selectedNodeState, dataPlaneRbacEnabled, aadTokenUpdated]);
if (userContext.apiType === "Postgres" || userContext.apiType === "VCoreMongo") { if (userContext.apiType === "Postgres" || userContext.apiType === "VCoreMongo") {
const buttons = const buttons =
userContext.apiType === "Postgres" userContext.apiType === "Postgres"
@ -62,7 +72,6 @@ export const CommandBar: React.FC<Props> = ({ container }: Props) => {
); );
} }
const staticButtons = CommandBarComponentButtonFactory.createStaticCommandBarButtons(container, selectedNodeState);
const contextButtons = (buttons || []).concat( const contextButtons = (buttons || []).concat(
CommandBarComponentButtonFactory.createContextCommandBarButtons(container, selectedNodeState), CommandBarComponentButtonFactory.createContextCommandBarButtons(container, selectedNodeState),
); );

View File

@ -1,7 +1,6 @@
import { KeyboardAction } from "KeyboardShortcuts"; import { KeyboardAction } from "KeyboardShortcuts";
import { isDataplaneRbacSupported } from "Utils/APITypeUtils"; import { isDataplaneRbacSupported } from "Utils/APITypeUtils";
import * as React from "react"; import * as React from "react";
import { useEffect, useState } from "react";
import AddSqlQueryIcon from "../../../../images/AddSqlQuery_16x16.svg"; import AddSqlQueryIcon from "../../../../images/AddSqlQuery_16x16.svg";
import AddStoredProcedureIcon from "../../../../images/AddStoredProcedure.svg"; import AddStoredProcedureIcon from "../../../../images/AddStoredProcedure.svg";
import AddTriggerIcon from "../../../../images/AddTrigger.svg"; import AddTriggerIcon from "../../../../images/AddTrigger.svg";
@ -68,15 +67,7 @@ export function createStaticCommandBarButtons(
} }
if (isDataplaneRbacSupported(userContext.apiType)) { if (isDataplaneRbacSupported(userContext.apiType)) {
const [loginButtonProps, setLoginButtonProps] = useState<CommandButtonComponentProps | undefined>(undefined); const loginButtonProps = createLoginForEntraIDButton(container);
const dataPlaneRbacEnabled = useDataPlaneRbac((state) => state.dataPlaneRbacEnabled);
const aadTokenUpdated = useDataPlaneRbac((state) => state.aadTokenUpdated);
useEffect(() => {
const buttonProps = createLoginForEntraIDButton(container);
setLoginButtonProps(buttonProps);
}, [dataPlaneRbacEnabled, aadTokenUpdated, container]);
if (loginButtonProps) { if (loginButtonProps) {
addDivider(); addDivider();
buttons.push(loginButtonProps); buttons.push(loginButtonProps);

View File

@ -13,7 +13,7 @@ import { updateDocument } from "../../Common/dataAccess/updateDocument";
import { configContext } from "../../ConfigContext"; import { configContext } from "../../ConfigContext";
import * as ViewModels from "../../Contracts/ViewModels"; import * as ViewModels from "../../Contracts/ViewModels";
import { userContext } from "../../UserContext"; import { userContext } from "../../UserContext";
import { getAuthorizationHeader } from "../../Utils/AuthorizationUtils"; import { getAuthorizationHeader, isDataplaneRbacEnabledForProxyApi } from "../../Utils/AuthorizationUtils";
import * as NotificationConsoleUtils from "../../Utils/NotificationConsoleUtils"; import * as NotificationConsoleUtils from "../../Utils/NotificationConsoleUtils";
import { logConsoleInfo, logConsoleProgress } from "../../Utils/NotificationConsoleUtils"; import { logConsoleInfo, logConsoleProgress } from "../../Utils/NotificationConsoleUtils";
import Explorer from "../Explorer"; import Explorer from "../Explorer";
@ -551,6 +551,10 @@ export class CassandraAPIDataClient extends TableDataClient {
const authorizationHeaderMetadata: ViewModels.AuthorizationTokenHeaderMetadata = getAuthorizationHeader(); const authorizationHeaderMetadata: ViewModels.AuthorizationTokenHeaderMetadata = getAuthorizationHeader();
xhr.setRequestHeader(authorizationHeaderMetadata.header, authorizationHeaderMetadata.token); xhr.setRequestHeader(authorizationHeaderMetadata.header, authorizationHeaderMetadata.token);
if (isDataplaneRbacEnabledForProxyApi(userContext)) {
xhr.setRequestHeader(Constants.HttpHeaders.entraIdToken, userContext.aadToken);
}
return true; return true;
}; };

View File

@ -24,7 +24,7 @@ export const EXIT_COMMAND_MONGO = ` printf "\\033[1;31mSession ended. Please clo
* This command runs mongosh in no-database and quiet mode, * This command runs mongosh in no-database and quiet mode,
* and evaluates the `disableTelemetry()` function to turn off telemetry collection. * and evaluates the `disableTelemetry()` function to turn off telemetry collection.
*/ */
export const DISABLE_TELEMETRY_COMMAND = `mongosh --nodb --quiet --eval "disableTelemetry()"`; export const DISABLE_TELEMETRY_COMMAND = `mongosh --nodb --quiet --eval 'disableTelemetry()'`;
/** /**
* Abstract class that defines the interface for shell-specific handlers * Abstract class that defines the interface for shell-specific handlers
@ -97,7 +97,7 @@ export abstract class AbstractShellHandler {
* is not already present in the environment. * is not already present in the environment.
*/ */
protected mongoShellSetupCommands(): string[] { protected mongoShellSetupCommands(): string[] {
const PACKAGE_VERSION: string = "2.5.5"; const PACKAGE_VERSION: string = "2.5.6";
return [ return [
"if ! command -v mongosh &> /dev/null; then echo '⚠️ mongosh not found. Installing...'; fi", "if ! command -v mongosh &> /dev/null; then echo '⚠️ mongosh not found. Installing...'; fi",
`if ! command -v mongosh &> /dev/null; then curl -LO https://downloads.mongodb.com/compass/mongosh-${PACKAGE_VERSION}-linux-x64.tgz; fi`, `if ! command -v mongosh &> /dev/null; then curl -LO https://downloads.mongodb.com/compass/mongosh-${PACKAGE_VERSION}-linux-x64.tgz; fi`,

View File

@ -18,6 +18,12 @@ interface DatabaseAccount {
interface UserContextType { interface UserContextType {
databaseAccount: DatabaseAccount; databaseAccount: DatabaseAccount;
features: {
enableAadDataPlane: boolean;
};
apiType: string;
dataPlaneRbacEnabled: boolean;
aadToken?: string;
} }
// Mock dependencies // Mock dependencies
@ -29,6 +35,8 @@ jest.mock("../../../../UserContext", () => ({
mongoEndpoint: "https://test-mongo.documents.azure.com:443/", mongoEndpoint: "https://test-mongo.documents.azure.com:443/",
}, },
}, },
features: { enableAadDataPlane: false },
apiType: "Mongo",
}, },
})); }));
@ -70,7 +78,7 @@ describe("MongoShellHandler", () => {
expect(Array.isArray(commands)).toBe(true); expect(Array.isArray(commands)).toBe(true);
expect(commands.length).toBe(7); expect(commands.length).toBe(7);
expect(commands[1]).toContain("mongosh-2.5.5-linux-x64.tgz"); expect(commands[1]).toContain("mongosh-2.5.6-linux-x64.tgz");
}); });
}); });
@ -88,11 +96,12 @@ describe("MongoShellHandler", () => {
kind: "test-kind", kind: "test-kind",
properties: { mongoEndpoint: "https://test-mongo.documents.azure.com:443/" }, properties: { mongoEndpoint: "https://test-mongo.documents.azure.com:443/" },
}; };
(userContext as UserContextType).dataPlaneRbacEnabled = false;
const command = mongoShellHandler.getConnectionCommand(); const command = mongoShellHandler.getConnectionCommand();
expect(command).toBe( expect(command).toBe(
'mongosh --nodb --quiet --eval "disableTelemetry()" && mongosh mongodb://test-mongo.documents.azure.com:10255?appName=CosmosExplorerTerminal --username test-account --password test-key --tls --tlsAllowInvalidCertificates', "mongosh --nodb --quiet --eval 'disableTelemetry()'; mongosh mongodb://test-mongo.documents.azure.com:10255?appName=CosmosExplorerTerminal --username test-account --password test-key --tls --tlsAllowInvalidCertificates",
); );
expect(CommonUtils.getHostFromUrl).toHaveBeenCalledWith("https://test-mongo.documents.azure.com:443/"); expect(CommonUtils.getHostFromUrl).toHaveBeenCalledWith("https://test-mongo.documents.azure.com:443/");
@ -115,12 +124,47 @@ describe("MongoShellHandler", () => {
}; };
const command = mongoShellHandler.getConnectionCommand(); const command = mongoShellHandler.getConnectionCommand();
expect(command).toBe("echo 'Database name not found.'"); expect(command).toBe("echo 'Database name not found.'");
// Restore original // Restore original
(userContext as UserContextType).databaseAccount = originalDatabaseAccount; (userContext as UserContextType).databaseAccount = originalDatabaseAccount;
}); });
it("should return echo if endpoint is missing", () => {
const testKey = "test-key";
(userContext as UserContextType).databaseAccount = {
id: "test-id",
name: "", // Empty name to simulate missing name
location: "test-location",
type: "test-type",
kind: "test-kind",
properties: { mongoEndpoint: "" },
};
const mongoShellHandler = new MongoShellHandler(testKey);
const command = mongoShellHandler.getConnectionCommand();
expect(command).toBe("echo 'MongoDB endpoint not found.'");
});
it("should use _getAadConnectionCommand when _isEntraIdEnabled is true", () => {
const testKey = "aad-key";
(userContext as UserContextType).databaseAccount = {
id: "test-id",
name: "test-account",
location: "test-location",
type: "test-type",
kind: "test-kind",
properties: { mongoEndpoint: "https://test-mongo.documents.azure.com:443/" },
};
(userContext as UserContextType).dataPlaneRbacEnabled = true;
const mongoShellHandler = new MongoShellHandler(testKey);
const command = mongoShellHandler.getConnectionCommand();
expect(command).toContain(
"mongosh 'mongodb://test-account:aad-key@test-account.mongo.cosmos.azure.com:10255/?ssl=true&replicaSet=globaldb&authMechanism=PLAIN&retryWrites=false' --tls --tlsAllowInvalidCertificates",
);
expect(command.startsWith("mongosh --nodb")).toBeTruthy();
});
}); });
describe("getTerminalSuppressedData", () => { describe("getTerminalSuppressedData", () => {

View File

@ -1,4 +1,5 @@
import { userContext } from "../../../../UserContext"; import { userContext } from "../../../../UserContext";
import { isDataplaneRbacEnabledForProxyApi } from "../../../../Utils/AuthorizationUtils";
import { filterAndCleanTerminalOutput, getHostFromUrl, getMongoShellRemoveInfoText } from "../Utils/CommonUtils"; import { filterAndCleanTerminalOutput, getHostFromUrl, getMongoShellRemoveInfoText } from "../Utils/CommonUtils";
import { AbstractShellHandler, DISABLE_TELEMETRY_COMMAND, EXIT_COMMAND_MONGO } from "./AbstractShellHandler"; import { AbstractShellHandler, DISABLE_TELEMETRY_COMMAND, EXIT_COMMAND_MONGO } from "./AbstractShellHandler";
@ -6,12 +7,23 @@ export class MongoShellHandler extends AbstractShellHandler {
private _key: string; private _key: string;
private _endpoint: string | undefined; private _endpoint: string | undefined;
private _removeInfoText: string[] = getMongoShellRemoveInfoText(); private _removeInfoText: string[] = getMongoShellRemoveInfoText();
private _isEntraIdEnabled: boolean = isDataplaneRbacEnabledForProxyApi(userContext);
constructor(private key: string) { constructor(private key: string) {
super(); super();
this._key = key; this._key = key;
this._endpoint = userContext?.databaseAccount?.properties?.mongoEndpoint; this._endpoint = userContext?.databaseAccount?.properties?.mongoEndpoint;
} }
private _getKeyConnectionCommand(dbName: string): string {
return `mongosh mongodb://${getHostFromUrl(this._endpoint)}:10255?appName=${
this.APP_NAME
} --username ${dbName} --password ${this._key} --tls --tlsAllowInvalidCertificates`;
}
private _getAadConnectionCommand(dbName: string): string {
return `mongosh 'mongodb://${dbName}:${this._key}@${dbName}.mongo.cosmos.azure.com:10255/?ssl=true&replicaSet=globaldb&authMechanism=PLAIN&retryWrites=false' --tls --tlsAllowInvalidCertificates`;
}
public getShellName(): string { public getShellName(): string {
return "MongoDB"; return "MongoDB";
} }
@ -29,19 +41,11 @@ export class MongoShellHandler extends AbstractShellHandler {
if (!dbName) { if (!dbName) {
return "echo 'Database name not found.'"; return "echo 'Database name not found.'";
} }
return ( const connectionCommand = this._isEntraIdEnabled
DISABLE_TELEMETRY_COMMAND + ? this._getAadConnectionCommand(dbName)
" && " + : this._getKeyConnectionCommand(dbName);
"mongosh mongodb://" + const fullCommand = `${DISABLE_TELEMETRY_COMMAND}; ${connectionCommand}`;
getHostFromUrl(this._endpoint) + return fullCommand;
":10255?appName=" +
this.APP_NAME +
" --username " +
dbName +
" --password " +
this._key +
" --tls --tlsAllowInvalidCertificates"
);
} }
public getTerminalSuppressedData(): string[] { public getTerminalSuppressedData(): string[] {

View File

@ -7,12 +7,24 @@ import { PostgresShellHandler } from "./PostgresShellHandler";
import { getHandler, getKey } from "./ShellTypeFactory"; import { getHandler, getKey } from "./ShellTypeFactory";
import { VCoreMongoShellHandler } from "./VCoreMongoShellHandler"; import { VCoreMongoShellHandler } from "./VCoreMongoShellHandler";
interface UserContextType {
databaseAccount: { name: string };
subscriptionId: string;
resourceGroup: string;
features: { enableAadDataPlane: boolean };
dataPlaneRbacEnabled: boolean;
aadToken?: string;
apiType?: string;
}
// Mock dependencies // Mock dependencies
jest.mock("../../../../UserContext", () => ({ jest.mock("../../../../UserContext", () => ({
userContext: { userContext: {
databaseAccount: { name: "testDbName" }, databaseAccount: { name: "testDbName" },
subscriptionId: "testSubId", subscriptionId: "testSubId",
resourceGroup: "testResourceGroup", resourceGroup: "testResourceGroup",
features: { enableAadDataPlane: false },
dataPlaneRbacEnabled: false,
}, },
})); }));
@ -109,5 +121,33 @@ describe("ShellTypeHandlerFactory", () => {
expect(key).toBe(mockKey); expect(key).toBe(mockKey);
expect(listKeys).toHaveBeenCalledWith("testSubId", "testResourceGroup", "testDbName"); expect(listKeys).toHaveBeenCalledWith("testSubId", "testResourceGroup", "testDbName");
}); });
it("should return MongoShellHandler with primaryMasterKey for TerminalKind.Mongo when RBAC is disabled", async () => {
(listKeys as jest.Mock).mockResolvedValue({ primaryMasterKey: "primaryKey123" });
(userContext as UserContextType).features.enableAadDataPlane = false;
(userContext as UserContextType).dataPlaneRbacEnabled = false;
const handler = await getHandler(TerminalKind.Mongo);
expect(handler).toBeInstanceOf(MongoShellHandler);
// eslint-disable-next-line @typescript-eslint/ban-ts-comment
// @ts-ignore
expect(handler.key).toBe("primaryKey123");
});
it("should return MongoShellHandler with aadToken for TerminalKind.Mongo when RBAC is enabled", async () => {
(userContext as UserContextType).aadToken = "aadToken123";
(userContext as UserContextType).features.enableAadDataPlane = true;
(userContext as UserContextType).dataPlaneRbacEnabled = true;
(userContext as UserContextType).apiType = "Mongo";
const handler = await getHandler(TerminalKind.Mongo);
expect(handler).toBeInstanceOf(MongoShellHandler);
// eslint-disable-next-line @typescript-eslint/ban-ts-comment
// @ts-ignore
expect(handler.key).toBe("aadToken123");
});
it("should throw error for unsupported shell type", async () => {
await expect(getHandler("UnknownShell" as unknown as TerminalKind)).rejects.toThrow(
"Unsupported shell type: UnknownShell",
);
});
}); });
}); });

View File

@ -1,6 +1,7 @@
import { TerminalKind } from "../../../../Contracts/ViewModels"; import { TerminalKind } from "../../../../Contracts/ViewModels";
import { userContext } from "../../../../UserContext"; import { userContext } from "../../../../UserContext";
import { listKeys } from "../../../../Utils/arm/generatedClients/cosmos/databaseAccounts"; import { listKeys } from "../../../../Utils/arm/generatedClients/cosmos/databaseAccounts";
import { isDataplaneRbacEnabledForProxyApi } from "../../../../Utils/AuthorizationUtils";
import { AbstractShellHandler } from "./AbstractShellHandler"; import { AbstractShellHandler } from "./AbstractShellHandler";
import { CassandraShellHandler } from "./CassandraShellHandler"; import { CassandraShellHandler } from "./CassandraShellHandler";
import { MongoShellHandler } from "./MongoShellHandler"; import { MongoShellHandler } from "./MongoShellHandler";
@ -30,6 +31,9 @@ export async function getKey(): Promise<string> {
if (!dbName) { if (!dbName) {
return ""; return "";
} }
if (isDataplaneRbacEnabledForProxyApi(userContext)) {
return userContext.aadToken || "";
}
const keys = await listKeys(userContext.subscriptionId, userContext.resourceGroup, dbName); const keys = await listKeys(userContext.subscriptionId, userContext.resourceGroup, dbName);
return keys?.primaryMasterKey || ""; return keys?.primaryMasterKey || "";

View File

@ -45,7 +45,7 @@ describe("VCoreMongoShellHandler", () => {
expect(Array.isArray(commands)).toBe(true); expect(Array.isArray(commands)).toBe(true);
expect(commands.length).toBe(7); expect(commands.length).toBe(7);
expect(commands[1]).toContain("mongosh-2.5.5-linux-x64.tgz"); expect(commands[1]).toContain("mongosh-2.5.6-linux-x64.tgz");
expect(commands[0]).toContain("mongosh not found"); expect(commands[0]).toContain("mongosh not found");
}); });

View File

@ -92,6 +92,18 @@ export class AttachAddon implements ITerminalAddon {
* @param {Terminal} terminal - The XTerm terminal instance * @param {Terminal} terminal - The XTerm terminal instance
*/ */
public addMessageListener(terminal: Terminal): void { public addMessageListener(terminal: Terminal): void {
let messageBuffer = "";
let bufferTimeout: NodeJS.Timeout | null = null;
const BUFFER_TIMEOUT = 50; // ms - short timeout for prompt detection
const processBuffer = () => {
if (messageBuffer.length > 0) {
this.handleCompleteTerminalData(terminal, messageBuffer);
messageBuffer = "";
}
bufferTimeout = null;
};
this._disposables.push( this._disposables.push(
addSocketListener(this._socket, "message", (ev) => { addSocketListener(this._socket, "message", (ev) => {
let data: ArrayBuffer | string = ev.data; let data: ArrayBuffer | string = ev.data;
@ -103,57 +115,136 @@ export class AttachAddon implements ITerminalAddon {
data = enc.decode(ev.data as ArrayBuffer); data = enc.decode(ev.data as ArrayBuffer);
} }
// for example of json object look in TerminalHelper in the socket.onMessage // Handle status messages
if (data.includes(startStatusJson) && data.includes(endStatusJson)) { let processedStatusData = data;
// process as one line
const statusData = data.split(startStatusJson)[1].split(endStatusJson)[0];
data = data.replace(statusData, "");
data = data.replace(startStatusJson, "");
data = data.replace(endStatusJson, "");
} else if (data.includes(startStatusJson)) {
// check for start
const partialStatusData = data.split(startStatusJson)[1];
this._socketData += partialStatusData;
data = data.replace(partialStatusData, "");
data = data.replace(startStatusJson, "");
} else if (data.includes(endStatusJson)) {
// check for end and process the command
const partialStatusData = data.split(endStatusJson)[0];
this._socketData += partialStatusData;
data = data.replace(partialStatusData, "");
data = data.replace(endStatusJson, "");
this._socketData = "";
} else if (this._socketData.length > 0) {
// check if the line is all data then just concatenate
this._socketData += data;
data = "";
}
if (this._allowTerminalWrite && data.includes(this._startMarker)) { // Process status messages with delimiters
this._allowTerminalWrite = false; // eslint-disable-next-line no-constant-condition
terminal.write(`Preparing ${this._shellHandler.getShellName()} environment...\r\n`); while (true) {
} const startIndex = processedStatusData.indexOf(startStatusJson);
if (startIndex === -1) {
if (this._allowTerminalWrite) { break;
const updatedData =
typeof this._shellHandler?.updateTerminalData === "function"
? this._shellHandler.updateTerminalData(data)
: data;
const suppressedData = this._shellHandler?.getTerminalSuppressedData();
const shouldNotWrite = suppressedData.filter(Boolean).some((item) => updatedData.includes(item));
if (!shouldNotWrite) {
terminal.write(updatedData);
} }
const afterStart = processedStatusData.substring(startIndex + startStatusJson.length);
const endIndex = afterStart.indexOf(endStatusJson);
if (endIndex === -1) {
// Incomplete status message
this._socketData += processedStatusData.substring(startIndex);
processedStatusData = processedStatusData.substring(0, startIndex);
break;
}
// Remove processed status message
processedStatusData =
processedStatusData.substring(0, startIndex) + afterStart.substring(endIndex + endStatusJson.length);
} }
if (data.includes(this._shellHandler.getConnectionCommand())) { // Add to message buffer
this._allowTerminalWrite = true; messageBuffer += processedStatusData;
// Clear existing timeout
if (bufferTimeout) {
clearTimeout(bufferTimeout);
bufferTimeout = null;
}
// Check if this looks like a complete message/command
const isComplete = this.isMessageComplete(messageBuffer, processedStatusData);
if (isComplete) {
// Message marked as complete, processing immediately
processBuffer();
} else {
// Set timeout to process buffer after delay
bufferTimeout = setTimeout(processBuffer, BUFFER_TIMEOUT);
} }
}), }),
); );
// Clean up timeout on dispose
this._disposables.push({
dispose: () => {
if (bufferTimeout) {
clearTimeout(bufferTimeout);
}
},
});
}
private isMessageComplete(fullBuffer: string, currentChunk: string): boolean {
// Immediate completion indicators
const immediateCompletionPatterns = [
/\n$/, // Ends with newline
/\r$/, // Ends with carriage return
/\r\n$/, // Ends with CRLF
/; \} \|\| true;$/, // Your command pattern
/disown -a && exit$/, // Exit commands
/printf.*?\\033\[0m\\n"$/, // Your printf pattern
];
// Check current chunk for immediate completion
for (const pattern of immediateCompletionPatterns) {
if (pattern.test(currentChunk)) {
return true;
}
}
// ANSI sequence detection - these might be complete prompts
const ansiPromptPatterns = [
/\[\d+G\[0J.*>\s*\[\d+G$/, // Your specific pattern: [1G[0J...> [26G
/\[\d+;\d+H/, // Cursor position sequences
/\]\s*\[\d+G$/, // Ends with cursor positioning
/>\s*\[\d+G$/, // Prompt followed by cursor position
];
// Check if buffer ends with what looks like a complete prompt
for (const pattern of ansiPromptPatterns) {
if (pattern.test(fullBuffer)) {
return true;
}
}
// Check for MongoDB shell prompts specifically
const mongoPromptPatterns = [
/globaldb \[primary\] \w+>\s*\[\d+G$/, // MongoDB replica set prompt
/>\s*\[\d+G$/, // General prompt with cursor positioning
/\w+>\s*$/, // Simple shell prompt
];
for (const pattern of mongoPromptPatterns) {
if (pattern.test(fullBuffer)) {
return true;
}
}
return false;
}
private handleCompleteTerminalData(terminal: Terminal, data: string): void {
if (this._allowTerminalWrite && data.includes(this._startMarker)) {
this._allowTerminalWrite = false;
terminal.write(`Preparing ${this._shellHandler.getShellName()} environment...\r\n`);
}
if (this._allowTerminalWrite) {
const updatedData =
typeof this._shellHandler?.updateTerminalData === "function"
? this._shellHandler.updateTerminalData(data)
: data;
const suppressedData = this._shellHandler?.getTerminalSuppressedData();
const shouldNotWrite = suppressedData.filter(Boolean).some((item) => updatedData.includes(item));
if (!shouldNotWrite) {
terminal.write(updatedData);
}
}
if (data.includes(this._shellHandler.getConnectionCommand())) {
this._allowTerminalWrite = true;
}
} }
public dispose(): void { public dispose(): void {

View File

@ -146,10 +146,16 @@ describe("Documents tab (Mongo API)", () => {
updateConfigContext({ platform: Platform.Hosted }); updateConfigContext({ platform: Platform.Hosted });
const props: IDocumentsTabComponentProps = createMockProps(); const props: IDocumentsTabComponentProps = createMockProps();
wrapper = mount(<DocumentsTabComponent {...props} />); wrapper = mount(<DocumentsTabComponent {...props} />);
wrapper = await waitForComponentToPaint(wrapper);
}); // Wait for all pending promises
await act(async () => {
await new Promise((resolve) => setTimeout(resolve, 100));
});
// Wait for any async operations to complete
wrapper = await waitForComponentToPaint(wrapper, 100);
}, 10000);
afterEach(() => { afterEach(() => {
wrapper.unmount(); wrapper.unmount();

View File

@ -91,5 +91,11 @@ export const getItemName = (): string => {
}; };
export const isDataplaneRbacSupported = (apiType: string): boolean => { export const isDataplaneRbacSupported = (apiType: string): boolean => {
return apiType === "SQL" || apiType === "Tables" || apiType === "Gremlin"; return (
apiType === "SQL" || apiType === "Tables" || apiType === "Gremlin" || apiType === "Mongo" || apiType === "Cassandra"
);
};
export const hasProxyServer = (apiType: string): boolean => {
return apiType === "Mongo" || apiType === "Cassandra";
}; };

View File

@ -104,7 +104,7 @@ describe("AuthorizationUtils", () => {
it("should return true if dataPlaneRbacEnabled is set to true and API supports RBAC", () => { it("should return true if dataPlaneRbacEnabled is set to true and API supports RBAC", () => {
setAadDataPlane(false); setAadDataPlane(false);
["SQL", "Tables", "Gremlin"].forEach((type) => { ["SQL", "Tables", "Gremlin", "Mongo", "Cassandra"].forEach((type) => {
updateUserContext({ updateUserContext({
dataPlaneRbacEnabled: true, dataPlaneRbacEnabled: true,
apiType: type as ApiType, apiType: type as ApiType,
@ -115,7 +115,7 @@ describe("AuthorizationUtils", () => {
it("should return false if dataPlaneRbacEnabled is set to true and API does not support RBAC", () => { it("should return false if dataPlaneRbacEnabled is set to true and API does not support RBAC", () => {
setAadDataPlane(false); setAadDataPlane(false);
["Mongo", "Cassandra", "Postgres", "VCoreMongo"].forEach((type) => { ["Postgres", "VCoreMongo"].forEach((type) => {
updateUserContext({ updateUserContext({
dataPlaneRbacEnabled: true, dataPlaneRbacEnabled: true,
apiType: type as ApiType, apiType: type as ApiType,

View File

@ -1,6 +1,7 @@
import * as msal from "@azure/msal-browser"; import * as msal from "@azure/msal-browser";
import { getEnvironmentScopeEndpoint } from "Common/EnvironmentUtility";
import { Action, ActionModifiers } from "Shared/Telemetry/TelemetryConstants"; import { Action, ActionModifiers } from "Shared/Telemetry/TelemetryConstants";
import { isDataplaneRbacSupported } from "Utils/APITypeUtils"; import { hasProxyServer, isDataplaneRbacSupported } from "Utils/APITypeUtils";
import { AuthType } from "../AuthType"; import { AuthType } from "../AuthType";
import * as Constants from "../Common/Constants"; import * as Constants from "../Common/Constants";
import * as Logger from "../Common/Logger"; import * as Logger from "../Common/Logger";
@ -74,10 +75,12 @@ export async function acquireMsalTokenForAccount(
if (userContext.databaseAccount.properties?.documentEndpoint === undefined) { if (userContext.databaseAccount.properties?.documentEndpoint === undefined) {
throw new Error("Database account has no document endpoint defined"); throw new Error("Database account has no document endpoint defined");
} }
const hrefEndpoint = new URL(userContext.databaseAccount.properties.documentEndpoint).href.replace( let hrefEndpoint = "";
/\/+$/, if (isDataplaneRbacEnabledForProxyApi(userContext)) {
"/.default", hrefEndpoint = getEnvironmentScopeEndpoint();
); } else {
hrefEndpoint = new URL(userContext.databaseAccount.properties.documentEndpoint).href.replace(/\/+$/, "/.default");
}
const msalInstance = await getMsalInstance(); const msalInstance = await getMsalInstance();
const knownAccounts = msalInstance.getAllAccounts(); const knownAccounts = msalInstance.getAllAccounts();
// If user_hint is provided, we will try to use it to find the account. // If user_hint is provided, we will try to use it to find the account.
@ -183,7 +186,11 @@ export async function acquireTokenWithMsal(
export function useDataplaneRbacAuthorization(userContext: UserContext): boolean { export function useDataplaneRbacAuthorization(userContext: UserContext): boolean {
return ( return (
userContext.features.enableAadDataPlane || userContext.features?.enableAadDataPlane ||
(userContext.dataPlaneRbacEnabled && isDataplaneRbacSupported(userContext.apiType)) (userContext.dataPlaneRbacEnabled && isDataplaneRbacSupported(userContext.apiType))
); );
} }
export function isDataplaneRbacEnabledForProxyApi(userContext: UserContext): boolean {
return useDataplaneRbacAuthorization(userContext) && hasProxyServer(userContext.apiType);
}

View File

@ -1,4 +1,5 @@
import * as Constants from "Common/Constants"; import * as Constants from "Common/Constants";
import { getEnvironmentScopeEndpoint } from "Common/EnvironmentUtility";
import { createUri } from "Common/UrlUtility"; import { createUri } from "Common/UrlUtility";
import { DATA_EXPLORER_RPC_VERSION } from "Contracts/DataExplorerMessagesContract"; import { DATA_EXPLORER_RPC_VERSION } from "Contracts/DataExplorerMessagesContract";
import { FabricMessageTypes } from "Contracts/FabricMessageTypes"; import { FabricMessageTypes } from "Contracts/FabricMessageTypes";
@ -62,6 +63,7 @@ import {
acquireTokenWithMsal, acquireTokenWithMsal,
getAuthorizationHeader, getAuthorizationHeader,
getMsalInstance, getMsalInstance,
isDataplaneRbacEnabledForProxyApi,
} from "../Utils/AuthorizationUtils"; } from "../Utils/AuthorizationUtils";
import { isInvalidParentFrameOrigin, shouldProcessMessage } from "../Utils/MessageValidation"; import { isInvalidParentFrameOrigin, shouldProcessMessage } from "../Utils/MessageValidation";
import { get, getReadOnlyKeys, listKeys } from "../Utils/arm/generatedClients/cosmos/databaseAccounts"; import { get, getReadOnlyKeys, listKeys } from "../Utils/arm/generatedClients/cosmos/databaseAccounts";
@ -331,7 +333,12 @@ async function configureHostedWithAAD(config: AAD): Promise<Explorer> {
const resourceGroup = accountResourceId && accountResourceId.split("resourceGroups/")[1].split("/")[0]; const resourceGroup = accountResourceId && accountResourceId.split("resourceGroups/")[1].split("/")[0];
let aadToken; let aadToken;
if (account.properties?.documentEndpoint) { if (account.properties?.documentEndpoint) {
const hrefEndpoint = new URL(account.properties.documentEndpoint).href.replace(/\/$/, "/.default"); let hrefEndpoint = "";
if (isDataplaneRbacEnabledForProxyApi(userContext)) {
hrefEndpoint = getEnvironmentScopeEndpoint();
} else {
hrefEndpoint = new URL(account.properties.documentEndpoint).href.replace(/\/$/, "/.default");
}
const msalInstance = await getMsalInstance(); const msalInstance = await getMsalInstance();
const cachedAccount = msalInstance.getAllAccounts()?.[0]; const cachedAccount = msalInstance.getAllAccounts()?.[0];
msalInstance.setActiveAccount(cachedAccount); msalInstance.setActiveAccount(cachedAccount);