From ae7aca32a690a4b21a3da793258ce17560b551e7 Mon Sep 17 00:00:00 2001 From: Adrien Grand Date: Mon, 16 Sep 2024 18:14:13 +0200 Subject: [PATCH] First iteration on removing copy() / vectorValue() in favor of dictionary(). --- .../codecs/hnsw/DefaultFlatVectorScorer.java | 10 ++--- .../apache/lucene/index/ByteVectorValues.java | 43 ++++++++++++++---- .../lucene/index/FloatVectorValues.java | 45 +++++++++++++++---- 3 files changed, 77 insertions(+), 21 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/codecs/hnsw/DefaultFlatVectorScorer.java b/lucene/core/src/java/org/apache/lucene/codecs/hnsw/DefaultFlatVectorScorer.java index 3e506037969..4ce1a4e7ccb 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/hnsw/DefaultFlatVectorScorer.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/hnsw/DefaultFlatVectorScorer.java @@ -125,15 +125,15 @@ public String toString() { /** RandomVectorScorerSupplier for Float vector */ private static final class FloatScoringSupplier implements RandomVectorScorerSupplier { private final FloatVectorValues vectors; - private final FloatVectorValues vectors1; - private final FloatVectorValues vectors2; + private final FloatVectorValues.Dictionary dict1; + private final FloatVectorValues.Dictionary dict2; private final VectorSimilarityFunction similarityFunction; private FloatScoringSupplier( FloatVectorValues vectors, VectorSimilarityFunction similarityFunction) throws IOException { this.vectors = vectors; - vectors1 = vectors.copy(); - vectors2 = vectors.copy(); + dict1 = vectors.dictionary(); + dict2 = vectors.dictionary(); this.similarityFunction = similarityFunction; } @@ -142,7 +142,7 @@ public RandomVectorScorer scorer(int ord) { return new RandomVectorScorer.AbstractRandomVectorScorer(vectors) { @Override public float score(int node) throws IOException { - return similarityFunction.compare(vectors1.vectorValue(ord), vectors2.vectorValue(node)); + return similarityFunction.compare(dict1.vectorValue(ord), dict2.vectorValue(node)); } }; } diff --git a/lucene/core/src/java/org/apache/lucene/index/ByteVectorValues.java b/lucene/core/src/java/org/apache/lucene/index/ByteVectorValues.java index bb84ba51ef8..aafa1f195c1 100644 --- a/lucene/core/src/java/org/apache/lucene/index/ByteVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/index/ByteVectorValues.java @@ -29,6 +29,20 @@ */ public abstract class ByteVectorValues extends KnnVectorValues { + /** + * A dictionary of dense byte vectors. + */ + public static abstract class Dictionary { + + /** + * Return the vector value for the given vector ordinal which must be in [0, size() - 1], + * otherwise IndexOutOfBoundsException is thrown. The returned array may be shared across calls. + * + * @return the vector value + */ + public abstract byte[] vectorValue(int ord) throws IOException; + } + /** Sole constructor */ protected ByteVectorValues() {} @@ -37,13 +51,26 @@ public ByteVectorValues copy() throws IOException { return this; } + /** Retrieve a {@link Dictionary} of vectors. */ + public Dictionary dictionary() throws IOException { + ByteVectorValues copy = copy(); + return new Dictionary() { + @Override + public byte[] vectorValue(int ord) throws IOException { + return copy.vectorValue(ord); + } + }; + } + /** * Return the vector value for the given vector ordinal which must be in [0, size() - 1], * otherwise IndexOutOfBoundsException is thrown. The returned array may be shared across calls. * * @return the vector value */ - public abstract byte[] vectorValue(int ord) throws IOException; + public byte[] vectorValue(int ord) throws IOException { + return dictionary().vectorValue(ord); + } /** * Checks the Vector Encoding of a field @@ -101,13 +128,13 @@ public int dimension() { } @Override - public byte[] vectorValue(int targetOrd) { - return vectors.get(targetOrd); - } - - @Override - public ByteVectorValues copy() { - return this; + public Dictionary dictionary() throws IOException { + return new Dictionary() { + @Override + public byte[] vectorValue(int ord) throws IOException { + return vectors.get(ord); + } + }; } @Override diff --git a/lucene/core/src/java/org/apache/lucene/index/FloatVectorValues.java b/lucene/core/src/java/org/apache/lucene/index/FloatVectorValues.java index d35cdcea5c8..a5d867e2c3f 100644 --- a/lucene/core/src/java/org/apache/lucene/index/FloatVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/index/FloatVectorValues.java @@ -29,21 +29,50 @@ */ public abstract class FloatVectorValues extends KnnVectorValues { + /** + * A dictionary of dense float vectors. + */ + public static abstract class Dictionary { + + /** + * Return the vector value for the given vector ordinal which must be in [0, size() - 1], + * otherwise IndexOutOfBoundsException is thrown. The returned array may be shared across calls. + * + * @return the vector value + */ + public abstract float[] vectorValue(int ord) throws IOException; + } + /** Sole constructor */ protected FloatVectorValues() {} + @Deprecated @Override public FloatVectorValues copy() throws IOException { return this; } + /** Retrieve a {@link Dictionary} of vectors. */ + public Dictionary dictionary() throws IOException { + FloatVectorValues copy = copy(); + return new Dictionary() { + @Override + public float[] vectorValue(int ord) throws IOException { + return copy.vectorValue(ord); + } + }; + } + /** * Return the vector value for the given vector ordinal which must be in [0, size() - 1], * otherwise IndexOutOfBoundsException is thrown. The returned array may be shared across calls. * * @return the vector value */ - public abstract float[] vectorValue(int ord) throws IOException; + @Deprecated + public float[] vectorValue(int ord) throws IOException { + return dictionary().vectorValue(ord); + } /** * Checks the Vector Encoding of a field @@ -102,13 +131,13 @@ public int dimension() { } @Override - public float[] vectorValue(int targetOrd) { - return vectors.get(targetOrd); - } - - @Override - public FloatVectorValues copy() { - return this; + public Dictionary dictionary() throws IOException { + return new Dictionary() { + @Override + public float[] vectorValue(int ord) throws IOException { + return vectors.get(ord); + } + }; } @Override