RikkaBotan commited on
Commit
8c123d7
·
verified ·
1 Parent(s): 6992684

Upload SSE_quantize.py

Browse files
Files changed (1) hide show
  1. SSE_quantize.py +228 -0
SSE_quantize.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ coding = utf-8
3
+ Copyright 2026 Rikka Botan. All rights reserved
4
+ Licensed under "MIT License"
5
+ Stable Static Embedding official PyTorch implementation
6
+ """
7
+
8
+ from __future__ import annotations
9
+ import os
10
+ from pathlib import Path
11
+ from safetensors.torch import save_file as save_safetensors_file
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ import numpy as np
16
+ from typing import Dict
17
+ from dataclasses import dataclass
18
+ from tokenizers import Tokenizer
19
+ from transformers import PreTrainedTokenizerFast
20
+ from sentence_transformers.models.InputModule import InputModule
21
+ from safetensors.torch import load_file
22
+
23
+
24
+ def quantize_q4_k_m(weight: torch.Tensor):
25
+ """
26
+ weight: (vocab, dim)
27
+ returns: packed uint8 + scale + zero
28
+ """
29
+ w = weight.detach().cpu().numpy().astype(np.float32)
30
+
31
+ scales = np.max(np.abs(w), axis=1, keepdims=True) + 1e-8
32
+ w_norm = w / scales
33
+
34
+ q = np.clip(np.round((w_norm + 1) * 7.5), 0, 15).astype(np.uint8)
35
+
36
+ # pack 2x4bit -> uint8
37
+ packed = (q[:, 0::2] << 4) | q[:, 1::2]
38
+
39
+ return {
40
+ "packed": packed,
41
+ "scales": scales.astype(np.float32),
42
+ }
43
+
44
+
45
+ def dequantize_q4_k_m(packed: np.ndarray, scales: np.ndarray):
46
+ hi = (packed >> 4) & 0xF
47
+ lo = packed & 0xF
48
+
49
+ q = np.empty((packed.shape[0], packed.shape[1]*2), dtype=np.uint8)
50
+ q[:, 0::2] = hi
51
+ q[:, 1::2] = lo
52
+
53
+ w = (q.astype(np.float32) / 7.5) - 1.0
54
+ w = w * scales
55
+ return torch.from_numpy(w)
56
+
57
+
58
+ class SeparableDyT(nn.Module):
59
+ def __init__(
60
+ self,
61
+ hidden_dim: int,
62
+ alpha_init: float = 0.5
63
+ ):
64
+ super().__init__()
65
+ self.alpha = nn.Parameter(alpha_init*torch.ones(hidden_dim))
66
+ self.beta = nn.Parameter(torch.ones(hidden_dim))
67
+ self.bias = nn.Parameter(torch.zeros(hidden_dim))
68
+
69
+ def forward(
70
+ self,
71
+ x: torch.Tensor
72
+ ) -> torch.Tensor:
73
+ x = self.beta * F.tanh(self.alpha * x + self.bias)
74
+ return x
75
+
76
+
77
+ class SSEQ(InputModule):
78
+ """
79
+ Stable Static Embedding (SSE)
80
+ StaticEmbedding-compatible Sentence-Transformers module
81
+ """
82
+
83
+ def __init__(
84
+ self,
85
+ tokenizer: Tokenizer | PreTrainedTokenizerFast,
86
+ vocab_size: int,
87
+ hidden_dim: int = 1024,
88
+ **kwargs,
89
+ ):
90
+ super().__init__()
91
+
92
+ if isinstance(tokenizer, PreTrainedTokenizerFast):
93
+ tokenizer = tokenizer._tokenizer
94
+ elif not isinstance(tokenizer, Tokenizer):
95
+ raise ValueError("Tokenizer must be a fast (Rust) tokenizer")
96
+
97
+ self.tokenizer: Tokenizer = tokenizer
98
+ self.tokenizer.no_padding()
99
+
100
+ self.embedding = nn.EmbeddingBag(vocab_size, hidden_dim)
101
+ self.dyt = SeparableDyT(hidden_dim)
102
+
103
+ self.embedding_dim = hidden_dim
104
+
105
+ # For model card compatibility
106
+ self.base_model = kwargs.get("base_model", None)
107
+
108
+ # Tokenization (StaticEmbedding-compatible)
109
+ def tokenize(
110
+ self,
111
+ texts: list[str],
112
+ **kwargs
113
+ ) -> dict[str, torch.Tensor]:
114
+ encodings = self.tokenizer.encode_batch(texts, add_special_tokens=False)
115
+ encodings_ids = [encoding.ids for encoding in encodings]
116
+
117
+ offsets = torch.from_numpy(
118
+ np.cumsum(
119
+ [0] + [len(token_ids) for token_ids in encodings_ids[:-1]]
120
+ )
121
+ )
122
+ input_ids = torch.tensor(
123
+ [token_id for token_ids in encodings_ids for token_id in token_ids],
124
+ dtype=torch.long
125
+ )
126
+ return {
127
+ "input_ids": input_ids,
128
+ "offsets": offsets
129
+ }
130
+
131
+ # Forward
132
+ def forward(
133
+ self,
134
+ features: Dict[str, torch.Tensor],
135
+ **kwargs,
136
+ ) -> Dict[str, torch.Tensor]:
137
+ x = self.embedding(features["input_ids"], features["offsets"])
138
+ x = self.dyt(x)
139
+ features["sentence_embedding"] = x
140
+ return features
141
+
142
+ # Required APIs
143
+ def get_sentence_embedding_dimension(self) -> int:
144
+ return self.embedding_dim
145
+
146
+ @property
147
+ def max_seq_length(self) -> int:
148
+ return torch.inf
149
+
150
+ def save(self, output_path: str):
151
+ os.makedirs(output_path, exist_ok=True)
152
+
153
+ state = self.state_dict()
154
+
155
+ emb = state["embedding.weight"]
156
+ q = quantize_q4_k_m(emb)
157
+
158
+ del state["embedding.weight"]
159
+
160
+ save_safetensors_file(
161
+ state,
162
+ os.path.join(output_path, "model_rest.safetensors"),
163
+ )
164
+
165
+ with open(os.path.join(output_path, "embedding.q4_k_m.bin"), "wb") as f:
166
+ f.write(q["packed"].tobytes())
167
+ f.write(q["scales"].tobytes())
168
+
169
+ self.tokenizer.save(
170
+ str(Path(output_path) / "tokenizer.json")
171
+ )
172
+
173
+ @classmethod
174
+ def load(cls, model_path: str):
175
+
176
+ tokenizer = Tokenizer.from_file(
177
+ os.path.join(model_path, "tokenizer.json")
178
+ )
179
+
180
+ state = load_file(
181
+ os.path.join(model_path, "model_rest.safetensors"),
182
+ device="cpu"
183
+ )
184
+
185
+ # read q4 binary
186
+ bin_path = os.path.join(model_path, "embedding.q4_k_m.bin")
187
+ with open(bin_path, "rb") as f:
188
+ raw = f.read()
189
+
190
+ hidden = state["dyt.alpha"].shape[0]
191
+ total_uint8 = len(raw)
192
+
193
+ bytes_per_row = hidden // 2 + 4
194
+ vocab = total_uint8 // bytes_per_row
195
+
196
+ packed_size = vocab * hidden // 2
197
+
198
+ packed = np.frombuffer(raw[:packed_size], dtype=np.uint8)
199
+ scales = np.frombuffer(raw[packed_size:], dtype=np.float32)
200
+
201
+ packed = packed.reshape(vocab, hidden // 2)
202
+ scales = scales.reshape(vocab, 1)
203
+
204
+ emb = dequantize_q4_k_m(packed, scales)
205
+
206
+ # rebuild model
207
+ model = cls(
208
+ tokenizer=tokenizer,
209
+ vocab_size=emb.shape[0],
210
+ hidden_dim=emb.shape[1]
211
+ )
212
+
213
+ state["embedding.weight"] = emb
214
+ model.load_state_dict(state)
215
+
216
+ return model
217
+
218
+
219
+ @dataclass
220
+ class SSESforzandoConfig:
221
+ hidden_dim: int = 512
222
+ vocab_size: int = 30522
223
+
224
+
225
+ @dataclass
226
+ class SSEForzandoConfig:
227
+ hidden_dim: int = 384
228
+ vocab_size: int = 30522