Skip to content

Commit

Permalink
config pooling and l2norm for autoencoder (#1722)
Browse files Browse the repository at this point in the history
* config pooling and l2norm for autoencoder
* config prefix for query and document encoder
  • Loading branch information
MXueguang committed Nov 29, 2023
1 parent e6700f6 commit 723e06c
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 15 deletions.
17 changes: 8 additions & 9 deletions pyserini/encode/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,8 @@
"openai-api": OpenAIDocumentEncoder,
"auto": AutoDocumentEncoder,
}
ALLOWED_POOLING_OPTS = ["cls","mean"]

def init_encoder(encoder, encoder_class, device):
def init_encoder(encoder, encoder_class, device, pooling, l2_norm, prefix):
_encoder_class = encoder_class

# determine encoder_class
Expand All @@ -52,6 +51,7 @@ def init_encoder(encoder, encoder_class, device):
# if none of the class keyword was matched,
# use the AutoDocumentEncoder
if encoder_class is None:
_encoder_class = "auto"
encoder_class = AutoDocumentEncoder

# prepare arguments to encoder class
Expand All @@ -60,6 +60,8 @@ def init_encoder(encoder, encoder_class, device):
kwargs.update(dict(pooling='mean', l2_norm=True))
if (_encoder_class == "contriever") or ("contriever" in encoder):
kwargs.update(dict(pooling='mean', l2_norm=False))
if (_encoder_class == "auto"):
kwargs.update(dict(pooling=pooling, l2_norm=l2_norm, prefix=prefix))
return encoder_class(**kwargs)


Expand Down Expand Up @@ -117,19 +119,16 @@ def parse_args(parser, commands):
default='cuda:0', required=False)
encoder_parser.add_argument('--fp16', action='store_true', default=False)
encoder_parser.add_argument('--add-sep', action='store_true', default=False)
encoder_parser.add_argument('--pooling', type=str, default='cls', help='for auto classes, allow the ability to dictate pooling strategy', required=False)
encoder_parser.add_argument('--pooling', type=str, default='cls', help='for auto classes, allow the ability to dictate pooling strategy', choices=['cls', 'mean'], required=False)
encoder_parser.add_argument('--l2-norm', action='store_true', help='whether to normalize embedding', default=False, required=False)
encoder_parser.add_argument('--prefix', type=str, help='prefix of document input', default=None, required=False)
encoder_parser.add_argument('--use-openai', help='use OpenAI text-embedding-ada-002 to retreive embeddings', action='store_true', default=False)
encoder_parser.add_argument('--rate-limit', type=int, help='rate limit of the requests per minute for OpenAI embeddings', default=3500, required=False)

args = parse_args(parser, commands)
delimiter = args.input.delimiter.replace("\\n", "\n") # argparse would add \ prior to the passed '\n\n'

encoder = init_encoder(args.encoder.encoder, args.encoder.encoder_class, device=args.encoder.device)
if type(encoder).__name__ == "AutoDocumentEncoder":
if args.encoder.pooling in ALLOWED_POOLING_OPTS:
encoder.pooling = args.encoder.pooling
else:
raise ValueError(f"Only allowed to use pooling types {ALLOWED_POOLING_OPTS}. You entered {args.encoder.pooling}")
encoder = init_encoder(args.encoder.encoder, args.encoder.encoder_class, device=args.encoder.device, pooling=args.encoder.pooling, l2_norm=args.encoder.l2_norm, prefix=args.encoder.prefix)
if args.output.to_faiss:
embedding_writer = FaissRepresentationWriter(args.output.embeddings, dimension=args.encoder.dimension)
else:
Expand Down
5 changes: 4 additions & 1 deletion pyserini/encode/_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@


