-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
238 lines (197 loc) · 7.86 KB
/
main.py
File metadata and controls
238 lines (197 loc) · 7.86 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
#!/usr/bin/env python3
"""
Batchman - High-Performance Ollama Batch Processor
Processes input lines through Ollama LLM with parallel workers and detailed progress tracking.
"""
import json
import time
import logging
from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import timedelta
from threading import Lock
import ollama
from config import (
OLLAMA_MODEL,
OLLAMA_BASE_URL,
OLLAMA_CONTEXT,
OLLAMA_KEEP_ALIVE,
PROMPT_FILE,
INPUT_FILE,
OUTPUT_FILE,
ERROR_FILE,
PARALLEL_WORKERS,
REQUEST_TIMEOUT
)
# Initialize logging
logging.basicConfig(
filename=ERROR_FILE,
level=logging.ERROR,
format='%(asctime)s - Line %(message)s'
)
# Thread-safe lock for progress updates
progress_lock = Lock()
class ProgressTracker:
"""Tracks and displays processing progress with detailed metrics."""
def __init__(self, total_items):
self.total_items = total_items
self.completed_items = 0
self.start_time = time.time()
self.response_times = []
self.lock = Lock()
def update(self, response_time):
"""Update progress with a new completed item."""
with self.lock:
self.completed_items += 1
self.response_times.append(response_time)
# Calculate metrics
elapsed_time = time.time() - self.start_time
avg_response_time = sum(self.response_times) / len(self.response_times)
remaining_items = self.total_items - self.completed_items
eta_seconds = avg_response_time * remaining_items
percentage = (self.completed_items / self.total_items) * 100
# Format output
eta_formatted = str(timedelta(seconds=int(eta_seconds)))
elapsed_formatted = str(timedelta(seconds=int(elapsed_time)))
print(f"\r┌─ Progress: {percentage:6.2f}% [{self.completed_items}/{self.total_items}] "
f"│ Avg Response: {avg_response_time:.2f}s "
f"│ Elapsed: {elapsed_formatted} "
f"│ ETA: {eta_formatted} ┐", end='', flush=True)
if self.completed_items == self.total_items:
print() # New line when complete
def get_stats(self):
"""Return final statistics."""
total_time = time.time() - self.start_time
avg_time = sum(self.response_times) / len(self.response_times) if self.response_times else 0
return {
'total_time': total_time,
'average_response_time': avg_time,
'total_items': self.total_items,
'items_per_second': self.total_items / total_time if total_time > 0 else 0
}
def load_prompt(prompt_file):
"""Load the prompt template from file."""
try:
with open(prompt_file, 'r', encoding='utf-8') as f:
return f.read().strip()
except FileNotFoundError:
print(f"❌ Error: Prompt file '{prompt_file}' not found.")
exit(1)
def load_input_lines(input_file):
"""Load all input lines from file."""
try:
with open(input_file, 'r', encoding='utf-8') as f:
lines = f.readlines()
return [line.strip() for line in lines]
except FileNotFoundError:
print(f"❌ Error: Input file '{input_file}' not found.")
exit(1)
def process_single_line(line_index, input_line, prompt_template):
"""
Process a single line through Ollama LLM.
Args:
line_index: The index of the line (0-based)
input_line: The input text to process
prompt_template: The prompt template with {INPUT} placeholder
Returns:
Tuple of (line_index, result_json, response_time)
"""
start_time = time.time()
try:
# Create the full prompt by replacing {INPUT} placeholder
full_prompt = prompt_template.replace('{INPUT}', input_line)
# Configure Ollama client
client = ollama.Client(host=OLLAMA_BASE_URL)
# Make the request to Ollama
response = client.generate(
model=OLLAMA_MODEL,
prompt=full_prompt,
options={
'num_ctx': OLLAMA_CONTEXT,
},
keep_alive=f"{OLLAMA_KEEP_ALIVE}m"
)
# Extract the response text
response_text = response['response'].strip()
# Try to parse as JSON
try:
# Find JSON content (sometimes LLMs add extra text)
# Look for the first { and last }
start_idx = response_text.find('{')
end_idx = response_text.rfind('}')
if start_idx != -1 and end_idx != -1:
json_str = response_text[start_idx:end_idx + 1]
result_json = json.loads(json_str)
else:
# No JSON found
raise json.JSONDecodeError("No JSON object found", response_text, 0)
except json.JSONDecodeError as e:
# Log JSON parsing error
error_msg = f"{line_index + 1}: JSON Parse Error - {str(e)}\nInput: {input_line}\nResponse: {response_text}"
logging.error(error_msg)
result_json = None
response_time = time.time() - start_time
return line_index, result_json, response_time
except Exception as e:
# Log general processing error
error_msg = f"{line_index + 1}: Processing Error - {str(e)}\nInput: {input_line}"
logging.error(error_msg)
response_time = time.time() - start_time
return line_index, None, response_time
def main():
"""Main execution function."""
print("=" * 80)
print("🚀 BATCHMAN - Ollama Batch Processor")
print("=" * 80)
print(f"📋 Model: {OLLAMA_MODEL}")
print(f"🔗 Server: {OLLAMA_BASE_URL}")
print(f"👷 Workers: {PARALLEL_WORKERS}")
print(f"⏱️ Timeout: {REQUEST_TIMEOUT}s")
print("=" * 80)
# Load prompt and input
print("📖 Loading prompt and input...")
prompt_template = load_prompt(PROMPT_FILE)
input_lines = load_input_lines(INPUT_FILE)
total_lines = len(input_lines)
print(f"✅ Loaded {total_lines} lines to process")
print("=" * 80)
# Initialize output file with empty lines (to maintain line order)
results = [None] * total_lines
# Initialize progress tracker
tracker = ProgressTracker(total_lines)
# Process lines in parallel
print(f"⚡ Processing with {PARALLEL_WORKERS} parallel workers...\n")
with ThreadPoolExecutor(max_workers=PARALLEL_WORKERS) as executor:
# Submit all tasks
future_to_index = {
executor.submit(process_single_line, idx, line, prompt_template): idx
for idx, line in enumerate(input_lines)
}
# Collect results as they complete
for future in as_completed(future_to_index):
line_idx, result, response_time = future.result()
results[line_idx] = result
tracker.update(response_time)
# Write results to output file (maintaining line order)
print("\n💾 Writing results to output file...")
with open(OUTPUT_FILE, 'w', encoding='utf-8') as f:
for result in results:
if result is not None:
f.write(json.dumps(result, ensure_ascii=False) + '\n')
else:
# Write empty object for failed lines to maintain line order
f.write(json.dumps({}) + '\n')
# Display final statistics
stats = tracker.get_stats()
print("=" * 80)
print("✅ PROCESSING COMPLETE!")
print("=" * 80)
print(f"📊 Total Items: {stats['total_items']}")
print(f"⏱️ Total Time: {timedelta(seconds=int(stats['total_time']))}")
print(f"📈 Average Response Time: {stats['average_response_time']:.2f}s")
print(f"🚀 Throughput: {stats['items_per_second']:.2f} items/second")
print("=" * 80)
print(f"📁 Output saved to: {OUTPUT_FILE}")
print(f"⚠️ Errors logged to: {ERROR_FILE}")
print("=" * 80)
if __name__ == "__main__":
main()