Spaces:
Sleeping
Sleeping
Commit ·
d11ee9a
1
Parent(s): 862cf68
rlhf is added
Browse files- .gitignore +2 -2
- src/text2sql_engine.py +17 -4
.gitignore
CHANGED
|
@@ -2,8 +2,8 @@ __pycache__/
|
|
| 2 |
*.pyc
|
| 3 |
.DS_Store
|
| 4 |
# checkpoints/milestone_before_more_dbs
|
| 5 |
-
checkpoints/best_rlhf_codet5_soft
|
| 6 |
-
checkpoints/best_rlhf_model
|
| 7 |
results/
|
| 8 |
*.png
|
| 9 |
|
|
|
|
| 2 |
*.pyc
|
| 3 |
.DS_Store
|
| 4 |
# checkpoints/milestone_before_more_dbs
|
| 5 |
+
# checkpoints/best_rlhf_codet5_soft
|
| 6 |
+
# checkpoints/best_rlhf_model
|
| 7 |
results/
|
| 8 |
*.png
|
| 9 |
|
src/text2sql_engine.py
CHANGED
|
@@ -56,7 +56,11 @@ class Text2SQLEngine:
|
|
| 56 |
self.dml_keywords = r'\b(delete|update|insert|drop|alter|truncate)\b'
|
| 57 |
|
| 58 |
print("Loading base model...")
|
| 59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
|
| 61 |
if not use_lora:
|
| 62 |
self.tokenizer = AutoTokenizer.from_pretrained(base_model_name)
|
|
@@ -71,7 +75,11 @@ class Text2SQLEngine:
|
|
| 71 |
|
| 72 |
adapter_path = adapter_path.resolve()
|
| 73 |
|
| 74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
|
| 76 |
try:
|
| 77 |
self.tokenizer = AutoTokenizer.from_pretrained(
|
|
@@ -81,7 +89,12 @@ class Text2SQLEngine:
|
|
| 81 |
except Exception:
|
| 82 |
self.tokenizer = AutoTokenizer.from_pretrained(base_model_name)
|
| 83 |
|
| 84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
self.model.eval()
|
| 86 |
|
| 87 |
print("✅ RLHF model ready\n")
|
|
@@ -214,4 +227,4 @@ def get_engine():
|
|
| 214 |
if _engine is None:
|
| 215 |
_engine = Text2SQLEngine()
|
| 216 |
|
| 217 |
-
return _engine
|
|
|
|
| 56 |
self.dml_keywords = r'\b(delete|update|insert|drop|alter|truncate)\b'
|
| 57 |
|
| 58 |
print("Loading base model...")
|
| 59 |
+
# Added tie_word_embeddings=False to silence the warning
|
| 60 |
+
base = AutoModelForSeq2SeqLM.from_pretrained(
|
| 61 |
+
base_model_name,
|
| 62 |
+
tie_word_embeddings=False
|
| 63 |
+
)
|
| 64 |
|
| 65 |
if not use_lora:
|
| 66 |
self.tokenizer = AutoTokenizer.from_pretrained(base_model_name)
|
|
|
|
| 75 |
|
| 76 |
adapter_path = adapter_path.resolve()
|
| 77 |
|
| 78 |
+
# Sanity check to prevent confusing Hugging Face hub errors
|
| 79 |
+
if not adapter_path.exists():
|
| 80 |
+
raise FileNotFoundError(f"CRITICAL ERROR: Cannot find the model folder at {adapter_path}. It likely did not upload to Hugging Face correctly.")
|
| 81 |
+
|
| 82 |
+
print(f"Loading tokenizer and LoRA adapter from {adapter_path}...")
|
| 83 |
|
| 84 |
try:
|
| 85 |
self.tokenizer = AutoTokenizer.from_pretrained(
|
|
|
|
| 89 |
except Exception:
|
| 90 |
self.tokenizer = AutoTokenizer.from_pretrained(base_model_name)
|
| 91 |
|
| 92 |
+
# Added local_files_only=True to force local loading
|
| 93 |
+
self.model = PeftModel.from_pretrained(
|
| 94 |
+
base,
|
| 95 |
+
str(adapter_path),
|
| 96 |
+
local_files_only=True
|
| 97 |
+
).to(self.device)
|
| 98 |
self.model.eval()
|
| 99 |
|
| 100 |
print("✅ RLHF model ready\n")
|
|
|
|
| 227 |
if _engine is None:
|
| 228 |
_engine = Text2SQLEngine()
|
| 229 |
|
| 230 |
+
return _engine
|