Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Vertex AI] Make GenerativeModel and Chat into Swift actors #13545

Merged
merged 5 commits into from
Sep 3, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions FirebaseVertexAI/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# 11.2.0
- [fixed] Resolved a decoding error for citations without a `uri` and added
support for decoding `title` fields, which were previously ignored. (#13518)
- [changed] **Breaking Change**: The methods for starting streaming requests
(`generateContentStream` and `sendMessageStream`) and creating a chat instance
(`startChat`) are now asynchronous and must be called with `await`. (#13545)

# 10.29.0
- [feature] Added community support for watchOS. (#13215)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,20 @@ class ConversationViewModel: ObservableObject {
}

private var model: GenerativeModel
private var chat: Chat
private var chat: Chat? = nil
private var stopGenerating = false

private var chatTask: Task<Void, Never>?

init() {
model = VertexAI.vertexAI().generativeModel(modelName: "gemini-1.5-flash")
chat = model.startChat()
}

func sendMessage(_ text: String, streaming: Bool = true) async {
error = nil
if chat == nil {
chat = await model.startChat()
}
if streaming {
await internalSendMessageStreaming(text)
} else {
Expand All @@ -52,7 +54,7 @@ class ConversationViewModel: ObservableObject {
func startNewChat() {
stop()
error = nil
chat = model.startChat()
chat = nil
messages.removeAll()
}

Expand All @@ -79,7 +81,10 @@ class ConversationViewModel: ObservableObject {
messages.append(systemMessage)

do {
let responseStream = chat.sendMessageStream(text)
guard let chat else {
throw ChatError.notInitialized
}
let responseStream = await chat.sendMessageStream(text)
for try await chunk in responseStream {
messages[messages.count - 1].pending = false
if let text = chunk.text {
Expand Down Expand Up @@ -112,10 +117,12 @@ class ConversationViewModel: ObservableObject {
messages.append(systemMessage)

do {
var response: GenerateContentResponse?
response = try await chat.sendMessage(text)
guard let chat = chat else {
throw ChatError.notInitialized
}
let response = try await chat.sendMessage(text)

if let responseText = response?.text {
if let responseText = response.text {
// replace pending message with backend response
messages[messages.count - 1].message = responseText
messages[messages.count - 1].pending = false
Expand All @@ -127,4 +134,8 @@ class ConversationViewModel: ObservableObject {
}
}
}

enum ChatError: Error {
case notInitialized
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class FunctionCallingViewModel: ObservableObject {
private var functionCalls = [FunctionCall]()

private var model: GenerativeModel
private var chat: Chat
private var chat: Chat? = nil

private var chatTask: Task<Void, Never>?

Expand Down Expand Up @@ -62,7 +62,6 @@ class FunctionCallingViewModel: ObservableObject {
),
])]
)
chat = model.startChat()
}

func sendMessage(_ text: String, streaming: Bool = true) async {
Expand All @@ -75,6 +74,10 @@ class FunctionCallingViewModel: ObservableObject {
busy = false
}

if chat == nil {
chat = await model.startChat()
}

// first, add the user's message to the chat
let userMessage = ChatMessage(message: text, participant: .user)
messages.append(userMessage)
Expand Down Expand Up @@ -103,7 +106,7 @@ class FunctionCallingViewModel: ObservableObject {
func startNewChat() {
stop()
error = nil
chat = model.startChat()
chat = nil
andrewheard marked this conversation as resolved.
Show resolved Hide resolved
messages.removeAll()
}

Expand All @@ -114,14 +117,17 @@ class FunctionCallingViewModel: ObservableObject {

private func internalSendMessageStreaming(_ text: String) async throws {
let functionResponses = try await processFunctionCalls()
guard let chat else {
throw ChatError.notInitialized
}
let responseStream: AsyncThrowingStream<GenerateContentResponse, Error>
if functionResponses.isEmpty {
responseStream = chat.sendMessageStream(text)
responseStream = await chat.sendMessageStream(text)
} else {
for functionResponse in functionResponses {
messages.insert(functionResponse.chatMessage(), at: messages.count - 1)
}
responseStream = chat.sendMessageStream(functionResponses.modelContent())
responseStream = await chat.sendMessageStream(functionResponses.modelContent())
}
for try await chunk in responseStream {
processResponseContent(content: chunk)
Expand All @@ -130,6 +136,9 @@ class FunctionCallingViewModel: ObservableObject {

private func internalSendMessage(_ text: String) async throws {
let functionResponses = try await processFunctionCalls()
guard let chat else {
throw ChatError.notInitialized
}
let response: GenerateContentResponse
if functionResponses.isEmpty {
response = try await chat.sendMessage(text)
Expand Down Expand Up @@ -181,6 +190,10 @@ class FunctionCallingViewModel: ObservableObject {
return functionResponses
}

enum ChatError: Error {
case notInitialized
}

// MARK: - Callable Functions

func getExchangeRate(args: JSONObject) -> JSONObject {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ class PhotoReasoningViewModel: ObservableObject {
}
}

let outputContentStream = model.generateContentStream(prompt, images)
let outputContentStream = await model.generateContentStream(prompt, images)

// stream response
for try await outputContent in outputContentStream {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class SummarizeViewModel: ObservableObject {

let prompt = "Summarize the following text for me: \(inputText)"

let outputContentStream = model.generateContentStream(prompt)
let outputContentStream = await model.generateContentStream(prompt)

// stream response
for try await outputContent in outputContentStream {
Expand Down
4 changes: 2 additions & 2 deletions FirebaseVertexAI/Sources/Chat.swift
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import Foundation
/// An object that represents a back-and-forth chat with a model, capturing the history and saving
/// the context in memory between each message sent.
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
public class Chat {
public actor Chat {
private let model: GenerativeModel

/// Initializes a new chat representing a 1:1 conversation between model and user.
Expand Down Expand Up @@ -121,7 +121,7 @@ public class Chat {

// Send the history alongside the new message as context.
let request = history + newContent
let stream = model.generateContentStream(request)
let stream = await model.generateContentStream(request)
do {
for try await chunk in stream {
// Capture any content that's streaming. This should be populated if there's no error.
Expand Down
44 changes: 21 additions & 23 deletions FirebaseVertexAI/Sources/GenerativeModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import Foundation
/// A type that represents a remote multimodal model (like Gemini), with the ability to generate
/// content based on various input types.
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
public final class GenerativeModel {
public final actor GenerativeModel {
/// The resource name of the model in the backend; has the format "models/model-name".
let modelResourceName: String

Expand Down Expand Up @@ -217,33 +217,31 @@ public final class GenerativeModel {
isStreaming: true,
options: requestOptions)

var responseIterator = generativeAIService.loadRequestStream(request: generateContentRequest)
.makeAsyncIterator()
let responseStream = generativeAIService.loadRequestStream(request: generateContentRequest)

return AsyncThrowingStream {
let response: GenerateContentResponse?
do {
response = try await responseIterator.next()
} catch {
throw GenerativeModel.generateContentError(from: error)
}
for try await response in responseStream {
// Check the prompt feedback to see if the prompt was blocked.
if response.promptFeedback?.blockReason != nil {
throw GenerateContentError.promptBlocked(response: response)
}

// The responseIterator will return `nil` when it's done.
guard let response = response else {
// If the stream ended early unexpectedly, throw an error.
if let finishReason = response.candidates.first?.finishReason, finishReason != .stop {
andrewheard marked this conversation as resolved.
Show resolved Hide resolved
throw GenerateContentError.responseStoppedEarly(
reason: finishReason,
response: response
)
} else {
// Response was valid content, pass it along and continue.
return response
}
}
// This is the end of the stream! Signal it by sending `nil`.
return nil
}

// Check the prompt feedback to see if the prompt was blocked.
if response.promptFeedback?.blockReason != nil {
throw GenerateContentError.promptBlocked(response: response)
}

// If the stream ended early unexpectedly, throw an error.
if let finishReason = response.candidates.first?.finishReason, finishReason != .stop {
throw GenerateContentError.responseStoppedEarly(reason: finishReason, response: response)
} else {
// Response was valid content, pass it along and continue.
return response
} catch {
throw GenerativeModel.generateContentError(from: error)
}
}
}
Expand Down
11 changes: 6 additions & 5 deletions FirebaseVertexAI/Tests/Unit/ChatTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -64,19 +64,20 @@ final class ChatTests: XCTestCase {
)
let chat = Chat(model: model, history: [])
let input = "Test input"
let stream = chat.sendMessageStream(input)
let stream = await chat.sendMessageStream(input)

// Ensure the values are parsed correctly
for try await value in stream {
XCTAssertNotNil(value.text)
}

XCTAssertEqual(chat.history.count, 2)
XCTAssertEqual(chat.history[0].parts[0].text, input)
let history = await chat.history
XCTAssertEqual(history.count, 2)
XCTAssertEqual(history[0].parts[0].text, input)

let finalText = "1 2 3 4 5 6 7 8"
let assembledExpectation = ModelContent(role: "model", parts: finalText)
XCTAssertEqual(chat.history[0].parts[0].text, input)
XCTAssertEqual(chat.history[1], assembledExpectation)
XCTAssertEqual(history[0].parts[0].text, input)
XCTAssertEqual(history[1], assembledExpectation)
}
}
Loading
Loading