class AutoDocumentEncoder(DocumentEncoder):
def __init__(self, model_name, tokenizer_name=None, device='cuda:0', pooling='cls', l2_norm=False):
def __init__(self, model_name, tokenizer_name=None, device='cuda:0', pooling='cls', l2_norm=False, prefix=None):
self.device = device
self.model = AutoModel.from_pretrained(model_name)
self.model.to(self.device)
Expand All @@ -33,8 +33,11 @@ def __init__(self, model_name, tokenizer_name=None, device='cuda:0', pooling='cl
self.has_model = True
self.pooling = pooling
self.l2_norm = l2_norm
self.prefix = prefix

def encode(self, texts, titles=None, max_length=256, add_sep=False, **kwargs):
if self.prefix is not None:
texts = [f'{self.prefix} {text}' for text in texts]
shared_tokenizer_kwargs = dict(
max_length=max_length,
truncation=True,
Expand Down
11 changes: 8 additions & 3 deletions pyserini/encode/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from pyserini.encode import UniCoilQueryEncoder, SpladeQueryEncoder, OpenAIQueryEncoder


def init_encoder(encoder, device):
def init_encoder(encoder, device, pooling, l2_norm, prefix):
if 'dpr' in encoder.lower():
return DprQueryEncoder(encoder, device=device)
elif 'tct' in encoder.lower():
Expand All @@ -40,7 +40,7 @@ def init_encoder(encoder, device):
elif 'openai-api' in encoder.lower():
return OpenAIQueryEncoder()
else:
return AutoQueryEncoder(encoder, device=device)
return AutoQueryEncoder(encoder, device=device, pooling=pooling, l2_norm=l2_norm, prefix=prefix)


if __name__ == '__main__':
Expand All @@ -54,9 +54,14 @@ def init_encoder(encoder, device):
parser.add_argument('--device', type=str, help='device cpu or cuda [cuda:0, cuda:1...]',
default='cpu', required=False)
parser.add_argument('--max-length', type=int, help='max length', default=256, required=False)
parser.add_argument('--pooling', type=str, help='pooling strategy', default='cls', choices=['cls', 'mean'],
required=False)
parser.add_argument('--l2-norm', action='store_true', help='whether to normalize embedding', default=False,
required=False)
parser.add_argument('--prefx', type=str, help='prefix query input', default=None, required=False)
args = parser.parse_args()

encoder = init_encoder(args.encoder, device=args.device)
encoder = init_encoder(args.encoder, device=args.device, pooling=args.pooling, l2_norm=args.l2_norm, prefix=args.prefx)
query_iterator = DefaultQueryIterator.from_topics(args.topics)

is_sparse = False
Expand Down
12 changes: 10 additions & 2 deletions pyserini/search/faiss/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ def define_dsearch_args(parser):
parser.add_argument('--encoder', type=str, metavar='path to query encoder checkpoint or encoder name',
required=False,
help="Path to query encoder pytorch checkpoint or hgf encoder model name")
parser.add_argument('--pooling', type=str, metavar='pooling strategy', required=False, default='cls',
choices=['cls', 'mean'],
help="Pooling strategy for query encoder")
parser.add_argument('--l2-norm', action='store_true', help='whether to normalize embedding', default=False,
required=False)
parser.add_argument('--tokenizer', type=str, metavar='name or path',
required=False,
help="Path to a hgf tokenizer name or path")
Expand Down Expand Up @@ -85,7 +90,7 @@ def define_dsearch_args(parser):
help="Set efSearch for HNSW index")


def init_query_encoder(encoder, encoder_class, tokenizer_name, topics_name, encoded_queries, device, prefix, max_length):
def init_query_encoder(encoder, encoder_class, tokenizer_name, topics_name, encoded_queries, device, max_length, pooling, l2_norm, prefix):
encoded_queries_map = {
'msmarco-passage-dev-subset': 'tct_colbert-msmarco-passage-dev-subset',
'dpr-nq-dev': 'dpr_multi-nq-dev',
Expand Down Expand Up @@ -126,6 +131,7 @@ def init_query_encoder(encoder, encoder_class, tokenizer_name, topics_name, enco
# if none of the class keyword was matched,
# use the AutoQueryEncoder
if encoder_class is None:
_encoder_class = "auto"
encoder_class = AutoQueryEncoder

# prepare arguments to encoder class
Expand All @@ -136,6 +142,8 @@ def init_query_encoder(encoder, encoder_class, tokenizer_name, topics_name, enco
kwargs.update(dict(pooling='mean', l2_norm=False))
if (_encoder_class == "openai-api") or ("openai" in encoder):
kwargs.update(dict(max_length=max_length))
if (_encoder_class == "auto"):
kwargs.update(dict(pooling=pooling, l2_norm=l2_norm, prefix=prefix))
return encoder_class(**kwargs)

if encoded_queries:
Expand Down Expand Up @@ -188,7 +196,7 @@ def init_query_encoder(encoder, encoder_class, tokenizer_name, topics_name, enco
topics = query_iterator.topics

query_encoder = init_query_encoder(
args.encoder, args.encoder_class, args.tokenizer, args.topics, args.encoded_queries, args.device, args.query_prefix, args.max_length)
args.encoder, args.encoder_class, args.tokenizer, args.topics, args.encoded_queries, args.device, args.max_length, args.pooling, args.l2_norm, args.query_prefix)
if args.pca_model:
query_encoder = PcaEncoder(query_encoder, args.pca_model)
kwargs = {}
Expand Down

0 comments on commit 723e06c

Please sign in to comment.