From 516ae39bcf42bea3f9fc14f5a9009644dd6b9099 Mon Sep 17 00:00:00 2001 From: xingchensong Date: Wed, 18 Sep 2024 20:29:16 +0800 Subject: [PATCH] [s3tokenizer] Support online extraction --- README.md | 7 +++++-- s3tokenizer/model.py | 5 +++++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index f4363c4..76b81ac 100644 --- a/README.md +++ b/README.md @@ -117,10 +117,13 @@ class SpeechLLM(nn.Module): def __init__(self, ...): ... self.speech_tokenizer = s3tokenizer.load_model("speech_tokenizer_v1") + self.speech_tokenizer.freeze() def forward(self, speech: Tensor, speech_lens: Tensor, text_ids: Tensor, ...): ... - speech_codes = self.speech_tokenizer(speech, speech_lens) + speech_codes, speech_codes_lens = self.speech_tokenizer.quantize(speech, speech_lens) + speech_codes = speech_codes.clone() # for backward compatbility + speech_codes_lens = speeech_codes_lens.clone() # for backward compatbility ``` @@ -133,4 +136,4 @@ class SpeechLLM(nn.Module): - [x] Usage-1: Offline batch inference - [x] Usage-2: Distributed offline batch inference via command-line tools -- [ ] Usage-3: Online speech code extraction +- [x] Usage-3: Online speech code extraction diff --git a/s3tokenizer/model.py b/s3tokenizer/model.py index bc293a3..d911ba0 100644 --- a/s3tokenizer/model.py +++ b/s3tokenizer/model.py @@ -302,3 +302,8 @@ def init_from_onnx(self, onnx_path: str): def init_from_pt(self, ckpt_path: str): ckpt = torch.load(ckpt_path, map_location="cpu", mmap=True) self.load_state_dict(ckpt, strict=True) + + def freeze(self): + for _, param in self.named_parameters(): + param.requires_grad = False +