Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor knn type and codecs #439

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 4 additions & 29 deletions src/main/java/org/opensearch/knn/index/KNNMethod.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

package org.opensearch.knn.index;

import lombok.AllArgsConstructor;
import lombok.Getter;
import org.opensearch.common.ValidationException;
import org.opensearch.knn.common.KNNConstants;

Expand All @@ -26,31 +28,13 @@
* KNNMethod is used to define the structure of a method supported by a particular k-NN library. It is used to validate
* the KNNMethodContext passed in by the user. It is also used to provide superficial string translations.
*/
@AllArgsConstructor
@Getter
public class KNNMethod {

private final MethodComponent methodComponent;
private final Set<SpaceType> spaces;

/**
* KNNMethod Constructor
*
* @param methodComponent top level method component that is compatible with the underlying library
* @param spaces set of valid space types that the method supports
*/
public KNNMethod(MethodComponent methodComponent, Set<SpaceType> spaces) {
this.methodComponent = methodComponent;
this.spaces = spaces;
}

/**
* getMainMethodComponent
*
* @return mainMethodComponent
*/
public MethodComponent getMethodComponent() {
return methodComponent;
}

/**
* Determines whether the provided space is supported for this method
*
Expand All @@ -61,15 +45,6 @@ public boolean containsSpace(SpaceType space) {
return spaces.contains(space);
}

/**
* Get all valid spaces for this method
*
* @return spaces that can be used with this method
*/
public Set<SpaceType> getSpaces() {
return spaces;
}

/**
* Validate that the configured KNNMethodContext is valid for this method
*
Expand Down
48 changes: 4 additions & 44 deletions src/main/java/org/opensearch/knn/index/KNNMethodContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@

package org.opensearch.knn.index;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import lombok.AllArgsConstructor;
import lombok.Getter;
import org.opensearch.common.ValidationException;
import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.io.stream.StreamOutput;
Expand Down Expand Up @@ -40,10 +40,10 @@
* KNNMethodContext will contain the information necessary to produce a library index from an Opensearch mapping.
* It will encompass all parameters necessary to build the index.
*/
@AllArgsConstructor
@Getter
public class KNNMethodContext implements ToXContentFragment, Writeable {

private static final Logger logger = LogManager.getLogger(KNNMethodContext.class);

private static KNNMethodContext defaultInstance = null;

public static synchronized KNNMethodContext getDefault() {
Expand All @@ -61,19 +61,6 @@ public static synchronized KNNMethodContext getDefault() {
private final SpaceType spaceType;
private final MethodComponentContext methodComponent;

/**
* Constructor
*
* @param knnEngine engine that this method uses
* @param spaceType space type that this method uses
* @param methodComponent MethodComponent describing the main index
*/
public KNNMethodContext(KNNEngine knnEngine, SpaceType spaceType, MethodComponentContext methodComponent) {
this.knnEngine = knnEngine;
this.spaceType = spaceType;
this.methodComponent = methodComponent;
}

/**
* Constructor from stream.
*
Expand All @@ -86,33 +73,6 @@ public KNNMethodContext(StreamInput in) throws IOException {
this.methodComponent = new MethodComponentContext(in);
}

/**
* Gets the main method component
*
* @return methodComponent
*/
public MethodComponentContext getMethodComponent() {
return methodComponent;
}

/**
* Gets the engine to be used for this context
*
* @return knnEngine
*/
public KNNEngine getEngine() {
return knnEngine;
}

/**
* Gets the space type for this context
*
* @return spaceType
*/
public SpaceType getSpaceType() {
return spaceType;
}

/**
* This method uses the knnEngine to validate that the method is compatible with the engine
*
Expand Down
27 changes: 13 additions & 14 deletions src/main/java/org/opensearch/knn/index/KNNVectorFieldMapper.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.knn.index;

import lombok.Getter;
import org.opensearch.common.Strings;
import org.opensearch.common.ValidationException;
import org.opensearch.common.xcontent.XContentFactory;
Expand Down Expand Up @@ -206,7 +207,7 @@ public KNNVectorFieldMapper build(BuilderContext context) {
if (knnMethodContext != null) {
return new MethodFieldMapper(
name,
new KNNVectorFieldType(buildFullName(context), meta.getValue(), dimension.getValue()),
new KNNVectorFieldType(buildFullName(context), meta.getValue(), dimension.getValue(), knnMethodContext),
multiFieldsBuilder.build(this, context),
copyTo.build(),
ignoreMalformed(context),
Expand All @@ -225,7 +226,7 @@ public KNNVectorFieldMapper build(BuilderContext context) {

return new ModelFieldMapper(
name,
new KNNVectorFieldType(buildFullName(context), meta.getValue(), -1, modelIdAsString),
new KNNVectorFieldType(buildFullName(context), meta.getValue(), -1, knnMethodContext, modelIdAsString),
multiFieldsBuilder.build(this, context),
copyTo.build(),
ignoreMalformed(context),
Expand Down Expand Up @@ -296,19 +297,25 @@ public Mapper.Builder<?> parse(String name, Map<String, Object> node, ParserCont
}
}

@Getter
public static class KNNVectorFieldType extends MappedFieldType {

int dimension;
String modelId;
KNNMethodContext knnMethodContext;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will changing this impact BWC?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It shouldn't, the only case I can think of is if do older clusters after migration this field is null. I can add extra checks in next PR when we add lucene engine. Can we run any specific bwc test to make sure there are no side-effects?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont think it will be a problem. When mapping reaches new node, the parser will parse it. Since we arent changing the parser, it should be okay. The MappedFieldType is not serialized, so it will be okay.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, exactly


public KNNVectorFieldType(String name, Map<String, String> meta, int dimension) {
this(name, meta, dimension, null);
this(name, meta, dimension, null, null);
}

public KNNVectorFieldType(String name, Map<String, String> meta, int dimension, KNNMethodContext knnMethodContext) {
this(name, meta, dimension, knnMethodContext, null);
}

public KNNVectorFieldType(String name, Map<String, String> meta, int dimension, String modelId) {
public KNNVectorFieldType(String name, Map<String, String> meta, int dimension, KNNMethodContext knnMethodContext, String modelId) {
super(name, false, false, true, TextSearchInfo.NONE, meta);
this.dimension = dimension;
this.modelId = modelId;
this.knnMethodContext = knnMethodContext;
}

@Override
Expand All @@ -334,14 +341,6 @@ public Query termQuery(Object value, QueryShardContext context) {
);
}

public int getDimension() {
return dimension;
}

public String getModelId() {
return modelId;
}

@Override
public IndexFieldData.Builder fielddataBuilder(String fullyQualifiedIndexName, Supplier<SearchLookup> searchLookup) {
failIfNoDocValues();
Expand Down Expand Up @@ -623,7 +622,7 @@ private MethodFieldMapper(
this.fieldType.putAttribute(DIMENSION, String.valueOf(dimension));
this.fieldType.putAttribute(SPACE_TYPE, knnMethodContext.getSpaceType().getValue());

KNNEngine knnEngine = knnMethodContext.getEngine();
KNNEngine knnEngine = knnMethodContext.getKnnEngine();
this.fieldType.putAttribute(KNN_ENGINE, knnEngine.getName());

try {
Expand Down
21 changes: 3 additions & 18 deletions src/main/java/org/opensearch/knn/index/MethodComponent.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

package org.opensearch.knn.index;

import lombok.Getter;
import org.opensearch.common.TriFunction;
import org.opensearch.common.ValidationException;
import org.opensearch.knn.common.KNNConstants;
Expand All @@ -26,7 +27,9 @@
*/
public class MethodComponent {

@Getter
private String name;
@Getter
private Map<String, Parameter<?>> parameters;
private BiFunction<MethodComponent, MethodComponentContext, Map<String, Object>> mapGenerator;
private TriFunction<MethodComponent, MethodComponentContext, Integer, Long> overheadInKBEstimator;
Expand All @@ -45,24 +48,6 @@ private MethodComponent(Builder builder) {
this.requiresTraining = builder.requiresTraining;
}

/**
* Get the name of the component
*
* @return name
*/
public String getName() {
return name;
}

/**
* Get the parameters for the component
*
* @return parameters
*/
public Map<String, Parameter<?>> getParameters() {
return parameters;
}

/**
* Parse methodComponentContext into a map that the library can use to configure the method
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@

package org.opensearch.knn.index;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import lombok.AllArgsConstructor;
import lombok.Getter;
import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.io.stream.StreamOutput;
import org.opensearch.common.io.stream.Writeable;
Expand All @@ -36,24 +36,13 @@
*
* Each component is composed of a name and a map of parameters.
*/
@AllArgsConstructor
public class MethodComponentContext implements ToXContentFragment, Writeable {

private static final Logger logger = LogManager.getLogger(MethodComponentContext.class);

@Getter
private final String name;
private final Map<String, Object> parameters;

/**
* Constructor
*
* @param name component name
* @param parameters component parameters
*/
public MethodComponentContext(String name, Map<String, Object> parameters) {
this.name = name;
this.parameters = parameters;
}

/**
* Constructor from stream.
*
Expand Down Expand Up @@ -183,15 +172,6 @@ public int hashCode() {
return new HashCodeBuilder().append(name).append(parameters).toHashCode();
}

/**
* Gets the name of the component
*
* @return name
*/
public String getName() {
return name;
}

/**
* Gets the parameters of the component
*
Expand Down
23 changes: 13 additions & 10 deletions src/main/java/org/opensearch/knn/index/codec/KNNCodecFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,33 +7,36 @@
import com.google.common.collect.ImmutableMap;
import org.apache.lucene.codecs.Codec;
import org.apache.lucene.backward_codecs.lucene91.Lucene91Codec;
import org.opensearch.knn.index.codec.KNN910Codec.KNN910Codec;
import org.opensearch.index.mapper.MapperService;
import org.opensearch.knn.index.codec.util.CodecBuilder;

import java.lang.reflect.Constructor;
import java.util.Map;

/**
* Factory abstraction for KNN codec
*/
public class KNNCodecFactory {

private static Map<KNNCodecVersion, Class> CODEC_BY_VERSION = ImmutableMap.of(KNNCodecVersion.KNN910, KNN910Codec.class);
private final Map<KNNCodecVersion, CodecBuilder> codecByVersion;

private static KNNCodecVersion LATEST_KNN_CODEC_VERSION = KNNCodecVersion.KNN910;
private static final KNNCodecVersion LATEST_KNN_CODEC_VERSION = KNNCodecVersion.KNN910;

public static Codec createKNNCodec(final Codec userCodec) {
public KNNCodecFactory(MapperService mapperService) {
codecByVersion = ImmutableMap.of(KNNCodecVersion.KNN910, new CodecBuilder.KNN91CodecBuilder(mapperService));
}

public Codec createKNNCodec(final Codec userCodec) {
return getCodec(LATEST_KNN_CODEC_VERSION, userCodec);
}

public static Codec createKNNCodec(final KNNCodecVersion knnCodecVersion, final Codec userCodec) {
public Codec createKNNCodec(final KNNCodecVersion knnCodecVersion, final Codec userCodec) {
return getCodec(knnCodecVersion, userCodec);
}

private static Codec getCodec(final KNNCodecVersion knnCodecVersion, final Codec userCodec) {
private Codec getCodec(final KNNCodecVersion knnCodecVersion, final Codec userCodec) {
try {
Constructor<?> constructor = CODEC_BY_VERSION.getOrDefault(knnCodecVersion, CODEC_BY_VERSION.get(LATEST_KNN_CODEC_VERSION))
.getConstructor(Codec.class);
return (Codec) constructor.newInstance(userCodec);
final CodecBuilder codecBuilder = codecByVersion.getOrDefault(knnCodecVersion, codecByVersion.get(LATEST_KNN_CODEC_VERSION));
return codecBuilder.userCodec(userCodec).build();
} catch (Exception ex) {
throw new RuntimeException("Cannot create instance of KNN codec", ex);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,11 @@
*/
public class KNNCodecService extends CodecService {

private final KNNCodecFactory knnCodecFactory;

public KNNCodecService(CodecServiceConfig codecServiceConfig) {
super(codecServiceConfig.getMapperService(), codecServiceConfig.getLogger());
knnCodecFactory = new KNNCodecFactory(codecServiceConfig.getMapperService());
}

/**
Expand All @@ -26,6 +29,6 @@ public KNNCodecService(CodecServiceConfig codecServiceConfig) {
*/
@Override
public Codec codec(String name) {
return KNNCodecFactory.createKNNCodec(super.codec(name));
return knnCodecFactory.createKNNCodec(super.codec(name));
}
}
Loading