Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move mappers to separate files #448

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/main/java/org/opensearch/knn/index/IndexUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.opensearch.cluster.metadata.MappingMetadata;
import org.opensearch.common.ValidationException;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
import org.opensearch.knn.index.util.KNNEngine;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.indices.ModelMetadata;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.apache.lucene.store.FilterDirectory;
import org.opensearch.index.engine.Engine;
import org.opensearch.index.shard.IndexShard;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
import org.opensearch.knn.index.memory.NativeMemoryCacheManager;
import org.opensearch.knn.index.memory.NativeMemoryEntryContext;
import org.opensearch.knn.index.memory.NativeMemoryLoadStrategy;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.knn.index;

import org.opensearch.index.mapper.NumberFieldMapper;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.knn.plugin.stats.KNNCounter;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.store.FSDirectory;
import org.apache.lucene.store.FilterDirectory;
import org.opensearch.knn.index.KNNVectorFieldMapper;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
import org.opensearch.knn.common.KNNConstants;

import java.io.Closeable;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,19 @@
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index;
package org.opensearch.knn.index.mapper;

import lombok.Getter;
import org.opensearch.common.Strings;
import org.opensearch.common.ValidationException;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.knn.common.KNNConstants;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.document.FieldType;
import org.apache.lucene.document.StoredField;
import org.apache.lucene.index.DocValuesType;
import org.apache.lucene.index.IndexOptions;
import org.apache.lucene.search.DocValuesFieldExistsQuery;
import org.apache.lucene.search.Query;
import org.opensearch.common.Explicit;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.xcontent.ToXContent;
import org.opensearch.common.xcontent.XContentBuilder;
import org.opensearch.common.xcontent.XContentParser;
Expand All @@ -36,9 +31,11 @@
import org.opensearch.index.mapper.ValueFetcher;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.index.query.QueryShardException;
import org.opensearch.knn.index.util.KNNEngine;
import org.opensearch.knn.index.KNNMethodContext;
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.knn.index.KNNVectorIndexFieldData;
import org.opensearch.knn.index.VectorField;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.search.aggregations.support.CoreValuesSourceType;
import org.opensearch.search.lookup.SearchLookup;

Expand All @@ -47,17 +44,10 @@
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Supplier;

import static org.opensearch.knn.common.KNNConstants.DIMENSION;
import static org.opensearch.knn.common.KNNConstants.HNSW_ALGO_EF_CONSTRUCTION;
import static org.opensearch.knn.common.KNNConstants.HNSW_ALGO_M;
import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE;
import static org.opensearch.knn.common.KNNConstants.KNN_METHOD;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE;
import static org.opensearch.knn.common.KNNConstants.MODEL_ID;
import static org.opensearch.knn.common.KNNConstants.PARAMETERS;
import static org.opensearch.knn.common.KNNConstants.SPACE_TYPE;

