EmbeddingGemma-300M: Width Compressed

This is a compressed version of google/embeddinggemma-300m with 10x width reduction (hidden_size: 768 โ†’ 72), keeping all 24 layers.

Model Details

  • Base Model: google/embeddinggemma-300m
  • Compression:
    • Width reduction: 10x (hidden_size: 768 โ†’ 72)
    • Depth: All 24 layers preserved
  • Parameters: ~20M (down from ~300M)
  • Output Dimension: 768 (via projection layer)
  • Compression Ratio: ~15x

Installation

pip install torch sentence-transformers transformers

Usage

Simple Usage with SentenceTransformers (Recommended)

The easiest way to use this model is with the sentence-transformers library:

from sentence_transformers import SentenceTransformer
import torch

# Load the model
model = SentenceTransformer(
    "Pieces/embeddinggemma-300m-distilled-width10pct-768dim-best",
    device="cuda" if torch.cuda.is_available() else "cpu"
)

# Encode texts
texts = ["Hello world", "This is a test"]
embeddings = model.encode(texts, convert_to_tensor=True)

print(f"Embeddings shape: {embeddings.shape}")
# Output: torch.Size([2, 768])

Advanced Usage (Full Model Access)

If you need access to the full model structure:

import torch
from pathlib import Path
import sys

# Add parent directory to path for imports
sys.path.insert(0, str(Path(__file__).parent.parent))

from playground.validate_from_checkpoint import load_trained_model
from tags_model.training.train_distillation import _get_transformer_layers

# Load the model
model, config = load_trained_model(
    checkpoint_path="Pieces/embeddinggemma-300m-distilled-width10pct-768dim-best",
    device="cuda" if torch.cuda.is_available() else "cpu",
    compile_model=False,
)

# Verify structure
transformer = model.backbone.transformer
layers = _get_transformer_layers(transformer)
layer_count = len(layers) if layers else 0
total_params = sum(p.numel() for p in model.parameters())

print(f"Model structure:")
print(f"  - Layers: {layer_count}")
print(f"  - Parameters: {total_params:,}")
print(f"  - Hidden size: {model.backbone.hidden_size}")

# Encode texts
model.eval()
with torch.no_grad():
    texts = ["Hello world", "This is a test"]
    embeddings = model.backbone.encode_texts(
        texts, max_length=512, return_dict=False
    )

print(f"Embeddings shape: {embeddings.shape}")

Query-Tag Retrieval Example

from sentence_transformers import SentenceTransformer
import torch
from typing import List

# Load model
model = SentenceTransformer(
    "Pieces/embeddinggemma-300m-distilled-width10pct-768dim-best"
)

def compute_similarities(query_embeddings, tag_embeddings):
    """Compute cosine similarities between queries and tags."""
    query_norm = query_embeddings / (query_embeddings.norm(dim=1, keepdim=True) + 1e-8)
    tag_norm = tag_embeddings / (tag_embeddings.norm(dim=1, keepdim=True) + 1e-8)
    return torch.mm(query_norm, tag_norm.t())

# Example queries and tags
queries = [
    "How to implement authentication in a web application?",
    "What are the best practices for database optimization?",
    "How to deploy a machine learning model to production?",
]

tags = [
    "authentication", "security", "web-development",
    "database", "sql", "performance", "optimization",
    "machine-learning", "deployment", "production",
    "api", "backend", "frontend", "javascript", "python",
]

# Encode queries and tags
query_embeddings = model.encode(queries, convert_to_tensor=True)
tag_embeddings = model.encode(tags, convert_to_tensor=True)

print(f"Query embeddings shape: {query_embeddings.shape}")
print(f"Tag embeddings shape: {tag_embeddings.shape}")

# Compute similarities
similarities = compute_similarities(query_embeddings, tag_embeddings)
print(f"Similarities shape: {similarities.shape}")

# Get top tags for each query
for query_idx, query in enumerate(queries):
    top_k = 3
    top_similarities, top_indices = torch.topk(
        similarities[query_idx], k=top_k, dim=0
    )
    
    print(f"\nQuery: {query}")
    for rank, (tag_idx, sim) in enumerate(
        zip(top_indices.cpu().tolist(), top_similarities.cpu().tolist()), start=1
    ):
        print(f"  {rank}. {tags[tag_idx]} (similarity: {sim:.4f})")

See usage_example.py in this repository for a complete standalone example.

Model Architecture

  • Transformer Layers: 24 (all preserved)
  • Hidden Size: 72 (reduced from 768)
  • Output Dimension: 768 (via projection layer: 72 โ†’ 768)
  • Projection: Linear layer projects from compressed hidden_size to output dimension

Performance

This compressed model maintains reasonable performance while being significantly smaller:

  • Parameters: 20M (vs 300M original)
  • Memory: ~80MB (vs ~1.2GB original)
  • Speed: Faster inference due to smaller hidden size

Citation

If you use this model, please cite:

@misc{embeddinggemma-compressed-width,
  title={EmbeddingGemma-300M: Width Compressed},
  author={Pieces},
  year={2024},
  url={https://huggingface.co/Pieces/embeddinggemma-300m-distilled-width10pct-768dim-best}
}
Downloads last month
-
Safetensors
Model size
19.8M params
Tensor type
F32
ยท
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support