tjhalanigrid commited on
Commit
d11ee9a
·
1 Parent(s): 862cf68

rlhf is added

Browse files
Files changed (2) hide show
  1. .gitignore +2 -2
  2. src/text2sql_engine.py +17 -4
.gitignore CHANGED
@@ -2,8 +2,8 @@ __pycache__/
2
  *.pyc
3
  .DS_Store
4
  # checkpoints/milestone_before_more_dbs
5
- checkpoints/best_rlhf_codet5_soft
6
- checkpoints/best_rlhf_model
7
  results/
8
  *.png
9
 
 
2
  *.pyc
3
  .DS_Store
4
  # checkpoints/milestone_before_more_dbs
5
+ # checkpoints/best_rlhf_codet5_soft
6
+ # checkpoints/best_rlhf_model
7
  results/
8
  *.png
9
 
src/text2sql_engine.py CHANGED
@@ -56,7 +56,11 @@ class Text2SQLEngine:
56
  self.dml_keywords = r'\b(delete|update|insert|drop|alter|truncate)\b'
57
 
58
  print("Loading base model...")
59
- base = AutoModelForSeq2SeqLM.from_pretrained(base_model_name)
 
 
 
 
60
 
61
  if not use_lora:
62
  self.tokenizer = AutoTokenizer.from_pretrained(base_model_name)
@@ -71,7 +75,11 @@ class Text2SQLEngine:
71
 
72
  adapter_path = adapter_path.resolve()
73
 
74
- print("Loading tokenizer and LoRA adapter...")
 
 
 
 
75
 
76
  try:
77
  self.tokenizer = AutoTokenizer.from_pretrained(
@@ -81,7 +89,12 @@ class Text2SQLEngine:
81
  except Exception:
82
  self.tokenizer = AutoTokenizer.from_pretrained(base_model_name)
83
 
84
- self.model = PeftModel.from_pretrained(base, str(adapter_path)).to(self.device)
 
 
 
 
 
85
  self.model.eval()
86
 
87
  print("✅ RLHF model ready\n")
@@ -214,4 +227,4 @@ def get_engine():
214
  if _engine is None:
215
  _engine = Text2SQLEngine()
216
 
217
- return _engine
 
56
  self.dml_keywords = r'\b(delete|update|insert|drop|alter|truncate)\b'
57
 
58
  print("Loading base model...")
59
+ # Added tie_word_embeddings=False to silence the warning
60
+ base = AutoModelForSeq2SeqLM.from_pretrained(
61
+ base_model_name,
62
+ tie_word_embeddings=False
63
+ )
64
 
65
  if not use_lora:
66
  self.tokenizer = AutoTokenizer.from_pretrained(base_model_name)
 
75
 
76
  adapter_path = adapter_path.resolve()
77
 
78
+ # Sanity check to prevent confusing Hugging Face hub errors
79
+ if not adapter_path.exists():
80
+ raise FileNotFoundError(f"CRITICAL ERROR: Cannot find the model folder at {adapter_path}. It likely did not upload to Hugging Face correctly.")
81
+
82
+ print(f"Loading tokenizer and LoRA adapter from {adapter_path}...")
83
 
84
  try:
85
  self.tokenizer = AutoTokenizer.from_pretrained(
 
89
  except Exception:
90
  self.tokenizer = AutoTokenizer.from_pretrained(base_model_name)
91
 
92
+ # Added local_files_only=True to force local loading
93
+ self.model = PeftModel.from_pretrained(
94
+ base,
95
+ str(adapter_path),
96
+ local_files_only=True
97
+ ).to(self.device)
98
  self.model.eval()
99
 
100
  print("✅ RLHF model ready\n")
 
227
  if _engine is None:
228
  _engine = Text2SQLEngine()
229
 
230
+ return _engine