Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

First-class random access API for KnnVectorValues #13779

Open
wants to merge 25 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
cd9c486
compiles!
Sep 1, 2024
2bbf8f1
adding some ordToDoc
Sep 2, 2024
a451fdb
restore vector count argument to scalarquantizer methods
Sep 2, 2024
8152b9d
remove docToOrd; mostly can use iterator.index()
Sep 3, 2024
dce766c
Make KnnVectorValues primarily a random access API
Sep 5, 2024
2f0cc8c
HasIndexSlice
Sep 6, 2024
327b930
remove RandomAccessVectorValues
Sep 7, 2024
98ab0a6
tests pass
Sep 10, 2024
1450b44
fixing up javadocs and making iterator methods instance methods
Sep 10, 2024
8d087e2
rename DocIterator to DocIndexIterator
Sep 10, 2024
c2ae86b
clean up some comments
Sep 10, 2024
ff7a317
fix case where index is reordered
Sep 12, 2024
9e5b9f9
rename 'fromOrdToDoc' to 'all'; move fromIndexedDISI to codecs/lucene90
Sep 15, 2024
d43785d
no default advance(); default cost() unsupported
Sep 15, 2024
787e89c
make iterator() API sane
Sep 16, 2024
1873955
Merge branch 'main' into knn-vector-random
Sep 16, 2024
4feecf8
Rename IteratorSupplier->SortingIteratorSupplier and add javadoc
Sep 16, 2024
abc1713
cache vector values iterators in VectorFieldSources
Sep 16, 2024
3f6091c
rename KnnvectorValues.all() to createSparseIterator()
Sep 17, 2024
d8ab1ec
implement cost(); enforce forward iteration in KnnVectorsWriter
Sep 18, 2024
a2ca172
add implementations of KnnVectorValues.copy()
Sep 19, 2024
274859f
Merge remote-tracking branch 'origin/main' into knn-vector-random
Sep 20, 2024
2b21668
fix SlowCOmpositeCodecReaderWrapper; off-by-one AND lazy iterator access
Sep 20, 2024
2a284f2
Merge remote-tracking branch 'origin/main' into knn-vector-random
Sep 20, 2024
cb62025
resolve merge conflicts
Sep 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,18 @@
package org.apache.lucene.analysis.synonym.word2vec;

import java.io.IOException;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.BytesRefHash;
import org.apache.lucene.util.TermAndVector;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;

