|
|
| """
|
| Verification script to demonstrate all implemented functionality.
|
| Run this to see layers.py and packing.py in action!
|
| """
|
|
|
| import torch
|
| import torch.nn as nn
|
| from bitlinear import BitLinear, MultiTernaryLinear, convert_linear_to_bitlinear
|
| from bitlinear.packing import (
|
| pack_ternary_base3,
|
| unpack_ternary_base3,
|
| estimate_memory_savings,
|
| )
|
|
|
|
|
| def demo_bitlinear():
|
| """Demonstrate BitLinear layer."""
|
| print("=" * 70)
|
| print("1. BitLinear Layer Demo")
|
| print("=" * 70)
|
|
|
|
|
| layer = BitLinear(512, 256, bias=True)
|
| print(f"β Created BitLinear(512 β 256)")
|
| print(f" - W_ternary shape: {layer.W_ternary.shape}")
|
| print(f" - Gamma shape: {layer.gamma.shape}")
|
| print(f" - Unique weight values: {sorted(layer.W_ternary.unique().tolist())}")
|
|
|
|
|
| x = torch.randn(16, 512)
|
| y = layer(x)
|
| print(f"\nβ Forward pass: {x.shape} β {y.shape}")
|
|
|
|
|
| linear = nn.Linear(512, 256)
|
| bitlinear = BitLinear.from_linear(linear)
|
| print(f"β Converted nn.Linear to BitLinear")
|
| print()
|
|
|
|
|
| def demo_multi_ternary():
|
| """Demonstrate MultiTernaryLinear layer."""
|
| print("=" * 70)
|
| print("2. MultiTernaryLinear Layer Demo")
|
| print("=" * 70)
|
|
|
|
|
| for k in [1, 2, 4]:
|
| layer = MultiTernaryLinear(256, 128, k=k, bias=True)
|
| print(f"β MultiTernaryLinear(256 β 128, k={k})")
|
| print(f" - W_ternary shape: {layer.W_ternary.shape}")
|
| print(f" - Gammas shape: {layer.gammas.shape}")
|
|
|
|
|
| print("\nβ Approximation quality test:")
|
| linear = nn.Linear(128, 128)
|
| x = torch.randn(8, 128)
|
| dense_out = linear(x)
|
|
|
| errors = []
|
| for k in [1, 2, 4]:
|
| multi = MultiTernaryLinear.from_linear(linear, k=k)
|
| ternary_out = multi(x)
|
| error = torch.norm(dense_out - ternary_out).item()
|
| errors.append(error)
|
| print(f" - k={k}: reconstruction error = {error:.4f}")
|
|
|
| print(f" - Error decreases with k: {errors[0] > errors[1] > errors[2]}")
|
| print()
|
|
|
|
|
| def demo_model_conversion():
|
| """Demonstrate model conversion utility."""
|
| print("=" * 70)
|
| print("3. Model Conversion Utility Demo")
|
| print("=" * 70)
|
|
|
|
|
| class SimpleModel(nn.Module):
|
| def __init__(self):
|
| super().__init__()
|
| self.fc1 = nn.Linear(128, 256)
|
| self.relu = nn.ReLU()
|
| self.fc2 = nn.Linear(256, 128)
|
| self.fc3 = nn.Linear(128, 10)
|
|
|
| def forward(self, x):
|
| x = self.relu(self.fc1(x))
|
| x = self.relu(self.fc2(x))
|
| return self.fc3(x)
|
|
|
| model = SimpleModel()
|
|
|
|
|
| linear_count = sum(1 for m in model.modules() if isinstance(m, nn.Linear))
|
| print(f"β Original model: {linear_count} Linear layers")
|
|
|
|
|
| model_converted = convert_linear_to_bitlinear(model, inplace=False)
|
| bitlinear_count = sum(1 for m in model_converted.modules() if isinstance(m, BitLinear))
|
| print(f"β Converted model: {bitlinear_count} BitLinear layers")
|
|
|
|
|
| x = torch.randn(4, 128)
|
| y = model_converted(x)
|
| print(f"β Forward pass works: {x.shape} β {y.shape}")
|
| print()
|
|
|
|
|
| def demo_packing():
|
| """Demonstrate base-3 packing."""
|
| print("=" * 70)
|
| print("4. Base-3 Packing Demo")
|
| print("=" * 70)
|
|
|
|
|
| W = torch.tensor([
|
| [-1, 0, 1, -1, 0],
|
| [1, 1, -1, 0, 1],
|
| [0, -1, 1, 1, -1],
|
| ], dtype=torch.float32)
|
|
|
| print(f"β Original ternary weights shape: {W.shape}")
|
| print(f" - Float32 memory: {W.numel() * 4} bytes")
|
|
|
|
|
| packed, original_shape = pack_ternary_base3(W)
|
| print(f"\nβ Packed into uint8 tensor")
|
| print(f" - Packed shape: {packed.shape}")
|
| print(f" - Packed memory: {packed.numel()} bytes")
|
| print(f" - Compression: {W.numel() * 4 / packed.numel():.2f}x")
|
|
|
|
|
| W_unpacked = unpack_ternary_base3(packed, original_shape)
|
| print(f"\nβ Unpacked back to ternary")
|
| print(f" - Unpacked shape: {W_unpacked.shape}")
|
| print(f" - Perfect round-trip: {torch.allclose(W, W_unpacked)}")
|
| print()
|
|
|
|
|
| def demo_memory_estimation():
|
| """Demonstrate memory savings estimation."""
|
| print("=" * 70)
|
| print("5. Memory Savings Estimation")
|
| print("=" * 70)
|
|
|
| configs = [
|
| (768, 3072, 1, "Single Transformer FFN layer"),
|
| (768, 3072, 12, "BERT-base (12 layers)"),
|
| (1024, 4096, 24, "BERT-large (24 layers)"),
|
| ]
|
|
|
| for in_dim, out_dim, num_layers, description in configs:
|
| stats = estimate_memory_savings(in_dim, out_dim, num_layers)
|
| print(f"\nβ {description}")
|
| print(f" Configuration: {in_dim} β {out_dim} Γ {num_layers} layers")
|
| print(f" Float32 memory: {stats['float32_bytes'] / 1e6:.2f} MB")
|
| print(f" Packed memory: {stats['packed_bytes'] / 1e6:.2f} MB")
|
| print(f" Savings: {stats['savings_bytes'] / 1e6:.2f} MB")
|
| print(f" Compression: {stats['compression_ratio']:.2f}x")
|
| print()
|
|
|
|
|
| def main():
|
| """Run all demos."""
|
| print("\n" + "=" * 70)
|
| print(" BitLinear Implementation Verification")
|
| print(" All functionality implemented and working!")
|
| print("=" * 70)
|
| print()
|
|
|
| demo_bitlinear()
|
| demo_multi_ternary()
|
| demo_model_conversion()
|
| demo_packing()
|
| demo_memory_estimation()
|
|
|
| print("=" * 70)
|
| print(" β All implementations verified!")
|
| print(" β Ready for C++/CUDA optimization")
|
| print("=" * 70)
|
| print()
|
|
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|