From abf5dc43adc4c26b91488a7565682d8e684383c0 Mon Sep 17 00:00:00 2001 From: Dustin Do Date: Wed, 17 Jul 2024 20:01:56 +0700 Subject: [PATCH] feat(api): improve gpt prompt (#135) Includes user's categories and language (Eng only for now). --- apps/api/v1/routes/transactions.ts | 14 ++++++++++-- apps/api/v1/services/ai.service.ts | 32 +++++++++++++++++++++------- apps/mobile/mutations/transaction.ts | 6 +----- 3 files changed, 37 insertions(+), 15 deletions(-) diff --git a/apps/api/v1/routes/transactions.ts b/apps/api/v1/routes/transactions.ts index d2c04eb8..046f8368 100644 --- a/apps/api/v1/routes/transactions.ts +++ b/apps/api/v1/routes/transactions.ts @@ -6,7 +6,11 @@ import { getLogger } from '../../lib/log' import { getAuthUserStrict } from '../middlewares/auth' import { generateTransactionDataFromFile } from '../services/ai.service' import { canUserReadBudget, findBudget } from '../services/budget.service' -import { canUserReadCategory, findCategory } from '../services/category.service' +import { + canUserReadCategory, + findCategoriesOfUser, + findCategory, +} from '../services/category.service' import { canUserCreateTransaction, canUserDeleteTransaction, @@ -275,6 +279,7 @@ const router = new Hono() }) .post('/ai', async (c) => { + const user = getAuthUserStrict(c) const body = await c.req.parseBody() const file = body.file as File | undefined @@ -282,8 +287,13 @@ const router = new Hono() return c.json({ message: 'file not found' }, 400) } + const userCategories = await findCategoriesOfUser({ user }) + try { - const transactionData = await generateTransactionDataFromFile({ file }) + const transactionData = await generateTransactionDataFromFile({ + file, + categories: userCategories, + }) return c.json(transactionData) } catch { return c.json( diff --git a/apps/api/v1/services/ai.service.ts b/apps/api/v1/services/ai.service.ts index eb097f1b..094f5cb4 100644 --- a/apps/api/v1/services/ai.service.ts +++ b/apps/api/v1/services/ai.service.ts @@ -1,3 +1,4 @@ +import type { CategoryType } from '@prisma/client' import { OpenAI } from 'openai' import { getLogger } from '../../lib/log' @@ -40,7 +41,18 @@ export async function deleteFile({ fileId }: { fileId: string }) { export async function generateTransactionDataFromFile({ file: inputFile, -}: { file: File }) { + noteLanguage = 'English', + categories, +}: { + file: File + noteLanguage?: string + categories?: { + id: string + name: string + icon?: string | null + type?: CategoryType + }[] +}) { const log = getLogger(`ai.service:${generateTransactionDataFromFile.name}`) const file = await uploadVisionFile({ file: inputFile }) @@ -58,12 +70,18 @@ export async function generateTransactionDataFromFile({ log.info('Created thread with uploaded file. Thread ID: %s', thread.id) log.debug('Created thread with uploaded file. Thread details: %o', thread) - log.debug('Running assistant on thread. Assistant ID: %s', ASSISTANT_ID) + + const additionalInstructions = `note must be in ${noteLanguage} language but do not translate names. Categories: ${JSON.stringify(categories?.map((cat) => ({ id: cat.id, name: cat.name, icon: cat.icon, type: cat.type })) || [])}` + + log.debug( + 'Running assistant on thread.\nAssistant ID: %s\nAdditional instructions: %s', + ASSISTANT_ID, + additionalInstructions, + ) const run = await openai.beta.threads.runs.createAndPoll(thread.id, { assistant_id: ASSISTANT_ID, - // TODO: PUT USER CATEGORIES HERE LATER - // additional_instructions: '', + additional_instructions: additionalInstructions, }) log.info('Ran assistant on thread. Run ID: %s', run.id) @@ -81,19 +99,17 @@ export async function generateTransactionDataFromFile({ const messages = await openai.beta.threads.messages.list(run.thread_id) const firstMessage = messages.data[0].content[0] - log.debug('First message: %o', firstMessage) - const aiTransactionData = firstMessage.type === 'text' ? JSON.parse(firstMessage.text.value) : null log.info('AI transaction data: %o', aiTransactionData) - await cleanup() + cleanup() return aiTransactionData } log.error('Assistant run failed. Run details: %o', run) - await cleanup() + cleanup() throw new Error('Assistant run failed') } diff --git a/apps/mobile/mutations/transaction.ts b/apps/mobile/mutations/transaction.ts index bab6638b..b33e7a6e 100644 --- a/apps/mobile/mutations/transaction.ts +++ b/apps/mobile/mutations/transaction.ts @@ -86,12 +86,8 @@ export async function getAITransactionData(fileUri: string) { const transaction = zUpdateTransaction.parse({ ...body, - date: body?.datetime, + date: body?.date ? new Date(body.date) : undefined, }) - if (!transaction.amount) { - throw new Error('Cannot extract transaction data') - } - return transaction }