diff --git a/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecModel.java b/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecModel.java index 8189de4dd6c4..68fd3b5884b0 100644 --- a/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecModel.java +++ b/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecModel.java @@ -18,10 +18,10 @@ package org.apache.lucene.analysis.synonym.word2vec; import java.io.IOException; +import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.BytesRefHash; import org.apache.lucene.util.TermAndVector; -import org.apache.lucene.util.hnsw.RandomAccessVectorValues; /** * Word2VecModel is a class representing the parsed Word2Vec model containing the vectors for each @@ -29,7 +29,7 @@ * * @lucene.experimental */ -public class Word2VecModel implements RandomAccessVectorValues.Floats { +public class Word2VecModel extends FloatVectorValues { private final int dictionarySize; private final int vectorDimension; diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswGraphBuilder.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswGraphBuilder.java index 52972e9dcda4..0d7fd520a303 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswGraphBuilder.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswGraphBuilder.java @@ -22,10 +22,10 @@ import java.util.Objects; import java.util.SplittableRandom; import java.util.concurrent.TimeUnit; +import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.util.InfoStream; import org.apache.lucene.util.hnsw.NeighborQueue; -import org.apache.lucene.util.hnsw.RandomAccessVectorValues; /** * Builder for HNSW graph. See {@link Lucene90OnHeapHnswGraph} for a gloss on the algorithm and the @@ -49,7 +49,7 @@ public final class Lucene90HnswGraphBuilder { private final Lucene90NeighborArray scratch; private final VectorSimilarityFunction similarityFunction; - private final RandomAccessVectorValues.Floats vectorValues; + private final FloatVectorValues vectorValues; private final SplittableRandom random; private final Lucene90BoundsChecker bound; final Lucene90OnHeapHnswGraph hnsw; @@ -58,7 +58,7 @@ public final class Lucene90HnswGraphBuilder { // we need two sources of vectors in order to perform diversity check comparisons without // colliding - private final RandomAccessVectorValues.Floats buildVectors; + private final FloatVectorValues buildVectors; /** * Reads all the vectors from vector values, builds a graph connecting them by their dense @@ -73,7 +73,7 @@ public final class Lucene90HnswGraphBuilder { * to ensure repeatable construction. */ public Lucene90HnswGraphBuilder( - RandomAccessVectorValues.Floats vectors, + FloatVectorValues vectors, VectorSimilarityFunction similarityFunction, int maxConn, int beamWidth, @@ -97,14 +97,14 @@ public Lucene90HnswGraphBuilder( } /** - * Reads all the vectors from two copies of a {@link RandomAccessVectorValues}. Providing two - * copies enables efficient retrieval without extra data copying, while avoiding collision of the + * Reads all the vectors from two copies of a {@link FloatVectorValues}. Providing two copies + * enables efficient retrieval without extra data copying, while avoiding collision of the * returned values. * * @param vectors the vectors for which to build a nearest neighbors graph. Must be an independet * accessor for the vectors */ - public Lucene90OnHeapHnswGraph build(RandomAccessVectorValues.Floats vectors) throws IOException { + public Lucene90OnHeapHnswGraph build(FloatVectorValues vectors) throws IOException { if (vectors == vectorValues) { throw new IllegalArgumentException( "Vectors to build must be independent of the source of vectors provided to HnswGraphBuilder()"); @@ -230,7 +230,7 @@ private boolean diversityCheck( float[] candidate, float score, Lucene90NeighborArray neighbors, - RandomAccessVectorValues.Floats vectorValues) + FloatVectorValues vectorValues) throws IOException { bound.set(score); for (int i = 0; i < neighbors.size(); i++) { diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java index 665d31403214..1196ed3fdb64 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java @@ -20,7 +20,6 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import java.io.IOException; -import java.util.Arrays; import java.util.HashMap; import java.util.Map; import java.util.SplittableRandom; @@ -34,7 +33,6 @@ import org.apache.lucene.index.IndexFileNames; import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.VectorSimilarityFunction; -import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.search.VectorScorer; import org.apache.lucene.store.ChecksumIndexInput; @@ -44,7 +42,6 @@ import org.apache.lucene.util.IOUtils; import org.apache.lucene.util.hnsw.HnswGraph; import org.apache.lucene.util.hnsw.NeighborQueue; -import org.apache.lucene.util.hnsw.RandomAccessVectorValues; /** * Reads vectors from the index segments along with index data structures supporting KNN search. @@ -355,8 +352,7 @@ int size() { } /** Read the vector values from the index input. This supports both iterated and random access. */ - static class OffHeapFloatVectorValues extends FloatVectorValues - implements RandomAccessVectorValues.Floats { + static class OffHeapFloatVectorValues extends FloatVectorValues { final int dimension; final int[] ordToDoc; @@ -367,9 +363,6 @@ static class OffHeapFloatVectorValues extends FloatVectorValues final float[] value; final VectorSimilarityFunction similarityFunction; - int ord = -1; - int doc = -1; - OffHeapFloatVectorValues( int dimension, int[] ordToDoc, @@ -394,42 +387,6 @@ public int size() { return ordToDoc.length; } - @Override - public float[] vectorValue() throws IOException { - return vectorValue(ord); - } - - @Override - public int docID() { - return doc; - } - - @Override - public int nextDoc() { - if (++ord >= size()) { - doc = NO_MORE_DOCS; - } else { - doc = ordToDoc[ord]; - } - return doc; - } - - @Override - public int advance(int target) { - assert docID() < target; - ord = Arrays.binarySearch(ordToDoc, ord + 1, ordToDoc.length, target); - if (ord < 0) { - ord = -(ord + 1); - } - assert ord <= ordToDoc.length; - if (ord == ordToDoc.length) { - doc = NO_MORE_DOCS; - } else { - doc = ordToDoc[ord]; - } - return doc; - } - @Override public OffHeapFloatVectorValues copy() { return new OffHeapFloatVectorValues(dimension, ordToDoc, similarityFunction, dataIn.clone()); @@ -446,21 +403,32 @@ public float[] vectorValue(int targetOrd) throws IOException { return value; } + @Override + public int ordToDoc(int ord) { + return ordToDoc[ord]; + } + + @Override + public DocIndexIterator iterator() { + return createSparseIterator(); + } + @Override public VectorScorer scorer(float[] target) { if (size() == 0) { return null; } OffHeapFloatVectorValues values = this.copy(); + DocIndexIterator iterator = values.iterator(); return new VectorScorer() { @Override public float score() throws IOException { - return values.similarityFunction.compare(values.vectorValue(), target); + return values.similarityFunction.compare(values.vectorValue(iterator.index()), target); } @Override - public DocIdSetIterator iterator() { - return values; + public DocIndexIterator iterator() { + return iterator; } }; } diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90OnHeapHnswGraph.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90OnHeapHnswGraph.java index 52f2146e836b..845987c2957c 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90OnHeapHnswGraph.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90OnHeapHnswGraph.java @@ -23,12 +23,12 @@ import java.util.ArrayList; import java.util.List; import java.util.SplittableRandom; +import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.util.Bits; import org.apache.lucene.util.SparseFixedBitSet; import org.apache.lucene.util.hnsw.HnswGraph; import org.apache.lucene.util.hnsw.NeighborQueue; -import org.apache.lucene.util.hnsw.RandomAccessVectorValues; /** * An {@link HnswGraph} where all nodes and connections are held in memory. This class is used to @@ -74,7 +74,7 @@ public static NeighborQueue search( float[] query, int topK, int numSeed, - RandomAccessVectorValues.Floats vectors, + FloatVectorValues vectors, VectorSimilarityFunction similarityFunction, HnswGraph graphValues, Bits acceptOrds, diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java index 81f8d97a9a0c..a140b4fd7f39 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java @@ -46,7 +46,6 @@ import org.apache.lucene.util.hnsw.HnswGraph; import org.apache.lucene.util.hnsw.HnswGraphSearcher; import org.apache.lucene.util.hnsw.OrdinalTranslatedKnnCollector; -import org.apache.lucene.util.hnsw.RandomAccessVectorValues; import org.apache.lucene.util.hnsw.RandomVectorScorer; /** @@ -398,8 +397,7 @@ int ordToDoc(int ord) { } /** Read the vector values from the index input. This supports both iterated and random access. */ - static class OffHeapFloatVectorValues extends FloatVectorValues - implements RandomAccessVectorValues.Floats { + static class OffHeapFloatVectorValues extends FloatVectorValues { private final int dimension; private final int size; @@ -410,9 +408,6 @@ static class OffHeapFloatVectorValues extends FloatVectorValues private final float[] value; private final VectorSimilarityFunction similarityFunction; - private int ord = -1; - private int doc = -1; - OffHeapFloatVectorValues( int dimension, int size, @@ -439,49 +434,6 @@ public int size() { return size; } - @Override - public float[] vectorValue() throws IOException { - dataIn.seek((long) ord * byteSize); - dataIn.readFloats(value, 0, value.length); - return value; - } - - @Override - public int docID() { - return doc; - } - - @Override - public int nextDoc() { - if (++ord >= size) { - doc = NO_MORE_DOCS; - } else { - doc = ordToDocOperator.applyAsInt(ord); - } - return doc; - } - - @Override - public int advance(int target) { - assert docID() < target; - - if (ordToDoc == null) { - ord = target; - } else { - ord = Arrays.binarySearch(ordToDoc, ord + 1, ordToDoc.length, target); - if (ord < 0) { - ord = -(ord + 1); - } - } - - if (ord < size) { - doc = ordToDocOperator.applyAsInt(ord); - } else { - doc = NO_MORE_DOCS; - } - return doc; - } - @Override public OffHeapFloatVectorValues copy() { return new OffHeapFloatVectorValues( @@ -495,21 +447,32 @@ public float[] vectorValue(int targetOrd) throws IOException { return value; } + @Override + public int ordToDoc(int ord) { + return ordToDocOperator.applyAsInt(ord); + } + + @Override + public DocIndexIterator iterator() { + return createSparseIterator(); + } + @Override public VectorScorer scorer(float[] target) { if (size == 0) { return null; } OffHeapFloatVectorValues values = this.copy(); + DocIndexIterator iterator = values.iterator(); return new VectorScorer() { @Override public float score() throws IOException { - return values.similarityFunction.compare(values.vectorValue(), target); + return values.similarityFunction.compare(values.vectorValue(iterator.index()), target); } @Override public DocIdSetIterator iterator() { - return values; + return iterator; } }; } diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/OffHeapFloatVectorValues.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/OffHeapFloatVectorValues.java index 19dc82cc46d5..c53594f36a4d 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/OffHeapFloatVectorValues.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/OffHeapFloatVectorValues.java @@ -26,12 +26,10 @@ import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.RandomAccessInput; import org.apache.lucene.util.Bits; -import org.apache.lucene.util.hnsw.RandomAccessVectorValues; import org.apache.lucene.util.packed.DirectMonotonicReader; /** Read the vector values from the index input. This supports both iterated and random access. */ -abstract class OffHeapFloatVectorValues extends FloatVectorValues - implements RandomAccessVectorValues.Floats { +abstract class OffHeapFloatVectorValues extends FloatVectorValues { protected final int dimension; protected final int size; @@ -95,8 +93,6 @@ static OffHeapFloatVectorValues load( static class DenseOffHeapVectorValues extends OffHeapFloatVectorValues { - private int doc = -1; - public DenseOffHeapVectorValues( int dimension, int size, @@ -106,32 +102,13 @@ public DenseOffHeapVectorValues( } @Override - public float[] vectorValue() throws IOException { - return vectorValue(doc); - } - - @Override - public int docID() { - return doc; - } - - @Override - public int nextDoc() throws IOException { - return advance(doc + 1); - } - - @Override - public int advance(int target) throws IOException { - assert docID() < target; - if (target >= size) { - return doc = NO_MORE_DOCS; - } - return doc = target; + public DenseOffHeapVectorValues copy() throws IOException { + return new DenseOffHeapVectorValues(dimension, size, vectorSimilarityFunction, slice.clone()); } @Override - public DenseOffHeapVectorValues copy() throws IOException { - return new DenseOffHeapVectorValues(dimension, size, vectorSimilarityFunction, slice.clone()); + public DocIndexIterator iterator() { + return createDenseIterator(); } @Override @@ -142,15 +119,17 @@ public Bits getAcceptOrds(Bits acceptDocs) { @Override public VectorScorer scorer(float[] query) throws IOException { DenseOffHeapVectorValues values = this.copy(); + DocIndexIterator iterator = values.iterator(); return new VectorScorer() { @Override public float score() throws IOException { - return values.vectorSimilarityFunction.compare(values.vectorValue(), query); + return values.vectorSimilarityFunction.compare( + values.vectorValue(iterator.index()), query); } @Override public DocIdSetIterator iterator() { - return values; + return iterator; } }; } @@ -186,33 +165,17 @@ public SparseOffHeapVectorValues( fieldEntry.size()); } - @Override - public float[] vectorValue() throws IOException { - return vectorValue(disi.index()); - } - - @Override - public int docID() { - return disi.docID(); - } - - @Override - public int nextDoc() throws IOException { - return disi.nextDoc(); - } - - @Override - public int advance(int target) throws IOException { - assert docID() < target; - return disi.advance(target); - } - @Override public SparseOffHeapVectorValues copy() throws IOException { return new SparseOffHeapVectorValues( fieldEntry, dataIn, vectorSimilarityFunction, slice.clone()); } + @Override + public DocIndexIterator iterator() { + return IndexedDISI.asDocIndexIterator(disi); + } + @Override public int ordToDoc(int ord) { return (int) ordToDoc.get(ord); @@ -239,15 +202,17 @@ public int length() { @Override public VectorScorer scorer(float[] query) throws IOException { SparseOffHeapVectorValues values = this.copy(); + DocIndexIterator iterator = values.iterator(); return new VectorScorer() { @Override public float score() throws IOException { - return values.vectorSimilarityFunction.compare(values.vectorValue(), query); + return values.vectorSimilarityFunction.compare( + values.vectorValue(iterator.index()), query); } @Override public DocIdSetIterator iterator() { - return values; + return iterator; } }; } @@ -259,8 +224,6 @@ public EmptyOffHeapVectorValues(int dimension) { super(dimension, 0, VectorSimilarityFunction.COSINE, null); } - private int doc = -1; - @Override public int dimension() { return super.dimension(); @@ -271,26 +234,6 @@ public int size() { return 0; } - @Override - public float[] vectorValue() throws IOException { - throw new UnsupportedOperationException(); - } - - @Override - public int docID() { - return doc; - } - - @Override - public int nextDoc() throws IOException { - return advance(doc + 1); - } - - @Override - public int advance(int target) throws IOException { - return doc = NO_MORE_DOCS; - } - @Override public OffHeapFloatVectorValues copy() throws IOException { throw new UnsupportedOperationException(); diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapByteVectorValues.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapByteVectorValues.java index 0c909e3839df..aedda7a6258c 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapByteVectorValues.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapByteVectorValues.java @@ -28,12 +28,10 @@ import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.RandomAccessInput; import org.apache.lucene.util.Bits; -import org.apache.lucene.util.hnsw.RandomAccessVectorValues; import org.apache.lucene.util.packed.DirectMonotonicReader; /** Read the vector values from the index input. This supports both iterated and random access. */ -abstract class OffHeapByteVectorValues extends ByteVectorValues - implements RandomAccessVectorValues.Bytes { +abstract class OffHeapByteVectorValues extends ByteVectorValues { protected final int dimension; protected final int size; @@ -108,8 +106,6 @@ static OffHeapByteVectorValues load( static class DenseOffHeapVectorValues extends OffHeapByteVectorValues { - private int doc = -1; - public DenseOffHeapVectorValues( int dimension, int size, @@ -119,36 +115,17 @@ public DenseOffHeapVectorValues( super(dimension, size, slice, vectorSimilarityFunction, byteSize); } - @Override - public byte[] vectorValue() throws IOException { - return vectorValue(doc); - } - - @Override - public int docID() { - return doc; - } - - @Override - public int nextDoc() throws IOException { - return advance(doc + 1); - } - - @Override - public int advance(int target) throws IOException { - assert docID() < target; - if (target >= size) { - return doc = NO_MORE_DOCS; - } - return doc = target; - } - @Override public DenseOffHeapVectorValues copy() throws IOException { return new DenseOffHeapVectorValues( dimension, size, slice.clone(), vectorSimilarityFunction, byteSize); } + @Override + public DocIndexIterator iterator() { + return createDenseIterator(); + } + @Override public Bits getAcceptOrds(Bits acceptDocs) { return acceptDocs; @@ -157,15 +134,16 @@ public Bits getAcceptOrds(Bits acceptDocs) { @Override public VectorScorer scorer(byte[] query) throws IOException { DenseOffHeapVectorValues copy = this.copy(); + DocIndexIterator iterator = copy.iterator(); return new VectorScorer() { @Override public float score() throws IOException { - return vectorSimilarityFunction.compare(copy.vectorValue(), query); + return vectorSimilarityFunction.compare(copy.vectorValue(iterator.index()), query); } @Override public DocIdSetIterator iterator() { - return copy; + return iterator; } }; } @@ -202,27 +180,6 @@ public SparseOffHeapVectorValues( fieldEntry.size()); } - @Override - public byte[] vectorValue() throws IOException { - return vectorValue(disi.index()); - } - - @Override - public int docID() { - return disi.docID(); - } - - @Override - public int nextDoc() throws IOException { - return disi.nextDoc(); - } - - @Override - public int advance(int target) throws IOException { - assert docID() < target; - return disi.advance(target); - } - @Override public SparseOffHeapVectorValues copy() throws IOException { return new SparseOffHeapVectorValues( @@ -234,6 +191,11 @@ public int ordToDoc(int ord) { return (int) ordToDoc.get(ord); } + @Override + public DocIndexIterator iterator() { + return fromDISI(disi); + } + @Override public Bits getAcceptOrds(Bits acceptDocs) { if (acceptDocs == null) { @@ -255,15 +217,16 @@ public int length() { @Override public VectorScorer scorer(byte[] query) throws IOException { SparseOffHeapVectorValues copy = this.copy(); + IndexedDISI disi = copy.disi; return new VectorScorer() { @Override public float score() throws IOException { - return vectorSimilarityFunction.compare(copy.vectorValue(), query); + return vectorSimilarityFunction.compare(copy.vectorValue(disi.index()), query); } @Override public DocIdSetIterator iterator() { - return copy; + return disi; } }; } @@ -275,8 +238,6 @@ public EmptyOffHeapVectorValues(int dimension) { super(dimension, 0, null, VectorSimilarityFunction.COSINE, 0); } - private int doc = -1; - @Override public int dimension() { return super.dimension(); @@ -287,26 +248,6 @@ public int size() { return 0; } - @Override - public byte[] vectorValue() throws IOException { - throw new UnsupportedOperationException(); - } - - @Override - public int docID() { - return doc; - } - - @Override - public int nextDoc() throws IOException { - return advance(doc + 1); - } - - @Override - public int advance(int target) throws IOException { - return doc = NO_MORE_DOCS; - } - @Override public OffHeapByteVectorValues copy() throws IOException { throw new UnsupportedOperationException(); diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapFloatVectorValues.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapFloatVectorValues.java index 91f97b8a41fa..02664837982b 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapFloatVectorValues.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapFloatVectorValues.java @@ -26,12 +26,10 @@ import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.RandomAccessInput; import org.apache.lucene.util.Bits; -import org.apache.lucene.util.hnsw.RandomAccessVectorValues; import org.apache.lucene.util.packed.DirectMonotonicReader; /** Read the vector values from the index input. This supports both iterated and random access. */ -abstract class OffHeapFloatVectorValues extends FloatVectorValues - implements RandomAccessVectorValues.Floats { +abstract class OffHeapFloatVectorValues extends FloatVectorValues { protected final int dimension; protected final int size; @@ -104,8 +102,6 @@ static OffHeapFloatVectorValues load( static class DenseOffHeapVectorValues extends OffHeapFloatVectorValues { - private int doc = -1; - public DenseOffHeapVectorValues( int dimension, int size, @@ -115,36 +111,17 @@ public DenseOffHeapVectorValues( super(dimension, size, slice, vectorSimilarityFunction, byteSize); } - @Override - public float[] vectorValue() throws IOException { - return vectorValue(doc); - } - - @Override - public int docID() { - return doc; - } - - @Override - public int nextDoc() throws IOException { - return advance(doc + 1); - } - - @Override - public int advance(int target) throws IOException { - assert docID() < target; - if (target >= size) { - return doc = NO_MORE_DOCS; - } - return doc = target; - } - @Override public DenseOffHeapVectorValues copy() throws IOException { return new DenseOffHeapVectorValues( dimension, size, slice.clone(), vectorSimilarityFunction, byteSize); } + @Override + public DocIndexIterator iterator() { + return createDenseIterator(); + } + @Override public Bits getAcceptOrds(Bits acceptDocs) { return acceptDocs; @@ -153,15 +130,18 @@ public Bits getAcceptOrds(Bits acceptDocs) { @Override public VectorScorer scorer(float[] query) throws IOException { DenseOffHeapVectorValues values = this.copy(); + DocIndexIterator iterator = values.iterator(); + return new VectorScorer() { @Override public float score() throws IOException { - return values.vectorSimilarityFunction.compare(values.vectorValue(), query); + return values.vectorSimilarityFunction.compare( + values.vectorValue(iterator.index()), query); } @Override public DocIdSetIterator iterator() { - return values; + return iterator; } }; } @@ -198,33 +178,17 @@ public SparseOffHeapVectorValues( fieldEntry.size()); } - @Override - public float[] vectorValue() throws IOException { - return vectorValue(disi.index()); - } - - @Override - public int docID() { - return disi.docID(); - } - - @Override - public int nextDoc() throws IOException { - return disi.nextDoc(); - } - - @Override - public int advance(int target) throws IOException { - assert docID() < target; - return disi.advance(target); - } - @Override public SparseOffHeapVectorValues copy() throws IOException { return new SparseOffHeapVectorValues( fieldEntry, dataIn, slice.clone(), vectorSimilarityFunction, byteSize); } + @Override + public DocIndexIterator iterator() { + return IndexedDISI.asDocIndexIterator(disi); + } + @Override public int ordToDoc(int ord) { return (int) ordToDoc.get(ord); @@ -251,15 +215,17 @@ public int length() { @Override public VectorScorer scorer(float[] query) throws IOException { SparseOffHeapVectorValues values = this.copy(); + DocIndexIterator iterator = values.iterator(); return new VectorScorer() { @Override public float score() throws IOException { - return values.vectorSimilarityFunction.compare(values.vectorValue(), query); + return values.vectorSimilarityFunction.compare( + values.vectorValue(iterator.index()), query); } @Override public DocIdSetIterator iterator() { - return values; + return iterator; } }; } @@ -271,8 +237,6 @@ public EmptyOffHeapVectorValues(int dimension) { super(dimension, 0, null, VectorSimilarityFunction.COSINE, 0); } - private int doc = -1; - @Override public int dimension() { return super.dimension(); @@ -283,26 +247,6 @@ public int size() { return 0; } - @Override - public float[] vectorValue() throws IOException { - throw new UnsupportedOperationException(); - } - - @Override - public int docID() { - return doc; - } - - @Override - public int nextDoc() throws IOException { - return advance(doc + 1); - } - - @Override - public int advance(int target) throws IOException { - return doc = NO_MORE_DOCS; - } - @Override public OffHeapFloatVectorValues copy() throws IOException { throw new UnsupportedOperationException(); diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsWriter.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsWriter.java index 39828524d264..f60411752d20 100644 --- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsWriter.java +++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsWriter.java @@ -29,13 +29,13 @@ import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.IndexFileNames; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.SegmentWriteState; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.IndexOutput; import org.apache.lucene.util.ArrayUtil; import org.apache.lucene.util.IOUtils; -import org.apache.lucene.util.hnsw.RandomAccessVectorValues; /** * Writes vector values and knn graphs to index segments. @@ -188,12 +188,13 @@ private static int[] writeVectorData(IndexOutput output, FloatVectorValues vecto int count = 0; ByteBuffer binaryVector = ByteBuffer.allocate(vectors.dimension() * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN); - for (int docV = vectors.nextDoc(); docV != NO_MORE_DOCS; docV = vectors.nextDoc(), count++) { + KnnVectorValues.DocIndexIterator iter = vectors.iterator(); + for (int docV = iter.nextDoc(); docV != NO_MORE_DOCS; docV = iter.nextDoc()) { // write vector - float[] vectorValue = vectors.vectorValue(); + float[] vectorValue = vectors.vectorValue(iter.index()); binaryVector.asFloatBuffer().put(vectorValue); output.writeBytes(binaryVector.array(), binaryVector.limit()); - docIds[count] = docV; + docIds[count++] = docV; } if (docIds.length > count) { @@ -234,7 +235,7 @@ private void writeGraphOffsets(IndexOutput out, long[] offsets) throws IOExcepti private void writeGraph( IndexOutput graphData, - RandomAccessVectorValues.Floats vectorValues, + FloatVectorValues vectorValues, VectorSimilarityFunction similarityFunction, long graphDataOffset, long[] offsets, diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene90/TestLucene90HnswVectorsFormat.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene90/TestLucene90HnswVectorsFormat.java index 37e2745c2471..b4840c9fd5b2 100644 --- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene90/TestLucene90HnswVectorsFormat.java +++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene90/TestLucene90HnswVectorsFormat.java @@ -12,7 +12,7 @@ * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and - * limitations under the License. + * limIndexedDISIitations under the License. */ package org.apache.lucene.backward_codecs.lucene90; diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswGraphBuilder.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswGraphBuilder.java index dbb9a71b4218..5ef85a8419c2 100644 --- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswGraphBuilder.java +++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswGraphBuilder.java @@ -25,6 +25,7 @@ import java.util.SplittableRandom; import java.util.concurrent.TimeUnit; import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer; +import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.util.FixedBitSet; import org.apache.lucene.util.InfoStream; @@ -32,7 +33,6 @@ import org.apache.lucene.util.hnsw.HnswGraphBuilder; import org.apache.lucene.util.hnsw.HnswGraphSearcher; import org.apache.lucene.util.hnsw.NeighborQueue; -import org.apache.lucene.util.hnsw.RandomAccessVectorValues; import org.apache.lucene.util.hnsw.RandomVectorScorer; /** @@ -57,7 +57,7 @@ public final class Lucene91HnswGraphBuilder { private final DefaultFlatVectorScorer defaultFlatVectorScorer = new DefaultFlatVectorScorer(); private final VectorSimilarityFunction similarityFunction; - private final RandomAccessVectorValues.Floats vectorValues; + private final FloatVectorValues vectorValues; private final SplittableRandom random; private final Lucene91BoundsChecker bound; private final HnswGraphSearcher graphSearcher; @@ -68,7 +68,7 @@ public final class Lucene91HnswGraphBuilder { // we need two sources of vectors in order to perform diversity check comparisons without // colliding - private RandomAccessVectorValues.Floats buildVectors; + private FloatVectorValues buildVectors; /** * Reads all the vectors from vector values, builds a graph connecting them by their dense @@ -83,7 +83,7 @@ public final class Lucene91HnswGraphBuilder { * to ensure repeatable construction. */ public Lucene91HnswGraphBuilder( - RandomAccessVectorValues.Floats vectors, + FloatVectorValues vectors, VectorSimilarityFunction similarityFunction, int maxConn, int beamWidth, @@ -113,14 +113,14 @@ public Lucene91HnswGraphBuilder( } /** - * Reads all the vectors from two copies of a {@link RandomAccessVectorValues}. Providing two - * copies enables efficient retrieval without extra data copying, while avoiding collision of the + * Reads all the vectors from two copies of a {@link FloatVectorValues}. Providing two copies + * enables efficient retrieval without extra data copying, while avoiding collision of the * returned values. * - * @param vectors the vectors for which to build a nearest neighbors graph. Must be an independet + * @param vectors the vectors for which to build a nearest neighbors graph. Must be an independent * accessor for the vectors */ - public Lucene91OnHeapHnswGraph build(RandomAccessVectorValues.Floats vectors) throws IOException { + public Lucene91OnHeapHnswGraph build(FloatVectorValues vectors) throws IOException { if (vectors == vectorValues) { throw new IllegalArgumentException( "Vectors to build must be independent of the source of vectors provided to HnswGraphBuilder()"); @@ -254,7 +254,7 @@ private boolean diversityCheck( float[] candidate, float score, Lucene91NeighborArray neighbors, - RandomAccessVectorValues.Floats vectorValues) + FloatVectorValues vectorValues) throws IOException { bound.set(score); for (int i = 0; i < neighbors.size(); i++) { diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsWriter.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsWriter.java index 37b752503817..a984a3ef1f8b 100644 --- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsWriter.java +++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsWriter.java @@ -17,8 +17,6 @@ package org.apache.lucene.backward_codecs.lucene91; -import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; - import java.io.IOException; import java.nio.ByteBuffer; import java.nio.ByteOrder; @@ -30,6 +28,7 @@ import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.IndexFileNames; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.SegmentWriteState; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.search.DocIdSetIterator; @@ -37,7 +36,6 @@ import org.apache.lucene.store.IndexOutput; import org.apache.lucene.util.IOUtils; import org.apache.lucene.util.hnsw.HnswGraph; -import org.apache.lucene.util.hnsw.RandomAccessVectorValues; /** * Writes vector values and knn graphs to index segments. @@ -183,9 +181,10 @@ private static DocsWithFieldSet writeVectorData(IndexOutput output, FloatVectorV DocsWithFieldSet docsWithField = new DocsWithFieldSet(); ByteBuffer binaryVector = ByteBuffer.allocate(vectors.dimension() * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN); - for (int docV = vectors.nextDoc(); docV != NO_MORE_DOCS; docV = vectors.nextDoc()) { + KnnVectorValues.DocIndexIterator iter = vectors.iterator(); + for (int docV = iter.nextDoc(); docV != DocIdSetIterator.NO_MORE_DOCS; docV = iter.nextDoc()) { // write vector - float[] vectorValue = vectors.vectorValue(); + float[] vectorValue = vectors.vectorValue(iter.index()); binaryVector.asFloatBuffer().put(vectorValue); output.writeBytes(binaryVector.array(), binaryVector.limit()); docsWithField.add(docV); @@ -243,7 +242,7 @@ private void writeMeta( } private Lucene91OnHeapHnswGraph writeGraph( - RandomAccessVectorValues.Floats vectorValues, VectorSimilarityFunction similarityFunction) + FloatVectorValues vectorValues, VectorSimilarityFunction similarityFunction) throws IOException { // build graph diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsWriter.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsWriter.java index caa8fc3da149..bf1c89a536d8 100644 --- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsWriter.java +++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsWriter.java @@ -18,7 +18,6 @@ package org.apache.lucene.backward_codecs.lucene92; import static org.apache.lucene.backward_codecs.lucene92.Lucene92RWHnswVectorsFormat.DIRECT_MONOTONIC_BLOCK_SHIFT; -import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import java.io.IOException; import java.nio.ByteBuffer; @@ -33,6 +32,7 @@ import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.IndexFileNames; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.SegmentWriteState; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.search.DocIdSetIterator; @@ -43,7 +43,6 @@ import org.apache.lucene.util.hnsw.HnswGraphBuilder; import org.apache.lucene.util.hnsw.NeighborArray; import org.apache.lucene.util.hnsw.OnHeapHnswGraph; -import org.apache.lucene.util.hnsw.RandomAccessVectorValues; import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; import org.apache.lucene.util.packed.DirectMonotonicWriter; @@ -190,9 +189,12 @@ private static DocsWithFieldSet writeVectorData(IndexOutput output, FloatVectorV DocsWithFieldSet docsWithField = new DocsWithFieldSet(); ByteBuffer binaryVector = ByteBuffer.allocate(vectors.dimension() * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN); - for (int docV = vectors.nextDoc(); docV != NO_MORE_DOCS; docV = vectors.nextDoc()) { + KnnVectorValues.DocIndexIterator iterator = vectors.iterator(); + for (int docV = iterator.nextDoc(); + docV != DocIdSetIterator.NO_MORE_DOCS; + docV = iterator.nextDoc()) { // write vector - float[] vectorValue = vectors.vectorValue(); + float[] vectorValue = vectors.vectorValue(iterator.index()); binaryVector.asFloatBuffer().put(vectorValue); output.writeBytes(binaryVector.array(), binaryVector.limit()); docsWithField.add(docV); @@ -277,7 +279,7 @@ private void writeMeta( } private OnHeapHnswGraph writeGraph( - RandomAccessVectorValues.Floats vectorValues, VectorSimilarityFunction similarityFunction) + FloatVectorValues vectorValues, VectorSimilarityFunction similarityFunction) throws IOException { DefaultFlatVectorScorer defaultFlatVectorScorer = new DefaultFlatVectorScorer(); // build graph diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsWriter.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsWriter.java index 1cb445cab776..01698da79893 100644 --- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsWriter.java +++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsWriter.java @@ -36,6 +36,7 @@ import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.IndexFileNames; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.MergeState; import org.apache.lucene.index.SegmentWriteState; import org.apache.lucene.index.Sorter; @@ -52,7 +53,6 @@ import org.apache.lucene.util.hnsw.HnswGraphBuilder; import org.apache.lucene.util.hnsw.NeighborArray; import org.apache.lucene.util.hnsw.OnHeapHnswGraph; -import org.apache.lucene.util.hnsw.RandomAccessVectorValues; import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; import org.apache.lucene.util.packed.DirectMonotonicWriter; @@ -216,9 +216,7 @@ private void writeSortingField(FieldWriter fieldData, int maxDoc, Sorter.DocM final int[] docIdOffsets = new int[sortMap.size()]; int offset = 1; // 0 means no vector for this (field, document) DocIdSetIterator iterator = fieldData.docsWithField.iterator(); - for (int docID = iterator.nextDoc(); - docID != DocIdSetIterator.NO_MORE_DOCS; - docID = iterator.nextDoc()) { + for (int docID = iterator.nextDoc(); docID != NO_MORE_DOCS; docID = iterator.nextDoc()) { int newDocID = sortMap.oldToNew(docID); docIdOffsets[newDocID] = offset++; } @@ -556,9 +554,7 @@ private void writeMeta( final DirectMonotonicWriter ordToDocWriter = DirectMonotonicWriter.getInstance(meta, vectorData, count, DIRECT_MONOTONIC_BLOCK_SHIFT); DocIdSetIterator iterator = docsWithField.iterator(); - for (int doc = iterator.nextDoc(); - doc != DocIdSetIterator.NO_MORE_DOCS; - doc = iterator.nextDoc()) { + for (int doc = iterator.nextDoc(); doc != NO_MORE_DOCS; doc = iterator.nextDoc()) { ordToDocWriter.add(doc); } ordToDocWriter.finish(); @@ -590,11 +586,10 @@ private void writeMeta( private static DocsWithFieldSet writeByteVectorData( IndexOutput output, ByteVectorValues byteVectorValues) throws IOException { DocsWithFieldSet docsWithField = new DocsWithFieldSet(); - for (int docV = byteVectorValues.nextDoc(); - docV != NO_MORE_DOCS; - docV = byteVectorValues.nextDoc()) { + KnnVectorValues.DocIndexIterator iter = byteVectorValues.iterator(); + for (int docV = iter.nextDoc(); docV != NO_MORE_DOCS; docV = iter.nextDoc()) { // write vector - byte[] binaryValue = byteVectorValues.vectorValue(); + byte[] binaryValue = byteVectorValues.vectorValue(iter.index()); assert binaryValue.length == byteVectorValues.dimension() * VectorEncoding.BYTE.byteSize; output.writeBytes(binaryValue, binaryValue.length); docsWithField.add(docV); @@ -608,14 +603,13 @@ private static DocsWithFieldSet writeByteVectorData( private static DocsWithFieldSet writeVectorData( IndexOutput output, FloatVectorValues floatVectorValues) throws IOException { DocsWithFieldSet docsWithField = new DocsWithFieldSet(); + KnnVectorValues.DocIndexIterator iter = floatVectorValues.iterator(); ByteBuffer binaryVector = ByteBuffer.allocate(floatVectorValues.dimension() * VectorEncoding.FLOAT32.byteSize) .order(ByteOrder.LITTLE_ENDIAN); - for (int docV = floatVectorValues.nextDoc(); - docV != NO_MORE_DOCS; - docV = floatVectorValues.nextDoc()) { + for (int docV = iter.nextDoc(); docV != NO_MORE_DOCS; docV = iter.nextDoc()) { // write vector - float[] vectorValue = floatVectorValues.vectorValue(); + float[] vectorValue = floatVectorValues.vectorValue(iter.index()); binaryVector.asFloatBuffer().put(vectorValue); output.writeBytes(binaryVector.array(), binaryVector.limit()); docsWithField.add(docV); @@ -672,11 +666,11 @@ public float[] copyValue(float[] value) { case BYTE -> defaultFlatVectorScorer.getRandomVectorScorerSupplier( fieldInfo.getVectorSimilarityFunction(), - RandomAccessVectorValues.fromBytes((List) vectors, dim)); + ByteVectorValues.fromBytes((List) vectors, dim)); case FLOAT32 -> defaultFlatVectorScorer.getRandomVectorScorerSupplier( fieldInfo.getVectorSimilarityFunction(), - RandomAccessVectorValues.fromFloats((List) vectors, dim)); + FloatVectorValues.fromFloats((List) vectors, dim)); }; hnswGraphBuilder = HnswGraphBuilder.create(scorerSupplier, M, beamWidth, HnswGraphBuilder.randSeed); diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene95/Lucene95HnswVectorsWriter.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene95/Lucene95HnswVectorsWriter.java index 37c39d311d6b..c855d8f5e073 100644 --- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene95/Lucene95HnswVectorsWriter.java +++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene95/Lucene95HnswVectorsWriter.java @@ -39,6 +39,7 @@ import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.IndexFileNames; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.MergeState; import org.apache.lucene.index.SegmentWriteState; import org.apache.lucene.index.Sorter; @@ -56,7 +57,6 @@ import org.apache.lucene.util.hnsw.IncrementalHnswGraphMerger; import org.apache.lucene.util.hnsw.NeighborArray; import org.apache.lucene.util.hnsw.OnHeapHnswGraph; -import org.apache.lucene.util.hnsw.RandomAccessVectorValues; import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; import org.apache.lucene.util.packed.DirectMonotonicWriter; @@ -221,9 +221,7 @@ private void writeSortingField(FieldWriter fieldData, int maxDoc, Sorter.DocM final int[] docIdOffsets = new int[sortMap.size()]; int offset = 1; // 0 means no vector for this (field, document) DocIdSetIterator iterator = fieldData.docsWithField.iterator(); - for (int docID = iterator.nextDoc(); - docID != DocIdSetIterator.NO_MORE_DOCS; - docID = iterator.nextDoc()) { + for (int docID = iterator.nextDoc(); docID != NO_MORE_DOCS; docID = iterator.nextDoc()) { int newDocID = sortMap.oldToNew(docID); docIdOffsets[newDocID] = offset++; } @@ -482,18 +480,18 @@ public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOE mergeState.knnVectorsReaders[i], mergeState.docMaps[i], mergeState.liveDocs[i]); } } - DocIdSetIterator mergedVectorIterator = null; + KnnVectorValues mergedVectorValues = null; switch (fieldInfo.getVectorEncoding()) { case BYTE -> - mergedVectorIterator = + mergedVectorValues = KnnVectorsWriter.MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState); case FLOAT32 -> - mergedVectorIterator = + mergedVectorValues = KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState); } graph = merger.merge( - mergedVectorIterator, segmentWriteState.infoStream, docsWithField.cardinality()); + mergedVectorValues, segmentWriteState.infoStream, docsWithField.cardinality()); vectorIndexNodeOffsets = writeGraph(graph); } long vectorIndexLength = vectorIndex.getFilePointer() - vectorIndexOffset; @@ -636,14 +634,13 @@ private void writeMeta( private static DocsWithFieldSet writeByteVectorData( IndexOutput output, ByteVectorValues byteVectorValues) throws IOException { DocsWithFieldSet docsWithField = new DocsWithFieldSet(); - for (int docV = byteVectorValues.nextDoc(); - docV != NO_MORE_DOCS; - docV = byteVectorValues.nextDoc()) { + KnnVectorValues.DocIndexIterator iter = byteVectorValues.iterator(); + for (int docId = iter.nextDoc(); docId != NO_MORE_DOCS; docId = iter.nextDoc()) { // write vector - byte[] binaryValue = byteVectorValues.vectorValue(); + byte[] binaryValue = byteVectorValues.vectorValue(iter.index()); assert binaryValue.length == byteVectorValues.dimension() * VectorEncoding.BYTE.byteSize; output.writeBytes(binaryValue, binaryValue.length); - docsWithField.add(docV); + docsWithField.add(docId); } return docsWithField; } @@ -657,11 +654,10 @@ private static DocsWithFieldSet writeVectorData( ByteBuffer buffer = ByteBuffer.allocate(floatVectorValues.dimension() * VectorEncoding.FLOAT32.byteSize) .order(ByteOrder.LITTLE_ENDIAN); - for (int docV = floatVectorValues.nextDoc(); - docV != NO_MORE_DOCS; - docV = floatVectorValues.nextDoc()) { + KnnVectorValues.DocIndexIterator iter = floatVectorValues.iterator(); + for (int docV = iter.nextDoc(); docV != NO_MORE_DOCS; docV = iter.nextDoc()) { // write vector - float[] value = floatVectorValues.vectorValue(); + float[] value = floatVectorValues.vectorValue(iter.index()); buffer.asFloatBuffer().put(value); output.writeBytes(buffer.array(), buffer.limit()); docsWithField.add(docV); @@ -718,11 +714,11 @@ public float[] copyValue(float[] value) { case BYTE -> defaultFlatVectorScorer.getRandomVectorScorerSupplier( fieldInfo.getVectorSimilarityFunction(), - RandomAccessVectorValues.fromBytes((List) vectors, dim)); + ByteVectorValues.fromBytes((List) vectors, dim)); case FLOAT32 -> defaultFlatVectorScorer.getRandomVectorScorerSupplier( fieldInfo.getVectorSimilarityFunction(), - RandomAccessVectorValues.fromFloats((List) vectors, dim)); + FloatVectorValues.fromFloats((List) vectors, dim)); }; hnswGraphBuilder = HnswGraphBuilder.create(scorerSupplier, M, beamWidth, HnswGraphBuilder.randSeed); diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_index/TestBasicBackwardsCompatibility.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_index/TestBasicBackwardsCompatibility.java index 8d35a1128be9..cf50b9e1526d 100644 --- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_index/TestBasicBackwardsCompatibility.java +++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_index/TestBasicBackwardsCompatibility.java @@ -52,6 +52,7 @@ import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.IndexWriterConfig; import org.apache.lucene.index.IndexableField; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.LogByteSizeMergePolicy; import org.apache.lucene.index.MultiBits; @@ -477,10 +478,14 @@ public static void searchIndex( FloatVectorValues values = ctx.reader().getFloatVectorValues(KNN_VECTOR_FIELD); if (values != null) { assertEquals(KNN_VECTOR_FIELD_TYPE.vectorDimension(), values.dimension()); - for (int doc = values.nextDoc(); doc != NO_MORE_DOCS; doc = values.nextDoc()) { + KnnVectorValues.DocIndexIterator it = values.iterator(); + for (int doc = it.nextDoc(); doc != NO_MORE_DOCS; doc = it.nextDoc()) { float[] expectedVector = {KNN_VECTOR[0], KNN_VECTOR[1], KNN_VECTOR[2] + 0.1f * cnt}; assertArrayEquals( - "vectors do not match for doc=" + cnt, expectedVector, values.vectorValue(), 0); + "vectors do not match for doc=" + cnt, + expectedVector, + values.vectorValue(it.index()), + 0); cnt++; } } diff --git a/lucene/benchmark-jmh/src/java/org/apache/lucene/benchmark/jmh/VectorScorerBenchmark.java b/lucene/benchmark-jmh/src/java/org/apache/lucene/benchmark/jmh/VectorScorerBenchmark.java index c4d3040f2835..0a4da1f48867 100644 --- a/lucene/benchmark-jmh/src/java/org/apache/lucene/benchmark/jmh/VectorScorerBenchmark.java +++ b/lucene/benchmark-jmh/src/java/org/apache/lucene/benchmark/jmh/VectorScorerBenchmark.java @@ -25,6 +25,7 @@ import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil; import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; import org.apache.lucene.codecs.lucene95.OffHeapByteVectorValues; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.store.Directory; import org.apache.lucene.store.IOContext; @@ -32,7 +33,6 @@ import org.apache.lucene.store.IndexOutput; import org.apache.lucene.store.MMapDirectory; import org.apache.lucene.util.IOUtils; -import org.apache.lucene.util.hnsw.RandomAccessVectorValues; import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; import org.openjdk.jmh.annotations.*; @@ -55,7 +55,7 @@ public class VectorScorerBenchmark { Directory dir; IndexInput in; - RandomAccessVectorValues vectorValues; + KnnVectorValues vectorValues; byte[] vec1, vec2; RandomVectorScorer scorer; @@ -95,7 +95,7 @@ public float binaryDotProductMemSeg() throws IOException { return scorer.score(1); } - static RandomAccessVectorValues vectorValues( + static KnnVectorValues vectorValues( int dims, int size, IndexInput in, VectorSimilarityFunction sim) throws IOException { return new OffHeapByteVectorValues.DenseOffHeapVectorValues( dims, size, in.slice("test", 0, in.length()), dims, new ThrowingFlatVectorScorer(), sim); @@ -105,23 +105,19 @@ static final class ThrowingFlatVectorScorer implements FlatVectorsScorer { @Override public RandomVectorScorerSupplier getRandomVectorScorerSupplier( - VectorSimilarityFunction similarityFunction, RandomAccessVectorValues vectorValues) { + VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues) { throw new UnsupportedOperationException(); } @Override public RandomVectorScorer getRandomVectorScorer( - VectorSimilarityFunction similarityFunction, - RandomAccessVectorValues vectorValues, - float[] target) { + VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, float[] target) { throw new UnsupportedOperationException(); } @Override public RandomVectorScorer getRandomVectorScorer( - VectorSimilarityFunction similarityFunction, - RandomAccessVectorValues vectorValues, - byte[] target) { + VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, byte[] target) { throw new UnsupportedOperationException(); } } diff --git a/lucene/codecs/src/java/org/apache/lucene/codecs/bitvectors/FlatBitVectorsScorer.java b/lucene/codecs/src/java/org/apache/lucene/codecs/bitvectors/FlatBitVectorsScorer.java index b8ff37c2654a..8ffcc1c8d50e 100644 --- a/lucene/codecs/src/java/org/apache/lucene/codecs/bitvectors/FlatBitVectorsScorer.java +++ b/lucene/codecs/src/java/org/apache/lucene/codecs/bitvectors/FlatBitVectorsScorer.java @@ -19,10 +19,11 @@ import java.io.IOException; import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.util.Bits; import org.apache.lucene.util.VectorUtil; -import org.apache.lucene.util.hnsw.RandomAccessVectorValues; import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; @@ -30,45 +31,39 @@ public class FlatBitVectorsScorer implements FlatVectorsScorer { @Override public RandomVectorScorerSupplier getRandomVectorScorerSupplier( - VectorSimilarityFunction similarityFunction, RandomAccessVectorValues vectorValues) + VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues) throws IOException { - assert vectorValues instanceof RandomAccessVectorValues.Bytes; - if (vectorValues instanceof RandomAccessVectorValues.Bytes byteVectorValues) { + assert vectorValues instanceof ByteVectorValues; + if (vectorValues instanceof ByteVectorValues byteVectorValues) { return new BitRandomVectorScorerSupplier(byteVectorValues); } - throw new IllegalArgumentException( - "vectorValues must be an instance of RandomAccessVectorValues.Bytes"); + throw new IllegalArgumentException("vectorValues must be an instance of ByteVectorValues"); } @Override public RandomVectorScorer getRandomVectorScorer( - VectorSimilarityFunction similarityFunction, - RandomAccessVectorValues vectorValues, - float[] target) + VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, float[] target) throws IOException { throw new IllegalArgumentException("bit vectors do not support float[] targets"); } @Override public RandomVectorScorer getRandomVectorScorer( - VectorSimilarityFunction similarityFunction, - RandomAccessVectorValues vectorValues, - byte[] target) + VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, byte[] target) throws IOException { - assert vectorValues instanceof RandomAccessVectorValues.Bytes; - if (vectorValues instanceof RandomAccessVectorValues.Bytes byteVectorValues) { + assert vectorValues instanceof ByteVectorValues; + if (vectorValues instanceof ByteVectorValues byteVectorValues) { return new BitRandomVectorScorer(byteVectorValues, target); } - throw new IllegalArgumentException( - "vectorValues must be an instance of RandomAccessVectorValues.Bytes"); + throw new IllegalArgumentException("vectorValues must be an instance of ByteVectorValues"); } static class BitRandomVectorScorer implements RandomVectorScorer { - private final RandomAccessVectorValues.Bytes vectorValues; + private final ByteVectorValues vectorValues; private final int bitDimensions; private final byte[] query; - BitRandomVectorScorer(RandomAccessVectorValues.Bytes vectorValues, byte[] query) { + BitRandomVectorScorer(ByteVectorValues vectorValues, byte[] query) { this.query = query; this.bitDimensions = vectorValues.dimension() * Byte.SIZE; this.vectorValues = vectorValues; @@ -97,12 +92,11 @@ public Bits getAcceptOrds(Bits acceptDocs) { } static class BitRandomVectorScorerSupplier implements RandomVectorScorerSupplier { - protected final RandomAccessVectorValues.Bytes vectorValues; - protected final RandomAccessVectorValues.Bytes vectorValues1; - protected final RandomAccessVectorValues.Bytes vectorValues2; + protected final ByteVectorValues vectorValues; + protected final ByteVectorValues vectorValues1; + protected final ByteVectorValues vectorValues2; - public BitRandomVectorScorerSupplier(RandomAccessVectorValues.Bytes vectorValues) - throws IOException { + public BitRandomVectorScorerSupplier(ByteVectorValues vectorValues) throws IOException { this.vectorValues = vectorValues; this.vectorValues1 = vectorValues.copy(); this.vectorValues2 = vectorValues.copy(); diff --git a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java index faba629715b7..97a518701b00 100644 --- a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java +++ b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java @@ -192,8 +192,8 @@ public void search(String field, float[] target, KnnCollector knnCollector, Bits } FieldInfo info = readState.fieldInfos.fieldInfo(field); VectorSimilarityFunction vectorSimilarity = info.getVectorSimilarityFunction(); - int doc; - while ((doc = values.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) { + for (int ord = 0; ord < values.size(); ord++) { + int doc = values.ordToDoc(ord); if (acceptDocs != null && acceptDocs.get(doc) == false) { continue; } @@ -202,7 +202,7 @@ public void search(String field, float[] target, KnnCollector knnCollector, Bits break; } - float[] vector = values.vectorValue(); + float[] vector = values.vectorValue(ord); float score = vectorSimilarity.compare(vector, target); knnCollector.collect(doc, score); knnCollector.incVisitedCount(1); @@ -223,8 +223,8 @@ public void search(String field, byte[] target, KnnCollector knnCollector, Bits FieldInfo info = readState.fieldInfos.fieldInfo(field); VectorSimilarityFunction vectorSimilarity = info.getVectorSimilarityFunction(); - int doc; - while ((doc = values.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) { + for (int ord = 0; ord < values.size(); ord++) { + int doc = values.ordToDoc(ord); if (acceptDocs != null && acceptDocs.get(doc) == false) { continue; } @@ -233,7 +233,7 @@ public void search(String field, byte[] target, KnnCollector knnCollector, Bits break; } - byte[] vector = values.vectorValue(); + byte[] vector = values.vectorValue(ord); float score = vectorSimilarity.compare(vector, target); knnCollector.collect(doc, score); knnCollector.incVisitedCount(1); @@ -327,35 +327,18 @@ public int size() { } @Override - public float[] vectorValue() { - return values[curOrd]; + public float[] vectorValue(int ord) { + return values[ord]; } @Override - public int docID() { - if (curOrd == -1) { - return -1; - } else if (curOrd >= entry.size()) { - // when call to advance / nextDoc below already returns NO_MORE_DOCS, calling docID - // immediately afterward should also return NO_MORE_DOCS - // this is needed for TestSimpleTextKnnVectorsFormat.testAdvance test case - return NO_MORE_DOCS; - } - - return entry.ordToDoc[curOrd]; - } - - @Override - public int nextDoc() throws IOException { - if (++curOrd < entry.size()) { - return docID(); - } - return NO_MORE_DOCS; + public int ordToDoc(int ord) { + return entry.ordToDoc[ord]; } @Override - public int advance(int target) throws IOException { - return slowAdvance(target); + public DocIndexIterator iterator() { + return createSparseIterator(); } @Override @@ -365,17 +348,19 @@ public VectorScorer scorer(float[] target) { } SimpleTextFloatVectorValues simpleTextFloatVectorValues = new SimpleTextFloatVectorValues(this); + DocIndexIterator iterator = simpleTextFloatVectorValues.iterator(); return new VectorScorer() { @Override public float score() throws IOException { + int ord = iterator.index(); return entry .similarityFunction() - .compare(simpleTextFloatVectorValues.vectorValue(), target); + .compare(simpleTextFloatVectorValues.vectorValue(ord), target); } @Override public DocIdSetIterator iterator() { - return simpleTextFloatVectorValues; + return iterator; } }; } @@ -397,6 +382,11 @@ private void readVector(float[] value) throws IOException { value[i] = Float.parseFloat(floatStrings[i]); } } + + @Override + public SimpleTextFloatVectorValues copy() { + return this; + } } private static class SimpleTextByteVectorValues extends ByteVectorValues { @@ -439,36 +429,14 @@ public int size() { } @Override - public byte[] vectorValue() { - binaryValue.bytes = values[curOrd]; + public byte[] vectorValue(int ord) { + binaryValue.bytes = values[ord]; return binaryValue.bytes; } @Override - public int docID() { - if (curOrd == -1) { - return -1; - } else if (curOrd >= entry.size()) { - // when call to advance / nextDoc below already returns NO_MORE_DOCS, calling docID - // immediately afterward should also return NO_MORE_DOCS - // this is needed for TestSimpleTextKnnVectorsFormat.testAdvance test case - return NO_MORE_DOCS; - } - - return entry.ordToDoc[curOrd]; - } - - @Override - public int nextDoc() throws IOException { - if (++curOrd < entry.size()) { - return docID(); - } - return NO_MORE_DOCS; - } - - @Override - public int advance(int target) throws IOException { - return slowAdvance(target); + public DocIndexIterator iterator() { + return createSparseIterator(); } @Override @@ -478,16 +446,19 @@ public VectorScorer scorer(byte[] target) { } SimpleTextByteVectorValues simpleTextByteVectorValues = new SimpleTextByteVectorValues(this); return new VectorScorer() { + DocIndexIterator it = simpleTextByteVectorValues.iterator(); + @Override public float score() throws IOException { + int ord = it.index(); return entry .similarityFunction() - .compare(simpleTextByteVectorValues.vectorValue(), target); + .compare(simpleTextByteVectorValues.vectorValue(ord), target); } @Override public DocIdSetIterator iterator() { - return simpleTextByteVectorValues; + return it; } }; } @@ -509,6 +480,11 @@ private void readVector(byte[] value) throws IOException { value[i] = (byte) Float.parseFloat(floatStrings[i]); } } + + @Override + public SimpleTextByteVectorValues copy() { + return this; + } } private int readInt(IndexInput in, BytesRef field) throws IOException { diff --git a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsWriter.java b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsWriter.java index a7a76ac1bb98..eaf4b657755c 100644 --- a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsWriter.java +++ b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsWriter.java @@ -28,6 +28,7 @@ import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.IndexFileNames; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.SegmentWriteState; import org.apache.lucene.store.IndexOutput; import org.apache.lucene.util.BytesRef; @@ -77,19 +78,18 @@ public void writeField(FieldInfo fieldInfo, FloatVectorValues floatVectorValues, throws IOException { long vectorDataOffset = vectorData.getFilePointer(); List docIds = new ArrayList<>(); - for (int docV = floatVectorValues.nextDoc(); - docV != NO_MORE_DOCS; - docV = floatVectorValues.nextDoc()) { - writeFloatVectorValue(floatVectorValues); - docIds.add(docV); + KnnVectorValues.DocIndexIterator iter = floatVectorValues.iterator(); + for (int docId = iter.nextDoc(); docId != NO_MORE_DOCS; docId = iter.nextDoc()) { + writeFloatVectorValue(floatVectorValues, iter.index()); + docIds.add(docId); } long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset; writeMeta(fieldInfo, vectorDataOffset, vectorDataLength, docIds); } - private void writeFloatVectorValue(FloatVectorValues vectors) throws IOException { + private void writeFloatVectorValue(FloatVectorValues vectors, int ord) throws IOException { // write vector value - float[] value = vectors.vectorValue(); + float[] value = vectors.vectorValue(ord); assert value.length == vectors.dimension(); write(vectorData, Arrays.toString(value)); newline(vectorData); @@ -100,19 +100,18 @@ public void writeField(FieldInfo fieldInfo, ByteVectorValues byteVectorValues, i throws IOException { long vectorDataOffset = vectorData.getFilePointer(); List docIds = new ArrayList<>(); - for (int docV = byteVectorValues.nextDoc(); - docV != NO_MORE_DOCS; - docV = byteVectorValues.nextDoc()) { - writeByteVectorValue(byteVectorValues); + KnnVectorValues.DocIndexIterator it = byteVectorValues.iterator(); + for (int docV = it.nextDoc(); docV != NO_MORE_DOCS; docV = it.nextDoc()) { + writeByteVectorValue(byteVectorValues, it.index()); docIds.add(docV); } long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset; writeMeta(fieldInfo, vectorDataOffset, vectorDataLength, docIds); } - private void writeByteVectorValue(ByteVectorValues vectors) throws IOException { + private void writeByteVectorValue(ByteVectorValues vectors, int ord) throws IOException { // write vector value - byte[] value = vectors.vectorValue(); + byte[] value = vectors.vectorValue(ord); assert value.length == vectors.dimension(); write(vectorData, Arrays.toString(value)); newline(vectorData); diff --git a/lucene/core/src/java/org/apache/lucene/codecs/BufferingKnnVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/BufferingKnnVectorsWriter.java index 8a9b4816571e..96b0f75a259f 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/BufferingKnnVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/BufferingKnnVectorsWriter.java @@ -20,14 +20,16 @@ import java.io.IOException; import java.util.ArrayList; import java.util.List; +import java.util.function.Supplier; import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.DocsWithFieldSet; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.MergeState; import org.apache.lucene.index.Sorter; -import org.apache.lucene.search.DocIdSetIterator; -import org.apache.lucene.search.VectorScorer; +import org.apache.lucene.index.SortingCodecReader; +import org.apache.lucene.index.SortingCodecReader.SortingValuesIterator; +import org.apache.lucene.search.DocIdSet; import org.apache.lucene.util.ArrayUtil; import org.apache.lucene.util.RamUsageEstimator; @@ -80,24 +82,26 @@ public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException { case FLOAT32: BufferedFloatVectorValues bufferedFloatVectorValues = new BufferedFloatVectorValues( - fieldData.docsWithField, (List) fieldData.vectors, - fieldData.fieldInfo.getVectorDimension()); + fieldData.fieldInfo.getVectorDimension(), + fieldData.docsWithField); FloatVectorValues floatVectorValues = sortMap != null - ? new SortingFloatVectorValues(bufferedFloatVectorValues, sortMap) + ? new SortingFloatVectorValues( + bufferedFloatVectorValues, fieldData.docsWithField, sortMap) : bufferedFloatVectorValues; writeField(fieldData.fieldInfo, floatVectorValues, maxDoc); break; case BYTE: BufferedByteVectorValues bufferedByteVectorValues = new BufferedByteVectorValues( - fieldData.docsWithField, (List) fieldData.vectors, - fieldData.fieldInfo.getVectorDimension()); + fieldData.fieldInfo.getVectorDimension(), + fieldData.docsWithField); ByteVectorValues byteVectorValues = sortMap != null - ? new SortingByteVectorValues(bufferedByteVectorValues, sortMap) + ? new SortingByteVectorValues( + bufferedByteVectorValues, fieldData.docsWithField, sortMap) : bufferedByteVectorValues; writeField(fieldData.fieldInfo, byteVectorValues, maxDoc); break; @@ -107,125 +111,77 @@ public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException { /** Sorting FloatVectorValues that iterate over documents in the order of the provided sortMap */ private static class SortingFloatVectorValues extends FloatVectorValues { - private final BufferedFloatVectorValues randomAccess; - private final int[] docIdOffsets; - private int docId = -1; + private final BufferedFloatVectorValues delegate; + private final Supplier iteratorSupplier; - SortingFloatVectorValues(BufferedFloatVectorValues delegate, Sorter.DocMap sortMap) + SortingFloatVectorValues( + BufferedFloatVectorValues delegate, DocsWithFieldSet docsWithField, Sorter.DocMap sortMap) throws IOException { - this.randomAccess = delegate.copy(); - this.docIdOffsets = new int[sortMap.size()]; - - int offset = 1; // 0 means no vector for this (field, document) - int docID; - while ((docID = delegate.nextDoc()) != NO_MORE_DOCS) { - int newDocID = sortMap.oldToNew(docID); - docIdOffsets[newDocID] = offset++; - } - } - - @Override - public int docID() { - return docId; - } - - @Override - public int nextDoc() throws IOException { - while (docId < docIdOffsets.length - 1) { - ++docId; - if (docIdOffsets[docId] != 0) { - return docId; - } - } - docId = NO_MORE_DOCS; - return docId; + this.delegate = delegate.copy(); + iteratorSupplier = SortingCodecReader.iteratorSupplier(delegate, sortMap); } @Override - public float[] vectorValue() throws IOException { - return randomAccess.vectorValue(docIdOffsets[docId] - 1); + public float[] vectorValue(int ord) throws IOException { + return delegate.vectorValue(ord); } @Override public int dimension() { - return randomAccess.dimension(); + return delegate.dimension(); } @Override public int size() { - return randomAccess.size(); + return delegate.size(); } @Override - public int advance(int target) throws IOException { + public SortingFloatVectorValues copy() { throw new UnsupportedOperationException(); } @Override - public VectorScorer scorer(float[] target) { - throw new UnsupportedOperationException(); + public DocIndexIterator iterator() { + return iteratorSupplier.get(); } } - /** Sorting FloatVectorValues that iterate over documents in the order of the provided sortMap */ + /** Sorting ByteVectorValues that iterate over documents in the order of the provided sortMap */ private static class SortingByteVectorValues extends ByteVectorValues { - private final BufferedByteVectorValues randomAccess; - private final int[] docIdOffsets; - private int docId = -1; + private final BufferedByteVectorValues delegate; + private final Supplier iteratorSupplier; - SortingByteVectorValues(BufferedByteVectorValues delegate, Sorter.DocMap sortMap) + SortingByteVectorValues( + BufferedByteVectorValues delegate, DocsWithFieldSet docsWithField, Sorter.DocMap sortMap) throws IOException { - this.randomAccess = delegate.copy(); - this.docIdOffsets = new int[sortMap.size()]; - - int offset = 1; // 0 means no vector for this (field, document) - int docID; - while ((docID = delegate.nextDoc()) != NO_MORE_DOCS) { - int newDocID = sortMap.oldToNew(docID); - docIdOffsets[newDocID] = offset++; - } - } - - @Override - public int docID() { - return docId; + this.delegate = delegate; + iteratorSupplier = SortingCodecReader.iteratorSupplier(delegate, sortMap); } @Override - public int nextDoc() throws IOException { - while (docId < docIdOffsets.length - 1) { - ++docId; - if (docIdOffsets[docId] != 0) { - return docId; - } - } - docId = NO_MORE_DOCS; - return docId; - } - - @Override - public byte[] vectorValue() throws IOException { - return randomAccess.vectorValue(docIdOffsets[docId] - 1); + public byte[] vectorValue(int ord) throws IOException { + return delegate.vectorValue(ord); } @Override public int dimension() { - return randomAccess.dimension(); + return delegate.dimension(); } @Override public int size() { - return randomAccess.size(); + return delegate.size(); } @Override - public int advance(int target) throws IOException { + public SortingByteVectorValues copy() { throw new UnsupportedOperationException(); } @Override - public VectorScorer scorer(byte[] target) { - throw new UnsupportedOperationException(); + public DocIndexIterator iterator() { + return iteratorSupplier.get(); } } @@ -296,7 +252,9 @@ public final void addValue(int docID, T value) { @Override public final long ramBytesUsed() { - if (vectors.size() == 0) return 0; + if (vectors.isEmpty()) { + return 0; + } return docsWithField.ramBytesUsed() + vectors.size() * (long) @@ -307,25 +265,18 @@ public final long ramBytesUsed() { } private static class BufferedFloatVectorValues extends FloatVectorValues { - final DocsWithFieldSet docsWithField; - // These are always the vectors of a VectorValuesWriter, which are copied when added to it final List vectors; final int dimension; + private final DocIdSet docsWithField; + private final DocIndexIterator iterator; - DocIdSetIterator docsWithFieldIter; - int ord = -1; - - BufferedFloatVectorValues( - DocsWithFieldSet docsWithField, List vectors, int dimension) { - this.docsWithField = docsWithField; + BufferedFloatVectorValues(List vectors, int dimension, DocIdSet docsWithField) + throws IOException { this.vectors = vectors; this.dimension = dimension; - docsWithFieldIter = docsWithField.iterator(); - } - - public BufferedFloatVectorValues copy() { - return new BufferedFloatVectorValues(docsWithField, vectors, dimension); + this.docsWithField = docsWithField; + this.iterator = fromDISI(docsWithField.iterator()); } @Override @@ -339,58 +290,39 @@ public int size() { } @Override - public float[] vectorValue() { - return vectors.get(ord); - } - - float[] vectorValue(int targetOrd) { - return vectors.get(targetOrd); + public int ordToDoc(int ord) { + return ord; } @Override - public int docID() { - return docsWithFieldIter.docID(); - } - - @Override - public int nextDoc() throws IOException { - int docID = docsWithFieldIter.nextDoc(); - if (docID != NO_MORE_DOCS) { - ++ord; - } - return docID; + public float[] vectorValue(int targetOrd) { + return vectors.get(targetOrd); } @Override - public int advance(int target) { - throw new UnsupportedOperationException(); + public DocIndexIterator iterator() { + return iterator; } @Override - public VectorScorer scorer(float[] target) { - throw new UnsupportedOperationException(); + public BufferedFloatVectorValues copy() throws IOException { + return new BufferedFloatVectorValues(vectors, dimension, docsWithField); } } private static class BufferedByteVectorValues extends ByteVectorValues { - final DocsWithFieldSet docsWithField; - // These are always the vectors of a VectorValuesWriter, which are copied when added to it final List vectors; final int dimension; + private final DocIdSet docsWithField; + private final DocIndexIterator iterator; - DocIdSetIterator docsWithFieldIter; - int ord = -1; - - BufferedByteVectorValues(DocsWithFieldSet docsWithField, List vectors, int dimension) { - this.docsWithField = docsWithField; + BufferedByteVectorValues(List vectors, int dimension, DocIdSet docsWithField) + throws IOException { this.vectors = vectors; this.dimension = dimension; - docsWithFieldIter = docsWithField.iterator(); - } - - public BufferedByteVectorValues copy() { - return new BufferedByteVectorValues(docsWithField, vectors, dimension); + this.docsWithField = docsWithField; + iterator = fromDISI(docsWithField.iterator()); } @Override @@ -404,36 +336,18 @@ public int size() { } @Override - public byte[] vectorValue() { - return vectors.get(ord); - } - - byte[] vectorValue(int targetOrd) { + public byte[] vectorValue(int targetOrd) { return vectors.get(targetOrd); } @Override - public int docID() { - return docsWithFieldIter.docID(); + public DocIndexIterator iterator() { + return iterator; } @Override - public int nextDoc() throws IOException { - int docID = docsWithFieldIter.nextDoc(); - if (docID != NO_MORE_DOCS) { - ++ord; - } - return docID; - } - - @Override - public int advance(int target) { - throw new UnsupportedOperationException(); - } - - @Override - public VectorScorer scorer(byte[] target) { - throw new UnsupportedOperationException(); + public BufferedByteVectorValues copy() throws IOException { + return new BufferedByteVectorValues(vectors, dimension, docsWithField); } } } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsWriter.java index 3b185fd13a07..cbf68dd8c20d 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsWriter.java @@ -30,6 +30,7 @@ import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FieldInfos; import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.MergeState; import org.apache.lucene.index.Sorter; import org.apache.lucene.index.VectorEncoding; @@ -55,28 +56,26 @@ protected KnnVectorsWriter() {} @SuppressWarnings("unchecked") public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException { switch (fieldInfo.getVectorEncoding()) { - case BYTE: + case BYTE -> { KnnFieldVectorsWriter byteWriter = (KnnFieldVectorsWriter) addField(fieldInfo); ByteVectorValues mergedBytes = MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState); - for (int doc = mergedBytes.nextDoc(); - doc != DocIdSetIterator.NO_MORE_DOCS; - doc = mergedBytes.nextDoc()) { - byteWriter.addValue(doc, mergedBytes.vectorValue()); + KnnVectorValues.DocIndexIterator iter = mergedBytes.iterator(); + for (int doc = iter.nextDoc(); doc != DocIdSetIterator.NO_MORE_DOCS; doc = iter.nextDoc()) { + byteWriter.addValue(doc, mergedBytes.vectorValue(iter.index())); } - break; - case FLOAT32: + } + case FLOAT32 -> { KnnFieldVectorsWriter floatWriter = (KnnFieldVectorsWriter) addField(fieldInfo); FloatVectorValues mergedFloats = MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState); - for (int doc = mergedFloats.nextDoc(); - doc != DocIdSetIterator.NO_MORE_DOCS; - doc = mergedFloats.nextDoc()) { - floatWriter.addValue(doc, mergedFloats.vectorValue()); + KnnVectorValues.DocIndexIterator iter = mergedFloats.iterator(); + for (int doc = iter.nextDoc(); doc != DocIdSetIterator.NO_MORE_DOCS; doc = iter.nextDoc()) { + floatWriter.addValue(doc, mergedFloats.vectorValue(iter.index())); } - break; + } } } @@ -117,32 +116,44 @@ public final void merge(MergeState mergeState) throws IOException { private static class FloatVectorValuesSub extends DocIDMerger.Sub { final FloatVectorValues values; + final KnnVectorValues.DocIndexIterator iterator; FloatVectorValuesSub(MergeState.DocMap docMap, FloatVectorValues values) { super(docMap); this.values = values; - assert values.docID() == -1; + this.iterator = values.iterator(); + assert iterator.docID() == -1; } @Override public int nextDoc() throws IOException { - return values.nextDoc(); + return iterator.nextDoc(); + } + + public int index() { + return iterator.index(); } } private static class ByteVectorValuesSub extends DocIDMerger.Sub { final ByteVectorValues values; + final KnnVectorValues.DocIndexIterator iterator; ByteVectorValuesSub(MergeState.DocMap docMap, ByteVectorValues values) { super(docMap); this.values = values; - assert values.docID() == -1; + iterator = values.iterator(); + assert iterator.docID() == -1; } @Override public int nextDoc() throws IOException { - return values.nextDoc(); + return iterator.nextDoc(); + } + + int index() { + return iterator.index(); } } @@ -287,7 +298,8 @@ static class MergedFloat32VectorValues extends FloatVectorValues { private final List subs; private final DocIDMerger docIdMerger; private final int size; - private int docId; + private int docId = -1; + private int lastOrd = -1; FloatVectorValuesSub current; private MergedFloat32VectorValues(List subs, MergeState mergeState) @@ -299,47 +311,81 @@ private MergedFloat32VectorValues(List subs, MergeState me totalSize += sub.values.size(); } size = totalSize; - docId = -1; } @Override - public int docID() { - return docId; + public DocIndexIterator iterator() { + return new DocIndexIterator() { + private int index = -1; + + @Override + public int docID() { + return docId; + } + + @Override + public int index() { + return index; + } + + @Override + public int nextDoc() throws IOException { + current = docIdMerger.next(); + if (current == null) { + docId = NO_MORE_DOCS; + index = NO_MORE_DOCS; + } else { + docId = current.mappedDocID; + ++index; + } + return docId; + } + + @Override + public int advance(int target) throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + public long cost() { + return size; + } + }; } @Override - public int nextDoc() throws IOException { - current = docIdMerger.next(); - if (current == null) { - docId = NO_MORE_DOCS; + public float[] vectorValue(int ord) throws IOException { + if (ord != lastOrd + 1) { + throw new IllegalStateException( + "only supports forward iteration: ord=" + ord + ", lastOrd=" + lastOrd); } else { - docId = current.mappedDocID; + lastOrd = ord; } - return docId; + return current.values.vectorValue(current.index()); } @Override - public float[] vectorValue() throws IOException { - return current.values.vectorValue(); + public int size() { + return size; } @Override - public int advance(int target) { - throw new UnsupportedOperationException(); + public int dimension() { + return subs.get(0).values.dimension(); } @Override - public int size() { - return size; + public int ordToDoc(int ord) { + throw new UnsupportedOperationException(); } @Override - public int dimension() { - return subs.get(0).values.dimension(); + public VectorScorer scorer(float[] target) { + throw new UnsupportedOperationException(); } @Override - public VectorScorer scorer(float[] target) { + public FloatVectorValues copy() { throw new UnsupportedOperationException(); } } @@ -349,7 +395,8 @@ static class MergedByteVectorValues extends ByteVectorValues { private final DocIDMerger docIdMerger; private final int size; - private int docId; + private int lastOrd = -1; + private int docId = -1; ByteVectorValuesSub current; private MergedByteVectorValues(List subs, MergeState mergeState) @@ -361,33 +408,57 @@ private MergedByteVectorValues(List subs, MergeState mergeS totalSize += sub.values.size(); } size = totalSize; - docId = -1; - } - - @Override - public byte[] vectorValue() throws IOException { - return current.values.vectorValue(); - } - - @Override - public int docID() { - return docId; } @Override - public int nextDoc() throws IOException { - current = docIdMerger.next(); - if (current == null) { - docId = NO_MORE_DOCS; + public byte[] vectorValue(int ord) throws IOException { + if (ord != lastOrd + 1) { + throw new IllegalStateException( + "only supports forward iteration: ord=" + ord + ", lastOrd=" + lastOrd); } else { - docId = current.mappedDocID; + lastOrd = ord; } - return docId; + return current.values.vectorValue(current.index()); } @Override - public int advance(int target) { - throw new UnsupportedOperationException(); + public DocIndexIterator iterator() { + return new DocIndexIterator() { + private int index = -1; + + @Override + public int docID() { + return docId; + } + + @Override + public int index() { + return index; + } + + @Override + public int nextDoc() throws IOException { + current = docIdMerger.next(); + if (current == null) { + docId = NO_MORE_DOCS; + index = NO_MORE_DOCS; + } else { + docId = current.mappedDocID; + ++index; + } + return docId; + } + + @Override + public int advance(int target) throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + public long cost() { + return size; + } + }; } @Override @@ -400,10 +471,20 @@ public int dimension() { return subs.get(0).values.dimension(); } + @Override + public int ordToDoc(int ord) { + throw new UnsupportedOperationException(); + } + @Override public VectorScorer scorer(byte[] target) { throw new UnsupportedOperationException(); } + + @Override + public ByteVectorValues copy() { + throw new UnsupportedOperationException(); + } } } } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/hnsw/DefaultFlatVectorScorer.java b/lucene/core/src/java/org/apache/lucene/codecs/hnsw/DefaultFlatVectorScorer.java index 1274e1c789e4..3e506037969a 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/hnsw/DefaultFlatVectorScorer.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/hnsw/DefaultFlatVectorScorer.java @@ -18,8 +18,10 @@ package org.apache.lucene.codecs.hnsw; import java.io.IOException; +import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.VectorSimilarityFunction; -import org.apache.lucene.util.hnsw.RandomAccessVectorValues; import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; @@ -34,24 +36,26 @@ public class DefaultFlatVectorScorer implements FlatVectorsScorer { @Override public RandomVectorScorerSupplier getRandomVectorScorerSupplier( - VectorSimilarityFunction similarityFunction, RandomAccessVectorValues vectorValues) + VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues) throws IOException { - if (vectorValues instanceof RandomAccessVectorValues.Floats floatVectorValues) { - return new FloatScoringSupplier(floatVectorValues, similarityFunction); - } else if (vectorValues instanceof RandomAccessVectorValues.Bytes byteVectorValues) { - return new ByteScoringSupplier(byteVectorValues, similarityFunction); + switch (vectorValues.getEncoding()) { + case FLOAT32 -> { + return new FloatScoringSupplier((FloatVectorValues) vectorValues, similarityFunction); + } + case BYTE -> { + return new ByteScoringSupplier((ByteVectorValues) vectorValues, similarityFunction); + } } throw new IllegalArgumentException( - "vectorValues must be an instance of RandomAccessVectorValues.Floats or RandomAccessVectorValues.Bytes"); + "vectorValues must be an instance of FloatVectorValues or ByteVectorValues, got a " + + vectorValues.getClass().getName()); } @Override public RandomVectorScorer getRandomVectorScorer( - VectorSimilarityFunction similarityFunction, - RandomAccessVectorValues vectorValues, - float[] target) + VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, float[] target) throws IOException { - assert vectorValues instanceof RandomAccessVectorValues.Floats; + assert vectorValues instanceof FloatVectorValues; if (target.length != vectorValues.dimension()) { throw new IllegalArgumentException( "vector query dimension: " @@ -59,17 +63,14 @@ public RandomVectorScorer getRandomVectorScorer( + " differs from field dimension: " + vectorValues.dimension()); } - return new FloatVectorScorer( - (RandomAccessVectorValues.Floats) vectorValues, target, similarityFunction); + return new FloatVectorScorer((FloatVectorValues) vectorValues, target, similarityFunction); } @Override public RandomVectorScorer getRandomVectorScorer( - VectorSimilarityFunction similarityFunction, - RandomAccessVectorValues vectorValues, - byte[] target) + VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, byte[] target) throws IOException { - assert vectorValues instanceof RandomAccessVectorValues.Bytes; + assert vectorValues instanceof ByteVectorValues; if (target.length != vectorValues.dimension()) { throw new IllegalArgumentException( "vector query dimension: " @@ -77,8 +78,7 @@ public RandomVectorScorer getRandomVectorScorer( + " differs from field dimension: " + vectorValues.dimension()); } - return new ByteVectorScorer( - (RandomAccessVectorValues.Bytes) vectorValues, target, similarityFunction); + return new ByteVectorScorer((ByteVectorValues) vectorValues, target, similarityFunction); } @Override @@ -88,14 +88,13 @@ public String toString() { /** RandomVectorScorerSupplier for bytes vector */ private static final class ByteScoringSupplier implements RandomVectorScorerSupplier { - private final RandomAccessVectorValues.Bytes vectors; - private final RandomAccessVectorValues.Bytes vectors1; - private final RandomAccessVectorValues.Bytes vectors2; + private final ByteVectorValues vectors; + private final ByteVectorValues vectors1; + private final ByteVectorValues vectors2; private final VectorSimilarityFunction similarityFunction; private ByteScoringSupplier( - RandomAccessVectorValues.Bytes vectors, VectorSimilarityFunction similarityFunction) - throws IOException { + ByteVectorValues vectors, VectorSimilarityFunction similarityFunction) throws IOException { this.vectors = vectors; vectors1 = vectors.copy(); vectors2 = vectors.copy(); @@ -125,14 +124,13 @@ public String toString() { /** RandomVectorScorerSupplier for Float vector */ private static final class FloatScoringSupplier implements RandomVectorScorerSupplier { - private final RandomAccessVectorValues.Floats vectors; - private final RandomAccessVectorValues.Floats vectors1; - private final RandomAccessVectorValues.Floats vectors2; + private final FloatVectorValues vectors; + private final FloatVectorValues vectors1; + private final FloatVectorValues vectors2; private final VectorSimilarityFunction similarityFunction; private FloatScoringSupplier( - RandomAccessVectorValues.Floats vectors, VectorSimilarityFunction similarityFunction) - throws IOException { + FloatVectorValues vectors, VectorSimilarityFunction similarityFunction) throws IOException { this.vectors = vectors; vectors1 = vectors.copy(); vectors2 = vectors.copy(); @@ -162,14 +160,12 @@ public String toString() { /** A {@link RandomVectorScorer} for float vectors. */ private static class FloatVectorScorer extends RandomVectorScorer.AbstractRandomVectorScorer { - private final RandomAccessVectorValues.Floats values; + private final FloatVectorValues values; private final float[] query; private final VectorSimilarityFunction similarityFunction; public FloatVectorScorer( - RandomAccessVectorValues.Floats values, - float[] query, - VectorSimilarityFunction similarityFunction) { + FloatVectorValues values, float[] query, VectorSimilarityFunction similarityFunction) { super(values); this.values = values; this.query = query; @@ -184,14 +180,12 @@ public float score(int node) throws IOException { /** A {@link RandomVectorScorer} for byte vectors. */ private static class ByteVectorScorer extends RandomVectorScorer.AbstractRandomVectorScorer { - private final RandomAccessVectorValues.Bytes values; + private final ByteVectorValues values; private final byte[] query; private final VectorSimilarityFunction similarityFunction; public ByteVectorScorer( - RandomAccessVectorValues.Bytes values, - byte[] query, - VectorSimilarityFunction similarityFunction) { + ByteVectorValues values, byte[] query, VectorSimilarityFunction similarityFunction) { super(values); this.values = values; this.query = query; diff --git a/lucene/core/src/java/org/apache/lucene/codecs/hnsw/FlatVectorsScorer.java b/lucene/core/src/java/org/apache/lucene/codecs/hnsw/FlatVectorsScorer.java index 17430c24f276..6ed170731de4 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/hnsw/FlatVectorsScorer.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/hnsw/FlatVectorsScorer.java @@ -18,8 +18,8 @@ package org.apache.lucene.codecs.hnsw; import java.io.IOException; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.VectorSimilarityFunction; -import org.apache.lucene.util.hnsw.RandomAccessVectorValues; import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; @@ -40,8 +40,7 @@ public interface FlatVectorsScorer { * @throws IOException if an I/O error occurs */ RandomVectorScorerSupplier getRandomVectorScorerSupplier( - VectorSimilarityFunction similarityFunction, RandomAccessVectorValues vectorValues) - throws IOException; + VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues) throws IOException; /** * Returns a {@link RandomVectorScorer} for the given set of vectors and target vector. @@ -53,9 +52,7 @@ RandomVectorScorerSupplier getRandomVectorScorerSupplier( * @throws IOException if an I/O error occurs when reading from the index. */ RandomVectorScorer getRandomVectorScorer( - VectorSimilarityFunction similarityFunction, - RandomAccessVectorValues vectorValues, - float[] target) + VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, float[] target) throws IOException; /** @@ -68,8 +65,6 @@ RandomVectorScorer getRandomVectorScorer( * @throws IOException if an I/O error occurs when reading from the index. */ RandomVectorScorer getRandomVectorScorer( - VectorSimilarityFunction similarityFunction, - RandomAccessVectorValues vectorValues, - byte[] target) + VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, byte[] target) throws IOException; } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/hnsw/ScalarQuantizedVectorScorer.java b/lucene/core/src/java/org/apache/lucene/codecs/hnsw/ScalarQuantizedVectorScorer.java index 4b73e1f7a4a6..ceb826aa3a11 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/hnsw/ScalarQuantizedVectorScorer.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/hnsw/ScalarQuantizedVectorScorer.java @@ -18,13 +18,13 @@ package org.apache.lucene.codecs.hnsw; import java.io.IOException; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.util.ArrayUtil; import org.apache.lucene.util.VectorUtil; -import org.apache.lucene.util.hnsw.RandomAccessVectorValues; import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; -import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues; +import org.apache.lucene.util.quantization.QuantizedByteVectorValues; import org.apache.lucene.util.quantization.ScalarQuantizedVectorSimilarity; import org.apache.lucene.util.quantization.ScalarQuantizer; @@ -60,9 +60,9 @@ public ScalarQuantizedVectorScorer(FlatVectorsScorer flatVectorsScorer) { @Override public RandomVectorScorerSupplier getRandomVectorScorerSupplier( - VectorSimilarityFunction similarityFunction, RandomAccessVectorValues vectorValues) + VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues) throws IOException { - if (vectorValues instanceof RandomAccessQuantizedByteVectorValues quantizedByteVectorValues) { + if (vectorValues instanceof QuantizedByteVectorValues quantizedByteVectorValues) { return new ScalarQuantizedRandomVectorScorerSupplier( similarityFunction, quantizedByteVectorValues.getScalarQuantizer(), @@ -74,11 +74,9 @@ public RandomVectorScorerSupplier getRandomVectorScorerSupplier( @Override public RandomVectorScorer getRandomVectorScorer( - VectorSimilarityFunction similarityFunction, - RandomAccessVectorValues vectorValues, - float[] target) + VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, float[] target) throws IOException { - if (vectorValues instanceof RandomAccessQuantizedByteVectorValues quantizedByteVectorValues) { + if (vectorValues instanceof QuantizedByteVectorValues quantizedByteVectorValues) { ScalarQuantizer scalarQuantizer = quantizedByteVectorValues.getScalarQuantizer(); byte[] targetBytes = new byte[target.length]; float offsetCorrection = @@ -104,9 +102,7 @@ public float score(int node) throws IOException { @Override public RandomVectorScorer getRandomVectorScorer( - VectorSimilarityFunction similarityFunction, - RandomAccessVectorValues vectorValues, - byte[] target) + VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, byte[] target) throws IOException { return nonQuantizedDelegate.getRandomVectorScorer(similarityFunction, vectorValues, target); } @@ -124,14 +120,14 @@ public String toString() { public static class ScalarQuantizedRandomVectorScorerSupplier implements RandomVectorScorerSupplier { - private final RandomAccessQuantizedByteVectorValues values; + private final QuantizedByteVectorValues values; private final ScalarQuantizedVectorSimilarity similarity; private final VectorSimilarityFunction vectorSimilarityFunction; public ScalarQuantizedRandomVectorScorerSupplier( VectorSimilarityFunction similarityFunction, ScalarQuantizer scalarQuantizer, - RandomAccessQuantizedByteVectorValues values) { + QuantizedByteVectorValues values) { this.similarity = ScalarQuantizedVectorSimilarity.fromVectorSimilarity( similarityFunction, @@ -144,7 +140,7 @@ public ScalarQuantizedRandomVectorScorerSupplier( private ScalarQuantizedRandomVectorScorerSupplier( ScalarQuantizedVectorSimilarity similarity, VectorSimilarityFunction vectorSimilarityFunction, - RandomAccessQuantizedByteVectorValues values) { + QuantizedByteVectorValues values) { this.similarity = similarity; this.values = values; this.vectorSimilarityFunction = vectorSimilarityFunction; @@ -152,7 +148,7 @@ private ScalarQuantizedRandomVectorScorerSupplier( @Override public RandomVectorScorer scorer(int ord) throws IOException { - final RandomAccessQuantizedByteVectorValues vectorsCopy = values.copy(); + final QuantizedByteVectorValues vectorsCopy = values.copy(); final byte[] queryVector = values.vectorValue(ord); final float queryOffset = values.getScoreCorrectionConstant(ord); return new RandomVectorScorer.AbstractRandomVectorScorer(vectorsCopy) { diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene90/IndexedDISI.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene90/IndexedDISI.java index a2b2c84e12ae..dbd56125fcd1 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene90/IndexedDISI.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene90/IndexedDISI.java @@ -18,6 +18,7 @@ import java.io.DataInput; import java.io.IOException; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.IndexOutput; @@ -439,6 +440,40 @@ public static RandomAccessInput createJumpTable( // ALL variables int gap; + /** + * Returns an iterator that delegates to the IndexedDISI. Advancing this iterator will advance the + * underlying IndexedDISI, and vice-versa. + */ + public static KnnVectorValues.DocIndexIterator asDocIndexIterator(IndexedDISI disi) { + // can we replace with fromDISI? + return new KnnVectorValues.DocIndexIterator() { + @Override + public int docID() { + return disi.docID(); + } + + @Override + public int index() { + return disi.index(); + } + + @Override + public int nextDoc() throws IOException { + return disi.nextDoc(); + } + + @Override + public int advance(int target) throws IOException { + return disi.advance(target); + } + + @Override + public long cost() { + return disi.cost(); + } + }; + } + @Override public int docID() { return doc; diff --git a/lucene/core/src/java/org/apache/lucene/util/quantization/RandomAccessQuantizedByteVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/HasIndexSlice.java similarity index 57% rename from lucene/core/src/java/org/apache/lucene/util/quantization/RandomAccessQuantizedByteVectorValues.java rename to lucene/core/src/java/org/apache/lucene/codecs/lucene95/HasIndexSlice.java index b86009a690e1..2bfe72386a05 100644 --- a/lucene/core/src/java/org/apache/lucene/util/quantization/RandomAccessQuantizedByteVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/HasIndexSlice.java @@ -14,23 +14,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.lucene.util.quantization; +package org.apache.lucene.codecs.lucene95; -import java.io.IOException; -import org.apache.lucene.util.hnsw.RandomAccessVectorValues; +import org.apache.lucene.store.IndexInput; /** - * Random access values for byte[], but also includes accessing the score correction - * constant for the current vector in the buffer. - * - * @lucene.experimental + * Implementors can return the IndexInput from which their values are read. For use by vector + * quantizers. */ -public interface RandomAccessQuantizedByteVectorValues extends RandomAccessVectorValues.Bytes { - - ScalarQuantizer getScalarQuantizer(); - - float getScoreCorrectionConstant(int vectorOrd) throws IOException; +public interface HasIndexSlice { - @Override - RandomAccessQuantizedByteVectorValues copy() throws IOException; + /** Returns an IndexInput from which to read this instance's values. */ + IndexInput getSlice(); } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java index f45158eadac7..1e78c8ea7aa2 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java @@ -29,13 +29,11 @@ import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.RandomAccessInput; import org.apache.lucene.util.Bits; -import org.apache.lucene.util.hnsw.RandomAccessVectorValues; import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.apache.lucene.util.packed.DirectMonotonicReader; /** Read the vector values from the index input. This supports both iterated and random access. */ -public abstract class OffHeapByteVectorValues extends ByteVectorValues - implements RandomAccessVectorValues.Bytes { +public abstract class OffHeapByteVectorValues extends ByteVectorValues implements HasIndexSlice { protected final int dimension; protected final int size; @@ -132,9 +130,6 @@ public static OffHeapByteVectorValues load( * vector. */ public static class DenseOffHeapVectorValues extends OffHeapByteVectorValues { - - private int doc = -1; - public DenseOffHeapVectorValues( int dimension, int size, @@ -145,36 +140,17 @@ public DenseOffHeapVectorValues( super(dimension, size, slice, byteSize, flatVectorsScorer, vectorSimilarityFunction); } - @Override - public byte[] vectorValue() throws IOException { - return vectorValue(doc); - } - - @Override - public int docID() { - return doc; - } - - @Override - public int nextDoc() throws IOException { - return advance(doc + 1); - } - - @Override - public int advance(int target) throws IOException { - assert docID() < target; - if (target >= size) { - return doc = NO_MORE_DOCS; - } - return doc = target; - } - @Override public DenseOffHeapVectorValues copy() throws IOException { return new DenseOffHeapVectorValues( dimension, size, slice.clone(), byteSize, flatVectorsScorer, similarityFunction); } + @Override + public DocIndexIterator iterator() { + return createDenseIterator(); + } + @Override public Bits getAcceptOrds(Bits acceptDocs) { return acceptDocs; @@ -183,17 +159,18 @@ public Bits getAcceptOrds(Bits acceptDocs) { @Override public VectorScorer scorer(byte[] query) throws IOException { DenseOffHeapVectorValues copy = copy(); + DocIndexIterator iterator = copy.iterator(); RandomVectorScorer scorer = flatVectorsScorer.getRandomVectorScorer(similarityFunction, copy, query); return new VectorScorer() { @Override public float score() throws IOException { - return scorer.score(copy.doc); + return scorer.score(iterator.docID()); } @Override public DocIdSetIterator iterator() { - return copy; + return iterator; } }; } @@ -238,27 +215,6 @@ public SparseOffHeapVectorValues( configuration.size); } - @Override - public byte[] vectorValue() throws IOException { - return vectorValue(disi.index()); - } - - @Override - public int docID() { - return disi.docID(); - } - - @Override - public int nextDoc() throws IOException { - return disi.nextDoc(); - } - - @Override - public int advance(int target) throws IOException { - assert docID() < target; - return disi.advance(target); - } - @Override public SparseOffHeapVectorValues copy() throws IOException { return new SparseOffHeapVectorValues( @@ -276,6 +232,11 @@ public int ordToDoc(int ord) { return (int) ordToDoc.get(ord); } + @Override + public DocIndexIterator iterator() { + return IndexedDISI.asDocIndexIterator(disi); + } + @Override public Bits getAcceptOrds(Bits acceptDocs) { if (acceptDocs == null) { @@ -307,7 +268,7 @@ public float score() throws IOException { @Override public DocIdSetIterator iterator() { - return copy; + return copy.disi; } }; } @@ -322,8 +283,6 @@ public EmptyOffHeapVectorValues( super(dimension, 0, null, 0, flatVectorsScorer, vectorSimilarityFunction); } - private int doc = -1; - @Override public int dimension() { return super.dimension(); @@ -335,23 +294,13 @@ public int size() { } @Override - public byte[] vectorValue() throws IOException { + public byte[] vectorValue(int ord) throws IOException { throw new UnsupportedOperationException(); } @Override - public int docID() { - return doc; - } - - @Override - public int nextDoc() throws IOException { - return advance(doc + 1); - } - - @Override - public int advance(int target) throws IOException { - return doc = NO_MORE_DOCS; + public DocIndexIterator iterator() { + return createDenseIterator(); } @Override @@ -359,11 +308,6 @@ public EmptyOffHeapVectorValues copy() throws IOException { throw new UnsupportedOperationException(); } - @Override - public byte[] vectorValue(int targetOrd) throws IOException { - throw new UnsupportedOperationException(); - } - @Override public int ordToDoc(int ord) { throw new UnsupportedOperationException(); diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java index 1f61283b5002..2384657e93e1 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java @@ -28,13 +28,11 @@ import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.RandomAccessInput; import org.apache.lucene.util.Bits; -import org.apache.lucene.util.hnsw.RandomAccessVectorValues; import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.apache.lucene.util.packed.DirectMonotonicReader; /** Read the vector values from the index input. This supports both iterated and random access. */ -public abstract class OffHeapFloatVectorValues extends FloatVectorValues - implements RandomAccessVectorValues.Floats { +public abstract class OffHeapFloatVectorValues extends FloatVectorValues implements HasIndexSlice { protected final int dimension; protected final int size; @@ -128,8 +126,6 @@ public static OffHeapFloatVectorValues load( */ public static class DenseOffHeapVectorValues extends OffHeapFloatVectorValues { - private int doc = -1; - public DenseOffHeapVectorValues( int dimension, int size, @@ -141,54 +137,41 @@ public DenseOffHeapVectorValues( } @Override - public float[] vectorValue() throws IOException { - return vectorValue(doc); - } - - @Override - public int docID() { - return doc; - } - - @Override - public int nextDoc() throws IOException { - return advance(doc + 1); + public DenseOffHeapVectorValues copy() throws IOException { + return new DenseOffHeapVectorValues( + dimension, size, slice.clone(), byteSize, flatVectorsScorer, similarityFunction); } @Override - public int advance(int target) throws IOException { - assert docID() < target; - if (target >= size) { - return doc = NO_MORE_DOCS; - } - return doc = target; + public int ordToDoc(int ord) { + return ord; } @Override - public DenseOffHeapVectorValues copy() throws IOException { - return new DenseOffHeapVectorValues( - dimension, size, slice.clone(), byteSize, flatVectorsScorer, similarityFunction); + public Bits getAcceptOrds(Bits acceptDocs) { + return acceptDocs; } @Override - public Bits getAcceptOrds(Bits acceptDocs) { - return acceptDocs; + public DocIndexIterator iterator() { + return createDenseIterator(); } @Override public VectorScorer scorer(float[] query) throws IOException { DenseOffHeapVectorValues copy = copy(); + DocIndexIterator iterator = copy.iterator(); RandomVectorScorer randomVectorScorer = flatVectorsScorer.getRandomVectorScorer(similarityFunction, copy, query); return new VectorScorer() { @Override public float score() throws IOException { - return randomVectorScorer.score(copy.doc); + return randomVectorScorer.score(iterator.docID()); } @Override public DocIdSetIterator iterator() { - return copy; + return iterator; } }; } @@ -227,27 +210,6 @@ public SparseOffHeapVectorValues( configuration.size); } - @Override - public float[] vectorValue() throws IOException { - return vectorValue(disi.index()); - } - - @Override - public int docID() { - return disi.docID(); - } - - @Override - public int nextDoc() throws IOException { - return disi.nextDoc(); - } - - @Override - public int advance(int target) throws IOException { - assert docID() < target; - return disi.advance(target); - } - @Override public SparseOffHeapVectorValues copy() throws IOException { return new SparseOffHeapVectorValues( @@ -283,20 +245,26 @@ public int length() { }; } + @Override + public DocIndexIterator iterator() { + return IndexedDISI.asDocIndexIterator(disi); + } + @Override public VectorScorer scorer(float[] query) throws IOException { SparseOffHeapVectorValues copy = copy(); + DocIndexIterator iterator = copy.iterator(); RandomVectorScorer randomVectorScorer = flatVectorsScorer.getRandomVectorScorer(similarityFunction, copy, query); return new VectorScorer() { @Override public float score() throws IOException { - return randomVectorScorer.score(copy.disi.index()); + return randomVectorScorer.score(iterator.index()); } @Override public DocIdSetIterator iterator() { - return copy; + return iterator; } }; } @@ -311,8 +279,6 @@ public EmptyOffHeapVectorValues( super(dimension, 0, null, 0, flatVectorsScorer, similarityFunction); } - private int doc = -1; - @Override public int dimension() { return super.dimension(); @@ -323,26 +289,6 @@ public int size() { return 0; } - @Override - public float[] vectorValue() throws IOException { - throw new UnsupportedOperationException(); - } - - @Override - public int docID() { - return doc; - } - - @Override - public int nextDoc() throws IOException { - return advance(doc + 1); - } - - @Override - public int advance(int target) { - return doc = NO_MORE_DOCS; - } - @Override public EmptyOffHeapVectorValues copy() { throw new UnsupportedOperationException(); @@ -354,8 +300,8 @@ public float[] vectorValue(int targetOrd) { } @Override - public int ordToDoc(int ord) { - throw new UnsupportedOperationException(); + public DocIndexIterator iterator() { + return createDenseIterator(); } @Override diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsWriter.java index 1af68618d833..b731e758b7a8 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsWriter.java @@ -39,6 +39,7 @@ import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.IndexFileNames; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.MergeState; import org.apache.lucene.index.SegmentWriteState; import org.apache.lucene.index.Sorter; @@ -361,11 +362,10 @@ private void writeMeta( private static DocsWithFieldSet writeByteVectorData( IndexOutput output, ByteVectorValues byteVectorValues) throws IOException { DocsWithFieldSet docsWithField = new DocsWithFieldSet(); - for (int docV = byteVectorValues.nextDoc(); - docV != NO_MORE_DOCS; - docV = byteVectorValues.nextDoc()) { + KnnVectorValues.DocIndexIterator iter = byteVectorValues.iterator(); + for (int docV = iter.nextDoc(); docV != NO_MORE_DOCS; docV = iter.nextDoc()) { // write vector - byte[] binaryValue = byteVectorValues.vectorValue(); + byte[] binaryValue = byteVectorValues.vectorValue(iter.index()); assert binaryValue.length == byteVectorValues.dimension() * VectorEncoding.BYTE.byteSize; output.writeBytes(binaryValue, binaryValue.length); docsWithField.add(docV); @@ -382,11 +382,10 @@ private static DocsWithFieldSet writeVectorData( ByteBuffer buffer = ByteBuffer.allocate(floatVectorValues.dimension() * VectorEncoding.FLOAT32.byteSize) .order(ByteOrder.LITTLE_ENDIAN); - for (int docV = floatVectorValues.nextDoc(); - docV != NO_MORE_DOCS; - docV = floatVectorValues.nextDoc()) { + KnnVectorValues.DocIndexIterator iter = floatVectorValues.iterator(); + for (int docV = iter.nextDoc(); docV != NO_MORE_DOCS; docV = iter.nextDoc()) { // write vector - float[] value = floatVectorValues.vectorValue(); + float[] value = floatVectorValues.vectorValue(iter.index()); buffer.asFloatBuffer().put(value); output.writeBytes(buffer.array(), buffer.limit()); docsWithField.add(docV); diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsWriter.java index dc0fb7184c7c..0f4e8196d52d 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsWriter.java @@ -32,14 +32,16 @@ import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter; import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; +import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.DocsWithFieldSet; import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.IndexFileNames; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.MergeState; import org.apache.lucene.index.SegmentWriteState; import org.apache.lucene.index.Sorter; import org.apache.lucene.index.VectorSimilarityFunction; -import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.TaskExecutor; import org.apache.lucene.store.IndexOutput; import org.apache.lucene.util.IOUtils; @@ -54,7 +56,6 @@ import org.apache.lucene.util.hnsw.IncrementalHnswGraphMerger; import org.apache.lucene.util.hnsw.NeighborArray; import org.apache.lucene.util.hnsw.OnHeapHnswGraph; -import org.apache.lucene.util.hnsw.RandomAccessVectorValues; import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; import org.apache.lucene.util.packed.DirectMonotonicWriter; @@ -359,18 +360,18 @@ public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOE mergeState.knnVectorsReaders[i], mergeState.docMaps[i], mergeState.liveDocs[i]); } } - DocIdSetIterator mergedVectorIterator = null; + KnnVectorValues mergedVectorValues = null; switch (fieldInfo.getVectorEncoding()) { case BYTE -> - mergedVectorIterator = + mergedVectorValues = KnnVectorsWriter.MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState); case FLOAT32 -> - mergedVectorIterator = + mergedVectorValues = KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState); } graph = merger.merge( - mergedVectorIterator, + mergedVectorValues, segmentWriteState.infoStream, scorerSupplier.totalVectorCount()); vectorIndexNodeOffsets = writeGraph(graph); @@ -582,13 +583,13 @@ static FieldWriter create( case BYTE -> scorer.getRandomVectorScorerSupplier( fieldInfo.getVectorSimilarityFunction(), - RandomAccessVectorValues.fromBytes( + ByteVectorValues.fromBytes( (List) flatFieldVectorsWriter.getVectors(), fieldInfo.getVectorDimension())); case FLOAT32 -> scorer.getRandomVectorScorerSupplier( fieldInfo.getVectorSimilarityFunction(), - RandomAccessVectorValues.fromFloats( + FloatVectorValues.fromFloats( (List) flatFieldVectorsWriter.getVectors(), fieldInfo.getVectorDimension())); }; diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorScorer.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorScorer.java index 8443017d3f9a..a4770f01f46d 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorScorer.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorScorer.java @@ -21,12 +21,12 @@ import java.io.IOException; import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.util.VectorUtil; -import org.apache.lucene.util.hnsw.RandomAccessVectorValues; import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; -import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues; +import org.apache.lucene.util.quantization.QuantizedByteVectorValues; import org.apache.lucene.util.quantization.ScalarQuantizer; /** @@ -45,9 +45,9 @@ public Lucene99ScalarQuantizedVectorScorer(FlatVectorsScorer flatVectorsScorer) @Override public RandomVectorScorerSupplier getRandomVectorScorerSupplier( - VectorSimilarityFunction similarityFunction, RandomAccessVectorValues vectorValues) + VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues) throws IOException { - if (vectorValues instanceof RandomAccessQuantizedByteVectorValues quantizedByteVectorValues) { + if (vectorValues instanceof QuantizedByteVectorValues quantizedByteVectorValues) { return new ScalarQuantizedRandomVectorScorerSupplier( quantizedByteVectorValues, similarityFunction); } @@ -57,11 +57,9 @@ public RandomVectorScorerSupplier getRandomVectorScorerSupplier( @Override public RandomVectorScorer getRandomVectorScorer( - VectorSimilarityFunction similarityFunction, - RandomAccessVectorValues vectorValues, - float[] target) + VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, float[] target) throws IOException { - if (vectorValues instanceof RandomAccessQuantizedByteVectorValues quantizedByteVectorValues) { + if (vectorValues instanceof QuantizedByteVectorValues quantizedByteVectorValues) { ScalarQuantizer scalarQuantizer = quantizedByteVectorValues.getScalarQuantizer(); byte[] targetBytes = new byte[target.length]; float offsetCorrection = @@ -79,9 +77,7 @@ public RandomVectorScorer getRandomVectorScorer( @Override public RandomVectorScorer getRandomVectorScorer( - VectorSimilarityFunction similarityFunction, - RandomAccessVectorValues vectorValues, - byte[] target) + VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, byte[] target) throws IOException { return nonQuantizedDelegate.getRandomVectorScorer(similarityFunction, vectorValues, target); } @@ -96,7 +92,7 @@ static RandomVectorScorer fromVectorSimilarity( float offsetCorrection, VectorSimilarityFunction sim, float constMultiplier, - RandomAccessQuantizedByteVectorValues values) { + QuantizedByteVectorValues values) { return switch (sim) { case EUCLIDEAN -> new Euclidean(values, constMultiplier, targetBytes); case COSINE, DOT_PRODUCT -> @@ -120,7 +116,7 @@ private static RandomVectorScorer.AbstractRandomVectorScorer dotProductFactory( byte[] targetBytes, float offsetCorrection, float constMultiplier, - RandomAccessQuantizedByteVectorValues values, + QuantizedByteVectorValues values, FloatToFloatFunction scoreAdjustmentFunction) { if (values.getScalarQuantizer().getBits() <= 4) { if (values.getVectorByteLength() != values.dimension() && values.getSlice() != null) { @@ -137,10 +133,9 @@ private static RandomVectorScorer.AbstractRandomVectorScorer dotProductFactory( private static class Euclidean extends RandomVectorScorer.AbstractRandomVectorScorer { private final float constMultiplier; private final byte[] targetBytes; - private final RandomAccessQuantizedByteVectorValues values; + private final QuantizedByteVectorValues values; - private Euclidean( - RandomAccessQuantizedByteVectorValues values, float constMultiplier, byte[] targetBytes) { + private Euclidean(QuantizedByteVectorValues values, float constMultiplier, byte[] targetBytes) { super(values); this.values = values; this.constMultiplier = constMultiplier; @@ -159,13 +154,13 @@ public float score(int node) throws IOException { /** Calculates dot product on quantized vectors, applying the appropriate corrections */ private static class DotProduct extends RandomVectorScorer.AbstractRandomVectorScorer { private final float constMultiplier; - private final RandomAccessQuantizedByteVectorValues values; + private final QuantizedByteVectorValues values; private final byte[] targetBytes; private final float offsetCorrection; private final FloatToFloatFunction scoreAdjustmentFunction; public DotProduct( - RandomAccessQuantizedByteVectorValues values, + QuantizedByteVectorValues values, float constMultiplier, byte[] targetBytes, float offsetCorrection, @@ -193,14 +188,14 @@ public float score(int vectorOrdinal) throws IOException { private static class CompressedInt4DotProduct extends RandomVectorScorer.AbstractRandomVectorScorer { private final float constMultiplier; - private final RandomAccessQuantizedByteVectorValues values; + private final QuantizedByteVectorValues values; private final byte[] compressedVector; private final byte[] targetBytes; private final float offsetCorrection; private final FloatToFloatFunction scoreAdjustmentFunction; private CompressedInt4DotProduct( - RandomAccessQuantizedByteVectorValues values, + QuantizedByteVectorValues values, float constMultiplier, byte[] targetBytes, float offsetCorrection, @@ -231,13 +226,13 @@ public float score(int vectorOrdinal) throws IOException { private static class Int4DotProduct extends RandomVectorScorer.AbstractRandomVectorScorer { private final float constMultiplier; - private final RandomAccessQuantizedByteVectorValues values; + private final QuantizedByteVectorValues values; private final byte[] targetBytes; private final float offsetCorrection; private final FloatToFloatFunction scoreAdjustmentFunction; public Int4DotProduct( - RandomAccessQuantizedByteVectorValues values, + QuantizedByteVectorValues values, float constMultiplier, byte[] targetBytes, float offsetCorrection, @@ -271,13 +266,12 @@ private static final class ScalarQuantizedRandomVectorScorerSupplier implements RandomVectorScorerSupplier { private final VectorSimilarityFunction vectorSimilarityFunction; - private final RandomAccessQuantizedByteVectorValues values; - private final RandomAccessQuantizedByteVectorValues values1; - private final RandomAccessQuantizedByteVectorValues values2; + private final QuantizedByteVectorValues values; + private final QuantizedByteVectorValues values1; + private final QuantizedByteVectorValues values2; public ScalarQuantizedRandomVectorScorerSupplier( - RandomAccessQuantizedByteVectorValues values, - VectorSimilarityFunction vectorSimilarityFunction) + QuantizedByteVectorValues values, VectorSimilarityFunction vectorSimilarityFunction) throws IOException { this.values = values; this.values1 = values.copy(); diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsReader.java index 40002fe06a6a..24123a4f21e3 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsReader.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsReader.java @@ -402,10 +402,10 @@ static FieldEntry create( private static final class QuantizedVectorValues extends FloatVectorValues { private final FloatVectorValues rawVectorValues; - private final OffHeapQuantizedByteVectorValues quantizedVectorValues; + private final QuantizedByteVectorValues quantizedVectorValues; QuantizedVectorValues( - FloatVectorValues rawVectorValues, OffHeapQuantizedByteVectorValues quantizedVectorValues) { + FloatVectorValues rawVectorValues, QuantizedByteVectorValues quantizedVectorValues) { this.rawVectorValues = rawVectorValues; this.quantizedVectorValues = quantizedVectorValues; } @@ -421,34 +421,28 @@ public int size() { } @Override - public float[] vectorValue() throws IOException { - return rawVectorValues.vectorValue(); + public float[] vectorValue(int ord) throws IOException { + return rawVectorValues.vectorValue(ord); } @Override - public int docID() { - return rawVectorValues.docID(); + public int ordToDoc(int ord) { + return rawVectorValues.ordToDoc(ord); } @Override - public int nextDoc() throws IOException { - int rawDocId = rawVectorValues.nextDoc(); - int quantizedDocId = quantizedVectorValues.nextDoc(); - assert rawDocId == quantizedDocId; - return quantizedDocId; + public QuantizedVectorValues copy() throws IOException { + return new QuantizedVectorValues(rawVectorValues.copy(), quantizedVectorValues.copy()); } @Override - public int advance(int target) throws IOException { - int rawDocId = rawVectorValues.advance(target); - int quantizedDocId = quantizedVectorValues.advance(target); - assert rawDocId == quantizedDocId; - return quantizedDocId; + public VectorScorer scorer(float[] query) throws IOException { + return quantizedVectorValues.scorer(query); } @Override - public VectorScorer scorer(float[] query) throws IOException { - return quantizedVectorValues.scorer(query); + public DocIndexIterator iterator() { + return rawVectorValues.iterator(); } } } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java index bb333ad45c22..1a30b5271cd7 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java @@ -19,9 +19,7 @@ import static org.apache.lucene.codecs.KnnVectorsWriter.MergedVectorValues.hasVectorValues; import static org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat.DIRECT_MONOTONIC_BLOCK_SHIFT; -import static org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat.DYNAMIC_CONFIDENCE_INTERVAL; -import static org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat.QUANTIZED_VECTOR_COMPONENT; -import static org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat.calculateDefaultConfidenceInterval; +import static org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat.*; import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import static org.apache.lucene.util.RamUsageEstimator.shallowSizeOfInstance; @@ -45,6 +43,7 @@ import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.IndexFileNames; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.MergeState; import org.apache.lucene.index.SegmentWriteState; import org.apache.lucene.index.Sorter; @@ -653,12 +652,11 @@ public static ScalarQuantizer mergeAndRecalculateQuantiles( || bits <= 4 || shouldRecomputeQuantiles(mergedQuantiles, quantizationStates)) { int numVectors = 0; - FloatVectorValues vectorValues = - KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState); + DocIdSetIterator iter = + KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState) + .iterator(); // iterate vectorValues and increment numVectors - for (int doc = vectorValues.nextDoc(); - doc != DocIdSetIterator.NO_MORE_DOCS; - doc = vectorValues.nextDoc()) { + for (int doc = iter.nextDoc(); doc != DocIdSetIterator.NO_MORE_DOCS; doc = iter.nextDoc()) { numVectors++; } return buildScalarQuantizer( @@ -730,11 +728,10 @@ public static DocsWithFieldSet writeQuantizedVectorData( ? OffHeapQuantizedByteVectorValues.compressedArray( quantizedByteVectorValues.dimension(), bits) : null; - for (int docV = quantizedByteVectorValues.nextDoc(); - docV != NO_MORE_DOCS; - docV = quantizedByteVectorValues.nextDoc()) { + KnnVectorValues.DocIndexIterator iter = quantizedByteVectorValues.iterator(); + for (int docV = iter.nextDoc(); docV != NO_MORE_DOCS; docV = iter.nextDoc()) { // write vector - byte[] binaryValue = quantizedByteVectorValues.vectorValue(); + byte[] binaryValue = quantizedByteVectorValues.vectorValue(iter.index()); assert binaryValue.length == quantizedByteVectorValues.dimension() : "dim=" + quantizedByteVectorValues.dimension() + " len=" + binaryValue.length; if (compressedVector != null) { @@ -743,7 +740,8 @@ public static DocsWithFieldSet writeQuantizedVectorData( } else { output.writeBytes(binaryValue, binaryValue.length); } - output.writeInt(Float.floatToIntBits(quantizedByteVectorValues.getScoreCorrectionConstant())); + output.writeInt( + Float.floatToIntBits(quantizedByteVectorValues.getScoreCorrectionConstant(iter.index()))); docsWithField.add(docV); } return docsWithField; @@ -855,7 +853,6 @@ public DocsWithFieldSet getDocsWithFieldSet() { static class FloatVectorWrapper extends FloatVectorValues { private final List vectorList; - protected int curDoc = -1; FloatVectorWrapper(List vectorList) { this.vectorList = vectorList; @@ -872,51 +869,42 @@ public int size() { } @Override - public float[] vectorValue() throws IOException { - if (curDoc == -1 || curDoc >= vectorList.size()) { - throw new IOException("Current doc not set or too many iterations"); - } - return vectorList.get(curDoc); + public FloatVectorValues copy() throws IOException { + return this; } @Override - public int docID() { - if (curDoc >= vectorList.size()) { - return NO_MORE_DOCS; + public float[] vectorValue(int ord) throws IOException { + if (ord < 0 || ord >= vectorList.size()) { + throw new IOException("vector ord " + ord + " out of bounds"); } - return curDoc; - } - - @Override - public int nextDoc() throws IOException { - curDoc++; - return docID(); + return vectorList.get(ord); } @Override - public int advance(int target) throws IOException { - curDoc = target; - return docID(); - } - - @Override - public VectorScorer scorer(float[] target) { - throw new UnsupportedOperationException(); + public DocIndexIterator iterator() { + return createDenseIterator(); } } static class QuantizedByteVectorValueSub extends DocIDMerger.Sub { private final QuantizedByteVectorValues values; + private final KnnVectorValues.DocIndexIterator iterator; QuantizedByteVectorValueSub(MergeState.DocMap docMap, QuantizedByteVectorValues values) { super(docMap); this.values = values; - assert values.docID() == -1; + iterator = values.iterator(); + assert iterator.docID() == -1; } @Override public int nextDoc() throws IOException { - return values.nextDoc(); + return iterator.nextDoc(); + } + + public int index() { + return iterator.index(); } } @@ -973,7 +961,6 @@ public static MergedQuantizedVectorValues mergeQuantizedByteVectorValues( private final DocIDMerger docIdMerger; private final int size; - private int docId; private QuantizedByteVectorValueSub current; private MergedQuantizedVectorValues( @@ -985,33 +972,16 @@ private MergedQuantizedVectorValues( totalSize += sub.values.size(); } size = totalSize; - docId = -1; - } - - @Override - public byte[] vectorValue() throws IOException { - return current.values.vectorValue(); - } - - @Override - public int docID() { - return docId; } @Override - public int nextDoc() throws IOException { - current = docIdMerger.next(); - if (current == null) { - docId = NO_MORE_DOCS; - } else { - docId = current.mappedDocID; - } - return docId; + public byte[] vectorValue(int ord) throws IOException { + return current.values.vectorValue(current.index()); } @Override - public int advance(int target) { - throw new UnsupportedOperationException(); + public DocIndexIterator iterator() { + return new CompositeIterator(); } @Override @@ -1025,13 +995,51 @@ public int dimension() { } @Override - public float getScoreCorrectionConstant() throws IOException { - return current.values.getScoreCorrectionConstant(); + public float getScoreCorrectionConstant(int ord) throws IOException { + return current.values.getScoreCorrectionConstant(current.index()); } - @Override - public VectorScorer scorer(float[] target) throws IOException { - throw new UnsupportedOperationException(); + private class CompositeIterator extends DocIndexIterator { + private int docId; + private int ord; + + public CompositeIterator() { + docId = -1; + ord = -1; + } + + @Override + public int index() { + return ord; + } + + @Override + public int docID() { + return docId; + } + + @Override + public int nextDoc() throws IOException { + current = docIdMerger.next(); + if (current == null) { + docId = NO_MORE_DOCS; + ord = NO_MORE_DOCS; + } else { + docId = current.mappedDocID; + ++ord; + } + return docId; + } + + @Override + public int advance(int target) throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + public long cost() { + return size; + } } } @@ -1039,6 +1047,7 @@ static class QuantizedFloatVectorValues extends QuantizedByteVectorValues { private final FloatVectorValues values; private final ScalarQuantizer quantizer; private final byte[] quantizedVector; + private int lastOrd = -1; private float offsetValue = 0f; private final VectorSimilarityFunction vectorSimilarityFunction; @@ -1054,7 +1063,14 @@ public QuantizedFloatVectorValues( } @Override - public float getScoreCorrectionConstant() { + public float getScoreCorrectionConstant(int ord) { + if (ord != lastOrd) { + throw new IllegalStateException( + "attempt to retrieve score correction for different ord " + + ord + + " than the quantization was done for: " + + lastOrd); + } return offsetValue; } @@ -1069,41 +1085,31 @@ public int size() { } @Override - public byte[] vectorValue() throws IOException { + public byte[] vectorValue(int ord) throws IOException { + if (ord != lastOrd) { + offsetValue = quantize(ord); + lastOrd = ord; + } return quantizedVector; } @Override - public int docID() { - return values.docID(); + public VectorScorer scorer(float[] target) throws IOException { + throw new UnsupportedOperationException(); } - @Override - public int nextDoc() throws IOException { - int doc = values.nextDoc(); - if (doc != NO_MORE_DOCS) { - quantize(); - } - return doc; + private float quantize(int ord) throws IOException { + return quantizer.quantize(values.vectorValue(ord), quantizedVector, vectorSimilarityFunction); } @Override - public int advance(int target) throws IOException { - int doc = values.advance(target); - if (doc != NO_MORE_DOCS) { - quantize(); - } - return doc; + public int ordToDoc(int ord) { + return values.ordToDoc(ord); } @Override - public VectorScorer scorer(float[] target) throws IOException { - throw new UnsupportedOperationException(); - } - - private void quantize() throws IOException { - offsetValue = - quantizer.quantize(values.vectorValue(), quantizedVector, vectorSimilarityFunction); + public DocIndexIterator iterator() { + return values.iterator(); } } @@ -1160,9 +1166,9 @@ static final class OffsetCorrectedQuantizedByteVectorValues extends QuantizedByt } @Override - public float getScoreCorrectionConstant() throws IOException { + public float getScoreCorrectionConstant(int ord) throws IOException { return scalarQuantizer.recalculateCorrectiveOffset( - in.vectorValue(), oldScalarQuantizer, vectorSimilarityFunction); + in.vectorValue(ord), oldScalarQuantizer, vectorSimilarityFunction); } @Override @@ -1176,35 +1182,24 @@ public int size() { } @Override - public byte[] vectorValue() throws IOException { - return in.vectorValue(); + public byte[] vectorValue(int ord) throws IOException { + return in.vectorValue(ord); } @Override - public int docID() { - return in.docID(); + public int ordToDoc(int ord) { + return in.ordToDoc(ord); } @Override - public int nextDoc() throws IOException { - return in.nextDoc(); - } - - @Override - public int advance(int target) throws IOException { - return in.advance(target); - } - - @Override - public VectorScorer scorer(float[] target) throws IOException { - throw new UnsupportedOperationException(); + public DocIndexIterator iterator() { + return in.iterator(); } } static final class NormalizedFloatVectorValues extends FloatVectorValues { private final FloatVectorValues values; private final float[] normalizedVector; - int curDoc = -1; public NormalizedFloatVectorValues(FloatVectorValues values) { this.values = values; @@ -1222,38 +1217,25 @@ public int size() { } @Override - public float[] vectorValue() throws IOException { - return normalizedVector; - } - - @Override - public VectorScorer scorer(float[] query) throws IOException { - throw new UnsupportedOperationException(); + public int ordToDoc(int ord) { + return values.ordToDoc(ord); } @Override - public int docID() { - return values.docID(); + public float[] vectorValue(int ord) throws IOException { + System.arraycopy(values.vectorValue(ord), 0, normalizedVector, 0, normalizedVector.length); + VectorUtil.l2normalize(normalizedVector); + return normalizedVector; } @Override - public int nextDoc() throws IOException { - curDoc = values.nextDoc(); - if (curDoc != NO_MORE_DOCS) { - System.arraycopy(values.vectorValue(), 0, normalizedVector, 0, normalizedVector.length); - VectorUtil.l2normalize(normalizedVector); - } - return curDoc; + public DocIndexIterator iterator() { + return values.iterator(); } @Override - public int advance(int target) throws IOException { - curDoc = values.advance(target); - if (curDoc != NO_MORE_DOCS) { - System.arraycopy(values.vectorValue(), 0, normalizedVector, 0, normalizedVector.length); - VectorUtil.l2normalize(normalizedVector); - } - return curDoc; + public NormalizedFloatVectorValues copy() throws IOException { + return new NormalizedFloatVectorValues(values.copy()); } } } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/OffHeapQuantizedByteVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/OffHeapQuantizedByteVectorValues.java index 655dcca11667..051c926a679e 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/OffHeapQuantizedByteVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/OffHeapQuantizedByteVectorValues.java @@ -30,15 +30,13 @@ import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.apache.lucene.util.packed.DirectMonotonicReader; import org.apache.lucene.util.quantization.QuantizedByteVectorValues; -import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues; import org.apache.lucene.util.quantization.ScalarQuantizer; /** * Read the quantized vector values and their score correction values from the index input. This * supports both iterated and random access. */ -public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVectorValues - implements RandomAccessQuantizedByteVectorValues { +public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVectorValues { protected final int dimension; protected final int size; @@ -141,11 +139,6 @@ public byte[] vectorValue(int targetOrd) throws IOException { return binaryValue; } - @Override - public float getScoreCorrectionConstant() { - return scoreCorrectionConstant[0]; - } - @Override public float getScoreCorrectionConstant(int targetOrd) throws IOException { if (lastOrd == targetOrd) { @@ -213,8 +206,6 @@ public static OffHeapQuantizedByteVectorValues load( */ public static class DenseOffHeapVectorValues extends OffHeapQuantizedByteVectorValues { - private int doc = -1; - public DenseOffHeapVectorValues( int dimension, int size, @@ -226,30 +217,6 @@ public DenseOffHeapVectorValues( super(dimension, size, scalarQuantizer, similarityFunction, vectorsScorer, compress, slice); } - @Override - public byte[] vectorValue() throws IOException { - return vectorValue(doc); - } - - @Override - public int docID() { - return doc; - } - - @Override - public int nextDoc() throws IOException { - return advance(doc + 1); - } - - @Override - public int advance(int target) throws IOException { - assert docID() < target; - if (target >= size) { - return doc = NO_MORE_DOCS; - } - return doc = target; - } - @Override public DenseOffHeapVectorValues copy() throws IOException { return new DenseOffHeapVectorValues( @@ -270,20 +237,26 @@ public Bits getAcceptOrds(Bits acceptDocs) { @Override public VectorScorer scorer(float[] target) throws IOException { DenseOffHeapVectorValues copy = copy(); + DocIndexIterator iterator = copy.iterator(); RandomVectorScorer vectorScorer = vectorsScorer.getRandomVectorScorer(similarityFunction, copy, target); return new VectorScorer() { @Override public float score() throws IOException { - return vectorScorer.score(copy.doc); + return vectorScorer.score(iterator.index()); } @Override public DocIdSetIterator iterator() { - return copy; + return iterator; } }; } + + @Override + public DocIndexIterator iterator() { + return createDenseIterator(); + } } private static class SparseOffHeapVectorValues extends OffHeapQuantizedByteVectorValues { @@ -312,24 +285,8 @@ public SparseOffHeapVectorValues( } @Override - public byte[] vectorValue() throws IOException { - return vectorValue(disi.index()); - } - - @Override - public int docID() { - return disi.docID(); - } - - @Override - public int nextDoc() throws IOException { - return disi.nextDoc(); - } - - @Override - public int advance(int target) throws IOException { - assert docID() < target; - return disi.advance(target); + public DocIndexIterator iterator() { + return IndexedDISI.asDocIndexIterator(disi); } @Override @@ -372,17 +329,18 @@ public int length() { @Override public VectorScorer scorer(float[] target) throws IOException { SparseOffHeapVectorValues copy = copy(); + DocIndexIterator iterator = copy.iterator(); RandomVectorScorer vectorScorer = vectorsScorer.getRandomVectorScorer(similarityFunction, copy, target); return new VectorScorer() { @Override public float score() throws IOException { - return vectorScorer.score(copy.disi.index()); + return vectorScorer.score(iterator.index()); } @Override public DocIdSetIterator iterator() { - return copy; + return iterator; } }; } @@ -404,8 +362,6 @@ public EmptyOffHeapVectorValues( null); } - private int doc = -1; - @Override public int dimension() { return super.dimension(); @@ -417,23 +373,8 @@ public int size() { } @Override - public byte[] vectorValue() { - throw new UnsupportedOperationException(); - } - - @Override - public int docID() { - return doc; - } - - @Override - public int nextDoc() throws IOException { - return advance(doc + 1); - } - - @Override - public int advance(int target) { - return doc = NO_MORE_DOCS; + public DocIndexIterator iterator() { + return createDenseIterator(); } @Override diff --git a/lucene/core/src/java/org/apache/lucene/index/ByteVectorValues.java b/lucene/core/src/java/org/apache/lucene/index/ByteVectorValues.java index d33ca1ca3544..e9be3423c181 100644 --- a/lucene/core/src/java/org/apache/lucene/index/ByteVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/index/ByteVectorValues.java @@ -17,8 +17,8 @@ package org.apache.lucene.index; import java.io.IOException; +import java.util.List; import org.apache.lucene.document.KnnByteVectorField; -import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.VectorScorer; /** @@ -27,34 +27,21 @@ * * @lucene.experimental */ -public abstract class ByteVectorValues extends DocIdSetIterator { +public abstract class ByteVectorValues extends KnnVectorValues { /** Sole constructor */ protected ByteVectorValues() {} - /** Return the dimension of the vectors */ - public abstract int dimension(); - /** - * Return the number of vectors for this field. + * Return the vector value for the given vector ordinal which must be in [0, size() - 1], + * otherwise IndexOutOfBoundsException is thrown. The returned array may be shared across calls. * - * @return the number of vectors returned by this iterator + * @return the vector value */ - public abstract int size(); + public abstract byte[] vectorValue(int ord) throws IOException; @Override - public final long cost() { - return size(); - } - - /** - * Return the vector value for the current document ID. It is illegal to call this method when the - * iterator is not positioned: before advancing, or after failing to advance. The returned array - * may be shared across calls, re-used, and modified as the iterator advances. - * - * @return the vector value - */ - public abstract byte[] vectorValue() throws IOException; + public abstract ByteVectorValues copy() throws IOException; /** * Checks the Vector Encoding of a field @@ -78,12 +65,53 @@ public static void checkField(LeafReader in, String field) { } /** - * Return a {@link VectorScorer} for the given query vector. The iterator for the scorer is not - * the same instance as the iterator for this {@link ByteVectorValues}. It is a copy, and - * iteration over the scorer will not affect the iteration of this {@link ByteVectorValues}. + * Return a {@link VectorScorer} for the given query vector. * * @param query the query vector * @return a {@link VectorScorer} instance or null */ - public abstract VectorScorer scorer(byte[] query) throws IOException; + public VectorScorer scorer(byte[] query) throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + public VectorEncoding getEncoding() { + return VectorEncoding.BYTE; + } + + /** + * Creates a {@link ByteVectorValues} from a list of byte arrays. + * + * @param vectors the list of byte arrays + * @param dim the dimension of the vectors + * @return a {@link ByteVectorValues} instancec + */ + public static ByteVectorValues fromBytes(List vectors, int dim) { + return new ByteVectorValues() { + @Override + public int size() { + return vectors.size(); + } + + @Override + public int dimension() { + return dim; + } + + @Override + public byte[] vectorValue(int targetOrd) { + return vectors.get(targetOrd); + } + + @Override + public ByteVectorValues copy() { + return this; + } + + @Override + public DocIndexIterator iterator() { + return createDenseIterator(); + } + }; + } } diff --git a/lucene/core/src/java/org/apache/lucene/index/CheckIndex.java b/lucene/core/src/java/org/apache/lucene/index/CheckIndex.java index b8256ecf5875..becb00cbb5b4 100644 --- a/lucene/core/src/java/org/apache/lucene/index/CheckIndex.java +++ b/lucene/core/src/java/org/apache/lucene/index/CheckIndex.java @@ -2760,16 +2760,16 @@ private static void checkFloatVectorValues( CheckIndex.Status.VectorValuesStatus status, CodecReader codecReader) throws IOException { - int docCount = 0; + int count = 0; int everyNdoc = Math.max(values.size() / 64, 1); - while (values.nextDoc() != NO_MORE_DOCS) { + while (count < values.size()) { // search the first maxNumSearches vectors to exercise the graph - if (values.docID() % everyNdoc == 0) { + if (values.ordToDoc(count) % everyNdoc == 0) { KnnCollector collector = new TopKnnCollector(10, Integer.MAX_VALUE); if (vectorsReaderSupportsSearch(codecReader, fieldInfo.name)) { codecReader .getVectorReader() - .search(fieldInfo.name, values.vectorValue(), collector, null); + .search(fieldInfo.name, values.vectorValue(count), collector, null); TopDocs docs = collector.topDocs(); if (docs.scoreDocs.length == 0) { throw new CheckIndexException( @@ -2777,7 +2777,7 @@ private static void checkFloatVectorValues( } } } - int valueLength = values.vectorValue().length; + int valueLength = values.vectorValue(count).length; if (valueLength != fieldInfo.getVectorDimension()) { throw new CheckIndexException( "Field \"" @@ -2787,19 +2787,19 @@ private static void checkFloatVectorValues( + " not matching the field's dimension=" + fieldInfo.getVectorDimension()); } - ++docCount; + ++count; } - if (docCount != values.size()) { + if (count != values.size()) { throw new CheckIndexException( "Field \"" + fieldInfo.name + "\" has size=" + values.size() + " but when iterated, returns " - + docCount + + count + " docs with values"); } - status.totalVectorValues += docCount; + status.totalVectorValues += count; } private static void checkByteVectorValues( @@ -2808,21 +2808,23 @@ private static void checkByteVectorValues( CheckIndex.Status.VectorValuesStatus status, CodecReader codecReader) throws IOException { - int docCount = 0; + int count = 0; int everyNdoc = Math.max(values.size() / 64, 1); boolean supportsSearch = vectorsReaderSupportsSearch(codecReader, fieldInfo.name); - while (values.nextDoc() != NO_MORE_DOCS) { + while (count < values.size()) { // search the first maxNumSearches vectors to exercise the graph - if (supportsSearch && values.docID() % everyNdoc == 0) { + if (supportsSearch && values.ordToDoc(count) % everyNdoc == 0) { KnnCollector collector = new TopKnnCollector(10, Integer.MAX_VALUE); - codecReader.getVectorReader().search(fieldInfo.name, values.vectorValue(), collector, null); + codecReader + .getVectorReader() + .search(fieldInfo.name, values.vectorValue(count), collector, null); TopDocs docs = collector.topDocs(); if (docs.scoreDocs.length == 0) { throw new CheckIndexException( "Field \"" + fieldInfo.name + "\" failed to search k nearest neighbors"); } } - int valueLength = values.vectorValue().length; + int valueLength = values.vectorValue(count).length; if (valueLength != fieldInfo.getVectorDimension()) { throw new CheckIndexException( "Field \"" @@ -2832,19 +2834,19 @@ private static void checkByteVectorValues( + " not matching the field's dimension=" + fieldInfo.getVectorDimension()); } - ++docCount; + ++count; } - if (docCount != values.size()) { + if (count != values.size()) { throw new CheckIndexException( "Field \"" + fieldInfo.name + "\" has size=" + values.size() + " but when iterated, returns " - + docCount + + count + " docs with values"); } - status.totalVectorValues += docCount; + status.totalVectorValues += count; } /** diff --git a/lucene/core/src/java/org/apache/lucene/index/ExitableDirectoryReader.java b/lucene/core/src/java/org/apache/lucene/index/ExitableDirectoryReader.java index ca2cb1a27d45..614a652cd35a 100644 --- a/lucene/core/src/java/org/apache/lucene/index/ExitableDirectoryReader.java +++ b/lucene/core/src/java/org/apache/lucene/index/ExitableDirectoryReader.java @@ -429,52 +429,35 @@ private void checkAndThrow(DocIdSetIterator in) { } private class ExitableFloatVectorValues extends FloatVectorValues { - private int docToCheck; private final FloatVectorValues vectorValues; public ExitableFloatVectorValues(FloatVectorValues vectorValues) { this.vectorValues = vectorValues; - docToCheck = 0; } @Override - public int advance(int target) throws IOException { - final int advance = vectorValues.advance(target); - if (advance >= docToCheck) { - checkAndThrow(); - docToCheck = advance + DOCS_BETWEEN_TIMEOUT_CHECK; - } - return advance; + public int dimension() { + return vectorValues.dimension(); } @Override - public int docID() { - return vectorValues.docID(); + public float[] vectorValue(int ord) throws IOException { + return vectorValues.vectorValue(ord); } @Override - public int nextDoc() throws IOException { - final int nextDoc = vectorValues.nextDoc(); - if (nextDoc >= docToCheck) { - checkAndThrow(); - docToCheck = nextDoc + DOCS_BETWEEN_TIMEOUT_CHECK; - } - return nextDoc; + public int ordToDoc(int ord) { + return vectorValues.ordToDoc(ord); } @Override - public int dimension() { - return vectorValues.dimension(); - } - - @Override - public float[] vectorValue() throws IOException { - return vectorValues.vectorValue(); + public int size() { + return vectorValues.size(); } @Override - public int size() { - return vectorValues.size(); + public DocIndexIterator iterator() { + return createExitableIterator(vectorValues.iterator(), queryTimeout); } @Override @@ -482,95 +465,109 @@ public VectorScorer scorer(float[] target) throws IOException { return vectorValues.scorer(target); } - /** - * Throws {@link ExitingReaderException} if {@link QueryTimeout#shouldExit()} returns true, or - * if {@link Thread#interrupted()} returns true. - */ - private void checkAndThrow() { - if (queryTimeout.shouldExit()) { - throw new ExitingReaderException( - "The request took too long to iterate over vector values. Timeout: " - + queryTimeout.toString() - + ", FloatVectorValues=" - + in); - } else if (Thread.interrupted()) { - throw new ExitingReaderException( - "Interrupted while iterating over vector values. FloatVectorValues=" + in); - } + @Override + public FloatVectorValues copy() { + throw new UnsupportedOperationException(); } } private class ExitableByteVectorValues extends ByteVectorValues { - private int docToCheck; private final ByteVectorValues vectorValues; public ExitableByteVectorValues(ByteVectorValues vectorValues) { this.vectorValues = vectorValues; - docToCheck = 0; } @Override - public int advance(int target) throws IOException { - final int advance = vectorValues.advance(target); - if (advance >= docToCheck) { - checkAndThrow(); - docToCheck = advance + DOCS_BETWEEN_TIMEOUT_CHECK; - } - return advance; + public int dimension() { + return vectorValues.dimension(); } @Override - public int docID() { - return vectorValues.docID(); + public int size() { + return vectorValues.size(); } @Override - public int nextDoc() throws IOException { - final int nextDoc = vectorValues.nextDoc(); - if (nextDoc >= docToCheck) { - checkAndThrow(); - docToCheck = nextDoc + DOCS_BETWEEN_TIMEOUT_CHECK; - } - return nextDoc; + public byte[] vectorValue(int ord) throws IOException { + return vectorValues.vectorValue(ord); } @Override - public int dimension() { - return vectorValues.dimension(); + public int ordToDoc(int ord) { + return vectorValues.ordToDoc(ord); } @Override - public int size() { - return vectorValues.size(); + public DocIndexIterator iterator() { + return createExitableIterator(vectorValues.iterator(), queryTimeout); } @Override - public byte[] vectorValue() throws IOException { - return vectorValues.vectorValue(); + public VectorScorer scorer(byte[] target) throws IOException { + return vectorValues.scorer(target); } @Override - public VectorScorer scorer(byte[] target) throws IOException { - return vectorValues.scorer(target); + public ByteVectorValues copy() { + throw new UnsupportedOperationException(); + } + } + } + + private static KnnVectorValues.DocIndexIterator createExitableIterator( + KnnVectorValues.DocIndexIterator delegate, QueryTimeout queryTimeout) { + return new KnnVectorValues.DocIndexIterator() { + private int nextCheck; + + @Override + public int index() { + return delegate.index(); + } + + @Override + public int docID() { + return delegate.docID(); + } + + @Override + public int nextDoc() throws IOException { + int doc = delegate.nextDoc(); + if (doc >= nextCheck) { + checkAndThrow(); + nextCheck = doc + ExitableFilterAtomicReader.DOCS_BETWEEN_TIMEOUT_CHECK; + } + return doc; + } + + @Override + public long cost() { + return delegate.cost(); + } + + @Override + public int advance(int target) throws IOException { + int doc = delegate.advance(target); + if (doc >= nextCheck) { + checkAndThrow(); + nextCheck = doc + ExitableFilterAtomicReader.DOCS_BETWEEN_TIMEOUT_CHECK; + } + return doc; } - /** - * Throws {@link ExitingReaderException} if {@link QueryTimeout#shouldExit()} returns true, or - * if {@link Thread#interrupted()} returns true. - */ private void checkAndThrow() { if (queryTimeout.shouldExit()) { throw new ExitingReaderException( - "The request took too long to iterate over vector values. Timeout: " + "The request took too long to iterate over knn vector values. Timeout: " + queryTimeout.toString() - + ", ByteVectorValues=" - + in); + + ", KnnVectorValues=" + + delegate); } else if (Thread.interrupted()) { throw new ExitingReaderException( - "Interrupted while iterating over vector values. ByteVectorValues=" + in); + "Interrupted while iterating over knn vector values. KnnVectorValues=" + delegate); } } - } + }; } /** Wrapper class for another PointValues implementation that is used by ExitableFields. */ @@ -683,7 +680,7 @@ private void checkAndThrow() { if (queryTimeout.shouldExit()) { throw new ExitingReaderException( "The request took too long to intersect point values. Timeout: " - + queryTimeout.toString() + + queryTimeout + ", PointValues=" + pointValues); } else if (Thread.interrupted()) { @@ -815,7 +812,7 @@ public void grow(int count) { /** Wrapper class for another Terms implementation that is used by ExitableFields. */ public static class ExitableTerms extends FilterTerms { - private QueryTimeout queryTimeout; + private final QueryTimeout queryTimeout; /** Constructor * */ public ExitableTerms(Terms terms, QueryTimeout queryTimeout) { diff --git a/lucene/core/src/java/org/apache/lucene/index/FloatVectorValues.java b/lucene/core/src/java/org/apache/lucene/index/FloatVectorValues.java index e5dbc620f5c3..aa840fc39319 100644 --- a/lucene/core/src/java/org/apache/lucene/index/FloatVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/index/FloatVectorValues.java @@ -17,8 +17,8 @@ package org.apache.lucene.index; import java.io.IOException; +import java.util.List; import org.apache.lucene.document.KnnFloatVectorField; -import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.VectorScorer; /** @@ -27,34 +27,21 @@ * * @lucene.experimental */ -public abstract class FloatVectorValues extends DocIdSetIterator { +public abstract class FloatVectorValues extends KnnVectorValues { /** Sole constructor */ protected FloatVectorValues() {} - /** Return the dimension of the vectors */ - public abstract int dimension(); - /** - * Return the number of vectors for this field. + * Return the vector value for the given vector ordinal which must be in [0, size() - 1], + * otherwise IndexOutOfBoundsException is thrown. The returned array may be shared across calls. * - * @return the number of vectors returned by this iterator + * @return the vector value */ - public abstract int size(); + public abstract float[] vectorValue(int ord) throws IOException; @Override - public final long cost() { - return size(); - } - - /** - * Return the vector value for the current document ID. It is illegal to call this method when the - * iterator is not positioned: before advancing, or after failing to advance. The returned array - * may be shared across calls, re-used, and modified as the iterator advances. - * - * @return the vector value - */ - public abstract float[] vectorValue() throws IOException; + public abstract FloatVectorValues copy() throws IOException; /** * Checks the Vector Encoding of a field @@ -79,12 +66,53 @@ public static void checkField(LeafReader in, String field) { /** * Return a {@link VectorScorer} for the given query vector and the current {@link - * FloatVectorValues}. The iterator for the scorer is not the same instance as the iterator for - * this {@link FloatVectorValues}. It is a copy, and iteration over the scorer will not affect the - * iteration of this {@link FloatVectorValues}. + * FloatVectorValues}. * - * @param query the query vector + * @param target the query vector * @return a {@link VectorScorer} instance or null */ - public abstract VectorScorer scorer(float[] query) throws IOException; + public VectorScorer scorer(float[] target) throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + public VectorEncoding getEncoding() { + return VectorEncoding.FLOAT32; + } + + /** + * Creates a {@link FloatVectorValues} from a list of float arrays. + * + * @param vectors the list of float arrays + * @param dim the dimension of the vectors + * @return a {@link FloatVectorValues} instance + */ + public static FloatVectorValues fromFloats(List vectors, int dim) { + return new FloatVectorValues() { + @Override + public int size() { + return vectors.size(); + } + + @Override + public int dimension() { + return dim; + } + + @Override + public float[] vectorValue(int targetOrd) { + return vectors.get(targetOrd); + } + + @Override + public FloatVectorValues copy() { + return this; + } + + @Override + public DocIndexIterator iterator() { + return createDenseIterator(); + } + }; + } } diff --git a/lucene/core/src/java/org/apache/lucene/index/KnnVectorValues.java b/lucene/core/src/java/org/apache/lucene/index/KnnVectorValues.java new file mode 100644 index 000000000000..8e58f387a334 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/index/KnnVectorValues.java @@ -0,0 +1,229 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.index; + +import java.io.IOException; +import org.apache.lucene.document.KnnByteVectorField; +import org.apache.lucene.document.KnnFloatVectorField; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.util.Bits; + +/** + * This class abstracts addressing of document vector values indexed as {@link KnnFloatVectorField} + * or {@link KnnByteVectorField}. + * + * @lucene.experimental + */ +public abstract class KnnVectorValues { + + /** Return the dimension of the vectors */ + public abstract int dimension(); + + /** + * Return the number of vectors for this field. + * + * @return the number of vectors returned by this iterator + */ + public abstract int size(); + + /** + * Return the docid of the document indexed with the given vector ordinal. This default + * implementation returns the argument and is appropriate for dense values implementations where + * every doc has a single value. + */ + public int ordToDoc(int ord) { + return ord; + } + + /** + * Creates a new copy of this {@link KnnVectorValues}. This is helpful when you need to access + * different values at once, to avoid overwriting the underlying vector returned. + */ + public abstract KnnVectorValues copy() throws IOException; + + /** Returns the vector byte length, defaults to dimension multiplied by float byte size */ + public int getVectorByteLength() { + return dimension() * getEncoding().byteSize; + } + + /** The vector encoding of these values. */ + public abstract VectorEncoding getEncoding(); + + /** Returns a Bits accepting docs accepted by the argument and having a vector value */ + public Bits getAcceptOrds(Bits acceptDocs) { + // FIXME: change default to return acceptDocs and provide this impl + // somewhere more specialized (in every non-dense impl). + if (acceptDocs == null) { + return null; + } + return new Bits() { + @Override + public boolean get(int index) { + return acceptDocs.get(ordToDoc(index)); + } + + @Override + public int length() { + return size(); + } + }; + } + + /** Create an iterator for this instance. */ + public DocIndexIterator iterator() { + throw new UnsupportedOperationException(); + } + + /** + * A DocIdSetIterator that also provides an index() method tracking a distinct ordinal for a + * vector associated with each doc. + */ + public abstract static class DocIndexIterator extends DocIdSetIterator { + + /** return the value index (aka "ordinal" or "ord") corresponding to the current doc */ + public abstract int index(); + } + + /** + * Creates an iterator for instances where every doc has a value, and the value ordinals are equal + * to the docids. + */ + protected DocIndexIterator createDenseIterator() { + return new DocIndexIterator() { + + int doc = -1; + + @Override + public int docID() { + return doc; + } + + @Override + public int index() { + return doc; + } + + @Override + public int nextDoc() throws IOException { + if (doc >= size() - 1) { + return doc = NO_MORE_DOCS; + } else { + return ++doc; + } + } + + @Override + public int advance(int target) { + if (target >= size()) { + return doc = NO_MORE_DOCS; + } + return doc = target; + } + + @Override + public long cost() { + return size(); + } + }; + } + + /** + * Creates an iterator from a DocIdSetIterator indicating which docs have values, and for which + * ordinals increase monotonically with docid. + */ + protected static DocIndexIterator fromDISI(DocIdSetIterator docsWithField) { + return new DocIndexIterator() { + + int ord = -1; + + @Override + public int docID() { + return docsWithField.docID(); + } + + @Override + public int index() { + return ord; + } + + @Override + public int nextDoc() throws IOException { + if (docID() == NO_MORE_DOCS) { + return NO_MORE_DOCS; + } + ord++; + return docsWithField.nextDoc(); + } + + @Override + public int advance(int target) throws IOException { + return docsWithField.advance(target); + } + + @Override + public long cost() { + return docsWithField.cost(); + } + }; + } + + /** + * Creates an iterator from this instance's ordinal-to-docid mapping which must be monotonic + * (docid increases when ordinal does). + */ + protected DocIndexIterator createSparseIterator() { + return new DocIndexIterator() { + private int ord = -1; + + @Override + public int docID() { + if (ord == -1) { + return -1; + } + if (ord == NO_MORE_DOCS) { + return NO_MORE_DOCS; + } + return ordToDoc(ord); + } + + @Override + public int index() { + return ord; + } + + @Override + public int nextDoc() throws IOException { + if (ord >= size() - 1) { + ord = NO_MORE_DOCS; + } else { + ++ord; + } + return docID(); + } + + @Override + public int advance(int target) throws IOException { + return slowAdvance(target); + } + + @Override + public long cost() { + return size(); + } + }; + } +} diff --git a/lucene/core/src/java/org/apache/lucene/index/SlowCompositeCodecReaderWrapper.java b/lucene/core/src/java/org/apache/lucene/index/SlowCompositeCodecReaderWrapper.java index b2f7f21fb7ed..7f5d8926b638 100644 --- a/lucene/core/src/java/org/apache/lucene/index/SlowCompositeCodecReaderWrapper.java +++ b/lucene/core/src/java/org/apache/lucene/index/SlowCompositeCodecReaderWrapper.java @@ -34,9 +34,7 @@ import org.apache.lucene.codecs.TermVectorsReader; import org.apache.lucene.index.MultiDocValues.MultiSortedDocValues; import org.apache.lucene.index.MultiDocValues.MultiSortedSetDocValues; -import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.KnnCollector; -import org.apache.lucene.search.VectorScorer; import org.apache.lucene.util.ArrayUtil; import org.apache.lucene.util.Bits; import org.apache.lucene.util.IOUtils; @@ -303,38 +301,21 @@ public void checkIntegrity() throws IOException { } } - private record DocValuesSub(T sub, int docStart, int docEnd) {} + private record DocValuesSub(T sub, int docStart, int ordStart) {} - private static class MergedDocIdSetIterator extends DocIdSetIterator { + private static class MergedDocIterator + extends KnnVectorValues.DocIndexIterator { final Iterator> it; - final long cost; DocValuesSub current; - int currentIndex = 0; + KnnVectorValues.DocIndexIterator currentIterator; + int ord = -1; int doc = -1; - MergedDocIdSetIterator(List> subs) { - long cost = 0; - for (DocValuesSub sub : subs) { - if (sub.sub != null) { - cost += sub.sub.cost(); - } - } - this.cost = cost; + MergedDocIterator(List> subs) { this.it = subs.iterator(); current = it.next(); - } - - private boolean advanceSub(int target) { - while (current.sub == null || current.docEnd <= target) { - if (it.hasNext() == false) { - doc = NO_MORE_DOCS; - return false; - } - current = it.next(); - currentIndex++; - } - return true; + currentIterator = current.sub.iterator(); } @Override @@ -342,41 +323,39 @@ public int docID() { return doc; } + @Override + public int index() { + return ord; + } + @Override public int nextDoc() throws IOException { while (true) { if (current.sub != null) { - int next = current.sub.nextDoc(); + int next = currentIterator.nextDoc(); if (next != NO_MORE_DOCS) { + ++ord; return doc = current.docStart + next; } } if (it.hasNext() == false) { + ord = NO_MORE_DOCS; return doc = NO_MORE_DOCS; } current = it.next(); - currentIndex++; + currentIterator = current.sub.iterator(); + ord = current.ordStart - 1; } } @Override - public int advance(int target) throws IOException { - while (true) { - if (advanceSub(target) == false) { - return DocIdSetIterator.NO_MORE_DOCS; - } - int next = current.sub.advance(target - current.docStart); - if (next == DocIdSetIterator.NO_MORE_DOCS) { - target = current.docEnd; - } else { - return doc = current.docStart + next; - } - } + public long cost() { + throw new UnsupportedOperationException(); } @Override - public long cost() { - return cost; + public int advance(int target) throws IOException { + throw new UnsupportedOperationException(); } } @@ -848,55 +827,75 @@ public FloatVectorValues getFloatVectorValues(String field) throws IOException { int size = 0; for (CodecReader reader : codecReaders) { FloatVectorValues values = reader.getFloatVectorValues(field); + subs.add(new DocValuesSub<>(values, docStarts[i], size)); if (values != null) { if (dimension == -1) { dimension = values.dimension(); } size += values.size(); } - subs.add(new DocValuesSub<>(values, docStarts[i], docStarts[i + 1])); i++; } - final int finalDimension = dimension; - final int finalSize = size; - MergedDocIdSetIterator mergedIterator = new MergedDocIdSetIterator<>(subs); - return new FloatVectorValues() { - - @Override - public int dimension() { - return finalDimension; - } + return new MergedFloatVectorValues(dimension, size, subs); + } - @Override - public int size() { - return finalSize; + class MergedFloatVectorValues extends FloatVectorValues { + final int dimension; + final int size; + final DocValuesSub[] subs; + final MergedDocIterator iter; + final int[] starts; + int lastSubIndex; + + MergedFloatVectorValues(int dimension, int size, List> subs) { + this.dimension = dimension; + this.size = size; + this.subs = subs.toArray(new DocValuesSub[0]); + iter = new MergedDocIterator<>(subs); + // [0, start(1), ..., size] - we want the extra element + // to avoid checking for out-of-array bounds + starts = new int[subs.size() + 1]; + for (int i = 0; i < subs.size(); i++) { + starts[i] = subs.get(i).ordStart; } + starts[starts.length - 1] = size; + } - @Override - public float[] vectorValue() throws IOException { - return mergedIterator.current.sub.vectorValue(); - } + @Override + public MergedDocIterator iterator() { + return iter; + } - @Override - public int docID() { - return mergedIterator.docID(); - } + @Override + public int dimension() { + return dimension; + } - @Override - public int nextDoc() throws IOException { - return mergedIterator.nextDoc(); - } + @Override + public int size() { + return size; + } - @Override - public int advance(int target) throws IOException { - return mergedIterator.advance(target); + @SuppressWarnings("unchecked") + @Override + public FloatVectorValues copy() throws IOException { + List> subsCopy = new ArrayList<>(); + for (Object sub : subs) { + subsCopy.add((DocValuesSub) sub); } + return new MergedFloatVectorValues(dimension, size, subsCopy); + } - @Override - public VectorScorer scorer(float[] target) { - throw new UnsupportedOperationException(); - } - }; + @Override + public float[] vectorValue(int ord) throws IOException { + assert ord >= 0 && ord < size; + // We need to implement fully random-access API here in order to support callers like + // SortingCodecReader that + // rely on it. + lastSubIndex = findSub(ord, lastSubIndex, starts); + return ((FloatVectorValues) subs[lastSubIndex].sub) + .vectorValue(ord - subs[lastSubIndex].ordStart); + } } @Override @@ -907,55 +906,96 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException { int size = 0; for (CodecReader reader : codecReaders) { ByteVectorValues values = reader.getByteVectorValues(field); + subs.add(new DocValuesSub<>(values, docStarts[i], size)); if (values != null) { if (dimension == -1) { dimension = values.dimension(); } size += values.size(); } - subs.add(new DocValuesSub<>(values, docStarts[i], docStarts[i + 1])); i++; } - final int finalDimension = dimension; - final int finalSize = size; - MergedDocIdSetIterator mergedIterator = new MergedDocIdSetIterator<>(subs); - return new ByteVectorValues() { + return new MergedByteVectorValues(dimension, size, subs); + } - @Override - public int dimension() { - return finalDimension; + class MergedByteVectorValues extends ByteVectorValues { + final int dimension; + final int size; + final DocValuesSub[] subs; + final MergedDocIterator iter; + final int[] starts; + int lastSubIndex; + + MergedByteVectorValues(int dimension, int size, List> subs) { + this.dimension = dimension; + this.size = size; + this.subs = subs.toArray(new DocValuesSub[0]); + iter = new MergedDocIterator<>(subs); + // [0, start(1), ..., size] - we want the extra element + // to avoid checking for out-of-array bounds + starts = new int[subs.size() + 1]; + for (int i = 0; i < subs.size(); i++) { + starts[i] = subs.get(i).ordStart; } + starts[starts.length - 1] = size; + } - @Override - public int size() { - return finalSize; - } + @Override + public MergedDocIterator iterator() { + return iter; + } - @Override - public byte[] vectorValue() throws IOException { - return mergedIterator.current.sub.vectorValue(); - } + @Override + public int dimension() { + return dimension; + } - @Override - public int docID() { - return mergedIterator.docID(); - } + @Override + public int size() { + return size; + } - @Override - public int nextDoc() throws IOException { - return mergedIterator.nextDoc(); - } + @Override + public byte[] vectorValue(int ord) throws IOException { + assert ord >= 0 && ord < size; + // We need to implement fully random-access API here in order to support callers like + // SortingCodecReader that rely on it. We maintain lastSubIndex since we expect some + // repetition. + lastSubIndex = findSub(ord, lastSubIndex, starts); + return ((ByteVectorValues) subs[lastSubIndex].sub) + .vectorValue(ord - subs[lastSubIndex].ordStart); + } - @Override - public int advance(int target) throws IOException { - return mergedIterator.advance(target); + @SuppressWarnings("unchecked") + @Override + public ByteVectorValues copy() throws IOException { + List> newSubs = new ArrayList<>(); + for (Object sub : subs) { + newSubs.add((DocValuesSub) sub); } + return new MergedByteVectorValues(dimension, size, newSubs); + } + } - @Override - public VectorScorer scorer(byte[] target) { - throw new UnsupportedOperationException(); + private static int findSub(int ord, int lastSubIndex, int[] starts) { + if (ord >= starts[lastSubIndex]) { + if (ord >= starts[lastSubIndex + 1]) { + return binarySearchStarts(starts, ord, lastSubIndex + 1, starts.length); } - }; + } else { + return binarySearchStarts(starts, ord, 0, lastSubIndex); + } + return lastSubIndex; + } + + private static int binarySearchStarts(int[] starts, int ord, int from, int to) { + int pos = Arrays.binarySearch(starts, from, to, ord); + if (pos < 0) { + // subtract one since binarySearch returns an *insertion point* + return -2 - pos; + } else { + return pos; + } } @Override diff --git a/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java b/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java index fee0fc2f7309..daec0c197d6a 100644 --- a/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java +++ b/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java @@ -25,6 +25,7 @@ import java.util.Iterator; import java.util.Map; import java.util.Objects; +import java.util.function.Supplier; import org.apache.lucene.codecs.DocValuesProducer; import org.apache.lucene.codecs.FieldsProducer; import org.apache.lucene.codecs.KnnVectorsReader; @@ -32,10 +33,11 @@ import org.apache.lucene.codecs.PointsReader; import org.apache.lucene.codecs.StoredFieldsReader; import org.apache.lucene.codecs.TermVectorsReader; +import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.search.Sort; import org.apache.lucene.search.SortField; -import org.apache.lucene.search.VectorScorer; +import org.apache.lucene.util.BitSetIterator; import org.apache.lucene.util.Bits; import org.apache.lucene.util.FixedBitSet; import org.apache.lucene.util.IOSupplier; @@ -206,121 +208,175 @@ public PointValues.Relation compare(byte[] minPackedValue, byte[] maxPackedValue } } - /** Sorting FloatVectorValues that iterate over documents in the order of the provided sortMap */ - private static class SortingFloatVectorValues extends FloatVectorValues { - final int size; - final int dimension; - final FixedBitSet docsWithField; - final float[][] vectors; + /** + * Factory for SortingValuesIterator. This enables us to create new iterators as needed without + * recomputing the sorting mappings. + */ + static class SortingIteratorSupplier implements Supplier { + private final FixedBitSet docBits; + private final int[] docToOrd; + private final int size; - private int docId = -1; + SortingIteratorSupplier(FixedBitSet docBits, int[] docToOrd, int size) { + this.docBits = docBits; + this.docToOrd = docToOrd; + this.size = size; + } - SortingFloatVectorValues(FloatVectorValues delegate, Sorter.DocMap sortMap) throws IOException { - this.size = delegate.size(); - this.dimension = delegate.dimension(); - docsWithField = new FixedBitSet(sortMap.size()); - vectors = new float[sortMap.size()][]; - for (int doc = delegate.nextDoc(); doc != NO_MORE_DOCS; doc = delegate.nextDoc()) { - int newDocID = sortMap.oldToNew(doc); - docsWithField.set(newDocID); - vectors[newDocID] = delegate.vectorValue().clone(); - } + @Override + public SortingValuesIterator get() { + return new SortingValuesIterator(docBits, docToOrd, size); + } + + public int size() { + return size; + } + } + + /** + * Creates a factory for SortingValuesIterator. Does the work of computing the (new docId to old + * ordinal) mapping, and caches the result, enabling it to create new iterators cheaply. + * + * @param values the values over which to iterate + * @param docMap the mapping from "old" docIds to "new" (sorted) docIds. + */ + public static SortingIteratorSupplier iteratorSupplier( + KnnVectorValues values, Sorter.DocMap docMap) throws IOException { + + final int[] docToOrd = new int[docMap.size()]; + final FixedBitSet docBits = new FixedBitSet(docMap.size()); + int count = 0; + // Note: docToOrd will contain zero for docids that have no vector. This is OK though + // because the iterator cannot be positioned on such docs + KnnVectorValues.DocIndexIterator iter = values.iterator(); + for (int doc = iter.nextDoc(); doc != NO_MORE_DOCS; doc = iter.nextDoc()) { + int newDocId = docMap.oldToNew(doc); + if (newDocId != -1) { + docToOrd[newDocId] = iter.index(); + docBits.set(newDocId); + ++count; + } + } + return new SortingIteratorSupplier(docBits, docToOrd, count); + } + + /** + * Iterator over KnnVectorValues accepting a mapping to differently-sorted docs. Consequently + * index() may skip around, not increasing monotonically as iteration proceeds. + */ + public static class SortingValuesIterator extends KnnVectorValues.DocIndexIterator { + private final FixedBitSet docBits; + private final DocIdSetIterator docsWithValues; + private final int[] docToOrd; + + int doc = -1; + + SortingValuesIterator(FixedBitSet docBits, int[] docToOrd, int size) { + this.docBits = docBits; + this.docToOrd = docToOrd; + docsWithValues = new BitSetIterator(docBits, size); } @Override public int docID() { - return docId; + return doc; + } + + @Override + public int index() { + assert docBits.get(doc); + return docToOrd[doc]; } @Override public int nextDoc() throws IOException { - return advance(docId + 1); + if (doc != NO_MORE_DOCS) { + doc = docsWithValues.nextDoc(); + } + return doc; } @Override - public float[] vectorValue() throws IOException { - return vectors[docId]; + public long cost() { + return docBits.cardinality(); + } + + @Override + public int advance(int target) { + throw new UnsupportedOperationException(); + } + } + + /** Sorting FloatVectorValues that maps ordinals using the provided sortMap */ + private static class SortingFloatVectorValues extends FloatVectorValues { + final FloatVectorValues delegate; + final SortingIteratorSupplier iteratorSupplier; + + SortingFloatVectorValues(FloatVectorValues delegate, Sorter.DocMap sortMap) throws IOException { + this.delegate = delegate; + // SortingValuesIterator consumes the iterator and records the docs and ord mapping + iteratorSupplier = iteratorSupplier(delegate, sortMap); + } + + @Override + public float[] vectorValue(int ord) throws IOException { + // ords are interpreted in the delegate's ord-space. + return delegate.vectorValue(ord); } @Override public int dimension() { - return dimension; + return delegate.dimension(); } @Override public int size() { - return size; + return iteratorSupplier.size(); } @Override - public int advance(int target) throws IOException { - if (target >= docsWithField.length()) { - return NO_MORE_DOCS; - } - return docId = docsWithField.nextSetBit(target); + public FloatVectorValues copy() { + throw new UnsupportedOperationException(); } @Override - public VectorScorer scorer(float[] target) { - throw new UnsupportedOperationException(); + public DocIndexIterator iterator() { + return iteratorSupplier.get(); } } private static class SortingByteVectorValues extends ByteVectorValues { - final int size; - final int dimension; - final FixedBitSet docsWithField; - final byte[][] vectors; - - private int docId = -1; + final ByteVectorValues delegate; + final SortingIteratorSupplier iteratorSupplier; SortingByteVectorValues(ByteVectorValues delegate, Sorter.DocMap sortMap) throws IOException { - this.size = delegate.size(); - this.dimension = delegate.dimension(); - docsWithField = new FixedBitSet(sortMap.size()); - vectors = new byte[sortMap.size()][]; - for (int doc = delegate.nextDoc(); doc != NO_MORE_DOCS; doc = delegate.nextDoc()) { - int newDocID = sortMap.oldToNew(doc); - docsWithField.set(newDocID); - vectors[newDocID] = delegate.vectorValue().clone(); - } + this.delegate = delegate; + // SortingValuesIterator consumes the iterator and records the docs and ord mapping + iteratorSupplier = iteratorSupplier(delegate, sortMap); } @Override - public int docID() { - return docId; + public byte[] vectorValue(int ord) throws IOException { + return delegate.vectorValue(ord); } @Override - public int nextDoc() throws IOException { - return advance(docId + 1); - } - - @Override - public byte[] vectorValue() throws IOException { - return vectors[docId]; + public DocIndexIterator iterator() { + return iteratorSupplier.get(); } @Override public int dimension() { - return dimension; + return delegate.dimension(); } @Override public int size() { - return size; - } - - @Override - public int advance(int target) throws IOException { - if (target >= docsWithField.length()) { - return NO_MORE_DOCS; - } - return docId = docsWithField.nextSetBit(target); + return iteratorSupplier.size(); } @Override - public VectorScorer scorer(byte[] target) { + public ByteVectorValues copy() { throw new UnsupportedOperationException(); } } diff --git a/lucene/core/src/java/org/apache/lucene/search/FieldExistsQuery.java b/lucene/core/src/java/org/apache/lucene/search/FieldExistsQuery.java index 409bcbc0b643..adaace27727e 100644 --- a/lucene/core/src/java/org/apache/lucene/search/FieldExistsQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/FieldExistsQuery.java @@ -181,8 +181,8 @@ public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOExcepti } else if (fieldInfo.getVectorDimension() != 0) { // the field indexes vectors iterator = switch (fieldInfo.getVectorEncoding()) { - case FLOAT32 -> context.reader().getFloatVectorValues(field); - case BYTE -> context.reader().getByteVectorValues(field); + case FLOAT32 -> context.reader().getFloatVectorValues(field).iterator(); + case BYTE -> context.reader().getByteVectorValues(field).iterator(); }; } else if (fieldInfo.getDocValuesType() != DocValuesType.NONE) { // the field indexes doc values diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/ConcurrentHnswMerger.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/ConcurrentHnswMerger.java index 392d83fa262c..c4e7d159b489 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/ConcurrentHnswMerger.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/ConcurrentHnswMerger.java @@ -19,7 +19,7 @@ import java.io.IOException; import org.apache.lucene.codecs.hnsw.HnswGraphProvider; import org.apache.lucene.index.FieldInfo; -import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.search.TaskExecutor; import org.apache.lucene.util.BitSet; import org.apache.lucene.util.FixedBitSet; @@ -46,7 +46,7 @@ public ConcurrentHnswMerger( } @Override - protected HnswBuilder createBuilder(DocIdSetIterator mergedVectorIterator, int maxOrd) + protected HnswBuilder createBuilder(KnnVectorValues mergedVectorValues, int maxOrd) throws IOException { if (initReader == null) { return new HnswConcurrentMergeBuilder( @@ -61,7 +61,7 @@ protected HnswBuilder createBuilder(DocIdSetIterator mergedVectorIterator, int m HnswGraph initializerGraph = ((HnswGraphProvider) initReader).getGraph(fieldInfo.name); BitSet initializedNodes = new FixedBitSet(maxOrd); - int[] oldToNewOrdinalMap = getNewOrdMapping(mergedVectorIterator, initializedNodes); + int[] oldToNewOrdinalMap = getNewOrdMapping(mergedVectorValues, initializedNodes); return new HnswConcurrentMergeBuilder( taskExecutor, diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphMerger.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphMerger.java index 7ed5dd142de5..31e9c768dc03 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphMerger.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphMerger.java @@ -18,8 +18,8 @@ import java.io.IOException; import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.MergeState; -import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.util.Bits; import org.apache.lucene.util.InfoStream; @@ -45,12 +45,12 @@ HnswGraphMerger addReader(KnnVectorsReader reader, MergeState.DocMap docMap, Bit /** * Merge and produce the on heap graph * - * @param mergedVectorIterator iterator over the vectors in the merged segment + * @param mergedVectorValues view of the vectors in the merged segment * @param infoStream optional info stream to set to builder * @param maxOrd max number of vectors that will be added to the graph * @return merged graph * @throws IOException during merge */ - OnHeapHnswGraph merge(DocIdSetIterator mergedVectorIterator, InfoStream infoStream, int maxOrd) + OnHeapHnswGraph merge(KnnVectorValues mergedVectorValues, InfoStream infoStream, int maxOrd) throws IOException; } diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/IncrementalHnswGraphMerger.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/IncrementalHnswGraphMerger.java index 7331111d45a9..d64961a02ee4 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/IncrementalHnswGraphMerger.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/IncrementalHnswGraphMerger.java @@ -25,9 +25,9 @@ import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.MergeState; import org.apache.lucene.internal.hppc.IntIntHashMap; -import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.util.BitSet; import org.apache.lucene.util.Bits; import org.apache.lucene.util.FixedBitSet; @@ -108,12 +108,12 @@ public IncrementalHnswGraphMerger addReader( * Builds a new HnswGraphBuilder using the biggest graph from the merge state as a starting point. * If no valid readers were added to the merge state, a new graph is created. * - * @param mergedVectorIterator iterator over the vectors in the merged segment + * @param mergedVectorValues vector values in the merged segment * @param maxOrd max num of vectors that will be merged into the graph * @return HnswGraphBuilder * @throws IOException If an error occurs while reading from the merge state */ - protected HnswBuilder createBuilder(DocIdSetIterator mergedVectorIterator, int maxOrd) + protected HnswBuilder createBuilder(KnnVectorValues mergedVectorValues, int maxOrd) throws IOException { if (initReader == null) { return HnswGraphBuilder.create( @@ -123,7 +123,7 @@ protected HnswBuilder createBuilder(DocIdSetIterator mergedVectorIterator, int m HnswGraph initializerGraph = ((HnswGraphProvider) initReader).getGraph(fieldInfo.name); BitSet initializedNodes = new FixedBitSet(maxOrd); - int[] oldToNewOrdinalMap = getNewOrdMapping(mergedVectorIterator, initializedNodes); + int[] oldToNewOrdinalMap = getNewOrdMapping(mergedVectorValues, initializedNodes); return InitializedHnswGraphBuilder.fromGraph( scorerSupplier, M, @@ -137,8 +137,8 @@ protected HnswBuilder createBuilder(DocIdSetIterator mergedVectorIterator, int m @Override public OnHeapHnswGraph merge( - DocIdSetIterator mergedVectorIterator, InfoStream infoStream, int maxOrd) throws IOException { - HnswBuilder builder = createBuilder(mergedVectorIterator, maxOrd); + KnnVectorValues mergedVectorValues, InfoStream infoStream, int maxOrd) throws IOException { + HnswBuilder builder = createBuilder(mergedVectorValues, maxOrd); builder.setInfoStream(infoStream); return builder.build(maxOrd); } @@ -147,46 +147,45 @@ public OnHeapHnswGraph merge( * Creates a new mapping from old ordinals to new ordinals and returns the total number of vectors * in the newly merged segment. * - * @param mergedVectorIterator iterator over the vectors in the merged segment + * @param mergedVectorValues vector values in the merged segment * @param initializedNodes track what nodes have been initialized * @return the mapping from old ordinals to new ordinals * @throws IOException If an error occurs while reading from the merge state */ protected final int[] getNewOrdMapping( - DocIdSetIterator mergedVectorIterator, BitSet initializedNodes) throws IOException { - DocIdSetIterator initializerIterator = null; + KnnVectorValues mergedVectorValues, BitSet initializedNodes) throws IOException { + KnnVectorValues.DocIndexIterator initializerIterator = null; switch (fieldInfo.getVectorEncoding()) { - case BYTE -> initializerIterator = initReader.getByteVectorValues(fieldInfo.name); - case FLOAT32 -> initializerIterator = initReader.getFloatVectorValues(fieldInfo.name); + case BYTE -> initializerIterator = initReader.getByteVectorValues(fieldInfo.name).iterator(); + case FLOAT32 -> + initializerIterator = initReader.getFloatVectorValues(fieldInfo.name).iterator(); } IntIntHashMap newIdToOldOrdinal = new IntIntHashMap(initGraphSize); - int oldOrd = 0; int maxNewDocID = -1; - for (int oldId = initializerIterator.nextDoc(); - oldId != NO_MORE_DOCS; - oldId = initializerIterator.nextDoc()) { - int newId = initDocMap.get(oldId); + for (int docId = initializerIterator.nextDoc(); + docId != NO_MORE_DOCS; + docId = initializerIterator.nextDoc()) { + int newId = initDocMap.get(docId); maxNewDocID = Math.max(newId, maxNewDocID); - newIdToOldOrdinal.put(newId, oldOrd); - oldOrd++; + newIdToOldOrdinal.put(newId, initializerIterator.index()); } if (maxNewDocID == -1) { return new int[0]; } final int[] oldToNewOrdinalMap = new int[initGraphSize]; - int newOrd = 0; + KnnVectorValues.DocIndexIterator mergedVectorIterator = mergedVectorValues.iterator(); for (int newDocId = mergedVectorIterator.nextDoc(); newDocId <= maxNewDocID; newDocId = mergedVectorIterator.nextDoc()) { int hashDocIndex = newIdToOldOrdinal.indexOf(newDocId); if (newIdToOldOrdinal.indexExists(hashDocIndex)) { + int newOrd = mergedVectorIterator.index(); initializedNodes.set(newOrd); oldToNewOrdinalMap[newIdToOldOrdinal.indexGet(hashDocIndex)] = newOrd; } - newOrd++; } return oldToNewOrdinalMap; } diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/RandomAccessVectorValues.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/RandomAccessVectorValues.java deleted file mode 100644 index e2c7372b667a..000000000000 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/RandomAccessVectorValues.java +++ /dev/null @@ -1,175 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.lucene.util.hnsw; - -import java.io.IOException; -import java.util.List; -import org.apache.lucene.store.IndexInput; -import org.apache.lucene.util.Bits; - -/** - * Provides random access to vectors by dense ordinal. This interface is used by HNSW-based - * implementations of KNN search. - * - * @lucene.experimental - */ -public interface RandomAccessVectorValues { - - /** Return the number of vector values */ - int size(); - - /** Return the dimension of the returned vector values */ - int dimension(); - - /** - * Creates a new copy of this {@link RandomAccessVectorValues}. This is helpful when you need to - * access different values at once, to avoid overwriting the underlying vector returned. - */ - RandomAccessVectorValues copy() throws IOException; - - /** - * Returns a slice of the underlying {@link IndexInput} that contains the vector values if - * available - */ - default IndexInput getSlice() { - return null; - } - - /** Returns the byte length of the vector values. */ - int getVectorByteLength(); - - /** - * Translates vector ordinal to the correct document ID. By default, this is an identity function. - * - * @param ord the vector ordinal - * @return the document Id for that vector ordinal - */ - default int ordToDoc(int ord) { - return ord; - } - - /** - * Returns the {@link Bits} representing live documents. By default, this is an identity function. - * - * @param acceptDocs the accept docs - * @return the accept docs - */ - default Bits getAcceptOrds(Bits acceptDocs) { - return acceptDocs; - } - - /** Float vector values. */ - interface Floats extends RandomAccessVectorValues { - @Override - RandomAccessVectorValues.Floats copy() throws IOException; - - /** - * Return the vector value indexed at the given ordinal. - * - * @param targetOrd a valid ordinal, ≥ 0 and < {@link #size()}. - */ - float[] vectorValue(int targetOrd) throws IOException; - - /** Returns the vector byte length, defaults to dimension multiplied by float byte size */ - @Override - default int getVectorByteLength() { - return dimension() * Float.BYTES; - } - } - - /** Byte vector values. */ - interface Bytes extends RandomAccessVectorValues { - @Override - RandomAccessVectorValues.Bytes copy() throws IOException; - - /** - * Return the vector value indexed at the given ordinal. - * - * @param targetOrd a valid ordinal, ≥ 0 and < {@link #size()}. - */ - byte[] vectorValue(int targetOrd) throws IOException; - - /** Returns the vector byte length, defaults to dimension multiplied by byte size */ - @Override - default int getVectorByteLength() { - return dimension() * Byte.BYTES; - } - } - - /** - * Creates a {@link RandomAccessVectorValues.Floats} from a list of float arrays. - * - * @param vectors the list of float arrays - * @param dim the dimension of the vectors - * @return a {@link RandomAccessVectorValues.Floats} instance - */ - static RandomAccessVectorValues.Floats fromFloats(List vectors, int dim) { - return new RandomAccessVectorValues.Floats() { - @Override - public int size() { - return vectors.size(); - } - - @Override - public int dimension() { - return dim; - } - - @Override - public float[] vectorValue(int targetOrd) { - return vectors.get(targetOrd); - } - - @Override - public RandomAccessVectorValues.Floats copy() { - return this; - } - }; - } - - /** - * Creates a {@link RandomAccessVectorValues.Bytes} from a list of byte arrays. - * - * @param vectors the list of byte arrays - * @param dim the dimension of the vectors - * @return a {@link RandomAccessVectorValues.Bytes} instance - */ - static RandomAccessVectorValues.Bytes fromBytes(List vectors, int dim) { - return new RandomAccessVectorValues.Bytes() { - @Override - public int size() { - return vectors.size(); - } - - @Override - public int dimension() { - return dim; - } - - @Override - public byte[] vectorValue(int targetOrd) { - return vectors.get(targetOrd); - } - - @Override - public RandomAccessVectorValues.Bytes copy() { - return this; - } - }; - } -} diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/RandomVectorScorer.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/RandomVectorScorer.java index fc8ed3d004a1..a135df436991 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/RandomVectorScorer.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/RandomVectorScorer.java @@ -18,6 +18,7 @@ package org.apache.lucene.util.hnsw; import java.io.IOException; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.util.Bits; /** A {@link RandomVectorScorer} for scoring random nodes in batches against an abstract query. */ @@ -57,14 +58,14 @@ default Bits getAcceptOrds(Bits acceptDocs) { /** Creates a default scorer for random access vectors. */ abstract class AbstractRandomVectorScorer implements RandomVectorScorer { - private final RandomAccessVectorValues values; + private final KnnVectorValues values; /** * Creates a new scorer for the given vector values. * * @param values the vector values */ - public AbstractRandomVectorScorer(RandomAccessVectorValues values) { + public AbstractRandomVectorScorer(KnnVectorValues values) { this.values = values; } diff --git a/lucene/core/src/java/org/apache/lucene/util/quantization/QuantizedByteVectorValues.java b/lucene/core/src/java/org/apache/lucene/util/quantization/QuantizedByteVectorValues.java index a0fe957fecb4..b90ab8276dd1 100644 --- a/lucene/core/src/java/org/apache/lucene/util/quantization/QuantizedByteVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/util/quantization/QuantizedByteVectorValues.java @@ -17,9 +17,10 @@ package org.apache.lucene.util.quantization; import java.io.IOException; +import org.apache.lucene.codecs.lucene95.HasIndexSlice; import org.apache.lucene.index.ByteVectorValues; -import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.VectorScorer; +import org.apache.lucene.store.IndexInput; /** * A version of {@link ByteVectorValues}, but additionally retrieving score correction offset for @@ -27,31 +28,31 @@ * * @lucene.experimental */ -public abstract class QuantizedByteVectorValues extends DocIdSetIterator { - public abstract float getScoreCorrectionConstant() throws IOException; +public abstract class QuantizedByteVectorValues extends ByteVectorValues implements HasIndexSlice { - public abstract byte[] vectorValue() throws IOException; + public ScalarQuantizer getScalarQuantizer() { + throw new UnsupportedOperationException(); + } - /** Return the dimension of the vectors */ - public abstract int dimension(); + public abstract float getScoreCorrectionConstant(int ord) throws IOException; /** - * Return the number of vectors for this field. + * Return a {@link VectorScorer} for the given query vector. * - * @return the number of vectors returned by this iterator + * @param query the query vector + * @return a {@link VectorScorer} instance or null */ - public abstract int size(); + public VectorScorer scorer(float[] query) throws IOException { + throw new UnsupportedOperationException(); + } @Override - public final long cost() { - return size(); + public QuantizedByteVectorValues copy() throws IOException { + return this; } - /** - * Return a {@link VectorScorer} for the given query vector. - * - * @param query the query vector - * @return a {@link VectorScorer} instance or null - */ - public abstract VectorScorer scorer(float[] query) throws IOException; + @Override + public IndexInput getSlice() { + return null; + } } diff --git a/lucene/core/src/java/org/apache/lucene/util/quantization/ScalarQuantizer.java b/lucene/core/src/java/org/apache/lucene/util/quantization/ScalarQuantizer.java index ab8a911ddfae..3f7bcf6c5c45 100644 --- a/lucene/core/src/java/org/apache/lucene/util/quantization/ScalarQuantizer.java +++ b/lucene/core/src/java/org/apache/lucene/util/quantization/ScalarQuantizer.java @@ -25,6 +25,7 @@ import java.util.Random; import java.util.stream.IntStream; import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.search.HitQueue; import org.apache.lucene.search.ScoreDoc; @@ -269,11 +270,12 @@ static ScalarQuantizer fromVectors( if (totalVectorCount == 0) { return new ScalarQuantizer(0f, 0f, bits); } + KnnVectorValues.DocIndexIterator iterator = floatVectorValues.iterator(); if (confidenceInterval == 1f) { float min = Float.POSITIVE_INFINITY; float max = Float.NEGATIVE_INFINITY; - while (floatVectorValues.nextDoc() != NO_MORE_DOCS) { - for (float v : floatVectorValues.vectorValue()) { + while (iterator.nextDoc() != NO_MORE_DOCS) { + for (float v : floatVectorValues.vectorValue(iterator.index())) { min = Math.min(min, v); max = Math.max(max, v); } @@ -289,8 +291,8 @@ static ScalarQuantizer fromVectors( if (totalVectorCount <= quantizationSampleSize) { int scratchSize = Math.min(SCRATCH_SIZE, totalVectorCount); int i = 0; - while (floatVectorValues.nextDoc() != NO_MORE_DOCS) { - float[] vectorValue = floatVectorValues.vectorValue(); + while (iterator.nextDoc() != NO_MORE_DOCS) { + float[] vectorValue = floatVectorValues.vectorValue(iterator.index()); System.arraycopy( vectorValue, 0, quantileGatheringScratch, i * vectorValue.length, vectorValue.length); i++; @@ -311,11 +313,11 @@ static ScalarQuantizer fromVectors( for (int i : vectorsToTake) { while (index <= i) { // We cannot use `advance(docId)` as MergedVectorValues does not support it - floatVectorValues.nextDoc(); + iterator.nextDoc(); index++; } - assert floatVectorValues.docID() != NO_MORE_DOCS; - float[] vectorValue = floatVectorValues.vectorValue(); + assert iterator.docID() != NO_MORE_DOCS; + float[] vectorValue = floatVectorValues.vectorValue(iterator.index()); System.arraycopy( vectorValue, 0, quantileGatheringScratch, idx * vectorValue.length, vectorValue.length); idx++; @@ -353,11 +355,16 @@ public static ScalarQuantizer fromVectorsAutoInterval( / (floatVectorValues.dimension() + 1), 1 - 1f / (floatVectorValues.dimension() + 1) }; + KnnVectorValues.DocIndexIterator iterator = floatVectorValues.iterator(); if (totalVectorCount <= sampleSize) { int scratchSize = Math.min(SCRATCH_SIZE, totalVectorCount); int i = 0; - while (floatVectorValues.nextDoc() != NO_MORE_DOCS) { - gatherSample(floatVectorValues, quantileGatheringScratch, sampledDocs, i); + while (iterator.nextDoc() != NO_MORE_DOCS) { + gatherSample( + floatVectorValues.vectorValue(iterator.index()), + quantileGatheringScratch, + sampledDocs, + i); i++; if (i == scratchSize) { extractQuantiles(confidenceIntervals, quantileGatheringScratch, upperSum, lowerSum); @@ -374,11 +381,15 @@ public static ScalarQuantizer fromVectorsAutoInterval( for (int i : vectorsToTake) { while (index <= i) { // We cannot use `advance(docId)` as MergedVectorValues does not support it - floatVectorValues.nextDoc(); + iterator.nextDoc(); index++; } - assert floatVectorValues.docID() != NO_MORE_DOCS; - gatherSample(floatVectorValues, quantileGatheringScratch, sampledDocs, idx); + assert iterator.docID() != NO_MORE_DOCS; + gatherSample( + floatVectorValues.vectorValue(iterator.index()), + quantileGatheringScratch, + sampledDocs, + idx); idx++; if (idx == SCRATCH_SIZE) { extractQuantiles(confidenceIntervals, quantileGatheringScratch, upperSum, lowerSum); @@ -437,12 +448,7 @@ private static void extractQuantiles( } private static void gatherSample( - FloatVectorValues floatVectorValues, - float[] quantileGatheringScratch, - List sampledDocs, - int i) - throws IOException { - float[] vectorValue = floatVectorValues.vectorValue(); + float[] vectorValue, float[] quantileGatheringScratch, List sampledDocs, int i) { float[] copy = new float[vectorValue.length]; System.arraycopy(vectorValue, 0, copy, 0, vectorValue.length); sampledDocs.add(copy); diff --git a/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentByteVectorScorer.java b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentByteVectorScorer.java index 0798885c9067..dae2cc3502cd 100644 --- a/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentByteVectorScorer.java +++ b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentByteVectorScorer.java @@ -19,11 +19,11 @@ import java.io.IOException; import java.lang.foreign.MemorySegment; import java.util.Optional; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.store.FilterIndexInput; import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.MemorySegmentAccessInput; -import org.apache.lucene.util.hnsw.RandomAccessVectorValues; import org.apache.lucene.util.hnsw.RandomVectorScorer; abstract sealed class Lucene99MemorySegmentByteVectorScorer @@ -39,10 +39,7 @@ abstract sealed class Lucene99MemorySegmentByteVectorScorer * returned. */ public static Optional create( - VectorSimilarityFunction type, - IndexInput input, - RandomAccessVectorValues values, - byte[] queryVector) { + VectorSimilarityFunction type, IndexInput input, KnnVectorValues values, byte[] queryVector) { input = FilterIndexInput.unwrapOnlyTest(input); if (!(input instanceof MemorySegmentAccessInput msInput)) { return Optional.empty(); @@ -58,7 +55,7 @@ public static Optional create( } Lucene99MemorySegmentByteVectorScorer( - MemorySegmentAccessInput input, RandomAccessVectorValues values, byte[] queryVector) { + MemorySegmentAccessInput input, KnnVectorValues values, byte[] queryVector) { super(values); this.input = input; this.vectorByteSize = values.getVectorByteLength(); @@ -92,7 +89,7 @@ final void checkOrdinal(int ord) { } static final class CosineScorer extends Lucene99MemorySegmentByteVectorScorer { - CosineScorer(MemorySegmentAccessInput input, RandomAccessVectorValues values, byte[] query) { + CosineScorer(MemorySegmentAccessInput input, KnnVectorValues values, byte[] query) { super(input, values, query); } @@ -105,8 +102,7 @@ public float score(int node) throws IOException { } static final class DotProductScorer extends Lucene99MemorySegmentByteVectorScorer { - DotProductScorer( - MemorySegmentAccessInput input, RandomAccessVectorValues values, byte[] query) { + DotProductScorer(MemorySegmentAccessInput input, KnnVectorValues values, byte[] query) { super(input, values, query); } @@ -120,7 +116,7 @@ public float score(int node) throws IOException { } static final class EuclideanScorer extends Lucene99MemorySegmentByteVectorScorer { - EuclideanScorer(MemorySegmentAccessInput input, RandomAccessVectorValues values, byte[] query) { + EuclideanScorer(MemorySegmentAccessInput input, KnnVectorValues values, byte[] query) { super(input, values, query); } @@ -133,8 +129,7 @@ public float score(int node) throws IOException { } static final class MaxInnerProductScorer extends Lucene99MemorySegmentByteVectorScorer { - MaxInnerProductScorer( - MemorySegmentAccessInput input, RandomAccessVectorValues values, byte[] query) { + MaxInnerProductScorer(MemorySegmentAccessInput input, KnnVectorValues values, byte[] query) { super(input, values, query); } diff --git a/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentByteVectorScorerSupplier.java b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentByteVectorScorerSupplier.java index 90b3bfb014c3..9dd2b4620ace 100644 --- a/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentByteVectorScorerSupplier.java +++ b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentByteVectorScorerSupplier.java @@ -19,11 +19,11 @@ import java.io.IOException; import java.lang.foreign.MemorySegment; import java.util.Optional; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.store.FilterIndexInput; import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.MemorySegmentAccessInput; -import org.apache.lucene.util.hnsw.RandomAccessVectorValues; import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; @@ -33,7 +33,7 @@ public abstract sealed class Lucene99MemorySegmentByteVectorScorerSupplier final int vectorByteSize; final int maxOrd; final MemorySegmentAccessInput input; - final RandomAccessVectorValues values; // to support ordToDoc/getAcceptOrds + final KnnVectorValues values; // to support ordToDoc/getAcceptOrds byte[] scratch1, scratch2; /** @@ -41,7 +41,7 @@ public abstract sealed class Lucene99MemorySegmentByteVectorScorerSupplier * optional is returned. */ static Optional create( - VectorSimilarityFunction type, IndexInput input, RandomAccessVectorValues values) { + VectorSimilarityFunction type, IndexInput input, KnnVectorValues values) { input = FilterIndexInput.unwrapOnlyTest(input); if (!(input instanceof MemorySegmentAccessInput msInput)) { return Optional.empty(); @@ -56,7 +56,7 @@ static Optional create( } Lucene99MemorySegmentByteVectorScorerSupplier( - MemorySegmentAccessInput input, RandomAccessVectorValues values) { + MemorySegmentAccessInput input, KnnVectorValues values) { this.input = input; this.values = values; this.vectorByteSize = values.getVectorByteLength(); @@ -103,7 +103,7 @@ final MemorySegment getSecondSegment(int ord) throws IOException { static final class CosineSupplier extends Lucene99MemorySegmentByteVectorScorerSupplier { - CosineSupplier(MemorySegmentAccessInput input, RandomAccessVectorValues values) { + CosineSupplier(MemorySegmentAccessInput input, KnnVectorValues values) { super(input, values); } @@ -128,7 +128,7 @@ public CosineSupplier copy() throws IOException { static final class DotProductSupplier extends Lucene99MemorySegmentByteVectorScorerSupplier { - DotProductSupplier(MemorySegmentAccessInput input, RandomAccessVectorValues values) { + DotProductSupplier(MemorySegmentAccessInput input, KnnVectorValues values) { super(input, values); } @@ -155,7 +155,7 @@ public DotProductSupplier copy() throws IOException { static final class EuclideanSupplier extends Lucene99MemorySegmentByteVectorScorerSupplier { - EuclideanSupplier(MemorySegmentAccessInput input, RandomAccessVectorValues values) { + EuclideanSupplier(MemorySegmentAccessInput input, KnnVectorValues values) { super(input, values); } @@ -181,7 +181,7 @@ public EuclideanSupplier copy() throws IOException { static final class MaxInnerProductSupplier extends Lucene99MemorySegmentByteVectorScorerSupplier { - MaxInnerProductSupplier(MemorySegmentAccessInput input, RandomAccessVectorValues values) { + MaxInnerProductSupplier(MemorySegmentAccessInput input, KnnVectorValues values) { super(input, values); } diff --git a/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentFlatVectorsScorer.java b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentFlatVectorsScorer.java index b085185fb113..63e79bccbdea 100644 --- a/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentFlatVectorsScorer.java +++ b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentFlatVectorsScorer.java @@ -19,11 +19,12 @@ import java.io.IOException; import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer; import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.codecs.lucene95.HasIndexSlice; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.VectorSimilarityFunction; -import org.apache.lucene.util.hnsw.RandomAccessVectorValues; import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; -import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues; +import org.apache.lucene.util.quantization.QuantizedByteVectorValues; public class Lucene99MemorySegmentFlatVectorsScorer implements FlatVectorsScorer { @@ -38,15 +39,15 @@ private Lucene99MemorySegmentFlatVectorsScorer(FlatVectorsScorer delegate) { @Override public RandomVectorScorerSupplier getRandomVectorScorerSupplier( - VectorSimilarityFunction similarityType, RandomAccessVectorValues vectorValues) - throws IOException { + VectorSimilarityFunction similarityType, KnnVectorValues vectorValues) throws IOException { // a quantized values here is a wrapping or delegation issue - assert !(vectorValues instanceof RandomAccessQuantizedByteVectorValues); + assert !(vectorValues instanceof QuantizedByteVectorValues); // currently only supports binary vectors - if (vectorValues instanceof RandomAccessVectorValues.Bytes && vectorValues.getSlice() != null) { + if (vectorValues instanceof HasIndexSlice byteVectorValues + && byteVectorValues.getSlice() != null) { var scorer = Lucene99MemorySegmentByteVectorScorerSupplier.create( - similarityType, vectorValues.getSlice(), vectorValues); + similarityType, byteVectorValues.getSlice(), vectorValues); if (scorer.isPresent()) { return scorer.get(); } @@ -56,9 +57,7 @@ public RandomVectorScorerSupplier getRandomVectorScorerSupplier( @Override public RandomVectorScorer getRandomVectorScorer( - VectorSimilarityFunction similarityType, - RandomAccessVectorValues vectorValues, - float[] target) + VectorSimilarityFunction similarityType, KnnVectorValues vectorValues, float[] target) throws IOException { // currently only supports binary vectors, so always delegate return delegate.getRandomVectorScorer(similarityType, vectorValues, target); @@ -66,17 +65,16 @@ public RandomVectorScorer getRandomVectorScorer( @Override public RandomVectorScorer getRandomVectorScorer( - VectorSimilarityFunction similarityType, - RandomAccessVectorValues vectorValues, - byte[] queryVector) + VectorSimilarityFunction similarityType, KnnVectorValues vectorValues, byte[] queryVector) throws IOException { checkDimensions(queryVector.length, vectorValues.dimension()); // a quantized values here is a wrapping or delegation issue - assert !(vectorValues instanceof RandomAccessQuantizedByteVectorValues); - if (vectorValues instanceof RandomAccessVectorValues.Bytes && vectorValues.getSlice() != null) { + assert !(vectorValues instanceof QuantizedByteVectorValues); + if (vectorValues instanceof HasIndexSlice byteVectorValues + && byteVectorValues.getSlice() != null) { var scorer = Lucene99MemorySegmentByteVectorScorer.create( - similarityType, vectorValues.getSlice(), vectorValues, queryVector); + similarityType, byteVectorValues.getSlice(), vectorValues, queryVector); if (scorer.isPresent()) { return scorer.get(); } diff --git a/lucene/core/src/test/org/apache/lucene/codecs/hnsw/TestFlatVectorScorer.java b/lucene/core/src/test/org/apache/lucene/codecs/hnsw/TestFlatVectorScorer.java index 9bce1f10a432..6fe9a685e1b4 100644 --- a/lucene/core/src/test/org/apache/lucene/codecs/hnsw/TestFlatVectorScorer.java +++ b/lucene/core/src/test/org/apache/lucene/codecs/hnsw/TestFlatVectorScorer.java @@ -35,6 +35,8 @@ import org.apache.lucene.codecs.lucene95.OffHeapByteVectorValues; import org.apache.lucene.codecs.lucene95.OffHeapFloatVectorValues; import org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorScorer; +import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.store.Directory; import org.apache.lucene.store.IOContext; @@ -42,7 +44,6 @@ import org.apache.lucene.store.IndexOutput; import org.apache.lucene.store.MMapDirectory; import org.apache.lucene.tests.util.LuceneTestCase; -import org.apache.lucene.util.hnsw.RandomAccessVectorValues; import org.hamcrest.Matcher; import org.hamcrest.MatcherAssert; @@ -174,13 +175,13 @@ public void testCheckFloatDimensions() throws IOException { } } - RandomAccessVectorValues byteVectorValues( - int dims, int size, IndexInput in, VectorSimilarityFunction sim) throws IOException { + ByteVectorValues byteVectorValues(int dims, int size, IndexInput in, VectorSimilarityFunction sim) + throws IOException { return new OffHeapByteVectorValues.DenseOffHeapVectorValues( dims, size, in.slice("byteValues", 0, in.length()), dims, flatVectorsScorer, sim); } - RandomAccessVectorValues floatVectorValues( + FloatVectorValues floatVectorValues( int dims, int size, IndexInput in, VectorSimilarityFunction sim) throws IOException { return new OffHeapFloatVectorValues.DenseOffHeapVectorValues( dims, diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99HnswQuantizedVectorsFormat.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99HnswQuantizedVectorsFormat.java index 825de3ab725a..c3225326f4c2 100644 --- a/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99HnswQuantizedVectorsFormat.java +++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99HnswQuantizedVectorsFormat.java @@ -17,7 +17,6 @@ package org.apache.lucene.codecs.lucene99; import static java.lang.String.format; -import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.oneOf; @@ -312,14 +311,13 @@ public void testQuantizedVectorsWriteAndRead() throws Exception { assertNotNull(hnswReader.getQuantizationState("f")); QuantizedByteVectorValues quantizedByteVectorValues = hnswReader.getQuantizedVectorValues("f"); - int docId = -1; - while ((docId = quantizedByteVectorValues.nextDoc()) != NO_MORE_DOCS) { - byte[] vector = quantizedByteVectorValues.vectorValue(); - float offset = quantizedByteVectorValues.getScoreCorrectionConstant(); + for (int ord = 0; ord < quantizedByteVectorValues.size(); ord++) { + byte[] vector = quantizedByteVectorValues.vectorValue(ord); + float offset = quantizedByteVectorValues.getScoreCorrectionConstant(ord); for (int i = 0; i < dim; i++) { - assertEquals(vector[i], expectedVectors[docId][i]); + assertEquals(vector[i], expectedVectors[ord][i]); } - assertEquals(offset, expectedCorrections[docId], 0.00001f); + assertEquals(offset, expectedCorrections[ord], 0.00001f); } } else { fail("reader is not Lucene99HnswVectorsReader"); diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99ScalarQuantizedVectorScorer.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99ScalarQuantizedVectorScorer.java index a0f640fa650b..da1020dc36b1 100644 --- a/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99ScalarQuantizedVectorScorer.java +++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99ScalarQuantizedVectorScorer.java @@ -46,7 +46,7 @@ import org.apache.lucene.tests.util.LuceneTestCase; import org.apache.lucene.util.VectorUtil; import org.apache.lucene.util.hnsw.RandomVectorScorer; -import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues; +import org.apache.lucene.util.quantization.QuantizedByteVectorValues; import org.apache.lucene.util.quantization.ScalarQuantizer; public class TestLucene99ScalarQuantizedVectorScorer extends LuceneTestCase { @@ -100,8 +100,8 @@ private void vectorNonZeroScoringTest(int bits, boolean compress) throws IOExcep try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) { Lucene99ScalarQuantizedVectorScorer scorer = new Lucene99ScalarQuantizedVectorScorer(new DefaultFlatVectorScorer()); - RandomAccessQuantizedByteVectorValues values = - new RandomAccessQuantizedByteVectorValues() { + QuantizedByteVectorValues values = + new QuantizedByteVectorValues() { @Override public int dimension() { return 32; @@ -128,7 +128,7 @@ public float getScoreCorrectionConstant(int ord) { } @Override - public RandomAccessQuantizedByteVectorValues copy() throws IOException { + public QuantizedByteVectorValues copy() throws IOException { return this; } diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99ScalarQuantizedVectorsFormat.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99ScalarQuantizedVectorsFormat.java index 64df927c7650..d6b42c697085 100644 --- a/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99ScalarQuantizedVectorsFormat.java +++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99ScalarQuantizedVectorsFormat.java @@ -37,6 +37,7 @@ import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.NoMergePolicy; import org.apache.lucene.index.VectorSimilarityFunction; @@ -173,9 +174,10 @@ public void testQuantizedVectorsWriteAndRead() throws Exception { QuantizedByteVectorValues quantizedByteVectorValues = quantizedReader.getQuantizedVectorValues("f"); int docId = -1; - while ((docId = quantizedByteVectorValues.nextDoc()) != NO_MORE_DOCS) { - byte[] vector = quantizedByteVectorValues.vectorValue(); - float offset = quantizedByteVectorValues.getScoreCorrectionConstant(); + KnnVectorValues.DocIndexIterator iter = quantizedByteVectorValues.iterator(); + for (docId = iter.nextDoc(); docId != NO_MORE_DOCS; docId = iter.nextDoc()) { + byte[] vector = quantizedByteVectorValues.vectorValue(iter.index()); + float offset = quantizedByteVectorValues.getScoreCorrectionConstant(iter.index()); for (int i = 0; i < dim; i++) { assertEquals(vector[i], expectedVectors[docId][i]); } diff --git a/lucene/core/src/test/org/apache/lucene/document/TestField.java b/lucene/core/src/test/org/apache/lucene/document/TestField.java index 6e3a855a0df4..5c1b8f17294f 100644 --- a/lucene/core/src/test/org/apache/lucene/document/TestField.java +++ b/lucene/core/src/test/org/apache/lucene/document/TestField.java @@ -18,6 +18,7 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; +import java.io.IOException; import java.io.StringReader; import java.nio.charset.StandardCharsets; import org.apache.lucene.codecs.Codec; @@ -27,6 +28,7 @@ import org.apache.lucene.index.IndexOptions; import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.Term; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.search.IndexSearcher; @@ -713,17 +715,21 @@ public void testKnnVectorField() throws Exception { try (IndexReader r = DirectoryReader.open(w)) { ByteVectorValues binary = r.leaves().get(0).reader().getByteVectorValues("binary"); assertEquals(1, binary.size()); - assertNotEquals(NO_MORE_DOCS, binary.nextDoc()); - assertNotNull(binary.vectorValue()); - assertArrayEquals(b, binary.vectorValue()); - assertEquals(NO_MORE_DOCS, binary.nextDoc()); + KnnVectorValues.DocIndexIterator iterator = binary.iterator(); + assertNotEquals(NO_MORE_DOCS, iterator.nextDoc()); + assertNotNull(binary.vectorValue(0)); + assertArrayEquals(b, binary.vectorValue(0)); + assertEquals(NO_MORE_DOCS, iterator.nextDoc()); + expectThrows(IOException.class, () -> binary.vectorValue(1)); FloatVectorValues floatValues = r.leaves().get(0).reader().getFloatVectorValues("float"); assertEquals(1, floatValues.size()); - assertNotEquals(NO_MORE_DOCS, floatValues.nextDoc()); - assertEquals(vector.length, floatValues.vectorValue().length); - assertEquals(vector[0], floatValues.vectorValue()[0], 0); - assertEquals(NO_MORE_DOCS, floatValues.nextDoc()); + KnnVectorValues.DocIndexIterator iterator1 = floatValues.iterator(); + assertNotEquals(NO_MORE_DOCS, iterator1.nextDoc()); + assertEquals(vector.length, floatValues.vectorValue(0).length); + assertEquals(vector[0], floatValues.vectorValue(0)[0], 0); + assertEquals(NO_MORE_DOCS, iterator1.nextDoc()); + expectThrows(IOException.class, () -> floatValues.vectorValue(1)); } } } diff --git a/lucene/core/src/test/org/apache/lucene/index/TestExitableDirectoryReader.java b/lucene/core/src/test/org/apache/lucene/index/TestExitableDirectoryReader.java index 3c82cd6b33e4..d03c8cf42b59 100644 --- a/lucene/core/src/test/org/apache/lucene/index/TestExitableDirectoryReader.java +++ b/lucene/core/src/test/org/apache/lucene/index/TestExitableDirectoryReader.java @@ -459,8 +459,8 @@ public void testFloatVectorValues() throws IOException { expectThrows( ExitingReaderException.class, () -> { - DocIdSetIterator iter = leaf.getFloatVectorValues("vector"); - scanAndRetrieve(leaf, iter); + KnnVectorValues values = leaf.getFloatVectorValues("vector"); + scanAndRetrieve(leaf, values); }); expectThrows( @@ -473,8 +473,8 @@ public void testFloatVectorValues() throws IOException { leaf.getLiveDocs(), Integer.MAX_VALUE)); } else { - DocIdSetIterator iter = leaf.getFloatVectorValues("vector"); - scanAndRetrieve(leaf, iter); + KnnVectorValues values = leaf.getFloatVectorValues("vector"); + scanAndRetrieve(leaf, values); leaf.searchNearestVectors( "vector", @@ -534,8 +534,8 @@ public void testByteVectorValues() throws IOException { expectThrows( ExitingReaderException.class, () -> { - DocIdSetIterator iter = leaf.getByteVectorValues("vector"); - scanAndRetrieve(leaf, iter); + KnnVectorValues values = leaf.getByteVectorValues("vector"); + scanAndRetrieve(leaf, values); }); expectThrows( @@ -549,8 +549,8 @@ public void testByteVectorValues() throws IOException { Integer.MAX_VALUE)); } else { - DocIdSetIterator iter = leaf.getByteVectorValues("vector"); - scanAndRetrieve(leaf, iter); + KnnVectorValues values = leaf.getByteVectorValues("vector"); + scanAndRetrieve(leaf, values); leaf.searchNearestVectors( "vector", @@ -564,20 +564,24 @@ public void testByteVectorValues() throws IOException { directory.close(); } - private static void scanAndRetrieve(LeafReader leaf, DocIdSetIterator iter) throws IOException { + private static void scanAndRetrieve(LeafReader leaf, KnnVectorValues values) throws IOException { + KnnVectorValues.DocIndexIterator iter = values.iterator(); for (iter.nextDoc(); iter.docID() != DocIdSetIterator.NO_MORE_DOCS && iter.docID() < leaf.maxDoc(); ) { - final int nextDocId = iter.docID() + 1; + int docId = iter.docID(); + if (docId >= leaf.maxDoc()) { + break; + } + final int nextDocId = docId + 1; if (random().nextBoolean() && nextDocId < leaf.maxDoc()) { iter.advance(nextDocId); } else { iter.nextDoc(); } - if (random().nextBoolean() && iter.docID() != DocIdSetIterator.NO_MORE_DOCS - && iter instanceof FloatVectorValues) { - ((FloatVectorValues) iter).vectorValue(); + && values instanceof FloatVectorValues) { + ((FloatVectorValues) values).vectorValue(iter.index()); } } } diff --git a/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java b/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java index 72be0bd929fa..41410ad4e39d 100644 --- a/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java +++ b/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java @@ -413,11 +413,13 @@ private void assertConsistentGraph(IndexWriter iw, float[][] values, String vect // stored vector values are the same as original int nextDocWithVectors = 0; StoredFields storedFields = reader.storedFields(); + KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator(); for (int i = 0; i < reader.maxDoc(); i++) { - nextDocWithVectors = vectorValues.advance(i); + nextDocWithVectors = iterator.advance(i); while (i < nextDocWithVectors && i < reader.maxDoc()) { int id = Integer.parseInt(storedFields.document(i).get("id")); - assertNull("document " + id + " has no vector, but was expected to", values[id]); + assertNull( + "document " + id + ", expected to have no vector, does have one", values[id]); ++i; } if (nextDocWithVectors == NO_MORE_DOCS) { @@ -425,7 +427,7 @@ private void assertConsistentGraph(IndexWriter iw, float[][] values, String vect } int id = Integer.parseInt(storedFields.document(i).get("id")); // documents with KnnGraphValues have the expected vectors - float[] scratch = vectorValues.vectorValue(); + float[] scratch = vectorValues.vectorValue(iterator.index()); assertArrayEquals( "vector did not match for doc " + i + ", id=" + id + ": " + Arrays.toString(scratch), values[id], @@ -435,9 +437,9 @@ private void assertConsistentGraph(IndexWriter iw, float[][] values, String vect } // if IndexDisi.doc == NO_MORE_DOCS, we should not call IndexDisi.nextDoc() if (nextDocWithVectors != NO_MORE_DOCS) { - assertEquals(NO_MORE_DOCS, vectorValues.nextDoc()); + assertEquals(NO_MORE_DOCS, iterator.nextDoc()); } else { - assertEquals(NO_MORE_DOCS, vectorValues.docID()); + assertEquals(NO_MORE_DOCS, iterator.docID()); } // assert graph values: diff --git a/lucene/core/src/test/org/apache/lucene/index/TestSortingCodecReader.java b/lucene/core/src/test/org/apache/lucene/index/TestSortingCodecReader.java index 241fc0a5fe5f..fcbd0cdea21f 100644 --- a/lucene/core/src/test/org/apache/lucene/index/TestSortingCodecReader.java +++ b/lucene/core/src/test/org/apache/lucene/index/TestSortingCodecReader.java @@ -242,6 +242,7 @@ public void testSortOnAddIndicesRandom() throws IOException { NumericDocValues ids = leaf.getNumericDocValues("id"); long prevValue = -1; boolean usingAltIds = false; + KnnVectorValues.DocIndexIterator valuesIterator = vectorValues.iterator(); for (int i = 0; i < actualNumDocs; i++) { int idNext = ids.nextDoc(); if (idNext == DocIdSetIterator.NO_MORE_DOCS) { @@ -262,7 +263,7 @@ public void testSortOnAddIndicesRandom() throws IOException { assertTrue(sorted_numeric_dv.advanceExact(idNext)); assertTrue(sorted_set_dv.advanceExact(idNext)); assertTrue(binary_sorted_dv.advanceExact(idNext)); - assertEquals(idNext, vectorValues.advance(idNext)); + assertEquals(idNext, valuesIterator.advance(idNext)); assertEquals(new BytesRef(ids.longValue() + ""), binary_dv.binaryValue()); assertEquals( new BytesRef(ids.longValue() + ""), @@ -274,7 +275,7 @@ public void testSortOnAddIndicesRandom() throws IOException { assertEquals(1, sorted_numeric_dv.docValueCount()); assertEquals(ids.longValue(), sorted_numeric_dv.nextValue()); - float[] vectorValue = vectorValues.vectorValue(); + float[] vectorValue = vectorValues.vectorValue(valuesIterator.index()); assertEquals(1, vectorValue.length); assertEquals((float) ids.longValue(), vectorValue[0], 0.001f); diff --git a/lucene/core/src/test/org/apache/lucene/internal/vectorization/TestVectorScorer.java b/lucene/core/src/test/org/apache/lucene/internal/vectorization/TestVectorScorer.java index da9c312ef96d..b935e83331bf 100644 --- a/lucene/core/src/test/org/apache/lucene/internal/vectorization/TestVectorScorer.java +++ b/lucene/core/src/test/org/apache/lucene/internal/vectorization/TestVectorScorer.java @@ -39,6 +39,7 @@ import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer; import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; import org.apache.lucene.codecs.lucene95.OffHeapByteVectorValues; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.store.Directory; import org.apache.lucene.store.IOContext; @@ -47,7 +48,6 @@ import org.apache.lucene.store.MMapDirectory; import org.apache.lucene.tests.util.LuceneTestCase; import org.apache.lucene.util.NamedThreadFactory; -import org.apache.lucene.util.hnsw.RandomAccessVectorValues; import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.junit.BeforeClass; @@ -329,8 +329,8 @@ public void testLarge() throws IOException { } } - RandomAccessVectorValues vectorValues( - int dims, int size, IndexInput in, VectorSimilarityFunction sim) throws IOException { + KnnVectorValues vectorValues(int dims, int size, IndexInput in, VectorSimilarityFunction sim) + throws IOException { return new OffHeapByteVectorValues.DenseOffHeapVectorValues( dims, size, in.slice("byteValues", 0, in.length()), dims, MEMSEG_SCORER, sim); } diff --git a/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java b/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java index 21a33f9ca3e7..afa150e387fe 100644 --- a/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java +++ b/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java @@ -38,6 +38,7 @@ import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.QueryTimeout; @@ -740,7 +741,7 @@ public void testMergeAwayAllValues() throws IOException { LeafReader leafReader = getOnlyLeafReader(reader); FieldInfo fi = leafReader.getFieldInfos().fieldInfo("field"); assertNotNull(fi); - DocIdSetIterator vectorValues; + KnnVectorValues vectorValues; switch (fi.getVectorEncoding()) { case BYTE: vectorValues = leafReader.getByteVectorValues("field"); @@ -752,7 +753,7 @@ public void testMergeAwayAllValues() throws IOException { throw new AssertionError(); } assertNotNull(vectorValues); - assertEquals(NO_MORE_DOCS, vectorValues.nextDoc()); + assertEquals(NO_MORE_DOCS, vectorValues.iterator().nextDoc()); } } } diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/AbstractMockVectorValues.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/AbstractMockVectorValues.java deleted file mode 100644 index 54de3919b516..000000000000 --- a/lucene/core/src/test/org/apache/lucene/util/hnsw/AbstractMockVectorValues.java +++ /dev/null @@ -1,89 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.lucene.util.hnsw; - -import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; - -import java.io.IOException; -import org.apache.lucene.util.BytesRef; - -abstract class AbstractMockVectorValues implements RandomAccessVectorValues { - - protected final int dimension; - protected final T[] denseValues; - protected final T[] values; - protected final int numVectors; - protected final BytesRef binaryValue; - - protected int pos = -1; - - AbstractMockVectorValues(T[] values, int dimension, T[] denseValues, int numVectors) { - this.dimension = dimension; - this.values = values; - this.denseValues = denseValues; - // used by tests that build a graph from bytes rather than floats - binaryValue = new BytesRef(dimension); - binaryValue.length = dimension; - this.numVectors = numVectors; - } - - @Override - public int size() { - return numVectors; - } - - @Override - public int dimension() { - return dimension; - } - - public T vectorValue(int targetOrd) { - return denseValues[targetOrd]; - } - - @Override - public abstract AbstractMockVectorValues copy(); - - public abstract T vectorValue() throws IOException; - - private boolean seek(int target) { - if (target >= 0 && target < values.length && values[target] != null) { - pos = target; - return true; - } else { - return false; - } - } - - public int docID() { - return pos; - } - - public int nextDoc() { - return advance(pos + 1); - } - - public int advance(int target) { - while (++pos < values.length) { - if (seek(pos)) { - return pos; - } - } - return NO_MORE_DOCS; - } -} diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java index 4a6794b4994e..41aeef2e5c8d 100644 --- a/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java +++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java @@ -56,6 +56,7 @@ import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.StoredFields; @@ -97,33 +98,28 @@ abstract class HnswGraphTestCase extends LuceneTestCase { abstract T randomVector(int dim); - abstract AbstractMockVectorValues vectorValues(int size, int dimension); + abstract KnnVectorValues vectorValues(int size, int dimension); - abstract AbstractMockVectorValues vectorValues(float[][] values); + abstract KnnVectorValues vectorValues(float[][] values); - abstract AbstractMockVectorValues vectorValues(LeafReader reader, String fieldName) - throws IOException; + abstract KnnVectorValues vectorValues(LeafReader reader, String fieldName) throws IOException; - abstract AbstractMockVectorValues vectorValues( - int size, - int dimension, - AbstractMockVectorValues pregeneratedVectorValues, - int pregeneratedOffset); + abstract KnnVectorValues vectorValues( + int size, int dimension, KnnVectorValues pregeneratedVectorValues, int pregeneratedOffset); abstract Field knnVectorField(String name, T vector, VectorSimilarityFunction similarityFunction); - abstract RandomAccessVectorValues circularVectorValues(int nDoc); + abstract KnnVectorValues circularVectorValues(int nDoc); abstract T getTargetVector(); - protected RandomVectorScorerSupplier buildScorerSupplier(RandomAccessVectorValues vectors) + protected RandomVectorScorerSupplier buildScorerSupplier(KnnVectorValues vectors) throws IOException { return flatVectorScorer.getRandomVectorScorerSupplier(similarityFunction, vectors); } - protected RandomVectorScorer buildScorer(RandomAccessVectorValues vectors, T query) - throws IOException { - RandomAccessVectorValues vectorsCopy = vectors.copy(); + protected RandomVectorScorer buildScorer(KnnVectorValues vectors, T query) throws IOException { + KnnVectorValues vectorsCopy = vectors.copy(); return switch (getVectorEncoding()) { case BYTE -> flatVectorScorer.getRandomVectorScorer(similarityFunction, vectorsCopy, (byte[]) query); @@ -134,6 +130,7 @@ protected RandomVectorScorer buildScorer(RandomAccessVectorValues vectors, T que // Tests writing segments of various sizes and merging to ensure there are no errors // in the HNSW graph merging logic. + @SuppressWarnings("unchecked") public void testRandomReadWriteAndMerge() throws IOException { int dim = random().nextInt(100) + 1; int[] segmentSizes = @@ -148,7 +145,7 @@ public void testRandomReadWriteAndMerge() throws IOException { int M = random().nextInt(4) + 2; int beamWidth = random().nextInt(10) + 5; long seed = random().nextLong(); - AbstractMockVectorValues vectors = vectorValues(numVectors, dim); + KnnVectorValues vectors = vectorValues(numVectors, dim); HnswGraphBuilder.randSeed = seed; try (Directory dir = newDirectory()) { @@ -173,7 +170,7 @@ public KnnVectorsFormat getKnnVectorsFormatForField(String field) { try (IndexWriter iw = new IndexWriter(dir, iwc)) { for (int i = 0; i < segmentSizes.length; i++) { int size = segmentSizes[i]; - while (vectors.nextDoc() < size) { + for (int ord = 0; ord < size; ord++) { if (isSparse[i] && random().nextBoolean()) { int d = random().nextInt(10) + 1; for (int j = 0; j < d; j++) { @@ -182,8 +179,24 @@ public KnnVectorsFormat getKnnVectorsFormatForField(String field) { } } Document doc = new Document(); - doc.add(knnVectorField("field", vectors.vectorValue(), similarityFunction)); - doc.add(new StringField("id", Integer.toString(vectors.docID()), Field.Store.NO)); + switch (vectors.getEncoding()) { + case BYTE -> { + doc.add( + knnVectorField( + "field", + (T) ((ByteVectorValues) vectors).vectorValue(ord), + similarityFunction)); + } + case FLOAT32 -> { + doc.add( + knnVectorField( + "field", + (T) ((FloatVectorValues) vectors).vectorValue(ord), + similarityFunction)); + } + } + ; + doc.add(new StringField("id", Integer.toString(vectors.ordToDoc(ord)), Field.Store.NO)); iw.addDocument(doc); } iw.commit(); @@ -199,13 +212,26 @@ public KnnVectorsFormat getKnnVectorsFormatForField(String field) { } try (IndexReader reader = DirectoryReader.open(dir)) { for (LeafReaderContext ctx : reader.leaves()) { - AbstractMockVectorValues values = vectorValues(ctx.reader(), "field"); + KnnVectorValues values = vectorValues(ctx.reader(), "field"); assertEquals(dim, values.dimension()); } } } } + @SuppressWarnings("unchecked") + private T vectorValue(KnnVectorValues vectors, int ord) throws IOException { + switch (vectors.getEncoding()) { + case BYTE -> { + return (T) ((ByteVectorValues) vectors).vectorValue(ord); + } + case FLOAT32 -> { + return (T) ((FloatVectorValues) vectors).vectorValue(ord); + } + } + throw new AssertionError("unknown encoding " + vectors.getEncoding()); + } + // test writing out and reading in a graph gives the expected graph public void testReadWrite() throws IOException { int dim = random().nextInt(100) + 1; @@ -213,8 +239,8 @@ public void testReadWrite() throws IOException { int M = random().nextInt(4) + 2; int beamWidth = random().nextInt(10) + 5; long seed = random().nextLong(); - AbstractMockVectorValues vectors = vectorValues(nDoc, dim); - AbstractMockVectorValues v2 = vectors.copy(), v3 = vectors.copy(); + KnnVectorValues vectors = vectorValues(nDoc, dim); + KnnVectorValues v2 = vectors.copy(), v3 = vectors.copy(); RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors); HnswGraphBuilder builder = HnswGraphBuilder.create(scorerSupplier, M, beamWidth, seed); HnswGraph hnsw = builder.build(vectors.size()); @@ -242,15 +268,16 @@ public KnnVectorsFormat getKnnVectorsFormatForField(String field) { } }); try (IndexWriter iw = new IndexWriter(dir, iwc)) { - while (v2.nextDoc() != NO_MORE_DOCS) { - while (indexedDoc < v2.docID()) { + KnnVectorValues.DocIndexIterator it2 = v2.iterator(); + while (it2.nextDoc() != NO_MORE_DOCS) { + while (indexedDoc < it2.docID()) { // increment docId in the index by adding empty documents iw.addDocument(new Document()); indexedDoc++; } Document doc = new Document(); - doc.add(knnVectorField("field", v2.vectorValue(), similarityFunction)); - doc.add(new StoredField("id", v2.docID())); + doc.add(knnVectorField("field", vectorValue(v2, it2.index()), similarityFunction)); + doc.add(new StoredField("id", it2.docID())); iw.addDocument(doc); nVec++; indexedDoc++; @@ -258,7 +285,7 @@ public KnnVectorsFormat getKnnVectorsFormatForField(String field) { } try (IndexReader reader = DirectoryReader.open(dir)) { for (LeafReaderContext ctx : reader.leaves()) { - AbstractMockVectorValues values = vectorValues(ctx.reader(), "field"); + KnnVectorValues values = vectorValues(ctx.reader(), "field"); assertEquals(dim, values.dimension()); assertEquals(nVec, values.size()); assertEquals(indexedDoc, ctx.reader().maxDoc()); @@ -280,7 +307,7 @@ public KnnVectorsFormat getKnnVectorsFormatForField(String field) { public void testSortedAndUnsortedIndicesReturnSameResults() throws IOException { int dim = random().nextInt(10) + 3; int nDoc = random().nextInt(200) + 100; - AbstractMockVectorValues vectors = vectorValues(nDoc, dim); + KnnVectorValues vectors = vectorValues(nDoc, dim); int M = random().nextInt(10) + 5; int beamWidth = random().nextInt(10) + 10; @@ -323,15 +350,15 @@ public KnnVectorsFormat getKnnVectorsFormatForField(String field) { int indexedDoc = 0; try (IndexWriter iw = new IndexWriter(dir, iwc); IndexWriter iw2 = new IndexWriter(dir2, iwc2)) { - while (vectors.nextDoc() != NO_MORE_DOCS) { - while (indexedDoc < vectors.docID()) { + for (int ord = 0; ord < vectors.size(); ord++) { + while (indexedDoc < vectors.ordToDoc(ord)) { // increment docId in the index by adding empty documents iw.addDocument(new Document()); indexedDoc++; } Document doc = new Document(); - doc.add(knnVectorField("vector", vectors.vectorValue(), similarityFunction)); - doc.add(new StoredField("id", vectors.docID())); + doc.add(knnVectorField("vector", vectorValue(vectors, ord), similarityFunction)); + doc.add(new StoredField("id", vectors.ordToDoc(ord))); doc.add(new NumericDocValuesField("sortkey", random().nextLong())); iw.addDocument(doc); iw2.addDocument(doc); @@ -461,7 +488,7 @@ void assertGraphEqual(HnswGraph g, HnswGraph h) throws IOException { public void testAknnDiverse() throws IOException { int nDoc = 100; similarityFunction = VectorSimilarityFunction.DOT_PRODUCT; - RandomAccessVectorValues vectors = circularVectorValues(nDoc); + KnnVectorValues vectors = circularVectorValues(nDoc); RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors); HnswGraphBuilder builder = HnswGraphBuilder.create(scorerSupplier, 10, 100, random().nextInt()); OnHeapHnswGraph hnsw = builder.build(vectors.size()); @@ -493,7 +520,7 @@ public void testAknnDiverse() throws IOException { @SuppressWarnings("unchecked") public void testSearchWithAcceptOrds() throws IOException { int nDoc = 100; - RandomAccessVectorValues vectors = circularVectorValues(nDoc); + KnnVectorValues vectors = circularVectorValues(nDoc); similarityFunction = VectorSimilarityFunction.DOT_PRODUCT; RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors); HnswGraphBuilder builder = HnswGraphBuilder.create(scorerSupplier, 16, 100, random().nextInt()); @@ -518,7 +545,7 @@ public void testSearchWithAcceptOrds() throws IOException { @SuppressWarnings("unchecked") public void testSearchWithSelectiveAcceptOrds() throws IOException { int nDoc = 100; - RandomAccessVectorValues vectors = circularVectorValues(nDoc); + KnnVectorValues vectors = circularVectorValues(nDoc); similarityFunction = VectorSimilarityFunction.DOT_PRODUCT; RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors); HnswGraphBuilder builder = HnswGraphBuilder.create(scorerSupplier, 16, 100, random().nextInt()); @@ -552,13 +579,13 @@ public void testHnswGraphBuilderInitializationFromGraph_withOffsetZero() throws int dim = atLeast(10); long seed = random().nextLong(); - AbstractMockVectorValues initializerVectors = vectorValues(initializerSize, dim); + KnnVectorValues initializerVectors = vectorValues(initializerSize, dim); RandomVectorScorerSupplier initialscorerSupplier = buildScorerSupplier(initializerVectors); HnswGraphBuilder initializerBuilder = HnswGraphBuilder.create(initialscorerSupplier, 10, 30, seed); OnHeapHnswGraph initializerGraph = initializerBuilder.build(initializerVectors.size()); - AbstractMockVectorValues finalVectorValues = + KnnVectorValues finalVectorValues = vectorValues(totalSize, dim, initializerVectors, docIdOffset); int[] initializerOrdMap = createOffsetOrdinalMap(initializerSize, finalVectorValues, docIdOffset); @@ -598,13 +625,13 @@ public void testHnswGraphBuilderInitializationFromGraph_withNonZeroOffset() thro int dim = atLeast(10); long seed = random().nextLong(); - AbstractMockVectorValues initializerVectors = vectorValues(initializerSize, dim); + KnnVectorValues initializerVectors = vectorValues(initializerSize, dim); RandomVectorScorerSupplier initialscorerSupplier = buildScorerSupplier(initializerVectors); HnswGraphBuilder initializerBuilder = HnswGraphBuilder.create(initialscorerSupplier, 10, 30, seed); OnHeapHnswGraph initializerGraph = initializerBuilder.build(initializerVectors.size()); - AbstractMockVectorValues finalVectorValues = + KnnVectorValues finalVectorValues = vectorValues(totalSize, dim, initializerVectors.copy(), docIdOffset); int[] initializerOrdMap = createOffsetOrdinalMap(initializerSize, finalVectorValues, docIdOffset); @@ -688,19 +715,17 @@ private int[] mapArrayAndSort(int[] arr, int[] offset) { } private int[] createOffsetOrdinalMap( - int docIdSize, AbstractMockVectorValues totalVectorValues, int docIdOffset) { + int docIdSize, KnnVectorValues totalVectorValues, int docIdOffset) throws IOException { // Compute the offset for the ordinal map to be the number of non-null vectors in the total - // vector values - // before the docIdOffset + // vector values before the docIdOffset int ordinalOffset = 0; - while (totalVectorValues.nextDoc() < docIdOffset) { + KnnVectorValues.DocIndexIterator it = totalVectorValues.iterator(); + while (it.nextDoc() < docIdOffset) { ordinalOffset++; } int[] offsetOrdinalMap = new int[docIdSize]; - for (int curr = 0; - totalVectorValues.docID() < docIdOffset + docIdSize; - totalVectorValues.nextDoc()) { + for (int curr = 0; it.docID() < docIdOffset + docIdSize; it.nextDoc()) { offsetOrdinalMap[curr] = ordinalOffset + curr++; } @@ -711,7 +736,7 @@ private int[] createOffsetOrdinalMap( public void testVisitedLimit() throws IOException { int nDoc = 500; similarityFunction = VectorSimilarityFunction.DOT_PRODUCT; - RandomAccessVectorValues vectors = circularVectorValues(nDoc); + KnnVectorValues vectors = circularVectorValues(nDoc); RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors); HnswGraphBuilder builder = HnswGraphBuilder.create(scorerSupplier, 16, 100, random().nextInt()); OnHeapHnswGraph hnsw = builder.build(vectors.size()); @@ -746,7 +771,7 @@ public void testRamUsageEstimate() throws IOException { int M = randomIntBetween(4, 96); similarityFunction = RandomizedTest.randomFrom(VectorSimilarityFunction.values()); - RandomAccessVectorValues vectors = vectorValues(size, dim); + KnnVectorValues vectors = vectorValues(size, dim); RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors); HnswGraphBuilder builder = @@ -771,7 +796,7 @@ public void testDiversity() throws IOException { unitVector2d(0.77), unitVector2d(0.6) }; - AbstractMockVectorValues vectors = vectorValues(values); + KnnVectorValues vectors = vectorValues(values); // First add nodes until everybody gets a full neighbor list RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors); HnswGraphBuilder builder = HnswGraphBuilder.create(scorerSupplier, 2, 10, random().nextInt()); @@ -825,7 +850,7 @@ public void testDiversityFallback() throws IOException { {10, 0, 0}, {0, 4, 0} }; - AbstractMockVectorValues vectors = vectorValues(values); + KnnVectorValues vectors = vectorValues(values); // First add nodes until everybody gets a full neighbor list RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors); HnswGraphBuilder builder = HnswGraphBuilder.create(scorerSupplier, 1, 10, random().nextInt()); @@ -855,7 +880,7 @@ public void testDiversity3d() throws IOException { {0, 0, 20}, {0, 9, 0} }; - AbstractMockVectorValues vectors = vectorValues(values); + KnnVectorValues vectors = vectorValues(values); // First add nodes until everybody gets a full neighbor list RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors); HnswGraphBuilder builder = HnswGraphBuilder.create(scorerSupplier, 1, 10, random().nextInt()); @@ -891,7 +916,7 @@ private void assertLevel0Neighbors(OnHeapHnswGraph graph, int node, int... expec public void testRandom() throws IOException { int size = atLeast(100); int dim = atLeast(10); - AbstractMockVectorValues vectors = vectorValues(size, dim); + KnnVectorValues vectors = vectorValues(size, dim); int topK = 5; RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors); HnswGraphBuilder builder = HnswGraphBuilder.create(scorerSupplier, 10, 30, random().nextLong()); @@ -908,15 +933,13 @@ public void testRandom() throws IOException { TopDocs topDocs = actual.topDocs(); NeighborQueue expected = new NeighborQueue(topK, false); for (int j = 0; j < size; j++) { - if (vectors.vectorValue(j) != null && (acceptOrds == null || acceptOrds.get(j))) { + if (vectorValue(vectors, j) != null && (acceptOrds == null || acceptOrds.get(j))) { if (getVectorEncoding() == VectorEncoding.BYTE) { - assert query instanceof byte[]; expected.add( - j, similarityFunction.compare((byte[]) query, (byte[]) vectors.vectorValue(j))); + j, similarityFunction.compare((byte[]) query, (byte[]) vectorValue(vectors, j))); } else { - assert query instanceof float[]; expected.add( - j, similarityFunction.compare((float[]) query, (float[]) vectors.vectorValue(j))); + j, similarityFunction.compare((float[]) query, (float[]) vectorValue(vectors, j))); } if (expected.size() > topK) { expected.pop(); @@ -940,7 +963,7 @@ public void testOnHeapHnswGraphSearch() throws IOException, ExecutionException, InterruptedException, TimeoutException { int size = atLeast(100); int dim = atLeast(10); - AbstractMockVectorValues vectors = vectorValues(size, dim); + KnnVectorValues vectors = vectorValues(size, dim); RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors); HnswGraphBuilder builder = HnswGraphBuilder.create(scorerSupplier, 10, 30, random().nextLong()); OnHeapHnswGraph hnsw = builder.build(vectors.size()); @@ -1004,7 +1027,7 @@ public void testOnHeapHnswGraphSearch() public void testConcurrentMergeBuilder() throws IOException { int size = atLeast(1000); int dim = atLeast(10); - AbstractMockVectorValues vectors = vectorValues(size, dim); + KnnVectorValues vectors = vectorValues(size, dim); RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors); ExecutorService exec = Executors.newFixedThreadPool(4, new NamedThreadFactory("hnswMerge")); TaskExecutor taskExecutor = new TaskExecutor(exec); @@ -1033,7 +1056,7 @@ public void testAllNodesVisitedInSingleLevel() throws IOException { // Search for a large number of results int topK = size - 1; - AbstractMockVectorValues docVectors = vectorValues(size, dim); + KnnVectorValues docVectors = vectorValues(size, dim); HnswGraph graph = HnswGraphBuilder.create(buildScorerSupplier(docVectors), 10, 30, random().nextLong()) .build(size); @@ -1047,8 +1070,8 @@ public int numLevels() { } }; - AbstractMockVectorValues queryVectors = vectorValues(1, dim); - RandomVectorScorer queryScorer = buildScorer(docVectors, queryVectors.vectorValue(0)); + KnnVectorValues queryVectors = vectorValues(1, dim); + RandomVectorScorer queryScorer = buildScorer(docVectors, vectorValue(queryVectors, 0)); KnnCollector collector = new TopKnnCollector(topK, Integer.MAX_VALUE); HnswGraphSearcher.search(queryScorer, collector, singleLevelGraph, null); @@ -1076,8 +1099,7 @@ private int computeOverlap(int[] a, int[] b) { } /** Returns vectors evenly distributed around the upper unit semicircle. */ - static class CircularFloatVectorValues extends FloatVectorValues - implements RandomAccessVectorValues.Floats { + static class CircularFloatVectorValues extends FloatVectorValues { private final int size; private final float[] value; @@ -1103,22 +1125,18 @@ public int size() { return size; } - @Override public float[] vectorValue() { return vectorValue(doc); } - @Override public int docID() { return doc; } - @Override public int nextDoc() { return advance(doc + 1); } - @Override public int advance(int target) { if (target >= 0 && target < size) { doc = target; @@ -1140,8 +1158,7 @@ public VectorScorer scorer(float[] target) { } /** Returns vectors evenly distributed around the upper unit semicircle. */ - static class CircularByteVectorValues extends ByteVectorValues - implements RandomAccessVectorValues.Bytes { + static class CircularByteVectorValues extends ByteVectorValues { private final int size; private final float[] value; private final byte[] bValue; @@ -1169,22 +1186,18 @@ public int size() { return size; } - @Override public byte[] vectorValue() { return vectorValue(doc); } - @Override public int docID() { return doc; } - @Override public int nextDoc() { return advance(doc + 1); } - @Override public int advance(int target) { if (target >= 0 && target < size) { doc = target; @@ -1227,27 +1240,25 @@ private Set getNeighborNodes(HnswGraph g) throws IOException { return neighbors; } - void assertVectorsEqual(AbstractMockVectorValues u, AbstractMockVectorValues v) - throws IOException { + void assertVectorsEqual(KnnVectorValues u, KnnVectorValues v) throws IOException { int uDoc, vDoc; - while (true) { - uDoc = u.nextDoc(); - vDoc = v.nextDoc(); + assertEquals(u.size(), v.size()); + for (int ord = 0; ord < u.size(); ord++) { + uDoc = u.ordToDoc(ord); + vDoc = v.ordToDoc(ord); assertEquals(uDoc, vDoc); - if (uDoc == NO_MORE_DOCS) { - break; - } + assertNotEquals(NO_MORE_DOCS, uDoc); switch (getVectorEncoding()) { case BYTE -> assertArrayEquals( "vectors do not match for doc=" + uDoc, - (byte[]) u.vectorValue(), - (byte[]) v.vectorValue()); + (byte[]) vectorValue(u, ord), + (byte[]) vectorValue(v, ord)); case FLOAT32 -> assertArrayEquals( "vectors do not match for doc=" + uDoc, - (float[]) u.vectorValue(), - (float[]) v.vectorValue(), + (float[]) vectorValue(u, ord), + (float[]) vectorValue(v, ord), 1e-4f); default -> throw new IllegalArgumentException("unknown vector encoding: " + getVectorEncoding()); diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/MockByteVectorValues.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/MockByteVectorValues.java index a3b17b9a621e..4ab86c707816 100644 --- a/lucene/core/src/test/org/apache/lucene/util/hnsw/MockByteVectorValues.java +++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/MockByteVectorValues.java @@ -17,11 +17,17 @@ package org.apache.lucene.util.hnsw; +import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.tests.util.LuceneTestCase; import org.apache.lucene.util.ArrayUtil; +import org.apache.lucene.util.BytesRef; -class MockByteVectorValues extends AbstractMockVectorValues - implements RandomAccessVectorValues.Bytes { +class MockByteVectorValues extends ByteVectorValues { + private final int dimension; + private final byte[][] denseValues; + protected final byte[][] values; + private final int numVectors; + private final BytesRef binaryValue; private final byte[] scratch; static MockByteVectorValues fromValues(byte[][] values) { @@ -43,10 +49,26 @@ static MockByteVectorValues fromValues(byte[][] values) { } MockByteVectorValues(byte[][] values, int dimension, byte[][] denseValues, int numVectors) { - super(values, dimension, denseValues, numVectors); + this.dimension = dimension; + this.values = values; + this.denseValues = denseValues; + this.numVectors = numVectors; + // used by tests that build a graph from bytes rather than floats + binaryValue = new BytesRef(dimension); + binaryValue.length = dimension; scratch = new byte[dimension]; } + @Override + public int size() { + return values.length; + } + + @Override + public int dimension() { + return dimension; + } + @Override public MockByteVectorValues copy() { return new MockByteVectorValues( @@ -55,20 +77,20 @@ public MockByteVectorValues copy() { @Override public byte[] vectorValue(int ord) { - return values[ord]; - } - - @Override - public byte[] vectorValue() { if (LuceneTestCase.random().nextBoolean()) { - return values[pos]; + return values[ord]; } else { // Sometimes use the same scratch array repeatedly, mimicing what the codec will do. // This should help us catch cases of aliasing where the same ByteVectorValues source is used // twice in a // single computation. - System.arraycopy(values[pos], 0, scratch, 0, dimension); + System.arraycopy(values[ord], 0, scratch, 0, dimension); return scratch; } } + + @Override + public DocIndexIterator iterator() { + return createDenseIterator(); + } } diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/MockVectorValues.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/MockVectorValues.java index f183f6c99a67..5411f2418de3 100644 --- a/lucene/core/src/test/org/apache/lucene/util/hnsw/MockVectorValues.java +++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/MockVectorValues.java @@ -17,11 +17,15 @@ package org.apache.lucene.util.hnsw; +import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.tests.util.LuceneTestCase; import org.apache.lucene.util.ArrayUtil; -class MockVectorValues extends AbstractMockVectorValues - implements RandomAccessVectorValues.Floats { +class MockVectorValues extends FloatVectorValues { + private final int dimension; + private final float[][] denseValues; + protected final float[][] values; + private final int numVectors; private final float[] scratch; static MockVectorValues fromValues(float[][] values) { @@ -43,10 +47,23 @@ static MockVectorValues fromValues(float[][] values) { } MockVectorValues(float[][] values, int dimension, float[][] denseValues, int numVectors) { - super(values, dimension, denseValues, numVectors); + this.dimension = dimension; + this.values = values; + this.denseValues = denseValues; + this.numVectors = numVectors; this.scratch = new float[dimension]; } + @Override + public int size() { + return values.length; + } + + @Override + public int dimension() { + return dimension; + } + @Override public MockVectorValues copy() { return new MockVectorValues( @@ -54,20 +71,20 @@ public MockVectorValues copy() { } @Override - public float[] vectorValue() { + public float[] vectorValue(int ord) { if (LuceneTestCase.random().nextBoolean()) { - return values[pos]; + return values[ord]; } else { // Sometimes use the same scratch array repeatedly, mimicing what the codec will do. // This should help us catch cases of aliasing where the same vector values source is used // twice in a single computation. - System.arraycopy(values[pos], 0, scratch, 0, dimension); + System.arraycopy(values[ord], 0, scratch, 0, dimension); return scratch; } } @Override - public float[] vectorValue(int targetOrd) { - return denseValues[targetOrd]; + public DocIndexIterator iterator() { + return createDenseIterator(); } } diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswByteVectorGraph.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswByteVectorGraph.java index 649bc1a64519..f0e6745211c6 100644 --- a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswByteVectorGraph.java +++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswByteVectorGraph.java @@ -17,13 +17,12 @@ package org.apache.lucene.util.hnsw; -import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; - import com.carrotsearch.randomizedtesting.RandomizedTest; import java.io.IOException; import org.apache.lucene.document.Field; import org.apache.lucene.document.KnnByteVectorField; import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; @@ -56,7 +55,7 @@ byte[] randomVector(int dim) { } @Override - AbstractMockVectorValues vectorValues(int size, int dimension) { + MockByteVectorValues vectorValues(int size, int dimension) { return MockByteVectorValues.fromValues(createRandomByteVectors(size, dimension, random())); } @@ -65,7 +64,7 @@ static boolean fitsInByte(float v) { } @Override - AbstractMockVectorValues vectorValues(float[][] values) { + MockByteVectorValues vectorValues(float[][] values) { byte[][] bValues = new byte[values.length][]; // The case when all floats fit within a byte already. boolean scaleSimple = fitsInByte(values[0][0]); @@ -86,42 +85,35 @@ AbstractMockVectorValues vectorValues(float[][] values) { } @Override - AbstractMockVectorValues vectorValues( - int size, - int dimension, - AbstractMockVectorValues pregeneratedVectorValues, - int pregeneratedOffset) { + MockByteVectorValues vectorValues( + int size, int dimension, KnnVectorValues pregeneratedVectorValues, int pregeneratedOffset) { + + MockByteVectorValues pvv = (MockByteVectorValues) pregeneratedVectorValues; byte[][] vectors = new byte[size][]; - byte[][] randomVectors = - createRandomByteVectors(size - pregeneratedVectorValues.values.length, dimension, random()); + byte[][] randomVectors = createRandomByteVectors(size - pvv.values.length, dimension, random()); for (int i = 0; i < pregeneratedOffset; i++) { vectors[i] = randomVectors[i]; } - int currentDoc; - while ((currentDoc = pregeneratedVectorValues.nextDoc()) != NO_MORE_DOCS) { - vectors[pregeneratedOffset + currentDoc] = pregeneratedVectorValues.values[currentDoc]; + for (int currentOrd = 0; currentOrd < pvv.size(); currentOrd++) { + vectors[pregeneratedOffset + currentOrd] = pvv.values[currentOrd]; } - for (int i = pregeneratedOffset + pregeneratedVectorValues.values.length; - i < vectors.length; - i++) { - vectors[i] = randomVectors[i - pregeneratedVectorValues.values.length]; + for (int i = pregeneratedOffset + pvv.values.length; i < vectors.length; i++) { + vectors[i] = randomVectors[i - pvv.values.length]; } return MockByteVectorValues.fromValues(vectors); } @Override - AbstractMockVectorValues vectorValues(LeafReader reader, String fieldName) - throws IOException { + MockByteVectorValues vectorValues(LeafReader reader, String fieldName) throws IOException { ByteVectorValues vectorValues = reader.getByteVectorValues(fieldName); byte[][] vectors = new byte[reader.maxDoc()][]; - while (vectorValues.nextDoc() != NO_MORE_DOCS) { - vectors[vectorValues.docID()] = - ArrayUtil.copyOfSubArray( - vectorValues.vectorValue(), 0, vectorValues.vectorValue().length); + for (int i = 0; i < vectorValues.size(); i++) { + vectors[vectorValues.ordToDoc(i)] = + ArrayUtil.copyOfSubArray(vectorValues.vectorValue(i), 0, vectorValues.dimension()); } return MockByteVectorValues.fromValues(vectors); } diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswFloatVectorGraph.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswFloatVectorGraph.java index 5621edc4b35e..52d1da3dfa83 100644 --- a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswFloatVectorGraph.java +++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswFloatVectorGraph.java @@ -17,13 +17,12 @@ package org.apache.lucene.util.hnsw; -import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; - import com.carrotsearch.randomizedtesting.RandomizedTest; import java.io.IOException; import org.apache.lucene.document.Field; import org.apache.lucene.document.KnnFloatVectorField; import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; @@ -60,52 +59,44 @@ float[] randomVector(int dim) { } @Override - AbstractMockVectorValues vectorValues(int size, int dimension) { + MockVectorValues vectorValues(int size, int dimension) { return MockVectorValues.fromValues(createRandomFloatVectors(size, dimension, random())); } @Override - AbstractMockVectorValues vectorValues(float[][] values) { + MockVectorValues vectorValues(float[][] values) { return MockVectorValues.fromValues(values); } @Override - AbstractMockVectorValues vectorValues(LeafReader reader, String fieldName) - throws IOException { + MockVectorValues vectorValues(LeafReader reader, String fieldName) throws IOException { FloatVectorValues vectorValues = reader.getFloatVectorValues(fieldName); float[][] vectors = new float[reader.maxDoc()][]; - while (vectorValues.nextDoc() != NO_MORE_DOCS) { - vectors[vectorValues.docID()] = - ArrayUtil.copyOfSubArray( - vectorValues.vectorValue(), 0, vectorValues.vectorValue().length); + for (int i = 0; i < vectorValues.size(); i++) { + vectors[vectorValues.ordToDoc(i)] = + ArrayUtil.copyOfSubArray(vectorValues.vectorValue(i), 0, vectorValues.dimension()); } return MockVectorValues.fromValues(vectors); } @Override - AbstractMockVectorValues vectorValues( - int size, - int dimension, - AbstractMockVectorValues pregeneratedVectorValues, - int pregeneratedOffset) { + MockVectorValues vectorValues( + int size, int dimension, KnnVectorValues pregeneratedVectorValues, int pregeneratedOffset) { + MockVectorValues pvv = (MockVectorValues) pregeneratedVectorValues; float[][] vectors = new float[size][]; float[][] randomVectors = - createRandomFloatVectors( - size - pregeneratedVectorValues.values.length, dimension, random()); + createRandomFloatVectors(size - pvv.values.length, dimension, random()); for (int i = 0; i < pregeneratedOffset; i++) { vectors[i] = randomVectors[i]; } - int currentDoc; - while ((currentDoc = pregeneratedVectorValues.nextDoc()) != NO_MORE_DOCS) { - vectors[pregeneratedOffset + currentDoc] = pregeneratedVectorValues.values[currentDoc]; + for (int currentOrd = 0; currentOrd < pvv.size(); currentOrd++) { + vectors[pregeneratedOffset + currentOrd] = pvv.values[currentOrd]; } - for (int i = pregeneratedOffset + pregeneratedVectorValues.values.length; - i < vectors.length; - i++) { - vectors[i] = randomVectors[i - pregeneratedVectorValues.values.length]; + for (int i = pregeneratedOffset + pvv.values.length; i < vectors.length; i++) { + vectors[i] = randomVectors[i - pvv.values.length]; } return MockVectorValues.fromValues(vectors); @@ -129,7 +120,7 @@ float[] getTargetVector() { public void testSearchWithSkewedAcceptOrds() throws IOException { int nDoc = 1000; similarityFunction = VectorSimilarityFunction.EUCLIDEAN; - RandomAccessVectorValues.Floats vectors = circularVectorValues(nDoc); + FloatVectorValues vectors = circularVectorValues(nDoc); RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors); HnswGraphBuilder builder = HnswGraphBuilder.create(scorerSupplier, 16, 100, random().nextInt()); OnHeapHnswGraph hnsw = builder.build(vectors.size()); diff --git a/lucene/core/src/test/org/apache/lucene/util/quantization/TestScalarQuantizedVectorSimilarity.java b/lucene/core/src/test/org/apache/lucene/util/quantization/TestScalarQuantizedVectorSimilarity.java index bdba822d4eca..f2cc3ac35c05 100644 --- a/lucene/core/src/test/org/apache/lucene/util/quantization/TestScalarQuantizedVectorSimilarity.java +++ b/lucene/core/src/test/org/apache/lucene/util/quantization/TestScalarQuantizedVectorSimilarity.java @@ -59,8 +59,7 @@ public void testToEuclidean() throws IOException { float error = Math.max((100 - confidenceInterval) * 0.01f, 0.01f); FloatVectorValues floatVectorValues = fromFloats(floats); ScalarQuantizer scalarQuantizer = - ScalarQuantizer.fromVectors( - floatVectorValues, confidenceInterval, floats.length, (byte) 7); + ScalarQuantizer.fromVectors(floatVectorValues, confidenceInterval, numVecs, (byte) 7); byte[][] quantized = new byte[floats.length][]; float[] offsets = quantizeVectors(scalarQuantizer, floats, quantized, VectorSimilarityFunction.EUCLIDEAN); @@ -92,8 +91,7 @@ public void testToCosine() throws IOException { float error = Math.max((100 - confidenceInterval) * 0.01f, 0.01f); FloatVectorValues floatVectorValues = fromFloatsNormalized(floats, null); ScalarQuantizer scalarQuantizer = - ScalarQuantizer.fromVectors( - floatVectorValues, confidenceInterval, floats.length, (byte) 7); + ScalarQuantizer.fromVectors(floatVectorValues, confidenceInterval, numVecs, (byte) 7); byte[][] quantized = new byte[floats.length][]; float[] offsets = quantizeVectorsNormalized( @@ -129,8 +127,7 @@ public void testToDotProduct() throws IOException { float error = Math.max((100 - confidenceInterval) * 0.01f, 0.01f); FloatVectorValues floatVectorValues = fromFloats(floats); ScalarQuantizer scalarQuantizer = - ScalarQuantizer.fromVectors( - floatVectorValues, confidenceInterval, floats.length, (byte) 7); + ScalarQuantizer.fromVectors(floatVectorValues, confidenceInterval, numVecs, (byte) 7); byte[][] quantized = new byte[floats.length][]; float[] offsets = quantizeVectors(scalarQuantizer, floats, quantized, VectorSimilarityFunction.DOT_PRODUCT); @@ -162,8 +159,7 @@ public void testToMaxInnerProduct() throws IOException { float error = Math.max((100 - confidenceInterval) * 0.5f, 0.5f); FloatVectorValues floatVectorValues = fromFloats(floats); ScalarQuantizer scalarQuantizer = - ScalarQuantizer.fromVectors( - floatVectorValues, confidenceInterval, floats.length, (byte) 7); + ScalarQuantizer.fromVectors(floatVectorValues, confidenceInterval, numVecs, (byte) 7); byte[][] quantized = new byte[floats.length][]; float[] offsets = quantizeVectors( @@ -242,11 +238,8 @@ private static FloatVectorValues fromFloatsNormalized( float[][] floats, Set deletedVectors) { return new TestScalarQuantizer.TestSimpleFloatVectorValues(floats, deletedVectors) { @Override - public float[] vectorValue() throws IOException { - if (curDoc == -1 || curDoc >= floats.length) { - throw new IOException("Current doc not set or too many iterations"); - } - float[] v = ArrayUtil.copyArray(floats[curDoc]); + public float[] vectorValue(int ord) throws IOException { + float[] v = ArrayUtil.copyArray(floats[ordToDoc[ord]]); VectorUtil.l2normalize(v); return v; } diff --git a/lucene/core/src/test/org/apache/lucene/util/quantization/TestScalarQuantizer.java b/lucene/core/src/test/org/apache/lucene/util/quantization/TestScalarQuantizer.java index 48eb7ce651c6..7f56688b7999 100644 --- a/lucene/core/src/test/org/apache/lucene/util/quantization/TestScalarQuantizer.java +++ b/lucene/core/src/test/org/apache/lucene/util/quantization/TestScalarQuantizer.java @@ -272,14 +272,27 @@ static TestSimpleFloatVectorValues fromFloatsWithRandomDeletions( static class TestSimpleFloatVectorValues extends FloatVectorValues { protected final float[][] floats; protected final Set deletedVectors; + protected final int[] ordToDoc; protected final int numLiveVectors; - protected int curDoc = -1; TestSimpleFloatVectorValues(float[][] values, Set deletedVectors) { this.floats = values; this.deletedVectors = deletedVectors; - this.numLiveVectors = + numLiveVectors = deletedVectors == null ? values.length : values.length - deletedVectors.size(); + ordToDoc = new int[numLiveVectors]; + if (deletedVectors == null) { + for (int i = 0; i < numLiveVectors; i++) { + ordToDoc[i] = i; + } + } else { + int ord = 0; + for (int doc = 0; doc < values.length; doc++) { + if (!deletedVectors.contains(doc)) { + ordToDoc[ord++] = doc; + } + } + } } @Override @@ -293,40 +306,64 @@ public int size() { } @Override - public float[] vectorValue() throws IOException { - if (curDoc == -1 || curDoc >= floats.length) { - throw new IOException("Current doc not set or too many iterations"); - } - return floats[curDoc]; + public float[] vectorValue(int ord) throws IOException { + return floats[ordToDoc(ord)]; } @Override - public int docID() { - if (curDoc >= floats.length) { - return NO_MORE_DOCS; - } - return curDoc; + public int ordToDoc(int ord) { + return ordToDoc[ord]; } @Override - public int nextDoc() throws IOException { - while (++curDoc < floats.length) { - if (deletedVectors == null || !deletedVectors.contains(curDoc)) { - return curDoc; + public DocIndexIterator iterator() { + return new DocIndexIterator() { + + int ord = -1; + int doc = -1; + + @Override + public int docID() { + return doc; } - } - return docID(); - } - @Override - public int advance(int target) throws IOException { - curDoc = target - 1; - return nextDoc(); + @Override + public int nextDoc() throws IOException { + while (doc < floats.length - 1) { + ++doc; + if (deletedVectors == null || !deletedVectors.contains(doc)) { + ++ord; + return doc; + } + } + return doc = NO_MORE_DOCS; + } + + @Override + public int index() { + return ord; + } + + @Override + public long cost() { + return floats.length - deletedVectors.size(); + } + + @Override + public int advance(int target) throws IOException { + throw new UnsupportedOperationException(); + } + }; } @Override public VectorScorer scorer(float[] target) { throw new UnsupportedOperationException(); } + + @Override + public TestSimpleFloatVectorValues copy() { + return this; + } } } diff --git a/lucene/memory/src/java/org/apache/lucene/index/memory/MemoryIndex.java b/lucene/memory/src/java/org/apache/lucene/index/memory/MemoryIndex.java index 2d46b243d838..04ac9285baba 100644 --- a/lucene/memory/src/java/org/apache/lucene/index/memory/MemoryIndex.java +++ b/lucene/memory/src/java/org/apache/lucene/index/memory/MemoryIndex.java @@ -2285,7 +2285,6 @@ public int[] clear() { private static final class MemoryFloatVectorValues extends FloatVectorValues { private final Info info; - private int currentDoc = -1; MemoryFloatVectorValues(Info info) { this.info = info; @@ -2302,14 +2301,19 @@ public int size() { } @Override - public float[] vectorValue() { - if (currentDoc == 0) { + public float[] vectorValue(int ord) { + if (ord == 0) { return info.floatVectorValues[0]; } else { return null; } } + @Override + public DocIndexIterator iterator() { + return createDenseIterator(); + } + @Override public VectorScorer scorer(float[] query) { if (query.length != info.fieldInfo.getVectorDimension()) { @@ -2320,50 +2324,31 @@ public VectorScorer scorer(float[] query) { + info.fieldInfo.getVectorDimension()); } MemoryFloatVectorValues vectorValues = new MemoryFloatVectorValues(info); + DocIndexIterator iterator = vectorValues.iterator(); return new VectorScorer() { @Override public float score() throws IOException { + assert iterator.docID() == 0; return info.fieldInfo .getVectorSimilarityFunction() - .compare(vectorValues.vectorValue(), query); + .compare(vectorValues.vectorValue(0), query); } @Override public DocIdSetIterator iterator() { - return vectorValues; + return iterator; } }; } @Override - public int docID() { - return currentDoc; - } - - @Override - public int nextDoc() { - int doc = ++currentDoc; - if (doc == 0) { - return doc; - } else { - return NO_MORE_DOCS; - } - } - - @Override - public int advance(int target) { - if (target == 0) { - currentDoc = target; - return target; - } else { - return NO_MORE_DOCS; - } + public MemoryFloatVectorValues copy() { + return this; } } private static final class MemoryByteVectorValues extends ByteVectorValues { private final Info info; - private int currentDoc = -1; MemoryByteVectorValues(Info info) { this.info = info; @@ -2380,14 +2365,19 @@ public int size() { } @Override - public byte[] vectorValue() { - if (currentDoc == 0) { + public byte[] vectorValue(int ord) { + if (ord == 0) { return info.byteVectorValues[0]; } else { return null; } } + @Override + public DocIndexIterator iterator() { + return createDenseIterator(); + } + @Override public VectorScorer scorer(byte[] query) { if (query.length != info.fieldInfo.getVectorDimension()) { @@ -2398,44 +2388,26 @@ public VectorScorer scorer(byte[] query) { + info.fieldInfo.getVectorDimension()); } MemoryByteVectorValues vectorValues = new MemoryByteVectorValues(info); + DocIndexIterator iterator = vectorValues.iterator(); return new VectorScorer() { @Override public float score() { + assert iterator.docID() == 0; return info.fieldInfo .getVectorSimilarityFunction() - .compare(vectorValues.vectorValue(), query); + .compare(vectorValues.vectorValue(0), query); } @Override public DocIdSetIterator iterator() { - return vectorValues; + return iterator; } }; } @Override - public int docID() { - return currentDoc; - } - - @Override - public int nextDoc() { - int doc = ++currentDoc; - if (doc == 0) { - return doc; - } else { - return NO_MORE_DOCS; - } - } - - @Override - public int advance(int target) { - if (target == 0) { - currentDoc = target; - return target; - } else { - return NO_MORE_DOCS; - } + public MemoryByteVectorValues copy() { + return this; } } } diff --git a/lucene/memory/src/test/org/apache/lucene/index/memory/TestMemoryIndex.java b/lucene/memory/src/test/org/apache/lucene/index/memory/TestMemoryIndex.java index 18e97c67d9d9..7c5928689127 100644 --- a/lucene/memory/src/test/org/apache/lucene/index/memory/TestMemoryIndex.java +++ b/lucene/memory/src/test/org/apache/lucene/index/memory/TestMemoryIndex.java @@ -63,6 +63,7 @@ import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexableField; import org.apache.lucene.index.IndexableFieldType; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.NumericDocValues; import org.apache.lucene.index.PostingsEnum; @@ -851,9 +852,10 @@ private static void assertFloatVectorValue(MemoryIndex mi, String fieldName, flo .reader() .getFloatVectorValues(fieldName); assertNotNull(fvv); - assertEquals(0, fvv.nextDoc()); - assertArrayEquals(expected, fvv.vectorValue(), 1e-6f); - assertEquals(DocIdSetIterator.NO_MORE_DOCS, fvv.nextDoc()); + KnnVectorValues.DocIndexIterator iterator = fvv.iterator(); + assertEquals(0, iterator.nextDoc()); + assertArrayEquals(expected, fvv.vectorValue(0), 1e-6f); + assertEquals(DocIdSetIterator.NO_MORE_DOCS, iterator.nextDoc()); } private static void assertFloatVectorScore( @@ -868,7 +870,7 @@ private static void assertFloatVectorScore( .getFloatVectorValues(fieldName); assertNotNull(fvv); if (random().nextBoolean()) { - fvv.nextDoc(); + fvv.iterator().nextDoc(); } VectorScorer scorer = fvv.scorer(queryVector); assertEquals(0, scorer.iterator().nextDoc()); @@ -886,9 +888,10 @@ private static void assertByteVectorValue(MemoryIndex mi, String fieldName, byte .reader() .getByteVectorValues(fieldName); assertNotNull(bvv); - assertEquals(0, bvv.nextDoc()); - assertArrayEquals(expected, bvv.vectorValue()); - assertEquals(DocIdSetIterator.NO_MORE_DOCS, bvv.nextDoc()); + KnnVectorValues.DocIndexIterator iterator = bvv.iterator(); + assertEquals(0, iterator.nextDoc()); + assertArrayEquals(expected, bvv.vectorValue(0)); + assertEquals(DocIdSetIterator.NO_MORE_DOCS, iterator.nextDoc()); } private static void assertByteVectorScore( @@ -903,7 +906,7 @@ private static void assertByteVectorScore( .getByteVectorValues(fieldName); assertNotNull(bvv); if (random().nextBoolean()) { - bvv.nextDoc(); + bvv.iterator().nextDoc(); } VectorScorer scorer = bvv.scorer(queryVector); assertEquals(0, scorer.iterator().nextDoc()); diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteKnnVectorFieldSource.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteKnnVectorFieldSource.java index 32517496d542..c95bf632a73a 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteKnnVectorFieldSource.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteKnnVectorFieldSource.java @@ -20,6 +20,7 @@ import java.util.Map; import java.util.Objects; import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.VectorEncoding; @@ -63,11 +64,12 @@ protected DocIdSetIterator getVectorIterator() { } return new VectorFieldFunction(this) { + KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator(); @Override public byte[] byteVectorVal(int doc) throws IOException { if (exists(doc)) { - return vectorValues.vectorValue(); + return vectorValues.vectorValue(iterator.index()); } else { return null; } @@ -75,7 +77,7 @@ public byte[] byteVectorVal(int doc) throws IOException { @Override protected DocIdSetIterator getVectorIterator() { - return vectorValues; + return iterator; } }; } diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatKnnVectorFieldSource.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatKnnVectorFieldSource.java index 43cc3aff880e..f026d9537bc6 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatKnnVectorFieldSource.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatKnnVectorFieldSource.java @@ -20,6 +20,7 @@ import java.util.Map; import java.util.Objects; import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.VectorEncoding; @@ -62,11 +63,12 @@ protected DocIdSetIterator getVectorIterator() { } return new VectorFieldFunction(this) { + KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator(); @Override public float[] floatVectorVal(int doc) throws IOException { if (exists(doc)) { - return vectorValues.vectorValue(); + return vectorValues.vectorValue(iterator.index()); } else { return null; } @@ -74,7 +76,7 @@ public float[] floatVectorVal(int doc) throws IOException { @Override protected DocIdSetIterator getVectorIterator() { - return vectorValues; + return iterator; } }; } diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/quantization/KMeans.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/quantization/KMeans.java index bb9d3ca63df5..88d2adba5fad 100644 --- a/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/quantization/KMeans.java +++ b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/quantization/KMeans.java @@ -25,11 +25,11 @@ import java.util.List; import java.util.Random; import java.util.Set; +import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.util.ArrayUtil; import org.apache.lucene.util.VectorUtil; import org.apache.lucene.util.hnsw.NeighborQueue; -import org.apache.lucene.util.hnsw.RandomAccessVectorValues; /** KMeans clustering algorithm for vectors */ public class KMeans { @@ -38,7 +38,7 @@ public class KMeans { public static final int DEFAULT_ITRS = 10; public static final int DEFAULT_SAMPLE_SIZE = 100_000; - private final RandomAccessVectorValues.Floats vectors; + private final FloatVectorValues vectors; private final int numVectors; private final int numCentroids; private final Random random; @@ -57,9 +57,7 @@ public class KMeans { * @throws IOException when if there is an error accessing vectors */ public static Results cluster( - RandomAccessVectorValues.Floats vectors, - VectorSimilarityFunction similarityFunction, - int numClusters) + FloatVectorValues vectors, VectorSimilarityFunction similarityFunction, int numClusters) throws IOException { return cluster( vectors, @@ -93,7 +91,7 @@ public static Results cluster( * @throws IOException if there is error accessing vectors */ public static Results cluster( - RandomAccessVectorValues.Floats vectors, + FloatVectorValues vectors, int numClusters, boolean assignCentroidsToVectors, long seed, @@ -124,7 +122,7 @@ public static Results cluster( if (numClusters == 1) { centroids = new float[1][vectors.dimension()]; } else { - RandomAccessVectorValues.Floats sampleVectors = + FloatVectorValues sampleVectors = vectors.size() <= sampleSize ? vectors : createSampleReader(vectors, sampleSize, seed); KMeans kmeans = new KMeans(sampleVectors, numClusters, random, initializationMethod, restarts, iters); @@ -142,7 +140,7 @@ public static Results cluster( } private KMeans( - RandomAccessVectorValues.Floats vectors, + FloatVectorValues vectors, int numCentroids, Random random, KmeansInitializationMethod initializationMethod, @@ -276,7 +274,7 @@ private float[][] initializePlusPlus() throws IOException { * @throws IOException if there is an error accessing vector values */ private static double runKMeansStep( - RandomAccessVectorValues.Floats vectors, + FloatVectorValues vectors, float[][] centroids, short[] docCentroids, boolean useKahanSummation, @@ -348,9 +346,7 @@ private static double runKMeansStep( * descending distance to the current centroid set */ static void assignCentroids( - RandomAccessVectorValues.Floats vectors, - float[][] centroids, - List unassignedCentroidsIdxs) + FloatVectorValues vectors, float[][] centroids, List unassignedCentroidsIdxs) throws IOException { int[] assignedCentroidsIdxs = new int[centroids.length - unassignedCentroidsIdxs.size()]; int assignedIndex = 0; diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/quantization/SampleReader.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/quantization/SampleReader.java index 9a718c811017..684c9fac838f 100644 --- a/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/quantization/SampleReader.java +++ b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/quantization/SampleReader.java @@ -20,18 +20,18 @@ import java.io.IOException; import java.util.Random; import java.util.function.IntUnaryOperator; +import org.apache.lucene.codecs.lucene95.HasIndexSlice; +import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.store.IndexInput; import org.apache.lucene.util.Bits; -import org.apache.lucene.util.hnsw.RandomAccessVectorValues; /** A reader of vector values that samples a subset of the vectors. */ -public class SampleReader implements RandomAccessVectorValues.Floats { - private final RandomAccessVectorValues.Floats origin; +public class SampleReader extends FloatVectorValues implements HasIndexSlice { + private final FloatVectorValues origin; private final int sampleSize; private final IntUnaryOperator sampleFunction; - SampleReader( - RandomAccessVectorValues.Floats origin, int sampleSize, IntUnaryOperator sampleFunction) { + SampleReader(FloatVectorValues origin, int sampleSize, IntUnaryOperator sampleFunction) { this.origin = origin; this.sampleSize = sampleSize; this.sampleFunction = sampleFunction; @@ -48,13 +48,13 @@ public int dimension() { } @Override - public Floats copy() throws IOException { + public FloatVectorValues copy() throws IOException { throw new IllegalStateException("Not supported"); } @Override public IndexInput getSlice() { - return origin.getSlice(); + return ((HasIndexSlice) origin).getSlice(); } @Override @@ -77,8 +77,7 @@ public Bits getAcceptOrds(Bits acceptDocs) { throw new IllegalStateException("Not supported"); } - public static SampleReader createSampleReader( - RandomAccessVectorValues.Floats origin, int k, long seed) { + public static SampleReader createSampleReader(FloatVectorValues origin, int k, long seed) { int[] samples = reservoirSample(origin.size(), k, seed); return new SampleReader(origin, samples.length, i -> samples[i]); } diff --git a/lucene/sandbox/src/test/org/apache/lucene/sandbox/codecs/quantization/TestKMeans.java b/lucene/sandbox/src/test/org/apache/lucene/sandbox/codecs/quantization/TestKMeans.java index 61c0e58c91ef..3669079b719d 100644 --- a/lucene/sandbox/src/test/org/apache/lucene/sandbox/codecs/quantization/TestKMeans.java +++ b/lucene/sandbox/src/test/org/apache/lucene/sandbox/codecs/quantization/TestKMeans.java @@ -20,9 +20,9 @@ import java.io.IOException; import java.util.ArrayList; import java.util.List; +import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.tests.util.LuceneTestCase; -import org.apache.lucene.util.hnsw.RandomAccessVectorValues; public class TestKMeans extends LuceneTestCase { @@ -32,7 +32,7 @@ public void testKMeansAPI() throws IOException { int dims = random().nextInt(2, 20); int randIdx = random().nextInt(VectorSimilarityFunction.values().length); VectorSimilarityFunction similarityFunction = VectorSimilarityFunction.values()[randIdx]; - RandomAccessVectorValues.Floats vectors = generateData(nVectors, dims, nClusters); + FloatVectorValues vectors = generateData(nVectors, dims, nClusters); // default case { @@ -75,7 +75,7 @@ public void testKMeansSpecialCases() throws IOException { // nClusters > nVectors int nClusters = 20; int nVectors = 10; - RandomAccessVectorValues.Floats vectors = generateData(nVectors, 5, nClusters); + FloatVectorValues vectors = generateData(nVectors, 5, nClusters); KMeans.Results results = KMeans.cluster(vectors, VectorSimilarityFunction.EUCLIDEAN, nClusters); // assert that we get 1 centroid, as nClusters will be adjusted @@ -87,7 +87,7 @@ public void testKMeansSpecialCases() throws IOException { int sampleSize = 2; int nClusters = 2; int nVectors = 300; - RandomAccessVectorValues.Floats vectors = generateData(nVectors, 5, nClusters); + FloatVectorValues vectors = generateData(nVectors, 5, nClusters); KMeans.KmeansInitializationMethod initializationMethod = KMeans.KmeansInitializationMethod.PLUS_PLUS; KMeans.Results results = @@ -108,7 +108,7 @@ public void testKMeansSpecialCases() throws IOException { // test unassigned centroids int nClusters = 4; int nVectors = 400; - RandomAccessVectorValues.Floats vectors = generateData(nVectors, 5, nClusters); + FloatVectorValues vectors = generateData(nVectors, 5, nClusters); KMeans.Results results = KMeans.cluster(vectors, VectorSimilarityFunction.EUCLIDEAN, nClusters); float[][] centroids = results.centroids(); @@ -118,8 +118,7 @@ public void testKMeansSpecialCases() throws IOException { } } - private static RandomAccessVectorValues.Floats generateData( - int nSamples, int nDims, int nClusters) { + private static FloatVectorValues generateData(int nSamples, int nDims, int nClusters) { List vectors = new ArrayList<>(nSamples); float[][] centroids = new float[nClusters][nDims]; // Generate random centroids @@ -137,6 +136,6 @@ private static RandomAccessVectorValues.Floats generateData( } vectors.add(vector); } - return RandomAccessVectorValues.fromFloats(vectors, nDims); + return FloatVectorValues.fromFloats(vectors, nDims); } } diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/codecs/asserting/AssertingKnnVectorsFormat.java b/lucene/test-framework/src/java/org/apache/lucene/tests/codecs/asserting/AssertingKnnVectorsFormat.java index 501e2e5616f0..21c62090a698 100644 --- a/lucene/test-framework/src/java/org/apache/lucene/tests/codecs/asserting/AssertingKnnVectorsFormat.java +++ b/lucene/test-framework/src/java/org/apache/lucene/tests/codecs/asserting/AssertingKnnVectorsFormat.java @@ -125,7 +125,7 @@ public FloatVectorValues getFloatVectorValues(String field) throws IOException { && fi.getVectorEncoding() == VectorEncoding.FLOAT32; FloatVectorValues floatValues = delegate.getFloatVectorValues(field); assert floatValues != null; - assert floatValues.docID() == -1; + assert floatValues.iterator().docID() == -1; assert floatValues.size() >= 0; assert floatValues.dimension() > 0; return floatValues; @@ -139,7 +139,7 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException { && fi.getVectorEncoding() == VectorEncoding.BYTE; ByteVectorValues values = delegate.getByteVectorValues(field); assert values != null; - assert values.docID() == -1; + assert values.iterator().docID() == -1; assert values.size() >= 0; assert values.dimension() > 0; return values; diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java b/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java index 63fe2b8f4c11..e42d3e189819 100644 --- a/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java +++ b/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java @@ -55,6 +55,7 @@ import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.MergePolicy; @@ -437,9 +438,10 @@ public void testAddIndexesDirectory0() throws Exception { try (IndexReader reader = DirectoryReader.open(w2)) { LeafReader r = getOnlyLeafReader(reader); FloatVectorValues vectorValues = r.getFloatVectorValues(fieldName); - assertEquals(0, vectorValues.nextDoc()); - assertEquals(0, vectorValues.vectorValue()[0], 0); - assertEquals(NO_MORE_DOCS, vectorValues.nextDoc()); + KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator(); + assertEquals(0, iterator.nextDoc()); + assertEquals(0, vectorValues.vectorValue(0)[0], 0); + assertEquals(NO_MORE_DOCS, iterator.nextDoc()); } } } @@ -462,9 +464,10 @@ public void testAddIndexesDirectory1() throws Exception { try (IndexReader reader = DirectoryReader.open(w2)) { LeafReader r = getOnlyLeafReader(reader); FloatVectorValues vectorValues = r.getFloatVectorValues(fieldName); - assertNotEquals(NO_MORE_DOCS, vectorValues.nextDoc()); - assertEquals(0, vectorValues.vectorValue()[0], 0); - assertEquals(NO_MORE_DOCS, vectorValues.nextDoc()); + KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator(); + assertNotEquals(NO_MORE_DOCS, iterator.nextDoc()); + assertEquals(0, vectorValues.vectorValue(iterator.index())[0], 0); + assertEquals(NO_MORE_DOCS, iterator.nextDoc()); } } } @@ -489,12 +492,13 @@ public void testAddIndexesDirectory01() throws Exception { try (IndexReader reader = DirectoryReader.open(w2)) { LeafReader r = getOnlyLeafReader(reader); FloatVectorValues vectorValues = r.getFloatVectorValues(fieldName); - assertEquals(0, vectorValues.nextDoc()); + KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator(); + assertEquals(0, iterator.nextDoc()); // The merge order is randomized, we might get 0 first, or 1 - float value = vectorValues.vectorValue()[0]; + float value = vectorValues.vectorValue(0)[0]; assertTrue(value == 0 || value == 1); - assertEquals(1, vectorValues.nextDoc()); - value += vectorValues.vectorValue()[0]; + assertEquals(1, iterator.nextDoc()); + value += vectorValues.vectorValue(1)[0]; assertEquals(1, value, 0); } } @@ -879,8 +883,10 @@ public void testSparseVectors() throws Exception { ByteVectorValues byteVectorValues = ctx.reader().getByteVectorValues(fieldName); if (byteVectorValues != null) { docCount += byteVectorValues.size(); - while (byteVectorValues.nextDoc() != NO_MORE_DOCS) { - checksum += byteVectorValues.vectorValue()[0]; + KnnVectorValues.DocIndexIterator iterator = byteVectorValues.iterator(); + while (true) { + if (!(iterator.nextDoc() != NO_MORE_DOCS)) break; + checksum += byteVectorValues.vectorValue(iterator.index())[0]; } } } @@ -890,8 +896,10 @@ public void testSparseVectors() throws Exception { FloatVectorValues vectorValues = ctx.reader().getFloatVectorValues(fieldName); if (vectorValues != null) { docCount += vectorValues.size(); - while (vectorValues.nextDoc() != NO_MORE_DOCS) { - checksum += vectorValues.vectorValue()[0]; + KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator(); + while (true) { + if (!(iterator.nextDoc() != NO_MORE_DOCS)) break; + checksum += vectorValues.vectorValue(iterator.index())[0]; } } } @@ -950,10 +958,12 @@ public void testFloatVectorScorerIteration() throws Exception { assertSame(iterator, scorer.iterator()); assertNotSame(iterator, scorer); // verify scorer iteration scores are valid & iteration with vectorValues is consistent - while (iterator.nextDoc() != NO_MORE_DOCS && vectorValues.nextDoc() != NO_MORE_DOCS) { + KnnVectorValues.DocIndexIterator valuesIterator = vectorValues.iterator(); + while (iterator.nextDoc() != NO_MORE_DOCS) { + if (!(valuesIterator.nextDoc() != NO_MORE_DOCS)) break; float score = scorer.score(); assertTrue(score >= 0f); - assertEquals(iterator.docID(), vectorValues.docID()); + assertEquals(iterator.docID(), valuesIterator.docID()); } // verify that a new scorer can be obtained after iteration VectorScorer newScorer = vectorValues.scorer(vectorToScore); @@ -1009,10 +1019,12 @@ public void testByteVectorScorerIteration() throws Exception { assertSame(iterator, scorer.iterator()); assertNotSame(iterator, scorer); // verify scorer iteration scores are valid & iteration with vectorValues is consistent - while (iterator.nextDoc() != NO_MORE_DOCS && vectorValues.nextDoc() != NO_MORE_DOCS) { + KnnVectorValues.DocIndexIterator valuesIterator = vectorValues.iterator(); + while (iterator.nextDoc() != NO_MORE_DOCS) { + if (!(valuesIterator.nextDoc() != NO_MORE_DOCS)) break; float score = scorer.score(); assertTrue(score >= 0f); - assertEquals(iterator.docID(), vectorValues.docID()); + assertEquals(iterator.docID(), valuesIterator.docID()); } // verify that a new scorer can be obtained after iteration VectorScorer newScorer = vectorValues.scorer(vectorToScore); @@ -1118,12 +1130,16 @@ public void testIndexedValueNotAliased() throws Exception { LeafReader r = getOnlyLeafReader(reader); FloatVectorValues vectorValues = r.getFloatVectorValues(fieldName); assertEquals(3, vectorValues.size()); - vectorValues.nextDoc(); - assertEquals(1, vectorValues.vectorValue()[0], 0); - vectorValues.nextDoc(); - assertEquals(1, vectorValues.vectorValue()[0], 0); - vectorValues.nextDoc(); - assertEquals(2, vectorValues.vectorValue()[0], 0); + KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator(); + iterator.nextDoc(); + assertEquals(0, iterator.index()); + assertEquals(1, vectorValues.vectorValue(0)[0], 0); + iterator.nextDoc(); + assertEquals(1, iterator.index()); + assertEquals(1, vectorValues.vectorValue(1)[0], 0); + iterator.nextDoc(); + assertEquals(2, iterator.index()); + assertEquals(2, vectorValues.vectorValue(2)[0], 0); } } } @@ -1146,13 +1162,14 @@ public void testSortedIndex() throws Exception { FloatVectorValues vectorValues = leaf.getFloatVectorValues(fieldName); assertEquals(2, vectorValues.dimension()); assertEquals(3, vectorValues.size()); - assertEquals("1", storedFields.document(vectorValues.nextDoc()).get("id")); - assertEquals(-1f, vectorValues.vectorValue()[0], 0); - assertEquals("2", storedFields.document(vectorValues.nextDoc()).get("id")); - assertEquals(1, vectorValues.vectorValue()[0], 0); - assertEquals("4", storedFields.document(vectorValues.nextDoc()).get("id")); - assertEquals(0, vectorValues.vectorValue()[0], 0); - assertEquals(NO_MORE_DOCS, vectorValues.nextDoc()); + KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator(); + assertEquals("1", storedFields.document(iterator.nextDoc()).get("id")); + assertEquals(-1f, vectorValues.vectorValue(0)[0], 0); + assertEquals("2", storedFields.document(iterator.nextDoc()).get("id")); + assertEquals(1, vectorValues.vectorValue(1)[0], 0); + assertEquals("4", storedFields.document(iterator.nextDoc()).get("id")); + assertEquals(0, vectorValues.vectorValue(2)[0], 0); + assertEquals(NO_MORE_DOCS, iterator.nextDoc()); } } } @@ -1175,13 +1192,13 @@ public void testSortedIndexBytes() throws Exception { ByteVectorValues vectorValues = leaf.getByteVectorValues(fieldName); assertEquals(2, vectorValues.dimension()); assertEquals(3, vectorValues.size()); - assertEquals("1", storedFields.document(vectorValues.nextDoc()).get("id")); - assertEquals(-1, vectorValues.vectorValue()[0], 0); - assertEquals("2", storedFields.document(vectorValues.nextDoc()).get("id")); - assertEquals(1, vectorValues.vectorValue()[0], 0); - assertEquals("4", storedFields.document(vectorValues.nextDoc()).get("id")); - assertEquals(0, vectorValues.vectorValue()[0], 0); - assertEquals(NO_MORE_DOCS, vectorValues.nextDoc()); + assertEquals("1", storedFields.document(vectorValues.iterator().nextDoc()).get("id")); + assertEquals(-1, vectorValues.vectorValue(0)[0], 0); + assertEquals("2", storedFields.document(vectorValues.iterator().nextDoc()).get("id")); + assertEquals(1, vectorValues.vectorValue(1)[0], 0); + assertEquals("4", storedFields.document(vectorValues.iterator().nextDoc()).get("id")); + assertEquals(0, vectorValues.vectorValue(2)[0], 0); + assertEquals(NO_MORE_DOCS, vectorValues.iterator().nextDoc()); } } } @@ -1211,27 +1228,30 @@ public void testIndexMultipleKnnVectorFields() throws Exception { FloatVectorValues vectorValues = leaf.getFloatVectorValues("field1"); assertEquals(2, vectorValues.dimension()); assertEquals(2, vectorValues.size()); - vectorValues.nextDoc(); - assertEquals(1f, vectorValues.vectorValue()[0], 0); - vectorValues.nextDoc(); - assertEquals(2f, vectorValues.vectorValue()[0], 0); - assertEquals(NO_MORE_DOCS, vectorValues.nextDoc()); + KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator(); + iterator.nextDoc(); + assertEquals(1f, vectorValues.vectorValue(0)[0], 0); + iterator.nextDoc(); + assertEquals(2f, vectorValues.vectorValue(1)[0], 0); + assertEquals(NO_MORE_DOCS, iterator.nextDoc()); FloatVectorValues vectorValues2 = leaf.getFloatVectorValues("field2"); + KnnVectorValues.DocIndexIterator it2 = vectorValues2.iterator(); assertEquals(4, vectorValues2.dimension()); assertEquals(2, vectorValues2.size()); - vectorValues2.nextDoc(); - assertEquals(2f, vectorValues2.vectorValue()[1], 0); - vectorValues2.nextDoc(); - assertEquals(2f, vectorValues2.vectorValue()[1], 0); - assertEquals(NO_MORE_DOCS, vectorValues2.nextDoc()); + it2.nextDoc(); + assertEquals(2f, vectorValues2.vectorValue(0)[1], 0); + it2.nextDoc(); + assertEquals(2f, vectorValues2.vectorValue(1)[1], 0); + assertEquals(NO_MORE_DOCS, it2.nextDoc()); FloatVectorValues vectorValues3 = leaf.getFloatVectorValues("field3"); assertEquals(4, vectorValues3.dimension()); assertEquals(1, vectorValues3.size()); - vectorValues3.nextDoc(); - assertEquals(1f, vectorValues3.vectorValue()[0], 0.1); - assertEquals(NO_MORE_DOCS, vectorValues3.nextDoc()); + KnnVectorValues.DocIndexIterator it3 = vectorValues3.iterator(); + it3.nextDoc(); + assertEquals(1f, vectorValues3.vectorValue(0)[0], 0.1); + assertEquals(NO_MORE_DOCS, it3.nextDoc()); } } } @@ -1295,13 +1315,15 @@ public void testRandom() throws Exception { totalSize += vectorValues.size(); StoredFields storedFields = ctx.reader().storedFields(); int docId; - while ((docId = vectorValues.nextDoc()) != NO_MORE_DOCS) { - float[] v = vectorValues.vectorValue(); + KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator(); + while (true) { + if (!((docId = iterator.nextDoc()) != NO_MORE_DOCS)) break; + float[] v = vectorValues.vectorValue(iterator.index()); assertEquals(dimension, v.length); String idString = storedFields.document(docId).getField("id").stringValue(); int id = Integer.parseInt(idString); if (ctx.reader().getLiveDocs() == null || ctx.reader().getLiveDocs().get(docId)) { - assertArrayEquals(idString, values[id], v, 0); + assertArrayEquals(idString + " " + docId, values[id], v, 0); ++valueCount; } else { ++numDeletes; @@ -1375,8 +1397,10 @@ public void testRandomBytes() throws Exception { totalSize += vectorValues.size(); StoredFields storedFields = ctx.reader().storedFields(); int docId; - while ((docId = vectorValues.nextDoc()) != NO_MORE_DOCS) { - byte[] v = vectorValues.vectorValue(); + KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator(); + while (true) { + if (!((docId = iterator.nextDoc()) != NO_MORE_DOCS)) break; + byte[] v = vectorValues.vectorValue(iterator.index()); assertEquals(dimension, v.length); String idString = storedFields.document(docId).getField("id").stringValue(); int id = Integer.parseInt(idString); @@ -1495,8 +1519,10 @@ public void testRandomWithUpdatesAndGraph() throws Exception { StoredFields storedFields = ctx.reader().storedFields(); int docId; int numLiveDocsWithVectors = 0; - while ((docId = vectorValues.nextDoc()) != NO_MORE_DOCS) { - float[] v = vectorValues.vectorValue(); + KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator(); + while (true) { + if (!((docId = iterator.nextDoc()) != NO_MORE_DOCS)) break; + float[] v = vectorValues.vectorValue(iterator.index()); assertEquals(dimension, v.length); String idString = storedFields.document(docId).getField("id").stringValue(); int id = Integer.parseInt(idString); @@ -1703,25 +1729,27 @@ public void testAdvance() throws Exception { FloatVectorValues vectorValues = r.getFloatVectorValues(fieldName); int[] vectorDocs = new int[vectorValues.size() + 1]; int cur = -1; + KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator(); while (++cur < vectorValues.size() + 1) { - vectorDocs[cur] = vectorValues.nextDoc(); + vectorDocs[cur] = iterator.nextDoc(); if (cur != 0) { assertTrue(vectorDocs[cur] > vectorDocs[cur - 1]); } } vectorValues = r.getFloatVectorValues(fieldName); + DocIdSetIterator iter = vectorValues.iterator(); cur = -1; for (int i = 0; i < numdocs; i++) { // randomly advance to i if (random().nextInt(4) == 3) { while (vectorDocs[++cur] < i) {} - assertEquals(vectorDocs[cur], vectorValues.advance(i)); - assertEquals(vectorDocs[cur], vectorValues.docID()); - if (vectorValues.docID() == NO_MORE_DOCS) { + assertEquals(vectorDocs[cur], iter.advance(i)); + assertEquals(vectorDocs[cur], iter.docID()); + if (iter.docID() == NO_MORE_DOCS) { break; } // make i equal to docid so that it is greater than docId in the next loop iteration - i = vectorValues.docID(); + i = iter.docID(); } } } @@ -1772,6 +1800,7 @@ public void testVectorValuesReportCorrectDocs() throws Exception { double checksum = 0; int docCount = 0; long sumDocIds = 0; + long sumOrdToDocIds = 0; switch (vectorEncoding) { case BYTE -> { for (LeafReaderContext ctx : r.leaves()) { @@ -1779,11 +1808,18 @@ public void testVectorValuesReportCorrectDocs() throws Exception { if (byteVectorValues != null) { docCount += byteVectorValues.size(); StoredFields storedFields = ctx.reader().storedFields(); - while (byteVectorValues.nextDoc() != NO_MORE_DOCS) { - checksum += byteVectorValues.vectorValue()[0]; - Document doc = storedFields.document(byteVectorValues.docID(), Set.of("id")); + KnnVectorValues.DocIndexIterator iter = byteVectorValues.iterator(); + for (iter.nextDoc(); iter.docID() != NO_MORE_DOCS; iter.nextDoc()) { + int ord = iter.index(); + checksum += byteVectorValues.vectorValue(ord)[0]; + Document doc = storedFields.document(iter.docID(), Set.of("id")); sumDocIds += Integer.parseInt(doc.get("id")); } + for (int ord = 0; ord < byteVectorValues.size(); ord++) { + Document doc = + storedFields.document(byteVectorValues.ordToDoc(ord), Set.of("id")); + sumOrdToDocIds += Integer.parseInt(doc.get("id")); + } } } } @@ -1793,11 +1829,17 @@ public void testVectorValuesReportCorrectDocs() throws Exception { if (vectorValues != null) { docCount += vectorValues.size(); StoredFields storedFields = ctx.reader().storedFields(); - while (vectorValues.nextDoc() != NO_MORE_DOCS) { - checksum += vectorValues.vectorValue()[0]; - Document doc = storedFields.document(vectorValues.docID(), Set.of("id")); + KnnVectorValues.DocIndexIterator iter = vectorValues.iterator(); + for (iter.nextDoc(); iter.docID() != NO_MORE_DOCS; iter.nextDoc()) { + int ord = iter.index(); + checksum += vectorValues.vectorValue(ord)[0]; + Document doc = storedFields.document(iter.docID(), Set.of("id")); sumDocIds += Integer.parseInt(doc.get("id")); } + for (int ord = 0; ord < vectorValues.size(); ord++) { + Document doc = storedFields.document(vectorValues.ordToDoc(ord), Set.of("id")); + sumOrdToDocIds += Integer.parseInt(doc.get("id")); + } } } } @@ -1809,6 +1851,7 @@ public void testVectorValuesReportCorrectDocs() throws Exception { vectorEncoding == VectorEncoding.BYTE ? numDocs * 0.2 : 1e-5); assertEquals(fieldDocCount, docCount); assertEquals(fieldSumDocIDs, sumDocIds); + assertEquals(fieldSumDocIDs, sumOrdToDocIds); } } } @@ -1839,25 +1882,27 @@ public void testMismatchedFields() throws Exception { ByteVectorValues byteVectors = leafReader.getByteVectorValues("byte"); assertNotNull(byteVectors); - assertEquals(0, byteVectors.nextDoc()); - assertArrayEquals(new byte[] {42}, byteVectors.vectorValue()); - assertEquals(1, byteVectors.nextDoc()); - assertArrayEquals(new byte[] {42}, byteVectors.vectorValue()); - assertEquals(DocIdSetIterator.NO_MORE_DOCS, byteVectors.nextDoc()); + KnnVectorValues.DocIndexIterator iter = byteVectors.iterator(); + assertEquals(0, iter.nextDoc()); + assertArrayEquals(new byte[] {42}, byteVectors.vectorValue(0)); + assertEquals(1, iter.nextDoc()); + assertArrayEquals(new byte[] {42}, byteVectors.vectorValue(1)); + assertEquals(DocIdSetIterator.NO_MORE_DOCS, iter.nextDoc()); FloatVectorValues floatVectors = leafReader.getFloatVectorValues("float"); assertNotNull(floatVectors); - assertEquals(0, floatVectors.nextDoc()); - float[] vector = floatVectors.vectorValue(); + iter = floatVectors.iterator(); + assertEquals(0, iter.nextDoc()); + float[] vector = floatVectors.vectorValue(0); assertEquals(2, vector.length); assertEquals(1f, vector[0], 0f); assertEquals(2f, vector[1], 0f); - assertEquals(1, floatVectors.nextDoc()); - vector = floatVectors.vectorValue(); + assertEquals(1, iter.nextDoc()); + vector = floatVectors.vectorValue(1); assertEquals(2, vector.length); assertEquals(1f, vector[0], 0f); assertEquals(2f, vector[1], 0f); - assertEquals(DocIdSetIterator.NO_MORE_DOCS, floatVectors.nextDoc()); + assertEquals(DocIdSetIterator.NO_MORE_DOCS, iter.nextDoc()); IOUtils.close(reader, w2, dir1, dir2); } diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/search/AssertingScorer.java b/lucene/test-framework/src/java/org/apache/lucene/tests/search/AssertingScorer.java index 8badba0d12b7..dd408befdbf3 100644 --- a/lucene/test-framework/src/java/org/apache/lucene/tests/search/AssertingScorer.java +++ b/lucene/test-framework/src/java/org/apache/lucene/tests/search/AssertingScorer.java @@ -183,7 +183,7 @@ public int advance(int target) throws IOException { } else { state = IteratorState.ITERATING; } - assert in.docID() == advanced; + assert in.docID() == advanced : in.docID() + " != " + advanced + " in " + in; assert AssertingScorer.this.in.docID() == in.docID(); return doc = advanced; }