-
Notifications
You must be signed in to change notification settings - Fork 128
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for Bedrock Converse API (Anthropic Messages API, Claude 3.5 Sonnet) #2851
Changes from all commits
fd8a923
5137bf3
3f1c41a
c1e33b5
ad211f3
e6b21da
06b74cb
65f59d4
b923d00
fa3b964
b28934c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -75,7 +75,6 @@ protected void setMlClient(MachineLearningInternalClient mlClient) { | |
* @return | ||
*/ | ||
@Override | ||
|
||
public void doChatCompletion(ChatCompletionInput chatCompletionInput, ActionListener<ChatCompletionOutput> listener) { | ||
MLInputDataset dataset = RemoteInferenceInputDataSet.builder().parameters(getInputParameters(chatCompletionInput)).build(); | ||
MLInput mlInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(dataset).build(); | ||
|
@@ -113,14 +112,15 @@ protected Map<String, String> getInputParameters(ChatCompletionInput chatComplet | |
inputParameters.put(CONNECTOR_INPUT_PARAMETER_MODEL, chatCompletionInput.getModel()); | ||
String messages = PromptUtil | ||
.getChatCompletionPrompt( | ||
chatCompletionInput.getModelProvider(), | ||
chatCompletionInput.getSystemPrompt(), | ||
chatCompletionInput.getUserInstructions(), | ||
chatCompletionInput.getQuestion(), | ||
chatCompletionInput.getChatHistory(), | ||
chatCompletionInput.getContexts() | ||
chatCompletionInput.getContexts(), | ||
chatCompletionInput.getLlmMessages() | ||
); | ||
inputParameters.put(CONNECTOR_INPUT_PARAMETER_MESSAGES, messages); | ||
// log.info("Messages to LLM: {}", messages); | ||
} else if (chatCompletionInput.getModelProvider() == ModelProvider.BEDROCK | ||
|| chatCompletionInput.getModelProvider() == ModelProvider.COHERE | ||
|| chatCompletionInput.getLlmResponseField() != null) { | ||
|
@@ -136,6 +136,19 @@ protected Map<String, String> getInputParameters(ChatCompletionInput chatComplet | |
chatCompletionInput.getContexts() | ||
) | ||
); | ||
} else if (chatCompletionInput.getModelProvider() == ModelProvider.BEDROCK_CONVERSE) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Instead of keep adding new "else if" blocks here, why not just define an annotation for different methods and use reflection to process different models servers at runtime? In that way I think we could reduce the code complexity here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Given the time constraint, I would do the refactoring at a later time. I have been thinking about a good way to avoid this style of handling different cases of LLM vendors and APIs, but I was waiting for some general patterns to emerge which I think is this Message API, but again there are still small differences between LLM providers. |
||
// Bedrock Converse API does not include the system prompt as part of the Messages block. | ||
String messages = PromptUtil | ||
.getChatCompletionPrompt( | ||
chatCompletionInput.getModelProvider(), | ||
null, | ||
chatCompletionInput.getUserInstructions(), | ||
chatCompletionInput.getQuestion(), | ||
chatCompletionInput.getChatHistory(), | ||
chatCompletionInput.getContexts(), | ||
chatCompletionInput.getLlmMessages() | ||
); | ||
inputParameters.put(CONNECTOR_INPUT_PARAMETER_MESSAGES, messages); | ||
} else { | ||
throw new IllegalArgumentException( | ||
"Unknown/unsupported model provider: " | ||
|
@@ -144,7 +157,6 @@ protected Map<String, String> getInputParameters(ChatCompletionInput chatComplet | |
); | ||
} | ||
|
||
// log.info("LLM input parameters: {}", inputParameters.toString()); | ||
return inputParameters; | ||
} | ||
|
||
|
@@ -184,6 +196,20 @@ protected ChatCompletionOutput buildChatCompletionOutput(ModelProvider provider, | |
} else if (provider == ModelProvider.COHERE) { | ||
answerField = "text"; | ||
fillAnswersOrErrors(dataAsMap, answers, errors, answerField, errorField, defaultErrorMessageField); | ||
} else if (provider == ModelProvider.BEDROCK_CONVERSE) { | ||
Map output = (Map) dataAsMap.get("output"); | ||
Map message = (Map) output.get("message"); | ||
if (message != null) { | ||
List content = (List) message.get("content"); | ||
String answer = (String) ((Map) content.get(0)).get("text"); | ||
answers.add(answer); | ||
} else { | ||
Map error = (Map) output.get("error"); | ||
if (error == null) { | ||
throw new RuntimeException("Unexpected output: " + output); | ||
} | ||
errors.add((String) error.get("message")); | ||
} | ||
} else { | ||
throw new IllegalArgumentException( | ||
"Unknown/unsupported model provider: " + provider + ". You must provide a valid model provider or llm_response_field." | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need license for the picture or pdf in this PR? @kolchfa-aws
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@austintlee is this image generated from openai image generation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@austintlee Could you tell me the original source of this image please?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI, that image is from OpenAI's API documentation - https://platform.openai.com/docs/api-reference/chat/create. It's in the example they give for including images in chat completion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The actual source is Wikipedia.