diff --git a/doc/GENAI.md b/doc/GENAI.md new file mode 100644 index 0000000000..66d5218a4c --- /dev/null +++ b/doc/GENAI.md @@ -0,0 +1,490 @@ +# GenAI Module Documentation + +## Overview + +The **GenAI (Generative AI) Module** in ProxySQL provides asynchronous, non-blocking access to embedding generation and document reranking services. It enables ProxySQL to interact with LLM services (like llama-server) for vector embeddings and semantic search operations without blocking MySQL threads. + +## Version + +- **Module Version**: 0.1.0 +- **Last Updated**: 2025-01-10 +- **Branch**: v3.1-vec_genAI_module + +## Architecture + +### Async Design + +The GenAI module uses a **non-blocking async architecture** based on socketpair IPC and epoll event notification: + +``` +┌─────────────────┐ socketpair ┌─────────────────┐ +│ MySQL_Session │◄────────────────────────────►│ GenAI Module │ +│ (MySQL Thread) │ fds[0] fds[1] │ Listener Loop │ +└────────┬────────┘ └────────┬────────┘ + │ │ + │ epoll │ queue + │ │ + └── epoll_wait() ────────────────────────────────┘ + (GenAI Response Ready) +``` + +### Key Components + +1. **MySQL_Session** - Client-facing interface that receives GENAI: queries +2. **GenAI Listener Thread** - Monitors socketpair fds via epoll for incoming requests +3. **GenAI Worker Threads** - Thread pool that processes requests (blocking HTTP calls) +4. **Socketpair Communication** - Bidirectional IPC between MySQL and GenAI modules + +### Communication Protocol + +#### Request Format (MySQL → GenAI) + +```c +struct GenAI_RequestHeader { + uint64_t request_id; // Client's correlation ID + uint32_t operation; // GENAI_OP_EMBEDDING, GENAI_OP_RERANK, or GENAI_OP_JSON + uint32_t query_len; // Length of JSON query that follows + uint32_t flags; // Reserved (must be 0) + uint32_t top_n; // For rerank: max results (0 = all) +}; +// Followed by: JSON query (query_len bytes) +``` + +#### Response Format (GenAI → MySQL) + +```c +struct GenAI_ResponseHeader { + uint64_t request_id; // Echo of client's request ID + uint32_t status_code; // 0 = success, >0 = error + uint32_t result_len; // Length of JSON result that follows + uint32_t processing_time_ms;// Time taken by GenAI worker + uint64_t result_ptr; // Reserved (must be 0) + uint32_t result_count; // Number of results + uint32_t reserved; // Reserved (must be 0) +}; +// Followed by: JSON result (result_len bytes) +``` + +## Configuration Variables + +### Thread Configuration + +| Variable | Type | Default | Description | +|----------|------|---------|-------------| +| `genai-threads` | int | 4 | Number of GenAI worker threads (1-256) | + +### Service Endpoints + +| Variable | Type | Default | Description | +|----------|------|---------|-------------| +| `genai-embedding_uri` | string | `http://127.0.0.1:8013/embedding` | Embedding service endpoint | +| `genai-rerank_uri` | string | `http://127.0.0.1:8012/rerank` | Reranking service endpoint | + +### Timeouts + +| Variable | Type | Default | Description | +|----------|------|---------|-------------| +| `genai-embedding_timeout_ms` | int | 30000 | Embedding request timeout (100-300000ms) | +| `genai-rerank_timeout_ms` | int | 30000 | Reranking request timeout (100-300000ms) | + +### Admin Commands + +```sql +-- Load/Save GenAI variables +LOAD GENAI VARIABLES TO RUNTIME; +SAVE GENAI VARIABLES FROM RUNTIME; +LOAD GENAI VARIABLES FROM DISK; +SAVE GENAI VARIABLES TO DISK; + +-- Set variables +SET genai-threads = 8; +SET genai-embedding_uri = 'http://localhost:8080/embed'; +SET genai-rerank_uri = 'http://localhost:8081/rerank'; + +-- View variables +SELECT @@genai-threads; +SHOW VARIABLES LIKE 'genai-%'; + +-- Checksum +CHECKSUM GENAI VARIABLES; +``` + +## Query Syntax + +### GENAI: Query Format + +GenAI queries use the special `GENAI:` prefix followed by JSON: + +```sql +GENAI: {"type": "embed", "documents": ["text1", "text2"]} +GENAI: {"type": "rerank", "query": "search text", "documents": ["doc1", "doc2"]} +``` + +### Supported Operations + +#### 1. Embedding + +Generate vector embeddings for documents: + +```sql +GENAI: { + "type": "embed", + "documents": [ + "Machine learning is a subset of AI.", + "Deep learning uses neural networks." + ] +} +``` + +**Response:** +``` ++------------------------------------------+ +| embedding | ++------------------------------------------+ +| 0.123, -0.456, 0.789, ... | +| 0.234, -0.567, 0.890, ... | ++------------------------------------------+ +``` + +#### 2. Reranking + +Rerank documents by relevance to a query: + +```sql +GENAI: { + "type": "rerank", + "query": "What is machine learning?", + "documents": [ + "Machine learning is a subset of artificial intelligence.", + "The capital of France is Paris.", + "Deep learning uses neural networks." + ], + "top_n": 2, + "columns": 3 +} +``` + +**Parameters:** +- `query` (required): Search query text +- `documents` (required): Array of documents to rerank +- `top_n` (optional): Maximum results to return (0 = all, default: all) +- `columns` (optional): 2 = {index, score}, 3 = {index, score, document} (default: 3) + +**Response:** +``` ++-------+-------+----------------------------------------------+ +| index | score | document | ++-------+-------+----------------------------------------------+ +| 0 | 0.95 | Machine learning is a subset of AI... | +| 2 | 0.82 | Deep learning uses neural networks... | ++-------+-------+----------------------------------------------+ +``` + +### Response Format + +All GenAI queries return results in MySQL resultset format with: +- `columns`: Array of column names +- `rows`: Array of row data + +**Success:** +```json +{ + "columns": ["index", "score", "document"], + "rows": [ + [0, 0.95, "Most relevant document"], + [2, 0.82, "Second most relevant"] + ] +} +``` + +**Error:** +```json +{ + "error": "Error message describing what went wrong" +} +``` + +## Usage Examples + +### Basic Embedding + +```sql +-- Generate embedding for a single document +GENAI: {"type": "embed", "documents": ["Hello, world!"]}; + +-- Batch embedding for multiple documents +GENAI: { + "type": "embed", + "documents": ["doc1", "doc2", "doc3"] +}; +``` + +### Basic Reranking + +```sql +-- Find most relevant documents +GENAI: { + "type": "rerank", + "query": "database optimization techniques", + "documents": [ + "How to bake a cake", + "Indexing strategies for MySQL", + "Python programming basics", + "Query optimization in ProxySQL" + ] +}; +``` + +### Top N Results + +```sql +-- Get only top 3 most relevant documents +GENAI: { + "type": "rerank", + "query": "best practices for SQL", + "documents": ["doc1", "doc2", "doc3", "doc4", "doc5"], + "top_n": 3 +}; +``` + +### Index and Score Only + +```sql +-- Get only relevance scores (no document text) +GENAI: { + "type": "rerank", + "query": "test query", + "documents": ["doc1", "doc2"], + "columns": 2 +}; +``` + +## Integration with ProxySQL + +### Session Lifecycle + +1. **Session Start**: MySQL session creates `genai_epoll_fd_` for monitoring GenAI responses +2. **Query Received**: `GENAI:` query detected in `handler___status_WAITING_CLIENT_DATA___STATE_SLEEP()` +3. **Async Send**: Socketpair created, request sent, returns immediately +4. **Main Loop**: `check_genai_events()` called on each iteration +5. **Response Ready**: `handle_genai_response()` processes response +6. **Result Sent**: MySQL result packet sent to client +7. **Cleanup**: Socketpair closed, resources freed + +### Main Loop Integration + +The GenAI event checking is integrated into the main MySQL handler loop: + +```cpp +handler_again: + switch (status) { + case WAITING_CLIENT_DATA: + handler___status_WAITING_CLIENT_DATA(); +#ifdef epoll_create1 + // Check for GenAI responses before processing new client data + if (check_genai_events()) { + goto handler_again; // Process more responses + } +#endif + break; + } +``` + +## Backend Services + +### llama-server Integration + +The GenAI module is designed to work with [llama-server](https://github.com/ggerganov/llama.cpp), a high-performance C++ inference server for LLaMA models. + +#### Starting llama-server + +```bash +# Start embedding server +./llama-server \ + --model /path/to/nomic-embed-text-v1.5.gguf \ + --port 8013 \ + --embedding \ + --ctx-size 512 + +# Start reranking server (using same model) +./llama-server \ + --model /path/to/nomic-embed-text-v1.5.gguf \ + --port 8012 \ + --ctx-size 512 +``` + +#### API Compatibility + +The GenAI module expects: +- **Embedding endpoint**: `POST /embedding` with JSON request +- **Rerank endpoint**: `POST /rerank` with JSON request + +Compatible with: +- llama-server +- OpenAI-compatible embedding APIs +- Custom services with matching request/response format + +## Testing + +### TAP Test Suite + +Comprehensive TAP tests are available in `test/tap/tests/genai_async-t.cpp`: + +```bash +cd test/tap/tests +make genai_async-t +./genai_async-t +``` + +**Test Coverage:** +- Single async requests +- Sequential requests (embedding and rerank) +- Batch requests (10+ documents) +- Mixed embedding and rerank +- Request/response matching +- Error handling (invalid JSON, missing fields) +- Special characters (quotes, unicode, etc.) +- Large documents (5KB+) +- `top_n` and `columns` parameters +- Concurrent connections + +### Manual Testing + +```sql +-- Test embedding +mysql> GENAI: {"type": "embed", "documents": ["test document"]}; + +-- Test reranking +mysql> GENAI: { + -> "type": "rerank", + -> "query": "test query", + -> "documents": ["doc1", "doc2", "doc3"] + -> }; +``` + +## Performance Characteristics + +### Non-Blocking Behavior + +- **MySQL threads**: Return immediately after sending request (~1ms) +- **GenAI workers**: Handle blocking HTTP calls (10-100ms typical) +- **Throughput**: Limited by GenAI service capacity and worker thread count + +### Resource Usage + +- **Per request**: 1 socketpair (2 file descriptors) +- **Memory**: Request metadata + pending response storage +- **Worker threads**: Configurable via `genai-threads` (default: 4) + +### Scalability + +- **Concurrent requests**: Limited by `genai-threads` and GenAI service capacity +- **Request queue**: Unlimited (pending requests stored in session map) +- **Recommended**: Set `genai-threads` to match expected concurrency + +## Error Handling + +### Common Errors + +| Error | Cause | Solution | +|-------|-------|----------| +| `Failed to create GenAI communication channel` | Socketpair creation failed | Check system limits (ulimit -n) | +| `Failed to register with GenAI module` | GenAI module not initialized | Run `LOAD GENAI VARIABLES TO RUNTIME` | +| `Failed to send request to GenAI module` | Write error on socketpair | Check connection stability | +| `GenAI module not initialized` | GenAI threads not started | Set `genai-threads > 0` and reload | + +### Timeout Handling + +Requests exceeding `genai-embedding_timeout_ms` or `genai-rerank_timeout_ms` will fail with: +- Status code > 0 in response header +- Error message in JSON result +- Socketpair cleanup + +## Monitoring + +### Status Variables + +```sql +-- Check GenAI module status (not yet implemented, planned) +SHOW STATUS LIKE 'genai-%'; +``` + +**Planned status variables:** +- `genai_threads_initialized`: Number of worker threads running +- `genai_active_requests`: Currently processing requests +- `genai_completed_requests`: Total successful requests +- `genai_failed_requests`: Total failed requests + +### Logging + +GenAI operations log at debug level: + +```bash +# Enable GenAI debug logging +SET mysql-debug = 1; + +# Check logs +tail -f proxysql.log | grep GenAI +``` + +## Limitations + +### Current Limitations + +1. **document_from_sql**: Not yet implemented (requires MySQL connection handling in workers) +2. **Shared memory**: Result pointer field reserved for future optimization +3. **Request size**: Limited by socket buffer size (typically 64KB-256KB) + +### Platform Requirements + +- **Epoll support**: Linux systems (kernel 2.6+) +- **Socketpair**: Unix domain sockets +- **Threading**: POSIX threads (pthread) + +## Future Enhancements + +### Planned Features + +1. **document_from_sql**: Execute SQL to retrieve documents for reranking +2. **Shared memory**: Zero-copy result transfer for large responses +3. **Connection pooling**: Reuse HTTP connections to GenAI services +4. **Metrics**: Enhanced monitoring and statistics +5. **Batch optimization**: Better support for large document batches +6. **Streaming**: Progressive result delivery for large operations + +## Related Documentation + +- [Posts Table Embeddings Setup](./posts-embeddings-setup.md) - Using sqlite-rembed with GenAI +- [SQLite3 Server Documentation](./SQLite3-Server.md) - SQLite3 backend integration +- [sqlite-rembed Integration](./sqlite-rembed-integration.md) - Embedding generation + +## Source Files + +### Core Implementation + +- `include/GenAI_Thread.h` - GenAI module interface and structures +- `lib/GenAI_Thread.cpp` - Implementation of listener and worker loops +- `include/MySQL_Session.h` - Session integration (GenAI async state) +- `lib/MySQL_Session.cpp` - Async handlers and main loop integration +- `include/Base_Session.h` - Base session GenAI members + +### Tests + +- `test/tap/tests/genai_module-t.cpp` - Admin commands and variables +- `test/tap/tests/genai_embedding_rerank-t.cpp` - Basic embedding/reranking +- `test/tap/tests/genai_async-t.cpp` - Async architecture tests + +## License + +Same as ProxySQL - See LICENSE file for details. + +## Contributing + +For contributions and issues: +- GitHub: https://github.com/sysown/proxysql +- Branch: `v3.1-vec_genAI_module` + +--- + +*Last Updated: 2025-01-10* +*Module Version: 0.1.0* diff --git a/genai_prototype/.gitignore b/genai_prototype/.gitignore new file mode 100644 index 0000000000..3209566ed9 --- /dev/null +++ b/genai_prototype/.gitignore @@ -0,0 +1,26 @@ +# Build artifacts +genai_demo +genai_demo_event +*.o +*.oo + +# Debug files +*.dSYM/ +*.su +*.idb +*.pdb + +# Editor files +*~ +.*.swp +.vscode/ +.idea/ + +# Core dumps +core +core.* + +# Temporary files +*.tmp +*.temp +*.log diff --git a/genai_prototype/Makefile b/genai_prototype/Makefile new file mode 100644 index 0000000000..249d5e180f --- /dev/null +++ b/genai_prototype/Makefile @@ -0,0 +1,83 @@ +# Makefile for GenAI Prototype +# Standalone prototype for testing GenAI module architecture + +CXX = g++ +CXXFLAGS = -std=c++17 -Wall -Wextra -O2 -g +LDFLAGS = -lpthread -lcurl +CURL_CFLAGS = $(shell curl-config --cflags) +CURL_LDFLAGS = $(shell curl-config --libs) + +# Target executables +TARGET_THREAD = genai_demo +TARGET_EVENT = genai_demo_event +TARGETS = $(TARGET_THREAD) $(TARGET_EVENT) + +# Source files +SOURCES_THREAD = genai_demo.cpp +SOURCES_EVENT = genai_demo_event.cpp + +# Object files +OBJECTS_THREAD = $(SOURCES_THREAD:.cpp=.o) +OBJECTS_EVENT = $(SOURCES_EVENT:.cpp=.o) + +# Default target (build both demos) +all: $(TARGETS) + +# Individual demo targets +genai_demo: genai_demo.o + @echo "Linking genai_demo..." + $(CXX) genai_demo.o $(LDFLAGS) -o genai_demo + @echo "Build complete: genai_demo" + +genai_demo_event: genai_demo_event.o + @echo "Linking genai_demo_event..." + $(CXX) genai_demo_event.o $(CURL_LDFLAGS) $(LDFLAGS) -o genai_demo_event + @echo "Build complete: genai_demo_event" + +# Compile source files +genai_demo.o: genai_demo.cpp + @echo "Compiling $<..." + $(CXX) $(CXXFLAGS) -c $< -o $@ + +genai_demo_event.o: genai_demo_event.cpp + @echo "Compiling $<..." + $(CXX) $(CXXFLAGS) $(CURL_CFLAGS) -c $< -o $@ + +# Run the demos +run: $(TARGET_THREAD) + @echo "Running thread-based GenAI demo..." + ./$(TARGET_THREAD) + +run-event: $(TARGET_EVENT) + @echo "Running event-based GenAI demo..." + ./$(TARGET_EVENT) + +# Clean build artifacts +clean: + @echo "Cleaning..." + rm -f $(OBJECTS_THREAD) $(OBJECTS_EVENT) $(TARGETS) + @echo "Clean complete" + +# Rebuild +rebuild: clean all + +# Debug build with more warnings +debug: CXXFLAGS += -DDEBUG -Wpedantic +debug: clean all + +# Help target +help: + @echo "GenAI Prototype Makefile" + @echo "" + @echo "Targets:" + @echo " all - Build both demos (default)" + @echo " genai_demo - Build thread-based demo" + @echo " genai_demo_event - Build event-based demo" + @echo " run - Build and run thread-based demo" + @echo " run-event - Build and run event-based demo" + @echo " clean - Remove build artifacts" + @echo " rebuild - Clean and build all" + @echo " debug - Build with debug flags and extra warnings" + @echo " help - Show this help message" + +.PHONY: all run run-event clean rebuild debug help diff --git a/genai_prototype/README.md b/genai_prototype/README.md new file mode 100644 index 0000000000..8d14e27bbc --- /dev/null +++ b/genai_prototype/README.md @@ -0,0 +1,139 @@ +# GenAI Module Prototype + +Standalone prototype demonstrating the GenAI module architecture for ProxySQL. + +## Architecture Overview + +This prototype demonstrates a thread-pool based GenAI module that: + +1. **Receives requests** from multiple clients (MySQL/PgSQL threads) via socket pairs +2. **Queues requests** internally with a fixed-size worker thread pool +3. **Processes requests asynchronously** without blocking the clients +4. **Returns responses** to clients via the same socket connections + +### Components + +``` +┌─────────────────────────────────────────────────────────┐ +│ GenAI Module │ +│ │ +│ ┌────────────────────────────────────────────────┐ │ +│ │ Listener Thread (epoll-based) │ │ +│ │ - Monitors all client file descriptors │ │ +│ │ - Reads incoming requests │ │ +│ │ - Pushes to request queue │ │ +│ └──────────────────┬─────────────────────────────┘ │ +│ │ │ +│ ▼ │ +│ ┌────────────────────────────────────────────────┐ │ +│ │ Request Queue │ │ +│ │ - Thread-safe queue │ │ +│ │ - Condition variable for worker notification │ │ +│ └──────────────────┬─────────────────────────────┘ │ +│ │ │ +│ ▼ │ +│ ┌────────────────────────────────────────────────┐ │ +│ │ Thread Pool (configurable number of workers) │ │ +│ │ ┌──────┐ ┌──────┐ ┌──────┐ ┌──────┐ │ │ +│ │ │Worker│ │Worker│ │Worker│ │Worker│ ... │ │ +│ │ └───┬──┘ └───┬──┘ └───┬──┘ └───┬──┘ │ │ +│ │ └──────────┴──────────┴──────────┘ │ │ +│ └────────────────────────────────────────────────┘ │ +└─────────────────────────────────────────────────────────┘ + ▲ │ ▲ + │ │ │ + socketpair() Responses socketpair() + from clients to clients from clients +``` + +### Communication Protocol + +**Client → GenAI (Request)**: +```cpp +struct RequestHeader { + uint64_t request_id; // Client's correlation ID + uint32_t operation; // 0=embedding, 1=completion, 2=rag + uint32_t input_size; // Size of following data + uint32_t flags; // Reserved +}; +// Followed by input_size bytes of input data +``` + +**GenAI → Client (Response)**: +```cpp +struct ResponseHeader { + uint64_t request_id; // Echo client's ID + uint32_t status_code; // 0=success, >0=error + uint32_t output_size; // Size of following data + uint32_t processing_time_ms; // Time taken to process +}; +// Followed by output_size bytes of output data +``` + +## Building and Running + +```bash +# Build +make + +# Run +make run + +# Clean +make clean + +# Debug build +make debug + +# Show help +make help +``` + +## Current Status + +**Implemented:** +- ✅ Thread pool with configurable workers +- ✅ epoll-based listener thread +- ✅ Thread-safe request queue +- ✅ socketpair communication +- ✅ Multiple concurrent clients +- ✅ Non-blocking async operation +- ✅ Simulated processing (random sleep) + +**TODO (Enhancement Phase):** +- ⬜ Real LLM API integration (OpenAI, local models) +- ⬜ Request batching for efficiency +- ⬜ Priority queue for urgent requests +- ⬜ Timeout and cancellation +- ⬜ Backpressure handling (queue limits) +- ⬜ Metrics and monitoring +- ⬜ Error handling and retry logic +- ⬜ Configuration file support +- ⬜ Unit tests +- ⬜ Performance benchmarking + +## Integration Plan + +Phase 1: **Prototype Enhancement** (Current) +- Complete TODO items above +- Test with real LLM APIs +- Performance testing + +Phase 2: **ProxySQL Integration** +- Integrate into ProxySQL build system +- Add to existing MySQL/PgSQL thread logic +- Implement GenAI variable system + +Phase 3: **Production Features** +- Connection pooling +- Request multiplexing +- Caching layer +- Fallback strategies + +## Design Principles + +1. **Zero Coupling**: GenAI module doesn't know about client types +2. **Non-Blocking**: Clients never wait on GenAI responses +3. **Scalable**: Fixed resource usage (bounded thread pool) +4. **Observable**: Easy to monitor and debug +5. **Testable**: Standalone, independent testing diff --git a/genai_prototype/genai_demo.cpp b/genai_prototype/genai_demo.cpp new file mode 100644 index 0000000000..7c5d51eba3 --- /dev/null +++ b/genai_prototype/genai_demo.cpp @@ -0,0 +1,1151 @@ +/** + * @file genai_demo.cpp + * @brief Standalone demonstration of GenAI module architecture + * + * @par Architecture Overview + * + * This program demonstrates a thread-pool based GenAI module designed for + * integration into ProxySQL. The architecture follows these principles: + * + * - **Zero Coupling**: GenAI only knows about file descriptors, not client types + * - **Non-Blocking**: Clients never wait on GenAI responses + * - **Scalable**: Fixed resource usage (bounded thread pool) + * - **Observable**: Easy to monitor and debug + * - **Testable**: Standalone, independent testing + * + * @par Components + * + * 1. **GenAI Module** (GenAIModule class) + * - Listener thread using epoll to monitor all client file descriptors + * - Thread-safe request queue with condition variable + * - Fixed-size worker thread pool (configurable, default 4) + * - Processes requests asynchronously without blocking clients + * + * 2. **Client** (Client class) + * - Simulates MySQL/PgSQL threads in ProxySQL + * - Creates socketpair connections to GenAI + * - Sends requests non-blocking + * - Polls for responses asynchronously + * + * @par Communication Flow + * + * @msc + * Client, GenAI_Listener, GenAI_Worker; + * + * Client note Client + * Client->GenAI_Listener: socketpair() creates 2 FDs; + * Client->GenAI_Listener: register_client(write_fd); + * Client->GenAI_Listener: send request (async); + * Client note Client continues working; + * GenAI_Listener>>GenAI_Worker: enqueue request; + * GenAI_Worker note GenAI_Worker processes (simulate with sleep); + * GenAI_Worker->Client: write response to socket; + * Client note Client receives response when polling; + * @endmsc + * + * @par Build and Run + * + * @code{.sh} + * # Compile + * g++ -std=c++17 -o genai_demo genai_demo.cpp -lpthread + * + * # Run + * ./genai_demo + * + * # Or use the Makefile + * make run + * @endcode + * + * @author ProxySQL Team + * @date 2025-01-08 + * @version 1.0 + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// ============================================================================ +// Platform Compatibility +// ============================================================================ + +/** + * @def EFD_CLOEXEC + * @brief Close-on-exec flag for eventfd() + * + * Set the close-on-exec (FD_CLOEXEC) flag on the new file descriptor. + * This ensures the file descriptor is automatically closed when exec() is called. + */ +#ifndef EFD_CLOEXEC +#define EFD_CLOEXEC 0200000 +#endif + +/** + * @def EFD_NONBLOCK + * @brief Non-blocking flag for eventfd() + * + * Set the O_NONBLOCK flag on the new file descriptor. + * This allows read() and write() operations to return immediately with EAGAIN + * if the operation would block. + */ +#ifndef EFD_NONBLOCK +#define EFD_NONBLOCK 04000 +#endif + +// ============================================================================ +// Protocol Definitions +// ============================================================================ + +/** + * @struct RequestHeader + * @brief Header structure for client requests to GenAI module + * + * This structure is sent first, followed by input_size bytes of input data. + * All fields are in network byte order (little-endian on x86_64). + * + * @var RequestHeader::request_id + * Unique identifier for this request. Generated by the client and echoed + * back in the response for correlation. Allows tracking of multiple + * concurrent requests from the same client. + * + * @var RequestHeader::operation + * Type of operation to perform. See Operation enum for valid values: + * - OP_EMBEDDING (0): Generate text embeddings + * - OP_COMPLETION (1): Text completion/generation + * - OP_RAG (2): Retrieval-augmented generation + * + * @var RequestHeader::input_size + * Size in bytes of the input data that follows this header. + * The input data is sent immediately after this header. + * + * @var RequestHeader::flags + * Reserved for future use. Set to 0. + * + * @par Example + * @code{.cpp} + * RequestHeader req; + * req.request_id = 12345; + * req.operation = OP_EMBEDDING; + * req.input_size = text.length(); + * req.flags = 0; + * + * write(fd, &req, sizeof(req)); + * write(fd, text.data(), text.length()); + * @endcode + */ +struct RequestHeader { + uint64_t request_id; ///< Unique request identifier for correlation + uint32_t operation; ///< Operation type (OP_EMBEDDING, OP_COMPLETION, OP_RAG) + uint32_t input_size; ///< Size of input data following header (bytes) + uint32_t flags; ///< Reserved for future use (set to 0) +}; + +/** + * @struct ResponseHeader + * @brief Header structure for GenAI module responses to clients + * + * This structure is sent first, followed by output_size bytes of output data. + * + * @var ResponseHeader::request_id + * Echoes the request_id from the original RequestHeader. + * Used by the client to correlate the response with the pending request. + * + * @var ResponseHeader::status_code + * Status of the request processing: + * - 0: Success + * - >0: Error code (specific codes to be defined) + * + * @var ResponseHeader::output_size + * Size in bytes of the output data that follows this header. + * May be 0 if there is no output data. + * + * @var ResponseHeader::processing_time_ms + * Time taken by GenAI to process this request, in milliseconds. + * Useful for performance monitoring and debugging. + * + * @par Example + * @code{.cpp} + * ResponseHeader resp; + * read(fd, &resp, sizeof(resp)); + * + * std::vector output(resp.output_size); + * read(fd, output.data(), resp.output_size); + * + * if (resp.status_code == 0) { + * std::cout << "Processed in " << resp.processing_time_ms << "ms\n"; + * } + * @endcode + */ +struct ResponseHeader { + uint64_t request_id; ///< Echo of client's request identifier + uint32_t status_code; ///< 0=success, >0=error + uint32_t output_size; ///< Size of output data following header (bytes) + uint32_t processing_time_ms; ///< Actual processing time in milliseconds +}; + +/** + * @enum Operation + * @brief Supported GenAI operations + * + * Defines the types of operations that the GenAI module can perform. + * These values are used in the RequestHeader::operation field. + * + * @var OP_EMBEDDING + * Generate text embeddings using an embedding model. + * Input: Text string to embed + * Output: Vector of floating-point numbers (the embedding) + * + * @var OP_COMPLETION + * Generate text completion using a language model. + * Input: Prompt text + * Output: Generated completion text + * + * @var OP_RAG + * Retrieval-augmented generation. + * Input: Query text + * Output: Generated response with retrieved context + */ +enum Operation { + OP_EMBEDDING = 0, ///< Generate text embeddings (e.g., OpenAI text-embedding-3-small) + OP_COMPLETION = 1, ///< Text completion/generation (e.g., GPT-4) + OP_RAG = 2 ///< Retrieval-augmented generation +}; + +// ============================================================================ +// GenAI Module +// ============================================================================ + +/** + * @class GenAIModule + * @brief Thread-pool based GenAI processing module + * + * The GenAI module implements an asynchronous request processing system + * designed to handle GenAI operations (embeddings, completions, RAG) without + * blocking calling threads. + * + * @par Architecture + * + * - **Listener Thread**: Uses epoll to monitor all client file descriptors. + * When data arrives on any FD, it reads the request, validates it, and + * pushes it onto the request queue. + * + * - **Request Queue**: Thread-safe FIFO queue that holds pending requests. + * Protected by a mutex and signaled via condition variable. + * + * - **Worker Threads**: Fixed-size thread pool (configurable) that processes + * requests from the queue. Each worker waits for work, processes requests + * (potentially blocking on I/O to external services), and writes responses + * back to clients. + * + * @par Threading Model + * + * - One listener thread (epoll-based I/O multiplexing) + * - N worker threads (configurable via constructor) + * - Total threads = 1 + num_workers + * + * @par Thread Safety + * + * - Public methods are thread-safe + * - Multiple clients can register/unregister concurrently + * - Request queue is protected by mutex + * - Client FD set is protected by mutex + * + * @par Usage Example + * @code{.cpp} + * // Create GenAI module with 8 workers + * GenAIModule genai(8); + * + * // Start the module (spawns threads) + * genai.start(); + * + * // Register clients + * int client_fd = get_socket_from_client(); + * genai.register_client(client_fd); + * + * // ... module runs, processing requests ... + * + * // Shutdown + * genai.stop(); + * @endcode + * + * @par Shutdown Sequence + * + * 1. Set running_ flag to false + * 2. Write to event_fd to wake listener + * 3. Notify all worker threads via condition variable + * 4. Join all threads + * 5. Close all client FDs + * 6. Close epoll and event FDs + */ +class GenAIModule { +public: + + /** + * @struct Request + * @brief Internal request structure queued for worker processing + * + * This structure represents a request after it has been received from + * a client and enqueued for processing by worker threads. + * + * @var Request::client_fd + * File descriptor to write the response to. This is the client's + * socket FD that was registered via register_client(). + * + * @var Request::request_id + * Client's request identifier for correlation. Echoed back in response. + * + * @var Request::operation + * Operation type (OP_EMBEDDING, OP_COMPLETION, or OP_RAG). + * + * @var Request::input + * Input data (text prompt, etc.) from the client. + */ + struct Request { + int client_fd; ///< Where to send response (client's socket FD) + uint64_t request_id; ///< Client's correlation identifier + uint32_t operation; ///< Type of operation to perform + std::string input; ///< Input data (text, prompt, etc.) + }; + + /** + * @brief Construct a GenAI module with specified worker count + * + * Creates a GenAI module instance. The module is not started until + * start() is called. + * + * @param num_workers Number of worker threads in the pool (default: 4) + * + * @par Worker Count Guidelines + * - For I/O-bound operations (API calls): 4-8 workers typically sufficient + * - For CPU-bound operations (local models): match CPU core count + * - Too many workers can cause contention; too few can cause queue buildup + */ + GenAIModule(int num_workers = 4) + : num_workers_(num_workers), running_(false) {} + + /** + * @brief Start the GenAI module + * + * Initializes and starts all internal threads: + * - Creates epoll instance for listener + * - Creates eventfd for shutdown signaling + * - Spawns worker threads (each runs worker_loop()) + * - Spawns listener thread (runs listener_loop()) + * + * @post Module is running and ready to accept clients + * @post All worker threads are waiting for requests + * @post Listener thread is monitoring registered client FDs + * + * @note This method blocks briefly during thread creation but returns + * immediately after threads are spawned. + * + * @warning Do not call start() on an already-running module. + * Call stop() first. + * + * @par Thread Creation Sequence + * 1. Create epoll instance for I/O multiplexing + * 2. Create eventfd for shutdown notification + * 3. Add eventfd to epoll set + * 4. Spawn N worker threads (each calls worker_loop()) + * 5. Spawn listener thread (calls listener_loop()) + */ + void start() { + running_ = true; + + // Create epoll instance for listener + // EPOLL_CLOEXEC: Close on exec to prevent FD leaks + epoll_fd_ = epoll_create1(EPOLL_CLOEXEC); + if (epoll_fd_ < 0) { + perror("epoll_create1"); + exit(1); + } + + // Create eventfd for shutdown notification + // EFD_NONBLOCK: Non-blocking operations + // EFD_CLOEXEC: Close on exec + event_fd_ = eventfd(0, EFD_NONBLOCK | EFD_CLOEXEC); + if (event_fd_ < 0) { + perror("eventfd"); + exit(1); + } + + // Add eventfd to epoll set so listener can be woken for shutdown + struct epoll_event ev; + ev.events = EPOLLIN; + ev.data.fd = event_fd_; + if (epoll_ctl(epoll_fd_, EPOLL_CTL_ADD, event_fd_, &ev) < 0) { + perror("epoll_ctl eventfd"); + exit(1); + } + + // Start worker threads + // Each worker runs worker_loop() and waits for requests + for (int i = 0; i < num_workers_; i++) { + worker_threads_.emplace_back([this, i]() { worker_loop(i); }); + } + + // Start listener thread + // Listener runs listener_loop() and monitors client FDs via epoll + listener_thread_ = std::thread([this]() { listener_loop(); }); + + std::cout << "[GenAI] Module started with " << num_workers_ << " workers\n"; + } + + /** + * @brief Register a new client connection with the GenAI module + * + * Registers a client's file descriptor with the epoll set so the + * listener thread can monitor it for incoming requests. + * + * @param client_fd File descriptor to monitor (one end of socketpair) + * + * @pre client_fd is a valid, open file descriptor + * @pre Module has been started (start() was called) + * @pre client_fd has not already been registered + * + * @post client_fd is added to epoll set for monitoring + * @post client_fd is set to non-blocking mode + * @post Listener thread will be notified when data arrives on client_fd + * + * @par Thread Safety + * This method is thread-safe and can be called concurrently by + * multiple threads registering different clients. + * + * @par Client Registration Flow + * 1. Client creates socketpair() (2 FDs) + * 2. Client keeps one FD for reading responses + * 3. Client passes other FD to this method + * 4. This FD is added to epoll set + * 5. When client writes request, listener is notified + * + * @note This method is typically called by the client after creating + * a socketpair(). The client keeps one end, the GenAI module + * gets the other end. + * + * @warning The caller retains ownership of the original socketpair FDs + * and is responsible for closing them after unregistering. + */ + void register_client(int client_fd) { + std::lock_guard lock(clients_mutex_); + + // Set FD to non-blocking mode + // This ensures read/write operations don't block the listener + int flags = fcntl(client_fd, F_GETFL, 0); + fcntl(client_fd, F_SETFL, flags | O_NONBLOCK); + + // Add to epoll set + // EPOLLIN: Notify when FD is readable (data available to read) + struct epoll_event ev; + ev.events = EPOLLIN; + ev.data.fd = client_fd; + if (epoll_ctl(epoll_fd_, EPOLL_CTL_ADD, client_fd, &ev) < 0) { + perror("epoll_ctl client_fd"); + return; + } + + client_fds_.insert(client_fd); + std::cout << "[GenAI] Registered client fd " << client_fd << "\n"; + } + + /** + * @brief Stop the GenAI module and clean up resources + * + * Initiates graceful shutdown: + * - Sets running_ flag to false (signals threads to stop) + * - Writes to event_fd to wake listener from epoll_wait() + * - Notifies all workers via condition variable + * - Joins all threads (waits for them to finish) + * - Closes all client file descriptors + * - Closes epoll and event FDs + * + * @post All threads have stopped + * @post All resources are cleaned up + * @post Module can be restarted by calling start() again + * + * @par Shutdown Sequence + * 1. Set running_ = false (signals threads to exit) + * 2. Write to event_fd (wakes listener from epoll_wait) + * 3. Notify all workers (wakes them from condition_variable wait) + * 4. Join listener thread (waits for it to finish) + * 5. Join all worker threads (wait for them to finish) + * 6. Close all client FDs + * 7. Close epoll_fd and event_fd + * + * @note This method blocks until all threads have finished. + * In-flight requests will complete before workers exit. + * + * @warning Do not call stop() on a module that is not running. + * The behavior is undefined. + */ + void stop() { + running_ = false; + + // Wake up listener from epoll_wait + uint64_t value = 1; + write(event_fd_, &value, sizeof(value)); + + // Wake up all workers from condition_variable wait + queue_cv_.notify_all(); + + // Wait for listener thread to finish + if (listener_thread_.joinable()) { + listener_thread_.join(); + } + + // Wait for all worker threads to finish + for (auto& t : worker_threads_) { + if (t.joinable()) { + t.join(); + } + } + + // Close all client FDs + for (int fd : client_fds_) { + close(fd); + } + + // Close epoll and event FDs + close(epoll_fd_); + close(event_fd_); + + std::cout << "[GenAI] Module stopped\n"; + } + + /** + * @brief Get the current size of the request queue + * + * Returns the number of requests currently waiting in the queue + * to be processed by worker threads. + * + * @return Current queue size (number of pending requests) + * + * @par Thread Safety + * This method is thread-safe and can be called concurrently + * by multiple threads. + * + * @par Use Cases + * - Monitoring: Track queue depth to detect backpressure + * - Metrics: Collect statistics on request load + * - Debugging: Verify queue is draining properly + * + * @note The queue size is momentary and may change immediately + * after this method returns. + */ + size_t get_queue_size() const { + std::lock_guard lock(queue_mutex_); + return request_queue_.size(); + } + +private: + /** + * @brief Listener thread main loop + * + * Runs in a dedicated thread and monitors all registered client file + * descriptors using epoll. When a client sends a request, this method + * reads it, validates it, and enqueues it for worker processing. + * + * @par Event Loop + * 1. Wait on epoll for events (timeout: 100ms) + * 2. For each ready FD: + * - If event_fd: check for shutdown signal + * - If client FD: read request and enqueue + * 3. If client disconnects: remove from epoll and close FD + * 4. Loop until running_ is false + * + * @par Request Reading Flow + * 1. Read RequestHeader (fixed size) + * 2. Validate read succeeded (n > 0) + * 3. Read input data (variable size based on header.input_size) + * 4. Create Request structure + * 5. Push to request_queue_ + * 6. Notify one worker via condition variable + * + * @param worker_id Unused parameter (for future per-worker stats) + * + * @note Runs with 100ms timeout on epoll_wait to periodically check + * the running_ flag for shutdown. + */ + void listener_loop() { + const int MAX_EVENTS = 64; // Max events to process per epoll_wait + struct epoll_event events[MAX_EVENTS]; + + std::cout << "[GenAI] Listener thread started\n"; + + while (running_) { + // Wait for events on monitored FDs + // Timeout of 100ms allows periodic check of running_ flag + int nfds = epoll_wait(epoll_fd_, events, MAX_EVENTS, 100); + + if (nfds < 0 && errno != EINTR) { + perror("epoll_wait"); + break; + } + + // Process each ready FD + for (int i = 0; i < nfds; i++) { + // Check if this is the shutdown eventfd + if (events[i].data.fd == event_fd_) { + // Shutdown signal - will exit loop when running_ is false + continue; + } + + int client_fd = events[i].data.fd; + + // Read request header + RequestHeader header; + ssize_t n = read(client_fd, &header, sizeof(header)); + + if (n <= 0) { + // Connection closed or error + std::cout << "[GenAI] Client fd " << client_fd << " disconnected\n"; + epoll_ctl(epoll_fd_, EPOLL_CTL_DEL, client_fd, nullptr); + close(client_fd); + std::lock_guard lock(clients_mutex_); + client_fds_.erase(client_fd); + continue; + } + + // Read input data (may require multiple reads for large data) + std::string input(header.input_size, '\0'); + size_t total_read = 0; + while (total_read < header.input_size) { + ssize_t r = read(client_fd, &input[total_read], header.input_size - total_read); + if (r <= 0) break; + total_read += r; + } + + // Create request and enqueue for processing + Request req; + req.client_fd = client_fd; + req.request_id = header.request_id; + req.operation = header.operation; + req.input = std::move(input); + + { + // Critical section: modify request_queue_ + std::lock_guard lock(queue_mutex_); + request_queue_.push(std::move(req)); + } + + // Notify one worker thread that work is available + queue_cv_.notify_one(); + + std::cout << "[GenAI] Enqueued request " << header.request_id + << " from fd " << client_fd + << " (queue size: " << get_queue_size() << ")\n" << std::flush; + } + } + + std::cout << "[GenAI] Listener thread stopped\n"; + } + + /** + * @brief Worker thread main loop + * + * Runs in each worker thread. Waits for requests to appear in the + * queue, processes them (potentially blocking), and sends responses + * back to clients. + * + * @par Worker Loop + * 1. Wait on condition variable for queue to have work + * 2. When notified, check if running_ is still true + * 3. Pop request from queue (critical section) + * 4. Process request (may block on I/O to external services) + * 5. Send response back to client + * 6. Loop back to step 1 + * + * @par Request Processing + * Currently simulates processing with a random sleep (100-500ms). + * In production, this would call actual LLM APIs (OpenAI, local models, etc.). + * + * @par Response Sending + * - Writes ResponseHeader first (fixed size) + * - Writes output data second (variable size) + * - Client reads both to get complete response + * + * @param worker_id Identifier for this worker (0 to num_workers_-1) + * Used for logging and potentially for per-worker stats. + * + * @note Workers exit when running_ is set to false and queue is empty. + * In-flight requests will complete before workers exit. + */ + void worker_loop(int worker_id) { + std::cout << "[GenAI] Worker " << worker_id << " started\n"; + + while (running_) { + Request req; + + // Wait for work to appear in queue + { + std::unique_lock lock(queue_mutex_); + queue_cv_.wait(lock, [this] { + return !running_ || !request_queue_.empty(); + }); + + if (!running_) break; + + if (request_queue_.empty()) continue; + + // Get request from front of queue + req = std::move(request_queue_.front()); + request_queue_.pop(); + } + + // Simulate processing time (random sleep between 100-500ms) + // In production, this would be actual LLM API calls + unsigned int seed = req.request_id; + int sleep_ms = 100 + (rand_r(&seed) % 400); // 100-500ms + + std::cout << "[GenAI] Worker " << worker_id + << " processing request " << req.request_id + << " (sleep " << sleep_ms << "ms)\n"; + + std::this_thread::sleep_for(std::chrono::milliseconds(sleep_ms)); + + // Prepare response + std::string output = "Processed: " + req.input; + + ResponseHeader resp; + resp.request_id = req.request_id; + resp.status_code = 0; + resp.output_size = output.size(); + resp.processing_time_ms = sleep_ms; + + // Send response back to client + write(req.client_fd, &resp, sizeof(resp)); + write(req.client_fd, output.data(), output.size()); + + std::cout << "[GenAI] Worker " << worker_id + << " completed request " << req.request_id << "\n"; + } + + std::cout << "[GenAI] Worker " << worker_id << " stopped\n"; + } + + // ======================================================================== + // Member Variables + // ======================================================================== + + int num_workers_; ///< Number of worker threads in the pool + std::atomic running_; ///< Flag indicating if module is running + + int epoll_fd_; ///< epoll instance file descriptor + int event_fd_; ///< eventfd for shutdown notification + + std::thread listener_thread_; ///< Thread that monitors client FDs + std::vector worker_threads_; ///< Thread pool for request processing + + std::queue request_queue_; ///< FIFO queue of pending requests + mutable std::mutex queue_mutex_; ///< Protects request_queue_ + std::condition_variable queue_cv_; ///< Notifies workers when queue has work + + std::unordered_set client_fds_; ///< Set of registered client FDs + mutable std::mutex clients_mutex_; ///< Protects client_fds_ +}; + +// ============================================================================ +// Client +// ============================================================================ + +/** + * @class Client + * @brief Simulates a ProxySQL thread (MySQL/PgSQL) making GenAI requests + * + * This class demonstrates how a client thread would interact with the + * GenAI module in a real ProxySQL deployment. It creates a socketpair + * connection, sends requests asynchronously, and polls for responses. + * + * @par Communication Pattern + * + * 1. Create socketpair() (2 FDs: read_fd and genai_fd) + * 2. Pass genai_fd to GenAI module via register_client() + * 3. Keep read_fd for monitoring responses + * 4. Send requests via genai_fd (non-blocking) + * 5. Poll read_fd for responses + * 6. Process responses as they arrive + * + * @par Asynchronous Operation + * + * The key design principle is that the client never blocks waiting for + * GenAI responses. Instead: + * - Send requests and continue working + * - Poll for responses periodically + * - Handle responses when they arrive + * + * This allows the client thread to handle many concurrent requests + * and continue serving other clients while GenAI processes. + * + * @par Usage Example + * @code{.cpp} + * // Create client + * Client client("MySQL-Thread-1", 10); + * + * // Connect to GenAI module + * client.connect_to_genai(genai_module); + * + * // Run (send requests, wait for responses) + * client.run(); + * + * // Clean up + * client.close(); + * @endcode + */ +class Client { +public: + + /** + * @brief Construct a Client with specified name and request count + * + * Creates a client instance that will send a specified number of + * requests to the GenAI module. + * + * @param name Human-readable name for this client (e.g., "MySQL-Thread-1") + * @param num_requests Number of requests to send (default: 5) + * + * @note The name is used for logging/debugging to identify which + * client is sending/receiving requests. + */ + Client(const std::string& name, int num_requests = 5) + : name_(name), num_requests_(num_requests), next_id_(1) {} + + /** + * @brief Connect to the GenAI module + * + * Creates a socketpair and registers one end with the GenAI module. + * The client keeps one end for reading responses. + * + * @param genai Reference to the GenAI module to connect to + * + * @pre GenAI module has been started (genai.start() was called) + * + * @post Socketpair is created + * @post One end is registered with GenAI module + * @post Other end is kept for reading responses + * @post Both FDs are set to non-blocking mode + * + * @par Socketpair Creation + * - socketpair() creates 2 connected Unix domain sockets + * - fds[0] (read_fd_): Client reads responses from this + * - fds[1] (genai_fd_): Passed to GenAI, GenAI writes responses to this + * - Data written to one end can be read from the other end + * + * @par Connection Flow + * 1. Create socketpair(AF_UNIX, SOCK_STREAM, 0, fds) + * 2. Store fds[0] as read_fd_ (client reads responses here) + * 3. Store fds[1] as genai_fd_ (pass to GenAI module) + * 4. Call genai.register_client(genai_fd_) to register with module + * 5. Set read_fd_ to non-blocking for polling + * + * @note This method is typically called once per client at initialization. + */ + void connect_to_genai(GenAIModule& genai) { + // Create socketpair for bidirectional communication + int fds[2]; + if (socketpair(AF_UNIX, SOCK_STREAM, 0, fds) < 0) { + perror("socketpair"); + exit(1); + } + + read_fd_ = fds[0]; // Client reads responses from this + genai_fd_ = fds[1]; // GenAI writes responses to this + + // Register write end with GenAI module + genai.register_client(genai_fd_); + + // Set read end to non-blocking for async polling + int flags = fcntl(read_fd_, F_GETFL, 0); + fcntl(read_fd_, F_SETFL, flags | O_NONBLOCK); + + std::cout << "[" << name_ << "] Connected to GenAI (read_fd=" << read_fd_ << ")\n"; + } + + /** + * @brief Send all requests and wait for all responses + * + * This method: + * 1. Sends all requests immediately (non-blocking) + * 2. Polls for responses periodically + * 3. Processes responses as they arrive + * 4. Returns when all responses have been received + * + * @pre connect_to_genai() has been called + * + * @post All requests have been sent + * @post All responses have been received + * @post pending_requests_ is empty + * + * @par Sending Phase + * - Loop num_requests_ times + * - Each iteration calls send_request() + * - Requests are sent immediately, no waiting + * + * @par Receiving Phase + * - Loop until completed_ == num_requests_ + * - Call process_responses() to check for new responses + * - Sleep 50ms between checks (non-blocking poll) + * - Process each response as it arrives + * + * @note In production, the 50ms sleep would be replaced by adding + * read_fd_ to the thread's epoll set along with other FDs. + */ + void run() { + // Send all requests immediately (non-blocking) + for (int i = 0; i < num_requests_; i++) { + send_request(i); + } + + // Wait for all responses (simulate async handling) + std::cout << "[" << name_ << "] Waiting for " << num_requests_ << " responses...\n"; + + while (completed_ < num_requests_) { + process_responses(); + std::this_thread::sleep_for(std::chrono::milliseconds(50)); + } + + std::cout << "[" << name_ << "] All requests completed!\n"; + } + + /** + * @brief Close the connection to GenAI module + * + * Closes both ends of the socketpair. + * + * @post read_fd_ is closed + * @post genai_fd_ is closed + * @post FDs are set to -1 (closed state) + * + * @note This should be called after run() completes. + */ + void close() { + if (read_fd_ >= 0) ::close(read_fd_); + if (genai_fd_ >= 0) ::close(genai_fd_); + } + +private: + + /** + * @brief Send a single request to the GenAI module + * + * Creates and sends a request with a generated input string. + * + * @param index Index of this request (used to generate input text) + * + * @par Request Creation + * - Generate unique request_id (incrementing counter) + * - Create input string: "name_ input #index" + * - Fill in RequestHeader + * - Write header to genai_fd_ + * - Write input data to genai_fd_ + * - Store request in pending_requests_ with timestamp + * + * @par Non-Blocking Operation + * The write operations may block briefly, but in practice: + * - Socket buffer is typically large enough + * - If buffer is full, EAGAIN is returned (not handled in this demo) + * - Client would need to retry in production + * + * @note This method increments next_id_ to ensure unique request IDs. + */ + void send_request(int index) { + // Create input string for this request + std::string input = name_ + " input #" + std::to_string(index); + uint64_t request_id = next_id_++; + + // Fill request header + RequestHeader req; + req.request_id = request_id; + req.operation = OP_EMBEDDING; + req.input_size = input.size(); + req.flags = 0; + + // Send request header + write(genai_fd_, &req, sizeof(req)); + + // Send input data + write(genai_fd_, input.data(), input.size()); + + // Track this request with timestamp for measuring round-trip time + pending_requests_[request_id] = std::chrono::steady_clock::now(); + + std::cout << "[" << name_ << "] Sent request " << request_id + << " (" << input << ")\n"; + } + + /** + * @brief Check for and process any available responses + * + * Attempts to read a response from read_fd_. If a complete response + * is available, processes it and updates tracking. + * + * @par Response Reading Flow + * 1. Try to read ResponseHeader (non-blocking) + * 2. If no data available (n <= 0), return immediately + * 3. Read output data (may require multiple reads) + * 4. Look up pending request by request_id + * 5. Calculate round-trip time + * 6. Log response details + * 7. Remove from pending_requests_ + * 8. Increment completed_ counter + * + * @note This method is called periodically from run() to poll for + * responses. In production, read_fd_ would be in an epoll set. + */ + void process_responses() { + ResponseHeader resp; + ssize_t n = read(read_fd_, &resp, sizeof(resp)); + + if (n <= 0) { + return; // No data available yet + } + + // Read output data + std::string output(resp.output_size, '\0'); + size_t total_read = 0; + while (total_read < resp.output_size) { + ssize_t r = read(read_fd_, &output[total_read], resp.output_size - total_read); + if (r <= 0) break; + total_read += r; + } + + // Find and process the matching pending request + auto it = pending_requests_.find(resp.request_id); + if (it != pending_requests_.end()) { + auto start_time = it->second; + auto end_time = std::chrono::steady_clock::now(); + auto duration = std::chrono::duration_cast( + end_time - start_time).count(); + + std::cout << "[" << name_ << "] Received response for request " << resp.request_id + << " (took " << duration << "ms, processed in " + << resp.processing_time_ms << "ms): " << output << "\n"; + + pending_requests_.erase(it); + completed_++; + } + } + + // ======================================================================== + // Member Variables + // ======================================================================== + + std::string name_; ///< Human-readable client identifier + int num_requests_; ///< Total number of requests to send + uint64_t next_id_; ///< Counter for generating unique request IDs + int completed_ = 0; ///< Number of requests completed + + int read_fd_ = -1; ///< FD for reading responses from GenAI + int genai_fd_ = -1; ///< FD for writing requests to GenAI + + /// Map of pending requests: request_id -> timestamp when sent + std::unordered_map pending_requests_; +}; + +// ============================================================================ +// Main - Demonstration Entry Point +// ============================================================================ + +/** + * @brief Main entry point for the GenAI module demonstration + * + * Creates a GenAI module with 4 workers and spawns 3 client threads + * (simulating 2 MySQL threads and 1 PgSQL thread) that each send 3 + * concurrent requests. + * + * @par Execution Flow + * 1. Create GenAI module with 4 workers + * 2. Start the module (spawns listener and worker threads) + * 3. Wait 100ms for module to initialize + * 4. Create and start 3 client threads: + * - MySQL-Thread-1: 3 requests + * - MySQL-Thread-2: 3 requests + * - PgSQL-Thread-1: 3 requests + * 5. Wait for all clients to complete + * 6. Stop the GenAI module + * + * @par Expected Output + * The program will output: + * - Thread start/stop messages + * - Client connection messages + * - Request send/receive messages + * - Timing information (round-trip time, processing time) + * - Completion messages + * + * @return 0 on success, non-zero on error + * + * @note All clients are started with 50ms delays to demonstrate + * interleaved execution. + */ +int main() { + std::cout << "=== GenAI Module Demonstration ===\n\n"; + + // Create and start GenAI module with 4 worker threads + GenAIModule genai(4); + genai.start(); + + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + + // Create multiple clients + std::cout << "\n=== Creating Clients ===\n"; + + std::vector client_threads; + + // Client 1: MySQL Thread simulation + client_threads.emplace_back([&genai]() { + Client client("MySQL-Thread-1", 3); + client.connect_to_genai(genai); + client.run(); + client.close(); + }); + + std::this_thread::sleep_for(std::chrono::milliseconds(50)); + + // Client 2: MySQL Thread simulation + client_threads.emplace_back([&genai]() { + Client client("MySQL-Thread-2", 3); + client.connect_to_genai(genai); + client.run(); + client.close(); + }); + + std::this_thread::sleep_for(std::chrono::milliseconds(50)); + + // Client 3: PgSQL Thread simulation + client_threads.emplace_back([&genai]() { + Client client("PgSQL-Thread-1", 3); + client.connect_to_genai(genai); + client.run(); + client.close(); + }); + + // Wait for all clients to complete + for (auto& t : client_threads) { + if (t.joinable()) { + t.join(); + } + } + + std::cout << "\n=== All Clients Completed ===\n"; + + // Stop GenAI module + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + genai.stop(); + + std::cout << "\n=== Demonstration Complete ===\n"; + + return 0; +} diff --git a/genai_prototype/genai_demo_event.cpp b/genai_prototype/genai_demo_event.cpp new file mode 100644 index 0000000000..e393ac3230 --- /dev/null +++ b/genai_prototype/genai_demo_event.cpp @@ -0,0 +1,1740 @@ +/** + * @file genai_demo_event.cpp + * @brief Event-driven GenAI module POC with real llama-server integration + * + * This POC demonstrates the GenAI module architecture with: + * - Shared memory communication (passing pointers, not copying data) + * - Real embedding generation via llama-server HTTP API + * - Real reranking via llama-server HTTP API + * - Support for single or multiple documents per request + * - libcurl-based HTTP client for API calls + * + * @par Architecture + * + * Client and GenAI module share the same process memory space. + * Documents and results are passed by pointer to avoid copying. + * + * @par Embedding Request Flow + * + * 1. Client allocates document(s) in its own memory + * 2. Client sends request with document pointers to GenAI + * 3. GenAI reads document pointers and accesses shared memory + * 4. GenAI calls llama-server via HTTP to get embeddings + * 5. GenAI allocates embedding result and passes pointer back to client + * 6. Client reads embedding from shared memory and displays length + * + * @par Rerank Request Flow + * + * 1. Client allocates query and document(s) in its own memory + * 2. Client sends request with query pointer and document pointers to GenAI + * 3. GenAI reads pointers and accesses shared memory + * 4. GenAI calls llama-server via HTTP to get rerank results + * 5. GenAI allocates rerank result array and passes pointer back to client + * 6. Client reads results (index, score) from shared memory + * + * @author ProxySQL Team + * @date 2025-01-09 + * @version 3.1 - POC with embeddings and reranking + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// Platform compatibility +#ifndef EFD_CLOEXEC +#define EFD_CLOEXEC 0200000 +#endif +#ifndef EFD_NONBLOCK +#define EFD_NONBLOCK 04000 +#endif + +// ============================================================================ +// Protocol Definitions +// ============================================================================ + +/** + * @enum Operation + * @brief GenAI operation types + */ +enum Operation : uint32_t { + OP_EMBEDDING = 0, ///< Generate embeddings for documents + OP_COMPLETION = 1, ///< Text completion (future) + OP_RERANK = 2, ///< Rerank documents by relevance to query +}; + +/** + * @struct Document + * @brief Document structure passed by pointer (shared memory) + * + * Client allocates this structure and passes its pointer to GenAI. + * GenAI reads the document directly from shared memory. + */ +struct Document { + const char* text; ///< Pointer to document text (owned by client) + size_t text_size; ///< Length of text in bytes + + Document() : text(nullptr), text_size(0) {} + + Document(const char* t, size_t s) : text(t), text_size(s) {} +}; + +/** + * @struct RequestHeader + * @brief Header for GenAI requests + * + * For embedding requests: client sends document_count pointers to Document structures (as uint64_t). + * For rerank requests: client sends query (as null-terminated string), then document_count pointers. + */ +struct RequestHeader { + uint64_t request_id; ///< Client's correlation ID + uint32_t operation; ///< Operation type (OP_EMBEDDING, OP_RERANK, etc.) + uint32_t document_count; ///< Number of documents (1 or more) + uint32_t flags; ///< Reserved for future use + uint32_t top_n; ///< For rerank: number of top results to return +}; + +/** + * @struct EmbeddingResult + * @brief Single embedding vector allocated by GenAI, read by client + * + * GenAI allocates this and passes the pointer to client. + * Client reads the embedding and then frees it. + */ +struct EmbeddingResult { + float* data; ///< Pointer to embedding vector (owned by GenAI initially) + size_t size; ///< Number of floats in the embedding + + EmbeddingResult() : data(nullptr), size(0) {} + + ~EmbeddingResult() { + if (data) { + delete[] data; + data = nullptr; + } + } + + // Move constructor and assignment + EmbeddingResult(EmbeddingResult&& other) noexcept + : data(other.data), size(other.size) { + other.data = nullptr; + other.size = 0; + } + + EmbeddingResult& operator=(EmbeddingResult&& other) noexcept { + if (this != &other) { + if (data) delete[] data; + data = other.data; + size = other.size; + other.data = nullptr; + other.size = 0; + } + return *this; + } + + // Disable copy + EmbeddingResult(const EmbeddingResult&) = delete; + EmbeddingResult& operator=(const EmbeddingResult&) = delete; +}; + +/** + * @struct BatchEmbeddingResult + * @brief Multiple embedding vectors allocated by GenAI, read by client + * + * For batch requests, GenAI allocates an array of embeddings. + * The embeddings are stored contiguously: [emb1 floats, emb2 floats, ...] + * Each embedding has the same size. + */ +struct BatchEmbeddingResult { + float* data; ///< Pointer to contiguous embedding array (owned by GenAI initially) + size_t embedding_size; ///< Number of floats per embedding + size_t count; ///< Number of embeddings + + BatchEmbeddingResult() : data(nullptr), embedding_size(0), count(0) {} + + ~BatchEmbeddingResult() { + if (data) { + delete[] data; + data = nullptr; + } + } + + // Move constructor and assignment + BatchEmbeddingResult(BatchEmbeddingResult&& other) noexcept + : data(other.data), embedding_size(other.embedding_size), count(other.count) { + other.data = nullptr; + other.embedding_size = 0; + other.count = 0; + } + + BatchEmbeddingResult& operator=(BatchEmbeddingResult&& other) noexcept { + if (this != &other) { + if (data) delete[] data; + data = other.data; + embedding_size = other.embedding_size; + count = other.count; + other.data = nullptr; + other.embedding_size = 0; + other.count = 0; + } + return *this; + } + + // Disable copy + BatchEmbeddingResult(const BatchEmbeddingResult&) = delete; + BatchEmbeddingResult& operator=(const BatchEmbeddingResult&) = delete; + + size_t total_floats() const { return embedding_size * count; } +}; + +/** + * @struct RerankResult + * @brief Single rerank result with index and relevance score + * + * Represents one document's rerank result. + * Allocated by GenAI, passed to client via shared memory. + */ +struct RerankResult { + uint32_t index; ///< Original document index + float score; ///< Relevance score (higher is better) +}; + +/** + * @struct RerankResultArray + * @brief Array of rerank results allocated by GenAI + * + * For rerank requests, GenAI allocates an array of RerankResult. + * Client takes ownership and must free the array. + */ +struct RerankResultArray { + RerankResult* data; ///< Pointer to result array (owned by GenAI initially) + size_t count; ///< Number of results + + RerankResultArray() : data(nullptr), count(0) {} + + ~RerankResultArray() { + if (data) { + delete[] data; + data = nullptr; + } + } + + // Move constructor and assignment + RerankResultArray(RerankResultArray&& other) noexcept + : data(other.data), count(other.count) { + other.data = nullptr; + other.count = 0; + } + + RerankResultArray& operator=(RerankResultArray&& other) noexcept { + if (this != &other) { + if (data) delete[] data; + data = other.data; + count = other.count; + other.data = nullptr; + other.count = 0; + } + return *this; + } + + // Disable copy + RerankResultArray(const RerankResultArray&) = delete; + RerankResultArray& operator=(const RerankResultArray&) = delete; +}; + +/** + * @struct ResponseHeader + * @brief Header for GenAI responses + * + * For embeddings: passes pointer to BatchEmbeddingResult as uint64_t. + * For rerank: passes pointer to RerankResultArray as uint64_t. + */ +struct ResponseHeader { + uint64_t request_id; ///< Echo client's request ID + uint32_t status_code; ///< 0=success, >0=error + uint32_t embedding_size; ///< For embeddings: floats per embedding + uint32_t processing_time_ms;///< Time taken to process + uint64_t result_ptr; ///< Pointer to result data (as uint64_t) + uint32_t result_count; ///< Number of results (embeddings or rerank results) + uint32_t data_size; ///< Additional data size (for future use) +}; + +// ============================================================================ +// GenAI Module +// ============================================================================ + +/** + * @class GenAIModule + * @brief Thread-pool based GenAI processing module with real embedding support + * + * This module provides embedding generation via llama-server HTTP API. + * It uses a thread pool with epoll-based listener for async processing. + */ +class GenAIModule { +public: + /** + * @struct Request + * @brief Internal request representation + */ + struct Request { + int client_fd; + uint64_t request_id; + uint32_t operation; + std::string query; ///< Query text (for rerank) + uint32_t top_n; ///< Number of top results (for rerank) + std::vector documents; ///< Document pointers from shared memory + }; + + GenAIModule(int num_workers = 4) + : num_workers_(num_workers), running_(false) { + + // Initialize libcurl + curl_global_init(CURL_GLOBAL_ALL); + } + + ~GenAIModule() { + if (running_) { + stop(); + } + curl_global_cleanup(); + } + + /** + * @brief Start the GenAI module (spawn threads) + */ + void start() { + running_ = true; + + epoll_fd_ = epoll_create1(EPOLL_CLOEXEC); + if (epoll_fd_ < 0) { + perror("epoll_create1"); + exit(1); + } + + event_fd_ = eventfd(0, EFD_NONBLOCK | EFD_CLOEXEC); + if (event_fd_ < 0) { + perror("eventfd"); + exit(1); + } + + struct epoll_event ev; + ev.events = EPOLLIN; + ev.data.fd = event_fd_; + if (epoll_ctl(epoll_fd_, EPOLL_CTL_ADD, event_fd_, &ev) < 0) { + perror("epoll_ctl eventfd"); + exit(1); + } + + for (int i = 0; i < num_workers_; i++) { + worker_threads_.emplace_back([this, i]() { worker_loop(i); }); + } + + listener_thread_ = std::thread([this]() { listener_loop(); }); + + std::cout << "[GenAI] Module started with " << num_workers_ << " workers\n"; + std::cout << "[GenAI] Embedding endpoint: http://127.0.0.1:8013/embedding\n"; + std::cout << "[GenAI] Rerank endpoint: http://127.0.0.1:8012/rerank\n"; + } + + /** + * @brief Register a client file descriptor with GenAI + * + * @param client_fd File descriptor to monitor (from socketpair) + */ + void register_client(int client_fd) { + std::lock_guard lock(clients_mutex_); + + int flags = fcntl(client_fd, F_GETFL, 0); + fcntl(client_fd, F_SETFL, flags | O_NONBLOCK); + + struct epoll_event ev; + ev.events = EPOLLIN; + ev.data.fd = client_fd; + if (epoll_ctl(epoll_fd_, EPOLL_CTL_ADD, client_fd, &ev) < 0) { + perror("epoll_ctl add client"); + return; + } + + client_fds_.insert(client_fd); + } + + /** + * @brief Stop the GenAI module + */ + void stop() { + running_ = false; + + uint64_t value = 1; + write(event_fd_, &value, sizeof(value)); + + queue_cv_.notify_all(); + + for (auto& t : worker_threads_) { + if (t.joinable()) t.join(); + } + + if (listener_thread_.joinable()) { + listener_thread_.join(); + } + + close(event_fd_); + close(epoll_fd_); + + std::cout << "[GenAI] Module stopped\n"; + } + + /** + * @brief Get current queue depth (for statistics) + */ + size_t get_queue_size() const { + std::lock_guard lock(queue_mutex_); + return request_queue_.size(); + } + +private: + /** + * @brief Listener loop - reads requests from clients via epoll + */ + void listener_loop() { + const int MAX_EVENTS = 64; + struct epoll_event events[MAX_EVENTS]; + + while (running_) { + int nfds = epoll_wait(epoll_fd_, events, MAX_EVENTS, 100); + + if (nfds < 0 && errno != EINTR) { + perror("epoll_wait"); + break; + } + + for (int i = 0; i < nfds; i++) { + if (events[i].data.fd == event_fd_) { + continue; + } + + int client_fd = events[i].data.fd; + + RequestHeader header; + ssize_t n = read(client_fd, &header, sizeof(header)); + + if (n <= 0) { + epoll_ctl(epoll_fd_, EPOLL_CTL_DEL, client_fd, nullptr); + close(client_fd); + std::lock_guard lock(clients_mutex_); + client_fds_.erase(client_fd); + continue; + } + + // For rerank operations, read the query first + std::string query; + if (header.operation == OP_RERANK) { + // Read query as null-terminated string + char ch; + while (true) { + ssize_t r = read(client_fd, &ch, 1); + if (r <= 0) break; + if (ch == '\0') break; // Null terminator + query += ch; + } + } + + // Read document pointers (passed as uint64_t) + std::vector doc_ptrs(header.document_count); + size_t total_read = 0; + while (total_read < header.document_count * sizeof(uint64_t)) { + ssize_t r = read(client_fd, + (char*)doc_ptrs.data() + total_read, + header.document_count * sizeof(uint64_t) - total_read); + if (r <= 0) break; + total_read += r; + } + + // Build request with document pointers (shared memory) + Request req; + req.client_fd = client_fd; + req.request_id = header.request_id; + req.operation = header.operation; + req.query = query; + req.top_n = header.top_n; + req.documents.reserve(header.document_count); + + for (uint32_t i = 0; i < header.document_count; i++) { + Document* doc = reinterpret_cast(doc_ptrs[i]); + if (doc && doc->text) { + req.documents.push_back(*doc); + } + } + + { + std::lock_guard lock(queue_mutex_); + request_queue_.push(std::move(req)); + } + + queue_cv_.notify_one(); + } + } + } + + /** + * @brief Callback function for libcurl to handle HTTP response + */ + static size_t WriteCallback(void* contents, size_t size, size_t nmemb, void* userp) { + size_t totalSize = size * nmemb; + std::string* response = static_cast(userp); + response->append(static_cast(contents), totalSize); + return totalSize; + } + + /** + * @brief Call llama-server embedding API via libcurl + * + * @param text Document text to embed + * @return EmbeddingResult containing the embedding vector + */ + EmbeddingResult call_llama_embedding(const std::string& text) { + EmbeddingResult result; + CURL* curl = curl_easy_init(); + + if (!curl) { + std::cerr << "[Worker] Failed to initialize curl\n"; + return result; + } + + // Build JSON request + std::stringstream json; + json << "{\"input\":\""; + + // Escape JSON special characters + for (char c : text) { + switch (c) { + case '"': json << "\\\""; break; + case '\\': json << "\\\\"; break; + case '\n': json << "\\n"; break; + case '\r': json << "\\r"; break; + case '\t': json << "\\t"; break; + default: json << c; break; + } + } + + json << "\"}"; + + std::string json_str = json.str(); + + // Configure curl + curl_easy_setopt(curl, CURLOPT_URL, "http://127.0.0.1:8013/embedding"); + curl_easy_setopt(curl, CURLOPT_POST, 1L); + curl_easy_setopt(curl, CURLOPT_POSTFIELDS, json_str.c_str()); + curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, WriteCallback); + + std::string response_data; + curl_easy_setopt(curl, CURLOPT_WRITEDATA, &response_data); + + // Add content-type header + struct curl_slist* headers = nullptr; + headers = curl_slist_append(headers, "Content-Type: application/json"); + curl_easy_setopt(curl, CURLOPT_HTTPHEADER, headers); + + // Perform request + CURLcode res = curl_easy_perform(curl); + + if (res != CURLE_OK) { + std::cerr << "[Worker] curl_easy_perform() failed: " + << curl_easy_strerror(res) << "\n"; + } else { + // Parse JSON response to extract embedding + // Response format: [{"index":0,"embedding":[0.1,0.2,...]}] + size_t embedding_pos = response_data.find("\"embedding\":"); + if (embedding_pos != std::string::npos) { + // Find the array start + size_t array_start = response_data.find("[", embedding_pos); + if (array_start != std::string::npos) { + // Find matching bracket + size_t array_end = array_start; + int bracket_count = 0; + bool in_array = false; + + for (size_t i = array_start; i < response_data.size(); i++) { + if (response_data[i] == '[') { + bracket_count++; + in_array = true; + } else if (response_data[i] == ']') { + bracket_count--; + if (bracket_count == 0 && in_array) { + array_end = i; + break; + } + } + } + + // Parse the array of floats + std::string array_str = response_data.substr(array_start + 1, array_end - array_start - 1); + std::vector embedding; + std::stringstream ss(array_str); + std::string token; + + while (std::getline(ss, token, ',')) { + // Remove whitespace and "null" values + token.erase(0, token.find_first_not_of(" \t\n\r")); + token.erase(token.find_last_not_of(" \t\n\r") + 1); + + if (token == "null" || token.empty()) { + continue; + } + + try { + float val = std::stof(token); + embedding.push_back(val); + } catch (...) { + // Skip invalid values + } + } + + if (!embedding.empty()) { + result.size = embedding.size(); + result.data = new float[embedding.size()]; + std::copy(embedding.begin(), embedding.end(), result.data); + } + } + } + } + + curl_slist_free_all(headers); + curl_easy_cleanup(curl); + + return result; + } + + /** + * @brief Call llama-server batch embedding API via libcurl + * + * @param texts Vector of document texts to embed + * @return BatchEmbeddingResult containing multiple embedding vectors + */ + BatchEmbeddingResult call_llama_batch_embedding(const std::vector& texts) { + BatchEmbeddingResult result; + CURL* curl = curl_easy_init(); + + if (!curl) { + std::cerr << "[Worker] Failed to initialize curl\n"; + return result; + } + + // Build JSON request with array of inputs + std::stringstream json; + json << "{\"input\":["; + + for (size_t i = 0; i < texts.size(); i++) { + if (i > 0) json << ","; + json << "\""; + + // Escape JSON special characters + for (char c : texts[i]) { + switch (c) { + case '"': json << "\\\""; break; + case '\\': json << "\\\\"; break; + case '\n': json << "\\n"; break; + case '\r': json << "\\r"; break; + case '\t': json << "\\t"; break; + default: json << c; break; + } + } + + json << "\""; + } + + json << "]}"; + + std::string json_str = json.str(); + + // Configure curl + curl_easy_setopt(curl, CURLOPT_URL, "http://127.0.0.1:8013/embedding"); + curl_easy_setopt(curl, CURLOPT_POST, 1L); + curl_easy_setopt(curl, CURLOPT_POSTFIELDS, json_str.c_str()); + curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, WriteCallback); + + std::string response_data; + curl_easy_setopt(curl, CURLOPT_WRITEDATA, &response_data); + + // Add content-type header + struct curl_slist* headers = nullptr; + headers = curl_slist_append(headers, "Content-Type: application/json"); + curl_easy_setopt(curl, CURLOPT_HTTPHEADER, headers); + + // Perform request + CURLcode res = curl_easy_perform(curl); + + if (res != CURLE_OK) { + std::cerr << "[Worker] curl_easy_perform() failed: " + << curl_easy_strerror(res) << "\n"; + } else { + // Parse JSON response to extract embeddings + // Response format: [{"index":0,"embedding":[[float1,float2,...]]}, {"index":1,...}] + std::vector> all_embeddings; + + // Find all result objects by looking for "embedding": + size_t pos = 0; + while ((pos = response_data.find("\"embedding\":", pos)) != std::string::npos) { + // Find the array start (expecting nested [[...]]) + size_t array_start = response_data.find("[", pos); + if (array_start == std::string::npos) break; + + // Skip the first [ to find the inner array + size_t inner_start = array_start + 1; + if (inner_start >= response_data.size() || response_data[inner_start] != '[') { + // Not a nested array, use first bracket + inner_start = array_start; + } + + // Find matching bracket for the inner array + size_t array_end = inner_start; + int bracket_count = 0; + bool in_array = false; + + for (size_t i = inner_start; i < response_data.size(); i++) { + if (response_data[i] == '[') { + bracket_count++; + in_array = true; + } else if (response_data[i] == ']') { + bracket_count--; + if (bracket_count == 0 && in_array) { + array_end = i; + break; + } + } + } + + // Parse the array of floats + std::string array_str = response_data.substr(inner_start + 1, array_end - inner_start - 1); + std::vector embedding; + std::stringstream ss(array_str); + std::string token; + + while (std::getline(ss, token, ',')) { + // Remove whitespace and "null" values + token.erase(0, token.find_first_not_of(" \t\n\r")); + token.erase(token.find_last_not_of(" \t\n\r") + 1); + + if (token == "null" || token.empty()) { + continue; + } + + try { + float val = std::stof(token); + embedding.push_back(val); + } catch (...) { + // Skip invalid values + } + } + + if (!embedding.empty()) { + all_embeddings.push_back(std::move(embedding)); + } + + // Move past this result + pos = array_end + 1; + } + + // Convert to contiguous array + if (!all_embeddings.empty()) { + result.count = all_embeddings.size(); + result.embedding_size = all_embeddings[0].size(); + + // Allocate contiguous array + size_t total_floats = result.embedding_size * result.count; + result.data = new float[total_floats]; + + // Copy embeddings + for (size_t i = 0; i < all_embeddings.size(); i++) { + size_t offset = i * result.embedding_size; + const auto& emb = all_embeddings[i]; + std::copy(emb.begin(), emb.end(), result.data + offset); + } + } + } + + curl_slist_free_all(headers); + curl_easy_cleanup(curl); + + return result; + } + + /** + * @brief Call llama-server rerank API via libcurl + * + * @param query Query string to rerank against + * @param texts Vector of document texts to rerank + * @param top_n Maximum number of results to return + * @return RerankResultArray containing top N results with index and score + */ + RerankResultArray call_llama_rerank(const std::string& query, + const std::vector& texts, + uint32_t top_n) { + RerankResultArray result; + CURL* curl = curl_easy_init(); + + if (!curl) { + std::cerr << "[Worker] Failed to initialize curl\n"; + return result; + } + + // Build JSON request + std::stringstream json; + json << "{\"query\":\""; + + // Escape query JSON special characters + for (char c : query) { + switch (c) { + case '"': json << "\\\""; break; + case '\\': json << "\\\\"; break; + case '\n': json << "\\n"; break; + case '\r': json << "\\r"; break; + case '\t': json << "\\t"; break; + default: json << c; break; + } + } + + json << "\",\"documents\":["; + + // Add documents + for (size_t i = 0; i < texts.size(); i++) { + if (i > 0) json << ","; + json << "\""; + + // Escape document JSON special characters + for (char c : texts[i]) { + switch (c) { + case '"': json << "\\\""; break; + case '\\': json << "\\\\"; break; + case '\n': json << "\\n"; break; + case '\r': json << "\\r"; break; + case '\t': json << "\\t"; break; + default: json << c; break; + } + } + + json << "\""; + } + + json << "]}"; + + std::string json_str = json.str(); + + // Configure curl + curl_easy_setopt(curl, CURLOPT_URL, "http://127.0.0.1:8012/rerank"); + curl_easy_setopt(curl, CURLOPT_POST, 1L); + curl_easy_setopt(curl, CURLOPT_POSTFIELDS, json_str.c_str()); + curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, WriteCallback); + + std::string response_data; + curl_easy_setopt(curl, CURLOPT_WRITEDATA, &response_data); + + // Add content-type header + struct curl_slist* headers = nullptr; + headers = curl_slist_append(headers, "Content-Type: application/json"); + curl_easy_setopt(curl, CURLOPT_HTTPHEADER, headers); + + // Perform request + CURLcode res = curl_easy_perform(curl); + + if (res != CURLE_OK) { + std::cerr << "[Worker] curl_easy_perform() failed: " + << curl_easy_strerror(res) << "\n"; + } else { + // Parse JSON response to extract rerank results + // Response format: {"results": [{"index": 0, "relevance_score": 0.95}, ...]} + size_t results_pos = response_data.find("\"results\":"); + if (results_pos != std::string::npos) { + // Find the array start + size_t array_start = response_data.find("[", results_pos); + if (array_start != std::string::npos) { + // Find matching bracket + size_t array_end = array_start; + int bracket_count = 0; + bool in_array = false; + + for (size_t i = array_start; i < response_data.size(); i++) { + if (response_data[i] == '[') { + bracket_count++; + in_array = true; + } else if (response_data[i] == ']') { + bracket_count--; + if (bracket_count == 0 && in_array) { + array_end = i; + break; + } + } + } + + // Parse each result object + std::string array_str = response_data.substr(array_start + 1, array_end - array_start - 1); + std::vector results; + + // Simple parsing - look for "index" and "relevance_score" patterns + size_t pos = 0; + while (pos < array_str.size()) { + size_t index_pos = array_str.find("\"index\":", pos); + if (index_pos == std::string::npos) break; + + // Skip to the number + size_t num_start = index_pos + 8; // Skip "\"index\":" + while (num_start < array_str.size() && + (array_str[num_start] == ' ' || array_str[num_start] == '\t')) { + num_start++; + } + + // Find the end of the number + size_t num_end = num_start; + while (num_end < array_str.size() && + (isdigit(array_str[num_end]) || array_str[num_end] == '-')) { + num_end++; + } + + uint32_t index = 0; + if (num_start < num_end) { + try { + index = std::stoul(array_str.substr(num_start, num_end - num_start)); + } catch (...) {} + } + + // Find relevance_score + size_t score_pos = array_str.find("\"relevance_score\":", index_pos); + if (score_pos == std::string::npos) break; + + // Skip to the number + size_t score_start = score_pos + 18; // Skip "\"relevance_score\":" + while (score_start < array_str.size() && + (array_str[score_start] == ' ' || array_str[score_start] == '\t')) { + score_start++; + } + + // Find the end of the number (including decimal point and negative sign) + size_t score_end = score_start; + while (score_end < array_str.size() && + (isdigit(array_str[score_end]) || + array_str[score_end] == '.' || + array_str[score_end] == '-' || + array_str[score_end] == 'e' || + array_str[score_end] == 'E')) { + score_end++; + } + + float score = 0.0f; + if (score_start < score_end) { + try { + score = std::stof(array_str.substr(score_start, score_end - score_start)); + } catch (...) {} + } + + results.push_back({index, score}); + pos = score_end + 1; + } + + // Limit to top_n results + if (!results.empty() && top_n > 0) { + size_t count = std::min(static_cast(top_n), results.size()); + result.count = count; + result.data = new RerankResult[count]; + std::copy(results.begin(), results.begin() + count, result.data); + } + } + } + } + + curl_slist_free_all(headers); + curl_easy_cleanup(curl); + + return result; + } + + /** + * @brief Worker loop - processes requests from queue + */ + void worker_loop(int worker_id) { + while (running_) { + Request req; + + { + std::unique_lock lock(queue_mutex_); + queue_cv_.wait(lock, [this] { + return !running_ || !request_queue_.empty(); + }); + + if (!running_) break; + + if (request_queue_.empty()) continue; + + req = std::move(request_queue_.front()); + request_queue_.pop(); + } + + auto start_time = std::chrono::steady_clock::now(); + + // Process based on operation type + if (req.operation == OP_EMBEDDING) { + if (!req.documents.empty()) { + // Prepare texts for batch embedding + std::vector texts; + texts.reserve(req.documents.size()); + size_t total_bytes = 0; + + for (const auto& doc : req.documents) { + texts.emplace_back(doc.text, doc.text_size); + total_bytes += doc.text_size; + } + + std::cout << "[Worker " << worker_id << "] Processing batch embedding for " + << req.documents.size() << " document(s) (" << total_bytes << " bytes)\n"; + + // Use batch embedding for all documents + BatchEmbeddingResult batch_embedding = call_llama_batch_embedding(texts); + + auto end_time = std::chrono::steady_clock::now(); + int processing_time_ms = std::chrono::duration_cast( + end_time - start_time).count(); + + // Prepare response + ResponseHeader resp; + resp.request_id = req.request_id; + resp.status_code = (batch_embedding.data != nullptr) ? 0 : 1; + resp.embedding_size = batch_embedding.embedding_size; + resp.processing_time_ms = processing_time_ms; + resp.result_ptr = reinterpret_cast(batch_embedding.data); + resp.result_count = batch_embedding.count; + resp.data_size = 0; + + // Send response header + write(req.client_fd, &resp, sizeof(resp)); + + // The batch embedding data stays in shared memory (allocated by GenAI) + // Client will read it and then take ownership (client must free it) + batch_embedding.data = nullptr; // Transfer ownership to client + } else { + // No documents + auto end_time = std::chrono::steady_clock::now(); + int processing_time_ms = std::chrono::duration_cast( + end_time - start_time).count(); + + ResponseHeader resp; + resp.request_id = req.request_id; + resp.status_code = 1; // Error + resp.embedding_size = 0; + resp.processing_time_ms = processing_time_ms; + resp.result_ptr = 0; + resp.result_count = 0; + resp.data_size = 0; + + write(req.client_fd, &resp, sizeof(resp)); + } + } else if (req.operation == OP_RERANK) { + if (!req.documents.empty() && !req.query.empty()) { + // Prepare texts for reranking + std::vector texts; + texts.reserve(req.documents.size()); + size_t total_bytes = 0; + + for (const auto& doc : req.documents) { + texts.emplace_back(doc.text, doc.text_size); + total_bytes += doc.text_size; + } + + std::cout << "[Worker " << worker_id << "] Processing rerank for " + << req.documents.size() << " document(s), query=\"" + << req.query.substr(0, 50) + << (req.query.size() > 50 ? "..." : "") + << "\" (" << total_bytes << " bytes)\n"; + + // Call rerank API + RerankResultArray rerank_results = call_llama_rerank(req.query, texts, req.top_n); + + auto end_time = std::chrono::steady_clock::now(); + int processing_time_ms = std::chrono::duration_cast( + end_time - start_time).count(); + + // Prepare response + ResponseHeader resp; + resp.request_id = req.request_id; + resp.status_code = (rerank_results.data != nullptr) ? 0 : 1; + resp.embedding_size = 0; // Not used for rerank + resp.processing_time_ms = processing_time_ms; + resp.result_ptr = reinterpret_cast(rerank_results.data); + resp.result_count = rerank_results.count; + resp.data_size = 0; + + // Send response header + write(req.client_fd, &resp, sizeof(resp)); + + // The rerank results stay in shared memory (allocated by GenAI) + // Client will read them and then take ownership (client must free it) + rerank_results.data = nullptr; // Transfer ownership to client + } else { + // No documents or query + auto end_time = std::chrono::steady_clock::now(); + int processing_time_ms = std::chrono::duration_cast( + end_time - start_time).count(); + + ResponseHeader resp; + resp.request_id = req.request_id; + resp.status_code = 1; // Error + resp.embedding_size = 0; + resp.processing_time_ms = processing_time_ms; + resp.result_ptr = 0; + resp.result_count = 0; + resp.data_size = 0; + + write(req.client_fd, &resp, sizeof(resp)); + } + } else { + // Unknown operation + auto end_time = std::chrono::steady_clock::now(); + int processing_time_ms = std::chrono::duration_cast( + end_time - start_time).count(); + + ResponseHeader resp; + resp.request_id = req.request_id; + resp.status_code = 1; // Error + resp.embedding_size = 0; + resp.processing_time_ms = processing_time_ms; + resp.result_ptr = 0; + resp.result_count = 0; + resp.data_size = 0; + + write(req.client_fd, &resp, sizeof(resp)); + } + } + } + + int num_workers_; + std::atomic running_; + int epoll_fd_; + int event_fd_; + std::thread listener_thread_; + std::vector worker_threads_; + std::queue request_queue_; + mutable std::mutex queue_mutex_; + std::condition_variable queue_cv_; + std::unordered_set client_fds_; + mutable std::mutex clients_mutex_; +}; + +// ============================================================================ +// Configuration +// ============================================================================ + +/** + * @struct Config + * @brief Configuration for the GenAI event-driven demo + */ +struct Config { + int genai_workers = 8; + int max_clients = 20; + int run_duration_seconds = 60; + double client_add_probability = 0.15; + double request_send_probability = 0.20; + int min_documents_per_request = 1; + int max_documents_per_request = 10; + int stats_print_interval_ms = 2000; +}; + +// ============================================================================ +// Sample Documents +// ============================================================================ + +/** + * @brief Sample documents for testing embeddings + */ +const std::vector SAMPLE_DOCUMENTS = { + "The quick brown fox jumps over the lazy dog. This is a classic sentence that contains all letters of the alphabet.", + "Machine learning is a subset of artificial intelligence that enables systems to learn from data.", + "Embeddings convert text into numerical vectors that capture semantic meaning.", + "Natural language processing has revolutionized how computers understand human language.", + "Vector databases store embeddings for efficient similarity search and retrieval.", + "Transformers have become the dominant architecture for modern natural language processing tasks.", + "Large language models demonstrate remarkable capabilities in text generation and comprehension.", + "Semantic search uses embeddings to find content based on meaning rather than keyword matching.", + "Neural networks learn complex patterns through interconnected layers of artificial neurons.", + "Convolutional neural networks excel at image recognition and computer vision tasks.", + "Recurrent neural networks can process sequential data like text and time series.", + "Attention mechanisms allow models to focus on relevant parts of the input.", + "Transfer learning enables models trained on one task to be applied to related tasks.", + "Gradient descent is the fundamental optimization algorithm for training neural networks.", + "Backpropagation efficiently computes gradients by propagating errors backward through the network.", + "Regularization techniques like dropout prevent overfitting in deep learning models.", + "Batch normalization stabilizes training by normalizing layer inputs.", + "Learning rate schedules adjust the step size during optimization for better convergence.", + "Tokenization breaks text into smaller units for processing by language models.", + "Word embeddings like Word2Vec capture semantic relationships between words.", + "Contextual embeddings like BERT generate representations based on surrounding context.", + "Sequence-to-sequence models are used for translation and text summarization.", + "Beam search improves output quality in text generation by considering multiple candidates.", + "Temperature controls randomness in probabilistic sampling for language model outputs.", + "Fine-tuning adapts pre-trained models to specific tasks with limited data." +}; + +/** + * @brief Sample queries for testing reranking + */ +const std::vector SAMPLE_QUERIES = { + "What is machine learning?", + "How do neural networks work?", + "Explain embeddings and vectors", + "What is transformers architecture?", + "How does attention mechanism work?", + "What is backpropagation?", + "Explain natural language processing" +}; + +/** + * @class Client + * @brief Client that sends embedding and rerank requests to GenAI module + * + * The client allocates documents and passes pointers to GenAI (shared memory). + * Client waits for response before sending next request (ensures memory validity). + */ +class Client { +public: + enum State { + NEW, + CONNECTED, + IDLE, + WAITING_FOR_RESPONSE, + DONE + }; + + Client(int id, const Config& config) + : id_(id), + config_(config), + state_(NEW), + read_fd_(-1), + genai_fd_(-1), + next_request_id_(1), + requests_sent_(0), + total_requests_(0), + responses_received_(0), + owned_embedding_(nullptr), + owned_rerank_results_(nullptr) { + + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution<> dist( + config_.min_documents_per_request, + config_.max_documents_per_request + ); + total_requests_ = dist(gen); + } + + ~Client() { + close(); + // Clean up any owned embedding + if (owned_embedding_) { + delete[] owned_embedding_; + } + // Clean up any owned rerank results + if (owned_rerank_results_) { + delete[] owned_rerank_results_; + } + } + + void connect(GenAIModule& genai) { + int fds[2]; + if (socketpair(AF_UNIX, SOCK_STREAM, 0, fds) < 0) { + perror("socketpair"); + return; + } + + // Client uses fds[0] for both reading and writing + // GenAI uses fds[1] for both reading and writing + read_fd_ = fds[0]; + genai_fd_ = fds[1]; // Only used for registration + + int flags = fcntl(read_fd_, F_GETFL, 0); + fcntl(read_fd_, F_SETFL, flags | O_NONBLOCK); + + genai.register_client(genai_fd_); // GenAI gets the other end + + state_ = IDLE; + + std::cout << "[" << id_ << "] Connected (will send " + << total_requests_ << " requests)\n"; + } + + bool can_send_request() const { + return state_ == IDLE; + } + + void send_request() { + if (state_ != IDLE) return; + + std::random_device rd; + std::mt19937 gen(rd()); + + // Randomly choose between embedding and rerank (30% chance of rerank) + std::uniform_real_distribution<> op_dist(0.0, 1.0); + bool use_rerank = op_dist(gen) < 0.3; + + // Allocate documents for this request (owned by client until response) + current_documents_.clear(); + + std::uniform_int_distribution<> doc_dist(0, SAMPLE_DOCUMENTS.size() - 1); + std::uniform_int_distribution<> count_dist( + config_.min_documents_per_request, + config_.max_documents_per_request + ); + + int num_docs = count_dist(gen); + for (int i = 0; i < num_docs; i++) { + const std::string& sample_text = SAMPLE_DOCUMENTS[doc_dist(gen)]; + current_documents_.push_back(Document(sample_text.c_str(), sample_text.size())); + } + + uint64_t request_id = next_request_id_++; + + if (use_rerank && !SAMPLE_QUERIES.empty()) { + // Send rerank request + std::uniform_int_distribution<> query_dist(0, SAMPLE_QUERIES.size() - 1); + const std::string& query = SAMPLE_QUERIES[query_dist(gen)]; + uint32_t top_n = 3 + (gen() % 3); // 3-5 results + + RequestHeader req; + req.request_id = request_id; + req.operation = OP_RERANK; + req.document_count = current_documents_.size(); + req.flags = 0; + req.top_n = top_n; + + // Send request header + write(read_fd_, &req, sizeof(req)); + + // Send query as null-terminated string + write(read_fd_, query.c_str(), query.size() + 1); // +1 for null terminator + + // Send document pointers (as uint64_t) + std::vector doc_ptrs; + doc_ptrs.reserve(current_documents_.size()); + for (const auto& doc : current_documents_) { + doc_ptrs.push_back(reinterpret_cast(&doc)); + } + write(read_fd_, doc_ptrs.data(), doc_ptrs.size() * sizeof(uint64_t)); + + pending_requests_[request_id] = std::chrono::steady_clock::now(); + requests_sent_++; + state_ = WAITING_FOR_RESPONSE; + + std::cout << "[" << id_ << "] Sent RERANK request " << request_id + << " with " << current_documents_.size() << " document(s), top_n=" << top_n + << " (" << requests_sent_ << "/" << total_requests_ << ")\n"; + } else { + // Send embedding request + RequestHeader req; + req.request_id = request_id; + req.operation = OP_EMBEDDING; + req.document_count = current_documents_.size(); + req.flags = 0; + req.top_n = 0; // Not used for embedding + + // Send request header + write(read_fd_, &req, sizeof(req)); + + // Send document pointers (as uint64_t) + std::vector doc_ptrs; + doc_ptrs.reserve(current_documents_.size()); + for (const auto& doc : current_documents_) { + doc_ptrs.push_back(reinterpret_cast(&doc)); + } + write(read_fd_, doc_ptrs.data(), doc_ptrs.size() * sizeof(uint64_t)); + + pending_requests_[request_id] = std::chrono::steady_clock::now(); + requests_sent_++; + state_ = WAITING_FOR_RESPONSE; + + std::cout << "[" << id_ << "] Sent EMBEDDING request " << request_id + << " with " << current_documents_.size() << " document(s) (" + << requests_sent_ << "/" << total_requests_ << ")\n"; + } + } + + void send_rerank_request(const std::string& query, const std::vector& documents, uint32_t top_n = 5) { + if (state_ != IDLE) return; + + // Store documents for this request (owned by client until response) + current_documents_ = documents; + + uint64_t request_id = next_request_id_++; + + RequestHeader req; + req.request_id = request_id; + req.operation = OP_RERANK; + req.document_count = current_documents_.size(); + req.flags = 0; + req.top_n = top_n; + + // Send request header + write(read_fd_, &req, sizeof(req)); + + // Send query as null-terminated string + write(read_fd_, query.c_str(), query.size() + 1); // +1 for null terminator + + // Send document pointers (as uint64_t) + std::vector doc_ptrs; + doc_ptrs.reserve(current_documents_.size()); + for (const auto& doc : current_documents_) { + doc_ptrs.push_back(reinterpret_cast(&doc)); + } + write(read_fd_, doc_ptrs.data(), doc_ptrs.size() * sizeof(uint64_t)); + + pending_requests_[request_id] = std::chrono::steady_clock::now(); + requests_sent_++; + state_ = WAITING_FOR_RESPONSE; + + std::cout << "[" << id_ << "] Sent rerank request " << request_id + << " with " << current_documents_.size() << " document(s), top_n=" << top_n + << " (" << requests_sent_ << "/" << total_requests_ << ")\n"; + } + + bool has_response() { + if (state_ != WAITING_FOR_RESPONSE) { + return false; + } + + ResponseHeader resp; + ssize_t n = read(read_fd_, &resp, sizeof(resp)); + + if (n <= 0) { + return false; + } + + auto it = pending_requests_.find(resp.request_id); + if (it != pending_requests_.end()) { + auto start_time = it->second; + auto end_time = std::chrono::steady_clock::now(); + auto duration = std::chrono::duration_cast( + end_time - start_time).count(); + + if (resp.status_code == 0) { + if (resp.embedding_size > 0) { + // Batch embedding response + float* batch_embedding_ptr = reinterpret_cast(resp.result_ptr); + + std::cout << "[" << id_ << "] Received embedding response " << resp.request_id + << " (rtt=" << duration << "ms, proc=" << resp.processing_time_ms + << "ms, embeddings=" << resp.result_count + << " x " << resp.embedding_size << " floats = " + << (resp.result_count * resp.embedding_size) << " total floats)\n"; + + // Take ownership of the batch embedding + if (owned_embedding_) { + delete[] owned_embedding_; + } + owned_embedding_ = batch_embedding_ptr; + } else if (resp.result_count > 0) { + // Rerank response + RerankResult* rerank_ptr = reinterpret_cast(resp.result_ptr); + + std::cout << "[" << id_ << "] Received rerank response " << resp.request_id + << " (rtt=" << duration << "ms, proc=" << resp.processing_time_ms + << "ms, results=" << resp.result_count << ")\n"; + + // Print top results + for (uint32_t i = 0; i < std::min(resp.result_count, 5u); i++) { + std::cout << " [" << i << "] index=" << rerank_ptr[i].index + << ", score=" << rerank_ptr[i].score << "\n"; + } + + // Take ownership of the rerank results + if (owned_rerank_results_) { + delete[] owned_rerank_results_; + } + owned_rerank_results_ = rerank_ptr; + } + } else { + std::cout << "[" << id_ << "] Received response " << resp.request_id + << " (rtt=" << duration << "ms, status=ERROR)\n"; + } + + pending_requests_.erase(it); + } + + responses_received_++; + + // Clean up current documents (safe now that response is received) + current_documents_.clear(); + + // Check if we should send more requests or are done + if (requests_sent_ >= total_requests_) { + state_ = DONE; + } else { + state_ = IDLE; + } + + return true; + } + + bool is_done() const { + return state_ == DONE; + } + + int get_read_fd() const { + return read_fd_; + } + + int get_id() const { + return id_; + } + + void close() { + if (read_fd_ >= 0) ::close(read_fd_); + if (genai_fd_ >= 0) ::close(genai_fd_); + read_fd_ = -1; + genai_fd_ = -1; + } + + const char* get_state_string() const { + switch (state_) { + case NEW: return "NEW"; + case CONNECTED: return "CONNECTED"; + case IDLE: return "IDLE"; + case WAITING_FOR_RESPONSE: return "WAITING"; + case DONE: return "DONE"; + default: return "UNKNOWN"; + } + } + +private: + int id_; + Config config_; + State state_; + + int read_fd_; + int genai_fd_; + + uint64_t next_request_id_; + int requests_sent_; + int total_requests_; + int responses_received_; + + std::vector current_documents_; ///< Documents for current request + float* owned_embedding_; ///< Embedding received from GenAI (owned by client) + RerankResult* owned_rerank_results_; ///< Rerank results from GenAI (owned by client) + + std::unordered_map pending_requests_; +}; + +// ============================================================================ +// Main +// ============================================================================ + +int main() { + std::cout << "=== GenAI Module Event-Driven POC ===\n"; + std::cout << "Real embedding generation and reranking via llama-server\n\n"; + + Config config; + std::cout << "Configuration:\n"; + std::cout << " GenAI workers: " << config.genai_workers << "\n"; + std::cout << " Max clients: " << config.max_clients << "\n"; + std::cout << " Run duration: " << config.run_duration_seconds << "s\n"; + std::cout << " Client add probability: " << config.client_add_probability << "\n"; + std::cout << " Request send probability: " << config.request_send_probability << "\n"; + std::cout << " Documents per request: " << config.min_documents_per_request + << "-" << config.max_documents_per_request << "\n"; + std::cout << " Sample documents: " << SAMPLE_DOCUMENTS.size() << "\n\n"; + + // Create and start GenAI module + GenAIModule genai(config.genai_workers); + genai.start(); + + // Create main epoll set for monitoring client responses + int main_epoll_fd = epoll_create1(EPOLL_CLOEXEC); + if (main_epoll_fd < 0) { + perror("epoll_create1"); + return 1; + } + + // Clients managed by main loop + std::vector clients; + int next_client_id = 1; + int total_clients_created = 0; + int total_clients_completed = 0; + + // Statistics + uint64_t total_requests_sent = 0; + uint64_t total_responses_received = 0; + auto last_stats_time = std::chrono::steady_clock::now(); + + // Random number generation + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_real_distribution<> dis(0.0, 1.0); + + auto start_time = std::chrono::steady_clock::now(); + + std::cout << "=== Starting Event Loop ===\n\n"; + + bool running = true; + while (running) { + auto now = std::chrono::steady_clock::now(); + auto elapsed = std::chrono::duration_cast( + now - start_time).count(); + + // Check termination conditions + bool all_work_done = (total_clients_created >= config.max_clients) && + (clients.empty()) && + (total_clients_completed >= config.max_clients); + + if (all_work_done) { + std::cout << "\n=== All work completed, shutting down early ===\n"; + running = false; + break; + } + + if (elapsed >= config.run_duration_seconds) { + std::cout << "\n=== Time elapsed, shutting down ===\n"; + running = false; + break; + } + + // -------------------------------------------------------- + // 1. Randomly add new clients + // -------------------------------------------------------- + if (clients.size() < static_cast(config.max_clients) && + total_clients_created < config.max_clients && + dis(gen) < config.client_add_probability) { + + Client* client = new Client(next_client_id++, config); + client->connect(genai); + + // Add to main epoll for monitoring responses + struct epoll_event ev; + ev.events = EPOLLIN; + ev.data.ptr = client; + if (epoll_ctl(main_epoll_fd, EPOLL_CTL_ADD, client->get_read_fd(), &ev) < 0) { + perror("epoll_ctl client"); + delete client; + } else { + clients.push_back(client); + total_clients_created++; + } + } + + // -------------------------------------------------------- + // 2. Randomly send requests from idle clients + // -------------------------------------------------------- + for (auto* client : clients) { + if (client->can_send_request() && dis(gen) < config.request_send_probability) { + client->send_request(); + total_requests_sent++; + } + } + + // -------------------------------------------------------- + // 3. Wait for events (responses or timeout) + // -------------------------------------------------------- + const int MAX_EVENTS = 64; + struct epoll_event events[MAX_EVENTS]; + + int timeout_ms = 100; + int nfds = epoll_wait(main_epoll_fd, events, MAX_EVENTS, timeout_ms); + + // -------------------------------------------------------- + // 4. Process responses + // -------------------------------------------------------- + for (int i = 0; i < nfds; i++) { + Client* client = static_cast(events[i].data.ptr); + + if (client->has_response()) { + total_responses_received++; + + if (client->is_done()) { + // Remove from epoll + epoll_ctl(main_epoll_fd, EPOLL_CTL_DEL, client->get_read_fd(), nullptr); + + // Remove from clients vector + clients.erase( + std::remove(clients.begin(), clients.end(), client), + clients.end() + ); + + std::cout << "[" << client->get_id() << "] Completed all requests, removing\n"; + + client->close(); + delete client; + total_clients_completed++; + } + } + } + + // -------------------------------------------------------- + // 5. Print statistics periodically + // -------------------------------------------------------- + auto time_since_last_stats = std::chrono::duration_cast( + now - last_stats_time).count(); + + if (time_since_last_stats >= config.stats_print_interval_ms) { + std::cout << "\n[STATS] T+" << elapsed << "s " + << "| Active clients: " << clients.size() + << " | Queue depth: " << genai.get_queue_size() + << " | Requests sent: " << total_requests_sent + << " | Responses: " << total_responses_received + << " | Completed: " << total_clients_completed << "\n"; + + // Show state distribution + std::unordered_map state_counts; + for (auto* client : clients) { + state_counts[client->get_state_string()]++; + } + std::cout << " States: "; + for (auto& [state, count] : state_counts) { + std::cout << state << "=" << count << " "; + } + std::cout << "\n\n"; + + last_stats_time = now; + } + } + + // ------------------------------------------------------------ + // Final statistics + // ------------------------------------------------------------ + std::cout << "\n=== Final Statistics ===\n"; + std::cout << "Total clients created: " << total_clients_created << "\n"; + std::cout << "Total clients completed: " << total_clients_completed << "\n"; + std::cout << "Total requests sent: " << total_requests_sent << "\n"; + std::cout << "Total responses received: " << total_responses_received << "\n"; + + // Clean up remaining clients + for (auto* client : clients) { + epoll_ctl(main_epoll_fd, EPOLL_CTL_DEL, client->get_read_fd(), nullptr); + client->close(); + delete client; + } + clients.clear(); + + close(main_epoll_fd); + + // Stop GenAI module + std::cout << "\nStopping GenAI module...\n"; + genai.stop(); + + std::cout << "\n=== Demonstration Complete ===\n"; + + return 0; +} diff --git a/include/Base_Session.h b/include/Base_Session.h index 3d13a15b33..8ff05fcaed 100644 --- a/include/Base_Session.h +++ b/include/Base_Session.h @@ -99,6 +99,21 @@ class Base_Session { //MySQL_STMTs_meta *sess_STMTs_meta; //StmtLongDataHandler *SLDH; + // GenAI async support +#ifdef epoll_create1 + struct GenAI_PendingRequest { + uint64_t request_id; + int client_fd; // MySQL side of socketpair + std::string json_query; + std::chrono::steady_clock::time_point start_time; + PtrSize_t *original_pkt; // Original packet to complete + }; + + std::unordered_map pending_genai_requests_; + uint64_t next_genai_request_id_; + int genai_epoll_fd_; // For monitoring GenAI response fds +#endif + void init(); diff --git a/include/GenAI_Thread.h b/include/GenAI_Thread.h new file mode 100644 index 0000000000..ea8a62b302 --- /dev/null +++ b/include/GenAI_Thread.h @@ -0,0 +1,397 @@ +#ifndef __CLASS_GENAI_THREAD_H +#define __CLASS_GENAI_THREAD_H + +#include "proxysql.h" +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef epoll_create1 +#include +#endif + +#include "curl/curl.h" + +#define GENAI_THREAD_VERSION "0.1.0" + +/** + * @brief GenAI operation types + */ +enum GenAI_Operation : uint32_t { + GENAI_OP_EMBEDDING = 0, ///< Generate embeddings for documents + GENAI_OP_RERANK = 1, ///< Rerank documents by relevance to query + GENAI_OP_JSON = 2, ///< Autonomous JSON query processing (handles embed/rerank/document_from_sql) +}; + +/** + * @brief Document structure for passing document data + */ +struct GenAI_Document { + const char* text; ///< Pointer to document text (owned by caller) + size_t text_size; ///< Length of text in bytes + + GenAI_Document() : text(nullptr), text_size(0) {} + GenAI_Document(const char* t, size_t s) : text(t), text_size(s) {} +}; + +/** + * @brief Embedding result structure + */ +struct GenAI_EmbeddingResult { + float* data; ///< Pointer to embedding vector + size_t embedding_size;///< Number of floats per embedding + size_t count; ///< Number of embeddings + + GenAI_EmbeddingResult() : data(nullptr), embedding_size(0), count(0) {} + ~GenAI_EmbeddingResult(); + + // Disable copy + GenAI_EmbeddingResult(const GenAI_EmbeddingResult&) = delete; + GenAI_EmbeddingResult& operator=(const GenAI_EmbeddingResult&) = delete; + + // Move semantics + GenAI_EmbeddingResult(GenAI_EmbeddingResult&& other) noexcept; + GenAI_EmbeddingResult& operator=(GenAI_EmbeddingResult&& other) noexcept; +}; + +/** + * @brief Rerank result structure + */ +struct GenAI_RerankResult { + uint32_t index; ///< Original document index + float score; ///< Relevance score +}; + +/** + * @brief Rerank result array structure + */ +struct GenAI_RerankResultArray { + GenAI_RerankResult* data; ///< Pointer to result array + size_t count; ///< Number of results + + GenAI_RerankResultArray() : data(nullptr), count(0) {} + ~GenAI_RerankResultArray(); + + // Disable copy + GenAI_RerankResultArray(const GenAI_RerankResultArray&) = delete; + GenAI_RerankResultArray& operator=(const GenAI_RerankResultArray&) = delete; + + // Move semantics + GenAI_RerankResultArray(GenAI_RerankResultArray&& other) noexcept; + GenAI_RerankResultArray& operator=(GenAI_RerankResultArray&& other) noexcept; +}; + +/** + * @brief Request structure for internal queue + */ +struct GenAI_Request { + int client_fd; ///< Client file descriptor + uint64_t request_id; ///< Request ID + uint32_t operation; ///< Operation type + std::string query; ///< Query for rerank (empty for embedding) + uint32_t top_n; ///< Top N results for rerank + std::vector documents; ///< Documents to process + std::string json_query; ///< Raw JSON query from client (for autonomous processing) +}; + +/** + * @brief Request header for socketpair communication between MySQL_Session and GenAI + * + * This structure is sent from MySQL_Session to the GenAI listener via socketpair + * when making async GenAI requests. It contains all the metadata needed to process + * the request without blocking the MySQL thread. + * + * Communication flow: + * 1. MySQL_Session creates socketpair() + * 2. MySQL_Session sends GenAI_RequestHeader + JSON query via its fd + * 3. GenAI listener reads from socketpair via epoll + * 4. GenAI worker processes request (blocking curl in worker thread) + * 5. GenAI worker sends GenAI_ResponseHeader + JSON result back via socketpair + * 6. MySQL_Session receives response via epoll notification + * + * @see GenAI_ResponseHeader + */ +struct GenAI_RequestHeader { + uint64_t request_id; ///< Client's correlation ID for matching requests/responses + uint32_t operation; ///< Operation type (GENAI_OP_EMBEDDING, GENAI_OP_RERANK, GENAI_OP_JSON) + uint32_t query_len; ///< Length of JSON query that follows this header (0 if no query) + uint32_t flags; ///< Reserved for future use (must be 0) + uint32_t top_n; ///< For rerank operations: maximum number of results to return (0 = all) +}; + +/** + * @brief Response header for socketpair communication from GenAI to MySQL_Session + * + * This structure is sent from the GenAI worker back to MySQL_Session via socketpair + * after processing completes. It contains status information and metadata about + * the results, followed by the JSON result payload. + * + * Response format: + * - GenAI_ResponseHeader (this structure) + * - JSON result data (result_len bytes if result_len > 0) + * + * @see GenAI_RequestHeader + */ +struct GenAI_ResponseHeader { + uint64_t request_id; ///< Echo of client's request ID for request/response matching + uint32_t status_code; ///< Status code: 0=success, >0=error occurred + uint32_t result_len; ///< Length of JSON result payload that follows this header + uint32_t processing_time_ms;///< Time taken by GenAI worker to process the request (milliseconds) + uint64_t result_ptr; ///< Reserved for future shared memory optimizations (must be 0) + uint32_t result_count; ///< Number of results in the response (e.g., number of embeddings/reranks) + uint32_t reserved; ///< Reserved for future use (must be 0) +}; + +/** + * @brief GenAI Threads Handler class for managing GenAI module + * + * This class handles the GenAI module's configuration variables, lifecycle, + * and provides embedding and reranking functionality via external services. + */ +class GenAI_Threads_Handler +{ +private: + int shutdown_; + pthread_rwlock_t rwlock; + + // Threading components + std::vector worker_threads_; + std::thread listener_thread_; + std::queue request_queue_; + std::mutex queue_mutex_; + std::condition_variable queue_cv_; + std::unordered_set client_fds_; + std::mutex clients_mutex_; + + // epoll for async I/O + int epoll_fd_; + int event_fd_; + + // Worker methods + void worker_loop(int worker_id); + void listener_loop(); + + // HTTP client methods + GenAI_EmbeddingResult call_llama_embedding(const std::string& text); + GenAI_EmbeddingResult call_llama_batch_embedding(const std::vector& texts); + GenAI_RerankResultArray call_llama_rerank(const std::string& query, + const std::vector& texts, + uint32_t top_n); + static size_t WriteCallback(void* contents, size_t size, size_t nmemb, void* userp); + +public: + /** + * @brief Structure holding GenAI module configuration variables + */ + struct { + // Thread configuration + int genai_threads; ///< Number of worker threads (default: 4) + + // Service endpoints + char* genai_embedding_uri; ///< URI for embedding service (default: http://127.0.0.1:8013/embedding) + char* genai_rerank_uri; ///< URI for reranking service (default: http://127.0.0.1:8012/rerank) + + // Timeouts (in milliseconds) + int genai_embedding_timeout_ms; ///< Timeout for embedding requests (default: 30000) + int genai_rerank_timeout_ms; ///< Timeout for reranking requests (default: 30000) + } variables; + + struct { + int threads_initialized = 0; + int active_requests = 0; + int completed_requests = 0; + int failed_requests = 0; + } status_variables; + + unsigned int num_threads; + + /** + * @brief Default constructor for GenAI_Threads_Handler + */ + GenAI_Threads_Handler(); + + /** + * @brief Destructor for GenAI_Threads_Handler + */ + ~GenAI_Threads_Handler(); + + /** + * @brief Initialize the GenAI module + * + * Starts worker threads and listener for processing requests. + * + * @param num Number of threads (uses genai_threads variable if 0) + * @param stack Stack size for threads (unused, reserved) + */ + void init(unsigned int num = 0, size_t stack = 0); + + /** + * @brief Shutdown the GenAI module + * + * Stops all threads and cleans up resources. + */ + void shutdown(); + + /** + * @brief Acquire write lock on variables + */ + void wrlock(); + + /** + * @brief Release write lock on variables + */ + void wrunlock(); + + /** + * @brief Get the value of a variable as a string + * + * @param name The name of the variable (without 'genai-' prefix) + * @return Dynamically allocated string with the value, or NULL if not found + */ + char* get_variable(char* name); + + /** + * @brief Set the value of a variable + * + * @param name The name of the variable (without 'genai-' prefix) + * @param value The new value to set + * @return true if successful, false if variable not found or value invalid + */ + bool set_variable(char* name, const char* value); + + /** + * @brief Get a list of all variable names + * + * @return Dynamically allocated array of strings, terminated by NULL + */ + char** get_variables_list(); + + /** + * @brief Print the version information + */ + void print_version(); + + /** + * @brief Register a client file descriptor with GenAI module for async communication + * + * Registers the GenAI side of a socketpair with the GenAI epoll instance. + * This allows the GenAI listener to receive requests from MySQL sessions asynchronously. + * + * Usage flow: + * 1. MySQL_Session creates socketpair(fds) + * 2. MySQL_Session keeps fds[0] for reading responses + * 3. MySQL_Session calls register_client(fds[1]) to register GenAI side + * 4. GenAI listener adds fds[1] to its epoll for reading requests + * 5. When request is received, it's queued to worker threads + * + * @param client_fd The GenAI side file descriptor from socketpair (typically fds[1]) + * @return true if successfully registered and added to epoll, false on error + * + * @see unregister_client() + */ + bool register_client(int client_fd); + + /** + * @brief Unregister a client file descriptor from GenAI module + * + * Removes a previously registered client fd from the GenAI epoll instance + * and closes the connection. Called when a MySQL session ends or an error occurs. + * + * @param client_fd The GenAI side file descriptor to remove + * + * @see register_client() + */ + void unregister_client(int client_fd); + + /** + * @brief Get current queue depth (number of pending requests) + * + * @return Number of requests in the queue + */ + size_t get_queue_size(); + + // Public API methods for embedding and reranking + // These methods can be called directly without going through socket pairs + + /** + * @brief Generate embeddings for multiple documents + * + * Sends the documents to the embedding service (configured via genai_embedding_uri) + * and returns the resulting embedding vectors. This method blocks until the + * embedding service responds (typically 10-100ms per document depending on model size). + * + * For async non-blocking behavior, use the socketpair-based async API via + * MySQL_Session's GENAI: query handler instead. + * + * @param documents Vector of document texts to embed (each can be up to several KB) + * @return GenAI_EmbeddingResult containing all embeddings with metadata. + * The caller takes ownership of the returned data and must free it. + * On error, returns an empty result (data==nullptr || count==0). + * + * @note This is a BLOCKING call. For async operation, use GENAI: queries through MySQL_Session. + * @see rerank_documents(), process_json_query() + */ + GenAI_EmbeddingResult embed_documents(const std::vector& documents); + + /** + * @brief Rerank documents based on query relevance + * + * Sends the query and documents to the reranking service (configured via genai_rerank_uri) + * and returns the documents sorted by relevance to the query. This method blocks + * until the reranking service responds (typically 20-50ms for most models). + * + * For async non-blocking behavior, use the socketpair-based async API via + * MySQL_Session's GENAI: query handler instead. + * + * @param query Query string to rerank against (e.g., search query, user question) + * @param documents Vector of document texts to rerank (typically search results or candidates) + * @param top_n Maximum number of top results to return (0 = return all sorted results) + * @return GenAI_RerankResultArray containing results sorted by relevance. + * Each result includes the original document index and a relevance score. + * The caller takes ownership of the returned data and must free it. + * On error, returns an empty result (data==nullptr || count==0). + * + * @note This is a BLOCKING call. For async operation, use GENAI: queries through MySQL_Session. + * @see embed_documents(), process_json_query() + */ + GenAI_RerankResultArray rerank_documents(const std::string& query, + const std::vector& documents, + uint32_t top_n = 0); + + /** + * @brief Process JSON query autonomously (handles embed/rerank/document_from_sql) + * + * This method processes JSON queries that describe embedding or reranking operations. + * It autonomously parses the JSON, determines the operation type, and routes to the + * appropriate handler. This is the main entry point for the async GENAI: query syntax. + * + * Supported query formats: + * - {"type": "embed", "documents": ["doc1", "doc2", ...]} + * - {"type": "rerank", "query": "...", "documents": [...], "top_n": 5} + * - {"type": "rerank", "query": "...", "document_from_sql": {"query": "SELECT ..."}} + * + * The response format is a JSON object with "columns" and "rows" arrays: + * - {"columns": ["col1", "col2"], "rows": [["val1", "val2"], ...]} + * - Error responses: {"error": "error message"} + * + * @param json_query JSON query string from client (must be valid JSON) + * @return JSON string result with columns and rows formatted for MySQL resultset. + * Returns empty string on error. + * + * @note This method is called from worker threads as part of async request processing. + * The blocking HTTP calls occur in the worker thread, not the MySQL thread. + * + * @see embed_documents(), rerank_documents() + */ + std::string process_json_query(const std::string& json_query); +}; + +// Global instance of the GenAI Threads Handler +extern GenAI_Threads_Handler *GloGATH; + +#endif // __CLASS_GENAI_THREAD_H diff --git a/include/MySQL_Session.h b/include/MySQL_Session.h index 9d4d6fe395..b44eea8a5a 100644 --- a/include/MySQL_Session.h +++ b/include/MySQL_Session.h @@ -283,6 +283,73 @@ class MySQL_Session: public Base_Session&& pgsql_users_resultset = nullptr, const std::string& checksum = "", const time_t epoch = 0); void flush_pgsql_users__from_memory_to_disk(); void flush_pgsql_users__from_disk_to_memory(); diff --git a/include/proxysql_structs.h b/include/proxysql_structs.h index 36bc727cb5..141db59383 100644 --- a/include/proxysql_structs.h +++ b/include/proxysql_structs.h @@ -159,6 +159,7 @@ enum debug_module { PROXY_DEBUG_RESTAPI, PROXY_DEBUG_MONITOR, PROXY_DEBUG_CLUSTER, + PROXY_DEBUG_GENAI, PROXY_DEBUG_UNKNOWN // this module doesn't exist. It is used only to define the last possible module }; @@ -178,6 +179,7 @@ enum MySQL_DS_type { // MYDS_BACKEND_PAUSE_CONNECT, // MYDS_BACKEND_FAILED_CONNECT, MYDS_FRONTEND, + MYDS_INTERNAL_GENAI, // Internal GenAI module connection }; using PgSQL_DS_type = MySQL_DS_type; diff --git a/lib/Admin_FlushVariables.cpp b/lib/Admin_FlushVariables.cpp index 79019cb81e..e9d6c343ae 100644 --- a/lib/Admin_FlushVariables.cpp +++ b/lib/Admin_FlushVariables.cpp @@ -42,6 +42,7 @@ using json = nlohmann::json; #include "ProxySQL_Statistics.hpp" #include "MySQL_Logger.hpp" #include "PgSQL_Logger.hpp" +#include "GenAI_Thread.h" #include "SQLite3_Server.h" #include "Web_Interface.hpp" @@ -138,6 +139,7 @@ extern PgSQL_Logger* GloPgSQL_Logger; extern MySQL_STMT_Manager_v14 *GloMyStmt; extern MySQL_Monitor *GloMyMon; extern PgSQL_Threads_Handler* GloPTH; +extern GenAI_Threads_Handler* GloGATH; extern void (*flush_logs_function)(); @@ -953,6 +955,124 @@ void ProxySQL_Admin::flush_pgsql_variables___database_to_runtime(SQLite3DB* db, if (resultset) delete resultset; } +// GenAI Variables Flush Functions +void ProxySQL_Admin::flush_genai_variables___runtime_to_database(SQLite3DB* db, bool replace, bool del, bool onlyifempty, bool runtime, bool use_lock) { + proxy_debug(PROXY_DEBUG_ADMIN, 4, "Flushing GenAI variables. Replace:%d, Delete:%d, Only_If_Empty:%d\n", replace, del, onlyifempty); + if (onlyifempty) { + char* error = NULL; + int cols = 0; + int affected_rows = 0; + SQLite3_result* resultset = NULL; + char* q = (char*)"SELECT COUNT(*) FROM global_variables WHERE variable_name LIKE 'genai-%'"; + db->execute_statement(q, &error, &cols, &affected_rows, &resultset); + int matching_rows = 0; + if (error) { + proxy_error("Error on %s : %s\n", q, error); + return; + } + else { + for (std::vector::iterator it = resultset->rows.begin(); it != resultset->rows.end(); ++it) { + SQLite3_row* r = *it; + matching_rows += atoi(r->fields[0]); + } + } + if (resultset) delete resultset; + if (matching_rows) { + proxy_debug(PROXY_DEBUG_ADMIN, 4, "Table global_variables has GenAI variables - skipping\n"); + return; + } + } + if (del) { + proxy_debug(PROXY_DEBUG_ADMIN, 4, "Deleting GenAI variables from global_variables\n"); + db->execute("DELETE FROM global_variables WHERE variable_name LIKE 'genai-%'"); + } + static char* a; + static char* b; + if (replace) { + a = (char*)"REPLACE INTO global_variables(variable_name, variable_value) VALUES(?1, ?2)"; + } + else { + a = (char*)"INSERT OR IGNORE INTO global_variables(variable_name, variable_value) VALUES(?1, ?2)"; + } + int rc; + sqlite3_stmt* statement1 = NULL; + sqlite3_stmt* statement2 = NULL; + rc = db->prepare_v2(a, &statement1); + ASSERT_SQLITE_OK(rc, db); + if (runtime) { + db->execute("DELETE FROM runtime_global_variables WHERE variable_name LIKE 'genai-%'"); + b = (char*)"INSERT INTO runtime_global_variables(variable_name, variable_value) VALUES(?1, ?2)"; + rc = db->prepare_v2(b, &statement2); + ASSERT_SQLITE_OK(rc, db); + } + if (use_lock) { + GloGATH->wrlock(); + db->execute("BEGIN"); + } + char** varnames = GloGATH->get_variables_list(); + for (int i = 0; varnames[i]; i++) { + char* val = GloGATH->get_variable(varnames[i]); + char* qualified_name = (char*)malloc(strlen(varnames[i]) + 10); + sprintf(qualified_name, "genai-%s", varnames[i]); + rc = (*proxy_sqlite3_bind_text)(statement1, 1, qualified_name, -1, SQLITE_TRANSIENT); ASSERT_SQLITE_OK(rc, db); + rc = (*proxy_sqlite3_bind_text)(statement1, 2, (val ? val : (char*)""), -1, SQLITE_TRANSIENT); ASSERT_SQLITE_OK(rc, db); + SAFE_SQLITE3_STEP2(statement1); + rc = (*proxy_sqlite3_clear_bindings)(statement1); ASSERT_SQLITE_OK(rc, db); + rc = (*proxy_sqlite3_reset)(statement1); ASSERT_SQLITE_OK(rc, db); + if (runtime) { + rc = (*proxy_sqlite3_bind_text)(statement2, 1, qualified_name, -1, SQLITE_TRANSIENT); ASSERT_SQLITE_OK(rc, db); + rc = (*proxy_sqlite3_bind_text)(statement2, 2, (val ? val : (char*)""), -1, SQLITE_TRANSIENT); ASSERT_SQLITE_OK(rc, db); + SAFE_SQLITE3_STEP2(statement2); + rc = (*proxy_sqlite3_clear_bindings)(statement2); ASSERT_SQLITE_OK(rc, db); + rc = (*proxy_sqlite3_reset)(statement2); ASSERT_SQLITE_OK(rc, db); + } + if (val) + free(val); + free(qualified_name); + } + if (use_lock) { + db->execute("COMMIT"); + GloGATH->wrunlock(); + } + (*proxy_sqlite3_finalize)(statement1); + if (runtime) + (*proxy_sqlite3_finalize)(statement2); + for (int i = 0; varnames[i]; i++) { + free(varnames[i]); + } + free(varnames); +} + +void ProxySQL_Admin::flush_genai_variables___database_to_runtime(SQLite3DB* db, bool replace, const std::string& checksum, const time_t epoch) { + proxy_debug(PROXY_DEBUG_ADMIN, 4, "Flushing GenAI variables. Replace:%d\n", replace); + char* error = NULL; + int cols = 0; + int affected_rows = 0; + SQLite3_result* resultset = NULL; + char* q = (char*)"SELECT substr(variable_name,7) vn, variable_value FROM global_variables WHERE variable_name LIKE 'genai-%'"; + admindb->execute_statement(q, &error, &cols, &affected_rows, &resultset); + if (error) { + proxy_error("Error on %s : %s\n", q, error); + return; + } + else { + GloGATH->wrlock(); + for (std::vector::iterator it = resultset->rows.begin(); it != resultset->rows.end(); ++it) { + SQLite3_row* r = *it; + const char* value = r->fields[1]; + bool rc = GloGATH->set_variable(r->fields[0], value); + if (rc == false) { + proxy_debug(PROXY_DEBUG_ADMIN, 4, "Impossible to set variable %s with value \"%s\"\n", r->fields[0], value); + } + else { + proxy_debug(PROXY_DEBUG_ADMIN, 4, "Set variable %s with value \"%s\"\n", r->fields[0], value); + } + } + GloGATH->wrunlock(); + } + if (resultset) delete resultset; +} + void ProxySQL_Admin::flush_mysql_variables___runtime_to_database(SQLite3DB *db, bool replace, bool del, bool onlyifempty, bool runtime, bool use_lock) { proxy_debug(PROXY_DEBUG_ADMIN, 4, "Flushing MySQL variables. Replace:%d, Delete:%d, Only_If_Empty:%d\n", replace, del, onlyifempty); if (onlyifempty) { diff --git a/lib/Admin_Handler.cpp b/lib/Admin_Handler.cpp index 288ca2a85c..0611d31918 100644 --- a/lib/Admin_Handler.cpp +++ b/lib/Admin_Handler.cpp @@ -42,6 +42,7 @@ using json = nlohmann::json; #include "ProxySQL_Statistics.hpp" #include "MySQL_Logger.hpp" #include "PgSQL_Logger.hpp" +#include "GenAI_Thread.h" #include "SQLite3_Server.h" #include "Web_Interface.hpp" @@ -151,6 +152,7 @@ extern PgSQL_Logger* GloPgSQL_Logger; extern MySQL_STMT_Manager_v14 *GloMyStmt; extern MySQL_Monitor *GloMyMon; extern PgSQL_Threads_Handler* GloPTH; +extern GenAI_Threads_Handler* GloGATH; extern void (*flush_logs_function)(); @@ -269,6 +271,19 @@ const std::vector SAVE_PGSQL_VARIABLES_TO_MEMORY = { "SAVE PGSQL VARIABLES TO MEM" , "SAVE PGSQL VARIABLES FROM RUNTIME" , "SAVE PGSQL VARIABLES FROM RUN" }; + +// GenAI +const std::vector LOAD_GENAI_VARIABLES_FROM_MEMORY = { + "LOAD GENAI VARIABLES FROM MEMORY" , + "LOAD GENAI VARIABLES FROM MEM" , + "LOAD GENAI VARIABLES TO RUNTIME" , + "LOAD GENAI VARIABLES TO RUN" }; + +const std::vector SAVE_GENAI_VARIABLES_TO_MEMORY = { + "SAVE GENAI VARIABLES TO MEMORY" , + "SAVE GENAI VARIABLES TO MEM" , + "SAVE GENAI VARIABLES FROM RUNTIME" , + "SAVE GENAI VARIABLES FROM RUN" }; // const std::vector LOAD_COREDUMP_FROM_MEMORY = { "LOAD COREDUMP FROM MEMORY" , @@ -1637,10 +1652,12 @@ bool admin_handler_command_load_or_save(char *query_no_space, unsigned int query } if ((query_no_space_length > 21) && ((!strncasecmp("SAVE MYSQL VARIABLES ", query_no_space, 21)) || (!strncasecmp("LOAD MYSQL VARIABLES ", query_no_space, 21)) || - (!strncasecmp("SAVE PGSQL VARIABLES ", query_no_space, 21)) || (!strncasecmp("LOAD PGSQL VARIABLES ", query_no_space, 21)))) { + (!strncasecmp("SAVE PGSQL VARIABLES ", query_no_space, 21)) || (!strncasecmp("LOAD PGSQL VARIABLES ", query_no_space, 21)) || + (!strncasecmp("SAVE GENAI VARIABLES ", query_no_space, 21)) || (!strncasecmp("LOAD GENAI VARIABLES ", query_no_space, 21)))) { const bool is_pgsql = (query_no_space[5] == 'P' || query_no_space[5] == 'p') ? true : false; - const std::string modname = is_pgsql ? "pgsql_variables" : "mysql_variables"; + const bool is_genai = (query_no_space[5] == 'G' || query_no_space[5] == 'g') ? true : false; + const std::string modname = is_pgsql ? "pgsql_variables" : (is_genai ? "genai_variables" : "mysql_variables"); tuple, vector>& t = load_save_disk_commands[modname]; if (is_admin_command_or_alias(get<1>(t), query_no_space, query_no_space_length)) { @@ -1648,6 +1665,9 @@ bool admin_handler_command_load_or_save(char *query_no_space, unsigned int query if (is_pgsql) { *q = l_strdup("INSERT OR REPLACE INTO main.global_variables SELECT * FROM disk.global_variables WHERE variable_name LIKE 'pgsql-%'"); } + else if (is_genai) { + *q = l_strdup("INSERT OR REPLACE INTO main.global_variables SELECT * FROM disk.global_variables WHERE variable_name LIKE 'genai-%'"); + } else { *q = l_strdup("INSERT OR REPLACE INTO main.global_variables SELECT * FROM disk.global_variables WHERE variable_name LIKE 'mysql-%'"); } @@ -1660,6 +1680,9 @@ bool admin_handler_command_load_or_save(char *query_no_space, unsigned int query if (is_pgsql) { *q = l_strdup("INSERT OR REPLACE INTO disk.global_variables SELECT * FROM main.global_variables WHERE variable_name LIKE 'pgsql-%'"); } + else if (is_genai) { + *q = l_strdup("INSERT OR REPLACE INTO disk.global_variables SELECT * FROM main.global_variables WHERE variable_name LIKE 'genai-%'"); + } else { *q = l_strdup("INSERT OR REPLACE INTO disk.global_variables SELECT * FROM main.global_variables WHERE variable_name LIKE 'mysql-%'"); } @@ -1676,6 +1699,14 @@ bool admin_handler_command_load_or_save(char *query_no_space, unsigned int query SPA->send_ok_msg_to_client(sess, NULL, 0, query_no_space); return false; } + } else if (is_genai) { + if (is_admin_command_or_alias(LOAD_GENAI_VARIABLES_FROM_MEMORY, query_no_space, query_no_space_length)) { + ProxySQL_Admin* SPA = (ProxySQL_Admin*)pa; + SPA->load_genai_variables_to_runtime(); + proxy_debug(PROXY_DEBUG_ADMIN, 4, "Loaded genai variables to RUNTIME\n"); + SPA->send_ok_msg_to_client(sess, NULL, 0, query_no_space); + return false; + } } else { if (is_admin_command_or_alias(LOAD_MYSQL_VARIABLES_FROM_MEMORY, query_no_space, query_no_space_length)) { ProxySQL_Admin* SPA = (ProxySQL_Admin*)pa; @@ -1688,7 +1719,8 @@ bool admin_handler_command_load_or_save(char *query_no_space, unsigned int query if ( (query_no_space_length==strlen("LOAD MYSQL VARIABLES FROM CONFIG") && (!strncasecmp("LOAD MYSQL VARIABLES FROM CONFIG",query_no_space, query_no_space_length) || - !strncasecmp("LOAD PGSQL VARIABLES FROM CONFIG", query_no_space, query_no_space_length))) + !strncasecmp("LOAD PGSQL VARIABLES FROM CONFIG", query_no_space, query_no_space_length) || + !strncasecmp("LOAD GENAI VARIABLES FROM CONFIG", query_no_space, query_no_space_length))) ) { proxy_info("Received %s command\n", query_no_space); if (GloVars.configfile_open) { @@ -1699,6 +1731,9 @@ bool admin_handler_command_load_or_save(char *query_no_space, unsigned int query if (query_no_space[5] == 'P' || query_no_space[5] == 'p') { rows=SPA->proxysql_config().Read_Global_Variables_from_configfile("pgsql"); proxy_debug(PROXY_DEBUG_ADMIN, 4, "Loaded pgsql global variables from CONFIG\n"); + } else if (query_no_space[5] == 'G' || query_no_space[5] == 'g') { + rows=SPA->proxysql_config().Read_Global_Variables_from_configfile("genai"); + proxy_debug(PROXY_DEBUG_ADMIN, 4, "Loaded genai global variables from CONFIG\n"); } else { rows = SPA->proxysql_config().Read_Global_Variables_from_configfile("mysql"); proxy_debug(PROXY_DEBUG_ADMIN, 4, "Loaded mysql global variables from CONFIG\n"); @@ -1728,6 +1763,14 @@ bool admin_handler_command_load_or_save(char *query_no_space, unsigned int query SPA->send_ok_msg_to_client(sess, NULL, 0, query_no_space); return false; } + } else if (is_genai) { + if (is_admin_command_or_alias(SAVE_GENAI_VARIABLES_TO_MEMORY, query_no_space, query_no_space_length)) { + ProxySQL_Admin* SPA = (ProxySQL_Admin*)pa; + SPA->save_genai_variables_from_runtime(); + proxy_debug(PROXY_DEBUG_ADMIN, 4, "Saved genai variables from RUNTIME\n"); + SPA->send_ok_msg_to_client(sess, NULL, 0, query_no_space); + return false; + } } else { if (is_admin_command_or_alias(SAVE_MYSQL_VARIABLES_TO_MEMORY, query_no_space, query_no_space_length)) { ProxySQL_Admin* SPA = (ProxySQL_Admin*)pa; diff --git a/lib/Base_Session.cpp b/lib/Base_Session.cpp index e9f5bc1122..73c4823c82 100644 --- a/lib/Base_Session.cpp +++ b/lib/Base_Session.cpp @@ -85,6 +85,15 @@ void Base_Session::init() { MySQL_Session* mysession = static_cast(this); mysession->sess_STMTs_meta = new MySQL_STMTs_meta(); mysession->SLDH = new StmtLongDataHandler(); +#ifdef epoll_create1 + // Initialize GenAI async support + mysession->next_genai_request_id_ = 1; + mysession->genai_epoll_fd_ = epoll_create1(EPOLL_CLOEXEC); + if (mysession->genai_epoll_fd_ < 0) { + proxy_error("Failed to create GenAI epoll fd: %s\n", strerror(errno)); + mysession->genai_epoll_fd_ = -1; + } +#endif } }; diff --git a/lib/GenAI_Thread.cpp b/lib/GenAI_Thread.cpp new file mode 100644 index 0000000000..3b426382c1 --- /dev/null +++ b/lib/GenAI_Thread.cpp @@ -0,0 +1,1398 @@ +#include "GenAI_Thread.h" +#include "proxysql_debug.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "json.hpp" + +using json = nlohmann::json; + +// Platform compatibility +#ifndef EFD_CLOEXEC +#define EFD_CLOEXEC 0200000 +#endif +#ifndef EFD_NONBLOCK +#define EFD_NONBLOCK 04000 +#endif + +// epoll compatibility - detect epoll availability at compile time +#ifdef epoll_create1 + #define EPOLL_CREATE epoll_create1(0) +#else + #define EPOLL_CREATE epoll_create(1) +#endif + +// Define the array of variable names for the GenAI module +// Note: These do NOT include the "genai_" prefix - it's added by the flush functions +static const char* genai_thread_variables_names[] = { + "threads", + "embedding_uri", + "rerank_uri", + "embedding_timeout_ms", + "rerank_timeout_ms", + NULL +}; + +// ============================================================================ +// Move constructors and destructors for result structures +// ============================================================================ + +GenAI_EmbeddingResult::~GenAI_EmbeddingResult() { + if (data) { + delete[] data; + data = nullptr; + } +} + +GenAI_EmbeddingResult::GenAI_EmbeddingResult(GenAI_EmbeddingResult&& other) noexcept + : data(other.data), embedding_size(other.embedding_size), count(other.count) { + other.data = nullptr; + other.embedding_size = 0; + other.count = 0; +} + +GenAI_EmbeddingResult& GenAI_EmbeddingResult::operator=(GenAI_EmbeddingResult&& other) noexcept { + if (this != &other) { + if (data) delete[] data; + data = other.data; + embedding_size = other.embedding_size; + count = other.count; + other.data = nullptr; + other.embedding_size = 0; + other.count = 0; + } + return *this; +} + +GenAI_RerankResultArray::~GenAI_RerankResultArray() { + if (data) { + delete[] data; + data = nullptr; + } +} + +GenAI_RerankResultArray::GenAI_RerankResultArray(GenAI_RerankResultArray&& other) noexcept + : data(other.data), count(other.count) { + other.data = nullptr; + other.count = 0; +} + +GenAI_RerankResultArray& GenAI_RerankResultArray::operator=(GenAI_RerankResultArray&& other) noexcept { + if (this != &other) { + if (data) delete[] data; + data = other.data; + count = other.count; + other.data = nullptr; + other.count = 0; + } + return *this; +} + +// ============================================================================ +// GenAI_Threads_Handler implementation +// ============================================================================ + +GenAI_Threads_Handler::GenAI_Threads_Handler() { + shutdown_ = 0; + num_threads = 0; + pthread_rwlock_init(&rwlock, NULL); + epoll_fd_ = -1; + event_fd_ = -1; + + curl_global_init(CURL_GLOBAL_ALL); + + // Initialize variables with default values + variables.genai_threads = 4; + variables.genai_embedding_uri = strdup("http://127.0.0.1:8013/embedding"); + variables.genai_rerank_uri = strdup("http://127.0.0.1:8012/rerank"); + variables.genai_embedding_timeout_ms = 30000; + variables.genai_rerank_timeout_ms = 30000; + + status_variables.threads_initialized = 0; + status_variables.active_requests = 0; + status_variables.completed_requests = 0; + status_variables.failed_requests = 0; +} + +GenAI_Threads_Handler::~GenAI_Threads_Handler() { + if (shutdown_ == 0) { + shutdown(); + } + + if (variables.genai_embedding_uri) + free(variables.genai_embedding_uri); + if (variables.genai_rerank_uri) + free(variables.genai_rerank_uri); + + pthread_rwlock_destroy(&rwlock); +} + +void GenAI_Threads_Handler::init(unsigned int num, size_t stack) { + proxy_info("Initializing GenAI Threads Handler\n"); + + // Use variable value if num is 0 + if (num == 0) { + num = variables.genai_threads; + } + + num_threads = num; + shutdown_ = 0; + +#ifdef epoll_create1 + // Use epoll for async I/O + epoll_fd_ = EPOLL_CREATE; + if (epoll_fd_ < 0) { + proxy_error("Failed to create epoll: %s\n", strerror(errno)); + return; + } + + // Create eventfd for wakeup + event_fd_ = eventfd(0, EFD_NONBLOCK | EFD_CLOEXEC); + if (event_fd_ < 0) { + proxy_error("Failed to create eventfd: %s\n", strerror(errno)); + close(epoll_fd_); + epoll_fd_ = -1; + return; + } + + struct epoll_event ev; + ev.events = EPOLLIN; + ev.data.fd = event_fd_; + if (epoll_ctl(epoll_fd_, EPOLL_CTL_ADD, event_fd_, &ev) < 0) { + proxy_error("Failed to add eventfd to epoll: %s\n", strerror(errno)); + close(event_fd_); + close(epoll_fd_); + event_fd_ = -1; + epoll_fd_ = -1; + return; + } +#else + // Use pipe for wakeup on systems without epoll + int pipefds[2]; + if (pipe(pipefds) < 0) { + proxy_error("Failed to create pipe: %s\n", strerror(errno)); + return; + } + + // Set both ends to non-blocking + fcntl(pipefds[0], F_SETFL, O_NONBLOCK); + fcntl(pipefds[1], F_SETFL, O_NONBLOCK); + + event_fd_ = pipefds[1]; // Use write end for wakeup + epoll_fd_ = pipefds[0]; // Use read end for polling (repurposed) +#endif + + // Start listener thread + listener_thread_ = std::thread(&GenAI_Threads_Handler::listener_loop, this); + + // Start worker threads + for (unsigned int i = 0; i < num; i++) { + pthread_t thread; + if (pthread_create(&thread, NULL, [](void* arg) -> void* { + auto* handler = static_cast*>(arg); + handler->first->worker_loop(handler->second); + delete handler; + return NULL; + }, new std::pair(this, i)) == 0) { + worker_threads_.push_back(thread); + } else { + proxy_error("Failed to create worker thread %d\n", i); + } + } + + status_variables.threads_initialized = worker_threads_.size(); + + proxy_info("GenAI module started with %zu workers\n", worker_threads_.size()); + proxy_info("Embedding endpoint: %s\n", variables.genai_embedding_uri); + proxy_info("Rerank endpoint: %s\n", variables.genai_rerank_uri); + print_version(); +} + +void GenAI_Threads_Handler::shutdown() { + if (shutdown_ == 1) { + return; // Already shutting down + } + + proxy_info("Shutting down GenAI module\n"); + shutdown_ = 1; + + // Wake up listener + if (event_fd_ >= 0) { + uint64_t value = 1; + write(event_fd_, &value, sizeof(value)); + } + + // Notify all workers + queue_cv_.notify_all(); + + // Join worker threads + for (auto& t : worker_threads_) { + pthread_join(t, NULL); + } + worker_threads_.clear(); + + // Join listener thread + if (listener_thread_.joinable()) { + listener_thread_.join(); + } + + // Clean up epoll + if (event_fd_ >= 0) { + close(event_fd_); + event_fd_ = -1; + } + if (epoll_fd_ >= 0) { + close(epoll_fd_); + epoll_fd_ = -1; + } + + status_variables.threads_initialized = 0; +} + +void GenAI_Threads_Handler::wrlock() { + pthread_rwlock_wrlock(&rwlock); +} + +void GenAI_Threads_Handler::wrunlock() { + pthread_rwlock_unlock(&rwlock); +} + +char* GenAI_Threads_Handler::get_variable(char* name) { + if (!name) + return NULL; + + if (!strcmp(name, "threads")) { + char buf[64]; + sprintf(buf, "%d", variables.genai_threads); + return strdup(buf); + } + if (!strcmp(name, "embedding_uri")) { + return strdup(variables.genai_embedding_uri ? variables.genai_embedding_uri : ""); + } + if (!strcmp(name, "rerank_uri")) { + return strdup(variables.genai_rerank_uri ? variables.genai_rerank_uri : ""); + } + if (!strcmp(name, "embedding_timeout_ms")) { + char buf[64]; + sprintf(buf, "%d", variables.genai_embedding_timeout_ms); + return strdup(buf); + } + if (!strcmp(name, "rerank_timeout_ms")) { + char buf[64]; + sprintf(buf, "%d", variables.genai_rerank_timeout_ms); + return strdup(buf); + } + + return NULL; +} + +bool GenAI_Threads_Handler::set_variable(char* name, const char* value) { + if (!name || !value) + return false; + + if (!strcmp(name, "threads")) { + int val = atoi(value); + if (val < 1 || val > 256) { + proxy_error("Invalid value for genai_threads: %d (must be 1-256)\n", val); + return false; + } + variables.genai_threads = val; + return true; + } + if (!strcmp(name, "embedding_uri")) { + if (variables.genai_embedding_uri) + free(variables.genai_embedding_uri); + variables.genai_embedding_uri = strdup(value); + return true; + } + if (!strcmp(name, "rerank_uri")) { + if (variables.genai_rerank_uri) + free(variables.genai_rerank_uri); + variables.genai_rerank_uri = strdup(value); + return true; + } + if (!strcmp(name, "embedding_timeout_ms")) { + int val = atoi(value); + if (val < 100 || val > 300000) { + proxy_error("Invalid value for genai_embedding_timeout_ms: %d (must be 100-300000)\n", val); + return false; + } + variables.genai_embedding_timeout_ms = val; + return true; + } + if (!strcmp(name, "rerank_timeout_ms")) { + int val = atoi(value); + if (val < 100 || val > 300000) { + proxy_error("Invalid value for genai_rerank_timeout_ms: %d (must be 100-300000)\n", val); + return false; + } + variables.genai_rerank_timeout_ms = val; + return true; + } + + return false; +} + +char** GenAI_Threads_Handler::get_variables_list() { + // Count variables + int count = 0; + while (genai_thread_variables_names[count]) { + count++; + } + + // Allocate array + char** list = (char**)malloc(sizeof(char*) * (count + 1)); + if (!list) + return NULL; + + // Fill array + for (int i = 0; i < count; i++) { + list[i] = strdup(genai_thread_variables_names[i]); + } + list[count] = NULL; + + return list; +} + +void GenAI_Threads_Handler::print_version() { + fprintf(stderr, "GenAI Threads Handler rev. %s -- %s -- %s\n", GENAI_THREAD_VERSION, __FILE__, __TIMESTAMP__); +} + +/** + * @brief Register a client file descriptor for async GenAI communication + * + * This function is called by MySQL_Session to register the GenAI side of a + * socketpair for receiving async GenAI requests. The fd is added to the + * GenAI module's epoll instance so the listener thread can monitor it for + * incoming requests. + * + * Registration flow: + * 1. MySQL_Session creates socketpair(fds) + * 2. MySQL_Session keeps fds[0] for reading responses + * 3. MySQL_Session calls this function with fds[1] (GenAI side) + * 4. This function adds fds[1] to client_fds_ set and to epoll_fd_ + * 5. GenAI listener can now receive requests via fds[1] + * + * The fd is set to non-blocking mode to prevent the listener from blocking + * on a slow client. + * + * @param client_fd The GenAI side file descriptor from socketpair (typically fds[1]) + * @return true if successfully registered and added to epoll, false on error + * + * @see unregister_client(), listener_loop() + */ +bool GenAI_Threads_Handler::register_client(int client_fd) { + std::lock_guard lock(clients_mutex_); + + int flags = fcntl(client_fd, F_GETFL, 0); + fcntl(client_fd, F_SETFL, flags | O_NONBLOCK); + +#ifdef epoll_create1 + struct epoll_event ev; + ev.events = EPOLLIN; + ev.data.fd = client_fd; + if (epoll_ctl(epoll_fd_, EPOLL_CTL_ADD, client_fd, &ev) < 0) { + proxy_error("Failed to add client fd %d to epoll: %s\n", client_fd, strerror(errno)); + return false; + } +#endif + + client_fds_.insert(client_fd); + proxy_debug(PROXY_DEBUG_GENAI, 3, "Registered GenAI client fd %d\n", client_fd); + return true; +} + +/** + * @brief Unregister a client file descriptor from GenAI module + * + * This function is called when a MySQL session ends or an error occurs + * to clean up the socketpair connection. It removes the fd from the + * GenAI module's epoll instance and closes the connection. + * + * Cleanup flow: + * 1. Remove fd from epoll_fd_ monitoring + * 2. Remove fd from client_fds_ set + * 3. Close the file descriptor + * + * This is typically called when: + * - MySQL session ends (client disconnect) + * - Socketpair communication error occurs + * - Session cleanup during shutdown + * + * @param client_fd The GenAI side file descriptor to remove (typically fds[1]) + * + * @see register_client(), listener_loop() + */ +void GenAI_Threads_Handler::unregister_client(int client_fd) { + std::lock_guard lock(clients_mutex_); + +#ifdef epoll_create1 + if (epoll_fd_ >= 0) { + epoll_ctl(epoll_fd_, EPOLL_CTL_DEL, client_fd, NULL); + } +#endif + + client_fds_.erase(client_fd); + close(client_fd); + proxy_debug(PROXY_DEBUG_GENAI, 3, "Unregistered GenAI client fd %d\n", client_fd); +} + +size_t GenAI_Threads_Handler::get_queue_size() { + std::lock_guard lock(queue_mutex_); + return request_queue_.size(); +} + +// ============================================================================ +// Public API methods +// ============================================================================ + +size_t GenAI_Threads_Handler::WriteCallback(void* contents, size_t size, size_t nmemb, void* userp) { + size_t totalSize = size * nmemb; + std::string* response = static_cast(userp); + response->append(static_cast(contents), totalSize); + return totalSize; +} + +GenAI_EmbeddingResult GenAI_Threads_Handler::call_llama_embedding(const std::string& text) { + // For single document, use batch API with 1 document + std::vector texts = {text}; + return call_llama_batch_embedding(texts); +} + +/** + * @brief Generate embeddings for multiple documents via HTTP to llama-server + * + * This function sends a batch embedding request to the configured embedding service + * (genai_embedding_uri) via libcurl. The request is sent as JSON with an "input" array + * containing all documents to embed. + * + * Request format: + * ```json + * { + * "input": ["document 1", "document 2", ...] + * } + * ``` + * + * Response format (parsed): + * ```json + * { + * "results": [ + * {"embedding": [0.1, 0.2, ...]}, + * {"embedding": [0.3, 0.4, ...]} + * ] + * } + * ``` + * + * The function handles: + * - JSON escaping for special characters (quotes, backslashes, newlines, tabs) + * - HTTP POST request with Content-Type: application/json + * - Timeout enforcement via genai_embedding_timeout_ms + * - JSON response parsing to extract embedding arrays + * - Contiguous memory allocation for result embeddings + * + * Error handling: + * - On curl error: returns empty result, increments failed_requests + * - On parse error: returns empty result, increments failed_requests + * - On success: increments completed_requests + * + * @param texts Vector of document texts to embed (each can be up to several KB) + * @return GenAI_EmbeddingResult containing all embeddings with metadata. + * The caller takes ownership of the returned data and must free it. + * Returns empty result (data==nullptr || count==0) on error. + * + * @note This is a BLOCKING call (curl_easy_perform blocks). Should only be called + * from worker threads, not MySQL threads. Use embed_documents() wrapper instead. + * @see embed_documents(), call_llama_rerank() + */ +GenAI_EmbeddingResult GenAI_Threads_Handler::call_llama_batch_embedding(const std::vector& texts) { + GenAI_EmbeddingResult result; + CURL* curl = curl_easy_init(); + + if (!curl) { + proxy_error("Failed to initialize curl\n"); + status_variables.failed_requests++; + return result; + } + + // Build JSON request using nlohmann/json + json payload; + payload["input"] = texts; + std::string json_str = payload.dump(); + + // Configure curl + curl_easy_setopt(curl, CURLOPT_URL, variables.genai_embedding_uri); + curl_easy_setopt(curl, CURLOPT_POST, 1L); + curl_easy_setopt(curl, CURLOPT_POSTFIELDS, json_str.c_str()); + curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, WriteCallback); + curl_easy_setopt(curl, CURLOPT_TIMEOUT_MS, variables.genai_embedding_timeout_ms); + + std::string response_data; + curl_easy_setopt(curl, CURLOPT_WRITEDATA, &response_data); + + // Add content-type header + struct curl_slist* headers = nullptr; + headers = curl_slist_append(headers, "Content-Type: application/json"); + curl_easy_setopt(curl, CURLOPT_HTTPHEADER, headers); + + // Perform request + auto start_time = std::chrono::steady_clock::now(); + CURLcode res = curl_easy_perform(curl); + + if (res != CURLE_OK) { + proxy_error("curl_easy_perform() failed: %s\n", curl_easy_strerror(res)); + status_variables.failed_requests++; + } else { + // Parse JSON response using nlohmann/json + try { + json response_json = json::parse(response_data); + + std::vector> all_embeddings; + + // Handle different response formats + if (response_json.contains("results") && response_json["results"].is_array()) { + // Format: {"results": [{"embedding": [...]}, ...]} + for (const auto& result_item : response_json["results"]) { + if (result_item.contains("embedding") && result_item["embedding"].is_array()) { + std::vector embedding = result_item["embedding"].get>(); + all_embeddings.push_back(std::move(embedding)); + } + } + } else if (response_json.contains("data") && response_json["data"].is_array()) { + // Format: {"data": [{"embedding": [...]}]} + for (const auto& item : response_json["data"]) { + if (item.contains("embedding") && item["embedding"].is_array()) { + std::vector embedding = item["embedding"].get>(); + all_embeddings.push_back(std::move(embedding)); + } + } + } else if (response_json.contains("embeddings") && response_json["embeddings"].is_array()) { + // Format: {"embeddings": [[...], ...]} + all_embeddings = response_json["embeddings"].get>>(); + } + + // Convert to contiguous array + if (!all_embeddings.empty()) { + result.count = all_embeddings.size(); + result.embedding_size = all_embeddings[0].size(); + + size_t total_floats = result.embedding_size * result.count; + result.data = new float[total_floats]; + + for (size_t i = 0; i < all_embeddings.size(); i++) { + size_t offset = i * result.embedding_size; + const auto& emb = all_embeddings[i]; + std::copy(emb.begin(), emb.end(), result.data + offset); + } + + status_variables.completed_requests++; + } else { + status_variables.failed_requests++; + } + } catch (const json::parse_error& e) { + proxy_error("Failed to parse embedding response JSON: %s\n", e.what()); + status_variables.failed_requests++; + } catch (const std::exception& e) { + proxy_error("Error processing embedding response: %s\n", e.what()); + status_variables.failed_requests++; + } + } + + curl_slist_free_all(headers); + curl_easy_cleanup(curl); + + return result; +} + +/** + * @brief Rerank documents based on query relevance via HTTP to llama-server + * + * This function sends a reranking request to the configured reranking service + * (genai_rerank_uri) via libcurl. The request is sent as JSON with a query + * and documents array. The service returns documents sorted by relevance to the query. + * + * Request format: + * ```json + * { + * "query": "search query here", + * "documents": ["doc 1", "doc 2", ...] + * } + * ``` + * + * Response format (parsed): + * ```json + * { + * "results": [ + * {"index": 3, "relevance_score": 0.95}, + * {"index": 0, "relevance_score": 0.82}, + * ... + * ] + * } + * ``` + * + * The function handles: + * - JSON escaping for special characters in query and documents + * - HTTP POST request with Content-Type: application/json + * - Timeout enforcement via genai_rerank_timeout_ms + * - JSON response parsing to extract results array + * - Optional top_n limiting of results + * + * Error handling: + * - On curl error: returns empty result, increments failed_requests + * - On parse error: returns empty result, increments failed_requests + * - On success: increments completed_requests + * + * @param query Query string to rerank against (e.g., search query, user question) + * @param texts Vector of document texts to rerank (typically search results) + * @param top_n Maximum number of top results to return (0 = return all sorted results) + * @return GenAI_RerankResultArray containing results sorted by relevance. + * Each result includes the original document index and a relevance score. + * The caller takes ownership of the returned data and must free it. + * Returns empty result (data==nullptr || count==0) on error. + * + * @note This is a BLOCKING call (curl_easy_perform blocks). Should only be called + * from worker threads, not MySQL threads. Use rerank_documents() wrapper instead. + * @see rerank_documents(), call_llama_batch_embedding() + */ +GenAI_RerankResultArray GenAI_Threads_Handler::call_llama_rerank(const std::string& query, + const std::vector& texts, + uint32_t top_n) { + GenAI_RerankResultArray result; + CURL* curl = curl_easy_init(); + + if (!curl) { + proxy_error("Failed to initialize curl\n"); + status_variables.failed_requests++; + return result; + } + + // Build JSON request using nlohmann/json + json payload; + payload["query"] = query; + payload["documents"] = texts; + std::string json_str = payload.dump(); + + // Configure curl + curl_easy_setopt(curl, CURLOPT_URL, variables.genai_rerank_uri); + curl_easy_setopt(curl, CURLOPT_POST, 1L); + curl_easy_setopt(curl, CURLOPT_POSTFIELDS, json_str.c_str()); + curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, WriteCallback); + curl_easy_setopt(curl, CURLOPT_TIMEOUT_MS, variables.genai_rerank_timeout_ms); + + std::string response_data; + curl_easy_setopt(curl, CURLOPT_WRITEDATA, &response_data); + + struct curl_slist* headers = nullptr; + headers = curl_slist_append(headers, "Content-Type: application/json"); + curl_easy_setopt(curl, CURLOPT_HTTPHEADER, headers); + + CURLcode res = curl_easy_perform(curl); + + if (res != CURLE_OK) { + proxy_error("curl_easy_perform() failed: %s\n", curl_easy_strerror(res)); + status_variables.failed_requests++; + } else { + // Parse JSON response using nlohmann/json + try { + json response_json = json::parse(response_data); + + std::vector results; + + // Handle different response formats + if (response_json.contains("results") && response_json["results"].is_array()) { + // Format: {"results": [{"index": 0, "relevance_score": 0.95}, ...]} + for (const auto& result_item : response_json["results"]) { + GenAI_RerankResult r; + r.index = result_item.value("index", 0); + // Support both "relevance_score" and "score" field names + if (result_item.contains("relevance_score")) { + r.score = result_item.value("relevance_score", 0.0f); + } else { + r.score = result_item.value("score", 0.0f); + } + results.push_back(r); + } + } else if (response_json.contains("data") && response_json["data"].is_array()) { + // Alternative format: {"data": [...]} + for (const auto& result_item : response_json["data"]) { + GenAI_RerankResult r; + r.index = result_item.value("index", 0); + // Support both "relevance_score" and "score" field names + if (result_item.contains("relevance_score")) { + r.score = result_item.value("relevance_score", 0.0f); + } else { + r.score = result_item.value("score", 0.0f); + } + results.push_back(r); + } + } + + // Apply top_n limit if specified + if (!results.empty() && top_n > 0 && top_n < results.size()) { + result.count = top_n; + result.data = new GenAI_RerankResult[top_n]; + std::copy(results.begin(), results.begin() + top_n, result.data); + } else if (!results.empty()) { + result.count = results.size(); + result.data = new GenAI_RerankResult[results.size()]; + std::copy(results.begin(), results.end(), result.data); + } + + if (!results.empty()) { + status_variables.completed_requests++; + } else { + status_variables.failed_requests++; + } + } catch (const json::parse_error& e) { + proxy_error("Failed to parse rerank response JSON: %s\n", e.what()); + status_variables.failed_requests++; + } catch (const std::exception& e) { + proxy_error("Error processing rerank response: %s\n", e.what()); + status_variables.failed_requests++; + } + } + + curl_slist_free_all(headers); + curl_easy_cleanup(curl); + + return result; +} + +// ============================================================================ +// Public API methods +// ============================================================================ + +GenAI_EmbeddingResult GenAI_Threads_Handler::embed_documents(const std::vector& documents) { + if (documents.empty()) { + proxy_error("embed_documents called with empty documents list\n"); + status_variables.failed_requests++; + return GenAI_EmbeddingResult(); + } + + status_variables.active_requests++; + + GenAI_EmbeddingResult result; + if (documents.size() == 1) { + result = call_llama_embedding(documents[0]); + } else { + result = call_llama_batch_embedding(documents); + } + + status_variables.active_requests--; + return result; +} + +GenAI_RerankResultArray GenAI_Threads_Handler::rerank_documents(const std::string& query, + const std::vector& documents, + uint32_t top_n) { + if (documents.empty()) { + proxy_error("rerank_documents called with empty documents list\n"); + status_variables.failed_requests++; + return GenAI_RerankResultArray(); + } + + if (query.empty()) { + proxy_error("rerank_documents called with empty query\n"); + status_variables.failed_requests++; + return GenAI_RerankResultArray(); + } + + status_variables.active_requests++; + + GenAI_RerankResultArray result = call_llama_rerank(query, documents, top_n); + + status_variables.active_requests--; + return result; +} + +// ============================================================================ +// Worker and listener loops (for async socket pair integration) +// ============================================================================ + +/** + * @brief GenAI listener thread main loop + * + * This function runs in a dedicated thread and monitors registered client file + * descriptors via epoll for incoming GenAI requests from MySQL sessions. + * + * Workflow: + * 1. Wait for events on epoll_fd_ (100ms timeout for shutdown check) + * 2. When event occurs on client fd: + * - Read GenAI_RequestHeader + * - Read JSON query (if query_len > 0) + * - Build GenAI_Request and queue to request_queue_ + * - Notify worker thread via condition variable + * 3. Handle client disconnection and errors + * + * Communication protocol: + * - Client sends: GenAI_RequestHeader (fixed size) + JSON query (variable size) + * - Header includes: request_id, operation, query_len, flags, top_n + * + * This thread ensures that MySQL sessions never block - they send requests + * via socketpair and immediately return to handling other queries. The actual + * blocking HTTP calls to llama-server happen in worker threads. + * + * @note Runs in dedicated listener_thread_ created during init() + * @see worker_loop(), register_client(), process_json_query() + */ +void GenAI_Threads_Handler::listener_loop() { + proxy_debug(PROXY_DEBUG_GENAI, 3, "GenAI listener thread started\n"); + +#ifdef epoll_create1 + const int MAX_EVENTS = 64; + struct epoll_event events[MAX_EVENTS]; + + while (!shutdown_) { + int nfds = epoll_wait(epoll_fd_, events, MAX_EVENTS, 100); + + if (nfds < 0 && errno != EINTR) { + if (errno != EINTR) { + proxy_error("epoll_wait failed: %s\n", strerror(errno)); + } + continue; + } + + for (int i = 0; i < nfds; i++) { + if (events[i].data.fd == event_fd_) { + continue; + } + + int client_fd = events[i].data.fd; + + // Read request header + GenAI_RequestHeader header; + ssize_t n = read(client_fd, &header, sizeof(header)); + + if (n < 0) { + // Check for non-blocking read - not an error, just no data yet + if (errno == EAGAIN || errno == EWOULDBLOCK) { + continue; + } + // Real error - log and close connection + proxy_error("GenAI: Error reading from client fd %d: %s\n", + client_fd, strerror(errno)); + epoll_ctl(epoll_fd_, EPOLL_CTL_DEL, client_fd, nullptr); + close(client_fd); + { + std::lock_guard lock(clients_mutex_); + client_fds_.erase(client_fd); + } + continue; + } + + if (n == 0) { + // Client disconnected (EOF) + epoll_ctl(epoll_fd_, EPOLL_CTL_DEL, client_fd, nullptr); + close(client_fd); + { + std::lock_guard lock(clients_mutex_); + client_fds_.erase(client_fd); + } + continue; + } + + if (n != sizeof(header)) { + proxy_error("GenAI: Incomplete header read from fd %d: got %zd, expected %zu\n", + client_fd, n, sizeof(header)); + continue; + } + + // Read JSON query if present + std::string json_query; + if (header.query_len > 0) { + json_query.resize(header.query_len); + size_t total_read = 0; + while (total_read < header.query_len) { + ssize_t r = read(client_fd, &json_query[total_read], + header.query_len - total_read); + if (r <= 0) { + if (errno == EAGAIN || errno == EWOULDBLOCK) { + usleep(1000); // Wait 1ms and retry + continue; + } + proxy_error("GenAI: Error reading JSON query from fd %d: %s\n", + client_fd, strerror(errno)); + break; + } + total_read += r; + } + } + + // Build request and queue it + GenAI_Request req; + req.client_fd = client_fd; + req.request_id = header.request_id; + req.operation = header.operation; + req.top_n = header.top_n; + req.json_query = json_query; + + { + std::lock_guard lock(queue_mutex_); + request_queue_.push(std::move(req)); + } + + queue_cv_.notify_one(); + + proxy_debug(PROXY_DEBUG_GENAI, 3, + "GenAI: Queued request %lu from fd %d (op=%u, query_len=%u)\n", + header.request_id, client_fd, header.operation, header.query_len); + } + } +#else + // Use poll() for systems without epoll support + while (!shutdown_) { + // Build pollfd array + std::vector pollfds; + pollfds.reserve(client_fds_.size() + 1); + + // Add wakeup pipe read end + struct pollfd wakeup_pfd; + wakeup_pfd.fd = epoll_fd_; // Reused as pipe read end + wakeup_pfd.events = POLLIN; + wakeup_pfd.revents = 0; + pollfds.push_back(wakeup_pfd); + + // Add all client fds + { + std::lock_guard lock(clients_mutex_); + for (int fd : client_fds_) { + struct pollfd pfd; + pfd.fd = fd; + pfd.events = POLLIN; + pfd.revents = 0; + pollfds.push_back(pfd); + } + } + + int nfds = poll(pollfds.data(), pollfds.size(), 100); + + if (nfds < 0 && errno != EINTR) { + proxy_error("poll failed: %s\n", strerror(errno)); + continue; + } + + // Check for wakeup event + if (pollfds.size() > 0 && (pollfds[0].revents & POLLIN)) { + uint64_t value; + read(pollfds[0].fd, &value, sizeof(value)); // Clear the pipe + continue; + } + + // Handle client events + for (size_t i = 1; i < pollfds.size(); i++) { + if (pollfds[i].revents & POLLIN) { + // Handle client events here + // This will be implemented when integrating with MySQL/PgSQL threads + } + } + } +#endif + + proxy_debug(PROXY_DEBUG_GENAI, 3, "GenAI listener thread stopped\n"); +} + +/** + * @brief GenAI worker thread main loop + * + * This function runs in worker thread pool and processes GenAI requests + * from the request queue. Each worker handles: + * - JSON query parsing + * - HTTP requests to embedding/reranking services (via libcurl) + * - Response formatting and sending back via socketpair + * + * Workflow: + * 1. Wait on request_queue_ with condition variable (shutdown-safe) + * 2. Dequeue GenAI_Request + * 3. Process the JSON query via process_json_query() + * - This may involve HTTP calls to llama-server (blocking in worker thread) + * 4. Format response as GenAI_ResponseHeader + JSON result + * 5. Write response back to client via socketpair + * 6. Update status variables (completed_requests, failed_requests) + * + * The blocking HTTP calls (curl_easy_perform) happen in this worker thread, + * NOT in the MySQL thread. This is the key to non-blocking behavior - MySQL + * sessions can continue processing other queries while workers wait for HTTP responses. + * + * Error handling: + * - On write error: cleanup request and mark as failed + * - On process_json_query error: send error response + * - Client fd cleanup on any error + * + * @param worker_id Worker thread identifier (0-based index for logging) + * + * @note Runs in worker_threads_[worker_id] created during init() + * @see listener_loop(), process_json_query(), GenAI_Request + */ +void GenAI_Threads_Handler::worker_loop(int worker_id) { + proxy_debug(PROXY_DEBUG_GENAI, 3, "GenAI worker thread %d started\n", worker_id); + + while (!shutdown_) { + std::unique_lock lock(queue_mutex_); + queue_cv_.wait(lock, [this] { + return shutdown_ || !request_queue_.empty(); + }); + + if (shutdown_) break; + + if (request_queue_.empty()) continue; + + GenAI_Request req = std::move(request_queue_.front()); + request_queue_.pop(); + lock.release(); + + // Process request + auto start_time = std::chrono::steady_clock::now(); + + proxy_debug(PROXY_DEBUG_GENAI, 3, + "Worker %d processing request %lu (op=%u)\n", + worker_id, req.request_id, req.operation); + + // Process the JSON query + std::string json_result = process_json_query(req.json_query); + + auto end_time = std::chrono::steady_clock::now(); + int processing_time_ms = std::chrono::duration_cast( + end_time - start_time).count(); + + // Prepare response header + GenAI_ResponseHeader resp; + resp.request_id = req.request_id; + resp.status_code = json_result.empty() ? 1 : 0; + resp.result_len = json_result.length(); + resp.processing_time_ms = processing_time_ms; + resp.result_ptr = 0; // Not using shared memory + resp.result_count = 0; + resp.reserved = 0; + + // Send response header + ssize_t written = write(req.client_fd, &resp, sizeof(resp)); + if (written != sizeof(resp)) { + proxy_error("GenAI: Failed to write response header to fd %d: %s\n", + req.client_fd, strerror(errno)); + status_variables.failed_requests++; + close(req.client_fd); + { + std::lock_guard lock(clients_mutex_); + client_fds_.erase(req.client_fd); + } + continue; + } + + // Send JSON result + if (resp.result_len > 0) { + size_t total_written = 0; + while (total_written < json_result.length()) { + ssize_t w = write(req.client_fd, + json_result.data() + total_written, + json_result.length() - total_written); + if (w <= 0) { + if (errno == EAGAIN || errno == EWOULDBLOCK) { + usleep(1000); // Wait 1ms and retry + continue; + } + proxy_error("GenAI: Failed to write JSON result to fd %d: %s\n", + req.client_fd, strerror(errno)); + status_variables.failed_requests++; + break; + } + total_written += w; + } + } + + status_variables.completed_requests++; + + proxy_debug(PROXY_DEBUG_GENAI, 3, + "Worker %d completed request %lu (status=%u, result_len=%u, time=%dms)\n", + worker_id, req.request_id, resp.status_code, resp.result_len, processing_time_ms); + } + + proxy_debug(PROXY_DEBUG_GENAI, 3, "GenAI worker thread %d stopped\n", worker_id); +} + +/** + * @brief Execute SQL query to retrieve documents for reranking + * + * This helper function is used by the document_from_sql feature to execute + * a SQL query and retrieve documents from a database for reranking. + * + * The SQL query should return a single column containing document text. + * For example: + * ```sql + * SELECT content FROM posts WHERE category = 'tech' + * ``` + * + * Note: This function is currently a stub and needs MySQL connection handling + * to be implemented. The document_from_sql feature cannot be used until this + * is implemented. + * + * @param sql_query SQL query string to execute (should select document text) + * @return Pair of (success, vector of documents). On success, returns (true, documents). + * On failure, returns (false, empty vector). + * + * @todo Implement MySQL connection handling for document_from_sql feature + */ +static std::pair> execute_sql_for_documents(const std::string& sql_query) { + std::vector documents; + + // TODO: Implement MySQL connection handling + // For now, return error indicating this needs MySQL connectivity + return {false, {}}; +} + +/** + * @brief Process JSON query autonomously (handles embed/rerank/document_from_sql) + * + * This method is the main entry point for processing GenAI JSON queries from + * MySQL sessions. It parses the JSON, determines the operation type, and routes + * to the appropriate handler (embedding or reranking). + * + * Supported query formats: + * + * 1. Embed operation: + * ```json + * { + * "type": "embed", + * "documents": ["doc1 text", "doc2 text", ...] + * } + * ``` + * Response: `{"columns": ["embedding"], "rows": [["0.1,0.2,..."], ...]}` + * + * 2. Rerank with direct documents: + * ```json + * { + * "type": "rerank", + * "query": "search query", + * "documents": ["doc1", "doc2", ...], + * "top_n": 5, + * "columns": 3 + * } + * ``` + * Response: `{"columns": ["index", "score", "document"], "rows": [[0, 0.95, "doc1"], ...]}` + * + * 3. Rerank with SQL documents (not yet implemented): + * ```json + * { + * "type": "rerank", + * "query": "search query", + * "document_from_sql": {"query": "SELECT content FROM posts WHERE ..."}, + * "top_n": 5 + * } + * ``` + * + * Response format: + * - Success: `{"columns": [...], "rows": [[...], ...]}` + * - Error: `{"error": "error message"}` + * + * The response format matches MySQL resultset format for easy conversion to + * MySQL result packets in MySQL_Session. + * + * @param json_query JSON query string from client (must be valid JSON) + * @return JSON string result with columns and rows formatted for MySQL resultset. + * Returns error JSON string on failure. + * + * @note This method is called from worker threads as part of async request processing. + * The blocking HTTP calls (embed_documents, rerank_documents) occur in the + * worker thread, not the MySQL thread. + * + * @see embed_documents(), rerank_documents(), worker_loop() + */ +std::string GenAI_Threads_Handler::process_json_query(const std::string& json_query) { + json result; + + try { + // Parse JSON query + json query_json = json::parse(json_query); + + if (!query_json.is_object()) { + result["error"] = "Query must be a JSON object"; + return result.dump(); + } + + // Extract operation type + if (!query_json.contains("type") || !query_json["type"].is_string()) { + result["error"] = "Query must contain a 'type' field (embed or rerank)"; + return result.dump(); + } + + std::string op_type = query_json["type"].get(); + + // Handle embed operation + if (op_type == "embed") { + // Extract documents array + if (!query_json.contains("documents") || !query_json["documents"].is_array()) { + result["error"] = "Embed operation requires a 'documents' array"; + return result.dump(); + } + + std::vector documents; + for (const auto& doc : query_json["documents"]) { + if (doc.is_string()) { + documents.push_back(doc.get()); + } else { + documents.push_back(doc.dump()); + } + } + + if (documents.empty()) { + result["error"] = "Embed operation requires at least one document"; + return result.dump(); + } + + // Call embedding service + GenAI_EmbeddingResult embeddings = embed_documents(documents); + + if (!embeddings.data || embeddings.count == 0) { + result["error"] = "Failed to generate embeddings"; + return result.dump(); + } + + // Build result + result["columns"] = json::array({"embedding"}); + json rows = json::array(); + + for (size_t i = 0; i < embeddings.count; i++) { + float* embedding = embeddings.data + (i * embeddings.embedding_size); + std::ostringstream oss; + for (size_t k = 0; k < embeddings.embedding_size; k++) { + if (k > 0) oss << ","; + oss << embedding[k]; + } + rows.push_back(json::array({oss.str()})); + } + + result["rows"] = rows; + return result.dump(); + } + + // Handle rerank operation + if (op_type == "rerank") { + // Extract query + if (!query_json.contains("query") || !query_json["query"].is_string()) { + result["error"] = "Rerank operation requires a 'query' string"; + return result.dump(); + } + std::string query_str = query_json["query"].get(); + + if (query_str.empty()) { + result["error"] = "Rerank query cannot be empty"; + return result.dump(); + } + + // Check for document_from_sql or documents array + std::vector documents; + bool use_sql_documents = query_json.contains("document_from_sql") && query_json["document_from_sql"].is_object(); + + if (use_sql_documents) { + // document_from_sql mode - execute SQL to get documents + if (!query_json["document_from_sql"].contains("query") || !query_json["document_from_sql"]["query"].is_string()) { + result["error"] = "document_from_sql requires a 'query' string"; + return result.dump(); + } + + std::string sql_query = query_json["document_from_sql"]["query"].get(); + if (sql_query.empty()) { + result["error"] = "document_from_sql query cannot be empty"; + return result.dump(); + } + + // Execute SQL query to get documents + auto [success, docs] = execute_sql_for_documents(sql_query); + if (!success) { + result["error"] = "document_from_sql feature not yet implemented - MySQL connection handling required"; + return result.dump(); + } + documents = docs; + } else { + // Direct documents array mode + if (!query_json.contains("documents") || !query_json["documents"].is_array()) { + result["error"] = "Rerank operation requires 'documents' array or 'document_from_sql' object"; + return result.dump(); + } + + for (const auto& doc : query_json["documents"]) { + if (doc.is_string()) { + documents.push_back(doc.get()); + } else { + documents.push_back(doc.dump()); + } + } + } + + if (documents.empty()) { + result["error"] = "Rerank operation requires at least one document"; + return result.dump(); + } + + // Extract optional top_n (default 0 = return all) + uint32_t opt_top_n = 0; + if (query_json.contains("top_n") && query_json["top_n"].is_number()) { + opt_top_n = query_json["top_n"].get(); + } + + // Extract optional columns (default 3 = index, score, document) + uint32_t opt_columns = 3; + if (query_json.contains("columns") && query_json["columns"].is_number()) { + opt_columns = query_json["columns"].get(); + if (opt_columns != 2 && opt_columns != 3) { + result["error"] = "Rerank 'columns' must be 2 or 3"; + return result.dump(); + } + } + + // Call rerank service + GenAI_RerankResultArray rerank_result = rerank_documents(query_str, documents, opt_top_n); + + if (!rerank_result.data || rerank_result.count == 0) { + result["error"] = "Failed to rerank documents"; + return result.dump(); + } + + // Build result + json rows = json::array(); + + if (opt_columns == 2) { + result["columns"] = json::array({"index", "score"}); + + for (size_t i = 0; i < rerank_result.count; i++) { + const GenAI_RerankResult& r = rerank_result.data[i]; + std::string index_str = std::to_string(r.index); + std::string score_str = std::to_string(r.score); + rows.push_back(json::array({index_str, score_str})); + } + } else { + result["columns"] = json::array({"index", "score", "document"}); + + for (size_t i = 0; i < rerank_result.count; i++) { + const GenAI_RerankResult& r = rerank_result.data[i]; + if (r.index >= documents.size()) { + continue; // Skip invalid index + } + std::string index_str = std::to_string(r.index); + std::string score_str = std::to_string(r.score); + const std::string& doc = documents[r.index]; + rows.push_back(json::array({index_str, score_str, doc})); + } + } + + result["rows"] = rows; + return result.dump(); + } + + // Unknown operation type + result["error"] = "Unknown operation type: " + op_type + ". Use 'embed' or 'rerank'"; + return result.dump(); + + } catch (const json::parse_error& e) { + result["error"] = std::string("JSON parse error: ") + e.what(); + return result.dump(); + } catch (const std::exception& e) { + result["error"] = std::string("Error: ") + e.what(); + return result.dump(); + } +} diff --git a/lib/Makefile b/lib/Makefile index 3229254228..10c7a25ef8 100644 --- a/lib/Makefile +++ b/lib/Makefile @@ -75,6 +75,7 @@ _OBJ_CXX := ProxySQL_GloVars.oo network.oo debug.oo configfile.oo Query_Cache.oo proxy_protocol_info.oo \ proxysql_find_charset.oo ProxySQL_Poll.oo \ PgSQL_Protocol.oo PgSQL_Thread.oo PgSQL_Data_Stream.oo PgSQL_Session.oo PgSQL_Variables.oo PgSQL_HostGroups_Manager.oo PgSQL_Connection.oo PgSQL_Backend.oo PgSQL_Logger.oo PgSQL_Authentication.oo PgSQL_Error_Helper.oo \ + GenAI_Thread.oo \ MySQL_Query_Cache.oo PgSQL_Query_Cache.oo PgSQL_Monitor.oo \ MySQL_Set_Stmt_Parser.oo PgSQL_Set_Stmt_Parser.oo \ PgSQL_Variables_Validator.oo PgSQL_ExplicitTxnStateMgr.oo \ diff --git a/lib/MySQL_Session.cpp b/lib/MySQL_Session.cpp index bef0d4bd99..e7c270614c 100644 --- a/lib/MySQL_Session.cpp +++ b/lib/MySQL_Session.cpp @@ -14,6 +14,7 @@ using json = nlohmann::json; #include "MySQL_Data_Stream.h" #include "MySQL_Query_Processor.h" #include "MySQL_PreparedStatement.h" +#include "GenAI_Thread.h" #include "MySQL_Logger.hpp" #include "StatCounters.h" #include "MySQL_Authentication.hpp" @@ -3609,6 +3610,679 @@ bool MySQL_Session::handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_C return false; } +// Handler for GENAI: queries - experimental GenAI integration +// Query formats: +// GENAI: {"type": "embed", "documents": ["doc1", "doc2", ...]} +// GENAI: {"type": "rerank", "query": "...", "documents": [...], "top_n": 5, "columns": 3} +// Returns: Resultset with embeddings or reranked documents +void MySQL_Session::handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_COM_QUERY___genai(const char* query, size_t query_len, PtrSize_t* pkt) { + // Skip leading space after "GENAI:" + while (query_len > 0 && (*query == ' ' || *query == '\t')) { + query++; + query_len--; + } + + if (query_len == 0) { + // Empty query after GENAI: + client_myds->DSS = STATE_QUERY_SENT_NET; + client_myds->myprot.generate_pkt_ERR(true, NULL, NULL, 1, 1234, (char*)"HY000", "Empty GENAI: query", true); + l_free(pkt->size, pkt->ptr); + client_myds->DSS = STATE_SLEEP; + status = WAITING_CLIENT_DATA; + return; + } + + // Check GenAI module is initialized + if (!GloGATH) { + client_myds->DSS = STATE_QUERY_SENT_NET; + client_myds->myprot.generate_pkt_ERR(true, NULL, NULL, 1, 1237, (char*)"HY000", "GenAI module is not initialized", true); + l_free(pkt->size, pkt->ptr); + client_myds->DSS = STATE_SLEEP; + status = WAITING_CLIENT_DATA; + return; + } + +#ifdef epoll_create1 + // Use async path with socketpair for non-blocking operation + if (!handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___genai_send_async(query, query_len, pkt)) { + // Async send failed - error already sent to client + l_free(pkt->size, pkt->ptr); + client_myds->DSS = STATE_SLEEP; + status = WAITING_CLIENT_DATA; + return; + } + + // Request sent asynchronously - don't free pkt, will be freed in response handler + // Return immediately, session is now free to handle other queries + proxy_debug(PROXY_DEBUG_GENAI, 3, "GenAI: Query sent asynchronously, session continuing\n"); +#else + // Fallback to synchronous blocking path for systems without epoll + // Pass JSON query to GenAI module for autonomous processing + std::string json_query(query, query_len); + std::string result_json = GloGATH->process_json_query(json_query); + + if (result_json.empty()) { + client_myds->DSS = STATE_QUERY_SENT_NET; + client_myds->myprot.generate_pkt_ERR(true, NULL, NULL, 1, 1250, (char*)"HY000", "GenAI query processing failed", true); + l_free(pkt->size, pkt->ptr); + client_myds->DSS = STATE_SLEEP; + status = WAITING_CLIENT_DATA; + return; + } + + // Parse the JSON result and build MySQL resultset + try { + json result = json::parse(result_json); + + if (!result.is_object()) { + client_myds->DSS = STATE_QUERY_SENT_NET; + client_myds->myprot.generate_pkt_ERR(true, NULL, NULL, 1, 1251, (char*)"HY000", "GenAI returned invalid result format", true); + l_free(pkt->size, pkt->ptr); + client_myds->DSS = STATE_SLEEP; + status = WAITING_CLIENT_DATA; + return; + } + + // Check if result is an error + if (result.contains("error") && result["error"].is_string()) { + std::string error_msg = result["error"].get(); + client_myds->DSS = STATE_QUERY_SENT_NET; + client_myds->myprot.generate_pkt_ERR(true, NULL, NULL, 1, 1252, (char*)"HY000", (char*)error_msg.c_str(), true); + l_free(pkt->size, pkt->ptr); + client_myds->DSS = STATE_SLEEP; + status = WAITING_CLIENT_DATA; + return; + } + + // Extract resultset data + if (!result.contains("columns") || !result["columns"].is_array()) { + client_myds->DSS = STATE_QUERY_SENT_NET; + client_myds->myprot.generate_pkt_ERR(true, NULL, NULL, 1, 1253, (char*)"HY000", "GenAI result missing 'columns' field", true); + l_free(pkt->size, pkt->ptr); + client_myds->DSS = STATE_SLEEP; + status = WAITING_CLIENT_DATA; + return; + } + + if (!result.contains("rows") || !result["rows"].is_array()) { + client_myds->DSS = STATE_QUERY_SENT_NET; + client_myds->myprot.generate_pkt_ERR(true, NULL, NULL, 1, 1254, (char*)"HY000", "GenAI result missing 'rows' field", true); + l_free(pkt->size, pkt->ptr); + client_myds->DSS = STATE_SLEEP; + status = WAITING_CLIENT_DATA; + return; + } + + auto columns = result["columns"]; + auto rows = result["rows"]; + + // Build SQLite3 resultset + std::unique_ptr resultset(new SQLite3_result(columns.size())); + + // Add column definitions + for (size_t i = 0; i < columns.size(); i++) { + if (columns[i].is_string()) { + std::string col_name = columns[i].get(); + resultset->add_column_definition(SQLITE_TEXT, (char*)col_name.c_str()); + } + } + + // Add rows + for (const auto& row : rows) { + if (!row.is_array()) continue; + + // Create row data array + char** row_data = (char**)malloc(columns.size() * sizeof(char*)); + size_t valid_cols = 0; + + for (size_t i = 0; i < columns.size() && i < row.size(); i++) { + if (row[i].is_string()) { + std::string val = row[i].get(); + row_data[valid_cols++] = strdup(val.c_str()); + } else if (row[i].is_null()) { + row_data[valid_cols++] = NULL; + } else { + // Convert to string + std::string val = row[i].dump(); + // Remove quotes if present + if (val.size() >= 2 && val[0] == '"' && val[val.size()-1] == '"') { + val = val.substr(1, val.size() - 2); + } + row_data[valid_cols++] = strdup(val.c_str()); + } + } + + resultset->add_row(row_data); + + // Free row data + for (size_t i = 0; i < valid_cols; i++) { + if (row_data[i]) free(row_data[i]); + } + free(row_data); + } + + // Send resultset to client + SQLite3_to_MySQL(resultset.get(), NULL, 0, &client_myds->myprot, false, + (client_myds->myconn->options.client_flag & CLIENT_DEPRECATE_EOF)); + + l_free(pkt->size, pkt->ptr); + client_myds->DSS = STATE_SLEEP; + status = WAITING_CLIENT_DATA; + + } catch (const json::parse_error& e) { + client_myds->DSS = STATE_QUERY_SENT_NET; + std::string err_msg = "Failed to parse GenAI result: "; + err_msg += e.what(); + client_myds->myprot.generate_pkt_ERR(true, NULL, NULL, 1, 1255, (char*)"HY000", (char*)err_msg.c_str(), true); + l_free(pkt->size, pkt->ptr); + client_myds->DSS = STATE_SLEEP; + status = WAITING_CLIENT_DATA; + } catch (const std::exception& e) { + client_myds->DSS = STATE_QUERY_SENT_NET; + std::string err_msg = "Error processing GenAI result: "; + err_msg += e.what(); + client_myds->myprot.generate_pkt_ERR(true, NULL, NULL, 1, 1256, (char*)"HY000", (char*)err_msg.c_str(), true); + l_free(pkt->size, pkt->ptr); + client_myds->DSS = STATE_SLEEP; + status = WAITING_CLIENT_DATA; + } +#endif // epoll_create1 - fallback blocking path +} + +#ifdef epoll_create1 +/** + * @brief Send GenAI request asynchronously via socketpair + * + * This function implements the non-blocking async GenAI request path. It creates + * a socketpair for bidirectional communication with the GenAI module and sends + * the request immediately without waiting for the response. + * + * Async flow: + * 1. Create socketpair(fds) for bidirectional communication + * 2. Register fds[1] (GenAI side) with GenAI module via register_client() + * 3. Store request in pending_genai_requests_ map for later response matching + * 4. Send GenAI_RequestHeader + JSON query via fds[0] + * 5. Add fds[0] to session's genai_epoll_fd_ for response notification + * 6. Return immediately (MySQL thread is now free to process other queries) + * + * The response will be delivered asynchronously and handled by + * handle_genai_response() when the GenAI worker completes processing. + * + * Error handling: + * - On socketpair failure: Send ERR packet to client, return false + * - On register_client failure: Cleanup fds, send ERR packet, return false + * - On write failure: Cleanup request via genai_cleanup_request(), send ERR packet + * - On epoll add failure: Log warning but continue (request was sent successfully) + * + * Memory management: + * - Original packet is copied to pending.original_pkt for response generation + * - Memory is freed in genai_cleanup_request() when response is processed + * + * @param query JSON query string to send to GenAI module + * @param query_len Length of the query string + * @param pkt Original MySQL packet (for command number and later response) + * @return true if request was sent successfully, false on error + * + * @note This function is non-blocking and returns immediately after sending. + * The actual GenAI processing happens in worker threads, not MySQL threads. + * @see handle_genai_response(), genai_cleanup_request(), check_genai_events() + */ +bool MySQL_Session::handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___genai_send_async( + const char* query, size_t query_len, PtrSize_t* pkt) { + + // Create socketpair for async communication + int fds[2]; + if (socketpair(AF_UNIX, SOCK_STREAM, 0, fds) < 0) { + proxy_error("GenAI: socketpair failed: %s\n", strerror(errno)); + client_myds->myprot.generate_pkt_ERR(true, NULL, NULL, 1, 1260, (char*)"HY000", + "Failed to create GenAI communication channel", true); + return false; + } + + // Set MySQL side to non-blocking + int flags = fcntl(fds[0], F_GETFL, 0); + fcntl(fds[0], F_SETFL, flags | O_NONBLOCK); + + // Register GenAI side with GenAI module + if (!GloGATH->register_client(fds[1])) { + proxy_error("GenAI: Failed to register client fd %d with GenAI module\n", fds[1]); + close(fds[0]); + close(fds[1]); + client_myds->myprot.generate_pkt_ERR(true, NULL, NULL, 1, 1261, (char*)"HY000", + "Failed to register with GenAI module", true); + return false; + } + + // Prepare request header + GenAI_RequestHeader hdr; + hdr.request_id = next_genai_request_id_++; + hdr.operation = GENAI_OP_JSON; + hdr.query_len = query_len; + hdr.flags = 0; + hdr.top_n = 0; + + // Store request in pending map + GenAI_PendingRequest pending; + pending.request_id = hdr.request_id; + pending.client_fd = fds[0]; + pending.json_query = std::string(query, query_len); + pending.start_time = std::chrono::steady_clock::now(); + + // Copy the original packet for later response + pending.original_pkt = (PtrSize_t*)malloc(sizeof(PtrSize_t)); + if (!pending.original_pkt) { + proxy_error("GenAI: Failed to allocate memory for packet copy\n"); + close(fds[0]); + close(fds[1]); + client_myds->myprot.generate_pkt_ERR(true, NULL, NULL, 1, 1262, (char*)"HY000", + "Memory allocation failed", true); + return false; + } + pending.original_pkt->ptr = pkt->ptr; + pending.original_pkt->size = pkt->size; + + pending_genai_requests_[hdr.request_id] = pending; + + // Send request header + ssize_t written = write(fds[0], &hdr, sizeof(hdr)); + if (written != sizeof(hdr)) { + proxy_error("GenAI: Failed to write request header to fd %d: %s\n", + fds[0], strerror(errno)); + genai_cleanup_request(hdr.request_id); + client_myds->myprot.generate_pkt_ERR(true, NULL, NULL, 1, 1263, (char*)"HY000", + "Failed to send request to GenAI module", true); + return false; + } + + // Send JSON query + size_t total_written = 0; + while (total_written < query_len) { + ssize_t w = write(fds[0], query + total_written, query_len - total_written); + if (w <= 0) { + if (errno == EAGAIN || errno == EWOULDBLOCK) { + usleep(1000); + continue; + } + proxy_error("GenAI: Failed to write JSON query to fd %d: %s\n", + fds[0], strerror(errno)); + genai_cleanup_request(hdr.request_id); + client_myds->myprot.generate_pkt_ERR(true, NULL, NULL, 1, 1264, (char*)"HY000", + "Failed to send query to GenAI module", true); + return false; + } + total_written += w; + } + + // Add to epoll for response notification + struct epoll_event ev; + ev.events = EPOLLIN; + ev.data.fd = fds[0]; + + if (epoll_ctl(genai_epoll_fd_, EPOLL_CTL_ADD, fds[0], &ev) < 0) { + proxy_error("GenAI: Failed to add fd %d to epoll: %s\n", fds[0], strerror(errno)); + // Request is sent, but we won't be notified of response + // This is not fatal - we'll timeout eventually + } + + proxy_debug(PROXY_DEBUG_GENAI, 3, + "GenAI: Sent async request %lu via fd %d (query_len=%zu)\n", + hdr.request_id, fds[0], query_len); + + return true; // Success - request sent asynchronously +} + +/** + * @brief Handle GenAI response from socketpair + * + * This function is called when epoll notifies that data is available on a + * GenAI response file descriptor. It reads the response from the GenAI worker + * thread, processes the result, and sends the MySQL result packet to the client. + * + * Response handling flow: + * 1. Read GenAI_ResponseHeader from socketpair + * 2. Find matching pending request via request_id in pending_genai_requests_ + * 3. Read JSON result payload (if result_len > 0) + * 4. Parse JSON and convert to MySQL resultset format + * 5. Send result packet (or ERR packet on error) to client + * 6. Cleanup resources via genai_cleanup_request() + * + * Response format (from GenAI worker): + * - GenAI_ResponseHeader (request_id, status_code, result_len, processing_time_ms) + * - JSON result payload (if result_len > 0) + * + * Error handling: + * - On read error: Find and cleanup pending request, return + * - On incomplete header: Log error, return + * - On unknown request_id: Log error, close fd, return + * - On status_code != 0: Send ERR packet to client with error details + * - On JSON parse error: Send ERR packet to client + * + * RTT (Round-Trip Time) tracking: + * - Calculates RTT from request start to response receipt + * - Logs RTT along with GenAI processing time for monitoring + * + * @param fd The MySQL side file descriptor from socketpair (fds[0]) + * + * @note This function is called from check_genai_events() which is invoked + * from the main handler() loop. It runs in the MySQL thread context. + * @see genai_send_async(), genai_cleanup_request(), check_genai_events() + */ +void MySQL_Session::handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___handle_genai_response(int fd) { + // Read response header + GenAI_ResponseHeader resp; + ssize_t n = read(fd, &resp, sizeof(resp)); + + if (n < 0) { + // Check for non-blocking read - not an error, just no data yet + if (errno == EAGAIN || errno == EWOULDBLOCK) { + return; + } + // Real error - log and cleanup + proxy_error("GenAI: Error reading response header from fd %d: %s\n", + fd, strerror(errno)); + } else if (n == 0) { + // Connection closed (EOF) - cleanup + } else { + // Successfully read header, continue processing + goto process_response; + } + + // Cleanup path for error or EOF + for (auto& pair : pending_genai_requests_) { + if (pair.second.client_fd == fd) { + genai_cleanup_request(pair.first); + break; + } + } + return; + +process_response: + + if (n != sizeof(resp)) { + proxy_error("GenAI: Incomplete response header from fd %d: got %zd, expected %zu\n", + fd, n, sizeof(resp)); + return; + } + + // Find the pending request + auto it = pending_genai_requests_.find(resp.request_id); + if (it == pending_genai_requests_.end()) { + proxy_error("GenAI: Received response for unknown request %lu\n", resp.request_id); + close(fd); + return; + } + + GenAI_PendingRequest& pending = it->second; + + // Read JSON result + std::string json_result; + if (resp.result_len > 0) { + json_result.resize(resp.result_len); + size_t total_read = 0; + while (total_read < resp.result_len) { + ssize_t r = read(fd, &json_result[total_read], + resp.result_len - total_read); + if (r <= 0) { + if (errno == EAGAIN || errno == EWOULDBLOCK) { + usleep(1000); + continue; + } + proxy_error("GenAI: Error reading JSON result from fd %d: %s\n", + fd, strerror(errno)); + json_result.clear(); + break; + } + total_read += r; + } + } + + // Process the result + auto end_time = std::chrono::steady_clock::now(); + int rtt_ms = std::chrono::duration_cast( + end_time - pending.start_time).count(); + + proxy_debug(PROXY_DEBUG_GENAI, 3, + "GenAI: Received response %lu (status=%u, result_len=%u, rtt=%dms, proc=%dms)\n", + resp.request_id, resp.status_code, resp.result_len, rtt_ms, resp.processing_time_ms); + + // Check for errors + if (resp.status_code != 0 || json_result.empty()) { + client_myds->DSS = STATE_QUERY_SENT_NET; + client_myds->myprot.generate_pkt_ERR(true, NULL, NULL, 1, 1265, (char*)"HY000", + "GenAI query processing failed", true); + } else { + // Parse JSON result and send resultset + try { + json result = json::parse(json_result); + + if (!result.is_object()) { + client_myds->DSS = STATE_QUERY_SENT_NET; + client_myds->myprot.generate_pkt_ERR(true, NULL, NULL, 1, 1266, (char*)"HY000", + "GenAI returned invalid result format", true); + } else if (result.contains("error") && result["error"].is_string()) { + std::string error_msg = result["error"].get(); + client_myds->DSS = STATE_QUERY_SENT_NET; + client_myds->myprot.generate_pkt_ERR(true, NULL, NULL, 1, 1267, (char*)"HY000", + (char*)error_msg.c_str(), true); + } else if (!result.contains("columns") || !result.contains("rows")) { + client_myds->DSS = STATE_QUERY_SENT_NET; + client_myds->myprot.generate_pkt_ERR(true, NULL, NULL, 1, 1268, (char*)"HY000", + "GenAI result missing required fields", true); + } else { + // Build and send resultset + auto columns = result["columns"]; + auto rows = result["rows"]; + + std::unique_ptr resultset(new SQLite3_result(columns.size())); + + // Add column definitions + for (size_t i = 0; i < columns.size(); i++) { + if (columns[i].is_string()) { + std::string col_name = columns[i].get(); + resultset->add_column_definition(SQLITE_TEXT, (char*)col_name.c_str()); + } + } + + // Add rows + for (const auto& row : rows) { + if (!row.is_array()) continue; + + size_t num_cols = row.size(); + if (num_cols > columns.size()) num_cols = columns.size(); + + char** row_data = (char**)malloc(num_cols * sizeof(char*)); + size_t valid_cols = 0; + + for (size_t i = 0; i < num_cols; i++) { + if (!row[i].is_null()) { + std::string val; + if (row[i].is_string()) { + val = row[i].get(); + } else { + val = row[i].dump(); + } + row_data[valid_cols++] = strdup(val.c_str()); + } + } + + resultset->add_row(row_data); + + for (size_t i = 0; i < valid_cols; i++) { + if (row_data[i]) free(row_data[i]); + } + free(row_data); + } + + // Send resultset to client + SQLite3_to_MySQL(resultset.get(), NULL, 0, &client_myds->myprot, false, + (client_myds->myconn->options.client_flag & CLIENT_DEPRECATE_EOF)); + } + } catch (const json::parse_error& e) { + client_myds->DSS = STATE_QUERY_SENT_NET; + std::string err_msg = "Failed to parse GenAI result: "; + err_msg += e.what(); + client_myds->myprot.generate_pkt_ERR(true, NULL, NULL, 1, 1269, (char*)"HY000", + (char*)err_msg.c_str(), true); + } catch (const std::exception& e) { + client_myds->DSS = STATE_QUERY_SENT_NET; + std::string err_msg = "Error processing GenAI result: "; + err_msg += e.what(); + client_myds->myprot.generate_pkt_ERR(true, NULL, NULL, 1, 1270, (char*)"HY000", + (char*)err_msg.c_str(), true); + } + } + + // Cleanup the request + genai_cleanup_request(resp.request_id); + + // Return to waiting state + client_myds->DSS = STATE_SLEEP; + status = WAITING_CLIENT_DATA; +} + +/** + * @brief Cleanup a GenAI pending request + * + * This function cleans up all resources associated with a GenAI pending request. + * It is called after a response has been processed or when an error occurs. + * + * Cleanup operations: + * 1. Remove request from pending_genai_requests_ map + * 2. Close the socketpair file descriptor (client_fd) + * 3. Remove fd from genai_epoll_fd_ monitoring + * 4. Free the original packet memory (original_pkt) + * + * Resource cleanup details: + * - client_fd: The MySQL side of the socketpair (fds[0]) is closed + * - epoll: The fd is removed from genai_epoll_fd_ to stop monitoring + * - original_pkt: The copied packet memory is freed (ptr and size) + * - pending map: The request entry is removed from the map + * + * This function must be called exactly once per request to avoid: + * - File descriptor leaks (unclosed sockets) + * - Memory leaks (unfreed packets) + * - Epoll monitoring stale fds (removed from map but still in epoll) + * + * @param request_id The request ID to cleanup (must exist in pending_genai_requests_) + * + * @note This function is idempotent - if the request_id is not found, it safely + * returns without error (useful for error paths where cleanup might be + * called multiple times). + * @note If the request is not found in the map, this function silently returns + * without error (this is intentional to avoid crashes on double cleanup). + * + * @see genai_send_async(), handle_genai_response() + */ +void MySQL_Session::genai_cleanup_request(uint64_t request_id) { + auto it = pending_genai_requests_.find(request_id); + if (it == pending_genai_requests_.end()) { + return; + } + + GenAI_PendingRequest& pending = it->second; + + // Remove from epoll + epoll_ctl(genai_epoll_fd_, EPOLL_CTL_DEL, pending.client_fd, nullptr); + + // Close socketpair fds + close(pending.client_fd); + + // Free the original packet + if (pending.original_pkt) { + l_free(pending.original_pkt->size, pending.original_pkt->ptr); + free(pending.original_pkt); + } + + pending_genai_requests_.erase(it); + + proxy_debug(PROXY_DEBUG_GENAI, 3, "GenAI: Cleaned up request %lu\n", request_id); +} + +/** + * @brief Check for pending GenAI responses + * + * This function performs a non-blocking epoll_wait on the session's GenAI epoll + * file descriptor to check if any responses from GenAI workers are ready to be + * processed. It is called from the main handler() loop in the WAITING_CLIENT_DATA + * state to interleave GenAI response processing with normal client query handling. + * + * Event checking flow: + * 1. Early return if no pending requests (empty pending_genai_requests_) + * 2. Non-blocking epoll_wait with timeout=0 on genai_epoll_fd_ + * 3. For each ready fd, find matching pending request + * 4. Call handle_genai_response() to process the response + * 5. Return true after processing one response (to re-check for more) + * + * Integration with main loop: + * ```cpp + * handler_again: + * switch (status) { + * case WAITING_CLIENT_DATA: + * handler___status_WAITING_CLIENT_DATA(); + * #ifdef epoll_create1 + * // Check for GenAI responses before processing new client data + * if (check_genai_events()) { + * // GenAI response was processed, check for more + * goto handler_again; + * } + * #endif + * break; + * } + * ``` + * + * Non-blocking behavior: + * - epoll_wait timeout is 0 (immediate return) + * - Returns true only if a response was actually processed + * - Allows main loop to continue processing client queries between responses + * + * Return value: + * - true: A GenAI response was processed (caller should re-check for more) + * - false: No responses ready (caller can proceed to normal client handling) + * + * @return true if a GenAI response was processed, false otherwise + * + * @note This function is called from the main handler() loop on every iteration + * when in WAITING_CLIENT_DATA state. It must return quickly to avoid + * delaying normal client query processing. + * @note Only processes one response per call to avoid starving client handling. + * The main loop will call again to process additional responses. + * + * @see handle_genai_response(), genai_send_async() + */ +bool MySQL_Session::check_genai_events() { +#ifdef epoll_create1 + if (pending_genai_requests_.empty()) { + return false; + } + + const int MAX_EVENTS = 16; + struct epoll_event events[MAX_EVENTS]; + + int nfds = epoll_wait(genai_epoll_fd_, events, MAX_EVENTS, 0); // Non-blocking check + + if (nfds <= 0) { + return false; + } + + for (int i = 0; i < nfds; i++) { + int fd = events[i].data.fd; + + // Find the pending request for this fd + for (auto it = pending_genai_requests_.begin(); it != pending_genai_requests_.end(); ++it) { + if (it->second.client_fd == fd) { + handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___handle_genai_response(fd); + return true; // Processed one response + } + } + } + + return false; +#else + return false; +#endif +} +#endif + // this function was inline inside MySQL_Session::get_pkts_from_client // where: // status = WAITING_CLIENT_DATA @@ -5004,6 +5678,13 @@ int MySQL_Session::handler() { case WAITING_CLIENT_DATA: // housekeeping handler___status_WAITING_CLIENT_DATA(); +#ifdef epoll_create1 + // Check for GenAI responses before processing new client data + if (check_genai_events()) { + // GenAI response was processed, check for more + goto handler_again; + } +#endif break; case FAST_FORWARD: if (mybe->server_myds->mypolls==NULL) { @@ -6067,6 +6748,19 @@ bool MySQL_Session::handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_C */ bool exit_after_SetParse = true; unsigned char command_type=*((unsigned char *)pkt->ptr+sizeof(mysql_hdr)); + + // Check for GENAI: queries - experimental GenAI integration + if (pkt->size > sizeof(mysql_hdr) + 7) { // Need at least "GENAI: " (7 chars after header) + const char* query_ptr = (const char*)pkt->ptr + sizeof(mysql_hdr) + 1; + size_t query_len = pkt->size - sizeof(mysql_hdr) - 1; + + if (query_len >= 7 && strncasecmp(query_ptr, "GENAI:", 6) == 0) { + // This is a GENAI: query - handle with GenAI module + handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_COM_QUERY___genai(query_ptr + 6, query_len - 6, pkt); + return true; + } + } + if (qpo->new_query) { handler_WCD_SS_MCQ_qpo_QueryRewrite(pkt); } diff --git a/lib/MySQL_Thread.cpp b/lib/MySQL_Thread.cpp index f8191ba9e0..78d164edfb 100644 --- a/lib/MySQL_Thread.cpp +++ b/lib/MySQL_Thread.cpp @@ -3009,10 +3009,46 @@ void MySQL_Thread::poll_listener_add(int sock) { listener_DS->fd=sock; proxy_debug(PROXY_DEBUG_NET,1,"Created listener %p for socket %d\n", listener_DS, sock); + + /** + * @brief Register listener socket with ProxySQL_Poll for incoming connections + * + * This usage pattern registers a listener socket file descriptor with the ProxySQL_Poll instance + * to monitor for incoming client connections. The listener data stream handles the accept() + * operation when connection events are detected. + * + * Usage pattern: mypolls.add(POLLIN, sock, listener_DS, monotonic_time()) + * - POLLIN: Monitor for read events (new connections ready to accept) + * - sock: Listener socket file descriptor + * - listener_DS: Data stream associated with the listener (accepts connections) + * - monotonic_time(): Current timestamp for tracking socket registration time + * + * Called during: Listener setup and initialization + * Purpose: Enables the thread to accept incoming MySQL client connections + */ mypolls.add(POLLIN, sock, listener_DS, monotonic_time()); } void MySQL_Thread::poll_listener_del(int sock) { + /** + * @brief Remove listener socket from the poll set using efficient index lookup + * + * This usage pattern demonstrates the complete removal workflow for listener sockets: + * 1. Find the index of the socket in the poll set using find_index() + * 2. Remove the socket using remove_index_fast() with the found index + * + * Usage pattern: + * int i = mypollolls.find_index(sock); // Find index by file descriptor + * if (i>=0) { + * mypolls.remove_index_fast(i); // Remove by index (O(1) operation) + * } + * + * find_index(sock): Returns index of socket or -1 if not found + * remove_index_fast(i): Removes the entry at index i efficiently + * + * Called during: Listener shutdown and cleanup + * Purpose: Properly removes listener sockets from polling to prevent memory leaks + */ int i=mypolls.find_index(sock); if (i>=0) { MySQL_Data_Stream *myds=mypolls.myds[i]; @@ -3288,6 +3324,26 @@ void MySQL_Thread::run() { //this is the only portion of code not protected by a global mutex proxy_debug(PROXY_DEBUG_NET,5,"Calling poll with timeout %d\n", ttw ); // poll is called with a timeout of mypolls.poll_timeout if set , or mysql_thread___poll_timeout + /** + * @brief Execute main poll() loop to monitor all registered FDs + * + * This usage pattern demonstrates the core polling mechanism that drives ProxySQL's event loop. + * The poll() system call blocks until one of the registered file descriptors becomes ready + * or the timeout expires. + * + * Usage pattern: rc = poll(mypolls.fds, mypolls.len, ttw) + * - mypollolls.fds: Array of pollfd structures containing file descriptors and events + * - mypolls.len: Number of file descriptors to monitor + * - ttw: Timeout in milliseconds (mydynamic poll timeout) + * + * Return codes: + * - > 0: Number of file descriptors with events + * - 0: Timeout occurred + * - -1: Error (errno set) + * + * Called during: Main event loop iteration + * Purpose: Enables efficient I/O multiplexing across all connections + */ rc=poll(mypolls.fds,mypolls.len, ttw); proxy_debug(PROXY_DEBUG_NET,5,"%s\n", "Returning poll"); #ifdef IDLE_THREADS @@ -3594,6 +3650,23 @@ void MySQL_Thread::worker_thread_gets_sessions_from_idle_thread() { MySQL_Session *mysess=(MySQL_Session *)myexchange.resume_mysql_sessions->remove_index_fast(0); register_session(this, mysess, false); MySQL_Data_Stream *myds=mysess->client_myds; + + /** + * @brief Add client session to poll set for resumed connections + * + * This usage pattern registers a client data stream for resumed connections + * during session restoration in IDLE_THREADS mode. + * + * Usage pattern: mypolls.add(POLLIN, myds->fd, myds, monotonic_time()) + * - POLLIN: Monitor for read events (client data available) + * - myds->fd: Client socket file descriptor + * - myds: MySQL_Data_Stream instance for the client session + * - monotonic_time(): Current timestamp for tracking session registration time + * + * Called during: Session restoration in IDLE_THREADS mode + * Purpose: Enables the thread to receive and process client MySQL protocol data + * for resumed sessions + */ mypolls.add(POLLIN, myds->fd, myds, monotonic_time()); } } @@ -3648,7 +3721,7 @@ bool MySQL_Thread::process_data_on_data_stream(MySQL_Data_Stream *myds, unsigned // // this can happen, for example, with a low wait_timeout and running transaction if (myds->sess->status==WAITING_CLIENT_DATA) { - if (myds->myconn->async_state_machine==ASYNC_IDLE) { + if (myds->myconn && myds->myconn->async_state_machine==ASYNC_IDLE) { proxy_warning("Detected broken idle connection on %s:%d\n", myds->myconn->parent->address, myds->myconn->parent->port); myds->destroy_MySQL_Connection_From_Pool(false); myds->sess->set_unhealthy(); @@ -3658,6 +3731,10 @@ bool MySQL_Thread::process_data_on_data_stream(MySQL_Data_Stream *myds, unsigned } return true; } + if (myds->myds_type==MYDS_INTERNAL_GENAI) { + // INTERNAL_GENAI doesn't need special idle connection handling + return true; + } if (mypolls.fds[n].revents) { if (mypolls.myds[n]->DSS < STATE_MARIADB_BEGIN || mypolls.myds[n]->DSS > STATE_MARIADB_END) { // only if we aren't using MariaDB Client Library @@ -4510,6 +4587,26 @@ void MySQL_Thread::listener_handle_new_connection(MySQL_Data_Stream *myds, unsig } sess->client_myds->myprot.generate_pkt_initial_handshake(true,NULL,NULL, &sess->thread_session_id, true); ioctl_FIONBIO(sess->client_myds->fd, 1); + + /** + * @brief Add client socket to poll set with both read and write monitoring + * + * This usage pattern registers a client socket with both POLLIN and POLLOUT events, + * which is typically done during initial client setup when we need to send the + * initial handshake packet and also be ready to receive client responses. + * + * Usage pattern: mypolls.add(POLLIN|POLLOUT, sess->client_myds->fd, sess->client_myds, curtime) + * - POLLIN|POLLOUT: Monitor both read and write events + * - sess->client_myds->fd: Client socket file descriptor + * - sess->client_myds: MySQL_Data_Stream instance for the client + * - curtime: Current timestamp for tracking + * + * Called during: Initial client connection setup after handshake packet generation + * Purpose: Enables bidirectional communication with the client during setup phase + * + * Note: This ensures we can send the initial handshake immediately and also handle + * any client packets that might arrive before the handshake is complete. + */ mypolls.add(POLLIN|POLLOUT, sess->client_myds->fd, sess->client_myds, curtime); proxy_debug(PROXY_DEBUG_NET,1,"Session=%p -- Adding client FD %d\n", sess, sess->client_myds->fd); diff --git a/lib/PgSQL_Data_Stream.cpp b/lib/PgSQL_Data_Stream.cpp index 9a7a0c0a08..ff326f2c75 100644 --- a/lib/PgSQL_Data_Stream.cpp +++ b/lib/PgSQL_Data_Stream.cpp @@ -321,6 +321,24 @@ PgSQL_Data_Stream::~PgSQL_Data_Stream() { delete PSarrayOUT; } + /** + * @brief Remove PostgreSQL data stream from poll set during destruction + * + * This usage pattern demonstrates how ProxySQL_Poll is used during PgSQL_Data_Stream cleanup: + * - Removes the data stream entry from the poll set to prevent polling on closed socket + * - Uses the stored poll_fds_idx for efficient removal (O(1) operation) + * - Called only if mypolls is not NULL (data stream is registered with a poll instance) + * + * Usage pattern: if (mypolls) mypolls->remove_index_fast(poll_fds_idx) + * - mypolls: Check if data stream is registered with a poll instance + * - remove_index_fast(poll_fds_idx): Remove by stored index from data stream + * + * Called during: PgSQL_Data_Stream destructor + * Purpose: Prevent memory leaks and ensure proper cleanup of poll entries + * + * Note: Each data stream maintains its poll_fds_idx to track its position in the poll array + * for efficient removal without requiring find_index() lookup. + */ if (mypolls) mypolls->remove_index_fast(poll_fds_idx); if (fd > 0) { @@ -652,6 +670,26 @@ int PgSQL_Data_Stream::read_from_net() { else { queue_w(queueIN, r); bytes_info.bytes_recv += r; + + /** + * @brief Update receive timestamp in ProxySQL_Poll for PostgreSQL activity tracking + * + * This usage pattern demonstrates how ProxySQL_Poll is used for PostgreSQL activity monitoring: + * - Updates the last receive timestamp in the poll entry for timeout management + * - Called after successful data reception to track connection activity + * - Uses the stored poll_fds_idx for direct array access (O(1) operation) + * + * Usage pattern: if (mypolls) mypolls->last_recv[poll_fds_idx] = sess->thread->curtime + * - mypolls: Check if data stream is registered with a poll instance + * - last_recv[poll_fds_idx]: Update the receive timestamp for this data stream + * - sess->thread->curtime: Current timestamp from the thread + * + * Called during: After receiving data on the PostgreSQL data stream + * Purpose: Enable timeout management and connection activity monitoring + * + * Note: This timestamp is used by the idle connection timeout system to detect + * inactive PostgreSQL connections that should be closed. + */ if (mypolls) mypolls->last_recv[poll_fds_idx] = sess->thread->curtime; } return r; @@ -759,6 +797,25 @@ int PgSQL_Data_Stream::write_to_net() { } else { queue_r(queueOUT, bytes_io); + /** + * @brief Update send timestamp in ProxySQL_Poll for PostgreSQL activity tracking + * + * This usage pattern demonstrates how ProxySQL_Poll is used for PostgreSQL activity monitoring: + * - Updates the last send timestamp in the poll entry for timeout management + * - Called after successful data transmission to track connection activity + * - Uses the stored poll_fds_idx for direct array access (O(1) operation) + * + * Usage pattern: if (mypolls) mypolls->last_sent[poll_fds_idx] = sess->thread->curtime + * - mypolls: Check if data stream is registered with a poll instance + * - last_sent[poll_fds_idx]: Update the send timestamp for this data stream + * - sess->thread->curtime: Current timestamp from the thread + * + * Called during: After sending data on the PostgreSQL data stream + * Purpose: Enable timeout management and connection activity monitoring + * + * Note: This timestamp is used by the idle connection timeout system to detect + * inactive PostgreSQL connections that should be closed. + */ if (mypolls) mypolls->last_sent[poll_fds_idx] = sess->thread->curtime; bytes_info.bytes_sent += bytes_io; } diff --git a/lib/PgSQL_Thread.cpp b/lib/PgSQL_Thread.cpp index 94c8494cc7..dbf42d8b87 100644 --- a/lib/PgSQL_Thread.cpp +++ b/lib/PgSQL_Thread.cpp @@ -2782,10 +2782,46 @@ void PgSQL_Thread::poll_listener_add(int sock) { listener_DS->fd = sock; proxy_debug(PROXY_DEBUG_NET, 1, "Created listener %p for socket %d\n", listener_DS, sock); + + /** + * @brief Register PostgreSQL listener socket with ProxySQL_Poll for incoming connections + * + * This usage pattern registers a PostgreSQL listener socket file descriptor with the ProxySQL_Poll instance + * to monitor for incoming PostgreSQL client connections. The listener data stream handles the accept() + * operation when connection events are detected. + * + * Usage pattern: mypolls.add(POLLIN, sock, listener_DS, monotonic_time()) + * - POLLIN: Monitor for read events (new connections ready to accept) + * - sock: Listener socket file descriptor + * - listener_DS: Data stream associated with the listener (accepts connections) + * - monotonic_time(): Current timestamp for tracking socket registration time + * + * Called during: PostgreSQL listener setup and initialization + * Purpose: Enables the thread to accept incoming PostgreSQL client connections + */ mypolls.add(POLLIN, sock, listener_DS, monotonic_time()); } void PgSQL_Thread::poll_listener_del(int sock) { + /** + * @brief Remove PostgreSQL listener socket from the poll set using efficient index lookup + * + * This usage pattern demonstrates the complete removal workflow for PostgreSQL listener sockets: + * 1. Find the index of the socket in the poll set using find_index() + * 2. Remove the socket using remove_index_fast() with the found index + * + * Usage pattern: + * int i = mypolls.find_index(sock); // Find index by file descriptor + * if (i>=0) { + * mypolls.remove_index_fast(i); // Remove by index (O(1) operation) + * } + * + * find_index(sock): Returns index of socket or -1 if not found + * remove_index_fast(i): Removes the entry at index i efficiently + * + * Called during: PostgreSQL listener shutdown and cleanup + * Purpose: Properly removes listener sockets from polling to prevent memory leaks + */ int i = mypolls.find_index(sock); if (i >= 0) { PgSQL_Data_Stream* myds = mypolls.myds[i]; @@ -2961,7 +2997,27 @@ void PgSQL_Thread::run() { #endif // IDLE_THREADS //this is the only portion of code not protected by a global mutex proxy_debug(PROXY_DEBUG_NET, 5, "Calling poll with timeout %d\n", ttw); - // poll is called with a timeout of mypolls.poll_timeout if set , or pgsql_thread___poll_timeout + /** + * @brief Execute main poll() loop to monitor all registered FDs for PostgreSQL thread + * + * This usage pattern demonstrates the core polling mechanism that drives ProxySQL's PostgreSQL event loop. + * The poll() system call blocks until one of the registered file descriptors becomes ready + * or the timeout expires. + * + * Usage pattern: rc = poll(mypolls.fds, mypolls.len, ttw) + * - mypollolls.fds: Array of pollfd structures containing file descriptors and events + * - mypolls.len: Number of file descriptors to monitor + * - ttw: Timeout in milliseconds (dynamic poll timeout) + * + * Return codes: + * - > 0: Number of file descriptors with events + * - 0: Timeout occurred + * - -1: Error (errno set) + * + * Called during: Main PostgreSQL event loop iteration + * Purpose: Enables efficient I/O multiplexing across all PostgreSQL connections + */ + // poll is called with a timeout of mypolls.poll_timeout if set , or pgsql_thread___poll_timeout rc = poll(mypolls.fds, mypolls.len, ttw); proxy_debug(PROXY_DEBUG_NET, 5, "%s\n", "Returning poll"); #ifdef IDLE_THREADS @@ -4065,6 +4121,25 @@ void PgSQL_Thread::listener_handle_new_connection(PgSQL_Data_Stream * myds, unsi sess->status = CONNECTING_CLIENT; ioctl_FIONBIO(sess->client_myds->fd, 1); + /** + * @brief Add PostgreSQL client socket to poll set with both read and write monitoring + * + * This usage pattern registers a PostgreSQL client socket with both POLLIN and POLLOUT events, + * which is typically done during initial client setup when we need to establish the connection + * and also be ready to receive client responses. + * + * Usage pattern: mypolls.add(POLLIN|POLLOUT, sess->client_myds->fd, sess->client_myds, curtime) + * - POLLIN|POLLOUT: Monitor both read and write events + * - sess->client_myds->fd: Client socket file descriptor + * - sess->client_myds: PgSQL_Data_Stream instance for the client + * - curtime: Current timestamp for tracking + * + * Called during: Initial PostgreSQL client connection setup + * Purpose: Enables bidirectional communication with the client during setup phase + * + * Note: This ensures we can establish the connection immediately and also handle + * any client packets that might arrive during the connection process. + */ mypolls.add(POLLIN | POLLOUT, sess->client_myds->fd, sess->client_myds, curtime); proxy_debug(PROXY_DEBUG_NET, 1, "Session=%p -- Adding client FD %d\n", sess, sess->client_myds->fd); diff --git a/lib/ProxySQL_Admin.cpp b/lib/ProxySQL_Admin.cpp index ebd2a2301f..1cb678b488 100644 --- a/lib/ProxySQL_Admin.cpp +++ b/lib/ProxySQL_Admin.cpp @@ -2610,6 +2610,7 @@ ProxySQL_Admin::ProxySQL_Admin() : generate_load_save_disk_commands("pgsql_users", "PGSQL USERS"); generate_load_save_disk_commands("pgsql_servers", "PGSQL SERVERS"); generate_load_save_disk_commands("pgsql_variables", "PGSQL VARIABLES"); + generate_load_save_disk_commands("genai_variables", "GENAI VARIABLES"); generate_load_save_disk_commands("scheduler", "SCHEDULER"); generate_load_save_disk_commands("restapi", "RESTAPI"); generate_load_save_disk_commands("proxysql_servers", "PROXYSQL SERVERS"); @@ -2838,6 +2839,12 @@ void ProxySQL_Admin::init_pgsql_variables() { flush_pgsql_variables___database_to_runtime(admindb, true); } +void ProxySQL_Admin::init_genai_variables() { + flush_genai_variables___runtime_to_database(configdb, false, false, false); + flush_genai_variables___runtime_to_database(admindb, false, true, false); + flush_genai_variables___database_to_runtime(admindb, true); +} + void ProxySQL_Admin::admin_shutdown() { int i; // do { usleep(50); } while (main_shutdown==0); diff --git a/lib/ProxySQL_Poll.cpp b/lib/ProxySQL_Poll.cpp index 637a288799..aadb91ceae 100644 --- a/lib/ProxySQL_Poll.cpp +++ b/lib/ProxySQL_Poll.cpp @@ -12,17 +12,80 @@ /** * @file ProxySQL_Poll.cpp * - * These functions provide functionality for managing file descriptors (FDs) and associated data streams in the ProxySQL_Poll class. - * They handle memory allocation, addition, removal, and searching of FDs within the poll object. - * Additionally, they ensure that memory is managed efficiently by dynamically resizing the internal arrays as needed. -*/ + * @brief Core I/O Multiplexing Engine for ProxySQL's Event-Driven Architecture + * + * The ProxySQL_Poll class is the heart of ProxySQL's event-driven I/O system, managing file descriptors + * and their associated data streams using the poll() system call. It serves as the central mechanism + * for handling thousands of concurrent connections efficiently. + * + * @section Architecture Integration + * + * ProxySQL_Poll is integrated throughout the ProxySQL codebase as follows: + * + * - **Thread Classes**: Each MySQL_Thread and PgSQL_Thread contains a ProxySQL_Poll instance + * - **Data Stream Classes**: MySQL_Data_Stream and PgSQL_Data_Stream maintain pointers to their poll instance + * - **Main Event Loop**: Forms the core of the poll() loop in both thread types + * + * @section Template Specialization + * + * ProxySQL_Poll is templated to work with different data stream types: + * - ProxySQL_Poll for MySQL protocol connections + * - ProxySQL_Poll for PostgreSQL protocol connections + * + * @section Memory Management + * + * The class implements sophisticated memory management: + * - Initial allocation: MIN_POLL_LEN (32) FDs + * - Dynamic expansion: Uses l_near_pow_2() for power-of-2 sizing + * - Automatic shrinking: When FD count drops below threshold + * - Efficient cleanup: Proper deallocation in destructor + * + * @section Event Processing Pipeline + * + * ProxySQL_Poll integrates into the event processing pipeline: + * 1. Before Poll: ProcessAllMyDS_BeforePoll() - set up poll events, handle timeouts + * 2. Poll Execution: poll() system call with dynamic timeout + * 3. After Poll: ProcessAllMyDS_AfterPoll() - process events, handle new connections + * + * @section Key Features + * + * - **Scalability**: Efficiently handles thousands of concurrent connections + * - **Event-Driven**: Non-blocking I/O for high performance + * - **Bidirectional Integration**: Data streams maintain poll array indices + * - **Memory Efficiency**: Dynamic resizing optimizes memory usage + * - **Thread Safety**: Each thread maintains its own poll instance + * + * @section Integration with Data Streams + * + * Each data stream maintains critical integration fields: + * - `mypolls`: Pointer to parent ProxySQL_Poll instance + * - `poll_fds_idx`: Index in poll array for quick lookup + * - `last_recv/sent`: Timestamps managed by poll system + * + * This tight integration enables ProxySQL to achieve high performance by minimizing + * lookup overhead and maintaining efficient event-driven processing across all connections. + * + * For usage patterns and examples, see the documentation in: + * - MySQL_Thread.cpp + * - PgSQL_Thread.cpp + * - MySQL_Data_Stream.cpp + * - PgSQL_Data_Stream.cpp + * + * @see MySQL_Thread + * @see PgSQL_Thread + * @see MySQL_Data_Stream + * @see PgSQL_Data_Stream + */ /** * @brief Shrinks the ProxySQL_Poll object by reallocating memory to fit the current number of elements. - * + * * This function reduces the size of the ProxySQL_Poll object by reallocating memory to fit the current number of elements. * It adjusts the size of internal arrays to a size that is a power of two near the current number of elements. + * + * @note Called automatically when FD count drops below MIN_POLL_DELETE_RATIO threshold + * (see lib/ProxySQL_Poll.cpp:166 in remove_index_fast()) */ template void ProxySQL_Poll::shrink() { @@ -36,10 +99,13 @@ void ProxySQL_Poll::shrink() { /** * @brief Expands the ProxySQL_Poll object to accommodate additional elements. - * + * * This function expands the ProxySQL_Poll object to accommodate the specified number of additional elements. * If the resulting size after expansion exceeds the current size, it reallocates memory to fit the expanded size. - * + * + * @note Called automatically in add() method when current capacity is exhausted + * (see lib/ProxySQL_Poll.cpp:114 in add()) + * * @param more The number of additional elements to accommodate. */ template @@ -98,15 +164,27 @@ ProxySQL_Poll::~ProxySQL_Poll() { } /** - * @brief Adds a new file descriptor (FD) and its associated MySQL_Data_Stream to the ProxySQL_Poll object. - * - * This function adds a new file descriptor (FD) along with its associated MySQL_Data_Stream and relevant metadata + * @brief Adds a new file descriptor (FD) and its associated data stream to the ProxySQL_Poll object. + * + * This function adds a new file descriptor (FD) along with its associated data stream and relevant metadata * to the ProxySQL_Poll object. It automatically expands the internal arrays if needed. - * - * @param _events The events to monitor for the FD. - * @param _fd The file descriptor (FD) to add. - * @param _myds The MySQL_Data_Stream associated with the FD. - * @param sent_time The time when data was last sent on the FD. + * + * @section Integration + * - Sets up bidirectional relationship: data stream -> poll instance via myds[i]->mypolls=this + * - Sets up index tracking: data stream -> poll array index via myds[i]->poll_fds_idx=i + * - Initializes timestamps for connection timeout management + * + * @section Usage Patterns + * Called during: + * - New client connections (MySQL_Thread.cpp:4518) + * - New server connections (MySQL_Thread.cpp:3600) + * - Listener socket registration (MySQL_Thread.cpp:3015) + * - PostgreSQL session establishment (PgSQL_Session.cpp:1094) + * + * @param _events The events to monitor for the FD (POLLIN, POLLOUT, POLLIN|POLLOUT) + * @param _fd The file descriptor (FD) to add + * @param _myds The data stream object associated with the FD + * @param sent_time The timestamp when data was last sent on the FD */ template void ProxySQL_Poll::add(uint32_t _events, int _fd, T *_myds, unsigned long long sent_time) { @@ -142,12 +220,24 @@ void ProxySQL_Poll::update_fd_at_index(unsigned int idx, int _fd) { } /** - * @brief Removes a file descriptor (FD) and its associated MySQL_Data_Stream from the ProxySQL_Poll object. - * - * This function removes a file descriptor (FD) along with its associated MySQL_Data_Stream from the ProxySQL_Poll object. - * It also adjusts internal arrays and may shrink the ProxySQL_Poll object if necessary. - * - * @param i The index of the file descriptor (FD) to remove. + * @brief Removes a file descriptor (FD) and its associated data stream from the ProxySQL_Poll object. + * + * This function removes a file descriptor (FD) along with its associated data stream from the ProxySQL_Poll object. + * It uses a swap-and-pop technique for efficient removal and may shrink the object if necessary. + * + * @section Removal Algorithm + * 1. Sets data stream's poll_fds_idx to -1 to prevent double-free + * 2. If not last element, swaps with last element (O(1) removal) + * 3. Updates swapped element's poll_fds_idx to new position + * 4. Decrement length and potentially shrink arrays + * + * @section Usage Patterns + * Called during: + * - Data stream destruction (MySQL_Data_Stream.cpp:337, PgSQL_Data_Stream.cpp:1114) + * - Thread cleanup operations (MySQL_Thread.cpp:3451) + * - Connection termination and cleanup + * + * @param i The index of the file descriptor (FD) to remove */ template void ProxySQL_Poll::remove_index_fast(unsigned int i) { @@ -170,12 +260,23 @@ void ProxySQL_Poll::remove_index_fast(unsigned int i) { /** * @brief Finds the index of a file descriptor (FD) in the ProxySQL_Poll object. - * - * This function searches for a file descriptor (FD) in the ProxySQL_Poll object and returns its index if found. + * + * This function performs a linear search for a file descriptor (FD) in the ProxySQL_Poll object and returns its index if found. * If the FD is not found, it returns -1. - * - * @param fd The file descriptor (FD) to search for. - * @return The index of the file descriptor (FD) if found, otherwise -1. + * + * @section Performance Notes + * - O(n) complexity where n is number of FDs in poll set + * - Used for lookup operations where FD is known but poll index is needed + * - Data streams maintain their own poll_fds_idx for O(1) access in most cases + * + * @section Usage Patterns + * Called during: + * - Listener socket operations (MySQL_Thread.cpp:3019) + * - Connection management and lookup (PgSQL_Thread.cpp:2789) + * - Debugging and diagnostic operations + * + * @param fd The file descriptor (FD) to search for + * @return The index of the file descriptor (FD) if found, otherwise -1 */ template int ProxySQL_Poll::find_index(int fd) { diff --git a/lib/debug.cpp b/lib/debug.cpp index 440ef80242..980326ba10 100644 --- a/lib/debug.cpp +++ b/lib/debug.cpp @@ -541,6 +541,7 @@ void init_debug_struct() { GloVars.global.gdbg_lvl[PROXY_DEBUG_RESTAPI].name=(char *)"debug_restapi"; GloVars.global.gdbg_lvl[PROXY_DEBUG_MONITOR].name=(char *)"debug_monitor"; GloVars.global.gdbg_lvl[PROXY_DEBUG_CLUSTER].name=(char *)"debug_cluster"; + GloVars.global.gdbg_lvl[PROXY_DEBUG_GENAI].name=(char *)"debug_genai"; for (i=0;iremove_index_fast(poll_fds_idx) + * - mypolls: Check if data stream is registered with a poll instance + * - remove_index_fast(poll_fds_idx): Remove by stored index from data stream + * + * Called during: MySQL_Data_Stream destructor + * Purpose: Prevent memory leaks and ensure proper cleanup of poll entries + * + * Note: Each data stream maintains its poll_fds_idx to track its position in the poll array + * for efficient removal without requiring find_index() lookup. + */ if (mypolls) mypolls->remove_index_fast(poll_fds_idx); @@ -738,6 +756,25 @@ int MySQL_Data_Stream::read_from_net() { } else { queue_w(queueIN,r); bytes_info.bytes_recv+=r; + /** + * @brief Update receive timestamp in ProxySQL_Poll for activity tracking + * + * This usage pattern demonstrates how ProxySQL_Poll is used for activity monitoring: + * - Updates the last receive timestamp in the poll entry for timeout management + * - Called after successful data reception to track connection activity + * - Uses the stored poll_fds_idx for direct array access (O(1) operation) + * + * Usage pattern: if (mypolls) mypolls->last_recv[poll_fds_idx] = sess->thread->curtime + * - mypolls: Check if data stream is registered with a poll instance + * - last_recv[poll_fds_idx]: Update the receive timestamp for this data stream + * - sess->thread->curtime: Current timestamp from the thread + * + * Called during: After receiving data on the data stream + * Purpose: Enable timeout management and connection activity monitoring + * + * Note: This timestamp is used by the idle connection timeout system to detect + * inactive connections that should be closed. + */ if (mypolls) mypolls->last_recv[poll_fds_idx]=sess->thread->curtime; } return r; @@ -839,6 +876,26 @@ int MySQL_Data_Stream::write_to_net() { } } else { queue_r(queueOUT, bytes_io); + + /** + * @brief Update send timestamp in ProxySQL_Poll for activity tracking + * + * This usage pattern demonstrates how ProxySQL_Poll is used for activity monitoring: + * - Updates the last send timestamp in the poll entry for timeout management + * - Called after successful data transmission to track connection activity + * - Uses the stored poll_fds_idx for direct array access (O(1) operation) + * + * Usage pattern: if (mypolls) mypolls->last_sent[poll_fds_idx] = sess->thread->curtime + * - mypolls: Check if data stream is registered with a poll instance + * - last_sent[poll_fds_idx]: Update the send timestamp for this data stream + * - sess->thread->curtime: Current timestamp from the thread + * + * Called during: After sending data on the data stream + * Purpose: Enable timeout management and connection activity monitoring + * + * Note: This timestamp is used by the idle connection timeout system to detect + * inactive connections that should be closed. + */ if (mypolls) mypolls->last_sent[poll_fds_idx]=sess->thread->curtime; bytes_info.bytes_sent+=bytes_io; } diff --git a/src/main.cpp b/src/main.cpp index aa78d0f799..1782cb7e4a 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -26,6 +26,7 @@ using json = nlohmann::json; #include "ProxySQL_Cluster.hpp" #include "MySQL_Logger.hpp" #include "PgSQL_Logger.hpp" +#include "GenAI_Thread.h" #include "SQLite3_Server.h" #include "MySQL_Query_Processor.h" #include "PgSQL_Query_Processor.h" @@ -477,6 +478,7 @@ PgSQL_Query_Processor* GloPgQPro; ProxySQL_Admin *GloAdmin; MySQL_Threads_Handler *GloMTH = NULL; PgSQL_Threads_Handler* GloPTH = NULL; +GenAI_Threads_Handler* GloGATH = NULL; Web_Interface *GloWebInterface; MySQL_STMT_Manager_v14 *GloMyStmt; PgSQL_STMT_Manager *GloPgStmt; @@ -929,6 +931,11 @@ void ProxySQL_Main_init_main_modules() { PgSQL_Threads_Handler* _tmp_GloPTH = NULL; _tmp_GloPTH = new PgSQL_Threads_Handler(); GloPTH = _tmp_GloPTH; + + // Initialize GenAI module + GenAI_Threads_Handler* _tmp_GloGATH = NULL; + _tmp_GloGATH = new GenAI_Threads_Handler(); + GloGATH = _tmp_GloGATH; } @@ -1258,6 +1265,14 @@ void ProxySQL_Main_shutdown_all_modules() { pthread_mutex_unlock(&GloVars.global.ext_glopth_mutex); #ifdef DEBUG std::cerr << "GloPTH shutdown in "; +#endif + } + if (GloGATH) { + cpu_timer t; + delete GloGATH; + GloGATH = NULL; +#ifdef DEBUG + std::cerr << "GloGATH shutdown in "; #endif } if (GloMyLogger) { @@ -1582,6 +1597,11 @@ void ProxySQL_Main_init_phase3___start_all() { GloAdmin->init_ldap_variables(); } + // GenAI + if (GloGATH) { + GloAdmin->init_genai_variables(); + } + // HTTP Server should be initialized after other modules. See #4510 GloAdmin->init_http_server(); GloAdmin->proxysql_restapi().load_restapi_to_runtime(); diff --git a/src/proxysql.cfg b/src/proxysql.cfg index 0d76936ae5..70f8ec27c2 100644 --- a/src/proxysql.cfg +++ b/src/proxysql.cfg @@ -57,6 +57,15 @@ mysql_variables= sessions_sort=true } +# GenAI module configuration +genai_variables= +{ + # Dummy variable 1: string value + var1="default_value_1" + # Dummy variable 2: integer value + var2=100 +} + mysql_servers = ( { diff --git a/test/tap/tests/genai_async-t.cpp b/test/tap/tests/genai_async-t.cpp new file mode 100644 index 0000000000..302d73ddd4 --- /dev/null +++ b/test/tap/tests/genai_async-t.cpp @@ -0,0 +1,786 @@ +/** + * @file genai_async-t.cpp + * @brief TAP test for the GenAI async architecture + * + * This test verifies the GenAI (Generative AI) module's async architecture: + * - Non-blocking GENAI: queries + * - Multiple concurrent requests + * - Socketpair communication + * - Epoll event handling + * - Request/response matching + * - Resource cleanup + * - Error handling in async mode + * + * Note: These tests require: + * 1. A running GenAI service (llama-server or compatible) at the configured endpoints + * 2. Epoll support (Linux systems) + * 3. The async architecture (socketpair + worker threads) + * + * @date 2025-01-10 + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "mysql.h" +#include "mysqld_error.h" + +#include "tap.h" +#include "command_line.h" +#include "utils.h" + +using std::string; + +// ============================================================================ +// Helper Functions +// ============================================================================ + +/** + * @brief Check if the GenAI module is initialized and async is available + * + * @param admin MySQL connection to admin interface + * @return true if GenAI module is initialized, false otherwise + */ +bool check_genai_initialized(MYSQL* admin) { + MYSQL_QUERY(admin, "SELECT @@genai-threads"); + MYSQL_RES* res = mysql_store_result(admin); + if (!res) { + return false; + } + + int num_rows = mysql_num_rows(res); + mysql_free_result(res); + + // If we get a result, the GenAI module is loaded and initialized + return num_rows == 1; +} + +/** + * @brief Execute a GENAI: query and verify result + * + * @param client MySQL connection to client interface + * @param json_query The JSON query to send (without GENAI: prefix) + * @param expected_rows Expected number of rows (or -1 for any) + * @return true if query succeeded, false otherwise + */ +bool execute_genai_query(MYSQL* client, const string& json_query, int expected_rows = -1) { + string full_query = "GENAI: " + json_query; + int rc = mysql_query(client, full_query.c_str()); + + if (rc != 0) { + diag("Query failed: %s", mysql_error(client)); + return false; + } + + MYSQL_RES* res = mysql_store_result(client); + if (!res) { + diag("No result set returned"); + return false; + } + + int num_rows = mysql_num_rows(res); + mysql_free_result(res); + + if (expected_rows >= 0 && num_rows != expected_rows) { + diag("Expected %d rows, got %d", expected_rows, num_rows); + return false; + } + + return true; +} + +/** + * @brief Execute a GENAI: query and expect an error + * + * @param client MySQL connection to client interface + * @param json_query The JSON query to send (without GENAI: prefix) + * @return true if query returned an error as expected, false otherwise + */ +bool execute_genai_query_expect_error(MYSQL* client, const string& json_query) { + string full_query = "GENAI: " + json_query; + int rc = mysql_query(client, full_query.c_str()); + + // Query should either fail or return an error result set + if (rc != 0) { + // Query failed at MySQL level - this is expected for errors + return true; + } + + MYSQL_RES* res = mysql_store_result(client); + if (!res) { + // No result set - error condition + return true; + } + + // Check if result set contains an error message + int num_fields = mysql_num_fields(res); + bool has_error = false; + + if (num_fields >= 1) { + MYSQL_ROW row = mysql_fetch_row(res); + if (row && row[0]) { + // Check if the first column contains "error" + if (strstr(row[0], "\"error\"") || strstr(row[0], "error")) { + has_error = true; + } + } + } + + mysql_free_result(res); + return has_error; +} + +/** + * @brief Test single async request + * + * @param client MySQL connection to client interface + * @return Number of tests performed + */ +int test_single_async_request(MYSQL* client) { + int test_count = 0; + + diag("Testing single async GenAI request"); + + // Test 1: Single embedding request - should return immediately (async) + auto start = std::chrono::steady_clock::now(); + string json = R"({"type": "embed", "documents": ["Test document for async"]})"; + bool success = execute_genai_query(client, json, 1); + auto end = std::chrono::steady_clock::now(); + auto elapsed = std::chrono::duration_cast(end - start).count(); + + ok(success, "Single async embedding request succeeds"); + test_count++; + + // Note: For async, the query returns quickly but the actual processing + // happens in the worker thread. We can't easily test the non-blocking + // behavior from a single connection, but we can verify it works. + + diag("Async request completed in %ld ms", elapsed); + + return test_count; +} + +/** + * @brief Test multiple sequential async requests + * + * @param client MySQL connection to client interface + * @return Number of tests performed + */ +int test_sequential_async_requests(MYSQL* client) { + int test_count = 0; + + diag("Testing multiple sequential async requests"); + + // Test 1: Send 5 sequential embedding requests + int success_count = 0; + for (int i = 0; i < 5; i++) { + string json = R"({"type": "embed", "documents": ["Sequential test document )" + + std::to_string(i) + R"("]})"; + if (execute_genai_query(client, json, 1)) { + success_count++; + } + } + + ok(success_count == 5, "All 5 sequential async requests succeeded (got %d/5)", success_count); + test_count++; + + // Test 2: Send 5 sequential rerank requests + success_count = 0; + for (int i = 0; i < 5; i++) { + string json = R"({ + "type": "rerank", + "query": "Sequential test query )" + std::to_string(i) + R"(", + "documents": ["doc1", "doc2", "doc3"] + })"; + if (execute_genai_query(client, json, 3)) { + success_count++; + } + } + + ok(success_count == 5, "All 5 sequential rerank requests succeeded (got %d/5)", success_count); + test_count++; + + return test_count; +} + +/** + * @brief Test batch async requests + * + * @param client MySQL connection to client interface + * @return Number of tests performed + */ +int test_batch_async_requests(MYSQL* client) { + int test_count = 0; + + diag("Testing batch async requests"); + + // Test 1: Batch embedding with 10 documents + string json = R"({"type": "embed", "documents": [)"; + for (int i = 0; i < 10; i++) { + if (i > 0) json += ","; + json += R"("Batch document )" + std::to_string(i) + R"(")"; + } + json += "]}"; + + ok(execute_genai_query(client, json, 10), + "Batch embedding with 10 documents returns 10 rows"); + test_count++; + + // Test 2: Batch rerank with 10 documents + json = R"({ + "type": "rerank", + "query": "Batch test query", + "documents": [)"; + for (int i = 0; i < 10; i++) { + if (i > 0) json += ","; + json += R"("Document )" + std::to_string(i) + R"(")"; + } + json += "]}"; + + ok(execute_genai_query(client, json, 10), + "Batch rerank with 10 documents returns 10 rows"); + test_count++; + + return test_count; +} + +/** + * @brief Test mixed embedding and rerank requests + * + * @param client MySQL connection to client interface + * @return Number of tests performed + */ +int test_mixed_requests(MYSQL* client) { + int test_count = 0; + + diag("Testing mixed embedding and rerank requests"); + + // Test 1: Interleave embedding and rerank requests + int success_count = 0; + for (int i = 0; i < 3; i++) { + // Embedding + string json = R"({"type": "embed", "documents": ["Mixed test )" + + std::to_string(i) + R"("]})"; + if (execute_genai_query(client, json, 1)) { + success_count++; + } + + // Rerank + json = R"({ + "type": "rerank", + "query": "Mixed query )" + std::to_string(i) + R"(", + "documents": ["doc1", "doc2"] + })"; + if (execute_genai_query(client, json, 2)) { + success_count++; + } + } + + ok(success_count == 6, "Mixed embedding and rerank requests succeeded (got %d/6)", success_count); + test_count++; + + return test_count; +} + +/** + * @brief Test request/response matching + * + * @param client MySQL connection to client interface + * @return Number of tests performed + */ +int test_request_response_matching(MYSQL* client) { + int test_count = 0; + + diag("Testing request/response matching"); + + // Test 1: Send requests with different document counts and verify + std::vector doc_counts = {1, 3, 5, 7}; + int success_count = 0; + + for (int doc_count : doc_counts) { + string json = R"({"type": "embed", "documents": [)"; + for (int i = 0; i < doc_count; i++) { + if (i > 0) json += ","; + json += R"("doc )" + std::to_string(i) + R"(")"; + } + json += "]}"; + + if (execute_genai_query(client, json, doc_count)) { + success_count++; + } + } + + ok(success_count == (int)doc_counts.size(), + "Request/response matching correct for varying document counts (got %d/%zu)", + success_count, doc_counts.size()); + test_count++; + + return test_count; +} + +/** + * @brief Test error handling in async mode + * + * @param client MySQL connection to client interface + * @return Number of tests performed + */ +int test_async_error_handling(MYSQL* client) { + int test_count = 0; + + diag("Testing async error handling"); + + // Test 1: Invalid JSON - should return error immediately + string json = R"({"type": "embed", "documents": [)"; + ok(execute_genai_query_expect_error(client, json), + "Invalid JSON returns error in async mode"); + test_count++; + + // Test 2: Missing documents array + json = R"({"type": "embed"})"; + ok(execute_genai_query_expect_error(client, json), + "Missing documents array returns error in async mode"); + test_count++; + + // Test 3: Empty documents array + json = R"({"type": "embed", "documents": []})"; + ok(execute_genai_query_expect_error(client, json), + "Empty documents array returns error in async mode"); + test_count++; + + // Test 4: Rerank without query + json = R"({"type": "rerank", "documents": ["doc1"]})"; + ok(execute_genai_query_expect_error(client, json), + "Rerank without query returns error in async mode"); + test_count++; + + // Test 5: Unknown operation type + json = R"({"type": "unknown", "documents": ["doc1"]})"; + ok(execute_genai_query_expect_error(client, json), + "Unknown operation type returns error in async mode"); + test_count++; + + // Test 6: Verify connection still works after errors + json = R"({"type": "embed", "documents": ["Recovery test"]})"; + ok(execute_genai_query(client, json, 1), + "Connection still works after error requests"); + test_count++; + + return test_count; +} + +/** + * @brief Test special characters in queries + * + * @param client MySQL connection to client interface + * @return Number of tests performed + */ +int test_special_characters(MYSQL* client) { + int test_count = 0; + + diag("Testing special characters in async queries"); + + // Test 1: Quotes and apostrophes + string json = R"({"type": "embed", "documents": ["Test with \"quotes\" and 'apostrophes'"]})"; + ok(execute_genai_query(client, json, 1), + "Embedding with quotes and apostrophes succeeds"); + test_count++; + + // Test 2: Backslashes + json = R"({"type": "embed", "documents": ["Path: C:\\Users\\test"]})"; + ok(execute_genai_query(client, json, 1), + "Embedding with backslashes succeeds"); + test_count++; + + // Test 3: Newlines and tabs + json = R"({"type": "embed", "documents": ["Line1\nLine2\tTabbed"]})"; + ok(execute_genai_query(client, json, 1), + "Embedding with newlines and tabs succeeds"); + test_count++; + + // Test 4: Unicode characters + json = R"({"type": "embed", "documents": ["Unicode: 你好世界 🌍 🚀"]})"; + ok(execute_genai_query(client, json, 1), + "Embedding with unicode characters succeeds"); + test_count++; + + // Test 5: Rerank with special characters in query + json = R"({ + "type": "rerank", + "query": "What is \"SQL\" injection?", + "documents": ["SQL injection is dangerous", "ProxySQL is a proxy"] + })"; + ok(execute_genai_query(client, json, 2), + "Rerank with quoted query succeeds"); + test_count++; + + return test_count; +} + +/** + * @brief Test large documents + * + * @param client MySQL connection to client interface + * @return Number of tests performed + */ +int test_large_documents(MYSQL* client) { + int test_count = 0; + + diag("Testing large documents in async mode"); + + // Test 1: Single large document (several KB) + string large_doc(5000, 'A'); // 5KB of 'A's + string json = R"({"type": "embed", "documents": [")" + large_doc + R"("]})"; + ok(execute_genai_query(client, json, 1), + "Single large document (5KB) embedding succeeds"); + test_count++; + + // Test 2: Multiple large documents + json = R"({"type": "embed", "documents": [)"; + for (int i = 0; i < 5; i++) { + if (i > 0) json += ","; + json += R"(")" + string(1000, 'A' + i) + R"(")"; // 1KB each + } + json += "]}"; + + ok(execute_genai_query(client, json, 5), + "Multiple large documents (5x1KB) embedding succeeds"); + test_count++; + + return test_count; +} + +/** + * @brief Test top_n parameter in async mode + * + * @param client MySQL connection to client interface + * @return Number of tests performed + */ +int test_top_n_parameter(MYSQL* client) { + int test_count = 0; + + diag("Testing top_n parameter in async mode"); + + // Test 1: top_n = 3 with 10 documents + string json = R"({ + "type": "rerank", + "query": "Test query", + "documents": ["doc1", "doc2", "doc3", "doc4", "doc5", "doc6", "doc7", "doc8", "doc9", "doc10"], + "top_n": 3 + })"; + ok(execute_genai_query(client, json, 3), + "Rerank with top_n=3 returns exactly 3 rows"); + test_count++; + + // Test 2: top_n = 1 + json = R"({ + "type": "rerank", + "query": "Test query", + "documents": ["doc1", "doc2", "doc3"], + "top_n": 1 + })"; + ok(execute_genai_query(client, json, 1), + "Rerank with top_n=1 returns exactly 1 row"); + test_count++; + + // Test 3: top_n = 0 (return all) + json = R"({ + "type": "rerank", + "query": "Test query", + "documents": ["doc1", "doc2", "doc3", "doc4", "doc5"], + "top_n": 0 + })"; + ok(execute_genai_query(client, json, 5), + "Rerank with top_n=0 returns all 5 rows"); + test_count++; + + return test_count; +} + +/** + * @brief Test columns parameter in async mode + * + * @param client MySQL connection to client interface + * @return Number of tests performed + */ +int test_columns_parameter(MYSQL* client) { + int test_count = 0; + + diag("Testing columns parameter in async mode"); + + // Test 1: columns = 2 (index and score only) + string json = R"({ + "type": "rerank", + "query": "Test query", + "documents": ["doc1", "doc2", "doc3"], + "columns": 2 + })"; + ok(execute_genai_query(client, json, 3), + "Rerank with columns=2 returns 3 rows"); + test_count++; + + // Test 2: columns = 3 (index, score, document) - default + json = R"({ + "type": "rerank", + "query": "Test query", + "documents": ["doc1", "doc2"], + "columns": 3 + })"; + ok(execute_genai_query(client, json, 2), + "Rerank with columns=3 returns 2 rows"); + test_count++; + + // Test 3: Invalid columns value + json = R"({ + "type": "rerank", + "query": "Test query", + "documents": ["doc1"], + "columns": 5 + })"; + ok(execute_genai_query_expect_error(client, json), + "Invalid columns=5 returns error"); + test_count++; + + return test_count; +} + +/** + * @brief Test concurrent requests from multiple connections + * + * @param host MySQL host + * @param username MySQL username + * @param password MySQL password + * @param port MySQL port + * @return Number of tests performed + */ +int test_concurrent_connections(const char* host, const char* username, + const char* password, int port) { + int test_count = 0; + + diag("Testing concurrent requests from multiple connections"); + + // Create 3 separate connections + const int num_conns = 3; + MYSQL* conns[num_conns]; + + for (int i = 0; i < num_conns; i++) { + conns[i] = mysql_init(NULL); + if (!conns[i]) { + diag("Failed to initialize connection %d", i); + continue; + } + + if (!mysql_real_connect(conns[i], host, username, password, + NULL, port, NULL, 0)) { + diag("Failed to connect connection %d: %s", i, mysql_error(conns[i])); + mysql_close(conns[i]); + conns[i] = NULL; + continue; + } + } + + // Count successful connections + int valid_conns = 0; + for (int i = 0; i < num_conns; i++) { + if (conns[i]) valid_conns++; + } + + ok(valid_conns == num_conns, + "Created %d concurrent connections (expected %d)", valid_conns, num_conns); + test_count++; + + if (valid_conns < num_conns) { + // Skip remaining tests if we couldn't create all connections + for (int i = 0; i < num_conns; i++) { + if (conns[i]) mysql_close(conns[i]); + } + return test_count; + } + + // Send requests from all connections concurrently + std::atomic success_count{0}; + std::vector threads; + + for (int i = 0; i < num_conns; i++) { + threads.push_back(std::thread([&, i]() { + string json = R"({"type": "embed", "documents": ["Concurrent test )" + + std::to_string(i) + R"("]})"; + if (execute_genai_query(conns[i], json, 1)) { + success_count++; + } + })); + } + + // Wait for all threads to complete + for (auto& t : threads) { + t.join(); + } + + ok(success_count == num_conns, + "All %d concurrent requests succeeded", num_conns); + test_count++; + + // Cleanup + for (int i = 0; i < num_conns; i++) { + mysql_close(conns[i]); + } + + return test_count; +} + +// ============================================================================ +// Main Test Function +// ============================================================================ + +int main() { + CommandLine cl; + + if (cl.getEnv()) { + diag("Failed to get the required environmental variables."); + return EXIT_FAILURE; + } + + // Initialize connections + MYSQL* admin = mysql_init(NULL); + if (!admin) { + fprintf(stderr, "File %s, line %d, Error: mysql_init failed\n", __FILE__, __LINE__); + return EXIT_FAILURE; + } + + if (!mysql_real_connect(admin, cl.admin_host, cl.admin_username, cl.admin_password, + NULL, cl.admin_port, NULL, 0)) { + fprintf(stderr, "File %s, line %d, Error: %s\n", __FILE__, __LINE__, mysql_error(admin)); + return EXIT_FAILURE; + } + + diag("Connected to ProxySQL admin interface at %s:%d", cl.admin_host, cl.admin_port); + + MYSQL* client = mysql_init(NULL); + if (!client) { + fprintf(stderr, "File %s, line %d, Error: mysql_init failed\n", __FILE__, __LINE__); + mysql_close(admin); + return EXIT_FAILURE; + } + + if (!mysql_real_connect(client, cl.host, cl.username, cl.password, + NULL, cl.port, NULL, 0)) { + fprintf(stderr, "File %s, line %d, Error: %s\n", __FILE__, __LINE__, mysql_error(client)); + mysql_close(admin); + mysql_close(client); + return EXIT_FAILURE; + } + + diag("Connected to ProxySQL client interface at %s:%d", cl.host, cl.port); + + // Check if GenAI module is initialized + if (!check_genai_initialized(admin)) { + diag("GenAI module is not initialized. Skipping all tests."); + plan(1); + skip(1, "GenAI module not initialized"); + mysql_close(admin); + mysql_close(client); + return exit_status(); + } + + diag("GenAI module is initialized. Proceeding with async tests."); + + // Calculate total tests + // Single async request: 1 test + // Sequential requests: 2 tests + // Batch requests: 2 tests + // Mixed requests: 1 test + // Request/response matching: 1 test + // Error handling: 6 tests + // Special characters: 5 tests + // Large documents: 2 tests + // top_n parameter: 3 tests + // columns parameter: 3 tests + // Concurrent connections: 2 tests + int total_tests = 1 + 2 + 2 + 1 + 1 + 6 + 5 + 2 + 3 + 3 + 2; + plan(total_tests); + + int test_count = 0; + + // ============================================================================ + // Part 1: Test single async request + // ============================================================================ + diag("=== Part 1: Testing single async request ==="); + test_count += test_single_async_request(client); + + // ============================================================================ + // Part 2: Test sequential async requests + // ============================================================================ + diag("=== Part 2: Testing sequential async requests ==="); + test_count += test_sequential_async_requests(client); + + // ============================================================================ + // Part 3: Test batch async requests + // ============================================================================ + diag("=== Part 3: Testing batch async requests ==="); + test_count += test_batch_async_requests(client); + + // ============================================================================ + // Part 4: Test mixed embedding and rerank requests + // ============================================================================ + diag("=== Part 4: Testing mixed requests ==="); + test_count += test_mixed_requests(client); + + // ============================================================================ + // Part 5: Test request/response matching + // ============================================================================ + diag("=== Part 5: Testing request/response matching ==="); + test_count += test_request_response_matching(client); + + // ============================================================================ + // Part 6: Test error handling in async mode + // ============================================================================ + diag("=== Part 6: Testing async error handling ==="); + test_count += test_async_error_handling(client); + + // ============================================================================ + // Part 7: Test special characters + // ============================================================================ + diag("=== Part 7: Testing special characters ==="); + test_count += test_special_characters(client); + + // ============================================================================ + // Part 8: Test large documents + // ============================================================================ + diag("=== Part 8: Testing large documents ==="); + test_count += test_large_documents(client); + + // ============================================================================ + // Part 9: Test top_n parameter + // ============================================================================ + diag("=== Part 9: Testing top_n parameter ==="); + test_count += test_top_n_parameter(client); + + // ============================================================================ + // Part 10: Test columns parameter + // ============================================================================ + diag("=== Part 10: Testing columns parameter ==="); + test_count += test_columns_parameter(client); + + // ============================================================================ + // Part 11: Test concurrent connections + // ============================================================================ + diag("=== Part 11: Testing concurrent connections ==="); + test_count += test_concurrent_connections(cl.host, cl.username, cl.password, cl.port); + + // ============================================================================ + // Cleanup + // ============================================================================ + mysql_close(admin); + mysql_close(client); + + diag("=== All GenAI async tests completed ==="); + + return exit_status(); +} diff --git a/test/tap/tests/genai_embedding_rerank-t.cpp b/test/tap/tests/genai_embedding_rerank-t.cpp new file mode 100644 index 0000000000..db9bfa481b --- /dev/null +++ b/test/tap/tests/genai_embedding_rerank-t.cpp @@ -0,0 +1,724 @@ +/** + * @file genai_embedding_rerank-t.cpp + * @brief TAP test for the GenAI embedding and reranking functionality + * + * This test verifies the GenAI (Generative AI) module's core functionality: + * - Embedding generation (single and batch) + * - Reranking documents by relevance + * - JSON query processing + * - Error handling for malformed queries + * - Timeout and error handling + * + * Note: These tests require a running GenAI service (llama-server or compatible) + * at the configured endpoints. The tests use the GENAI: query syntax which + * allows autonomous JSON query processing. + * + * @date 2025-01-09 + */ + +#include +#include +#include +#include +#include +#include +#include + +#include "mysql.h" +#include "mysqld_error.h" + +#include "tap.h" +#include "command_line.h" +#include "utils.h" + +using std::string; + +// ============================================================================ +// Helper Functions +// ============================================================================ + +/** + * @brief Check if the GenAI module is initialized + * + * @param admin MySQL connection to admin interface + * @return true if GenAI module is initialized, false otherwise + */ +bool check_genai_initialized(MYSQL* admin) { + MYSQL_QUERY(admin, "SELECT @@genai-threads"); + MYSQL_RES* res = mysql_store_result(admin); + if (!res) { + return false; + } + + int num_rows = mysql_num_rows(res); + mysql_free_result(res); + + // If we get a result, the GenAI module is loaded and initialized + return num_rows == 1; +} + +/** + * @brief Execute a GENAI: query and check if it returns a result set + * + * @param client MySQL connection to client interface + * @param json_query The JSON query to send (without GENAI: prefix) + * @param expected_rows Expected number of rows (or -1 for any) + * @return true if query succeeded, false otherwise + */ +bool execute_genai_query(MYSQL* client, const string& json_query, int expected_rows = -1) { + string full_query = "GENAI: " + json_query; + int rc = mysql_query(client, full_query.c_str()); + + if (rc != 0) { + diag("Query failed: %s", mysql_error(client)); + return false; + } + + MYSQL_RES* res = mysql_store_result(client); + if (!res) { + diag("No result set returned"); + return false; + } + + int num_rows = mysql_num_rows(res); + mysql_free_result(res); + + if (expected_rows >= 0 && num_rows != expected_rows) { + diag("Expected %d rows, got %d", expected_rows, num_rows); + return false; + } + + return true; +} + +/** + * @brief Execute a GENAI: query and expect an error + * + * @param client MySQL connection to client interface + * @param json_query The JSON query to send (without GENAI: prefix) + * @return true if query returned an error as expected, false otherwise + */ +bool execute_genai_query_expect_error(MYSQL* client, const string& json_query) { + string full_query = "GENAI: " + json_query; + int rc = mysql_query(client, full_query.c_str()); + + // Query should either fail or return an error result set + if (rc != 0) { + // Query failed at MySQL level - this is expected for errors + return true; + } + + MYSQL_RES* res = mysql_store_result(client); + if (!res) { + // No result set - error condition + return true; + } + + // Check if result set contains an error message + int num_fields = mysql_num_fields(res); + bool has_error = false; + + if (num_fields >= 1) { + MYSQL_ROW row = mysql_fetch_row(res); + if (row && row[0]) { + // Check if the first column contains "error" + if (strstr(row[0], "\"error\"") || strstr(row[0], "error")) { + has_error = true; + } + } + } + + mysql_free_result(res); + return has_error; +} + +/** + * @brief Test embedding a single document + * + * @param client MySQL connection to client interface + * @return Number of tests performed + */ +int test_single_embedding(MYSQL* client) { + int test_count = 0; + + diag("Testing single document embedding"); + + // Test 1: Valid single document embedding + string json = R"({"type": "embed", "documents": ["Hello, world!"]})"; + ok(execute_genai_query(client, json, 1), + "Single document embedding returns 1 row"); + test_count++; + + // Test 2: Embedding with special characters + json = R"({"type": "embed", "documents": ["Test with quotes \" and 'apostrophes'"]})"; + ok(execute_genai_query(client, json, 1), + "Embedding with special characters returns 1 row"); + test_count++; + + // Test 3: Embedding with unicode + json = R"({"type": "embed", "documents": ["Unicode test: 你好世界 🌍"]})"; + ok(execute_genai_query(client, json, 1), + "Embedding with unicode returns 1 row"); + test_count++; + + return test_count; +} + +/** + * @brief Test embedding multiple documents (batch) + * + * @param client MySQL connection to client interface + * @return Number of tests performed + */ +int test_batch_embedding(MYSQL* client) { + int test_count = 0; + + diag("Testing batch document embedding"); + + // Test 1: Batch embedding with 3 documents + string json = R"({"type": "embed", "documents": ["First document", "Second document", "Third document"]})"; + ok(execute_genai_query(client, json, 3), + "Batch embedding with 3 documents returns 3 rows"); + test_count++; + + // Test 2: Batch embedding with 5 documents + json = R"({"type": "embed", "documents": ["doc1", "doc2", "doc3", "doc4", "doc5"]})"; + ok(execute_genai_query(client, json, 5), + "Batch embedding with 5 documents returns 5 rows"); + test_count++; + + // Test 3: Batch embedding with empty document (edge case) + json = R"({"type": "embed", "documents": [""]})"; + ok(execute_genai_query(client, json, 1), + "Batch embedding with empty document returns 1 row"); + test_count++; + + return test_count; +} + +/** + * @brief Test embedding error handling + * + * @param client MySQL connection to client interface + * @return Number of tests performed + */ +int test_embedding_errors(MYSQL* client) { + int test_count = 0; + + diag("Testing embedding error handling"); + + // Test 1: Missing documents array + string json = R"({"type": "embed"})"; + ok(execute_genai_query_expect_error(client, json), + "Embedding without documents array returns error"); + test_count++; + + // Test 2: Empty documents array + json = R"({"type": "embed", "documents": []})"; + ok(execute_genai_query_expect_error(client, json), + "Embedding with empty documents array returns error"); + test_count++; + + // Test 3: Invalid JSON + json = R"({"type": "embed", "documents": [)"; + ok(execute_genai_query_expect_error(client, json), + "Embedding with invalid JSON returns error"); + test_count++; + + // Test 4: Documents is not an array + json = R"({"type": "embed", "documents": "not an array"})"; + ok(execute_genai_query_expect_error(client, json), + "Embedding with non-array documents returns error"); + test_count++; + + return test_count; +} + +/** + * @brief Test basic reranking functionality + * + * @param client MySQL connection to client interface + * @return Number of tests performed + */ +int test_basic_rerank(MYSQL* client) { + int test_count = 0; + + diag("Testing basic reranking"); + + // Test 1: Simple rerank with 3 documents + string json = R"({ + "type": "rerank", + "query": "What is machine learning?", + "documents": [ + "Machine learning is a subset of artificial intelligence.", + "The capital of France is Paris.", + "Deep learning uses neural networks with multiple layers." + ] + })"; + ok(execute_genai_query(client, json, 3), + "Rerank with 3 documents returns 3 rows"); + test_count++; + + // Test 2: Rerank with query containing quotes + json = R"({ + "type": "rerank", + "query": "What is \"SQL\" injection?", + "documents": [ + "SQL injection is a code vulnerability.", + "ProxySQL is a database proxy." + ] + })"; + ok(execute_genai_query(client, json, 2), + "Rerank with quoted query returns 2 rows"); + test_count++; + + return test_count; +} + +/** + * @brief Test rerank with top_n parameter + * + * @param client MySQL connection to client interface + * @return Number of tests performed + */ +int test_rerank_top_n(MYSQL* client) { + int test_count = 0; + + diag("Testing rerank with top_n parameter"); + + // Test 1: top_n = 2 with 5 documents + string json = R"({ + "type": "rerank", + "query": "database systems", + "documents": [ + "ProxySQL is a proxy for MySQL.", + "PostgreSQL is an object-relational database.", + "Redis is an in-memory data store.", + "Elasticsearch is a search engine.", + "MongoDB is a NoSQL database." + ], + "top_n": 2 + })"; + ok(execute_genai_query(client, json, 2), + "Rerank with top_n=2 returns exactly 2 rows"); + test_count++; + + // Test 2: top_n = 1 with 3 documents + json = R"({ + "type": "rerank", + "query": "best fruit", + "documents": ["Apple", "Banana", "Orange"], + "top_n": 1 + })"; + ok(execute_genai_query(client, json, 1), + "Rerank with top_n=1 returns exactly 1 row"); + test_count++; + + // Test 3: top_n = 0 should return all results + json = R"({ + "type": "rerank", + "query": "test query", + "documents": ["doc1", "doc2", "doc3"], + "top_n": 0 + })"; + ok(execute_genai_query(client, json, 3), + "Rerank with top_n=0 returns all 3 rows"); + test_count++; + + return test_count; +} + +/** + * @brief Test rerank with columns parameter + * + * @param client MySQL connection to client interface + * @return Number of tests performed + */ +int test_rerank_columns(MYSQL* client) { + int test_count = 0; + + diag("Testing rerank with columns parameter"); + + // Test 1: columns = 2 (index and score only) + string json = R"({ + "type": "rerank", + "query": "test query", + "documents": ["doc1", "doc2"], + "columns": 2 + })"; + ok(execute_genai_query(client, json, 2), + "Rerank with columns=2 returns 2 rows"); + test_count++; + + // Test 2: columns = 3 (index, score, document) - default + json = R"({ + "type": "rerank", + "query": "test query", + "documents": ["doc1", "doc2"], + "columns": 3 + })"; + ok(execute_genai_query(client, json, 2), + "Rerank with columns=3 returns 2 rows"); + test_count++; + + // Test 3: Invalid columns value should return error + json = R"({ + "type": "rerank", + "query": "test query", + "documents": ["doc1"], + "columns": 5 + })"; + ok(execute_genai_query_expect_error(client, json), + "Rerank with invalid columns=5 returns error"); + test_count++; + + return test_count; +} + +/** + * @brief Test rerank error handling + * + * @param client MySQL connection to client interface + * @return Number of tests performed + */ +int test_rerank_errors(MYSQL* client) { + int test_count = 0; + + diag("Testing rerank error handling"); + + // Test 1: Missing query + string json = R"({"type": "rerank", "documents": ["doc1"]})"; + ok(execute_genai_query_expect_error(client, json), + "Rerank without query returns error"); + test_count++; + + // Test 2: Empty query + json = R"({"type": "rerank", "query": "", "documents": ["doc1"]})"; + ok(execute_genai_query_expect_error(client, json), + "Rerank with empty query returns error"); + test_count++; + + // Test 3: Missing documents array + json = R"({"type": "rerank", "query": "test"})"; + ok(execute_genai_query_expect_error(client, json), + "Rerank without documents returns error"); + test_count++; + + // Test 4: Empty documents array + json = R"({"type": "rerank", "query": "test", "documents": []})"; + ok(execute_genai_query_expect_error(client, json), + "Rerank with empty documents returns error"); + test_count++; + + return test_count; +} + +/** + * @brief Test general JSON query error handling + * + * @param client MySQL connection to client interface + * @return Number of tests performed + */ +int test_json_query_errors(MYSQL* client) { + int test_count = 0; + + diag("Testing JSON query error handling"); + + // Test 1: Missing type field + string json = R"({"documents": ["doc1"]})"; + ok(execute_genai_query_expect_error(client, json), + "Query without type field returns error"); + test_count++; + + // Test 2: Unknown operation type + json = R"({"type": "unknown_op", "documents": ["doc1"]})"; + ok(execute_genai_query_expect_error(client, json), + "Query with unknown type returns error"); + test_count++; + + // Test 3: Completely invalid JSON + json = R"({invalid json})"; + ok(execute_genai_query_expect_error(client, json), + "Invalid JSON returns error"); + test_count++; + + // Test 4: Empty JSON object + json = R"({})"; + ok(execute_genai_query_expect_error(client, json), + "Empty JSON object returns error"); + test_count++; + + // Test 5: Query is not an object + json = R"(["array", "not", "object"])"; + ok(execute_genai_query_expect_error(client, json), + "JSON array (not object) returns error"); + test_count++; + + return test_count; +} + +/** + * @brief Test GENAI: query syntax variations + * + * @param client MySQL connection to client interface + * @return Number of tests performed + */ +int test_genai_syntax(MYSQL* client) { + int test_count = 0; + + diag("Testing GENAI: query syntax variations"); + + // Test 1: GENAI: with leading space should FAIL (not recognized) + int rc = mysql_query(client, " GENAI: {\"type\": \"embed\", \"documents\": [\"test\"]}"); + ok(rc != 0, "GENAI: with leading space is rejected"); + test_count++; + + // Test 2: Empty query after GENAI: + rc = mysql_query(client, "GENAI: "); + MYSQL_RES* res = mysql_store_result(client); + // Should either fail or have no rows + bool empty_query_ok = (rc != 0) || (res && mysql_num_rows(res) == 0); + if (res) mysql_free_result(res); + ok(empty_query_ok, "Empty GENAI: query handled correctly"); + test_count++; + + // Test 3: Case sensitivity - lowercase should also work + rc = mysql_query(client, "genai: {\"type\": \"embed\", \"documents\": [\"test\"]}"); + ok(rc == 0, "Lowercase 'genai:' works"); + test_count++; + + return test_count; +} + +/** + * @brief Test GenAI configuration variables + * + * @param admin MySQL connection to admin interface + * @return Number of tests performed + */ +int test_genai_configuration(MYSQL* admin) { + int test_count = 0; + + diag("Testing GenAI configuration variables"); + + // Test 1: Check genai-threads variable + MYSQL_QUERY(admin, "SELECT @@genai-threads"); + MYSQL_RES* res = mysql_store_result(admin); + ok(res != NULL, "genai-threads variable is accessible"); + if (res) { + int num_rows = mysql_num_rows(res); + ok(num_rows == 1, "genai-threads returns 1 row"); + test_count++; + mysql_free_result(res); + } else { + skip(1, "Cannot check row count"); + test_count++; + } + test_count++; + + // Test 2: Check genai-embedding_uri variable + MYSQL_QUERY(admin, "SELECT @@genai-embedding_uri"); + res = mysql_store_result(admin); + ok(res != NULL, "genai-embedding_uri variable is accessible"); + if (res) { + MYSQL_ROW row = mysql_fetch_row(res); + ok(row && row[0] && strlen(row[0]) > 0, "genai-embedding_uri has a value"); + test_count++; + mysql_free_result(res); + } else { + skip(1, "Cannot check value"); + test_count++; + } + test_count++; + + // Test 3: Check genai-rerank_uri variable + MYSQL_QUERY(admin, "SELECT @@genai-rerank_uri"); + res = mysql_store_result(admin); + ok(res != NULL, "genai-rerank_uri variable is accessible"); + if (res) { + MYSQL_ROW row = mysql_fetch_row(res); + ok(row && row[0] && strlen(row[0]) > 0, "genai-rerank_uri has a value"); + test_count++; + mysql_free_result(res); + } else { + skip(1, "Cannot check value"); + test_count++; + } + test_count++; + + // Test 4: Check genai-embedding_timeout_ms variable + MYSQL_QUERY(admin, "SELECT @@genai-embedding_timeout_ms"); + res = mysql_store_result(admin); + ok(res != NULL, "genai-embedding_timeout_ms variable is accessible"); + if (res) { + MYSQL_ROW row = mysql_fetch_row(res); + ok(row && row[0] && atoi(row[0]) > 0, "genai-embedding_timeout_ms has a positive value"); + test_count++; + mysql_free_result(res); + } else { + skip(1, "Cannot check value"); + test_count++; + } + test_count++; + + // Test 5: Check genai-rerank_timeout_ms variable + MYSQL_QUERY(admin, "SELECT @@genai-rerank_timeout_ms"); + res = mysql_store_result(admin); + ok(res != NULL, "genai-rerank_timeout_ms variable is accessible"); + if (res) { + MYSQL_ROW row = mysql_fetch_row(res); + ok(row && row[0] && atoi(row[0]) > 0, "genai-rerank_timeout_ms has a positive value"); + test_count++; + mysql_free_result(res); + } else { + skip(1, "Cannot check value"); + test_count++; + } + test_count++; + + return test_count; +} + +// ============================================================================ +// Main Test Function +// ============================================================================ + +int main() { + CommandLine cl; + + if (cl.getEnv()) { + diag("Failed to get the required environmental variables."); + return EXIT_FAILURE; + } + + // Initialize connections + MYSQL* admin = mysql_init(NULL); + if (!admin) { + fprintf(stderr, "File %s, line %d, Error: mysql_init failed\n", __FILE__, __LINE__); + return EXIT_FAILURE; + } + + if (!mysql_real_connect(admin, cl.admin_host, cl.admin_username, cl.admin_password, + NULL, cl.admin_port, NULL, 0)) { + fprintf(stderr, "File %s, line %d, Error: %s\n", __FILE__, __LINE__, mysql_error(admin)); + mysql_close(admin); + return EXIT_FAILURE; + } + + diag("Connected to ProxySQL admin interface at %s:%d", cl.admin_host, cl.admin_port); + + MYSQL* client = mysql_init(NULL); + if (!client) { + fprintf(stderr, "File %s, line %d, Error: mysql_init failed\n", __FILE__, __LINE__); + mysql_close(admin); + return EXIT_FAILURE; + } + + if (!mysql_real_connect(client, cl.host, cl.username, cl.password, + NULL, cl.port, NULL, 0)) { + fprintf(stderr, "File %s, line %d, Error: %s\n", __FILE__, __LINE__, mysql_error(client)); + mysql_close(admin); + mysql_close(client); + return EXIT_FAILURE; + } + + diag("Connected to ProxySQL client interface at %s:%d", cl.host, cl.port); + + // Check if GenAI module is initialized + if (!check_genai_initialized(admin)) { + diag("GenAI module is not initialized. Skipping all tests."); + plan(1); + skip(1, "GenAI module not initialized"); + mysql_close(admin); + mysql_close(client); + return exit_status(); + } + + diag("GenAI module is initialized. Proceeding with tests."); + + // Calculate total tests + // Configuration tests: 10 tests (5 vars × 2 tests each) + // Single embedding: 3 tests + // Batch embedding: 3 tests + // Embedding errors: 4 tests + // Basic rerank: 2 tests + // Rerank top_n: 3 tests + // Rerank columns: 3 tests + // Rerank errors: 4 tests + // JSON query errors: 5 tests + // GENAI syntax: 3 tests + int total_tests = 10 + 3 + 3 + 4 + 2 + 3 + 3 + 4 + 5 + 3; + plan(total_tests); + + int test_count = 0; + + // ============================================================================ + // Part 1: Test GenAI configuration + // ============================================================================ + diag("=== Part 1: Testing GenAI configuration ==="); + test_count += test_genai_configuration(admin); + + // ============================================================================ + // Part 2: Test single document embedding + // ============================================================================ + diag("=== Part 2: Testing single document embedding ==="); + test_count += test_single_embedding(client); + + // ============================================================================ + // Part 3: Test batch embedding + // ============================================================================ + diag("=== Part 3: Testing batch embedding ==="); + test_count += test_batch_embedding(client); + + // ============================================================================ + // Part 4: Test embedding error handling + // ============================================================================ + diag("=== Part 4: Testing embedding error handling ==="); + test_count += test_embedding_errors(client); + + // ============================================================================ + // Part 5: Test basic reranking + // ============================================================================ + diag("=== Part 5: Testing basic reranking ==="); + test_count += test_basic_rerank(client); + + // ============================================================================ + // Part 6: Test rerank with top_n parameter + // ============================================================================ + diag("=== Part 6: Testing rerank with top_n parameter ==="); + test_count += test_rerank_top_n(client); + + // ============================================================================ + // Part 7: Test rerank with columns parameter + // ============================================================================ + diag("=== Part 7: Testing rerank with columns parameter ==="); + test_count += test_rerank_columns(client); + + // ============================================================================ + // Part 8: Test rerank error handling + // ============================================================================ + diag("=== Part 8: Testing rerank error handling ==="); + test_count += test_rerank_errors(client); + + // ============================================================================ + // Part 9: Test JSON query error handling + // ============================================================================ + diag("=== Part 9: Testing JSON query error handling ==="); + test_count += test_json_query_errors(client); + + // ============================================================================ + // Part 10: Test GENAI: query syntax variations + // ============================================================================ + diag("=== Part 10: Testing GENAI: query syntax variations ==="); + test_count += test_genai_syntax(client); + + // ============================================================================ + // Cleanup + // ============================================================================ + mysql_close(admin); + mysql_close(client); + + diag("=== All GenAI embedding and reranking tests completed ==="); + + return exit_status(); +} diff --git a/test/tap/tests/genai_module-t.cpp b/test/tap/tests/genai_module-t.cpp new file mode 100644 index 0000000000..425911f9b4 --- /dev/null +++ b/test/tap/tests/genai_module-t.cpp @@ -0,0 +1,376 @@ +/** + * @file genai_module-t.cpp + * @brief TAP test for the GenAI module + * + * This test verifies the functionality of the GenAI (Generative AI) module in ProxySQL. + * It tests: + * - LOAD/SAVE commands for GenAI variables across all variants + * - Variable access (SET and SELECT) for genai-var1 and genai-var2 + * - Variable persistence across storage layers (memory, disk, runtime) + * - CHECKSUM commands for GenAI variables + * - SHOW VARIABLES for GenAI module + * + * @date 2025-01-08 + */ + +#include +#include +#include +#include +#include +#include +#include + +#include "mysql.h" +#include "mysqld_error.h" + +#include "tap.h" +#include "command_line.h" +#include "utils.h" + +using std::string; + +/** + * @brief Helper function to add LOAD/SAVE command variants for GenAI module + * + * This function generates all the standard LOAD/SAVE command variants that + * ProxySQL supports for module variables. It follows the same pattern used + * for other modules like MYSQL, PGSQL, etc. + * + * @param queries Vector to append the generated commands to + */ +void add_genai_load_save_commands(std::vector& queries) { + // LOAD commands - Memory variants + queries.push_back("LOAD GENAI VARIABLES TO MEMORY"); + queries.push_back("LOAD GENAI VARIABLES TO MEM"); + + // LOAD from disk + queries.push_back("LOAD GENAI VARIABLES FROM DISK"); + + // LOAD from memory + queries.push_back("LOAD GENAI VARIABLES FROM MEMORY"); + queries.push_back("LOAD GENAI VARIABLES FROM MEM"); + + // LOAD to runtime + queries.push_back("LOAD GENAI VARIABLES TO RUNTIME"); + queries.push_back("LOAD GENAI VARIABLES TO RUN"); + + // SAVE from memory + queries.push_back("SAVE GENAI VARIABLES FROM MEMORY"); + queries.push_back("SAVE GENAI VARIABLES FROM MEM"); + + // SAVE to disk + queries.push_back("SAVE GENAI VARIABLES TO DISK"); + + // SAVE to memory + queries.push_back("SAVE GENAI VARIABLES TO MEMORY"); + queries.push_back("SAVE GENAI VARIABLES TO MEM"); + + // SAVE from runtime + queries.push_back("SAVE GENAI VARIABLES FROM RUNTIME"); + queries.push_back("SAVE GENAI VARIABLES FROM RUN"); +} + +/** + * @brief Get the value of a GenAI variable as a string + * + * @param admin MySQL connection to admin interface + * @param var_name Variable name (without genai- prefix) + * @return std::string The variable value, or empty string on error + */ +std::string get_genai_variable(MYSQL* admin, const std::string& var_name) { + std::string query = "SELECT @@genai-" + var_name; + if (mysql_query(admin, query.c_str()) != 0) { + return ""; + } + + MYSQL_RES* res = mysql_store_result(admin); + if (!res) { + return ""; + } + + MYSQL_ROW row = mysql_fetch_row(res); + std::string value = row && row[0] ? row[0] : ""; + + mysql_free_result(res); + return value; +} + +/** + * @brief Test variable access operations (SET and SELECT) + * + * Tests setting and retrieving GenAI variables to ensure they work correctly. + */ +int test_variable_access(MYSQL* admin) { + int test_num = 0; + + // Test 1: Get default value of genai-var1 + std::string var1_default = get_genai_variable(admin, "var1"); + ok(var1_default == "default_value_1", + "Default value of genai-var1 is 'default_value_1', got '%s'", var1_default.c_str()); + + // Test 2: Get default value of genai-var2 + std::string var2_default = get_genai_variable(admin, "var2"); + ok(var2_default == "100", + "Default value of genai-var2 is '100', got '%s'", var2_default.c_str()); + + // Test 3: Set genai-var1 to a new value + MYSQL_QUERY(admin, "SET genai-var1='test_value_123'"); + std::string var1_new = get_genai_variable(admin, "var1"); + ok(var1_new == "test_value_123", + "After SET, genai-var1 is 'test_value_123', got '%s'", var1_new.c_str()); + + // Test 4: Set genai-var2 to a new integer value + MYSQL_QUERY(admin, "SET genai-var2=42"); + std::string var2_new = get_genai_variable(admin, "var2"); + ok(var2_new == "42", + "After SET, genai-var2 is '42', got '%s'", var2_new.c_str()); + + // Test 5: Set genai-var1 with special characters + MYSQL_QUERY(admin, "SET genai-var1='test with spaces'"); + std::string var1_special = get_genai_variable(admin, "var1"); + ok(var1_special == "test with spaces", + "genai-var1 with spaces is 'test with spaces', got '%s'", var1_special.c_str()); + + // Test 6: Verify SHOW VARIABLES LIKE pattern + MYSQL_QUERY(admin, "SHOW VARIABLES LIKE 'genai-%'"); + MYSQL_RES* res = mysql_store_result(admin); + int num_rows = mysql_num_rows(res); + ok(num_rows == 2, + "SHOW VARIABLES LIKE 'genai-%%' returns 2 rows, got %d", num_rows); + mysql_free_result(res); + + // Test 7: Restore default values + MYSQL_QUERY(admin, "SET genai-var1='default_value_1'"); + MYSQL_QUERY(admin, "SET genai-var2=100"); + ok(1, "Restored default values for genai-var1 and genai-var2"); + + return test_num; +} + +/** + * @brief Test variable persistence across storage layers + * + * Tests that variables are correctly copied between: + * - Memory (main.global_variables) + * - Disk (disk.global_variables) + * - Runtime (GloGATH handler object) + */ +int test_variable_persistence(MYSQL* admin) { + int test_num = 0; + + // Test 1: Set values and save to disk + diag("Testing variable persistence: Set values, save to disk, modify, load from disk"); + MYSQL_QUERY(admin, "SET genai-var1='disk_test_value'"); + MYSQL_QUERY(admin, "SET genai-var2=999"); + MYSQL_QUERY(admin, "SAVE GENAI VARIABLES TO DISK"); + ok(1, "Set genai-var1='disk_test_value', genai-var2=999 and saved to disk"); + + // Test 2: Modify values in memory + MYSQL_QUERY(admin, "SET genai-var1='memory_value'"); + MYSQL_QUERY(admin, "SET genai-var2=111"); + std::string var1_mem = get_genai_variable(admin, "var1"); + std::string var2_mem = get_genai_variable(admin, "var2"); + ok(var1_mem == "memory_value" && var2_mem == "111", + "Modified in memory: genai-var1='memory_value', genai-var2='111'"); + + // Test 3: Load from disk and verify original values restored + MYSQL_QUERY(admin, "LOAD GENAI VARIABLES FROM DISK"); + std::string var1_disk = get_genai_variable(admin, "var1"); + std::string var2_disk = get_genai_variable(admin, "var2"); + ok(var1_disk == "disk_test_value" && var2_disk == "999", + "After LOAD FROM DISK: genai-var1='disk_test_value', genai-var2='999'"); + + // Test 4: Save to memory and verify + MYSQL_QUERY(admin, "SAVE GENAI VARIABLES TO MEMORY"); + ok(1, "SAVE GENAI VARIABLES TO MEMORY executed"); + + // Test 5: Load from memory + MYSQL_QUERY(admin, "LOAD GENAI VARIABLES FROM MEMORY"); + ok(1, "LOAD GENAI VARIABLES FROM MEMORY executed"); + + // Test 6: Test SAVE from runtime + MYSQL_QUERY(admin, "SAVE GENAI VARIABLES FROM RUNTIME"); + ok(1, "SAVE GENAI VARIABLES FROM RUNTIME executed"); + + // Test 7: Test LOAD to runtime + MYSQL_QUERY(admin, "LOAD GENAI VARIABLES TO RUNTIME"); + ok(1, "LOAD GENAI VARIABLES TO RUNTIME executed"); + + // Test 8: Restore default values + MYSQL_QUERY(admin, "SET genai-var1='default_value_1'"); + MYSQL_QUERY(admin, "SET genai-var2=100"); + MYSQL_QUERY(admin, "SAVE GENAI VARIABLES TO DISK"); + ok(1, "Restored default values and saved to disk"); + + return test_num; +} + +/** + * @brief Test CHECKSUM commands for GenAI variables + * + * Tests all CHECKSUM variants to ensure they work correctly. + */ +int test_checksum_commands(MYSQL* admin) { + int test_num = 0; + + // Test 1: CHECKSUM DISK GENAI VARIABLES + diag("Testing CHECKSUM commands for GenAI variables"); + int rc1 = mysql_query(admin, "CHECKSUM DISK GENAI VARIABLES"); + ok(rc1 == 0, "CHECKSUM DISK GENAI VARIABLES"); + if (rc1 == 0) { + MYSQL_RES* res = mysql_store_result(admin); + int num_rows = mysql_num_rows(res); + ok(num_rows == 1, "CHECKSUM DISK GENAI VARIABLES returns 1 row"); + mysql_free_result(res); + } else { + skip(1, "Skipping row count check due to error"); + } + + // Test 2: CHECKSUM MEM GENAI VARIABLES + int rc2 = mysql_query(admin, "CHECKSUM MEM GENAI VARIABLES"); + ok(rc2 == 0, "CHECKSUM MEM GENAI VARIABLES"); + if (rc2 == 0) { + MYSQL_RES* res = mysql_store_result(admin); + int num_rows = mysql_num_rows(res); + ok(num_rows == 1, "CHECKSUM MEM GENAI VARIABLES returns 1 row"); + mysql_free_result(res); + } else { + skip(1, "Skipping row count check due to error"); + } + + // Test 3: CHECKSUM MEMORY GENAI VARIABLES (alias for MEM) + int rc3 = mysql_query(admin, "CHECKSUM MEMORY GENAI VARIABLES"); + ok(rc3 == 0, "CHECKSUM MEMORY GENAI VARIABLES"); + if (rc3 == 0) { + MYSQL_RES* res = mysql_store_result(admin); + int num_rows = mysql_num_rows(res); + ok(num_rows == 1, "CHECKSUM MEMORY GENAI VARIABLES returns 1 row"); + mysql_free_result(res); + } else { + skip(1, "Skipping row count check due to error"); + } + + // Test 4: CHECKSUM GENAI VARIABLES (defaults to DISK) + int rc4 = mysql_query(admin, "CHECKSUM GENAI VARIABLES"); + ok(rc4 == 0, "CHECKSUM GENAI VARIABLES"); + if (rc4 == 0) { + MYSQL_RES* res = mysql_store_result(admin); + int num_rows = mysql_num_rows(res); + ok(num_rows == 1, "CHECKSUM GENAI VARIABLES returns 1 row"); + mysql_free_result(res); + } else { + skip(1, "Skipping row count check due to error"); + } + + return test_num; +} + +/** + * @brief Main test function + * + * Orchestrates all GenAI module tests. + */ +int main() { + CommandLine cl; + + if (cl.getEnv()) { + diag("Failed to get the required environmental variables."); + return EXIT_FAILURE; + } + + // Initialize connection to admin interface + MYSQL* admin = mysql_init(NULL); + if (!admin) { + fprintf(stderr, "File %s, line %d, Error: mysql_init failed\n", __FILE__, __LINE__); + return EXIT_FAILURE; + } + + if (!mysql_real_connect(admin, cl.host, cl.admin_username, cl.admin_password, + NULL, cl.admin_port, NULL, 0)) { + fprintf(stderr, "File %s, line %d, Error: %s\n", __FILE__, __LINE__, mysql_error(admin)); + return EXIT_FAILURE; + } + + diag("Connected to ProxySQL admin interface at %s:%d", cl.host, cl.admin_port); + + // Build the list of LOAD/SAVE commands to test + std::vector queries; + add_genai_load_save_commands(queries); + + // Each command test = 2 tests (execution + optional result check) + // LOAD/SAVE commands: 14 commands + // Variable access tests: 7 tests + // Persistence tests: 8 tests + // CHECKSUM tests: 8 tests (4 commands × 2) + int num_load_save_tests = (int)queries.size() * 2; // Each command + result check + int total_tests = num_load_save_tests + 7 + 8 + 8; + + plan(total_tests); + + int test_count = 0; + + // ============================================================================ + // Part 1: Test LOAD/SAVE commands + // ============================================================================ + diag("=== Part 1: Testing LOAD/SAVE GENAI VARIABLES commands ==="); + for (const auto& query : queries) { + MYSQL* admin_local = mysql_init(NULL); + if (!admin_local) { + diag("Failed to initialize MySQL connection"); + continue; + } + + if (!mysql_real_connect(admin_local, cl.host, cl.admin_username, cl.admin_password, + NULL, cl.admin_port, NULL, 0)) { + diag("Failed to connect to admin interface"); + mysql_close(admin_local); + continue; + } + + int rc = run_q(admin_local, query.c_str()); + ok(rc == 0, "Command executed successfully: %s", query.c_str()); + + // For SELECT/SHOW/CHECKSUM style commands, verify result set + if (strncasecmp(query.c_str(), "SELECT ", 7) == 0 || + strncasecmp(query.c_str(), "SHOW ", 5) == 0 || + strncasecmp(query.c_str(), "CHECKSUM ", 9) == 0) { + MYSQL_RES* res = mysql_store_result(admin_local); + unsigned long long num_rows = mysql_num_rows(res); + ok(num_rows != 0, "Command returned rows: %s", query.c_str()); + mysql_free_result(res); + } else { + // For non-query commands, just mark the test as passed + ok(1, "Command completed: %s", query.c_str()); + } + + mysql_close(admin_local); + } + + // ============================================================================ + // Part 2: Test variable access (SET and SELECT) + // ============================================================================ + diag("=== Part 2: Testing variable access (SET and SELECT) ==="); + test_count += test_variable_access(admin); + + // ============================================================================ + // Part 3: Test variable persistence across layers + // ============================================================================ + diag("=== Part 3: Testing variable persistence across storage layers ==="); + test_count += test_variable_persistence(admin); + + // ============================================================================ + // Part 4: Test CHECKSUM commands + // ============================================================================ + diag("=== Part 4: Testing CHECKSUM commands ==="); + test_count += test_checksum_commands(admin); + + // ============================================================================ + // Cleanup + // ============================================================================ + mysql_close(admin); + + diag("=== All GenAI module tests completed ==="); + + return exit_status(); +}