/**
* Word2VecModel is a class representing the parsed Word2Vec model containing the vectors for each
* word in dictionary
*
* @lucene.experimental
*/
public class Word2VecModel implements RandomAccessVectorValues.Floats {
public class Word2VecModel extends FloatVectorValues {

private final int dictionarySize;
private final int vectorDimension;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@
import java.util.Objects;
import java.util.SplittableRandom;
import java.util.concurrent.TimeUnit;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.InfoStream;
import org.apache.lucene.util.hnsw.NeighborQueue;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;

/**
* Builder for HNSW graph. See {@link Lucene90OnHeapHnswGraph} for a gloss on the algorithm and the
Expand All @@ -49,7 +49,7 @@ public final class Lucene90HnswGraphBuilder {
private final Lucene90NeighborArray scratch;

private final VectorSimilarityFunction similarityFunction;
private final RandomAccessVectorValues.Floats vectorValues;
private final FloatVectorValues vectorValues;
private final SplittableRandom random;
private final Lucene90BoundsChecker bound;
final Lucene90OnHeapHnswGraph hnsw;
Expand All @@ -58,7 +58,7 @@ public final class Lucene90HnswGraphBuilder {

// we need two sources of vectors in order to perform diversity check comparisons without
// colliding
private final RandomAccessVectorValues.Floats buildVectors;
private final FloatVectorValues buildVectors;

/**
* Reads all the vectors from vector values, builds a graph connecting them by their dense
Expand All @@ -73,7 +73,7 @@ public final class Lucene90HnswGraphBuilder {
* to ensure repeatable construction.
*/
public Lucene90HnswGraphBuilder(
RandomAccessVectorValues.Floats vectors,
FloatVectorValues vectors,
VectorSimilarityFunction similarityFunction,
int maxConn,
int beamWidth,
Expand All @@ -97,14 +97,14 @@ public Lucene90HnswGraphBuilder(
}

/**
* Reads all the vectors from two copies of a {@link RandomAccessVectorValues}. Providing two
* copies enables efficient retrieval without extra data copying, while avoiding collision of the
* Reads all the vectors from two copies of a {@link FloatVectorValues}. Providing two copies
* enables efficient retrieval without extra data copying, while avoiding collision of the
* returned values.
*
* @param vectors the vectors for which to build a nearest neighbors graph. Must be an independet
* accessor for the vectors
*/
public Lucene90OnHeapHnswGraph build(RandomAccessVectorValues.Floats vectors) throws IOException {
public Lucene90OnHeapHnswGraph build(FloatVectorValues vectors) throws IOException {
if (vectors == vectorValues) {
throw new IllegalArgumentException(
"Vectors to build must be independent of the source of vectors provided to HnswGraphBuilder()");
Expand Down Expand Up @@ -230,7 +230,7 @@ private boolean diversityCheck(
float[] candidate,
float score,
Lucene90NeighborArray neighbors,
RandomAccessVectorValues.Floats vectorValues)
FloatVectorValues vectorValues)
throws IOException {
bound.set(score);
for (int i = 0; i < neighbors.size(); i++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;

import java.io.IOException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.SplittableRandom;
Expand All @@ -34,7 +33,6 @@
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.VectorScorer;
import org.apache.lucene.store.ChecksumIndexInput;
Expand All @@ -44,7 +42,6 @@
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.hnsw.HnswGraph;
import org.apache.lucene.util.hnsw.NeighborQueue;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;

/**
* Reads vectors from the index segments along with index data structures supporting KNN search.
Expand Down Expand Up @@ -355,8 +352,7 @@ int size() {
}

/** Read the vector values from the index input. This supports both iterated and random access. */
static class OffHeapFloatVectorValues extends FloatVectorValues
implements RandomAccessVectorValues.Floats {
static class OffHeapFloatVectorValues extends FloatVectorValues {

final int dimension;
final int[] ordToDoc;
Expand All @@ -367,9 +363,6 @@ static class OffHeapFloatVectorValues extends FloatVectorValues
final float[] value;
final VectorSimilarityFunction similarityFunction;

int ord = -1;
int doc = -1;

OffHeapFloatVectorValues(
int dimension,
int[] ordToDoc,
Expand All @@ -394,42 +387,6 @@ public int size() {
return ordToDoc.length;
}

@Override
public float[] vectorValue() throws IOException {
return vectorValue(ord);
}

@Override
public int docID() {
return doc;
}

@Override
public int nextDoc() {
if (++ord >= size()) {
doc = NO_MORE_DOCS;
} else {
doc = ordToDoc[ord];
}
return doc;
}

@Override
public int advance(int target) {
assert docID() < target;
ord = Arrays.binarySearch(ordToDoc, ord + 1, ordToDoc.length, target);
if (ord < 0) {
ord = -(ord + 1);
}
assert ord <= ordToDoc.length;
if (ord == ordToDoc.length) {
doc = NO_MORE_DOCS;
} else {
doc = ordToDoc[ord];
}
return doc;
}

@Override
public OffHeapFloatVectorValues copy() {
return new OffHeapFloatVectorValues(dimension, ordToDoc, similarityFunction, dataIn.clone());
Expand All @@ -446,6 +403,16 @@ public float[] vectorValue(int targetOrd) throws IOException {
return value;
}

@Override
public int ordToDoc(int ord) {
return ordToDoc[ord];
}

@Override
protected DocIndexIterator createIterator() {
return fromOrdToDoc();
}

@Override
public VectorScorer scorer(float[] target) {
if (size() == 0) {
Expand All @@ -455,12 +422,12 @@ public VectorScorer scorer(float[] target) {
return new VectorScorer() {
@Override
public float score() throws IOException {
return values.similarityFunction.compare(values.vectorValue(), target);
return values.similarityFunction.compare(values.vectorValue(iterator().index()), target);
}

@Override
public DocIdSetIterator iterator() {
return values;
public DocIndexIterator iterator() {
return values.iterator();
}
};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@
import java.util.ArrayList;
import java.util.List;
import java.util.SplittableRandom;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.SparseFixedBitSet;
import org.apache.lucene.util.hnsw.HnswGraph;
import org.apache.lucene.util.hnsw.NeighborQueue;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;

/**
* An {@link HnswGraph} where all nodes and connections are held in memory. This class is used to
Expand Down Expand Up @@ -74,7 +74,7 @@ public static NeighborQueue search(
float[] query,
int topK,
int numSeed,
RandomAccessVectorValues.Floats vectors,
FloatVectorValues vectors,
VectorSimilarityFunction similarityFunction,
HnswGraph graphValues,
Bits acceptOrds,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
import org.apache.lucene.util.hnsw.HnswGraph;
import org.apache.lucene.util.hnsw.HnswGraphSearcher;
import org.apache.lucene.util.hnsw.OrdinalTranslatedKnnCollector;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
import org.apache.lucene.util.hnsw.RandomVectorScorer;

/**
Expand Down Expand Up @@ -398,8 +397,7 @@ int ordToDoc(int ord) {
}

/** Read the vector values from the index input. This supports both iterated and random access. */
static class OffHeapFloatVectorValues extends FloatVectorValues
implements RandomAccessVectorValues.Floats {
static class OffHeapFloatVectorValues extends FloatVectorValues {

private final int dimension;
private final int size;
Expand All @@ -410,9 +408,6 @@ static class OffHeapFloatVectorValues extends FloatVectorValues
private final float[] value;
private final VectorSimilarityFunction similarityFunction;

private int ord = -1;
private int doc = -1;

OffHeapFloatVectorValues(
int dimension,
int size,
Expand All @@ -439,49 +434,6 @@ public int size() {
return size;
}

@Override
public float[] vectorValue() throws IOException {
dataIn.seek((long) ord * byteSize);
dataIn.readFloats(value, 0, value.length);
return value;
}

@Override
public int docID() {
return doc;
}

@Override
public int nextDoc() {
if (++ord >= size) {
doc = NO_MORE_DOCS;
} else {
doc = ordToDocOperator.applyAsInt(ord);
}
return doc;
}

@Override
public int advance(int target) {
assert docID() < target;

if (ordToDoc == null) {
ord = target;
} else {
ord = Arrays.binarySearch(ordToDoc, ord + 1, ordToDoc.length, target);
if (ord < 0) {
ord = -(ord + 1);
}
}

if (ord < size) {
doc = ordToDocOperator.applyAsInt(ord);
} else {
doc = NO_MORE_DOCS;
}
return doc;
}

@Override
public OffHeapFloatVectorValues copy() {
return new OffHeapFloatVectorValues(
Expand All @@ -495,6 +447,16 @@ public float[] vectorValue(int targetOrd) throws IOException {
return value;
}

@Override
public int ordToDoc(int ord) {
return ordToDocOperator.applyAsInt(ord);
}

@Override
protected DocIndexIterator createIterator() {
return fromOrdToDoc();
}

@Override
public VectorScorer scorer(float[] target) {
if (size == 0) {
Expand All @@ -504,12 +466,13 @@ public VectorScorer scorer(float[] target) {
return new VectorScorer() {
@Override
public float score() throws IOException {
return values.similarityFunction.compare(values.vectorValue(), target);
return values.similarityFunction.compare(
values.vectorValue(values.iterator().index()), target);
}

@Override
public DocIdSetIterator iterator() {
return values;
return values.iterator();
}
};
}
Expand Down
Loading