-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcode_searcher.py
More file actions
139 lines (119 loc) · 6.02 KB
/
code_searcher.py
File metadata and controls
139 lines (119 loc) · 6.02 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import torch
import numpy as np
import faiss
import pickle
from pathlib import Path
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM, BitsAndBytesConfig
class CodeSearcher:
def __init__(self, index_dir, generation_model_name=None, verbose=True, quantize=False):
self.index_dir = Path(index_dir)
self.verbose = verbose
self.code_embedding_model_name = "microsoft/unixcoder-base"
self.code_embedding_tokenizer = AutoTokenizer.from_pretrained(self.code_embedding_model_name)
self.code_embedding_model = AutoModel.from_pretrained(self.code_embedding_model_name)
self.generation_model_name = generation_model_name or "codellama/CodeLlama-7b-Instruct-hf"
self.generation_tokenizer = AutoTokenizer.from_pretrained(self.generation_model_name)
if quantize:
try:
quant_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16)
self.generation_model = AutoModelForCausalLM.from_pretrained(
self.generation_model_name,
device_map="auto",
quantization_config=quant_config
)
except Exception:
self.generation_model = AutoModelForCausalLM.from_pretrained(
self.generation_model_name,
device_map="auto"
)
else:
self.generation_model = AutoModelForCausalLM.from_pretrained(
self.generation_model_name,
device_map="auto"
)
self._load_indices()
def _get_query_embedding(self, query):
inputs = self.code_embedding_tokenizer(query, return_tensors="pt", truncation=True, padding=True, max_length=512)
with torch.no_grad():
outputs = self.code_embedding_model(**inputs)
return outputs.last_hidden_state[:, 0, :]
def _load_indices(self):
p = self.index_dir
self.code_file_index = faiss.read_index(str(p / "code_file_index.faiss"))
self.file_paths = np.load(p / "file_paths.npy", allow_pickle=True).tolist()
self.code_file_idx = np.load(p / "code_file_idx.npy", allow_pickle=True).item()
with open(p / "files_content.pkl", 'rb') as f:
self.files_content = pickle.load(f)
with open(p / "component_mappings.pkl", 'rb') as f:
component_mappings = pickle.load(f)
self.filepath_cembeddings = {}
for file_path in self.file_paths:
safe_name = file_path.replace("/", "_").replace("\\", "_")
index_path = p / f"component_{safe_name}.faiss"
if index_path.exists() and file_path in component_mappings:
comp_index = faiss.read_index(str(index_path))
self.filepath_cembeddings[file_path] = (comp_index, component_mappings[file_path])
def retrieve_relevant_files(self, query_embedding, k=5):
distances, indices = self.code_file_index.search(query_embedding, k)
results = []
for i in range(len(indices[0])):
idx = indices[0][i]
if idx < 0 or idx >= len(self.file_paths):
continue
file_path = self.code_file_idx.get(idx)
if file_path and file_path in self.files_content:
results.append({
'path': file_path,
'similarity': float(distances[0][i])
})
return sorted(results, key=lambda x: x['similarity'], reverse=True)
def retrieve_relevant_components(self, file_path, query_embedding, k=2):
if file_path not in self.filepath_cembeddings:
return []
index, components = self.filepath_cembeddings[file_path]
k = min(k, len(components))
distances, indices = index.search(query_embedding, k)
return [
{'component': components[i], 'similarity': float(distances[0][j])}
for j, i in enumerate(indices[0]) if i in components
]
def generate_explanation(self, context, query):
prompt = f"""<s>[INST] I have the following code or text:
{context}
Question: {query}
Please provide a detailed explanation of the code or answer the question.
If the answer to the query does not exist in the context given
then reply "Your question cannot be answered by the given content".[/INST]
"""
try:
inputs = self.generation_tokenizer(prompt, return_tensors="pt", truncation=True)
outputs = self.generation_model.generate(
inputs["input_ids"].to(self.generation_model.device),
attention_mask=inputs["attention_mask"].to(self.generation_model.device),
max_length=2048,
temperature=0.7,
top_p=0.9,
do_sample=True,
pad_token_id=self.generation_tokenizer.eos_token_id
)
result = self.generation_tokenizer.decode(outputs[0], skip_special_tokens=True)
return result.split("[/INST]")[1].strip() if "[/INST]" in result else result
except Exception as e:
return f"An error occurred while generating explanation: {e}"
def answer_query(self, query):
embedding = self._get_query_embedding(query).squeeze().cpu().numpy().astype('float32')
embedding = embedding / np.linalg.norm(embedding)
embedding = embedding.reshape(1, -1)
top_files = self.retrieve_relevant_files(embedding, k=1)
if not top_files:
return "No relevant files found."
file_path = top_files[0]['path']
components = self.retrieve_relevant_components(file_path, embedding, k=2)
if components:
context = "Here are the relevant code components:\n\n"
for comp in components:
c = comp['component']
context += f"--- {c['type']} {c['name']} ---\n{c['code']}\n\n"
else:
context = f"File: {file_path}\n\n{self.files_content[file_path]}"
return self.generate_explanation(context, query)