Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Bedrock Converse API (Anthropic Messages API, Claude 3.5 Sonnet) #2851

Merged
merged 11 commits into from
Sep 6, 2024
5 changes: 5 additions & 0 deletions plugin/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -576,3 +576,8 @@ task bwcTestSuite(type: StandaloneRestIntegTestTask) {
dependsOn tasks.named("${baseName}#rollingUpgradeClusterTask")
dependsOn tasks.named("${baseName}#fullRestartClusterTask")
}

forbiddenPatterns {
exclude '**/*.pdf'
exclude '**/*.jpg'
}

Large diffs are not rendered by default.

Binary file not shown.
Copy link
Collaborator

@mingshl mingshl Sep 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need license for the picture or pdf in this PR? @kolchfa-aws

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@austintlee is this image generated from openai image generation?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@austintlee Could you tell me the original source of this image please?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI, that image is from OpenAI's API documentation - https://platform.openai.com/docs/api-reference/chat/create. It's in the example they give for including images in chat completion.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The actual source is Wikipedia.

Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,8 @@ public void processResponseAsync(
chatHistory,
searchResults,
timeout,
params.getLlmResponseField()
params.getLlmResponseField(),
params.getLlmMessages()
),
null,
llmQuestion,
Expand All @@ -202,7 +203,8 @@ public void processResponseAsync(
chatHistory,
searchResults,
timeout,
params.getLlmResponseField()
params.getLlmResponseField(),
params.getLlmMessages()
),
conversationId,
llmQuestion,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
package org.opensearch.searchpipelines.questionanswering.generative.ext;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;

import org.opensearch.core.ParseField;
Expand All @@ -30,6 +32,7 @@
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.searchpipelines.questionanswering.generative.GenerativeQAProcessorConstants;
import org.opensearch.searchpipelines.questionanswering.generative.llm.MessageBlock;

import com.google.common.base.Preconditions;

