From a771fe21d6097a8e94cab02a452c7d96250c26de Mon Sep 17 00:00:00 2001 From: Yulong Ruan Date: Thu, 5 Sep 2024 14:18:54 +0800 Subject: [PATCH] Expose a general function for agent execution (#268) * feat: add assistant client to public and server Signed-off-by: Yulong Ruan * tweaks Signed-off-by: Yulong Ruan * fix path Signed-off-by: Yulong Ruan * update changelog Signed-off-by: Yulong Ruan * feat: support both execute agent by name and by id Signed-off-by: Yulong Ruan * fix(public): support execute agent by name or id Signed-off-by: Yulong Ruan --------- Signed-off-by: Yulong Ruan Signed-off-by: gaobinlong Co-authored-by: gaobinlong --- CHANGELOG.md | 5 ++ common/constants/llm.ts | 4 ++ public/plugin.tsx | 7 +++ public/services/assistant_client.ts | 35 ++++++++++++ public/services/assistant_service.ts | 26 +++++++++ public/types.ts | 2 + server/plugin.ts | 15 ++++- server/routes/agent_routes.ts | 43 ++++++++++++++ server/routes/text2viz_routes.ts | 53 ++++++----------- server/services/assistant_client.ts | 85 ++++++++++++++++++++++++++++ server/services/assistant_service.ts | 30 ++++++++++ server/types.ts | 2 + 12 files changed, 268 insertions(+), 39 deletions(-) create mode 100644 public/services/assistant_client.ts create mode 100644 public/services/assistant_service.ts create mode 100644 server/routes/agent_routes.ts create mode 100644 server/services/assistant_client.ts create mode 100644 server/services/assistant_service.ts diff --git a/CHANGELOG.md b/CHANGELOG.md index 9d035c63..599dffdd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,11 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) +### Unreleased +- fix: make sure $schema always added to LLM generated vega json object([252](https://github.com/opensearch-project/dashboards-assistant/pull/252)) +- feat: expose a general function for agent execution([268](https://github.com/opensearch-project/dashboards-assistant/pull/268)) +- Fix CVE-2024-4067 ([#269](https://github.com/opensearch-project/dashboards-assistant/pull/269)) + ### 📈 Features/Enhancements - Add support for registerMessageParser ([#5](https://github.com/opensearch-project/dashboards-assistant/pull/5)) diff --git a/common/constants/llm.ts b/common/constants/llm.ts index 63ac5a59..ff0673f0 100644 --- a/common/constants/llm.ts +++ b/common/constants/llm.ts @@ -24,6 +24,10 @@ export const TEXT2VIZ_API = { TEXT2VEGA: `${API_BASE}/text2vega`, }; +export const AGENT_API = { + EXECUTE: `${API_BASE}/agent/_execute`, +}; + export const NOTEBOOK_API = { CREATE_NOTEBOOK: `${NOTEBOOK_PREFIX}/note`, SET_PARAGRAPH: `${NOTEBOOK_PREFIX}/set_paragraphs/`, diff --git a/public/plugin.tsx b/public/plugin.tsx index e1309fe0..f3f85667 100644 --- a/public/plugin.tsx +++ b/public/plugin.tsx @@ -42,6 +42,8 @@ import { import { ConfigSchema } from '../common/types/config'; import { DataSourceService } from './services/data_source_service'; import { ASSISTANT_API, DEFAULT_USER_NAME } from '../common/constants/llm'; +import { IncontextInsightProps } from './components/incontext_insight'; +import { AssistantService } from './services/assistant_service'; export const [getCoreStart, setCoreStart] = createGetterSetter('CoreStart'); @@ -71,6 +73,7 @@ export class AssistantPlugin incontextInsightRegistry: IncontextInsightRegistry | undefined; private dataSourceService: DataSourceService; private resetChatSubscription: Subscription | undefined; + private assistantService = new AssistantService(); constructor(initializerContext: PluginInitializerContext) { this.config = initializerContext.config.get(); @@ -81,6 +84,7 @@ export class AssistantPlugin core: CoreSetup, setupDeps: AssistantPluginSetupDependencies ): AssistantSetup { + this.assistantService.setup(); this.incontextInsightRegistry = new IncontextInsightRegistry(); setIncontextInsightRegistry(this.incontextInsightRegistry); const messageRenderers: Record = {}; @@ -211,17 +215,20 @@ export class AssistantPlugin } public start(core: CoreStart): AssistantStart { + const assistantServiceStart = this.assistantService.start(core.http); setCoreStart(core); setChrome(core.chrome); setNotifications(core.notifications); return { dataSource: this.dataSourceService.start(), + assistantClient: assistantServiceStart.client, }; } public stop() { this.dataSourceService.stop(); + this.assistantService.stop(); this.resetChatSubscription?.unsubscribe(); } } diff --git a/public/services/assistant_client.ts b/public/services/assistant_client.ts new file mode 100644 index 00000000..fd990137 --- /dev/null +++ b/public/services/assistant_client.ts @@ -0,0 +1,35 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import { API_BASE } from '../../common/constants/llm'; +import { HttpSetup } from '../../../../src/core/public'; + +interface Options { + dataSourceId?: string; +} + +export class AssistantClient { + constructor(private http: HttpSetup) {} + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + executeAgent = (agentId: string, parameters: Record, options?: Options) => { + return this.http.fetch({ + method: 'POST', + path: `${API_BASE}/agent/_execute`, + body: JSON.stringify(parameters), + query: { dataSourceId: options?.dataSourceId, agentId }, + }); + }; + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + executeAgentByName = (agentName: string, parameters: Record, options?: Options) => { + return this.http.fetch({ + method: 'POST', + path: `${API_BASE}/agent/_execute`, + body: JSON.stringify(parameters), + query: { dataSourceId: options?.dataSourceId, agentName }, + }); + }; +} diff --git a/public/services/assistant_service.ts b/public/services/assistant_service.ts new file mode 100644 index 00000000..117b67d9 --- /dev/null +++ b/public/services/assistant_service.ts @@ -0,0 +1,26 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import { HttpSetup } from '../../../../src/core/public'; +import { AssistantClient } from './assistant_client'; + +export interface AssistantServiceStart { + client: AssistantClient; +} + +export class AssistantService { + constructor() {} + + setup() {} + + start(http: HttpSetup): AssistantServiceStart { + const assistantClient = new AssistantClient(http); + return { + client: assistantClient, + }; + } + + stop() {} +} diff --git a/public/types.ts b/public/types.ts index aa95eef6..e0bf2a9b 100644 --- a/public/types.ts +++ b/public/types.ts @@ -16,6 +16,7 @@ import { } from '../../../src/plugins/visualizations/public'; import { DataPublicPluginSetup, DataPublicPluginStart } from '../../../src/plugins/data/public'; import { AppMountParameters, CoreStart } from '../../../src/core/public'; +import { AssistantClient } from './services/assistant_client'; export interface RenderProps { props: MessageContentProps; @@ -67,6 +68,7 @@ export interface AssistantSetup { export interface AssistantStart { dataSource: DataSourceServiceContract; + assistantClient: AssistantClient; } export type StartServices = CoreStart & diff --git a/server/plugin.ts b/server/plugin.ts index 48777819..0499bc5c 100644 --- a/server/plugin.ts +++ b/server/plugin.ts @@ -17,10 +17,13 @@ import { BasicInputOutputParser } from './parsers/basic_input_output_parser'; import { VisualizationCardParser } from './parsers/visualization_card_parser'; import { registerChatRoutes } from './routes/chat_routes'; import { registerText2VizRoutes } from './routes/text2viz_routes'; +import { AssistantService } from './services/assistant_service'; +import { registerAgentRoutes } from './routes/agent_routes'; export class AssistantPlugin implements Plugin { private readonly logger: Logger; private messageParsers: MessageParser[] = []; + private assistantService = new AssistantService(); constructor(private readonly initializerContext: PluginInitializerContext) { this.logger = initializerContext.logger.get(); @@ -33,6 +36,8 @@ export class AssistantPlugin implements Plugin { @@ -42,6 +47,8 @@ export class AssistantPlugin implements Plugin ({ @@ -72,6 +79,7 @@ export class AssistantPlugin implements Plugin { const findIndex = this.messageParsers.findIndex((item) => item.id === parserId); @@ -86,8 +94,11 @@ export class AssistantPlugin implements Plugin { + try { + const assistantClient = assistantService.getScopedClient(req, context); + if ('agentId' in req.query) { + const response = await assistantClient.executeAgent(req.query.agentId, req.body); + return res.ok({ body: response }); + } + const response = await assistantClient.executeAgentByName(req.query.agentName, req.body); + return res.ok({ body: response }); + } catch (e) { + return res.internalError(); + } + }) + ); +} diff --git a/server/routes/text2viz_routes.ts b/server/routes/text2viz_routes.ts index bf7d58ef..752240df 100644 --- a/server/routes/text2viz_routes.ts +++ b/server/routes/text2viz_routes.ts @@ -6,14 +6,12 @@ import { schema } from '@osd/config-schema'; import { IRouter } from '../../../../src/core/server'; import { TEXT2VIZ_API } from '../../common/constants/llm'; -import { getOpenSearchClientTransport } from '../utils/get_opensearch_client_transport'; -import { ML_COMMONS_BASE_API } from '../utils/constants'; -import { getAgent } from './get_agent'; +import { AssistantServiceSetup } from '../services/assistant_service'; const TEXT2VEGA_AGENT_CONFIG_ID = 'text2vega'; const TEXT2PPL_AGENT_CONFIG_ID = 'text2ppl'; -export function registerText2VizRoutes(router: IRouter) { +export function registerText2VizRoutes(router: IRouter, assistantService: AssistantServiceSetup) { router.post( { path: TEXT2VIZ_API.TEXT2VEGA, @@ -30,25 +28,15 @@ export function registerText2VizRoutes(router: IRouter) { }, }, router.handleLegacyErrors(async (context, req, res) => { - const client = await getOpenSearchClientTransport({ - context, - dataSourceId: req.query.dataSourceId, - }); - const agentId = await getAgent(TEXT2VEGA_AGENT_CONFIG_ID, client); - const response = await client.request({ - method: 'POST', - path: `${ML_COMMONS_BASE_API}/agents/${agentId}/_execute`, - body: { - parameters: { - input: req.body.input, - ppl: req.body.ppl, - dataSchema: req.body.dataSchema, - sampleData: req.body.sampleData, - }, - }, - }); - + const assistantClient = assistantService.getScopedClient(req, context); try { + const response = await assistantClient.executeAgentByName(TEXT2VEGA_AGENT_CONFIG_ID, { + input: req.body.input, + ppl: req.body.ppl, + dataSchema: req.body.dataSchema, + sampleData: req.body.sampleData, + }); + // let result = response.body.inference_results[0].output[0].dataAsMap; let result = JSON.parse(response.body.inference_results[0].output[0].result); // sometimes llm returns {response: } instead of @@ -80,22 +68,13 @@ export function registerText2VizRoutes(router: IRouter) { }, }, router.handleLegacyErrors(async (context, req, res) => { - const client = await getOpenSearchClientTransport({ - context, - dataSourceId: req.query.dataSourceId, - }); - const agentId = await getAgent(TEXT2PPL_AGENT_CONFIG_ID, client); - const response = await client.request({ - method: 'POST', - path: `${ML_COMMONS_BASE_API}/agents/${agentId}/_execute`, - body: { - parameters: { - question: req.body.question, - index: req.body.index, - }, - }, - }); + const assistantClient = assistantService.getScopedClient(req, context); try { + const response = await assistantClient.executeAgentByName(TEXT2PPL_AGENT_CONFIG_ID, { + question: req.body.question, + index: req.body.index, + }); + const result = JSON.parse(response.body.inference_results[0].output[0].result); return res.ok({ body: result }); } catch (e) { diff --git a/server/services/assistant_client.ts b/server/services/assistant_client.ts new file mode 100644 index 00000000..30593db8 --- /dev/null +++ b/server/services/assistant_client.ts @@ -0,0 +1,85 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import { ApiResponse } from '@opensearch-project/opensearch'; + +import { + OpenSearchClient, + OpenSearchDashboardsRequest, + RequestHandlerContext, +} from '../../../../src/core/server'; +import { ML_COMMONS_BASE_API } from '../utils/constants'; +import { getAgent } from '../routes/get_agent'; + +interface AgentExecuteResponse { + inference_results: Array<{ + output: Array<{ result: string }>; + }>; +} + +// eslint-disable-next-line @typescript-eslint/no-explicit-any +const isDataSourceQuery = (query: any): query is { dataSourceId: string } => { + if ('dataSourceId' in query && query.dataSourceId) { + return true; + } + return false; +}; + +export class AssistantClient { + private client?: OpenSearchClient; + + constructor( + private request: OpenSearchDashboardsRequest, + private context: RequestHandlerContext & { + dataSource?: { + opensearch: { + getClient: (dataSourceId: string) => Promise; + }; + }; + } + ) {} + + executeAgent = async ( + agentId: string, + // eslint-disable-next-line @typescript-eslint/no-explicit-any + parameters: Record + ): Promise> => { + const client = await this.getOpenSearchClient(); + + const response = await client.transport.request({ + method: 'POST', + path: `${ML_COMMONS_BASE_API}/agents/${agentId}/_execute`, + body: { + parameters, + }, + }); + + return response as ApiResponse; + }; + + executeAgentByName = async ( + agentName: string, + // eslint-disable-next-line @typescript-eslint/no-explicit-any + parameters: Record + ) => { + const client = await this.getOpenSearchClient(); + const agentId = await getAgent(agentName, client.transport); + return this.executeAgent(agentId, parameters); + }; + + private async getOpenSearchClient() { + if (!this.client) { + let client = this.context.core.opensearch.client.asCurrentUser; + if (isDataSourceQuery(this.request.query) && this.context.dataSource) { + client = await this.context.dataSource.opensearch.getClient( + this.request.query.dataSourceId + ); + } + this.client = client; + } + + return this.client; + } +} diff --git a/server/services/assistant_service.ts b/server/services/assistant_service.ts new file mode 100644 index 00000000..e10265ba --- /dev/null +++ b/server/services/assistant_service.ts @@ -0,0 +1,30 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import { OpenSearchDashboardsRequest, RequestHandlerContext } from '../../../../src/core/server'; +import { AssistantClient } from './assistant_client'; + +export interface AssistantServiceSetup { + getScopedClient: ( + request: OpenSearchDashboardsRequest, + context: RequestHandlerContext + ) => AssistantClient; +} + +export class AssistantService { + constructor() {} + + setup(): AssistantServiceSetup { + return { + getScopedClient: (request: OpenSearchDashboardsRequest, context: RequestHandlerContext) => { + return new AssistantClient(request, context); + }, + }; + } + + start() {} + + stop() {} +} diff --git a/server/types.ts b/server/types.ts index a47b42bd..7f4e509d 100644 --- a/server/types.ts +++ b/server/types.ts @@ -5,8 +5,10 @@ import { IMessage, Interaction } from '../common/types/chat_saved_object_attributes'; import { Logger, HttpAuth } from '../../../src/core/server'; +import { AssistantServiceSetup } from './services/assistant_service'; export interface AssistantPluginSetup { + assistantService: AssistantServiceSetup; registerMessageParser: (message: MessageParser) => void; removeMessageParser: (parserId: MessageParser['id']) => void; }