diff --git a/dev/src/field-value.ts b/dev/src/field-value.ts index eb054067c..dc1960f99 100644 --- a/dev/src/field-value.ts +++ b/dev/src/field-value.ts @@ -30,6 +30,42 @@ import { import api = proto.google.firestore.v1; +export class VectorValue implements firestore.VectorValue { + private readonly _values: number[]; + constructor(values: number[] | undefined) { + // Making a copy of the parameter. + this._values = (values || []).map(n => n); + } + + public toArray(): number[] { + return this._values.map(n => n); + } + + /** + * @private + */ + toProto(serializer: Serializer): api.IValue { + return serializer.encodeVector(this._values); + } + + /** + * @private + */ + static fromProto(valueArray: api.IValue): VectorValue { + const values = valueArray.arrayValue?.values?.map(v => { + return v.doubleValue!; + }); + return new VectorValue(values); + } + + /** + * @private + */ + isEqual(other: VectorValue): boolean { + return this._values === other._values; + } +} + /** * Sentinel values that can be used when writing documents with set(), create() * or update(). @@ -40,6 +76,10 @@ export class FieldValue implements firestore.FieldValue { /** @private */ constructor() {} + static vector(values?: number[]): VectorValue { + return new VectorValue(values); + } + /** * Returns a sentinel for use with update() or set() with {merge:true} to mark * a field for deletion. diff --git a/dev/src/serializer.ts b/dev/src/serializer.ts index dae07e8a3..bb172ff07 100644 --- a/dev/src/serializer.ts +++ b/dev/src/serializer.ts @@ -19,7 +19,7 @@ import {DocumentData} from '@google-cloud/firestore'; import * as proto from '../protos/firestore_v1_proto_api'; import {detectValueType} from './convert'; -import {DeleteTransform, FieldTransform} from './field-value'; +import {DeleteTransform, FieldTransform, VectorValue} from './field-value'; import {GeoPoint} from './geo-point'; import {DocumentReference, Firestore} from './index'; import {FieldPath, QualifiedResourcePath} from './path'; @@ -38,6 +38,10 @@ import api = proto.google.firestore.v1; */ const MAX_DEPTH = 20; +const RESERVED_MAP_KEY = '__type__'; +const RESERVED_MAP_KEY_VECTOR_VALUE = '__vector__'; +const VECTOR_MAP_VECTORS_KEY = 'value'; + /** * An interface for Firestore types that can be serialized to Protobuf. * @@ -168,6 +172,10 @@ export class Serializer { }; } + if (val instanceof VectorValue) { + return val.toProto(this); + } + if (isObject(val)) { const toProto = val['toProto']; if (typeof toProto === 'function') { @@ -217,6 +225,31 @@ export class Serializer { throw new Error(`Cannot encode value: ${val}`); } + /** + * @private + */ + encodeVector(rawVector: number[]): api.IValue { + // A Firestore Vector is a map with reserved key/value pairs. + return { + mapValue: { + fields: { + [RESERVED_MAP_KEY]: { + stringValue: RESERVED_MAP_KEY_VECTOR_VALUE, + }, + [VECTOR_MAP_VECTORS_KEY]: { + arrayValue: { + values: rawVector.map(value => { + return { + doubleValue: value, + }; + }), + }, + }, + }, + }, + }; + } + /** * Decodes a single Firestore 'Value' Protobuf. * @@ -263,15 +296,25 @@ export class Serializer { return null; } case 'mapValue': { - const obj: DocumentData = {}; const fields = proto.mapValue!.fields; if (fields) { - for (const prop of Object.keys(fields)) { - obj[prop] = this.decodeValue(fields[prop]); + const props = Object.keys(fields); + if ( + props.indexOf(RESERVED_MAP_KEY) !== -1 && + this.decodeValue(fields[RESERVED_MAP_KEY]) === + RESERVED_MAP_KEY_VECTOR_VALUE + ) { + return VectorValue.fromProto(fields[VECTOR_MAP_VECTORS_KEY]); + } else { + const obj: DocumentData = {}; + for (const prop of Object.keys(fields)) { + obj[prop] = this.decodeValue(fields[prop]); + } + return obj; } + } else { + return {}; } - - return obj; } case 'geoPointValue': { return GeoPoint.fromProto(proto.geoPointValue!); @@ -367,6 +410,8 @@ export function validateUserInput( 'If you want to ignore undefined values, enable `ignoreUndefinedProperties`.' ); } + } else if (value instanceof VectorValue) { + // OK } else if (value instanceof DeleteTransform) { if (inArray) { throw new Error( diff --git a/dev/system-test/firestore.ts b/dev/system-test/firestore.ts index 2b63a2846..3fb8280e7 100644 --- a/dev/system-test/firestore.ts +++ b/dev/system-test/firestore.ts @@ -1018,6 +1018,30 @@ describe('DocumentReference class', () => { return promise; }); + it.only('can write and read vector embeddings', async () => { + const ref = randomCol.doc(); + await ref.create({ + vectorEmpty: FieldValue.vector(), + vector1: FieldValue.vector([1, 2, 3.99]), + }); + await ref.set({ + vectorEmpty: FieldValue.vector(), + vector1: FieldValue.vector([1, 2, 3.99]), + vector2: FieldValue.vector([0, 0, 0]), + }); + await ref.update({ + vector3: FieldValue.vector([-1, -200, -999]), + }); + + const snap1 = await ref.get(); + expect(snap1.get('vectorEmpty')).to.deep.equal(FieldValue.vector()); + expect(snap1.get('vector1')).to.deep.equal(FieldValue.vector([1, 2, 3.99])); + expect(snap1.get('vector2')).to.deep.equal(FieldValue.vector([0, 0, 0])); + expect(snap1.get('vector3')).to.deep.equal( + FieldValue.vector([-1, -200, -999]) + ); + }); + describe('watch', () => { const currentDeferred = new DeferredPromise(); @@ -1311,6 +1335,91 @@ describe('DocumentReference class', () => { const result2 = await ref2.get(); expect(result2.data()).to.deep.equal([1, 2, 3]); }); + + it.only('can listen to documents with vectors', async () => { + const ref = randomCol.doc(); + const initialDeferred = new Deferred(); + const createDeferred = new Deferred(); + const setDeferred = new Deferred(); + const updateDeferred = new Deferred(); + const deleteDeferred = new Deferred(); + + const expected = [ + initialDeferred, + createDeferred, + setDeferred, + updateDeferred, + deleteDeferred, + ]; + let idx = 0; + let document: DocumentSnapshot | null = null; + + const unlisten = randomCol + .where('purpose', '==', 'vector tests') + .onSnapshot(snap => { + expected[idx].resolve(); + idx += 1; + if (snap.docs.length > 0) { + document = snap.docs[0]; + } else { + document = null; + } + }); + + await initialDeferred.promise; + expect(document).to.be.null; + + await ref.create({ + purpose: 'vector tests', + vectorEmpty: FieldValue.vector(), + vector1: FieldValue.vector([1, 2, 3.99]), + }); + + await createDeferred.promise; + expect(document).to.be.not.null; + expect(document!.get('vectorEmpty')).to.deep.equal(FieldValue.vector()); + expect(document!.get('vector1')).to.deep.equal( + FieldValue.vector([1, 2, 3.99]) + ); + + await ref.set({ + purpose: 'vector tests', + vectorEmpty: FieldValue.vector(), + vector1: FieldValue.vector([1, 2, 3.99]), + vector2: FieldValue.vector([0, 0, 0]), + }); + await setDeferred.promise; + expect(document).to.be.not.null; + expect(document!.get('vectorEmpty')).to.deep.equal(FieldValue.vector()); + expect(document!.get('vector1')).to.deep.equal( + FieldValue.vector([1, 2, 3.99]) + ); + expect(document!.get('vector2')).to.deep.equal( + FieldValue.vector([0, 0, 0]) + ); + + await ref.update({ + vector3: FieldValue.vector([-1, -200, -999]), + }); + await updateDeferred.promise; + expect(document).to.be.not.null; + expect(document!.get('vectorEmpty')).to.deep.equal(FieldValue.vector()); + expect(document!.get('vector1')).to.deep.equal( + FieldValue.vector([1, 2, 3.99]) + ); + expect(document!.get('vector2')).to.deep.equal( + FieldValue.vector([0, 0, 0]) + ); + expect(document!.get('vector3')).to.deep.equal( + FieldValue.vector([-1, -200, -999]) + ); + + await ref.delete(); + await deleteDeferred.promise; + expect(document).to.be.null; + + unlisten(); + }); }); describe('runs query on a large collection', () => { diff --git a/dev/test/document.ts b/dev/test/document.ts index f37bf0050..dc8a62b71 100644 --- a/dev/test/document.ts +++ b/dev/test/document.ts @@ -471,6 +471,43 @@ describe('serialize document', () => { return ref.set({ref}); }); }); + + it('is able to translate FirestoreVector to internal representation with set', () => { + const overrides: ApiOverride = { + commit: request => { + requestEquals( + request, + set({ + document: document('documentId', 'embedding1', { + mapValue: { + fields: { + __type__: { + stringValue: '__vector__', + }, + value: { + arrayValue: { + values: [ + {doubleValue: 0}, + {doubleValue: 1}, + {doubleValue: 2}, + ], + }, + }, + }, + }, + }), + }) + ); + return response(writeResult(1)); + }, + }; + + return createInstance(overrides).then(firestore => { + return firestore.doc('collectionId/documentId').set({ + embedding1: FieldValue.vector([0, 1, 2]), + }); + }); + }); }); describe('deserialize document', () => { @@ -599,6 +636,46 @@ describe('deserialize document', () => { }); }); + it('deserializes FirestoreVector', () => { + const overrides: ApiOverride = { + batchGetDocuments: () => { + return stream( + found( + document('documentId', 'embedding', { + mapValue: { + fields: { + __type__: { + stringValue: '__vector__', + }, + value: { + arrayValue: { + values: [ + {doubleValue: -41.0}, + {doubleValue: 0}, + {doubleValue: 42}, + ], + }, + }, + }, + }, + }) + ) + ); + }, + }; + + return createInstance(overrides).then(firestore => { + return firestore + .doc('collectionId/documentId') + .get() + .then(res => { + expect(res.get('embedding')).to.deep.equal( + FieldValue.vector([-41.0, 0, 42]) + ); + }); + }); + }); + it("doesn't deserialize unsupported types", () => { const overrides: ApiOverride = { batchGetDocuments: () => { diff --git a/dev/test/types.ts b/dev/test/types.ts index 2240cdaf9..a2b765553 100644 --- a/dev/test/types.ts +++ b/dev/test/types.ts @@ -120,7 +120,7 @@ describe('FirestoreTypeConverter', () => { await newDocRef.set({stringProperty: 'foo', numberProperty: 42}); await newDocRef.update({a: 'newFoo', b: 43}); const snapshot = await newDocRef.get(); - const data: MyModelType = snapshot.data()!; + const data = snapshot.data()!; expect(data.stringProperty).to.equal('newFoo'); expect(data.numberProperty).to.equal(43); } diff --git a/types/firestore.d.ts b/types/firestore.d.ts index 9f4c2928c..d5a3798eb 100644 --- a/types/firestore.d.ts +++ b/types/firestore.d.ts @@ -2481,6 +2481,23 @@ declare namespace FirebaseFirestore { ): boolean; } + /** + * Represent a vector type in Firestore documents. + */ + export class VectorValue { + private constructor(values: number[] | undefined); + + /** + * Returns a copy of the raw number array form of the vector. + */ + toArray(): number[]; + + /** + * Returns true if the two `VectorValue` has the same raw number arrays, returns false otherwise. + */ + isEqual(other: VectorValue): boolean; + } + /** * Sentinel values that can be used when writing document fields with set(), * create() or update(). @@ -2551,6 +2568,11 @@ declare namespace FirebaseFirestore { */ static arrayRemove(...elements: any[]): FieldValue; + /** + * @return A new `VectorValue` constructed with a copy of the given array of number. + */ + static vector(values?: number[]): VectorValue; + /** * Returns true if this `FieldValue` is equal to the provided one. *