/**
* Field Mapper for KNN vector type.
Expand All @@ -69,8 +59,6 @@
*/
public abstract class KNNVectorFieldMapper extends ParametrizedFieldMapper {

private static Logger logger = LogManager.getLogger(KNNVectorFieldMapper.class);

public static final String CONTENT_TYPE = "knn_vector";
public static final String KNN_FIELD = "knn_field";

Expand Down Expand Up @@ -99,11 +87,13 @@ public static class Builder extends ParametrizedFieldMapper.Builder {
}
int value = XContentMapValues.nodeIntegerValue(o);
if (value > MAX_DIMENSION) {
throw new IllegalArgumentException("Dimension value cannot be greater than " + MAX_DIMENSION + " for vector: " + name);
throw new IllegalArgumentException(
String.format("Dimension value cannot be greater than %s for vector: %s", MAX_DIMENSION, name)
);
}

if (value <= 0) {
throw new IllegalArgumentException("Dimension value must be greater than 0 " + "for vector: " + name);
throw new IllegalArgumentException(String.format("Dimension value must be greater than 0 for vector: %s", name));
}
return value;
}, m -> toType(m).dimension);
Expand Down Expand Up @@ -285,12 +275,12 @@ public Mapper.Builder<?> parse(String name, Map<String, Object> node, ParserCont
// is done before any mappers are built. Therefore, validation should be done during parsing
// so that it can fail early.
if (builder.knnMethodContext.get() != null && builder.modelId.get() != null) {
throw new IllegalArgumentException("Method and model can not be both specified in the mapping: " + name);
throw new IllegalArgumentException(String.format("Method and model can not be both specified in the mapping: %s", name));
}

// Dimension should not be null unless modelId is used
if (builder.dimension.getValue() == -1 && builder.modelId.get() == null) {
throw new IllegalArgumentException("Dimension value missing for vector: " + name);
throw new IllegalArgumentException(String.format("Dimension value missing for vector: %s", name));
}

return builder;
Expand Down Expand Up @@ -337,7 +327,7 @@ public Query existsQuery(QueryShardContext context) {
public Query termQuery(Object value, QueryShardContext context) {
throw new QueryShardException(
context,
"KNN vector do not support exact searching, use KNN queries " + "instead: [" + name() + "]"
String.format("KNN vector do not support exact searching, use KNN queries instead: [%s]", name())
);
}

Expand Down Expand Up @@ -392,16 +382,39 @@ protected void parseCreateField(ParseContext context) throws IOException {

protected void parseCreateField(ParseContext context, int dimension) throws IOException {

if (!KNNSettings.isKNNPluginEnabled()) {
throw new IllegalStateException("KNN plugin is disabled. To enable " + "update knn.plugin.enabled setting to true");
validateIfKNNPluginEnabled();
validateIfCircuitBreakerIsNotTriggered();

Optional<float[]> arrayOptional = getFloatsFromContext(context, dimension);

if (!arrayOptional.isPresent()) {
return;
}
final float[] array = arrayOptional.get();
VectorField point = new VectorField(name(), array, fieldType);

context.doc().add(point);
if (fieldType.stored()) {
context.doc().add(new StoredField(name(), point.toString()));
}
context.path().remove();
}

void validateIfCircuitBreakerIsNotTriggered() {
if (KNNSettings.isCircuitBreakerTriggered()) {
throw new IllegalStateException(
"Indexing knn vector fields is rejected as circuit breaker triggered." + " Check _opendistro/_knn/stats for detailed state"
"Indexing knn vector fields is rejected as circuit breaker triggered. Check _opendistro/_knn/stats for detailed state"
);
}
}

void validateIfKNNPluginEnabled() {
if (!KNNSettings.isKNNPluginEnabled()) {
throw new IllegalStateException("KNN plugin is disabled. To enable update knn.plugin.enabled setting to true");
}
}

Optional<float[]> getFloatsFromContext(ParseContext context, int dimension) throws IOException {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part will be common with Lucene field mapper, extract it to a method to avoid code duplication

context.path().add(simpleName());

ArrayList<Float> vector = new ArrayList<>();
Expand Down Expand Up @@ -438,7 +451,7 @@ protected void parseCreateField(ParseContext context, int dimension) throws IOEx
context.parser().nextToken();
} else if (token == XContentParser.Token.VALUE_NULL) {
context.path().remove();
return;
return Optional.empty();
}

if (dimension != vector.size()) {
Expand All @@ -451,14 +464,7 @@ protected void parseCreateField(ParseContext context, int dimension) throws IOEx
for (Float f : vector) {
array[i++] = f;
}

VectorField point = new VectorField(name(), array, fieldType);

context.doc().add(point);
if (fieldType.stored()) {
context.doc().add(new StoredField(name(), point.toString()));
}
context.path().remove();
return Optional.of(array);
}

@Override
Expand Down Expand Up @@ -505,187 +511,4 @@ public static class Defaults {
FIELD_TYPE.freeze();
}
}

/**
* Field mapper for original implementation
*/
protected static class LegacyFieldMapper extends KNNVectorFieldMapper {

protected String spaceType;
protected String m;
protected String efConstruction;

private LegacyFieldMapper(
String simpleName,
KNNVectorFieldType mappedFieldType,
MultiFields multiFields,
CopyTo copyTo,
Explicit<Boolean> ignoreMalformed,
boolean stored,
boolean hasDocValues,
String spaceType,
String m,
String efConstruction
) {
super(simpleName, mappedFieldType, multiFields, copyTo, ignoreMalformed, stored, hasDocValues);

this.spaceType = spaceType;
this.m = m;
this.efConstruction = efConstruction;

this.fieldType = new FieldType(KNNVectorFieldMapper.Defaults.FIELD_TYPE);

this.fieldType.putAttribute(DIMENSION, String.valueOf(dimension));
this.fieldType.putAttribute(SPACE_TYPE, spaceType);
this.fieldType.putAttribute(KNN_ENGINE, KNNEngine.NMSLIB.getName());

// These are extra just for legacy
this.fieldType.putAttribute(HNSW_ALGO_M, m);
this.fieldType.putAttribute(HNSW_ALGO_EF_CONSTRUCTION, efConstruction);

this.fieldType.freeze();
}

@Override
public ParametrizedFieldMapper.Builder getMergeBuilder() {
return new KNNVectorFieldMapper.Builder(simpleName(), this.spaceType, this.m, this.efConstruction).init(this);
}

static String getSpaceType(Settings indexSettings) {
String spaceType = indexSettings.get(KNNSettings.INDEX_KNN_SPACE_TYPE.getKey());
if (spaceType == null) {
logger.info(
"[KNN] The setting \""
+ METHOD_PARAMETER_SPACE_TYPE
+ "\" was not set for the index. "
+ "Likely caused by recent version upgrade. Setting the setting to the default value="
+ KNNSettings.INDEX_KNN_DEFAULT_SPACE_TYPE
);
return KNNSettings.INDEX_KNN_DEFAULT_SPACE_TYPE;
}
return spaceType;
}

static String getM(Settings indexSettings) {
String m = indexSettings.get(KNNSettings.INDEX_KNN_ALGO_PARAM_M_SETTING.getKey());
if (m == null) {
logger.info(
"[KNN] The setting \""
+ HNSW_ALGO_M
+ "\" was not set for the index. "
+ "Likely caused by recent version upgrade. Setting the setting to the default value="
+ KNNSettings.INDEX_KNN_DEFAULT_ALGO_PARAM_M
);
return String.valueOf(KNNSettings.INDEX_KNN_DEFAULT_ALGO_PARAM_M);
}
return m;
}

static String getEfConstruction(Settings indexSettings) {
String efConstruction = indexSettings.get(KNNSettings.INDEX_KNN_ALGO_PARAM_EF_CONSTRUCTION_SETTING.getKey());
if (efConstruction == null) {
logger.info(
"[KNN] The setting \""
+ HNSW_ALGO_EF_CONSTRUCTION
+ "\" was not set for"
+ " the index. Likely caused by recent version upgrade. Setting the setting to the default value="
+ KNNSettings.INDEX_KNN_DEFAULT_ALGO_PARAM_EF_CONSTRUCTION
);
return String.valueOf(KNNSettings.INDEX_KNN_DEFAULT_ALGO_PARAM_EF_CONSTRUCTION);
}
return efConstruction;
}
}

/**
* Field mapper for method definition in mapping
*/
protected static class MethodFieldMapper extends KNNVectorFieldMapper {

private MethodFieldMapper(
String simpleName,
KNNVectorFieldType mappedFieldType,
MultiFields multiFields,
CopyTo copyTo,
Explicit<Boolean> ignoreMalformed,
boolean stored,
boolean hasDocValues,
KNNMethodContext knnMethodContext
) {

super(simpleName, mappedFieldType, multiFields, copyTo, ignoreMalformed, stored, hasDocValues);

this.knnMethod = knnMethodContext;

this.fieldType = new FieldType(KNNVectorFieldMapper.Defaults.FIELD_TYPE);

this.fieldType.putAttribute(DIMENSION, String.valueOf(dimension));
this.fieldType.putAttribute(SPACE_TYPE, knnMethodContext.getSpaceType().getValue());

KNNEngine knnEngine = knnMethodContext.getKnnEngine();
this.fieldType.putAttribute(KNN_ENGINE, knnEngine.getName());

try {
this.fieldType.putAttribute(
PARAMETERS,
Strings.toString(XContentFactory.jsonBuilder().map(knnEngine.getMethodAsMap(knnMethodContext)))
);
} catch (IOException ioe) {
throw new RuntimeException("Unable to create KNNVectorFieldMapper: " + ioe);
}

this.fieldType.freeze();
}
}

/**
* Field mapper for model in mapping
*/
protected static class ModelFieldMapper extends KNNVectorFieldMapper {

private ModelFieldMapper(
String simpleName,
KNNVectorFieldType mappedFieldType,
MultiFields multiFields,
CopyTo copyTo,
Explicit<Boolean> ignoreMalformed,
boolean stored,
boolean hasDocValues,
ModelDao modelDao,
String modelId
) {
super(simpleName, mappedFieldType, multiFields, copyTo, ignoreMalformed, stored, hasDocValues);

this.modelId = modelId;
this.modelDao = modelDao;

this.fieldType = new FieldType(KNNVectorFieldMapper.Defaults.FIELD_TYPE);
this.fieldType.putAttribute(MODEL_ID, modelId);
this.fieldType.freeze();
}

@Override
protected void parseCreateField(ParseContext context) throws IOException {
// For the model field mapper, we cannot validate the model during index creation due to
// an issue with reading cluster state during mapper creation. So, we need to validate the
// model when ingestion starts.
ModelMetadata modelMetadata = this.modelDao.getMetadata(modelId);

if (modelMetadata == null) {
throw new IllegalStateException(
"Model \""
+ modelId
+ "\" from "
+ context.mapperService().index().getName()
+ "'s mapping does not exist. Because the "
+ "\""
+ MODEL_ID
+ "\" parameter is not updateable, this index will need to "
+ "be recreated with a valid model."
);
}

parseCreateField(context, modelMetadata.getDimension());
}
}
}
Loading