diff --git a/docs/service/mllm_cli.rst b/docs/service/mllm_cli.rst index 7cf913d09..fdaf84766 100644 --- a/docs/service/mllm_cli.rst +++ b/docs/service/mllm_cli.rst @@ -6,6 +6,11 @@ Overview This document describes the MLLM command-line interface (CLI) tool, which operates within a client-server architecture. The system is designed to provide network access to MLLM's core inference capabilities. The backend service is written in Go and interacts with the core C++ MLLM library through a C API. The frontend can be a Go-based command-line client or a standard GUI client like Chatbox that communicates with the service via an OpenAI-compatible API. +**Currently, the system officially supports the following models:** + +* **LLM**: ``mllmTeam/Qwen3-0.6B-w4a32kai`` +* **OCR**: ``mllmTeam/DeepSeek-OCR-w4a8-i8mm-kai`` + This guide covers three main areas: 1. **System Architecture and API**: An explanation of the components and the C API bridge. @@ -36,6 +41,7 @@ The C API uses shared data structures to pass information between Go and C++. These C functions wrap the C++ service logic, making them callable from Go via `cgo`. * `createQwen3Session(const char* model_path)`: Loads a model from the specified path and creates a session handle. +* `createDeepseekOCRSession(const char* model_path)`: Loads a DeepSeek-OCR model from the specified path and creates a session handle. This session is specifically designed to handle visual inputs and OCR tasks. * `insertSession(const char* session_id, MllmCAny handle)`: Registers the created session with a unique ID in the service. * `sendRequest(const char* session_id, const char* json_request)`: Sends a user's request (in JSON format) to the specified session for processing. * `pollResponse(const char* session_id)`: Polls for a response from the model. This is used for streaming results back to the client. @@ -46,9 +52,21 @@ These C functions wrap the C++ service logic, making them callable from Go via ` The `mllm-server` is an HTTP server written in Go. It acts as a bridge between network clients and the MLLM C++ core. -* **Initialization**: On startup, it initializes the MLLM context, starts the backend service, and loads the specified model into a session. -* **API Endpoint**: It exposes an OpenAI-compatible endpoint at `/v1/chat/completions`. -* **Request Handling**: When it receives a request, it retrieves the appropriate model session, forwards the request JSON to the C++ core using `sendRequest`, and then continuously polls for results using `pollResponse`. +* **Initialization (Dual Model Support)**: On startup, the server checks for two command-line arguments: + + * ``--model-path``: Path to the Qwen3 LLM model. + * ``--ocr-model-path``: Path to the DeepSeek-OCR model. + + If provided, the server initializes the respective sessions (`createQwen3Session` and/or `createDeepseekOCRSession`) and registers them with their directory names as Session IDs. + +* **API Endpoint**: It exposes an OpenAI-compatible endpoint at ``/v1/chat/completions``. + +* **Request Handling & OCR Preprocessing**: + When a request arrives, the server inspects the ``model`` parameter. + + * **Text Requests**: Routed directly to the Qwen3 session. + * **OCR Requests**: If the model name contains "OCR", the server triggers a preprocessing step (`preprocessRequestForOCR`). It detects Base64 encoded images in the payload, decodes them, saves them to temporary files on the device, and modifies the request to point the C++ backend to these file paths. + * **Streaming Response**: Results are streamed back to the client over HTTP using Server-Sent Events (SSE). **Key Service Layer Files** @@ -64,6 +82,7 @@ The `mllm-server` functionality is implemented across several key Go files: 3. Forwarding the request to the C++ core (`session.SendRequest`). 4. Continuously polling (`session.PollResponse`) for responses from the C++ layer. 5. Streaming the responses back to the client in the standard Server-Sent Events (SSE) format. + 6. For OCR models, it identifies Base64 encoded images in the request, decodes them into temporary files on the Android device, and updates the request payload with the local file paths. * ``pkg/mllm/service.go`` * **Purpose**: Acts as the Go-level **session manager** (`Service`). It holds a map that links model IDs (e.g., "Qwen3-0.6B-w4a32kai") to their active MLLM sessions (`*mllm.Session`). `handlers.go` uses this to find the correct session instance. * ``pkg/api/types.go`` @@ -78,6 +97,9 @@ The `mllm-client` is an interactive command-line tool that allows users to chat * **API Communication**: It formats the user input into an OpenAI-compatible JSON request and sends it to the `mllm-server`. * **Response Handling**: It receives the SSE stream from the server, decodes the JSON chunks, and prints the assistant's response to the console in real-time. +.. note:: + **Current Limitation**: The ``mllm-client`` is currently hardcoded to use the **Qwen3** model only. It does not support switching to the DeepSeek-OCR model or uploading images. For OCR tasks, please use a GUI client like Chatbox. + Alternative Client: Chatbox ~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -99,6 +121,15 @@ After running this command, configure Chatbox with the following settings: * **API Key**: (Can be left blank or any value; the server does not currently check it) * **API Host**: ``http://localhost:8081`` * **API Path**: ``/v1/chat/completions`` +* **Model**: **[Important]** You must manually add the model name by clicking **+ New**. + +The name **MUST match the folder name** of the model directory on the Android device. + + * For the LLM, enter: ``Qwen3-0.6B-w4a32kai`` (or your specific LLM folder name). + * For OCR, enter: ``DeepSeek-OCR-w4a8-i8mm-kai`` (or your specific OCR folder name). + +.. note:: + The server uses the directory name (e.g., ``filepath.Base``) as the session ID. If you enter a different name in Chatbox, the server will return a "Model not found" error. Once configured, you can click the **Check** button to ensure the connection is successful. Please note that this step must be performed while the server is running. @@ -168,15 +199,15 @@ First, we compile the MLLM C++ core, which produces the essential shared librari rsync -avz --checksum -e 'ssh -p ' --exclude 'build' --exclude '.git' ./ @:/your_workspace/your_programname/ 2. **Run the Build Task**: - On the build server, execute the build task. This task uses `tasks/build_android_debug.yaml` to configure and run CMake. + On the build server, execute the build task. This task uses `tasks/build_android.yaml` to configure and run CMake. - Before executing this step, you also need to ensure that the hardcoded directories in build_android_debug.yaml have been modified to match your requirements. The modification method is the same as for the Go compilation file mentioned earlier. + Before executing this step, you also need to ensure that the hardcoded directories in build_android.yaml have been modified to match your requirements. The modification method is the same as for the Go compilation file mentioned earlier. .. code-block:: bash # These commands are run on your build server. cd /your_workspace/your_programname/ - python task.py tasks/build_android_debug.yaml + python task.py tasks/build_android.yaml 3. **Retrieve Compiled Libraries**: After the build succeeds, copy the compiled shared libraries from the build server back to your local machine. These libraries are the C++ backend that the Go application will call. @@ -285,7 +316,14 @@ This covers testing with both the Go CLI client and Chatbox. chmod +x mllm_web_server export LD_LIBRARY_PATH=. # Ensure you provide the correct path to your model - ./mllm_web_server --model-path /path/to/your/model_directory/model_name + # Option A: Run with Qwen3 LLM only + ./mllm_web_server --model-path /path/to/your/qwen3_model_dir + + # Option B: Run with both Qwen3 LLM and DeepSeek-OCR + # Use this if you plan to switch between text chat and OCR tasks + ./mllm_web_server \ + --model-path /path/to/your/qwen3_model_dir \ + --ocr-model-path /path/to/your/deepseek_ocr_model_dir .. warning:: The `export LD_LIBRARY_PATH=.` command is crucial. It tells the Android dynamic linker to look for the `.so` files in the current directory. Without it, the server will fail to start. @@ -302,7 +340,7 @@ You should now be able to interact with the model from the client terminal. Type **B. Testing with Chatbox (Host Machine)** 1. **Terminal 1: Start the Server**: - Follow the same instructions as in **Step 5.A.1** to start the server on the Android device. + Follow the same instructions as in **Step 5.A.1** to start the server on the Android device.If you intend to test the OCR functionality, please ensure that you used Option B (which specifies the --ocr-model-path). 2. **Terminal 2: Set up Port Forwarding**: On your host machine (not in the adb shell), run the following command. This maps your local host port (e.g., 8081) to the device's port (e.g., 8080). diff --git a/mllm-cli/cmd/mllm-server/main.go b/mllm-cli/cmd/mllm-server/main.go index fc08b5a9e..b66011454 100644 --- a/mllm-cli/cmd/mllm-server/main.go +++ b/mllm-cli/cmd/mllm-server/main.go @@ -18,9 +18,10 @@ import ( func main() { modelPath := flag.String("model-path", "", "Path to the MLLM model directory.") + ocrModelPath := flag.String("ocr-model-path", "", "Path to the DeepSeek-OCR model directory.") flag.Parse() - if *modelPath == "" { + if *modelPath == "" && *ocrModelPath == "" { log.Fatal("FATAL: --model-path argument is required.") } @@ -36,19 +37,37 @@ func main() { mllmService := pkgmllm.NewService() - log.Printf("Loading model and creating session from: %s", *modelPath) - session, err := mllm.NewSession(*modelPath) - if err != nil { - log.Fatalf("FATAL: Failed to create session: %v", err) + if *modelPath != "" { + log.Printf("Loading Qwen3 model and creating session from: %s", *modelPath) + session, err := mllm.NewSession(*modelPath) + if err != nil { + log.Fatalf("FATAL: Failed to create Qwen3 session: %v", err) + } + + sessionID := filepath.Base(*modelPath) + if !session.Insert(sessionID) { + session.Close() + log.Fatalf("FATAL: Failed to insert Qwen3 session with ID '%s'", sessionID) + } + mllmService.RegisterSession(sessionID, session) + log.Printf("Qwen3 Session created and registered successfully with ID: %s", sessionID) } - sessionID := filepath.Base(*modelPath) - if !session.Insert(sessionID) { - session.Close() - log.Fatalf("FATAL: Failed to insert session with ID '%s'", sessionID) + if *ocrModelPath != "" { + log.Printf("Loading DeepSeek-OCR model and creating session from: %s", *ocrModelPath) + session, err := mllm.NewDeepseekOCRSession(*ocrModelPath) + if err != nil { + log.Fatalf("FATAL: Failed to create DeepSeek-OCR session: %v", err) + } + + sessionID := filepath.Base(*ocrModelPath) + if !session.Insert(sessionID) { + session.Close() + log.Fatalf("FATAL: Failed to insert DeepSeek-OCR session with ID '%s'", sessionID) + } + mllmService.RegisterSession(sessionID, session) + log.Printf("DeepSeek-OCR Session created and registered successfully with ID: %s", sessionID) } - mllmService.RegisterSession(sessionID, session) - log.Printf("Session created and registered successfully with ID: %s", sessionID) httpServer := server.NewServer(":8080", mllmService) diff --git a/mllm-cli/mllm/c.go b/mllm-cli/mllm/c.go index 27f02f09e..ec7995897 100644 --- a/mllm-cli/mllm/c.go +++ b/mllm-cli/mllm/c.go @@ -73,6 +73,23 @@ func NewSession(modelPath string) (*Session, error) { return s, nil } +func NewDeepseekOCRSession(modelPath string) (*Session, error) { + cModelPath := C.CString(modelPath) + defer C.free(unsafe.Pointer(cModelPath)) + + handle := C.createDeepseekOCRSession(cModelPath) + if !isOk(handle) { + return nil, fmt.Errorf("底层C API createDeepseekOCRSession 失败") + } + s := &Session{cHandle: handle} + runtime.SetFinalizer(s, func(s *Session) { + fmt.Println("[Go Finalizer] Mllm OCR Session automatically released.") + C.freeSession(s.cHandle) + }) + + return s, nil +} + func (s *Session) Close() { if C.MllmCAny_get_v_custom_ptr(s.cHandle) != nil { fmt.Println("[Go Close] Mllm Session manually closed.") diff --git a/mllm-cli/pkg/server/handlers.go b/mllm-cli/pkg/server/handlers.go index a15946fd7..2041117da 100644 --- a/mllm-cli/pkg/server/handlers.go +++ b/mllm-cli/pkg/server/handlers.go @@ -1,14 +1,168 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. package server import ( + "encoding/base64" "encoding/json" "fmt" "log" "net/http" + "os" + "path/filepath" + "strings" "github.com/google/uuid" ) +func decodeBase64Image(uri string) ([]byte, string, error) { + if !strings.HasPrefix(uri, "data:image/") { + return nil, "", fmt.Errorf("invalid data URI: must start with 'data:image/'") + } + + parts := strings.SplitN(uri, ",", 2) + if len(parts) != 2 { + return nil, "", fmt.Errorf("invalid base64 image data") + } + + meta := parts[0] + ext := "" + if strings.Contains(meta, "image/jpeg") { + ext = ".jpg" + } else if strings.Contains(meta, "image/webp") { + ext = ".webp" + } else if strings.Contains(meta, "image/png") { + ext = ".png" + } else { + return nil, "", fmt.Errorf("unsupported image format in data URI") + } + + data, err := base64.StdEncoding.DecodeString(parts[1]) + if err != nil { + return nil, "", fmt.Errorf("failed to decode base64: %v", err) + } + return data, ext, nil +} + +func (s *Server) preprocessRequestForOCR(payload map[string]interface{}) (bool, func(), error) { + messages, ok := payload["messages"].([]interface{}) + if !ok { + return false, nil, fmt.Errorf("invalid messages format") + } + + var userMessage map[string]interface{} + var contentArray []interface{} + var imageFoundInPayload bool = false + + for i := len(messages) - 1; i >= 0; i-- { + msg, ok := messages[i].(map[string]interface{}) + if !ok { + continue + } + if role, _ := msg["role"].(string); role == "user" { + delete(msg, "images") + + if content, ok := msg["content"].([]interface{}); ok { + userMessage = msg + contentArray = content + log.Println("[Handler] Found OpenAI Vision 'content' array.") + break + } else if images, ok := msg["images"].([]interface{}); ok && len(images) > 0 { + log.Println("[Handler] Found custom 'images' field.") + base64URI, ok := images[0].(string) + if !ok || !strings.HasPrefix(base64URI, "data:image") { + return false, func() {}, fmt.Errorf("image data in 'images' field is not a valid base64 URI") + } + + textContent, _ := msg["content"].(string) + + contentArray = []interface{}{ + map[string]interface{}{"type": "text", "text": textContent}, + map[string]interface{}{"type": "image_url", "image_url": map[string]interface{}{"url": base64URI}}, + } + userMessage = msg + break + } else { + userMessage = msg + contentArray = nil + log.Println("[Handler] Found text-only user message.") + break + } + } + } + + if userMessage == nil { + return false, nil, fmt.Errorf("no user message found") + } + + var textContent string + var base64URI string + if contentArray != nil { + for _, part := range contentArray { + partMap, ok := part.(map[string]interface{}) + if !ok { + continue + } + partType, _ := partMap["type"].(string) + if partType == "text" { + textContent, _ = partMap["text"].(string) + } else if partType == "image_url" { + imageUrl, ok := partMap["image_url"].(map[string]interface{}) + if ok { + base64URI, _ = imageUrl["url"].(string) + imageFoundInPayload = true + } + } + } + } else { + textContent, _ = userMessage["content"].(string) + } + + if strings.TrimSpace(textContent) == "" { + log.Println("[Handler] User content is empty, auto-filling default prompt.") + textContent = "Convert the document to markdown." + } + + if !imageFoundInPayload { + log.Println("[Handler] No new image found in payload for OCR request.") + userMessage["content"] = textContent + delete(userMessage, "images") + return false, func() {}, nil + } + + log.Println("[Handler] New image found. Processing...") + + imageData, ext, err := decodeBase64Image(base64URI) + if err != nil { + return false, nil, err + } + tempFile, err := os.CreateTemp("", "ocr_temp_*"+ext) + if err != nil { + return false, nil, fmt.Errorf("failed to create temp file: %v", err) + } + if _, err := tempFile.Write(imageData); err != nil { + tempFile.Close() + os.Remove(tempFile.Name()) + return false, nil, fmt.Errorf("failed to write to temp file: %v", err) + } + tempFile.Close() + absPath, err := filepath.Abs(tempFile.Name()) + if err != nil { + os.Remove(tempFile.Name()) + return false, nil, fmt.Errorf("failed to get absolute path for temp file: %v", err) + } + log.Printf("[Handler] Saved Base64 image to temporary file: %s", absPath) + + userMessage["content"] = textContent + userMessage["images"] = []interface{}{absPath} + + cleanupFunc := func() { + log.Printf("[Handler] Cleaning up temporary file: %s", absPath) + os.Remove(absPath) + } + return true, cleanupFunc, nil +} + func (s *Server) chatCompletionsHandler() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { @@ -23,6 +177,28 @@ func (s *Server) chatCompletionsHandler() http.HandlerFunc { } modelName, _ := requestPayload["model"].(string) + + if strings.Contains(strings.ToLower(modelName), "ocr") || strings.HasSuffix(strings.ToLower(modelName), "-ocr") { + log.Printf("[Handler] OCR model detected ('%s'). Checking for image data...", modelName) + + imageFound, cleanupFunc, err := s.preprocessRequestForOCR(requestPayload) + if err != nil { + log.Printf("ERROR: Failed to process OCR request: %v", err) + http.Error(w, fmt.Sprintf("Failed to process OCR request: %v", err), http.StatusBadRequest) + return + } + defer cleanupFunc() + + if !imageFound { + log.Println("ERROR: OCR model is single-turn and requires an image in *every* request. Text-only follow-ups are not supported.") + http.Error(w, "OCR model is single-turn and requires an image in every request. Text-only follow-ups are not supported.", http.StatusBadRequest) + return + } + + } else { + log.Printf("[Handler] Text model detected ('%s'). Forwarding request...", modelName) + } + session, err := s.mllmService.GetSession(modelName) if err != nil { log.Printf("ERROR: Could not get session for model '%s': %v", modelName, err) @@ -45,6 +221,7 @@ func (s *Server) chatCompletionsHandler() http.HandlerFunc { http.Error(w, "Failed to re-marshal request payload", http.StatusInternalServerError) return } + if !session.SendRequest(string(requestBytes)) { http.Error(w, "Failed to process request by the model", http.StatusInternalServerError) return @@ -53,26 +230,26 @@ func (s *Server) chatCompletionsHandler() http.HandlerFunc { w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Cache-Control", "no-cache") w.Header().Set("Connection", "keep-alive") - flusher, _ := w.(http.Flusher) log.Printf("Streaming response for session %s (Request ID: %s)...", session.SessionID(), requestID) - + for { if r.Context().Err() != nil { log.Printf("Client disconnected. Stopping stream for %s.", session.SessionID()) break } - rawResponse := session.PollResponse(requestID) if rawResponse == "" { log.Println("Received empty response from poll, assuming stream has ended.") break } - + fmt.Fprintf(w, "data: %s\n\n", rawResponse) - flusher.Flush() + if fl, ok := w.(http.Flusher); ok { + fl.Flush() + } var responseChunk map[string]interface{} if json.Unmarshal([]byte(rawResponse), &responseChunk) == nil { @@ -88,7 +265,9 @@ func (s *Server) chatCompletionsHandler() http.HandlerFunc { } fmt.Fprintf(w, "data: [DONE]\n\n") - flusher.Flush() + if fl, ok := w.(http.Flusher); ok { + fl.Flush() + } log.Printf("Finished streaming for session %s (Request ID: %s).", session.SessionID(), requestID) } } \ No newline at end of file diff --git a/mllm/c_api/Runtime.cpp b/mllm/c_api/Runtime.cpp index b116debbe..e0b301ced 100644 --- a/mllm/c_api/Runtime.cpp +++ b/mllm/c_api/Runtime.cpp @@ -4,6 +4,7 @@ #include "mllm/c_api/Runtime.h" #include "mllm/engine/service/Service.hpp" #include "mllm/models/qwen3/modeling_qwen3_service.hpp" +#include "mllm/models/deepseek_ocr/modeling_deepseek_ocr_service.hpp" #include #include #include @@ -95,6 +96,25 @@ MllmCAny createQwen3Session(const char* model_path) { } } +MllmCAny createDeepseekOCRSession(const char* model_path) { + if (model_path == nullptr) { + printf("[C++ Service] createDeepseekOCRSession error: invalid arguments.\n"); + return MllmCAny{.type_id = kRetCode, .v_return_code = -1}; + } + try { + auto dpsk_session = std::make_shared(); + dpsk_session->fromPreTrain(model_path); + + auto* handle = new MllmSessionWrapper(); + handle->session_ptr = dpsk_session; + + return MllmCAny{.type_id = kCustomObject, .v_custom_ptr = handle}; + } catch (const std::exception& e) { + printf("[C++ Service] createDeepseekOCRSession exception: %s\n", e.what()); + return MllmCAny{.type_id = kRetCode, .v_return_code = -1}; + } +} + MllmCAny insertSession(const char* session_id, MllmCAny handle) { if (session_id == nullptr || handle.type_id != kCustomObject || handle.v_custom_ptr == nullptr) { printf("[C++ Service] insertSession error: invalid arguments.\n"); diff --git a/mllm/c_api/Runtime.h b/mllm/c_api/Runtime.h index 9b18b0259..97b34f4e8 100644 --- a/mllm/c_api/Runtime.h +++ b/mllm/c_api/Runtime.h @@ -42,6 +42,8 @@ void setLogLevel(int level); MllmCAny createQwen3Session(const char* model_path); +MllmCAny createDeepseekOCRSession(const char* model_path); + MllmCAny insertSession(const char* session_id, MllmCAny handle); MllmCAny freeSession(MllmCAny handle); diff --git a/mllm/models/deepseek_ocr/modeling_deepseek_ocr.hpp b/mllm/models/deepseek_ocr/modeling_deepseek_ocr.hpp index aa9afcdf1..332889371 100644 --- a/mllm/models/deepseek_ocr/modeling_deepseek_ocr.hpp +++ b/mllm/models/deepseek_ocr/modeling_deepseek_ocr.hpp @@ -655,6 +655,9 @@ class DeepseekOCRForCausalLM final : public nn::Module, public ARGeneration { eos_token_id_ = config.eos_token_id; } + inline nn::StaticCache& kvCache() { return kv_cache_; } + inline int64_t eosTokenId() const { return eos_token_id_; } + ARGenerationOutputPast forward(const ARGenerationOutputPast& input, const ARGenerationArgs& args) override { auto patches = input.count("patches") ? input.at("patches") : Tensor::nil(); auto image_ori = input.count("image_ori") ? input.at("image_ori") : Tensor::nil(); diff --git a/mllm/models/deepseek_ocr/modeling_deepseek_ocr_service.hpp b/mllm/models/deepseek_ocr/modeling_deepseek_ocr_service.hpp new file mode 100644 index 000000000..eea2dc60f --- /dev/null +++ b/mllm/models/deepseek_ocr/modeling_deepseek_ocr_service.hpp @@ -0,0 +1,216 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. +#pragma once + +#include +#include +#include +#include + +#include "mllm/engine/service/Session.hpp" +#include "mllm/engine/prefix_cache/Cache.hpp" + +#include "mllm/models/deepseek_ocr/configuration_deepseek_ocr.hpp" +#include "mllm/models/deepseek_ocr/tokenization_deepseek_ocr.hpp" +#include "mllm/models/deepseek_ocr/conversation_preprocess.hpp" +#include "mllm/models/deepseek_ocr/modeling_deepseek_ocr.hpp" +#include "mllm/preprocessor/visual/ImageTransform.hpp" + +namespace mllm::models::deepseek_ocr { + +class DeepseekOCRSession final : public ::mllm::service::Session { + public: + DeepseekOCRSession() { + initializeTemplates(); + } + + void fromPreTrain(const std::string& model_path) override { + namespace fs = std::filesystem; + fs::path root = fs::path(model_path).lexically_normal(); + fs::path config_file = root / "config.json"; + fs::path model_file = root / "model.mllm"; + fs::path tokenizer_file = root / "tokenizer.json"; + + if (!fs::exists(config_file)) throw std::runtime_error(config_file.string() + " not found"); + if (!fs::exists(model_file)) throw std::runtime_error(model_file.string() + " not found"); + if (!fs::exists(tokenizer_file)) throw std::runtime_error(tokenizer_file.string() + " not found"); + + printf("[C++ Service] Loading DeepSeek-OCR model from: %s\n", model_path.c_str()); + + config_ = DpskOcrConfig(config_file.string()); + + config_.max_cache_length = 16384; + + model_ = std::make_shared(config_); + model_->load(load(model_file.string(), ModelFileVersion::kV2)); + tokenizer_ = std::make_shared(tokenizer_file.string()); + + printf("[C++ Service] DeepSeek-OCR model loaded successfully.\n"); + } + + void streamGenerate(const nlohmann::json& request, + const std::function& callback) override { + + model_->kvCache().clearCache(); + + const int base_size = 512; + const int image_size = 512; + const int PATCH_SIZE = 16; + const int DOWN_SAMPLE_RATIO = 4; + const std::string IMAGE_TOKEN = ""; + const int64_t IMAGE_TOKEN_ID = 128815; + + auto image_transform = BasicImageTransform(std::nullopt, std::nullopt, + std::vector{0.5, 0.5, 0.5}, + std::vector{0.5, 0.5, 0.5}); + + auto images = loadImages(request["messages"]); + if (images.empty()) { + printf("[C++ Service] Warning: No images found in OCR request.\n"); + } + + std::string user_text = ""; + const auto& messages = request["messages"]; + if (messages.is_array() && !messages.empty()) { + for (int i = static_cast(messages.size()) - 1; i >= 0; --i) { + const auto& msg = messages[i]; + std::string role = msg.value("role", ""); + if (role == "user" || role == "<|User|>") { + if (msg.contains("content") && msg["content"].is_string()) { + user_text = msg["content"].get(); + break; + } + } + } + } + + std::vector tokenized_str; + std::vector images_seq_mask; + std::vector images_list; + std::vector images_crop_list; + std::vector> images_spatial_crop; + + std::vector prefix_tokens; + prefix_tokens.push_back(config_.bos_token_id); + auto user_role_tokens = tokenizer_->encode("<|User|>"); + prefix_tokens.insert(prefix_tokens.end(), user_role_tokens.begin(), user_role_tokens.end()); + + tokenized_str.insert(tokenized_str.end(), prefix_tokens.begin(), prefix_tokens.end()); + images_seq_mask.insert(images_seq_mask.end(), prefix_tokens.size(), (int8_t)false); + + if (!images.empty()) { + auto image = images[0]; + std::tuple crop_ratio; + std::vector images_crop_raw; + + if (image.h() <= 640 && image.w() <= 640) { + crop_ratio = {1, 1}; + } else { + auto p = dynamicPreprocess(image, 2, 9, image_size, false); + images_crop_raw = p.first; + crop_ratio = p.second; + } + + auto global_view = image.pad(base_size, base_size, (int)(255 * 0.5), (int)(255 * 0.5), (int)(255 * 0.5)); + images_list.emplace_back(image_transform(global_view)); + + auto [width_crop_num, height_crop_num] = crop_ratio; + images_spatial_crop.emplace_back(width_crop_num, height_crop_num); + + if (width_crop_num > 1 || height_crop_num > 1) { + for (const auto& _i : images_crop_raw) { + images_crop_list.emplace_back(image_transform(_i)); + } + } + + auto num_queries_base = std::ceil((base_size / PATCH_SIZE) / DOWN_SAMPLE_RATIO); + std::vector tokenized_image; + tokenized_image.reserve((num_queries_base + 1) * num_queries_base + 1 + (width_crop_num > 1 ? 1000 : 0)); + + for (int i = 0; i < num_queries_base; ++i) { + tokenized_image.insert(tokenized_image.end(), num_queries_base, IMAGE_TOKEN_ID); + tokenized_image.push_back(IMAGE_TOKEN_ID); + } + tokenized_image.push_back(IMAGE_TOKEN_ID); + + if (width_crop_num > 1 || height_crop_num > 1) { + auto num_queries = std::ceil((image_size / PATCH_SIZE) / DOWN_SAMPLE_RATIO); + for (int h = 0; h < num_queries * height_crop_num; ++h) { + tokenized_image.insert(tokenized_image.end(), num_queries * width_crop_num, IMAGE_TOKEN_ID); + tokenized_image.push_back(IMAGE_TOKEN_ID); + } + } + + tokenized_str.insert(tokenized_str.end(), tokenized_image.begin(), tokenized_image.end()); + images_seq_mask.insert(images_seq_mask.end(), tokenized_image.size(), (int8_t)true); + } + + std::string suffix_text = "\n<|grounding|>" + user_text + "<|Assistant|>"; + auto tokenized_suffix = tokenizer_->encode(suffix_text); + + tokenized_str.insert(tokenized_str.end(), tokenized_suffix.begin(), tokenized_suffix.end()); + images_seq_mask.insert(images_seq_mask.end(), tokenized_suffix.size(), (int8_t)false); + + auto input_ids = Tensor::fromVector(tokenized_str, {1, (int32_t)tokenized_str.size()}, kInt64, kCPU); + auto images_seq_mask_tensor = Tensor::fromVector(images_seq_mask, {1, (int32_t)images_seq_mask.size()}, kInt8, kCPU); + + Tensor images_ori_tensor; + if (!images_list.empty()) { + images_ori_tensor = nn::functional::stack(images_list, 0); + } else { + images_ori_tensor = Tensor::zeros({1, 3, image_size, image_size}, kFloat32, kCPU); + } + + auto images_spatial_crop_tensor = Tensor::zeros({(int32_t)images_spatial_crop.size(), 2}, kInt64, kCPU); + auto* _ptr = images_spatial_crop_tensor.ptr(); + for (int _i = 0; _i < images_spatial_crop.size(); ++_i) { + auto [l, h] = images_spatial_crop[_i]; + _ptr[2 * _i + 0] = l; + _ptr[2 * _i + 1] = h; + } + + Tensor images_crop_tensor; + if (!images_crop_list.empty()) { + images_crop_tensor = nn::functional::stack(images_crop_list, 0); + } else { + images_crop_tensor = Tensor::zeros({1, 3, image_size, image_size}, kFloat32, kCPU); + } + + ARGenerationOutputPast input; + input["sequence"] = input_ids; + input["patches"] = images_crop_tensor; + input["image_ori"] = images_ori_tensor; + input["images_spatial_crop"] = images_spatial_crop_tensor; + input["images_seq_mask"] = images_seq_mask_tensor; + + ARGenerationArgs args; + args["kv_cache"] = mllm::AnyValue(&model_->kvCache()); + args["temperature"] = request.value("temperature", 1.0f); + args["top_k"] = request.value("top_k", 0); + args["top_p"] = request.value("top_p", 0.0f); + args["max_length"] = request.value("max_length", 1024); + args["do_sample"] = request.value("do_sample", false); + + model_->streamGenerate(input, args, [this, &callback](int64_t idx) { + bool finished = false; + std::string ret_token; + + if (idx == model_->eosTokenId()) { + finished = true; + ret_token = ""; + } else { + finished = false; + ret_token = tokenizer_->decode({idx}); + } + + callback(ret_token, finished); + }); + } + + private: + std::shared_ptr model_; + std::shared_ptr tokenizer_; + DpskOcrConfig config_; +}; + +} // namespace mllm::models::deepseek_ocr \ No newline at end of file