Expand Down Expand Up @@ -81,6 +84,8 @@ public class GenerativeQAParameters implements Writeable, ToXContentObject {
// that contains the chat completion text, i.e. "answer".
private static final ParseField LLM_RESPONSE_FIELD = new ParseField("llm_response_field");

private static final ParseField LLM_MESSAGES_FIELD = new ParseField("llm_messages");

public static final int SIZE_NULL_VALUE = -1;

static {
Expand All @@ -94,6 +99,7 @@ public class GenerativeQAParameters implements Writeable, ToXContentObject {
PARSER.declareIntOrNull(GenerativeQAParameters::setInteractionSize, SIZE_NULL_VALUE, INTERACTION_SIZE);
PARSER.declareIntOrNull(GenerativeQAParameters::setTimeout, SIZE_NULL_VALUE, TIMEOUT);
PARSER.declareStringOrNull(GenerativeQAParameters::setLlmResponseField, LLM_RESPONSE_FIELD);
PARSER.declareObjectArray(GenerativeQAParameters::setMessageBlock, (p, c) -> MessageBlock.fromXContent(p), LLM_MESSAGES_FIELD);
}

@Setter
Expand Down Expand Up @@ -132,6 +138,10 @@ public class GenerativeQAParameters implements Writeable, ToXContentObject {
@Getter
private String llmResponseField;

@Setter
@Getter
private List<MessageBlock> llmMessages = new ArrayList<>();

public GenerativeQAParameters(
String conversationId,
String llmModel,
Expand All @@ -142,6 +152,32 @@ public GenerativeQAParameters(
Integer interactionSize,
Integer timeout,
String llmResponseField
) {
this(
conversationId,
llmModel,
llmQuestion,
systemPrompt,
userInstructions,
contextSize,
interactionSize,
timeout,
llmResponseField,
null
);
}

public GenerativeQAParameters(
String conversationId,
String llmModel,
String llmQuestion,
String systemPrompt,
String userInstructions,
Integer contextSize,
Integer interactionSize,
Integer timeout,
String llmResponseField,
List<MessageBlock> llmMessages
) {
this.conversationId = conversationId;
this.llmModel = llmModel;
Expand All @@ -156,6 +192,9 @@ public GenerativeQAParameters(
this.interactionSize = (interactionSize == null) ? SIZE_NULL_VALUE : interactionSize;
this.timeout = (timeout == null) ? SIZE_NULL_VALUE : timeout;
this.llmResponseField = llmResponseField;
if (llmMessages != null) {
this.llmMessages.addAll(llmMessages);
}
}

public GenerativeQAParameters(StreamInput input) throws IOException {
Expand All @@ -168,6 +207,7 @@ public GenerativeQAParameters(StreamInput input) throws IOException {
this.interactionSize = input.readInt();
this.timeout = input.readInt();
this.llmResponseField = input.readOptionalString();
this.llmMessages.addAll(input.readList(MessageBlock::new));
}

@Override
Expand All @@ -181,7 +221,8 @@ public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params
.field(CONTEXT_SIZE.getPreferredName(), this.contextSize)
.field(INTERACTION_SIZE.getPreferredName(), this.interactionSize)
.field(TIMEOUT.getPreferredName(), this.timeout)
.field(LLM_RESPONSE_FIELD.getPreferredName(), this.llmResponseField);
.field(LLM_RESPONSE_FIELD.getPreferredName(), this.llmResponseField)
.field(LLM_MESSAGES_FIELD.getPreferredName(), this.llmMessages);
}

@Override
Expand All @@ -197,6 +238,7 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeInt(interactionSize);
out.writeInt(timeout);
out.writeOptionalString(llmResponseField);
out.writeList(llmMessages);
}

public static GenerativeQAParameters parse(XContentParser parser) throws IOException {
Expand All @@ -223,4 +265,8 @@ public boolean equals(Object o) {
&& (this.timeout == other.getTimeout())
&& Objects.equals(this.llmResponseField, other.getLlmResponseField());
}

public void setMessageBlock(List<MessageBlock> blockList) {
this.llmMessages = blockList;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,5 @@ public class ChatCompletionInput {
private String userInstructions;
private Llm.ModelProvider modelProvider;
private String llmResponseField;
private List<MessageBlock> llmMessages;
}
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ protected void setMlClient(MachineLearningInternalClient mlClient) {
* @return
*/
@Override

public void doChatCompletion(ChatCompletionInput chatCompletionInput, ActionListener<ChatCompletionOutput> listener) {
MLInputDataset dataset = RemoteInferenceInputDataSet.builder().parameters(getInputParameters(chatCompletionInput)).build();
MLInput mlInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(dataset).build();
Expand Down Expand Up @@ -113,14 +112,15 @@ protected Map<String, String> getInputParameters(ChatCompletionInput chatComplet
inputParameters.put(CONNECTOR_INPUT_PARAMETER_MODEL, chatCompletionInput.getModel());
String messages = PromptUtil
.getChatCompletionPrompt(
chatCompletionInput.getModelProvider(),
chatCompletionInput.getSystemPrompt(),
chatCompletionInput.getUserInstructions(),
chatCompletionInput.getQuestion(),
chatCompletionInput.getChatHistory(),
chatCompletionInput.getContexts()
chatCompletionInput.getContexts(),
chatCompletionInput.getLlmMessages()
);
inputParameters.put(CONNECTOR_INPUT_PARAMETER_MESSAGES, messages);
// log.info("Messages to LLM: {}", messages);
} else if (chatCompletionInput.getModelProvider() == ModelProvider.BEDROCK
|| chatCompletionInput.getModelProvider() == ModelProvider.COHERE
|| chatCompletionInput.getLlmResponseField() != null) {
Expand All @@ -136,6 +136,19 @@ protected Map<String, String> getInputParameters(ChatCompletionInput chatComplet
chatCompletionInput.getContexts()
)
);
} else if (chatCompletionInput.getModelProvider() == ModelProvider.BEDROCK_CONVERSE) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of keep adding new "else if" blocks here, why not just define an annotation for different methods and use reflection to process different models servers at runtime? In that way I think we could reduce the code complexity here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given the time constraint, I would do the refactoring at a later time. I have been thinking about a good way to avoid this style of handling different cases of LLM vendors and APIs, but I was waiting for some general patterns to emerge which I think is this Message API, but again there are still small differences between LLM providers.

// Bedrock Converse API does not include the system prompt as part of the Messages block.
String messages = PromptUtil
.getChatCompletionPrompt(
chatCompletionInput.getModelProvider(),
null,
chatCompletionInput.getUserInstructions(),
chatCompletionInput.getQuestion(),
chatCompletionInput.getChatHistory(),
chatCompletionInput.getContexts(),
chatCompletionInput.getLlmMessages()
);
inputParameters.put(CONNECTOR_INPUT_PARAMETER_MESSAGES, messages);
} else {
throw new IllegalArgumentException(
"Unknown/unsupported model provider: "
Expand All @@ -144,7 +157,6 @@ protected Map<String, String> getInputParameters(ChatCompletionInput chatComplet
);
}

// log.info("LLM input parameters: {}", inputParameters.toString());
return inputParameters;
}

Expand Down Expand Up @@ -184,6 +196,20 @@ protected ChatCompletionOutput buildChatCompletionOutput(ModelProvider provider,
} else if (provider == ModelProvider.COHERE) {
answerField = "text";
fillAnswersOrErrors(dataAsMap, answers, errors, answerField, errorField, defaultErrorMessageField);
} else if (provider == ModelProvider.BEDROCK_CONVERSE) {
Map output = (Map) dataAsMap.get("output");
Map message = (Map) output.get("message");
if (message != null) {
List content = (List) message.get("content");
String answer = (String) ((Map) content.get(0)).get("text");
answers.add(answer);
} else {
Map error = (Map) output.get("error");
if (error == null) {
throw new RuntimeException("Unexpected output: " + output);
}
errors.add((String) error.get("message"));
}
} else {
throw new IllegalArgumentException(
"Unknown/unsupported model provider: " + provider + ". You must provide a valid model provider or llm_response_field."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ public interface Llm {
enum ModelProvider {
OPENAI,
BEDROCK,
COHERE
COHERE,
BEDROCK_CONVERSE
}

void doChatCompletion(ChatCompletionInput input, ActionListener<ChatCompletionOutput> listener);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ public class LlmIOUtil {

public static final String BEDROCK_PROVIDER_PREFIX = "bedrock/";
public static final String COHERE_PROVIDER_PREFIX = "cohere/";
public static final String BEDROCK_CONVERSE__PROVIDER_PREFIX = "bedrock-converse/";

public static ChatCompletionInput createChatCompletionInput(
String llmModel,
Expand All @@ -49,7 +50,8 @@ public static ChatCompletionInput createChatCompletionInput(
chatHistory,
contexts,
timeoutInSeconds,
llmResponseField
llmResponseField,
null
);
}

Expand All @@ -61,7 +63,8 @@ public static ChatCompletionInput createChatCompletionInput(
List<Interaction> chatHistory,
List<String> contexts,
int timeoutInSeconds,
String llmResponseField
String llmResponseField,
List<MessageBlock> llmMessages
) {
Llm.ModelProvider provider = null;
if (llmResponseField == null) {
Expand All @@ -71,6 +74,8 @@ public static ChatCompletionInput createChatCompletionInput(
provider = Llm.ModelProvider.BEDROCK;
} else if (llmModel.startsWith(COHERE_PROVIDER_PREFIX)) {
provider = Llm.ModelProvider.COHERE;
} else if (llmModel.startsWith(BEDROCK_CONVERSE__PROVIDER_PREFIX)) {
provider = Llm.ModelProvider.BEDROCK_CONVERSE;
}
}
}
Expand All @@ -83,7 +88,8 @@ public static ChatCompletionInput createChatCompletionInput(
systemPrompt,
userInstructions,
provider,
llmResponseField
llmResponseField,
llmMessages
);
}
}
Loading
Loading