Skip to content

Commit

Permalink
[s3tokenizer] Support online extraction
Browse files Browse the repository at this point in the history
  • Loading branch information
xingchensong committed Sep 18, 2024
1 parent ef156fd commit 516ae39
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 2 deletions.
7 changes: 5 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,13 @@ class SpeechLLM(nn.Module):
def __init__(self, ...):
...
self.speech_tokenizer = s3tokenizer.load_model("speech_tokenizer_v1")
self.speech_tokenizer.freeze()

def forward(self, speech: Tensor, speech_lens: Tensor, text_ids: Tensor, ...):
...
speech_codes = self.speech_tokenizer(speech, speech_lens)
speech_codes, speech_codes_lens = self.speech_tokenizer.quantize(speech, speech_lens)
speech_codes = speech_codes.clone() # for backward compatbility
speech_codes_lens = speeech_codes_lens.clone() # for backward compatbility
```

</sub>
Expand All @@ -133,4 +136,4 @@ class SpeechLLM(nn.Module):

- [x] Usage-1: Offline batch inference
- [x] Usage-2: Distributed offline batch inference via command-line tools
- [ ] Usage-3: Online speech code extraction
- [x] Usage-3: Online speech code extraction
5 changes: 5 additions & 0 deletions s3tokenizer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,3 +302,8 @@ def init_from_onnx(self, onnx_path: str):
def init_from_pt(self, ckpt_path: str):
ckpt = torch.load(ckpt_path, map_location="cpu", mmap=True)
self.load_state_dict(ckpt, strict=True)

def freeze(self):
for _, param in self.named_parameters():
param.requires_grad = False

0 comments on commit 516ae39

Please sign in to comment.