From 81c4861a017d67613e1fca73f41d3b7478a50199 Mon Sep 17 00:00:00 2001 From: yexiangle Date: Wed, 23 Apr 2025 20:46:27 +0800 Subject: [PATCH 01/21] fix:fix l1 save problem (#269) * fix:fix l1 save problem * fix:simplify the code * fix:delete no use import * fix:delete useless data --- lpm_kernel/train/trainprocess_service.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/lpm_kernel/train/trainprocess_service.py b/lpm_kernel/train/trainprocess_service.py index b7bb9407..a348ffcd 100644 --- a/lpm_kernel/train/trainprocess_service.py +++ b/lpm_kernel/train/trainprocess_service.py @@ -30,10 +30,10 @@ from lpm_kernel.api.domains.trainprocess.process_step import ProcessStep from lpm_kernel.api.domains.trainprocess.progress_holder import TrainProgressHolder from lpm_kernel.train.training_params_manager import TrainingParamsManager +from lpm_kernel.common.repository.database_session import DatabaseSession +from lpm_kernel.api.domains.kernel.routes import store_l1_data import gc import subprocess -import shlex - from lpm_kernel.configs.logging import get_train_process_logger, TRAIN_LOG_FILE logger = get_train_process_logger() @@ -216,12 +216,16 @@ def generate_biography(self) -> bool: # Mark step as in progress self.progress.mark_step_status(ProcessStep.GENERATE_BIOGRAPHY, Status.IN_PROGRESS) logger.info("Starting biography generation...") - + # Generate L1 data and biography logger.info("Generating L1 data and biography...") - generate_l1_from_l0() + l1_data = generate_l1_from_l0() logger.info("Successfully generated L1 data and biography") - + + # Store L1 data + with DatabaseSession.session() as session: + store_l1_data(session, l1_data) + # Mark step as completed self.progress.mark_step_status(ProcessStep.GENERATE_BIOGRAPHY, Status.COMPLETED) logger.info("Biography generation completed successfully") From ce9cfcb4a8ab193adff64a8f475112af7962d5d4 Mon Sep 17 00:00:00 2001 From: KKKKKKKevin <115385420+kevin-mindverse@users.noreply.github.com> Date: Wed, 23 Apr 2025 20:47:41 +0800 Subject: [PATCH 02/21] Feature/fix update instace (#272) * fix password update logic, if there's more than one load * update fix --- lpm_kernel/api/domains/loads/load_service.py | 44 ++++++++++++-------- lpm_kernel/api/domains/upload/routes.py | 2 +- 2 files changed, 27 insertions(+), 19 deletions(-) diff --git a/lpm_kernel/api/domains/loads/load_service.py b/lpm_kernel/api/domains/loads/load_service.py index 593eaf09..463de918 100644 --- a/lpm_kernel/api/domains/loads/load_service.py +++ b/lpm_kernel/api/domains/loads/load_service.py @@ -60,10 +60,10 @@ def create_load(name: str, description: Optional[str] = None, email: str = "", i """ try: with DatabaseSession.session() as session: - # Check if a load record with the same name exists - existing_load = session.query(Load).filter(Load.name == name).first() + # Check if any load record exists + existing_load = session.query(Load).first() if existing_load: - return None, f"A load record with name '{name}' already exists", 400 + return None, f"A load record already exists. Only one load record is allowed in the system.", 400 # Create a new Load instance new_load = Load( @@ -106,13 +106,9 @@ def update_load(load_id: str, data: Dict[str, Any]) -> Tuple[Optional[Dict[str, return None, "Load record not found", 404 # Update fields - updatable_fields = ["name", "description", "email", "avatar_data", "instance_id", "status"] + updatable_fields = ["name", "description", "email", "avatar_data"] for field in updatable_fields: if field in data: - # Special validation for status field - if field == "status" and data[field] not in ["active", "inactive", "deleted"]: - return None, f"Status must be one of 'active', 'inactive', or 'deleted'", 400 - setattr(load, field, data[field]) session.commit() @@ -144,7 +140,7 @@ def update_current_load(data: Dict[str, Any]) -> Tuple[Optional[Dict[str, Any]], return LoadService.update_load(current_load.id, data) @staticmethod - def update_instance_id(instance_id: str, instance_password: str = None) -> Tuple[bool, Optional[str], int]: + def update_instance_credentials(instance_id: str, instance_password: str) -> Tuple[bool, Optional[str], int]: """ Update the instance_id and instance_password of the current load @@ -158,22 +154,34 @@ def update_instance_id(instance_id: str, instance_password: str = None) -> Tuple - Error message or None if successful - Status code (200 for success, 400/404/500 for errors) """ - try: - logger.info(f"Updating instance_id to {instance_id} and instance_password to {instance_password}") + try: + logger.info(f"Updating instance credentials: ID={instance_id}, Password={'*****' if instance_password else None}") + + # Get current load + current_load, error, status_code = LoadService.get_current_load() + if error: + return False, error, status_code + + if not current_load: + logger.warning("Load record not found") + return False, "Load record not found", 404 + + # Update fields in database with DatabaseSession.session() as session: - load = session.query(Load).first() + load = session.query(Load).get(current_load.id) if not load: - logger.warning("Load record not found") - return False, "Load record not found", 404 + logger.warning("Load record not found in database") + return False, "Load record not found in database", 404 load.instance_id = instance_id - if instance_password: - load.instance_password = instance_password + load.instance_password = instance_password + session.commit() - logger.info(f"Updated instance_id in database to: {instance_id}, instance_password: {instance_password}") + + logger.info(f"Updated instance credentials successfully") return True, None, 200 except Exception as e: - logger.error(f"Error updating instance_id: {str(e)}", exc_info=True) + logger.error(f"Error updating instance credentials: {str(e)}", exc_info=True) return False, f"Internal server error: {str(e)}", 500 @staticmethod diff --git a/lpm_kernel/api/domains/upload/routes.py b/lpm_kernel/api/domains/upload/routes.py index 6efdbe7d..080e01bb 100644 --- a/lpm_kernel/api/domains/upload/routes.py +++ b/lpm_kernel/api/domains/upload/routes.py @@ -39,7 +39,7 @@ def register_upload(): )) instance_password = result.get("instance_password") - LoadService.update_instance_id(instance_id_new, instance_password) + LoadService.update_instance_credentials(instance_id_new, instance_password) return jsonify(APIResponse.success( data=result From fd64b4e5da14e206066a78db070358a1ca3d3387 Mon Sep 17 00:00:00 2001 From: ryangyuan Date: Thu, 24 Apr 2025 11:02:52 +0800 Subject: [PATCH 03/21] fix: fetch uploadInfo in homepage (#271) --- lpm_frontend/src/app/home/page.tsx | 23 +++++++++++++++++----- lpm_frontend/src/store/useLoadInfoStore.ts | 22 ++++++++------------- 2 files changed, 26 insertions(+), 19 deletions(-) diff --git a/lpm_frontend/src/app/home/page.tsx b/lpm_frontend/src/app/home/page.tsx index de573415..8473f25a 100644 --- a/lpm_frontend/src/app/home/page.tsx +++ b/lpm_frontend/src/app/home/page.tsx @@ -4,11 +4,11 @@ import { useState, useEffect } from 'react'; import { useRouter } from 'next/navigation'; import CreateSecondMe from '@/app/home/components/Create'; import dynamic from 'next/dynamic'; -import { getUploadCount } from '@/service/info'; +import type { ILoadInfo } from '@/service/info'; +import { getCurrentInfo, getUploadCount } from '@/service/info'; import { ROUTER_PATH } from '@/utils/router'; import Footer from './components/Footer'; import SocialMedia from './components/SocialMedia'; -import { useLoadInfoStore } from '@/store/useLoadInfoStore'; import { message } from 'antd'; const NetworkSphere = dynamic(() => import('@/components/NetworkSphere'), { @@ -23,8 +23,21 @@ export default function Home() { const [isMounted, setIsMounted] = useState(false); const [contentVisible, setContentVisible] = useState(false); - const loadInfo = useLoadInfoStore((state) => state.loadInfo); - const firstLoaded = useLoadInfoStore((state) => state.firstLoaded); + const [loading, setLoading] = useState(true); + const [loadInfo, setLoadInfo] = useState(null); + + useEffect(() => { + getCurrentInfo() + .then((res) => { + if (res.data.code === 0) { + setLoadInfo(res.data.data); + localStorage.setItem('upload', JSON.stringify(res.data.data)); + } + }) + .finally(() => { + setLoading(false); + }); + }, []); useEffect(() => { setIsMounted(true); @@ -115,7 +128,7 @@ export default function Home() { - {firstLoaded && ( + {!loading && (
diff --git a/lpm_frontend/src/store/useLoadInfoStore.ts b/lpm_frontend/src/store/useLoadInfoStore.ts index cd9e01fd..17be5d79 100644 --- a/lpm_frontend/src/store/useLoadInfoStore.ts +++ b/lpm_frontend/src/store/useLoadInfoStore.ts @@ -4,7 +4,6 @@ import { EVENT } from '@/utils/event'; interface ILoadInfoState { loadInfo: ILoadInfo | null; - firstLoaded: boolean; fetchLoadInfo: () => void; setLoadInfo: (info: ILoadInfo | null) => void; clearLoadInfo: () => void; @@ -12,20 +11,15 @@ interface ILoadInfoState { export const useLoadInfoStore = create((set) => ({ loadInfo: null, - firstLoaded: false, fetchLoadInfo: () => { - getCurrentInfo() - .then((res) => { - if (res.data.code === 0) { - set({ loadInfo: res.data.data }); - localStorage.setItem('upload', JSON.stringify(res.data.data)); - } else if (res.data.code === 404) { - dispatchEvent(new Event(EVENT.LOGOUT)); - } - }) - .finally(() => { - set({ firstLoaded: true }); - }); + getCurrentInfo().then((res) => { + if (res.data.code === 0) { + set({ loadInfo: res.data.data }); + localStorage.setItem('upload', JSON.stringify(res.data.data)); + } else if (res.data.code === 404) { + dispatchEvent(new Event(EVENT.LOGOUT)); + } + }); }, setLoadInfo: (info) => set({ loadInfo: info }), clearLoadInfo: () => set({ loadInfo: null }) From 9fe511f0f2345cc3ca538d2c2b19c467b1b3355e Mon Sep 17 00:00:00 2001 From: ryangyuan Date: Thu, 24 Apr 2025 14:19:23 +0800 Subject: [PATCH 04/21] Feature/0416/add thinking mode (#264) * fix: modify thinking_model loading configuration * feat: realize thinkModel ui * feat:store * feat: add combined_llm_config_dto * add thinking_model_config & database migration * directly add thinking model to user_llm_config * delete thinking model repo dto service * delete thinkingmodel table migration * add is_cot config * feat: allow define is_cot * feat: simplify logs info * feat: add training model * feat: fix is_cot problem * fix: fix chat message * fix: fix progress error * fix: disable no settings thinking * feat: add thinking warning * fix: fix start service error * feat:fix init trainparams problem * feat: change playGround prompt * feat: Add Dimension Mismatch Handling for ChromaDB (#157) (#207) * Fix Issue #157 Add chroma_utils.py to manage chromaDB and added docs for explanation * Add logging and debugging process - Enhanced the`reinitialize_chroma_collections` function in`chroma_utils.py` to properly check if collections exist before attempting to delete them, preventing potential errors when collections don't exist. - Improved error handling in the`_handle_dimension_mismatch` method in`embedding_service.py` by adding more robust exception handling and verification steps after reinitialization. - Enhanced the collection initialization process in`embedding_service.py` to provide more detailed error messages and better handle cases where collections still have incorrect dimensions after reinitialization. - Added additional verification steps to ensure that collection dimensions match the expected dimension after creation or retrieval. - Improved logging throughout the code to provide more context in error messages, making debugging easier. * Change topics_generator timeout to 30 (#263) * quick fix * fix: shade -> shade_merge_info (#265) * fix: shade -> shade_merge_info * add convert array * quick fix import error * add log * add heartbeat * new strategy * sse version * add heartbeat * zh to en * optimize code * quick fix convert function * Feat/new branch management (#267) * feat: new branch management * feat: fix multi-upload * optimize contribute management --------- Co-authored-by: Crabboss Mr <1123357821@qq.com> Co-authored-by: Ye Xiangle Co-authored-by: Xinghan Pan Co-authored-by: doubleBlack2 <108928143+doubleBlack2@users.noreply.github.com> Co-authored-by: kevin-mindverse Co-authored-by: KKKKKKKevin <115385420+kevin-mindverse@users.noreply.github.com> --- docker/sqlite/init.sql | 5 + lpm_frontend/.eslintrc.js | 10 + .../app/dashboard/playground/chat/page.tsx | 66 ++- .../src/app/dashboard/train/training/page.tsx | 82 +-- .../src/app/standalone/role/[roleId]/page.tsx | 18 +- .../src/components/ModelStatus/index.tsx | 10 +- .../components/ThinkingModelModal/index.tsx | 126 ++++ .../src/components/modelConfigModal/index.tsx | 79 +-- lpm_frontend/src/components/svgs/stopIcon.tsx | 28 + .../train/TrainingConfiguration.tsx | 560 ++++++++++-------- .../src/components/train/TrainingLog.tsx | 10 +- lpm_frontend/src/service/modelConfig.ts | 18 +- lpm_frontend/src/service/train.ts | 8 +- lpm_frontend/src/store/useModelConfigStore.ts | 84 ++- lpm_frontend/src/store/useTrainingStore.ts | 4 +- lpm_kernel/L2/.env | 4 - .../diversity/diversity_data_generator.py | 17 +- .../preference/preference_QA_generate.py | 17 +- .../data_prep/selfqa/selfqa_generator.py | 17 +- lpm_kernel/L2/train_for_user.sh | 7 +- lpm_kernel/api/domains/trainprocess/routes.py | 9 +- .../api/domains/user_llm_config/routes.py | 68 +++ lpm_kernel/api/dto/user_llm_config_dto.py | 15 +- lpm_kernel/api/models/user_llm_config.py | 11 +- lpm_kernel/api/services/local_llm_service.py | 187 +++++- .../api/services/user_llm_config_service.py | 2 +- lpm_kernel/app.py | 2 + lpm_kernel/database/__init__.py | 1 + lpm_kernel/database/migration_manager.py | 374 ++++++++++++ ...add_thinking_fields_to_user_llm_configs.py | 92 +++ lpm_kernel/database/migrations/__init__.py | 1 + lpm_kernel/train/training_params_manager.py | 3 +- lpm_kernel/train/trainprocess_service.py | 18 +- scripts/downgrade_migration.py | 64 ++ scripts/run_migrations.py | 88 +++ scripts/start_local.sh | 4 + 36 files changed, 1639 insertions(+), 470 deletions(-) create mode 100644 lpm_frontend/src/components/ThinkingModelModal/index.tsx create mode 100644 lpm_frontend/src/components/svgs/stopIcon.tsx delete mode 100644 lpm_kernel/L2/.env create mode 100644 lpm_kernel/database/__init__.py create mode 100644 lpm_kernel/database/migration_manager.py create mode 100644 lpm_kernel/database/migrations/V20250420221300__add_thinking_fields_to_user_llm_configs.py create mode 100644 lpm_kernel/database/migrations/__init__.py create mode 100644 scripts/downgrade_migration.py create mode 100644 scripts/run_migrations.py diff --git a/docker/sqlite/init.sql b/docker/sqlite/init.sql index d0639c25..cd3a5233 100644 --- a/docker/sqlite/init.sql +++ b/docker/sqlite/init.sql @@ -188,6 +188,11 @@ CREATE TABLE IF NOT EXISTS user_llm_configs ( embedding_api_key VARCHAR(200), embedding_model_name VARCHAR(200), + -- Thinking configuration + thinking_model_name VARCHAR(200), + thinking_endpoint VARCHAR(200), + thinking_api_key VARCHAR(200), + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ); diff --git a/lpm_frontend/.eslintrc.js b/lpm_frontend/.eslintrc.js index 4583f90f..40e72cc9 100644 --- a/lpm_frontend/.eslintrc.js +++ b/lpm_frontend/.eslintrc.js @@ -17,5 +17,15 @@ module.exports = { ], rules: { '@next/next/no-sync-scripts': 'off' + }, + settings: { + 'import/resolver': { + typescript: { + project: 'Second-Me/lpm_frontend/tsconfig.json' + }, + node: { + extensions: ['.js', '.jsx', '.ts', '.tsx'] + } + } } }; diff --git a/lpm_frontend/src/app/dashboard/playground/chat/page.tsx b/lpm_frontend/src/app/dashboard/playground/chat/page.tsx index 815025ef..921f9d53 100644 --- a/lpm_frontend/src/app/dashboard/playground/chat/page.tsx +++ b/lpm_frontend/src/app/dashboard/playground/chat/page.tsx @@ -13,10 +13,13 @@ import { import type { ChatRequest } from '@/hooks/useSSE'; import { useSSE } from '@/hooks/useSSE'; import { useLoadInfoStore } from '@/store/useLoadInfoStore'; +import { getTrainingParams } from '@/service/train'; // Use the Message type directly from storage type Message = StorageMessage; +type ModelType = 'chat' | 'thinking'; + interface PlaygroundSettings { enableL0Retrieval: boolean; enableL1Retrieval: boolean; @@ -36,17 +39,44 @@ const generateMessageId = () => { }; export default function PlaygroundChat() { + const loadInfo = useLoadInfoStore((state) => state.loadInfo); + + const { sendStreamMessage, streaming, streamContent, stopSSE } = useSSE(); + const [sessions, setSessions] = useState([]); const [activeSessionId, setActiveSessionId] = useState(null); const [messages, setMessages] = useState([]); - const { sendStreamMessage, streaming, streamContent, stopSSE } = useSSE(); + const [modelType, setModelType] = useState(undefined); - const loadInfo = useLoadInfoStore((state) => state.loadInfo); const originPrompt = useMemo(() => { const name = loadInfo?.name || 'user'; - return `You are ${name}'s "Second Me", which is a personalized AI created by ${name}. You can help ${name} answer questions based on your understanding of ${name}'s background information and past records.`; - }, [loadInfo]); + if (modelType === 'chat') { + return `You are ${name}'s "Second Me", which is a personalized AI created by ${name}. You can help ${name} answer questions based on your understanding of ${name}'s background information and past records.`; + } + + if (modelType === 'thinking') { + return `You are ${name}'s "Second Me", and you are currently in conversation with ${name}. + Your task is to help ${name} answer relevant questions based on your understanding of ${name}'s background information and past records. + Please ensure your answers meet ${name}'s needs and provide precise solutions based on their historical information and personal preferences. + + When thinking, please follow these steps and output the results clearly in order: + 1. Consider the connection between questions and background: Review ${name}'s past records and personal information, analyzing the connections between their questions and these records. + 2. Derive answers to questions: Based on ${name}'s historical data and specific question content, conduct reasoning and analysis to ensure accuracy and relevance of answers. + 3. Generate high-quality responses: Distill answers that best meet ${name}'s needs and present them systematically with high information density. + + Your output format must follow the following structure: + + + As "Second Me"'s thinking process, analyze the relationships between ${name}'s background information, historical records and the questions raised, deriving reasonable solution approaches. + + + This is the final answer for ${name}, ensuring the response is precise and meets their needs, while being systematic and information-dense. + `; + } + + return ''; + }, [loadInfo, modelType]); const originSettings = useMemo(() => { return { enableL0Retrieval: true, @@ -72,6 +102,23 @@ export default function PlaygroundChat() { }); }, [originPrompt]); + useEffect(() => { + getTrainingParams() + .then((res) => { + if (res.data.code === 0) { + const data = res.data.data; + + localStorage.setItem('trainingParams', JSON.stringify(data)); + setModelType(data.is_cot ? 'thinking' : 'chat'); + } else { + throw new Error(res.data.message); + } + }) + .catch((error) => { + console.error(error.message); + }); + }, []); + const scrollToBottom = () => { if (messagesEndRef.current) { const container = messagesEndRef.current.parentElement; @@ -173,7 +220,7 @@ export default function PlaygroundChat() { }; // Update message list, adding user message and empty assistant message - let newMessages = [...messages, userMessage, assistantMessage]; + let newMessages = [...messages, userMessage]; const systemMessage: Message = { id: generateMessageId(), @@ -186,21 +233,22 @@ export default function PlaygroundChat() { }; if (!newMessages.find((item) => item.role === 'system')) { - newMessages = [systemMessage, ...newMessages] + newMessages = [systemMessage, ...newMessages]; } else { newMessages = newMessages.map((msg) => { if (msg.role === 'system') { return { ...msg, content: originPrompt }; } + return msg; }); } - setMessages(newMessages); + setMessages([...newMessages, assistantMessage]); // Save messages to session if (activeSessionId) { - chatStorage.saveSessionMessages(activeSessionId, newMessages); + chatStorage.saveSessionMessages(activeSessionId, [...newMessages, assistantMessage]); // If it's the first message in a new session, update session title if (messages.length === 0) { @@ -306,8 +354,8 @@ export default function PlaygroundChat() { index === messages.length - 1 && message.role === 'assistant' } - role={message.role} message={message.content} + role={message.role} timestamp={message.timestamp} /> ))} diff --git a/lpm_frontend/src/app/dashboard/train/training/page.tsx b/lpm_frontend/src/app/dashboard/train/training/page.tsx index c201c94a..7938cae5 100644 --- a/lpm_frontend/src/app/dashboard/train/training/page.tsx +++ b/lpm_frontend/src/app/dashboard/train/training/page.tsx @@ -10,7 +10,6 @@ import { getMemoryList } from '@/service/memory'; import { message, Modal } from 'antd'; import { useModelConfigStore } from '@/store/useModelConfigStore'; import CelebrationEffect from '@/components/Celebration'; -import { getModelConfig } from '@/service/modelConfig'; import TrainingLog from '@/components/train/TrainingLog'; import TrainingProgress from '@/components/train/TrainingProgress'; import TrainingConfiguration from '@/components/train/TrainingConfiguration'; @@ -60,28 +59,40 @@ const baseModelOptions = [ } ]; +// Title and explanation section +const pageTitle = 'Training Process'; +const pageDescription = + 'Transform your memories into a personalized AI model that thinks and communicates like you.'; + export default function TrainingPage() { - // Title and explanation section - const pageTitle = 'Training Process'; - const pageDescription = - 'Transform your memories into a personalized AI model that thinks and communicates like you.'; + const checkTrainStatus = useTrainingStore((state) => state.checkTrainStatus); + const resetTrainingState = useTrainingStore((state) => state.resetTrainingState); + const trainingError = useTrainingStore((state) => state.error); + const setStatus = useTrainingStore((state) => state.setStatus); + const fetchModelConfig = useModelConfigStore((state) => state.fetchModelConfig); + const modelConfig = useModelConfigStore((store) => store.modelConfig); + const status = useTrainingStore((state) => state.status); + const trainingProgress = useTrainingStore((state) => state.trainingProgress); + + const router = useRouter(); const [selectedInfo, setSelectedInfo] = useState(false); const [isTraining, setIsTraining] = useState(false); const [trainingParams, setTrainingParams] = useState({} as TrainingConfig); const [nowTrainingParams, setNowTrainingParams] = useState(null); const [trainActionLoading, setTrainActionLoading] = useState(false); - - const containerRef = useRef(null); - const firstLoadRef = useRef(true); const [showCelebration, setShowCelebration] = useState(false); const [showMemoryModal, setShowMemoryModal] = useState(false); - const modelConfig = useModelConfigStore((store) => store.modelConfig); - const updateModelConfig = useModelConfigStore((store) => store.updateModelConfig); + const [changeBaseModel, setChangeBaseModel] = useState(false); const cleanupEventSourceRef = useRef<(() => void) | undefined>(); + const containerRef = useRef(null); + const firstLoadRef = useRef(true); + const pollingStopRef = useRef(false); - const [changeBaseModel, setChangeBaseModel] = useState(false); + const [isResume, setIsResume] = useState( + trainingProgress.status === 'suspended' || trainingProgress.status === 'failed' + ); useEffect(() => { const localTrainingParams = JSON.parse(localStorage.getItem('trainingParams') || '{}'); @@ -90,30 +101,9 @@ export default function TrainingPage() { }, [trainingParams.model_name]); useEffect(() => { - getModelConfig().then((res) => { - if (res.data.code == 0) { - const data = res.data.data || {}; - - updateModelConfig(data); - } else { - message.error(res.data.message); - } - }); + fetchModelConfig(); }, []); - const pollingStopRef = useRef(false); - const router = useRouter(); - - const status = useTrainingStore((state) => state.status); - const trainingProgress = useTrainingStore((state) => state.trainingProgress); - const [isResume, setIsResume] = useState( - trainingProgress.status === 'suspended' || trainingProgress.status === 'failed' - ); - const checkTrainStatus = useTrainingStore((state) => state.checkTrainStatus); - const resetTrainingState = useTrainingStore((state) => state.resetTrainingState); - const trainingError = useTrainingStore((state) => state.error); - const setStatus = useTrainingStore((state) => state.setStatus); - // Start polling training progress const startPolling = () => { if (pollingStopRef.current) { @@ -171,13 +161,6 @@ export default function TrainingPage() { } }, [status, trainingError]); - // Monitor training status changes, scroll to bottom when status becomes 'training' - useEffect(() => { - if (status === 'training') { - scrollToBottom(); - } - }, [status]); - // Check training status once when component loads useEffect(() => { // Check if user has at least 3 memories @@ -201,14 +184,6 @@ export default function TrainingPage() { // Only proceed with training status check if memory check passes checkTrainStatus(); - - // Check if we were in the middle of retraining - const isRetraining = localStorage.getItem('isRetraining') === 'true'; - - if (isRetraining) { - // If we were retraining, set status to training - startGetTrainingProgress(); - } }; checkMemoryCount(); @@ -225,7 +200,9 @@ export default function TrainingPage() { if (firstLoadRef.current) { scrollPageToBottom(); - scrollToBottom(); + + // On first load, start polling and get training progress. + startGetTrainingProgress(); } } // If training is completed or failed, stop polling @@ -294,11 +271,6 @@ export default function TrainingPage() { firstLoadRef.current = false; }; - const scrollToBottom = () => { - // This function is kept for backward compatibility - // The actual scrolling is now handled by the TrainingLog component - }; - const updateTrainingParams = (params: TrainingConfig) => { setTrainingParams((state: TrainingConfig) => ({ ...state, ...params })); }; @@ -346,7 +318,6 @@ export default function TrainingPage() { eventSource.onerror = (error) => { console.error('EventSource failed:', error); eventSource.close(); - message.error('Failed to get training logs'); }; return () => { @@ -411,7 +382,6 @@ export default function TrainingPage() { resetTrainingState(); try { - // updateTrainLog(); setNowTrainingParams(trainingParams); console.log('Using startTrain API to train new model:', trainingParams.model_name); diff --git a/lpm_frontend/src/app/standalone/role/[roleId]/page.tsx b/lpm_frontend/src/app/standalone/role/[roleId]/page.tsx index f6c47447..9a2d1f71 100644 --- a/lpm_frontend/src/app/standalone/role/[roleId]/page.tsx +++ b/lpm_frontend/src/app/standalone/role/[roleId]/page.tsx @@ -57,7 +57,7 @@ export default function RoleChat() { const storedMessages = roleplayChatStorage.getMessages(role_id); const systemMessage: IChatMessage = { id: generateMessageId(), - content: role?.system_prompt ||'', + content: role?.system_prompt || '', role: 'system', timestamp: new Date().toLocaleTimeString([], { hour: '2-digit', @@ -123,11 +123,11 @@ export default function RoleChat() { }; // Update message list, add user message and empty assistant message - let newMessages = [...messages, userMessage, assistantMessage]; + let newMessages = [...messages, userMessage]; const systemMessage: IChatMessage = { id: generateMessageId(), - content: role?.system_prompt ||'', + content: role.system_prompt || '', role: 'system', timestamp: new Date().toLocaleTimeString([], { hour: '2-digit', @@ -136,20 +136,20 @@ export default function RoleChat() { }; if (!newMessages.find((item) => item.role === 'system')) { - newMessages = [systemMessage, ...newMessages] + newMessages = [systemMessage, ...newMessages]; } else { newMessages = newMessages.map((msg) => { if (msg.role === 'system') { - return { ...msg, content: role?.system_prompt ||'' }; + return { ...msg, content: role.system_prompt || '' }; } + return msg; }); } - - setMessages(newMessages); + setMessages([...newMessages, assistantMessage]); // Save messages - roleplayChatStorage.saveMessages(role_id, newMessages); + roleplayChatStorage.saveMessages(role_id, [...newMessages, assistantMessage]); // Send request const chatRequest: ChatRequest = { @@ -263,8 +263,8 @@ export default function RoleChat() { index === messages.length - 1 && msg.role === 'assistant' } - role={msg.role} message={msg.content} + role={msg.role} timestamp={msg.timestamp} /> ))} diff --git a/lpm_frontend/src/components/ModelStatus/index.tsx b/lpm_frontend/src/components/ModelStatus/index.tsx index 3804e105..28fde187 100644 --- a/lpm_frontend/src/components/ModelStatus/index.tsx +++ b/lpm_frontend/src/components/ModelStatus/index.tsx @@ -136,7 +136,13 @@ export function ModelStatus() { }; const handleServiceAction = () => { - const config = JSON.parse(localStorage.getItem('trainingConfig') || '{}'); + const config = JSON.parse(localStorage.getItem('trainingParams') || '{}'); + + if (!config.model_name) { + message.error('Please train a base model first'); + + return; + } if (status === 'running') { setServiceStopping(true); @@ -160,7 +166,7 @@ export function ModelStatus() { }); } else { setServiceStarting(true); - startService({ model_name: config.baseModel || 'Qwen2.5-0.5B-Instruct' }) + startService({ model_name: config.model_name }) .then((res) => { if (res.data.code === 0) { messageApi.success({ content: 'Service starting...', duration: 1 }); diff --git a/lpm_frontend/src/components/ThinkingModelModal/index.tsx b/lpm_frontend/src/components/ThinkingModelModal/index.tsx new file mode 100644 index 00000000..dd5f2cdc --- /dev/null +++ b/lpm_frontend/src/components/ThinkingModelModal/index.tsx @@ -0,0 +1,126 @@ +import type { IThinkingModelParams } from '@/service/modelConfig'; +import { updateThinkingConfig } from '@/service/modelConfig'; +import { useModelConfigStore } from '@/store/useModelConfigStore'; +import { Input, message, Modal } from 'antd'; +import { useEffect, useState } from 'react'; + +interface IProps { + open: boolean; + onClose: () => void; +} + +const ThinkingModelModal = (props: IProps) => { + const { open, onClose: handleCancel } = props; + + const fetchModelConfig = useModelConfigStore((store) => store.fetchModelConfig); + const [thinkingModelParams, setThinkingModelParams] = useState( + {} as IThinkingModelParams + ); + const updateThinkingModelConfig = useModelConfigStore((store) => store.updateThinkingModelConfig); + const thinkingModelConfig = useModelConfigStore((store) => store.thinkingModelConfig); + + useEffect(() => { + if (open) { + fetchModelConfig(); + } + }, [open]); + + useEffect(() => { + setThinkingModelParams(thinkingModelConfig); + }, [thinkingModelConfig]); + + const handleUpdate = () => { + const thinkingConfigComplete = + !!thinkingModelParams.thinking_model_name && + !!thinkingModelParams.thinking_api_key && + !!thinkingModelParams.thinking_endpoint; + + if (!thinkingConfigComplete) { + message.error('Please fill in all thinking model configuration fields'); + + return; + } + + updateThinkingConfig(thinkingModelParams) + .then((res) => { + if (res.data.code == 0) { + updateThinkingModelConfig(thinkingModelParams); + handleCancel(); + } else { + throw new Error(res.data.message); + } + }) + .catch((error) => { + console.error(error.message || 'Failed to update model config'); + }); + }; + + return ( + { + handleUpdate(); + }} + open={open} + > +
+
Thinking model
+
Currently only supports DeepSeek
+
+
+
+
+ + + setThinkingModelParams({ + ...thinkingModelParams, + thinking_model_name: e.target.value + }) + } + value={thinkingModelParams.thinking_model_name} + /> +
+ +
+ + {/* form is to disable autoComplete */} +
+ + setThinkingModelParams({ + ...thinkingModelParams, + thinking_api_key: e.target.value + }) + } + value={thinkingModelParams.thinking_api_key} + /> + +
+
+ +
+ + + setThinkingModelParams({ + ...thinkingModelParams, + thinking_endpoint: e.target.value + }) + } + value={thinkingModelParams.thinking_endpoint} + /> +
+
+
+ ); +}; + +export default ThinkingModelModal; diff --git a/lpm_frontend/src/components/modelConfigModal/index.tsx b/lpm_frontend/src/components/modelConfigModal/index.tsx index 78276599..f03015ba 100644 --- a/lpm_frontend/src/components/modelConfigModal/index.tsx +++ b/lpm_frontend/src/components/modelConfigModal/index.tsx @@ -1,6 +1,6 @@ -import { getModelConfig, updateModelConfig } from '@/service/modelConfig'; -import { useModelConfigStore } from '@/store/useModelConfigStore'; -import { Input, message, Modal, Radio } from 'antd'; +import { updateModelConfig } from '../../service/modelConfig'; +import { useModelConfigStore } from '../../store/useModelConfigStore'; +import { Input, Modal, Radio } from 'antd'; import Image from 'next/image'; import { useCallback, useEffect, useState } from 'react'; import { QuestionCircleOutlined } from '@ant-design/icons'; @@ -28,22 +28,21 @@ const options = [ const ModelConfigModal = (props: IProps) => { const { open, onClose } = props; const modelConfig = useModelConfigStore((store) => store.modelConfig); - const updateLocalModelConfig = useModelConfigStore((store) => store.updateModelConfig); - + const baseModelConfig = useModelConfigStore((store) => store.baseModelConfig); + const updateBaseModelConfig = useModelConfigStore((store) => store.updateBaseModelConfig); + const fetchModelConfig = useModelConfigStore((store) => store.fetchModelConfig); + const localProviderType = useModelConfigStore((store) => store.modelConfig.provider_type); const [modelType, setModelType] = useState(''); useEffect(() => { - getModelConfig().then((res) => { - if (res.data.code == 0) { - const data = res.data.data || {}; + if (open) { + fetchModelConfig(); + } + }, [open]); - updateLocalModelConfig(data); - setModelType(data.provider_type); - } else { - message.error(res.data.message); - } - }); - }, []); + useEffect(() => { + setModelType(localProviderType); + }, [localProviderType]); const renderEmpty = () => { return ( @@ -69,10 +68,10 @@ const ModelConfigModal = (props: IProps) => { { - updateLocalModelConfig({ ...modelConfig, key: e.target.value }); + updateBaseModelConfig({ ...baseModelConfig, key: e.target.value }); }} placeholder="Enter your OpenAI API key" - value={modelConfig.key} + value={baseModelConfig.key} />
You can get your API key from{' '} @@ -89,7 +88,7 @@ const ModelConfigModal = (props: IProps) => {
); - }, [modelConfig]); + }, [baseModelConfig]); const renderCustom = useCallback(() => { return ( @@ -106,10 +105,10 @@ const ModelConfigModal = (props: IProps) => { className="w-full" data-form-type="other" onChange={(e) => { - updateLocalModelConfig({ ...modelConfig, chat_model_name: e.target.value }); + updateBaseModelConfig({ ...baseModelConfig, chat_model_name: e.target.value }); }} spellCheck="false" - value={modelConfig.chat_model_name} + value={baseModelConfig.chat_model_name} /> @@ -122,10 +121,10 @@ const ModelConfigModal = (props: IProps) => { className="w-full" data-form-type="other" onChange={(e) => { - updateLocalModelConfig({ ...modelConfig, chat_api_key: e.target.value }); + updateBaseModelConfig({ ...baseModelConfig, chat_api_key: e.target.value }); }} spellCheck="false" - value={modelConfig.chat_api_key} + value={baseModelConfig.chat_api_key} /> @@ -136,9 +135,9 @@ const ModelConfigModal = (props: IProps) => { autoComplete="off" className="w-full" onChange={(e) => { - updateLocalModelConfig({ ...modelConfig, chat_endpoint: e.target.value }); + updateBaseModelConfig({ ...baseModelConfig, chat_endpoint: e.target.value }); }} - value={modelConfig.chat_endpoint} + value={baseModelConfig.chat_endpoint} /> @@ -151,9 +150,12 @@ const ModelConfigModal = (props: IProps) => { { - updateLocalModelConfig({ ...modelConfig, embedding_model_name: e.target.value }); + updateBaseModelConfig({ + ...baseModelConfig, + embedding_model_name: e.target.value + }); }} - value={modelConfig.embedding_model_name} + value={baseModelConfig.embedding_model_name} /> @@ -162,9 +164,9 @@ const ModelConfigModal = (props: IProps) => { { - updateLocalModelConfig({ ...modelConfig, embedding_api_key: e.target.value }); + updateBaseModelConfig({ ...baseModelConfig, embedding_api_key: e.target.value }); }} - value={modelConfig.embedding_api_key} + value={baseModelConfig.embedding_api_key} /> @@ -174,31 +176,27 @@ const ModelConfigModal = (props: IProps) => { { - updateLocalModelConfig({ ...modelConfig, embedding_endpoint: e.target.value }); + updateBaseModelConfig({ ...baseModelConfig, embedding_endpoint: e.target.value }); }} - value={modelConfig.embedding_endpoint} + value={baseModelConfig.embedding_endpoint} /> ); - }, [modelConfig]); + }, [baseModelConfig, updateBaseModelConfig]); const handleUpdate = () => { - // When None is selected, save an empty provider_type instead of deleting the config - const providerType = modelType || ''; - - updateModelConfig({ ...modelConfig, provider_type: providerType }) + updateModelConfig(modelConfig) .then((res) => { if (res.data.code == 0) { - updateLocalModelConfig({ ...modelConfig, provider_type: providerType }); onClose(); } else { throw new Error(res.data.message); } }) - .catch((error: any) => { - message.error(error.message || 'Failed to update model config'); + .catch((error) => { + console.error(error.message || 'Failed to update model config'); }); }; @@ -246,7 +244,10 @@ const ModelConfigModal = (props: IProps) => {

setModelType(e.target.value)} + onChange={(e) => { + setModelType(e.target.value); + updateBaseModelConfig({ ...baseModelConfig, provider_type: e.target.value }); + }} optionType="button" options={options} value={modelType ? modelType : ''} diff --git a/lpm_frontend/src/components/svgs/stopIcon.tsx b/lpm_frontend/src/components/svgs/stopIcon.tsx new file mode 100644 index 00000000..50505c04 --- /dev/null +++ b/lpm_frontend/src/components/svgs/stopIcon.tsx @@ -0,0 +1,28 @@ +import classNames from 'classnames'; + +interface StopIconProps { + className?: string; +} + +const StopIcon = (props: StopIconProps) => { + const { className } = props; + + return ( + + + + ); +}; + +export default StopIcon; diff --git a/lpm_frontend/src/components/train/TrainingConfiguration.tsx b/lpm_frontend/src/components/train/TrainingConfiguration.tsx index 7e2e9aaf..699dd021 100644 --- a/lpm_frontend/src/components/train/TrainingConfiguration.tsx +++ b/lpm_frontend/src/components/train/TrainingConfiguration.tsx @@ -4,14 +4,17 @@ import type React from 'react'; import { Fragment, useEffect, useMemo, useState } from 'react'; import { Listbox, Transition } from '@headlessui/react'; import { PlayIcon, StopIcon } from '@heroicons/react/24/outline'; -import { EVENT } from '@/utils/event'; -import { InputNumber, Radio, Spin, Tooltip } from 'antd'; +import { EVENT } from '../../utils/event'; +import { Checkbox, InputNumber, Radio, Spin, Tooltip } from 'antd'; import type { TrainingConfig } from '@/service/train'; import { QuestionCircleOutlined } from '@ant-design/icons'; import OpenAiModelIcon from '../svgs/OpenAiModelIcon'; import CustomModelIcon from '../svgs/CustomModelIcon'; import ColumnArrowIcon from '../svgs/ColumnArrowIcon'; import DoneIcon from '../svgs/DoneIcon'; +import ThinkingModelModal from '../ThinkingModelModal'; +import { useModelConfigStore } from '@/store/useModelConfigStore'; +import classNames from 'classnames'; interface BaseModelOption { value: string; @@ -61,6 +64,17 @@ const TrainingConfiguration: React.FC = ({ setSelectedInfo }) => { const [disabledChangeParams, setDisabledChangeParams] = useState(false); + const [openThinkingModel, setOpenThinkingModel] = useState(false); + const [showThinkingWarning, setShowThinkingWarning] = useState(false); + const thinkingModelConfig = useModelConfigStore((state) => state.thinkingModelConfig); + + const thinkingConfigComplete = useMemo(() => { + return ( + !!thinkingModelConfig.thinking_model_name && + !!thinkingModelConfig.thinking_api_key && + !!thinkingModelConfig.thinking_endpoint + ); + }, [thinkingModelConfig]); const trainButtonText = useMemo(() => { return isTraining @@ -116,272 +130,311 @@ const TrainingConfiguration: React.FC = ({

-
-
-
-

- Step 1: Choose Support Model for Data Synthesis -

- {!modelConfig?.provider_type ? ( -
-
- - -
- - Model used for processing and synthesizing your memory data - +
+
+

+ Step 1: Choose Support Model for Data Synthesis +

+ {!modelConfig?.provider_type ? ( +
+
+ +
- ) : ( -
-
{ + + Model used for processing and synthesizing your memory data + +
+ ) : ( +
+
{ + window.dispatchEvent(new CustomEvent(EVENT.SHOW_MODEL_CONFIG_MODAL)); + }} + > + Model Used :   + {modelConfig.provider_type === 'openai' ? ( + + ) : ( + + )} + + {modelConfig.provider_type === 'openai' ? 'OpenAI' : 'Custom Model'} + + -
- - Model used for processing and synthesizing your memory data + Configure Model for Data Synthesis + +
+ + Model used for processing and synthesizing your memory data + +
+ )} +
+
Data Synthesis Mode
+ + updateTrainingParams({ + ...trainingParams, + data_synthesis_mode: e.target.value + }) + } + optionType="button" + options={synthesisModeOptions} + value={ + disabledChangeParams && nowTrainingParams && !changeBaseModel + ? nowTrainingParams.data_synthesis_mode + : trainingParams.data_synthesis_mode + } + /> + + + Low: Fast data synthesis. Medium: Balanced synthesis and speed. High: Rich + synthesis, slower speed. + +
+
+ +
+
+

+ Step 2: Choose Base Model for Training Second Me +

+ + Base model for training your Second Me. Choose based on your available system + resources. + +
+ updateTrainingParams({ model_name: value })} + value={trainingParams.model_name} + > +
+ + + {baseModelOptions.find((option) => option.value === trainingParams.model_name) + ?.label || 'Select a model...'} + + + + + + + {baseModelOptions.map((option) => ( + + `relative cursor-pointer select-none py-2 pl-10 pr-4 ${active ? 'bg-blue-100 text-blue-900' : 'text-gray-900'}` + } + value={option.value} + > + {({ selected }) => ( + <> + + {option.label} + + {selected ? ( + + + + ) : null} + + )} + + ))} + + +
+
+
+ +
+
+

+ Step 3: Configure Advanced Training Parameters +

+
+ Adjust these parameters to control training quality and performance. Recommended + settings will ensure stable training. +
+
+
+
+
+
Learning Rate
+ + +
- )} -
-
Data Synthesis Mode
- - updateTrainingParams({ - ...trainingParams, - data_synthesis_mode: e.target.value - }) + max={0.005} + min={0.00003} + onChange={(value) => { + if (value == null) { + return; + } + + updateTrainingParams({ ...trainingParams, learning_rate: value }); + }} + status={ + trainingParams.learning_rate == 0.005 || trainingParams.learning_rate == 0.00003 + ? 'warning' + : undefined } - optionType="button" - options={synthesisModeOptions} + step={0.0001} value={ disabledChangeParams && nowTrainingParams && !changeBaseModel - ? nowTrainingParams.data_synthesis_mode - : trainingParams.data_synthesis_mode + ? nowTrainingParams.learning_rate + : trainingParams.learning_rate } /> - - - Low: Fast data synthesis. Medium: Balanced synthesis and speed. High: Rich - synthesis, slower speed. - +
+ Enter a value between 0.00003 and 0.005 (recommended: 0.0001) +
-
+
+
+
Number of Epochs
+ + + +
+ { + if (value == null) { + return; + } -
-
-

- Step 2: Choose Base Model for Training Second Me -

- - Base model for training your Second Me. Choose based on your available system - resources. - + updateTrainingParams({ ...trainingParams, number_of_epochs: value }); + }} + status={ + trainingParams.number_of_epochs == 10 || trainingParams.number_of_epochs == 1 + ? 'warning' + : undefined + } + step={1} + value={ + disabledChangeParams && nowTrainingParams && !changeBaseModel + ? nowTrainingParams.number_of_epochs + : trainingParams.number_of_epochs + } + /> +
+ Enter an integer between 1 and 10 (recommended: 2) +
- updateTrainingParams({ model_name: value })} - value={trainingParams.model_name} - > -
- - - {baseModelOptions.find((option) => option.value === trainingParams.model_name) - ?.label || 'Select a model...'} - - - - - - - - {baseModelOptions.map((option) => ( - - `relative cursor-pointer select-none py-2 pl-10 pr-4 ${active ? 'bg-blue-100 text-blue-900' : 'text-gray-900'}` - } - value={option.value} - > - {({ selected }) => ( - <> - - {option.label} - - {selected ? ( - - - - ) : null} - - )} - - ))} - - +
+
+
Concurrency Threads
+ + +
- -
+ { + if (value == null) { + return; + } -
-
-

- Step 3: Configure Advanced Training Parameters -

+ updateTrainingParams({ ...trainingParams, concurrency_threads: value }); + }} + status={ + trainingParams.concurrency_threads == 10 || + trainingParams.concurrency_threads == 1 + ? 'warning' + : undefined + } + step={1} + value={ + disabledChangeParams && nowTrainingParams && !changeBaseModel + ? nowTrainingParams.concurrency_threads + : trainingParams.concurrency_threads + } + />
- Adjust these parameters to control training quality and performance. Recommended - settings will ensure stable training. + Enter an integer between 1 and 10 (recommended: 2)
-
-
-
-
Learning Rate
- - - -
- { - if (value == null) { - return; - } +
+
- updateTrainingParams({ ...trainingParams, learning_rate: value }); - }} - status={ - trainingParams.learning_rate == 0.005 || - trainingParams.learning_rate == 0.00003 - ? 'warning' - : undefined - } - step={0.0001} - value={ - disabledChangeParams && nowTrainingParams && !changeBaseModel - ? nowTrainingParams.learning_rate - : trainingParams.learning_rate - } - /> -
- Enter a value between 0.00003 and 0.005 (recommended: 0.0001) -
-
-
-
-
Number of Epochs
- - - -
- { - if (value == null) { - return; - } +
+
+ Step 4: Configure Advanced Behavior +
- updateTrainingParams({ ...trainingParams, number_of_epochs: value }); - }} - status={ - trainingParams.number_of_epochs == 10 || trainingParams.number_of_epochs == 1 - ? 'warning' - : undefined - } - step={1} - value={ - disabledChangeParams && nowTrainingParams && !changeBaseModel - ? nowTrainingParams.number_of_epochs - : trainingParams.number_of_epochs - } - /> -
- Enter an integer between 1 and 10 (recommended: 2) -
-
-
-
-
Concurrency Threads
- - - -
- { - if (value == null) { - return; - } +
+ { + e.stopPropagation(); - updateTrainingParams({ ...trainingParams, concurrency_threads: value }); - }} - status={ - trainingParams.concurrency_threads == 10 || - trainingParams.concurrency_threads == 1 - ? 'warning' - : undefined - } - step={1} - value={ - disabledChangeParams && nowTrainingParams && !changeBaseModel - ? nowTrainingParams.concurrency_threads - : trainingParams.concurrency_threads - } - /> -
- Enter an integer between 1 and 10 (recommended: 2) -
-
+ if (!thinkingConfigComplete) { + setShowThinkingWarning(true); + setTimeout(() => setShowThinkingWarning(false), 2000); + + return; + } + + updateTrainingParams({ ...trainingParams, is_cot: e.target.checked }); + }} + /> +
{ + setOpenThinkingModel(true); + }} + > + Thinking Model
@@ -390,20 +443,7 @@ const TrainingConfiguration: React.FC = ({
{isTraining && (
- - - + Full stop only when the current step is complete
)} @@ -429,6 +469,8 @@ const TrainingConfiguration: React.FC = ({
+ + setOpenThinkingModel(false)} open={openThinkingModel} />
); }; diff --git a/lpm_frontend/src/components/train/TrainingLog.tsx b/lpm_frontend/src/components/train/TrainingLog.tsx index 93ca2995..a6a600fd 100644 --- a/lpm_frontend/src/components/train/TrainingLog.tsx +++ b/lpm_frontend/src/components/train/TrainingLog.tsx @@ -15,7 +15,7 @@ const TrainingLog: React.FC = ({ trainingDetails }: TrainingLo // Smooth scroll console to bottom const smoothScrollConsole = () => { if (consoleEndRef.current && !isUserScrolling) { - const consoleContainer = consoleEndRef.current.closest('.overflow-y-auto'); + const consoleContainer = consoleEndRef.current; if (consoleContainer instanceof HTMLElement) { consoleContainer.scrollTo({ @@ -44,7 +44,7 @@ const TrainingLog: React.FC = ({ trainingDetails }: TrainingLo // Find the console container and attach the scroll listener if (consoleEndRef.current) { - const consoleContainer = consoleEndRef.current.closest('.overflow-y-auto'); + const consoleContainer = consoleEndRef.current; if (consoleContainer instanceof HTMLElement) { consoleContainer.addEventListener('scroll', handleUserScroll); @@ -70,7 +70,10 @@ const TrainingLog: React.FC = ({ trainingDetails }: TrainingLo return (

Training Log

-
+
{trainingDetails.length > 0 ? ( trainingDetails.map((detail, index) => ( @@ -83,7 +86,6 @@ const TrainingLog: React.FC = ({ trainingDetails }: TrainingLo No training logs available. Start training to see logs here.
)} -
diff --git a/lpm_frontend/src/service/modelConfig.ts b/lpm_frontend/src/service/modelConfig.ts index 78bf74fa..af1eab5c 100644 --- a/lpm_frontend/src/service/modelConfig.ts +++ b/lpm_frontend/src/service/modelConfig.ts @@ -1,7 +1,9 @@ import { Request } from '../utils/request'; import type { CommonResponse, EmptyResponse } from '../types/responseModal'; -export interface IModelConfig { +export type IModelConfig = IBaseModelParams & IThinkingModelParams; + +export interface IBaseModelParams { id: number; provider_type: string; key: string; @@ -15,6 +17,12 @@ export interface IModelConfig { updated_at: string; } +export interface IThinkingModelParams { + thinking_model_name: string; + thinking_api_key: string; + thinking_endpoint: string; +} + export const getModelConfig = () => { return Request>({ method: 'get', @@ -30,6 +38,14 @@ export const updateModelConfig = (data: IModelConfig) => { }); }; +export const updateThinkingConfig = (data: IThinkingModelParams) => { + return Request>({ + method: 'put', + url: `/api/user-llm-configs/thinking`, + data + }); +}; + export const deleteModelConfig = () => { return Request({ method: 'delete', diff --git a/lpm_frontend/src/service/train.ts b/lpm_frontend/src/service/train.ts index a62a384c..7ecf1dcb 100644 --- a/lpm_frontend/src/service/train.ts +++ b/lpm_frontend/src/service/train.ts @@ -56,6 +56,10 @@ export interface TrainProgress { status: StageStatus; } +export interface TrainAdvanceParams { + is_cot?: boolean; +} + export interface TrainingParams { concurrency_threads?: number; data_synthesis_mode?: string; @@ -63,10 +67,12 @@ export interface TrainingParams { number_of_epochs?: number; } -export interface TrainingConfig extends TrainingParams { +export interface TrainBaseParams { model_name: string; } +export type TrainingConfig = TrainingParams & TrainAdvanceParams & TrainBaseParams; + export const startTrain = (config: TrainingConfig) => { return Request>({ method: 'post', diff --git a/lpm_frontend/src/store/useModelConfigStore.ts b/lpm_frontend/src/store/useModelConfigStore.ts index 7520be30..e7084692 100644 --- a/lpm_frontend/src/store/useModelConfigStore.ts +++ b/lpm_frontend/src/store/useModelConfigStore.ts @@ -1,18 +1,96 @@ import { create } from 'zustand'; -import type { IModelConfig } from '@/service/modelConfig'; +import { + getModelConfig, + type IBaseModelParams, + type IModelConfig, + type IThinkingModelParams +} from '@/service/modelConfig'; interface ModelConfigState { modelConfig: IModelConfig; + baseModelConfig: IBaseModelParams; + thinkingModelConfig: IThinkingModelParams; + fetchModelConfig: () => Promise; updateModelConfig: (config: IModelConfig) => void; deleteModelConfig: () => void; + updateBaseModelConfig: (config: IBaseModelParams) => void; + updateThinkingModelConfig: (config: IThinkingModelParams) => void; } export const useModelConfigStore = create((set, get) => ({ modelConfig: {} as IModelConfig, + baseModelConfig: {} as IBaseModelParams, + thinkingModelConfig: {} as IThinkingModelParams, + fetchModelConfig: async () => { + return getModelConfig() + .then((res) => { + if (res.data.code !== 0) { + throw new Error(res.data.message); + } + + const { thinking_model_name, thinking_api_key, thinking_endpoint, ...baseModelConfig } = + res.data.data; + + set({ + modelConfig: { ...(get().modelConfig as IModelConfig), ...res.data.data }, + baseModelConfig: { ...(get().baseModelConfig as IBaseModelParams), ...baseModelConfig }, + thinkingModelConfig: { + ...(get().thinkingModelConfig as IThinkingModelParams), + thinking_model_name, + thinking_api_key, + thinking_endpoint + } + }); + }) + .catch((error) => { + console.error(error.message || 'Failed to fetch model config'); + }); + }, updateModelConfig(config: IModelConfig) { - set({ modelConfig: { ...(get().modelConfig as IModelConfig), ...config } }); + const { thinking_model_name, thinking_api_key, thinking_endpoint, ...baseModelConfig } = config; + + set({ + modelConfig: { ...(get().modelConfig as IModelConfig), ...config }, + baseModelConfig: { ...(get().baseModelConfig as IBaseModelParams), ...baseModelConfig }, + thinkingModelConfig: { + ...(get().thinkingModelConfig as IThinkingModelParams), + thinking_model_name, + thinking_api_key, + thinking_endpoint + } + }); }, deleteModelConfig() { - set({ modelConfig: {} as IModelConfig }); + set({ + modelConfig: {} as IModelConfig, + baseModelConfig: {} as IBaseModelParams, + thinkingModelConfig: {} as IThinkingModelParams + }); + }, + updateBaseModelConfig(config: IBaseModelParams) { + set({ + baseModelConfig: { ...(get().baseModelConfig as IBaseModelParams), ...config } + }); + + set({ + modelConfig: { + ...(get().modelConfig as IModelConfig), + ...(get().baseModelConfig as IBaseModelParams), + ...config + } + }); + }, + updateThinkingModelConfig(config: IThinkingModelParams) { + set({ + thinkingModelConfig: { ...(get().thinkingModelConfig as IThinkingModelParams), ...config } + }); + + set({ + modelConfig: { + ...(get().modelConfig as IModelConfig), + ...(get().thinkingModelConfig as IThinkingModelParams), + ...config + } + }); } })); diff --git a/lpm_frontend/src/store/useTrainingStore.ts b/lpm_frontend/src/store/useTrainingStore.ts index a80d7e0c..12b1ba63 100644 --- a/lpm_frontend/src/store/useTrainingStore.ts +++ b/lpm_frontend/src/store/useTrainingStore.ts @@ -90,13 +90,13 @@ export const useTrainingStore = create((set) => ({ setTrainingProgress: (progress) => set({ trainingProgress: progress }), resetTrainingState: () => set({ trainingProgress: defaultTrainingProgress }), checkTrainStatus: async () => { - const config = JSON.parse(localStorage.getItem('trainingConfig') || '{}'); + const config = JSON.parse(localStorage.getItem('trainingParams') || '{}'); set({ error: false }); try { const res = await getTrainProgress({ - model_name: config.baseModel || 'Qwen2.5-0.5B-Instruct' + model_name: config.model_name || 'Qwen2.5-0.5B-Instruct' }); if (res.data.code === 0) { diff --git a/lpm_kernel/L2/.env b/lpm_kernel/L2/.env deleted file mode 100644 index 1da08004..00000000 --- a/lpm_kernel/L2/.env +++ /dev/null @@ -1,4 +0,0 @@ -# L2 API_KEY -DEEPSEEK_MODEL_NAME=deepseek-r1-* -DEEPSEEK_API_KEY=your_api_key -DEEPSEEK_BASE_URL=your_base_url \ No newline at end of file diff --git a/lpm_kernel/L2/data_pipeline/data_prep/diversity/diversity_data_generator.py b/lpm_kernel/L2/data_pipeline/data_prep/diversity/diversity_data_generator.py index 9b410079..ccc76f72 100644 --- a/lpm_kernel/L2/data_pipeline/data_prep/diversity/diversity_data_generator.py +++ b/lpm_kernel/L2/data_pipeline/data_prep/diversity/diversity_data_generator.py @@ -10,7 +10,6 @@ import pandas as pd from tqdm import tqdm from enum import Enum -from dotenv import load_dotenv from lpm_kernel.api.services.user_llm_config_service import UserLLMConfigService from lpm_kernel.configs.config import Config from lpm_kernel.L2.data_pipeline.data_prep.diversity.utils import remove_similar_dicts @@ -70,19 +69,11 @@ def __init__(self, preference_language: str, is_cot: bool = True): self.is_cot = is_cot if self.is_cot: logger.info("generate diversity data in longcot pattern!!!") - self.env_path = os.path.join(os.getcwd(), "lpm_kernel/L2/.env") - if os.path.exists(self.env_path): - load_dotenv(self.env_path) - else: - raise FileNotFoundError(f"Config file not found: {self.env_path}") - self.model_name = os.getenv("DEEPSEEK_MODEL_NAME", "") - self.api_key = os.getenv("DEEPSEEK_API_KEY", "") - self.base_url = os.getenv("DEEPSEEK_BASE_URL", "") + self.model_name = user_llm_config.thinking_model_name + self.api_key = user_llm_config.thinking_api_key + self.base_url = user_llm_config.thinking_endpoint if self.model_name.startswith("deepseek"): - self.client = openai.OpenAI( - api_key=self.api_key, - base_url=self.base_url, - ) + self.client = openai.OpenAI(api_key=self.api_key, base_url=self.base_url) else: logger.error(f"Error model_name, longcot data generating model_name: deepseek series") raise diff --git a/lpm_kernel/L2/data_pipeline/data_prep/preference/preference_QA_generate.py b/lpm_kernel/L2/data_pipeline/data_prep/preference/preference_QA_generate.py index 141f8d4f..961c4dc2 100644 --- a/lpm_kernel/L2/data_pipeline/data_prep/preference/preference_QA_generate.py +++ b/lpm_kernel/L2/data_pipeline/data_prep/preference/preference_QA_generate.py @@ -4,7 +4,6 @@ import os import random import re -from dotenv import load_dotenv from tqdm import tqdm import openai from enum import Enum @@ -74,19 +73,11 @@ def __init__(self, filename: str, bio: str, preference_language: str, is_cot: bo ) if self.is_cot: logger.info("generate pereference data in longcot pattern!!!") - self.env_path = os.path.join(os.getcwd(), "lpm_kernel/L2/.env") - if os.path.exists(self.env_path): - load_dotenv(self.env_path) - else: - raise FileNotFoundError(f"Config file not found: {self.env_path}") - self.model_name = os.getenv("DEEPSEEK_MODEL_NAME", "") - self.api_key = os.getenv("DEEPSEEK_API_KEY", "") - self.base_url = os.getenv("DEEPSEEK_BASE_URL", "") + self.model_name = user_llm_config.thinking_model_name + self.api_key = user_llm_config.thinking_api_key + self.base_url = user_llm_config.thinking_endpoint if self.model_name.startswith("deepseek"): - self.client = openai.OpenAI( - api_key=self.api_key, - base_url=self.base_url, - ) + self.client = openai.OpenAI(api_key=self.api_key, base_url=self.base_url) else: logger.error(f"Error model_name, longcot data generating model_name: deepseek series") raise diff --git a/lpm_kernel/L2/data_pipeline/data_prep/selfqa/selfqa_generator.py b/lpm_kernel/L2/data_pipeline/data_prep/selfqa/selfqa_generator.py index 9f514382..8adc1f45 100644 --- a/lpm_kernel/L2/data_pipeline/data_prep/selfqa/selfqa_generator.py +++ b/lpm_kernel/L2/data_pipeline/data_prep/selfqa/selfqa_generator.py @@ -2,7 +2,6 @@ import traceback import os import random -from dotenv import load_dotenv import openai from tqdm import tqdm from enum import Enum @@ -84,19 +83,11 @@ def __init__( self.data_synthesis_mode = os.environ.get("DATA_SYNTHESIS_MODE", "low") if self.is_cot: logger.info("generate selfQA data in longcot pattern!!!") - self.env_path = os.path.join(os.getcwd(), "lpm_kernel/L2/.env") - if os.path.exists(self.env_path): - load_dotenv(self.env_path) - else: - raise FileNotFoundError(f"Config file not found: {self.env_path}") - self.model_name = os.getenv("DEEPSEEK_MODEL_NAME", "") - self.api_key = os.getenv("DEEPSEEK_API_KEY", "") - self.base_url = os.getenv("DEEPSEEK_BASE_URL", "") + self.model_name = user_llm_config.thinking_model_name + self.api_key = user_llm_config.thinking_api_key + self.base_url = user_llm_config.thinking_endpoint if self.model_name.startswith("deepseek"): - self.client = openai.OpenAI( - api_key=self.api_key, - base_url=self.base_url, - ) + self.client = openai.OpenAI(api_key=self.api_key, base_url=self.base_url) else: logger.error(f"Error model_name, longcot data generating model_name: deepseek series") raise diff --git a/lpm_kernel/L2/train_for_user.sh b/lpm_kernel/L2/train_for_user.sh index 9877803a..92965a4f 100755 --- a/lpm_kernel/L2/train_for_user.sh +++ b/lpm_kernel/L2/train_for_user.sh @@ -6,6 +6,7 @@ NUM_TRAIN_EPOCHS="3" CONCURRENCY_THREADS="2" DATA_SYNTHESIS_MODE="low" HALF=False +IS_COT=False # Process parameters while [[ "$#" -gt 0 ]]; do @@ -14,6 +15,7 @@ while [[ "$#" -gt 0 ]]; do --epochs) NUM_TRAIN_EPOCHS="$2"; shift ;; --threads) CONCURRENCY_THREADS="$2"; shift ;; --mode) DATA_SYNTHESIS_MODE="$2"; shift ;; + --is_cot) IS_COT="$2"; shift ;; *) echo "Unknown parameter: $1"; exit 1 ;; esac shift @@ -25,6 +27,7 @@ echo " Learning rate: $LEARNING_RATE" echo " Number of epochs: $NUM_TRAIN_EPOCHS" echo " Concurrency threads: $CONCURRENCY_THREADS" echo " Data synthesis mode: $DATA_SYNTHESIS_MODE" +echo " Is chain of thought: $IS_COT" # If concurrency threads are set, configure related environment variables if [ "$CONCURRENCY_THREADS" != "1" ]; then @@ -75,6 +78,6 @@ python lpm_kernel/L2/train.py \ --lora_target_modules "all-linear" \ --use_4bit_quantization False \ --use_nested_quant False \ - --bnb_4bit_compute_dtype "bfloat16" \ - --is_cot False + --bnb_4bit_compute_dtype "bfloat16"\ + --is_cot $IS_COT diff --git a/lpm_kernel/api/domains/trainprocess/routes.py b/lpm_kernel/api/domains/trainprocess/routes.py index 4a8a9fb9..2ac3beec 100644 --- a/lpm_kernel/api/domains/trainprocess/routes.py +++ b/lpm_kernel/api/domains/trainprocess/routes.py @@ -63,9 +63,10 @@ def start_process(): number_of_epochs = data.get("number_of_epochs", None) concurrency_threads = data.get("concurrency_threads", None) data_synthesis_mode = data.get("data_synthesis_mode", None) + is_cot = data.get("is_cot", None) # Log the received parameters - logger.info(f"Training parameters: model_name={model_name}, learning_rate={learning_rate}, number_of_epochs={number_of_epochs}, concurrency_threads={concurrency_threads}, data_synthesis_mode={data_synthesis_mode}") + logger.info(f"Training parameters: model_name={model_name}, learning_rate={learning_rate}, number_of_epochs={number_of_epochs}, concurrency_threads={concurrency_threads}, data_synthesis_mode={data_synthesis_mode}, is_cot={is_cot}") # Create service instance with model name and additional parameters train_service = TrainProcessService( @@ -88,7 +89,8 @@ def start_process(): "learning_rate": learning_rate, "number_of_epochs": number_of_epochs, "concurrency_threads": concurrency_threads, - "data_synthesis_mode": data_synthesis_mode + "data_synthesis_mode": data_synthesis_mode, + "is_cot": is_cot } params_manager = TrainingParamsManager() @@ -110,7 +112,8 @@ def start_process(): "learning_rate": learning_rate, "number_of_epochs": number_of_epochs, "concurrency_threads": concurrency_threads, - "data_synthesis_mode": data_synthesis_mode + "data_synthesis_mode": data_synthesis_mode, + "is_cot": is_cot } ) ) diff --git a/lpm_kernel/api/domains/user_llm_config/routes.py b/lpm_kernel/api/domains/user_llm_config/routes.py index dd0e5a1e..a5919c85 100644 --- a/lpm_kernel/api/domains/user_llm_config/routes.py +++ b/lpm_kernel/api/domains/user_llm_config/routes.py @@ -40,7 +40,28 @@ def validate_llm_config(data: Dict[Any, Any]) -> Dict[str, str]: errors[field] = f'{field} is required for custom provider' else: errors['provider_type'] = 'provider_type is required' + + return errors + + +def validate_thinking_model(data: Dict[Any, Any]) -> Dict[str, str]: + """Validate thinking model configuration + + Args: + data: Configuration data + Returns: + Dictionary with error messages if validation fails, empty dict if validation passes + """ + errors = {} + + # Validate required fields + if not data.get('thinking_model_name'): + errors['thinking_model_name'] = 'Thinking model name is required' + + if not data.get('thinking_endpoint'): + errors['thinking_endpoint'] = 'Thinking endpoint is required' + return errors def process_openai_config(data: Dict[Any, Any]) -> Dict[Any, Any]: @@ -57,6 +78,9 @@ def process_openai_config(data: Dict[Any, Any]) -> Dict[Any, Any]: return data + + + @user_llm_config_bp.route("", methods=["GET"]) def get_config(): """Get LLM configuration""" @@ -116,6 +140,50 @@ def update_config(): ), HTTPStatus.INTERNAL_SERVER_ERROR + + + +@user_llm_config_bp.route("/thinking", methods=["PUT"]) +def update_thinking_model(): + """Update thinking model configuration""" + try: + # Validate request data + request_data = request.json + validation_errors = validate_thinking_model(request_data) + + if validation_errors: + error_message = "; ".join([f"{k}: {v}" for k, v in validation_errors.items()]) + return jsonify( + APIResponse.error(f"Validation failed: {error_message}") + ), HTTPStatus.BAD_REQUEST + + # Create a DTO with only thinking model fields + thinking_data = {} + if 'thinking_model_name' in request_data: + thinking_data['thinking_model_name'] = request_data['thinking_model_name'] + if 'thinking_endpoint' in request_data: + thinking_data['thinking_endpoint'] = request_data['thinking_endpoint'] + if 'thinking_api_key' in request_data: + thinking_data['thinking_api_key'] = request_data['thinking_api_key'] + + # Update the configuration + data = UpdateUserLLMConfigDTO(**thinking_data) + config = user_llm_config_service.update_config(1, data) # Default configuration ID is 1 + + return jsonify( + APIResponse.success( + data=config.dict(), + message="Thinking model configuration updated successfully" + ) + ), HTTPStatus.OK + + except Exception as e: + logger.error(f"Failed to update thinking model configuration: {str(e)}", exc_info=True) + return jsonify( + APIResponse.error(f"Failed to update thinking model configuration: {str(e)}") + ), HTTPStatus.INTERNAL_SERVER_ERROR + + @user_llm_config_bp.route("/key", methods=["DELETE"]) def delete_key(): """Delete API key from LLM configuration""" diff --git a/lpm_kernel/api/dto/user_llm_config_dto.py b/lpm_kernel/api/dto/user_llm_config_dto.py index 13cc911a..7cc46482 100644 --- a/lpm_kernel/api/dto/user_llm_config_dto.py +++ b/lpm_kernel/api/dto/user_llm_config_dto.py @@ -18,6 +18,11 @@ class BaseUserLLMConfigDTO(BaseModel): embedding_api_key: Optional[str] = None embedding_model_name: Optional[str] = None + # Thinking configuration + thinking_model_name: Optional[str] = None + thinking_endpoint: Optional[str] = None + thinking_api_key: Optional[str] = None + def dict(self, *args, **kwargs): result = super().dict(*args, **kwargs) return result @@ -43,8 +48,11 @@ class UpdateUserLLMConfigDTO(BaseModel): embedding_api_key: Optional[str] = None embedding_model_name: Optional[str] = None - - + # Thinking configuration + thinking_model_name: Optional[str] = None + thinking_endpoint: Optional[str] = None + thinking_api_key: Optional[str] = None + def dict(self, *args, **kwargs): result = super().dict(*args, **kwargs) return result @@ -71,6 +79,9 @@ def from_model(cls, model): embedding_endpoint=model.embedding_endpoint, embedding_api_key=model.embedding_api_key, embedding_model_name=model.embedding_model_name, + thinking_model_name=model.thinking_model_name, + thinking_endpoint=model.thinking_endpoint, + thinking_api_key=model.thinking_api_key, created_at=model.created_at, updated_at=model.updated_at ) diff --git a/lpm_kernel/api/models/user_llm_config.py b/lpm_kernel/api/models/user_llm_config.py index 4cdca65f..b8dcc489 100644 --- a/lpm_kernel/api/models/user_llm_config.py +++ b/lpm_kernel/api/models/user_llm_config.py @@ -21,7 +21,10 @@ class UserLLMConfig(Base): embedding_api_key = Column(String(200), nullable=True, comment='Embedding API key') embedding_model_name = Column(String(200), nullable=True, comment='Embedding model name') - + # Thinking configuration + thinking_model_name = Column(String(200), nullable=True, comment='Thinking model name') + thinking_endpoint = Column(String(200), nullable=True, comment='Thinking API endpoint') + thinking_api_key = Column(String(200), nullable=True, comment='Thinking API key') created_at = Column(DateTime, default=datetime.utcnow, comment='Creation time') updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, comment='Update time') @@ -41,6 +44,9 @@ def to_dict(self): 'embedding_endpoint': self.embedding_endpoint, 'embedding_api_key': self.embedding_api_key, 'embedding_model_name': self.embedding_model_name, + 'thinking_model_name': self.thinking_model_name, + 'thinking_endpoint': self.thinking_endpoint, + 'thinking_api_key': self.thinking_api_key, 'created_at': self.created_at, 'updated_at': self.updated_at } @@ -58,6 +64,9 @@ def from_dict(cls, data): embedding_endpoint=data.get('embedding_endpoint'), embedding_api_key=data.get('embedding_api_key'), embedding_model_name=data.get('embedding_model_name'), + thinking_model_name=data.get('thinking_model_name'), + thinking_endpoint=data.get('thinking_endpoint'), + thinking_api_key=data.get('thinking_api_key'), created_at=data.get('created_at'), updated_at=data.get('updated_at') ) diff --git a/lpm_kernel/api/services/local_llm_service.py b/lpm_kernel/api/services/local_llm_service.py index 32a8a3b9..f5bb3447 100644 --- a/lpm_kernel/api/services/local_llm_service.py +++ b/lpm_kernel/api/services/local_llm_service.py @@ -4,6 +4,8 @@ import psutil import time import subprocess +import threading +import queue from typing import Iterator, Any, Optional, Generator, Dict from datetime import datetime from flask import Response @@ -244,48 +246,181 @@ def _parse_response_chunk(self, chunk): def handle_stream_response(self, response_iter: Iterator[Any]) -> Response: """Handle streaming response from the LLM server""" - def generate(): - chunk = None # Initialize chunk variable + # Create a queue for thread communication + message_queue = queue.Queue() + # Create an event flag to notify when model processing is complete + completion_event = threading.Event() + # Create a variable to track if heartbeat is needed after first response + first_response_received = False + + def heartbeat_thread(): + """Thread function for sending heartbeats""" + start_time = time.time() + heartbeat_interval = 10 # Send heartbeat every 10 seconds + heartbeat_count = 0 + + logger.info("[STREAM_DEBUG] Heartbeat thread started") + try: + # Send initial heartbeat + message_queue.put((b": initial heartbeat\n\n", "[INITIAL_HEARTBEAT]")) + last_heartbeat_time = time.time() + + while not completion_event.is_set(): + current_time = time.time() + + # Check if we need to send a heartbeat + if current_time - last_heartbeat_time >= heartbeat_interval: + heartbeat_count += 1 + elapsed = current_time - start_time + logger.info(f"[STREAM_DEBUG] Sending heartbeat #{heartbeat_count} at {elapsed:.2f}s") + message_queue.put((f": heartbeat #{heartbeat_count}\n\n".encode('utf-8'), "[HEARTBEAT]")) + last_heartbeat_time = current_time + + # Short sleep to prevent CPU spinning + time.sleep(0.1) + + logger.info(f"[STREAM_DEBUG] Heartbeat thread stopping after {heartbeat_count} heartbeats") + except Exception as e: + logger.error(f"[STREAM_DEBUG] Error in heartbeat thread: {str(e)}", exc_info=True) + message_queue.put((f"data: {{\"error\": \"Heartbeat error: {str(e)}\"}}\n\n".encode('utf-8'), "[ERROR]")) + + def model_response_thread(): + """Thread function for processing model responses""" + chunk = None + start_time = time.time() + chunk_count = 0 + + try: + logger.info("[STREAM_DEBUG] Model response thread started") + + # Process model responses for chunk in response_iter: + current_time = time.time() + elapsed_time = current_time - start_time + chunk_count += 1 + + logger.info(f"[STREAM_DEBUG] Received chunk #{chunk_count} after {elapsed_time:.2f}s") + if chunk is None: - logger.warning("Received None chunk in stream, skipping") + logger.warning("[STREAM_DEBUG] Received None chunk, skipping") continue - - # logger.info(f"Received raw chunk: {chunk}") - # Check if this is the done marker for custom format + + # Check if it's an end marker if chunk == "[DONE]": - logger.info("Received [DONE] marker") - yield b"data: [DONE]\n\n" - return # Use return instead of break to ensure [DONE] in finally won't be executed + logger.info(f"[STREAM_DEBUG] Received [DONE] marker after {elapsed_time:.2f}s") + message_queue.put((b"data: [DONE]\n\n", "[DONE]")) + break - # Handle OpenAI error format directly + # Handle error responses if isinstance(chunk, dict) and "error" in chunk: - logger.warning(f"Received error response: {chunk}") + logger.warning(f"[STREAM_DEBUG] Received error response: {chunk}") data_str = json.dumps(chunk) - yield f"data: {data_str}\n\n".encode('utf-8') - # After sending error, send [DONE] marker to close the stream properly - yield b"data: [DONE]\n\n" - return + message_queue.put((f"data: {data_str}\n\n".encode('utf-8'), "[ERROR]")) + message_queue.put((b"data: [DONE]\n\n", "[DONE]")) + break + # Handle normal responses response_data = self._parse_response_chunk(chunk) if response_data: data_str = json.dumps(response_data) - # logger.info(f"Sending response data: {data_str}") - yield f"data: {data_str}\n\n".encode('utf-8') + content = response_data.get("choices", [{}])[0].get("delta", {}).get("content", "") + content_length = len(content) if content else 0 + logger.info(f"[STREAM_DEBUG] Sending chunk #{chunk_count}, content length: {content_length}, elapsed: {elapsed_time:.2f}s") + message_queue.put((f"data: {data_str}\n\n".encode('utf-8'), "[CONTENT]")) else: - logger.warning("Parsed response data is None, skipping chunk") - + logger.warning(f"[STREAM_DEBUG] Parsed response data is None for chunk #{chunk_count}") + + # Handle the case where no responses were received + if chunk_count == 0: + logger.info("[STREAM_DEBUG] No chunks received, sending empty message") + thinking_message = { + "id": str(uuid.uuid4()), + "object": "chat.completion.chunk", + "created": int(datetime.now().timestamp()), + "model": "models/lpm", + "system_fingerprint": None, + "choices": [ + { + "index": 0, + "delta": { + "content": "" # Empty content won't affect frontend display + }, + "finish_reason": None + } + ] + } + data_str = json.dumps(thinking_message) + message_queue.put((f"data: {data_str}\n\n".encode('utf-8'), "[THINKING]")) + + # Model processing is complete, send end marker + if chunk != "[DONE]": + logger.info(f"[STREAM_DEBUG] Sending final [DONE] marker after {elapsed_time:.2f}s") + message_queue.put((b"data: [DONE]\n\n", "[DONE]")) + except Exception as e: - error_msg = json.dumps({'error': str(e)}) - logger.error(f"Failed to process stream response: {str(e)}", exc_info=True) - yield f"data: {error_msg}\n\n".encode('utf-8') + logger.error(f"[STREAM_DEBUG] Error processing model response: {str(e)}", exc_info=True) + message_queue.put((f"data: {{\"error\": \"{str(e)}\"}}\n\n".encode('utf-8'), "[ERROR]")) + message_queue.put((b"data: [DONE]\n\n", "[DONE]")) finally: - if chunk != "[DONE]": # Only send if [DONE] marker was not received - logger.info("Sending final [DONE] marker") + # Set completion event to notify heartbeat thread to stop + completion_event.set() + logger.info(f"[STREAM_DEBUG] Model response thread completed with {chunk_count} chunks") + + def generate(): + """Main generator function for generating responses""" + # Start heartbeat thread + heart_thread = threading.Thread(target=heartbeat_thread, daemon=True) + heart_thread.start() + + # Start model response processing thread + model_thread = threading.Thread(target=model_response_thread, daemon=True) + model_thread.start() + + try: + # Get messages from queue and return to client + while True: + try: + # Use short timeout to get message, prevent blocking + message, message_type = message_queue.get(timeout=0.1) + logger.debug(f"[STREAM_DEBUG] Yielding message type: {message_type}") + yield message + + # If end marker is received, exit loop + if message_type == "[DONE]": + logger.info("[STREAM_DEBUG] Received [DONE] marker, ending generator") + break + except queue.Empty: + # Queue is empty, continue trying to get message + # Check if model thread has completed but didn't send [DONE] + if completion_event.is_set() and not model_thread.is_alive(): + logger.warning("[STREAM_DEBUG] Model thread completed without [DONE], ending generator") + yield b"data: [DONE]\n\n" + break + pass + except GeneratorExit: + # Client closed connection + logger.info("[STREAM_DEBUG] Client closed connection (GeneratorExit)") + completion_event.set() + except Exception as e: + logger.error(f"[STREAM_DEBUG] Error in generator: {str(e)}", exc_info=True) + try: + yield f"data: {{\"error\": \"Generator error: {str(e)}\"}}\n\n".encode('utf-8') yield b"data: [DONE]\n\n" - logger.info("Stream response completed successfully") - + except: + pass + completion_event.set() + finally: + # Ensure completion event is set + completion_event.set() + # Wait for threads to complete + if heart_thread.is_alive(): + heart_thread.join(timeout=1.0) + if model_thread.is_alive(): + model_thread.join(timeout=1.0) + logger.info("[STREAM_DEBUG] Generator completed") + + # Return response return Response( generate(), mimetype='text/event-stream', diff --git a/lpm_kernel/api/services/user_llm_config_service.py b/lpm_kernel/api/services/user_llm_config_service.py index 63678c1c..3d318dc2 100644 --- a/lpm_kernel/api/services/user_llm_config_service.py +++ b/lpm_kernel/api/services/user_llm_config_service.py @@ -42,7 +42,7 @@ def update_config( # Update or create the configuration return self.repository.update(config_id, dto) - + def delete_key(self, config_id: int = 1) -> Optional[UserLLMConfigDTO]: """Delete API key from the configuration diff --git a/lpm_kernel/app.py b/lpm_kernel/app.py index 29c423e4..62f18c2f 100644 --- a/lpm_kernel/app.py +++ b/lpm_kernel/app.py @@ -3,6 +3,7 @@ from .common.logging import logger from .api import init_routes from .api.file_server.handler import FileServerHandler +from .database.migration_manager import MigrationManager import os import atexit @@ -14,6 +15,7 @@ def create_app(): try: DatabaseSession.initialize() logger.info("Database connection initialized successfully") + except Exception as e: logger.error(f"Failed to initialize database connection: {str(e)}") raise diff --git a/lpm_kernel/database/__init__.py b/lpm_kernel/database/__init__.py new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/lpm_kernel/database/__init__.py @@ -0,0 +1 @@ + diff --git a/lpm_kernel/database/migration_manager.py b/lpm_kernel/database/migration_manager.py new file mode 100644 index 00000000..bf7a7439 --- /dev/null +++ b/lpm_kernel/database/migration_manager.py @@ -0,0 +1,374 @@ +""" +Database Migration Manager + +This module provides functionality to manage database migrations in a systematic way. +It ensures that migrations are applied in order and only once. +""" +import os +import importlib.util +import sqlite3 +from datetime import datetime +from lpm_kernel.common.logging import logger +class MigrationManager: + """Manages database migrations for SQLite database""" + + def __init__(self, db_path): + """ + Initialize the migration manager + + Args: + db_path: Path to the SQLite database file + """ + self.db_path = db_path + self._ensure_migration_table() + + def _ensure_migration_table(self): + """Create migration tracking table if it doesn't exist""" + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() + + cursor.execute(""" + CREATE TABLE IF NOT EXISTS schema_migrations ( + version VARCHAR(50) PRIMARY KEY, + description TEXT, + applied_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ); + """) + + conn.commit() + conn.close() + logger.debug("Migration tracking table checked/created") + + def get_applied_migrations(self): + """ + Get list of already applied migrations + + Returns: + List of applied migration versions + """ + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() + + cursor.execute("SELECT version FROM schema_migrations ORDER BY version") + versions = [row[0] for row in cursor.fetchall()] + + conn.close() + return versions + + def apply_migrations(self, migrations_dir=None): + """ + Apply all pending migrations from the migrations directory + + Args: + migrations_dir: Directory containing migration scripts. + If None, use 'migrations' subdirectory + + Returns: + List of applied migration versions + """ + if migrations_dir is None: + migrations_dir = os.path.join(os.path.dirname(__file__), "migrations") + # logger.info(f"Using default migrations directory: {migrations_dir}") + + # Ensure migrations directory exists + os.makedirs(migrations_dir, exist_ok=True) + + # Get already applied migrations + applied = self.get_applied_migrations() + # logger.info(f"Found {len(applied)} previously applied migrations") + + # Get all migration files and sort them + migration_files = [] + for f in os.listdir(migrations_dir): + if f.endswith('.py') and not f.startswith('__'): + try: + # Extract version from filename (format: V20250420221300__description.py) + version = f.split('__')[0].replace('V', '') + migration_files.append((version, f)) + # logger.info(f"Found migration file: {f}") + except Exception as e: + logger.warning(f"Skipping invalid migration filename: {f}, error: {e}") + + # Sort by version + migration_files.sort(key=lambda x: x[0]) + # logger.info(f"Found {len(migration_files)} migration files: {', '.join([f[1] for f in migration_files])}") + + applied_in_session = [] + + # Apply each migration that hasn't been applied yet + for version, migration_file in migration_files: + if version in applied: + logger.debug(f"Skipping already applied migration: {migration_file}") + continue + + # logger.info(f"Applying migration: {migration_file}") + + # Load the migration module + module_path = os.path.join(migrations_dir, migration_file) + module_name = f"migration_{version}" + + try: + # Import the migration module dynamically + spec = importlib.util.spec_from_file_location(module_name, module_path) + migration_module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(migration_module) + + # Get migration description + description = getattr(migration_module, 'description', migration_file) + + # Connect to database and start transaction + conn = sqlite3.connect(self.db_path) + conn.execute("BEGIN TRANSACTION") + + try: + # Execute the migration + migration_module.upgrade(conn) + + # Record the migration + conn.execute( + "INSERT INTO schema_migrations (version, description) VALUES (?, ?)", + (version, description) + ) + + # Commit the transaction + conn.commit() + # logger.info(f"Successfully applied migration: {migration_file}") + applied_in_session.append(version) + + except Exception as e: + # Rollback on error + conn.rollback() + logger.error(f"Error applying migration {migration_file}: {str(e)}") + raise + + finally: + conn.close() + + except Exception as e: + logger.error(f"Failed to load migration {migration_file}: {str(e)}") + raise + + # if not applied_in_session: + # # logger.info("No new migrations to apply") + # else: + # logger.info(f"Applied {len(applied_in_session)} new migrations") + + return applied_in_session + + def downgrade_migration(self, version, migrations_dir=None): + """ + Downgrade a specific migration by version + + Args: + version: Version of the migration to downgrade + migrations_dir: Directory containing migration scripts + + Returns: + True if downgrade was successful, False otherwise + """ + if migrations_dir is None: + migrations_dir = os.path.join(os.path.dirname(__file__), "migrations") + # logger.info(f"Using default migrations directory: {migrations_dir}") + + # Check if migration is applied + applied = self.get_applied_migrations() + if version not in applied: + logger.warning(f"Migration version {version} is not applied, cannot downgrade") + return False + + # Find migration file + migration_file = None + for f in os.listdir(migrations_dir): + if f.endswith('.py') and not f.startswith('__') and f.startswith(f'V{version}'): + migration_file = f + break + + if not migration_file: + logger.error(f"Migration file for version {version} not found") + return False + + # logger.info(f"Downgrading migration: {migration_file}") + + # Load the migration module + module_path = os.path.join(migrations_dir, migration_file) + module_name = f"migration_{version}" + + try: + # Import the migration module dynamically + spec = importlib.util.spec_from_file_location(module_name, module_path) + migration_module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(migration_module) + + # Check if downgrade method exists + if not hasattr(migration_module, 'downgrade'): + logger.error(f"Migration {migration_file} does not have a downgrade method") + return False + + # Connect to database and start transaction + conn = sqlite3.connect(self.db_path) + conn.execute("BEGIN TRANSACTION") + + try: + # Execute the downgrade + migration_module.downgrade(conn) + + # Remove the migration record + conn.execute( + "DELETE FROM schema_migrations WHERE version = ?", + (version,) + ) + + # Commit the transaction + conn.commit() + # logger.info(f"Successfully downgraded migration: {migration_file}") + return True + + except Exception as e: + # Rollback on error + conn.rollback() + logger.error(f"Error downgrading migration {migration_file}: {str(e)}") + raise + + finally: + conn.close() + + except Exception as e: + logger.error(f"Failed to load migration {migration_file}: {str(e)}") + raise + + def downgrade_to_version(self, target_version=None, migrations_dir=None): + """ + Downgrade migrations to a specific version + + Args: + target_version: Version to downgrade to (inclusive). If None, downgrade all migrations. + migrations_dir: Directory containing migration scripts + + Returns: + List of downgraded migration versions + """ + if migrations_dir is None: + migrations_dir = os.path.join(os.path.dirname(__file__), "migrations") + # logger.info(f"Using default migrations directory: {migrations_dir}") + + # Get applied migrations + applied = self.get_applied_migrations() + # logger.info(f"Found {len(applied)} applied migrations") + + if not applied: + # logger.info("No migrations to downgrade") + return [] + + # Determine which migrations to downgrade + to_downgrade = [] + if target_version is None: + # Downgrade all migrations + to_downgrade = applied + else: + # Find target version index + if target_version not in applied: + logger.error(f"Target version {target_version} is not applied") + return [] + + target_index = applied.index(target_version) + + # Get all versions after target_version + to_downgrade = applied[target_index + 1:] + + # Sort in reverse order to downgrade newest first + to_downgrade.sort(reverse=True) + + downgraded = [] + for version in to_downgrade: + try: + if self.downgrade_migration(version, migrations_dir): + downgraded.append(version) + else: + logger.error(f"Failed to downgrade migration {version}, stopping") + break + except Exception as e: + logger.error(f"Error during downgrade of {version}: {str(e)}") + break + + # if not downgraded: + # logger.info("No migrations were downgraded") + # else: + # logger.info(f"Downgraded {len(downgraded)} migrations: {', '.join(downgraded)}") + + return downgraded + + def create_migration(self, description, migrations_dir=None): + """ + Create a new migration file with template code + + Args: + description: Short description of what the migration does + migrations_dir: Directory to create migration in + + Returns: + Path to the created migration file + """ + if migrations_dir is None: + migrations_dir = os.path.join(os.path.dirname(__file__), "migrations") + + # Ensure migrations directory exists + os.makedirs(migrations_dir, exist_ok=True) + + # Get current timestamp for version + timestamp = datetime.now().strftime("%Y%m%d%H%M%S") + + # Format description for filename (lowercase, underscores) + safe_description = description.lower().replace(' ', '_').replace('-', '_') + safe_description = ''.join(c for c in safe_description if c.isalnum() or c == '_') + + # Create filename + filename = f"V{timestamp}__{safe_description}.py" + filepath = os.path.join(migrations_dir, filename) + + # Create migration file with template + with open(filepath, 'w') as f: + f.write(f'''""" +Migration: {description} +Version: {timestamp} +""" + +description = "{description}" + +def upgrade(conn): + """ + Apply the migration + + Args: + conn: SQLite connection object + """ + cursor = conn.cursor() + + # TODO: Implement your migration logic here + # Example: + # cursor.execute(""" + # CREATE TABLE IF NOT EXISTS new_table ( + # id INTEGER PRIMARY KEY AUTOINCREMENT, + # name TEXT NOT NULL + # ) + # """) + + # No need to commit, the migration manager handles transactions + +def downgrade(conn): + """ + Revert the migration + + Args: + conn: SQLite connection object + """ + cursor = conn.cursor() + + # TODO: Implement your downgrade logic here + # Example: + # cursor.execute("DROP TABLE IF EXISTS new_table") + + # No need to commit, the migration manager handles transactions +''') + + # logger.info(f"Created new migration: {filename}") + return filepath diff --git a/lpm_kernel/database/migrations/V20250420221300__add_thinking_fields_to_user_llm_configs.py b/lpm_kernel/database/migrations/V20250420221300__add_thinking_fields_to_user_llm_configs.py new file mode 100644 index 00000000..4889c998 --- /dev/null +++ b/lpm_kernel/database/migrations/V20250420221300__add_thinking_fields_to_user_llm_configs.py @@ -0,0 +1,92 @@ +""" +Migration: Add thinking model fields to user_llm_configs table +Version: 20250420221300 +""" + +description = "Add thinking model fields to user_llm_configs table" + +def upgrade(conn): + """ + Apply the migration + + Args: + conn: SQLite connection object + """ + cursor = conn.cursor() + + # Check if thinking_model_name column already exists in user_llm_configs table + cursor.execute("PRAGMA table_info(user_llm_configs)") + columns = [row[1] for row in cursor.fetchall()] + + # Add thinking model fields if they don't exist + if 'thinking_model_name' not in columns: + cursor.execute("ALTER TABLE user_llm_configs ADD COLUMN thinking_model_name VARCHAR(200)") + print("Added thinking_model_name column to user_llm_configs table") + + if 'thinking_endpoint' not in columns: + cursor.execute("ALTER TABLE user_llm_configs ADD COLUMN thinking_endpoint VARCHAR(200)") + print("Added thinking_endpoint column to user_llm_configs table") + + if 'thinking_api_key' not in columns: + cursor.execute("ALTER TABLE user_llm_configs ADD COLUMN thinking_api_key VARCHAR(200)") + print("Added thinking_api_key column to user_llm_configs table") + + # No need to commit, the migration manager handles transactions + +def downgrade(conn): + """ + Revert the migration + + Args: + conn: SQLite connection object + """ + cursor = conn.cursor() + + # SQLite doesn't support dropping columns directly + # We need to create a new table without the thinking model fields, copy the data, and replace the old table + + # Create a temporary table without thinking model fields + cursor.execute(""" + CREATE TABLE user_llm_configs_temp ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + provider_type VARCHAR(50) NOT NULL DEFAULT 'openai', + key VARCHAR(200), + chat_endpoint VARCHAR(200), + chat_api_key VARCHAR(200), + chat_model_name VARCHAR(200), + embedding_endpoint VARCHAR(200), + embedding_api_key VARCHAR(200), + embedding_model_name VARCHAR(200), + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """) + + # Copy data from the original table to the temporary table + cursor.execute(""" + INSERT INTO user_llm_configs_temp ( + id, provider_type, key, + chat_endpoint, chat_api_key, chat_model_name, + embedding_endpoint, embedding_api_key, embedding_model_name, + created_at, updated_at + ) + SELECT + id, provider_type, key, + chat_endpoint, chat_api_key, chat_model_name, + embedding_endpoint, embedding_api_key, embedding_model_name, + created_at, updated_at + FROM user_llm_configs + """) + + # Drop the original table + cursor.execute("DROP TABLE user_llm_configs") + + # Rename the temporary table to the original table name + cursor.execute("ALTER TABLE user_llm_configs_temp RENAME TO user_llm_configs") + + # Recreate the index + cursor.execute("CREATE INDEX IF NOT EXISTS idx_user_llm_configs_created_at ON user_llm_configs(created_at)") + + print("Removed thinking model fields from user_llm_configs table") + + # No need to commit, the migration manager handles transactions diff --git a/lpm_kernel/database/migrations/__init__.py b/lpm_kernel/database/migrations/__init__.py new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/lpm_kernel/database/migrations/__init__.py @@ -0,0 +1 @@ + diff --git a/lpm_kernel/train/training_params_manager.py b/lpm_kernel/train/training_params_manager.py index f5389664..2f6589ac 100644 --- a/lpm_kernel/train/training_params_manager.py +++ b/lpm_kernel/train/training_params_manager.py @@ -22,7 +22,8 @@ class TrainingParamsManager: "learning_rate": 1e-4, "number_of_epochs": 3, "concurrency_threads": 2, - "data_synthesis_mode": "low" + "data_synthesis_mode": "low", + "is_cot": False } # Parameters file path diff --git a/lpm_kernel/train/trainprocess_service.py b/lpm_kernel/train/trainprocess_service.py index a348ffcd..fe1e58e0 100644 --- a/lpm_kernel/train/trainprocess_service.py +++ b/lpm_kernel/train/trainprocess_service.py @@ -48,7 +48,7 @@ def __new__(cls, *args, **kwargs): cls._instance = super().__new__(cls) return cls._instance - def __init__(self, current_model_name: str = None, is_cot: bool = False): + def __init__(self, current_model_name: str = None): if not self._initialized: # Generate a unique progress file name based on model name self.progress = TrainProgressHolder(current_model_name) @@ -76,7 +76,6 @@ def __init__(self, current_model_name: str = None, is_cot: bool = False): self.model_name = current_model_name # Create new progress instance with updated progress file name self.progress = TrainProgressHolder(current_model_name) - self.is_cot = is_cot def list_documents(self): """List all documents""" @@ -308,7 +307,8 @@ def decode_preference_patterns(self)->bool: self._prepare_l2_data() # Use data from l2_data dictionary - L2Generator(is_cot=self.is_cot).gen_preference_data( + training_params = TrainingParamsManager.get_latest_training_params() + L2Generator(is_cot=training_params.get("is_cot", False)).gen_preference_data( self.l2_data["notes"], self.l2_data["basic_info"], self.l2_data["data_output_base_dir"], @@ -336,9 +336,11 @@ def reinforce_identity(self)->bool: # Get or prepare L2 data self._prepare_l2_data() + # Get training parameters + training_params = TrainingParamsManager.get_latest_training_params() # Use data from l2_data dictionary l2_generator = L2Generator( - data_path=os.path.join(os.getcwd(), "resources"), is_cot=self.is_cot + data_path=os.path.join(os.getcwd(), "resources"), is_cot=training_params.get("is_cot", False) ) l2_generator.gen_selfqa_data( self.l2_data["notes"], @@ -386,8 +388,10 @@ def augment_content_retention(self) -> bool: # Get or prepare L2 data self._prepare_l2_data() + # Get training parameters + training_params = TrainingParamsManager.get_latest_training_params() # Use data from l2_data dictionary - l2_generator = L2Generator(data_path=os.path.join(os.getcwd(), "resources"), is_cot=self.is_cot) + l2_generator = L2Generator(data_path=os.path.join(os.getcwd(), "resources"), is_cot=training_params.get("is_cot", False)) l2_generator.gen_diversity_data( self.l2_data["notes"], self.l2_data["basic_info"], @@ -629,6 +633,7 @@ def _start_training(self, script_path, log_path): num_train_epochs = training_params.get("number_of_epochs") concurrency_threads = training_params.get("concurrency_threads") data_synthesis_mode = training_params.get("data_synthesis_mode") + is_cot = training_params.get("is_cot", False) # Log training parameters logger.info("Training parameters from latest settings:") @@ -644,7 +649,8 @@ def _start_training(self, script_path, log_path): "--lr", str(learning_rate), "--epochs", str(num_train_epochs), "--threads", str(concurrency_threads), - "--mode", str(data_synthesis_mode) + "--mode", str(data_synthesis_mode), + "--is_cot", str(is_cot) ] # Ensure log directory exists diff --git a/scripts/downgrade_migration.py b/scripts/downgrade_migration.py new file mode 100644 index 00000000..527a3c61 --- /dev/null +++ b/scripts/downgrade_migration.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python +""" +Database Migration Downgrade Tool + +Usage: + python downgrade_migration.py [version] + +Arguments: + version: Optional, version to downgrade to (inclusive). + If not provided, all migrations will be downgraded. +""" + +import os +import sys +import sqlite3 +import importlib.util +from pathlib import Path + +# Add project root to Python path +project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +sys.path.insert(0, project_root) + +# Import migration manager +from lpm_kernel.database.migration_manager import MigrationManager +from lpm_kernel.configs.config import Config + +def main(): + # Get database path + config = Config.from_env() + db_path = config.database.db_file + + print(f"Database path: {db_path}") + + # Create migration manager + manager = MigrationManager(db_path) + + # Get applied migrations + applied = manager.get_applied_migrations() + print(f"Applied migrations: {', '.join(applied) if applied else 'none'}") + + if not applied: + print("No migrations to downgrade") + return + + # Check command line arguments + if len(sys.argv) > 1: + target_version = sys.argv[1] + if target_version not in applied: + print(f"Error: Version {target_version} is not applied, cannot downgrade to this version") + return + + print(f"Downgrading to version {target_version}...") + downgraded = manager.downgrade_to_version(target_version) + else: + print("Downgrading all migrations...") + downgraded = manager.downgrade_to_version() + + if downgraded: + print(f"Successfully downgraded migrations: {', '.join(downgraded)}") + else: + print("No migrations were downgraded") + +if __name__ == "__main__": + main() diff --git a/scripts/run_migrations.py b/scripts/run_migrations.py new file mode 100644 index 00000000..567dd2e7 --- /dev/null +++ b/scripts/run_migrations.py @@ -0,0 +1,88 @@ +#!/usr/bin/env python +""" +Database Migration Runner + +This script runs database migrations using the migration manager. +It should be executed whenever the database schema needs to be updated. +""" + +import os +import sys +from pathlib import Path + +# Add project root to path +project_root = str(Path(__file__).parent.parent) +sys.path.insert(0, project_root) + +from lpm_kernel.configs.config import Config +from lpm_kernel.database.migration_manager import MigrationManager + +from lpm_kernel.common.logging import logger + +def get_db_path(): + """Get the database path from environment or use default""" + config = Config.from_env() + db_path = config.get("SQLITE_DB_PATH", os.path.join(project_root, "data", "sqlite", "lpm.db")) + return db_path + +def run_migrations(): + """Run all pending database migrations""" + db_path = get_db_path() + + # logger.info(f"Using database at: {db_path}") + + # Check if database file exists + if not os.path.exists(db_path): + # logger.error(f"Database file not found at {db_path}") + return False + + try: + # Initialize migration manager + migrations_dir = os.path.join(project_root, "lpm_kernel", "database", "migrations") + manager = MigrationManager(db_path) + + # Apply migrations + applied = manager.apply_migrations(migrations_dir) + + # if applied: + # logger.info(f"Successfully applied {len(applied)} migrations") + # else: + # logger.info("No new migrations to apply") + + return True + + except Exception as e: + logger.error(f"Error during migrations: {e}") + return False + +def create_migration(description): + """Create a new migration file""" + db_path = get_db_path() + migrations_dir = os.path.join(project_root, "lpm_kernel", "database", "migrations") + + manager = MigrationManager(db_path) + filepath = manager.create_migration(description, migrations_dir) + + # logger.info(f"Created new migration at: {filepath}") + return filepath + +if __name__ == "__main__": + # logger.info("Starting database migration") + + if len(sys.argv) > 1 and sys.argv[1] == "create": + if len(sys.argv) > 2: + description = sys.argv[2] + create_migration(description) + else: + logger.error("Missing migration description") + print("Usage: python run_migrations.py create 'Add new table'") + sys.exit(1) + else: + success = run_migrations() + + if success: + # logger.info("Migration completed successfully") + sys.exit(0) + else: + logger.error("Migration failed") + sys.exit(1) diff --git a/scripts/start_local.sh b/scripts/start_local.sh index 52d13441..b2eb0b29 100755 --- a/scripts/start_local.sh +++ b/scripts/start_local.sh @@ -82,6 +82,10 @@ python docker/app/init_chroma.py # Get local IP address (excluding localhost and docker networks) LOCAL_IP=$(ifconfig | grep "inet " | grep -v "127.0.0.1" | grep "192.168" | awk '{print $2}' | head -n 1) +# Run database migrations first +log_info "Running database migrations..." +python scripts/run_migrations.py + # Start Flask application log_info "Starting Flask application..." log_info "Application will run at the following addresses:" From 516843d9637c9b727c0223e358bce53cf209f67d Mon Sep 17 00:00:00 2001 From: doubleBlack2 <108928143+doubleBlack2@users.noreply.github.com> Date: Thu, 24 Apr 2025 14:24:45 +0800 Subject: [PATCH 05/21] mcp search online secondme model (#242) --- mcp/mcp_public.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/mcp/mcp_public.py b/mcp/mcp_public.py index ee7ea4bb..1f3bfa2d 100644 --- a/mcp/mcp_public.py +++ b/mcp/mcp_public.py @@ -2,6 +2,7 @@ from mcp.server.fastmcp import FastMCP import http.client import json +import requests mindverse = FastMCP("mindverse_public") url = "app.secondme.io" @@ -66,6 +67,31 @@ async def get_response(query:str, instance_id:str) -> str | None: else: return None +@mindverse.tool() +async def get_online_instances(): + """ + Check which secondme models are available for chatting online. + """ + url = "https://app.secondme.io/api/upload/list?page_size=100" + response = requests.get(url) + + if response.status_code == 200: + data = response.json() + items = data.get("data", {}).get("items", []) + + online_items = [ + { + "upload_name": item["upload_name"], + "instance_id": item["instance_id"], + "description": item["description"] + } + for item in items if item.get("status") == "online" + ] + + return json.dumps(online_items, ensure_ascii=False, indent=2) + else: + raise Exception(f"Request failed with status code: {response.status_code}") + if __name__ == "__main__": mindverse.run(transport='stdio') From f04916754c37fb364d9ded0d0e9ca040b9da6bce Mon Sep 17 00:00:00 2001 From: ryangyuan Date: Thu, 24 Apr 2025 14:25:00 +0800 Subject: [PATCH 06/21] feat: replace tutorial link (#268) * feat: replace tutorial link * replace video link --------- Co-authored-by: kevin-mindverse --- lpm_frontend/src/app/home/components/Footer/index.tsx | 2 +- lpm_frontend/src/layouts/DashboardLayout/Menu/index.tsx | 5 +---- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/lpm_frontend/src/app/home/components/Footer/index.tsx b/lpm_frontend/src/app/home/components/Footer/index.tsx index 30e33d60..f58a3e84 100644 --- a/lpm_frontend/src/app/home/components/Footer/index.tsx +++ b/lpm_frontend/src/app/home/components/Footer/index.tsx @@ -24,7 +24,7 @@ const Footer = (props: IProps) => {
diff --git a/lpm_frontend/src/layouts/DashboardLayout/Menu/index.tsx b/lpm_frontend/src/layouts/DashboardLayout/Menu/index.tsx index 12276ebf..14d09dc5 100644 --- a/lpm_frontend/src/layouts/DashboardLayout/Menu/index.tsx +++ b/lpm_frontend/src/layouts/DashboardLayout/Menu/index.tsx @@ -255,10 +255,7 @@ const Menu = () => {
+ +
+
+
Enable CUDA GPU Acceleration
+ + + +
+
+ +
+
+ {cudaAvailable + ? 'Enable for faster training on NVIDIA GPUs.' + : 'CUDA acceleration is not available on this system.'} +
+
diff --git a/lpm_frontend/src/components/train/TrainingLog.tsx b/lpm_frontend/src/components/train/TrainingLog.tsx index a6a600fd..aceee8b4 100644 --- a/lpm_frontend/src/components/train/TrainingLog.tsx +++ b/lpm_frontend/src/components/train/TrainingLog.tsx @@ -11,6 +11,7 @@ const TrainingLog: React.FC = ({ trainingDetails }: TrainingLo const consoleEndRef = useRef(null); const [isUserScrolling, setIsUserScrolling] = useState(false); const userScrollTimeout = useRef(null); + const [isAutoScrollEnabled, setIsAutoScrollEnabled] = useState(true); // Smooth scroll console to bottom const smoothScrollConsole = () => { @@ -29,17 +30,37 @@ const TrainingLog: React.FC = ({ trainingDetails }: TrainingLo useEffect(() => { // Set up scroll event listener to detect user scrolling const handleUserScroll = () => { - setIsUserScrolling(true); + if (!consoleEndRef.current) return; + + const consoleContainer = consoleEndRef.current.closest('.overflow-y-auto'); + + if (!(consoleContainer instanceof HTMLElement)) return; + + // Check if scrolled away from bottom + const isScrolledToBottom = + Math.abs((consoleContainer.scrollHeight - consoleContainer.scrollTop) - consoleContainer.clientHeight) < 50; + + // If scrolled away from bottom, consider it manual scrolling + if (!isScrolledToBottom) { + setIsUserScrolling(true); - // Clear any existing timeout - if (userScrollTimeout.current) { - clearTimeout(userScrollTimeout.current); - } + // Clear any existing timeout + if (userScrollTimeout.current) { + clearTimeout(userScrollTimeout.current); + } - // Reset the flag after a short delay - userScrollTimeout.current = setTimeout(() => { + // Reset the flag after a delay + userScrollTimeout.current = setTimeout(() => { + setIsUserScrolling(false); + }, 5000); // 5 seconds delay before allowing auto-scroll again + } else { + // If at bottom, not considered manual scrolling setIsUserScrolling(false); - }, 2000); // 2 seconds delay before allowing auto-scroll again + if (userScrollTimeout.current) { + clearTimeout(userScrollTimeout.current); + userScrollTimeout.current = null; + } + } }; // Find the console container and attach the scroll listener @@ -65,7 +86,16 @@ const TrainingLog: React.FC = ({ trainingDetails }: TrainingLo if (trainingDetails.length > 0) { smoothScrollConsole(); } - }, [trainingDetails]); + }, [trainingDetails, isAutoScrollEnabled]); + + const toggleAutoScroll = () => { + setIsAutoScrollEnabled(!isAutoScrollEnabled); + if (!isAutoScrollEnabled) { + // If we're re-enabling auto-scroll, scroll to bottom immediately + setIsUserScrolling(false); + setTimeout(smoothScrollConsole, 50); + } + }; return (
diff --git a/lpm_frontend/src/service/train.ts b/lpm_frontend/src/service/train.ts index 7ecf1dcb..0d04bf21 100644 --- a/lpm_frontend/src/service/train.ts +++ b/lpm_frontend/src/service/train.ts @@ -138,3 +138,17 @@ export const getTrainingParams = () => { url: `/api/trainprocess/training_params` }); }; + +export const checkCudaAvailability = () => { + return Request>({ + method: 'get', + url: '/api/kernel2/cuda/available' + }); +}; diff --git a/lpm_kernel/L1/shade_generator.py b/lpm_kernel/L1/shade_generator.py index 428a0338..7e952ace 100644 --- a/lpm_kernel/L1/shade_generator.py +++ b/lpm_kernel/L1/shade_generator.py @@ -169,6 +169,18 @@ def __add_second_view_info(self, shade_info: ShadeInfo) -> ShadeInfo: content = response.choices[0].message.content shift_pattern = r"\{.*\}" shift_perspective_result = self.__parse_json_response(content, shift_pattern) + + # Check if result is None and provide default values to avoid TypeError + if shift_perspective_result is None: + logger.warning(f"Failed to parse perspective shift result, using default values: {content}") + # Create a default mapping with expected parameters + shift_perspective_result = { + "domainDesc": f"You have knowledge and experience related to {shade_info.name}.", + "domainContent": shade_info.content_third_view, + "domainTimeline": [] + } + + # Now it's safe to pass shift_perspective_result as kwargs shade_info.add_second_view(**shift_perspective_result) return shade_info diff --git a/lpm_kernel/L2/data_pipeline/data_prep/preference/preference_QA_generate.py b/lpm_kernel/L2/data_pipeline/data_prep/preference/preference_QA_generate.py index 961c4dc2..c7a25ce0 100644 --- a/lpm_kernel/L2/data_pipeline/data_prep/preference/preference_QA_generate.py +++ b/lpm_kernel/L2/data_pipeline/data_prep/preference/preference_QA_generate.py @@ -53,8 +53,18 @@ def __init__(self, filename: str, bio: str, preference_language: str, is_cot: bo bio: Biography or context information to use in prompt generation. preference_language: Language for prompts ("Chinese/中文" or otherwise English). """ + # Ensure the filename is actually a string + if filename is None: + raise ValueError("Filename cannot be None") + self.filename = filename - self.is_cot = is_cot + # Convert is_cot to bool if it's a string + if isinstance(is_cot, str): + self.is_cot = is_cot.lower() == 'true' + else: + self.is_cot = bool(is_cot) + + logger.info(f"PreferenceQAGenerator initialized with is_cot={self.is_cot}") with open(self.filename, "r", encoding="utf-8") as f: self.pre_msg = json.load(f) diff --git a/lpm_kernel/L2/download_model.sh b/lpm_kernel/L2/download_model.sh index 0ac2fc7d..31d89510 100644 --- a/lpm_kernel/L2/download_model.sh +++ b/lpm_kernel/L2/download_model.sh @@ -1 +1,12 @@ -python lpm_kernel/L2/utils.py Qwen2.5-0.5B-Instruct +#!/bin/bash +# Script to download model from Hugging Face +# Usage: ./download_model.sh [model_name] +# If no model name is provided, will attempt to get it from config + +if [ "$1" != "" ]; then + # Use provided model name + python lpm_kernel/L2/utils.py "$1" +else + # No model name provided, let utils.py determine from config + python lpm_kernel/L2/utils.py +fi diff --git a/lpm_kernel/L2/l2_generator.py b/lpm_kernel/L2/l2_generator.py index d8468724..03eb6621 100644 --- a/lpm_kernel/L2/l2_generator.py +++ b/lpm_kernel/L2/l2_generator.py @@ -35,11 +35,19 @@ def __init__(self, data_path: str = "../raw_data", preferred_lang: str = "Englis Args: data_path: Path to the raw data directory. Defaults to "../raw_data". preferred_lang: Preferred language for data processing. Defaults to "English". + is_cot: Whether to use Chain of Thought reasoning. Can be bool or string. """ self.data_path = data_path self.data_processor = L2DataProcessor(data_path, preferred_lang) self.preferred_lang = preferred_lang - self.is_cot = is_cot + + # Convert is_cot to bool if it's a string + if isinstance(is_cot, str): + self.is_cot = is_cot.lower() == 'true' + else: + self.is_cot = bool(is_cot) + + logging.info(f"L2Generator initialized with is_cot={self.is_cot}") def data_preprocess(self, note_list: List[Note], basic_info: Dict): """Preprocess the input notes and basic information. @@ -60,39 +68,50 @@ def gen_subjective_data( graph_path: str, config_path: str, ): - """Generate subjective data based on input notes and user information. + """Generate subjective data for personalization. + + This method orchestrates the generation of subjective data including preferences, + diversity, self-Q&A data, and graph indexing. Args: - note_list: List of Note objects. - basic_info: Dictionary containing basic user information. + note_list: List of Note objects to process. + basic_info: Dictionary containing user information. data_output_base_dir: Base directory for output data. topics_path: Path to topics data. - entities_path: Path to entities data. + entities_path: Path to entity data. graph_path: Path to graph data. config_path: Path to configuration file. """ - global_bio = basic_info["globalBio"] - user_name = basic_info["username"] - user_intro = basic_info["aboutMe"] + if not os.path.exists(data_output_base_dir): + os.makedirs(data_output_base_dir) - preference_output_path = "preference.json" - diversity_output_path = "diversity.json" - selfqa_output_path = "selfqa.json" + # Check if the file exists + if not os.path.exists(topics_path): + # Create an empty file + with open(topics_path, "w") as f: + f.write(json.dumps([])) + # Generate subjective data self.data_processor.gen_subjective_data( - note_list, - data_output_base_dir, - preference_output_path, - diversity_output_path, - selfqa_output_path, - global_bio, - topics_path, - entities_path, - graph_path, - user_name, - config_path, - user_intro, + note_list=note_list, + data_output_base_dir=data_output_base_dir, + preference_output_path="preference.json", + diversity_output_path="diversity.json", + selfqa_output_path="selfqa.json", + global_bio=basic_info["globalBio"], + topics_path=topics_path, + entitys_path=entities_path, + graph_path=graph_path, + user_name=basic_info["username"], + config_path=config_path, + user_intro=basic_info["aboutMe"], ) + + # Merge JSON files for training + self.merge_json_files(data_output_base_dir) + + # Release Ollama models from memory after data synthesis is complete + self._release_ollama_models() def gen_preference_data( self, @@ -187,6 +206,20 @@ def merge_json_files(self, data_output_base_dir: str): with open(merged_output_path, 'w', encoding='utf-8') as f: json.dump(merged_data, f, ensure_ascii=False, indent=2) + def _release_ollama_models(self): + """Release Ollama models from memory to free up VRAM for training. + + This method calls the release function defined in the train module. + It's important to release models after data synthesis and before training + to ensure VRAM is properly freed. + """ + try: + from lpm_kernel.L2.train import release_ollama_models + release_ollama_models() + except Exception as e: + import logging + logging = logging.getLogger(__name__) + logging.warning(f"Failed to release Ollama models: {str(e)}") def clean_graphrag_keys(self): GRAPH_CONFIG = os.path.join( diff --git a/lpm_kernel/L2/memory_manager.py b/lpm_kernel/L2/memory_manager.py new file mode 100644 index 00000000..b4c8950a --- /dev/null +++ b/lpm_kernel/L2/memory_manager.py @@ -0,0 +1,149 @@ +"""Memory management utilities for PyTorch training. + +This module provides lightweight utilities to monitor memory usage +and configure PyTorch's built-in memory management features. +""" + +import os +import gc +import logging +import psutil +import torch +from typing import Dict, Any + +# Configure logging +logger = logging.getLogger(__name__) + +class MemoryManager: + """Simple memory manager that leverages PyTorch's built-in memory optimizations.""" + + def __init__(self): + """Initialize the memory manager.""" + self.cuda_available = torch.cuda.is_available() + self.process = psutil.Process(os.getpid()) + + # Remove redundant environment variable setting - now handled in train_for_user.sh + + def get_memory_info(self) -> Dict[str, Any]: + """Get current memory usage information.""" + info = { + "ram_used_percent": psutil.virtual_memory().percent, + "ram_used_gb": psutil.virtual_memory().used / (1024**3), + "ram_available_gb": psutil.virtual_memory().available / (1024**3), + "ram_total_gb": psutil.virtual_memory().total / (1024**3), + } + + if self.cuda_available: + try: + info.update({ + "vram_used_gb": torch.cuda.memory_allocated() / (1024**3), + "vram_reserved_gb": torch.cuda.memory_reserved() / (1024**3), + "vram_total_gb": torch.cuda.get_device_properties(0).total_memory / (1024**3), + }) + except RuntimeError as e: + logger.warning(f"Error getting CUDA memory info: {str(e)}") + self.cuda_available = False + + return info + + def cleanup_memory(self, force: bool = False) -> None: + """Free up memory by garbage collection and emptying CUDA cache.""" + # Run Python garbage collection + gc.collect() + + # Empty CUDA cache if available + if self.cuda_available: + torch.cuda.empty_cache() + + # Log memory status after cleanup + if force: + info = self.get_memory_info() + logger.info( + f"Memory after cleanup: RAM: {info['ram_used_gb']:.2f}GB / {info['ram_total_gb']:.2f}GB, " + f"VRAM: {info.get('vram_used_gb', 0):.2f}GB / {info.get('vram_total_gb', 0):.2f}GB" + ) + + def get_optimal_training_config(self) -> Dict[str, Any]: + """Get recommended configurations for model training based on hardware capabilities.""" + # Default configs that rely on PyTorch's automatic memory management + config = { + "device_map": "auto", + "fp16": False, + "bf16": False, + "gradient_checkpointing": True, + "gradient_accumulation_steps": 1, + } + + # Enable mixed precision based on hardware support + if self.cuda_available: + capability = torch.cuda.get_device_capability() + if capability[0] >= 8: # Ampere or newer (supports BF16) + config["bf16"] = True + elif capability[0] >= 7: # Volta or newer (supports FP16) + config["fp16"] = True + + # Adjust accumulation steps based on available memory + vram_gb = self.get_memory_info().get("vram_total_gb", 0) + if vram_gb < 8: # Small GPUs + config["gradient_accumulation_steps"] = 4 + elif vram_gb < 16: # Medium GPUs + config["gradient_accumulation_steps"] = 2 + + return config + + def optimize_model_for_training(self, model): + """Apply PyTorch's built-in memory optimizations for training.""" + # Enable gradient checkpointing if available + if hasattr(model, "gradient_checkpointing_enable"): + logger.info("Enabling gradient checkpointing for memory efficiency") + model.gradient_checkpointing_enable() + + # Enable memory-efficient attention for PyTorch 2.0+ + if hasattr(model, "config"): + try: + model.config.use_memory_efficient_attention = True + except: + pass + + # Enable flash attention for compatible GPUs + if self.cuda_available and torch.cuda.get_device_capability()[0] >= 8: + try: + model.config.attn_implementation = "flash_attention_2" + except: + pass + + return model + + def optimize_training_args(self, training_args): + """Configure training arguments for efficient memory usage.""" + if not training_args: + return None + + # Get optimal configuration based on hardware + config = self.get_optimal_training_config() + + # Apply configurations to training arguments + if not getattr(training_args, "fp16", False) and not getattr(training_args, "bf16", False): + training_args.fp16 = config["fp16"] + training_args.bf16 = config["bf16"] + + if not getattr(training_args, "gradient_checkpointing", False): + training_args.gradient_checkpointing = config["gradient_checkpointing"] + + if training_args.gradient_accumulation_steps == 1: + training_args.gradient_accumulation_steps = config["gradient_accumulation_steps"] + + logger.info("Training configuration optimized for memory efficiency:") + logger.info(f" Mixed precision: FP16={training_args.fp16}, BF16={training_args.bf16}") + logger.info(f" Gradient checkpointing: {training_args.gradient_checkpointing}") + logger.info(f" Gradient accumulation steps: {training_args.gradient_accumulation_steps}") + + return training_args + + +# Global memory manager instance +memory_manager = MemoryManager() + +def get_memory_manager() -> MemoryManager: + """Get the global memory manager instance.""" + return memory_manager \ No newline at end of file diff --git a/lpm_kernel/L2/merge_lora_weights.py b/lpm_kernel/L2/merge_lora_weights.py index 59877b2c..14e892f4 100644 --- a/lpm_kernel/L2/merge_lora_weights.py +++ b/lpm_kernel/L2/merge_lora_weights.py @@ -6,36 +6,123 @@ """ import argparse +import os +import gc +import sys +import logging +import traceback +import torch +import datetime +from typing import Optional, Dict, Any from peft import PeftModel from transformers import AutoModelForCausalLM, AutoTokenizer -import logging +from lpm_kernel.L2.memory_manager import get_memory_manager + +# Configure logging +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO) def merge_lora_weights(base_model_path, lora_adapter_path, output_model_path): """Merge LoRA weights into a base model and save the result. This function loads a base model and a LoRA adapter, merges them together, - and saves the resulting model to the specified output path. + and saves the resulting model to the specified output path. It leverages + PyTorch's built-in memory management features. Args: base_model_path: Path to the base model directory. lora_adapter_path: Path to the LoRA adapter directory. output_model_path: Path where the merged model will be saved. """ - # Load the base model - logging.info(f"Loading base model from {base_model_path}") - base_model = AutoModelForCausalLM.from_pretrained(base_model_path) - tokenizer = AutoTokenizer.from_pretrained(base_model_path) - - # Load the LoRA adapter and apply it to the base model - lora_model = PeftModel.from_pretrained(base_model, lora_adapter_path) - - # Merge LoRA weights into the base model - merged_model = lora_model.merge_and_unload() - - # Save the merged model and tokenizer - merged_model.save_pretrained(output_model_path) - tokenizer.save_pretrained(output_model_path) + # Get memory manager + memory_manager = get_memory_manager() + + try: + # Log initial memory state + memory_info = memory_manager.get_memory_info() + logger.info(f"Initial memory state: RAM used: {memory_info['ram_used_gb']:.2f}GB, " + f"available: {memory_info['ram_available_gb']:.2f}GB") + + # Determine if CUDA is available and should be used + use_cuda = memory_manager.cuda_available + device = "cuda" if use_cuda else "cpu" + + if use_cuda: + logger.info(f"CUDA is available. VRAM used: {memory_info.get('vram_used_gb', 0):.2f}GB") + else: + logger.warning("CUDA not available or not enabled. Using CPU for model operations.") + + # Clean up memory before starting + memory_manager.cleanup_memory(force=True) + + # Explicitly set device configuration based on available hardware + device_map = "auto" if use_cuda else None + dtype = torch.float16 if use_cuda else torch.float32 + + logger.info(f"Loading base model from {base_model_path} with device_map={device_map}, dtype={dtype}") + + # Use explicit configuration for GPU utilization + base_model = AutoModelForCausalLM.from_pretrained( + base_model_path, + torch_dtype=dtype, + device_map=device_map + ) + + # Load tokenizer - this doesn't consume much memory + tokenizer = AutoTokenizer.from_pretrained(base_model_path) + + # Load the LoRA adapter and apply it to the base model + logger.info(f"Loading LoRA adapter from {lora_adapter_path}") + lora_model = PeftModel.from_pretrained(base_model, lora_adapter_path) + + # Merge weights - this is done automatically by PyTorch on appropriate devices + logger.info(f"Merging LoRA weights into base model on {device}") + merged_model = lora_model.merge_and_unload() + + # Clean up before saving + memory_manager.cleanup_memory() + + # Add inference optimization config to the merged model for faster startup + if use_cuda: + # Set inference-specific configuration in model config + if hasattr(merged_model.config, "torch_dtype"): + merged_model.config.torch_dtype = "float16" # Prefer float16 for inference + if not hasattr(merged_model.config, "pretraining_tp"): + merged_model.config.pretraining_tp = 1 # For tensor parallelism during inference + + # Set default inference device + if not hasattr(merged_model.config, "_default_inference_device"): + merged_model.config._default_inference_device = "cuda:0" + + logger.info("Added GPU optimization settings to model configuration") + + # Save merged model with shard size to prevent OOM errors during save + logger.info(f"Saving merged model to {output_model_path}") + merged_model.save_pretrained( + output_model_path, + safe_serialization=True, + max_shard_size="2GB" # Sharded saving to avoid memory spikes + ) + tokenizer.save_pretrained(output_model_path) + + # Save a special marker file to indicate this model should use GPU for inference + if use_cuda: + with open(os.path.join(output_model_path, "gpu_optimized.json"), "w") as f: + import json + json.dump({"gpu_optimized": True, "optimized_on": datetime.datetime.now().isoformat()}, f) + logger.info("Added GPU optimization marker file for faster service startup") + + logger.info("Model successfully merged and saved!") + + except Exception as e: + logger.error(f"Error during model merge: {str(e)}") + logger.error(traceback.format_exc()) + # Force cleanup + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + raise def merge_model_weights( diff --git a/lpm_kernel/L2/train.py b/lpm_kernel/L2/train.py index 539cfe8a..54a00ee6 100644 --- a/lpm_kernel/L2/train.py +++ b/lpm_kernel/L2/train.py @@ -28,10 +28,13 @@ create_and_prepare_model, formatting_prompts_func, create_chat_data, + release_ollama_models_early, ) from lpm_kernel.configs.logging import LOGGING_CONFIG import logging.config from lpm_kernel.configs.logging import get_train_process_logger +from lpm_kernel.L2.memory_manager import get_memory_manager + logger = get_train_process_logger() @@ -45,6 +48,25 @@ def __init__(self, *args, **kwargs): # Replace the default tqdm sys.modules["tqdm"].tqdm = LogTqdm +# Debug callback for logging training progress +class DebugCallback(transformers.TrainerCallback): + def __init__(self): + self.total_time = 0 + self.last_time = time.time() + + def on_step_end(self, args, state, control, **kwargs): + if state.global_step % 10 == 0: + current_time = time.time() + step_time = current_time - self.last_time + self.total_time += step_time + self.last_time = current_time + + # Log step time and training progress + logger.info(f"Step {state.global_step}: {step_time:.2f}s - Total training time: {self.total_time:.2f}s") + + def on_epoch_end(self, args, state, control, **kwargs): + logger.info(f"Epoch {state.epoch} completed") + @dataclass class ModelArguments: @@ -112,6 +134,10 @@ class ModelArguments: default=False, metadata={"help": "Enables UnSloth for training."}, ) + use_cuda: Optional[bool] = field( + default=False, + metadata={"help": "Enables CUDA GPU acceleration for training and inference when available."}, + ) @dataclass @@ -162,80 +188,117 @@ def main(model_args, data_args, training_args): for handler in logging.getLogger().handlers: handler.flush() - logger.info("start 1") + # Get memory manager for optimization + memory_manager = get_memory_manager() + memory_manager.cleanup_memory(force=True) + + # Release Ollama models if they exist to free up VRAM + if torch.cuda.is_available() and model_args.use_cuda: + release_ollama_models_early() + + logger.info("Initializing training with memory optimizations") set_seed(training_args.seed) - logger.info("start 2") - # model + + # Apply PyTorch memory optimizations to training arguments + logger.info("Applying memory optimizations to training configuration") + training_args = memory_manager.optimize_training_args(training_args) + + # --- Accelerate optimizer state offloading logic --- + # Enable optimizer state offload to CPU if VRAM is low and not using DeepSpeed + vram_total = memory_manager.get_memory_info().get("vram_total_gb", 0) + use_accelerate_offload = False + if torch.cuda.is_available() and model_args.use_cuda and vram_total > 0 and vram_total < 16: + # Only set if not already using DeepSpeed + if not hasattr(training_args, "deepspeed") or training_args.deepspeed is None: + logger.info("Enabling Hugging Face Accelerate optimizer state offload to CPU for low VRAM GPUs") + accelerate_config = { + "compute_environment": "LOCAL_MACHINE", + "deepspeed_config": None, + "distributed_type": "NO", + "downcast_bf16": False, + "fsdp_config": {}, + "main_training_function": "main", + "mixed_precision": "no", + "num_machines": 1, + "num_processes": 1, + "use_cpu": False, + "zero3_init_flag": False, + "offload_optimizer_device": "cpu", + "offload_param_device": "none" + } + training_args.accelerate_config = accelerate_config + use_accelerate_offload = True + + # Model loading with device_map="auto" for automatic offloading + logger.info(f"Loading model with automatic memory management from {model_args.model_name_or_path}") + + # Create model arguments dict with automatic offloading + model_kwargs = { + # Don't use "auto" device_map initially to avoid meta tensor issues + "device_map": None, + "trust_remote_code": True + } + + # Configure quantization if requested + if model_args.use_4bit_quantization: + from transformers import BitsAndBytesConfig + compute_dtype = getattr(torch, model_args.bnb_4bit_compute_dtype) + quant_storage_dtype = getattr(torch, model_args.bnb_4bit_quant_storage_dtype) + + model_kwargs["quantization_config"] = BitsAndBytesConfig( + load_in_4bit=model_args.use_4bit_quantization, + bnb_4bit_quant_type=model_args.bnb_4bit_quant_type, + bnb_4bit_compute_dtype=compute_dtype, + bnb_4bit_use_double_quant=model_args.use_nested_quant, + bnb_4bit_quant_storage=quant_storage_dtype, + ) + # For 4-bit models, we can use device_map="auto" + model_kwargs["device_map"] = "auto" + logger.info("Using 4-bit quantization for memory efficiency") + elif model_args.use_8bit_quantization: + from transformers import BitsAndBytesConfig + model_kwargs["quantization_config"] = BitsAndBytesConfig( + load_in_8bit=model_args.use_8bit_quantization + ) + # For 8-bit models, we can use device_map="auto" + model_kwargs["device_map"] = "auto" + logger.info("Using 8-bit quantization for memory efficiency") + + # Flash attention for memory efficiency when supported + if model_args.use_flash_attn and torch.cuda.is_available() and model_args.use_cuda: + model_kwargs["attn_implementation"] = "flash_attention_2" + logger.info("Using Flash Attention 2 for memory efficiency") + + # Load model with built-in memory management features model, peft_config, tokenizer = create_and_prepare_model( - model_args, data_args, training_args - ) - logger.info("start 3") - # gradient ckpt - model.config.use_cache = not training_args.gradient_checkpointing - training_args.gradient_checkpointing = ( - training_args.gradient_checkpointing and not model_args.use_unsloth + model_args, data_args, training_args, model_kwargs=model_kwargs ) - logger.info("start 4") - if training_args.gradient_checkpointing: - training_args.gradient_checkpointing_kwargs = { - "use_reentrant": model_args.use_reentrant - } - - # Configure system resources for optimal performance - def configure_system_resources(num_cores=None): - """ - Configure system resources to optimize training performance - - Args: - num_cores: Number of CPU cores to use, if None, automatically detect - """ - # Automatically detect available cores, if not specified - if num_cores is None: - num_cores = min(os.cpu_count(), 6) # Limit to 6 cores, match Docker configuration - - logger.info(f"Configuring system to use {num_cores} CPU cores") - - # Set environment variables - os.environ["OMP_NUM_THREADS"] = str(num_cores) - os.environ["MKL_NUM_THREADS"] = str(num_cores) - os.environ["NUMEXPR_NUM_THREADS"] = str(num_cores) - - # Set PyTorch thread count - torch.set_num_threads(num_cores) - - # If supported, set PyTorch multi-thread optimization - if hasattr(torch, "set_num_interop_threads"): - torch.set_num_interop_threads(num_cores) - - # Enable memory-optimized garbage collection - # import gc - # gc.enable() - - # # Monitor memory usage and clean up periodically - # def schedule_gc(): - # gc.collect() - # torch.cuda.empty_cache() if torch.cuda.is_available() else None - # return schedule_gc - - # If CUDA is available, set CUDA device - if torch.cuda.is_available(): - torch.cuda.set_device(0) - logger.info(f"CUDA is available. Using device: {torch.cuda.get_device_name(0)}") - # Display CUDA memory information - logger.info(f"CUDA memory allocated: {torch.cuda.memory_allocated(0) / 1024**2:.2f} MB") - logger.info(f"CUDA memory reserved: {torch.cuda.memory_reserved(0) / 1024**2:.2f} MB") - # Call function to configure system resources - configure_system_resources() + # If model has meta tensors, handle them properly + if hasattr(model, "is_meta") and model.is_meta: + logger.info("Model has meta tensors, using to_empty() to properly initialize") + device = "cuda" if torch.cuda.is_available() and model_args.use_cuda else "cpu" + model = model.to_empty(device=device) + + # Apply gradient checkpointing for memory efficiency + if training_args.gradient_checkpointing and hasattr(model, "gradient_checkpointing_enable"): + logger.info("Enabling gradient checkpointing for memory efficiency") + model.gradient_checkpointing_enable() + model.config.use_cache = False + + # Allow only one full forward/backward pass at a time (if needed for memory) + if torch.cuda.is_available() and memory_manager.get_memory_info().get("vram_total_gb", 0) < 8: + torch.cuda.set_per_process_memory_fraction(0.9) + logger.info("Setting memory fraction limit to avoid OOM errors") # datasets train_dataset = create_chat_data( data_args, tokenizer, ) - + response_template = "\n<|im_start|>assistant\n" - + collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer) training_args.dataset_kwargs = { @@ -243,6 +306,45 @@ def configure_system_resources(num_cores=None): "add_special_tokens": data_args.add_special_tokens, } + # Use DeepSpeed to handle meta tensors if available + try: + # Only configure DeepSpeed if meta tensors are present and DeepSpeed is available + if hasattr(model, "is_meta") and model.is_meta: + logger.info("Model has meta tensors, checking DeepSpeed availability") + # First verify DeepSpeed is properly installed and importable + try: + import deepspeed + logger.info("DeepSpeed is available, configuring for meta tensor handling") + + # Configure with appropriate settings for meta tensors + training_args.deepspeed = { + "zero_stage": 3, + "offload_optimizer": { + "device": "cpu" + }, + "offload_param": { + "device": "cpu" + }, + "zero3_init_flag": True, + "zero_force_ds_cpu_optimizer": False + } + logger.info("DeepSpeed configured for meta tensor handling") + except ImportError: + logger.warning("DeepSpeed is not available, meta tensors will be handled differently") + # If DeepSpeed isn't available, use alternative approach to handle meta tensors + if torch.cuda.is_available() and model_args.use_cuda: + logger.info("Initializing meta tensors on GPU") + # Use device_map instead of DeepSpeed for meta tensor initialization + from accelerate import init_empty_weights + with init_empty_weights(): + model.to_empty(device="cuda") + else: + logger.info("Initializing meta tensors on CPU") + model.to_empty(device="cpu") + except Exception as e: + logger.warning(f"Could not configure meta tensor handling: {e}") + logger.warning(traceback.format_exc()) + trainer = SFTTrainer( model=model, tokenizer=tokenizer, @@ -252,145 +354,46 @@ def configure_system_resources(num_cores=None): formatting_func=formatting_prompts_func, data_collator=collator, ) + + # Print model details trainer.accelerator.print(f"{trainer.model}") - trainer.model.print_trainable_parameters() - - logger.info("start 6") - # train - checkpoint = None - if training_args.resume_from_checkpoint is not None: - logger.info("start 6.1") - checkpoint = training_args.resume_from_checkpoint - logger.info("start 6.2") - - class DebugCallback(transformers.TrainerCallback): - """ - Debug callback to monitor training process - """ - + + if hasattr(trainer.model, "print_trainable_parameters"): + trainer.model.print_trainable_parameters() + + # Memory usage tracking callback + class MemoryMonitorCallback(transformers.TrainerCallback): def __init__(self): - self.step_times = {} - self.current_step_start = None - - def on_train_begin(self, args, state, control, **kwargs): - logger.info("=== Training Begin ===") - logger.info("Checking initial conditions:") - trainer = kwargs.get("trainer") - if trainer: - # Check model status - logger.info(f"Model device: {trainer.model.device}") - logger.info(f"Model dtype: {next(trainer.model.parameters()).dtype}") - - # Check data loader - if hasattr(trainer, "train_dataset"): - logger.info(f"Training dataset size: {len(trainer.train_dataset)}") - - # Check optimizer - if hasattr(trainer, "optimizer"): - logger.info("Optimizer configuration:") - for i, group in enumerate(trainer.optimizer.param_groups): - logger.info( - f"Group {i}: lr={group['lr']}, weight_decay={group['weight_decay']}" - ) - - def on_step_begin(self, args, state, control, **kwargs): - self.current_step_start = time.time() - logger.info(f"\n=== Starting Step {state.global_step + 1} ===") - - # Check system status every 10 steps - if state.global_step % 10 == 0: - process = psutil.Process() - with process.oneshot(): - logger.info(f"CPU Usage: {process.cpu_percent()}%") - logger.info( - f"Memory Usage: {process.memory_info().rss / 1024**2:.2f}MB" - ) - logger.info(f"Thread Count: {process.num_threads()}") - + self.memory_manager = get_memory_manager() + def on_step_end(self, args, state, control, **kwargs): - if self.current_step_start: - step_time = time.time() - self.current_step_start - self.step_times[state.global_step] = step_time - avg_time = sum(self.step_times.values()) / len(self.step_times) - logger.info( - f"Step {state.global_step + 1} completed in {step_time:.2f}s (avg: {avg_time:.2f}s)" - ) - - # Check if step time is much longer than average - if step_time > avg_time * 2 and len(self.step_times) > 1: - logger.warning( - f"Step {state.global_step + 1} took {step_time:.2f}s, which is much longer than average!" - ) - - trainer = kwargs.get("trainer") - if trainer and hasattr(trainer, "optimizer"): - # Check gradient status - grad_norms = [] - for name, param in trainer.model.named_parameters(): - if param.grad is not None: - grad_norms.append(param.grad.norm().item()) - - if grad_norms: - avg_grad_norm = sum(grad_norms) / len(grad_norms) - logger.info(f"Average gradient norm: {avg_grad_norm:.5f}") - else: - logger.warning("No gradients found in this step!") - - def on_log(self, args, state, control, logs=None, **kwargs): - if logs: - logger.info(f"=== Logs for Step {state.global_step} ===") - for key, value in logs.items(): - logger.info(f"{key}: {value}") - - def on_train_end(self, args, state, control, **kwargs): - logger.info("=== Training Ended ===") - logger.info(f"Total steps completed: {state.global_step}") - if self.step_times: - avg_time = sum(self.step_times.values()) / len(self.step_times) - logger.info(f"Average step time: {avg_time:.2f}s") + # Check memory every 5 steps + if state.global_step % 5 == 0 and torch.cuda.is_available(): + info = self.memory_manager.get_memory_info() + vram_usage_pct = info.get("vram_used_gb", 0) / info.get("vram_total_gb", 1) * 100 + + if vram_usage_pct > 90: + logger.info(f"VRAM usage high ({vram_usage_pct:.1f}%), cleaning cache") + self.memory_manager.cleanup_memory() + + def on_save(self, args, state, control, **kwargs): + # Free up memory before saving + self.memory_manager.cleanup_memory(force=True) + # Add memory monitoring + trainer.add_callback(MemoryMonitorCallback()) + + # Add existing debug callback trainer.add_callback(DebugCallback()) - # Add more detailed logs - logger.info("Starting training preparation...") - try: - logger.info("Initializing training process...") - # Check model loading and structure - logger.info("Analyzing model structure...") - model = trainer.model - - def print_model_structure(model, prefix=""): - logger.info(f"{prefix}Model class: {model.__class__.__name__}") - for name, child in model.named_children(): - logger.info(f"{prefix}Child: {name} ({child.__class__.__name__})") - if len(list(child.named_children())) > 0: - print_model_structure(child, prefix + " ") - - # print_model_structure(model) - - # Check model size - total_params = sum(p.numel() for p in trainer.model.parameters()) - trainable_params = sum( - p.numel() for p in trainer.model.parameters() if p.requires_grad - ) - logger.info(f"Total parameters: {total_params:,}") - logger.info(f"Trainable parameters: {trainable_params:,}") - - # Check optimizer settings - logger.info("Checking optimizer settings...") - - # Check data loader - train_dataloader = trainer.get_train_dataloader() - logger.info(f"Train dataloader created with {len(train_dataloader)} batches") - - process = psutil.Process() - memory_info = process.memory_info() - logger.info(f"Memory usage details:") - logger.info(f"RSS (Resident Set Size): {memory_info.rss / 1024**2:.2f}MB") - logger.info(f"VMS (Virtual Memory Size): {memory_info.vms / 1024**2:.2f}MB") + # Resume from checkpoint if specified + checkpoint = None + if training_args.resume_from_checkpoint is not None: + checkpoint = training_args.resume_from_checkpoint - # Start training - logger.info("Starting actual training process...") + # Training with automatic memory management + try: + logger.info("Starting training with memory-optimized configuration") trainer.train(resume_from_checkpoint=checkpoint) except Exception as e: logger.error(f"Error during training: {str(e)}") @@ -398,9 +401,13 @@ def print_model_structure(model, prefix=""): logger.error(f"Traceback: {traceback.format_exc()}") raise - logger.info("start 7") + # Save the model if trainer.is_fsdp_enabled: trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT") + + # Clean up before saving + memory_manager.cleanup_memory(force=True) + trainer.save_model() logger.info("Training completed successfully") diff --git a/lpm_kernel/L2/train_for_user.sh b/lpm_kernel/L2/train_for_user.sh old mode 100755 new mode 100644 index 92965a4f..976426fc --- a/lpm_kernel/L2/train_for_user.sh +++ b/lpm_kernel/L2/train_for_user.sh @@ -6,6 +6,7 @@ NUM_TRAIN_EPOCHS="3" CONCURRENCY_THREADS="2" DATA_SYNTHESIS_MODE="low" HALF=False +USE_CUDA=False # Default to False, will be overridden by parameter IS_COT=False # Process parameters @@ -15,33 +16,75 @@ while [[ "$#" -gt 0 ]]; do --epochs) NUM_TRAIN_EPOCHS="$2"; shift ;; --threads) CONCURRENCY_THREADS="$2"; shift ;; --mode) DATA_SYNTHESIS_MODE="$2"; shift ;; + --cuda) + # Convert string to lowercase for consistent comparison + cuda_value=$(echo "$2" | tr '[:upper:]' '[:lower:]') + if [[ "$cuda_value" == "true" || "$cuda_value" == "1" || "$cuda_value" == "yes" ]]; then + USE_CUDA=True + echo "CUDA enabled by user configuration." + else + USE_CUDA=False + echo "CUDA disabled by user configuration." + fi + shift ;; --is_cot) IS_COT="$2"; shift ;; *) echo "Unknown parameter: $1"; exit 1 ;; esac shift done +# Explicitly log the CUDA setting passed from the command line +echo "CUDA parameter received: $USE_CUDA" + +# Verify CUDA availability if enabled +if [[ "$USE_CUDA" == "True" ]]; then + # Set CUDA environment variables to ensure PyTorch detects GPU + export CUDA_VISIBLE_DEVICES=0 + echo "CUDA_VISIBLE_DEVICES set to 0" + + # Set CUDA_LAUNCH_BLOCKING to 0 for async operations (better performance) + export CUDA_LAUNCH_BLOCKING=0 + echo "CUDA_LAUNCH_BLOCKING set to 0 for better performance" +else + # Explicitly disable CUDA + export CUDA_VISIBLE_DEVICES="" + echo "CUDA_VISIBLE_DEVICES explicitly disabled" +fi + # Log the parameters being used echo "Using training parameters:" echo " Learning rate: $LEARNING_RATE" echo " Number of epochs: $NUM_TRAIN_EPOCHS" echo " Concurrency threads: $CONCURRENCY_THREADS" echo " Data synthesis mode: $DATA_SYNTHESIS_MODE" +echo " Use CUDA: $USE_CUDA" echo " Is chain of thought: $IS_COT" # If concurrency threads are set, configure related environment variables if [ "$CONCURRENCY_THREADS" != "1" ]; then + # Limit the number of parallel threads to avoid memory issues export OMP_NUM_THREADS=$CONCURRENCY_THREADS export MKL_NUM_THREADS=$CONCURRENCY_THREADS export NUMEXPR_NUM_THREADS=$CONCURRENCY_THREADS + # Add torch-specific threading controls + export PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:128 echo "Set thread environment variables to $CONCURRENCY_THREADS" fi -# Add BF16 option based on the platform -if [ "$PLATFORM" != "apple" ]; then +# Add BF16 option based on the platform and CUDA availability +if [ "$PLATFORM" != "apple" ] && [ "$USE_CUDA" == "True" ]; then HALF=True + echo "Enabling BF16 half precision for non-Apple platform with CUDA" +else + echo "Using standard precision (not using BF16)" fi +# Print environment for debugging +echo "Environment configuration:" +echo " CUDA_VISIBLE_DEVICES: ${CUDA_VISIBLE_DEVICES}" +echo " PYTORCH_CUDA_ALLOC_CONF: ${PYTORCH_CUDA_ALLOC_CONF}" +echo " Using half precision: ${HALF}" + # Execute training script with parameters from environment variables python lpm_kernel/L2/train.py \ --seed 42 \ @@ -70,7 +113,7 @@ python lpm_kernel/L2/train.py \ --per_device_train_batch_size 2 \ --gradient_accumulation_steps $CONCURRENCY_THREADS \ --gradient_checkpointing True \ - --use_reentrant True \ + --use_reentrant False \ --use_peft_lora True \ --lora_r 8 \ --lora_alpha 16 \ @@ -78,6 +121,7 @@ python lpm_kernel/L2/train.py \ --lora_target_modules "all-linear" \ --use_4bit_quantization False \ --use_nested_quant False \ - --bnb_4bit_compute_dtype "bfloat16"\ - --is_cot $IS_COT + --bnb_4bit_compute_dtype "bfloat16" \ + --is_cot $IS_COT \ + --use_cuda $USE_CUDA diff --git a/lpm_kernel/L2/utils.py b/lpm_kernel/L2/utils.py index 5b13a808..000debc9 100644 --- a/lpm_kernel/L2/utils.py +++ b/lpm_kernel/L2/utils.py @@ -33,12 +33,171 @@ MEMORY_COT_PROMPT, ) +# Add import for memory manager +from .memory_manager import get_memory_manager +import gc +import requests + +# Initialize the logger +logger = logging.getLogger(__name__) # Default chat templates for different model formats DEFAULT_CHATML_CHAT_TEMPLATE = "{% for message in messages %}\n{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% if loop.last and add_generation_prompt %}{{'<|im_start|>assistant\n' }}{% endif %}{% endfor %}" DEFAULT_ZEPHYR_CHAT_TEMPLATE = "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}" +def release_ollama_models_early(): + """Release Ollama models from memory as early as possible before model loading. + + This function uses the Ollama API with keep_alive=0 parameter to properly unload models + and free up VRAM before loading the training model. + """ + try: + from lpm_kernel.api.services.user_llm_config_service import UserLLMConfigService + import json + + logger.info("Early release of Ollama models to free up VRAM for training") + + # Get current user LLM config to identify models to release + user_llm_config_service = UserLLMConfigService() + user_llm_config = user_llm_config_service.get_available_llm() + + if not user_llm_config: + logger.warning("No user LLM configuration found. Skipping Ollama model release.") + return + + # Track which models have been released + released_models = set() + + def get_generate_url(base_endpoint): + """Helper function to get the API endpoint for unloading models""" + if not base_endpoint: + return None + + base_url = base_endpoint.rstrip("/") + + # Convert to API base URL if needed (may be v1 format or direct ollama format) + if "/v1/" in base_url: + api_base = base_url.split("/v1/")[0] + return f"{api_base}/api/generate" + else: + # Check if this is a non-localhost Ollama instance + if "ollama" in base_url.lower(): + if "localhost" in base_url or "127.0.0.1" in base_url: + return "http://localhost:11434/api/generate" + else: + # Extract the base URL and use it + parts = base_url.split("//") + if len(parts) > 1: + host = parts[1].split("/")[0] + return f"{parts[0]}//{host}/api/generate" + + # Default ollama endpoint as fallback + return "http://localhost:11434/api/generate" + + # Release chat model if using Ollama + if "ollama" in user_llm_config.chat_endpoint.lower() and user_llm_config.chat_model_name: + chat_model = user_llm_config.chat_model_name + generate_url = get_generate_url(user_llm_config.chat_endpoint) + + if not generate_url: + logger.warning(f"Could not determine API endpoint for chat model: {chat_model}") + else: + logger.info(f"Releasing Ollama chat model: {chat_model} via {generate_url}") + + try: + # Set up headers with API key if provided + headers = { + "Content-Type": "application/json" + } + if user_llm_config.chat_api_key: + headers["Authorization"] = f"Bearer {user_llm_config.chat_api_key}" + + # Use the proper generate endpoint with keep_alive=0 to unload + payload = { + "model": chat_model, + "keep_alive": 0, + "prompt": " " # Minimal prompt needed for request + } + + unload_response = requests.post( + generate_url, + headers=headers, + data=json.dumps(payload), + timeout=30 # Add timeout to prevent hanging + ) + + if unload_response.status_code < 300: + logger.info(f"✅ Successfully unloaded chat model: {chat_model}") + released_models.add(chat_model) + else: + logger.warning(f"Failed to unload model via API: {unload_response.status_code} - {unload_response.text}") + except Exception as e: + logger.warning(f"Failed to release chat model {chat_model}: {str(e)}") + + # Release embedding model if different from chat model and using Ollama + if (user_llm_config.embedding_model_name and + "ollama" in user_llm_config.embedding_endpoint.lower() and + user_llm_config.embedding_model_name != user_llm_config.chat_model_name and + user_llm_config.embedding_model_name not in released_models): + + embedding_model = user_llm_config.embedding_model_name + generate_url = get_generate_url(user_llm_config.embedding_endpoint) + + if not generate_url: + logger.warning(f"Could not determine API endpoint for embedding model: {embedding_model}") + else: + logger.info(f"Releasing Ollama embedding model: {embedding_model} via {generate_url}") + + try: + # Set up headers with API key if provided + headers = { + "Content-Type": "application/json" + } + if user_llm_config.embedding_api_key: + headers["Authorization"] = f"Bearer {user_llm_config.embedding_api_key}" + + # Use the proper generate endpoint with keep_alive=0 to unload + payload = { + "model": embedding_model, + "keep_alive": 0, + "prompt": " " # Minimal prompt needed for request + } + + unload_response = requests.post( + generate_url, + headers=headers, + data=json.dumps(payload), + timeout=30 # Add timeout to prevent hanging + ) + + if unload_response.status_code < 300: + logger.info(f"✅ Successfully unloaded embedding model: {embedding_model}") + released_models.add(embedding_model) + else: + logger.warning(f"Failed to unload model via API: {unload_response.status_code} - {unload_response.text}") + except Exception as e: + logger.warning(f"Failed to release embedding model {embedding_model}: {str(e)}") + + # Final cleanup and verification + if torch.cuda.is_available(): + torch.cuda.empty_cache() + memory_info = get_memory_manager().get_memory_info() + vram_used = memory_info.get('vram_used_gb', 0) + vram_total = memory_info.get('vram_total_gb', 1) + logger.info(f"VRAM after early model release: {vram_used:.2f}GB / {vram_total:.2f}GB ({vram_used/vram_total*100:.1f}%)") + + if released_models: + logger.info(f"Early release completed for {len(released_models)} Ollama models: {', '.join(released_models)}") + else: + logger.info("No Ollama models were released early") + + except Exception as e: + import traceback + logger.error(f"Error in early Ollama model release: {str(e)}") + logger.error(traceback.format_exc()) + + def count_tokens_from_string(string: str, encoding_name: str = "cl100k_base") -> int: """Returns the number of tokens in a text string using a specified encoding. @@ -105,21 +264,31 @@ def list(cls): """Returns a list of all special tokens.""" return [token.value for token in cls] -def create_and_prepare_model(args, data_args, training_args): +def create_and_prepare_model(args, data_args, training_args, model_kwargs=None): """Creates and prepares a model for training. Args: args: Model arguments containing model configuration. data_args: Data arguments for training. training_args: Training configuration arguments. + model_kwargs: Additional kwargs to pass to the model loading function. Returns: - Tuple of (model, tokenizer) ready for training. + Tuple of (model, tokenizer, peft_config) ready for training. """ + # Get the memory manager for adaptive loading + memory_manager = get_memory_manager() + model_kwargs = model_kwargs or {} + + # Release Ollama models early before we load any models + if torch.cuda.is_available() and args.use_cuda: + release_ollama_models_early() + # Force cleanup memory after releasing Ollama models + memory_manager.cleanup_memory(force=True) + if args.use_unsloth: from unsloth import FastLanguageModel bnb_config = None - quant_storage_dtype = None if ( torch.distributed.is_available() @@ -129,55 +298,142 @@ def create_and_prepare_model(args, data_args, training_args): ): raise NotImplementedError("Unsloth is not supported in distributed training") - if args.use_4bit_quantization: - compute_dtype = getattr(torch, args.bnb_4bit_compute_dtype) - quant_storage_dtype = getattr(torch, args.bnb_4bit_quant_storage_dtype) - - bnb_config = BitsAndBytesConfig( - load_in_4bit=args.use_4bit_quantization, - bnb_4bit_quant_type=args.bnb_4bit_quant_type, - bnb_4bit_compute_dtype=compute_dtype, - bnb_4bit_use_double_quant=args.use_nested_quant, - bnb_4bit_quant_storage=quant_storage_dtype, - ) + # Clean up memory before loading model + memory_manager.cleanup_memory() + + # Check for CUDA availability and use it if enabled + cuda_available = torch.cuda.is_available() + use_cuda_requested = args.use_cuda + device = "cpu" + + # Always enable memory-adaptive loading by default (device_map="auto"), unless CUDA is off + if cuda_available and use_cuda_requested: + device = "cuda" + model_kwargs["device_map"] = "auto" + else: + if use_cuda_requested and not cuda_available: + logger.warning("⚠️ CUDA was requested but is not available on this system. Falling back to CPU.") + elif cuda_available and not use_cuda_requested: + logger.info("ℹ️ CUDA is available but not requested. Using CPU as specified.") + else: + logger.info("ℹ️ CUDA is not available. Using CPU for training.") + # Explicitly remove device_map to force CPU-only + if "device_map" in model_kwargs: + model_kwargs.pop("device_map") + logger.info("Using CPU for model training and inference.") + + # Configure quantization based on available memory + # Use model_kwargs quantization_config if provided, otherwise build it + if "quantization_config" not in model_kwargs: + if args.use_4bit_quantization: + compute_dtype = getattr(torch, args.bnb_4bit_compute_dtype) + quant_storage_dtype = getattr(torch, args.bnb_4bit_quant_storage_dtype) + + bnb_config = BitsAndBytesConfig( + load_in_4bit=args.use_4bit_quantization, + bnb_4bit_quant_type=args.bnb_4bit_quant_type, + bnb_4bit_compute_dtype=compute_dtype, + bnb_4bit_use_double_quant=args.use_nested_quant, + bnb_4bit_quant_storage=quant_storage_dtype, + ) + model_kwargs["quantization_config"] = bnb_config - if compute_dtype == torch.float16 and args.use_4bit_quantization: - major, _ = torch.cuda.get_device_capability() - if major >= 8: - logging.info("=" * 80) - logging.info( - "Your GPU supports bfloat16, you can accelerate training with the argument --bf16" - ) - logging.info("=" * 80) + if compute_dtype == torch.float16 and args.use_4bit_quantization: + major, _ = torch.cuda.get_device_capability() if torch.cuda.is_available() else (0, 0) + if major >= 8: + logger.info("Your GPU supports bfloat16, you can accelerate training with the argument --bf16") elif args.use_8bit_quantization: bnb_config = BitsAndBytesConfig(load_in_8bit=args.use_8bit_quantization) + model_kwargs["quantization_config"] = bnb_config - if args.use_unsloth: - # Load model - model, _ = FastLanguageModel.from_pretrained( - model_name=args.model_name_or_path, - max_seq_length=data_args.max_seq_length, - dtype=None, - load_in_4bit=args.use_4bit_quantization, - ) - else: - if os.getenv("PLATFORM") != "apple": - model = AutoModelForCausalLM.from_pretrained( - args.model_name_or_path, - quantization_config=bnb_config, - trust_remote_code=True, - attn_implementation="flash_attention_2" if args.use_flash_attn else "eager", - torch_dtype=torch.bfloat16 - ) + # Load model with memory-adaptive approach + model = None + tokenizer = None + peft_config = None + + try: + # First try loading the model with the requested configuration + if args.use_unsloth: + # Load model with Unsloth using memory manager + unsloth_kwargs = { + "model_name": args.model_name_or_path, + "max_seq_length": data_args.max_seq_length, + "dtype": None, + "load_in_4bit": args.use_4bit_quantization, + "load_in_8bit": args.use_8bit_quantization, + "trust_remote_code": True, + "device_map": model_kwargs.get("device_map", "auto") if args.use_cuda and torch.cuda.is_available() else None, + } + + logger.info(f"Loading model with Unsloth with parameters: {unsloth_kwargs}") + model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs) + else: + # Load model with standard approach + load_kwargs = { + "trust_remote_code": True, + } + + # Use any provided model_kwargs + load_kwargs.update(model_kwargs) + + if "attn_implementation" not in load_kwargs and args.use_flash_attn: + load_kwargs["attn_implementation"] = "flash_attention_2" + + # Set default device_map if not specified + if "device_map" not in load_kwargs and args.use_cuda and torch.cuda.is_available(): + load_kwargs["device_map"] = "auto" + + logger.info(f"Loading model with parameters: {load_kwargs}") + model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, **load_kwargs) + + except (RuntimeError, torch.cuda.OutOfMemoryError, MemoryError) as e: + # If standard approaches fail, try progressive fallbacks + logger.warning(f"Failed to load model with standard settings: {str(e)}") + logger.info("Falling back to adaptive model loading...") + + # First cleanup to ensure maximum memory available + memory_manager.cleanup_memory(force=True) + + try: + # Try with simpler configuration - float16 instead of bfloat16 + logger.info("Attempting to load with float16 precision...") model = AutoModelForCausalLM.from_pretrained( args.model_name_or_path, - quantization_config=bnb_config, + device_map="auto" if torch.cuda.is_available() and args.use_cuda else None, + torch_dtype=torch.float16 if torch.cuda.is_available() and args.use_cuda else None, trust_remote_code=True ) + except (RuntimeError, torch.cuda.OutOfMemoryError, MemoryError) as e: + # If that fails too, try even more conservative loading + logger.warning(f"Float16 loading failed: {str(e)}") + memory_manager.cleanup_memory(force=True) + + try: + # Try with CPU offloading and gradual loading + logger.info("Attempting most conservative loading with CPU offloading...") + model = AutoModelForCausalLM.from_pretrained( + args.model_name_or_path, + device_map="auto", + offload_folder="offload_folder", + offload_state_dict=True, + torch_dtype=torch.float16 if torch.cuda.is_available() else None, + trust_remote_code=True, + low_cpu_mem_usage=True + ) + except Exception as e: + # If all fallbacks fail, it's a fatal error + logger.error(f"All adaptive loading approaches failed: {str(e)}") + raise RuntimeError(f"Failed to load model with any memory adaptation technique: {str(e)}") - peft_config = None - chat_template = None + # If still not loaded, it's a fatal error + if model is None: + raise RuntimeError("Failed to load model with any memory adaptation technique") + + # Apply memory optimization to model + model = memory_manager.optimize_model_for_training(model) + + # Configure LoRA if requested if args.use_peft_lora and not args.use_unsloth: peft_config = LoraConfig( lora_alpha=args.lora_alpha, @@ -190,6 +446,7 @@ def create_and_prepare_model(args, data_args, training_args): else args.lora_target_modules, ) + # Load tokenizer - tokenizers are usually small and don't need memory management special_tokens = None chat_template = None if args.chat_template_format == "chatml": @@ -214,22 +471,43 @@ def create_and_prepare_model(args, data_args, training_args): tokenizer = AutoTokenizer.from_pretrained( args.model_name_or_path, trust_remote_code=True ) - tokenizer.pad_token = tokenizer.eos_token + # Make sure pad_token is set + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + # Apply Unsloth LoRA if requested and check memory status if args.use_unsloth: - # Do model patching and add fast LoRA weights - model = FastLanguageModel.get_peft_model( - model, - lora_alpha=args.lora_alpha, - lora_dropout=args.lora_dropout, - r=args.lora_r, - target_modules=args.lora_target_modules.split(",") - if args.lora_target_modules != "all-linear" - else args.lora_target_modules, - use_gradient_checkpointing=training_args.gradient_checkpointing, - random_state=training_args.seed, - max_seq_length=data_args.max_seq_length, - ) + try: + # Clean up first + memory_manager.cleanup_memory() + + # Apply LoRA with memory monitoring + model = FastLanguageModel.get_peft_model( + model, + lora_alpha=args.lora_alpha, + lora_dropout=args.lora_dropout, + r=args.lora_r, + target_modules=args.lora_target_modules.split(",") + if args.lora_target_modules != "all-linear" + else args.lora_target_modules, + use_gradient_checkpointing=training_args.gradient_checkpointing, + random_state=training_args.seed, + max_seq_length=data_args.max_seq_length, + ) + + except Exception as e: + logger.error(f"Failed to apply Unsloth LoRA: {str(e)}") + # If Unsloth fails, we might need to fall back to standard training + if args.use_cuda and torch.cuda.is_available(): + logger.warning("Low VRAM detected, moving model to CPU") + model = model.cpu() + torch.cuda.empty_cache() + + # Final memory status check + memory_info = memory_manager.get_memory_info() + logger.info(f"Memory after model preparation: RAM: {memory_info['ram_used_gb']:.2f}GB / {memory_info['ram_total_gb']:.2f}GB") + if torch.cuda.is_available(): + logger.info(f"VRAM: {memory_info.get('vram_used_gb', 0):.2f}GB / {memory_info.get('vram_total_gb', 0):.2f}GB") return model, peft_config, tokenizer @@ -343,11 +621,11 @@ def setup_logger(log_path, logger_name="download_logger"): return logger -def save_hf_model(model_name="Qwen2.5-0.5B-Instruct", log_file_path=None) -> str: +def save_hf_model(model_name=None, log_file_path=None) -> str: """Saves a Hugging Face model locally. Args: - model_name: Name of the model to save. Defaults to "Qwen2.5-0.5B-Instruct". + model_name: Name of the model to save. If None, will attempt to get from config. log_file_path: Path to save download logs. If None, uses default path. Returns: @@ -360,25 +638,39 @@ def save_hf_model(model_name="Qwen2.5-0.5B-Instruct", log_file_path=None) -> str # Setup logging logger = setup_logger(log_file_path) + # If no model name provided, attempt to get from training configuration + if not model_name: + try: + from lpm_kernel.configs.config import Config + config = Config() + model_name = config.get("training", {}).get("model_name") + if not model_name: + logger.warning("No model name provided and none found in config. Using Qwen2.5-0.5B-Instruct as fallback.") + model_name = "Qwen2.5-0.5B-Instruct" + except Exception as e: + logger.warning(f"Failed to get model name from config: {str(e)}. Using Qwen2.5-0.5B-Instruct as fallback.") + model_name = "Qwen2.5-0.5B-Instruct" + base_dir = os.path.join(os.getcwd(), "resources/L2/base_models") + # Normalize model name and check for path traversal attempts normalized_model_name = os.path.normpath(model_name) if ".." in normalized_model_name or normalized_model_name.startswith("/"): - raise ValueError("Invalid model name") + raise ValueError("Invalid model name - potential path traversal attempt") + + # Prepare save path save_path = os.path.join(base_dir, normalized_model_name) os.makedirs(save_path, exist_ok=True) - from huggingface_hub import list_repo_files, configure_http_backend - import requests + from huggingface_hub import list_repo_files, configure_http_backend, hf_hub_download from tqdm import tqdm - from tqdm.contrib.concurrent import thread_map - import shutil + from concurrent.futures import ThreadPoolExecutor + import traceback - # Set a higher timeout, but remove the unsupported pool_size parameter + # Configure HTTP backend more simply try: - # Try using the timeout parameter to configure configure_http_backend(timeout=100.0) - except TypeError: - # If the timeout parameter is also not supported, do not use any parameters + except Exception as e: + logger.warning(f"Failed to configure HTTP backend with timeout: {e}") try: configure_http_backend() except Exception as e: @@ -391,31 +683,31 @@ def save_hf_model(model_name="Qwen2.5-0.5B-Instruct", log_file_path=None) -> str hf_model_name = f"Qwen/{model_name}" try: - # First get the list of all files in the repository + # Get list of files to download files = list_repo_files(hf_model_name) logger.info(f"Found {len(files)} files to download from {hf_model_name}") - - # Define a function for downloading a single file and recording progress - def download_file_with_progress(file_info): - filename, file_path = file_info + + def download_file_with_progress(filename): + """Download a single file from the model repository""" + local_file_path = os.path.join(save_path, filename) + + # Create directories if they don't exist + os.makedirs(os.path.dirname(local_file_path), exist_ok=True) + + # Check if file already exists and is not empty + if os.path.exists(local_file_path) and os.path.getsize(local_file_path) > 0: + logger.info(f"File already exists: {filename}") + return True + try: # Build the download URL url = f"https://huggingface.co/{hf_model_name}/resolve/main/{filename}" - # Target file path - local_file_path = os.path.join(save_path, filename) - os.makedirs(os.path.dirname(local_file_path), exist_ok=True) - - # Check if the file already exists - if os.path.exists(local_file_path): - logger.info(f"File already exists: {filename}") - return filename, True - # Get file size response = requests.head(url) total_size = int(response.headers.get('content-length', 0)) - # If the size cannot be obtained, set a default value or do not display the percentage + # If the size cannot be obtained, set a default value if total_size == 0: logger.info(f"Starting download of file: {filename} (Size unknown)") else: @@ -423,7 +715,7 @@ def download_file_with_progress(file_info): # Create the file to write to with open(local_file_path, 'wb') as f: - # Create a progress bar, if the total size is unknown, set it to None + # Create a progress bar progress_bar = tqdm( total=total_size if total_size > 0 else None, unit='iB', @@ -432,84 +724,87 @@ def download_file_with_progress(file_info): disable=False ) - # Define the progress callback function + # Define progress callback def progress_callback(current, total): - # Update the progress bar progress_bar.update(current - progress_bar.n) - # Record the log every 1MB (or a value close to 1MB) - if current % (1024 * 1024) < 8192: # Record every 1MB - # Ensure total is greater than 0 before calculating the percentage - if total and total > 0: # Use and to ensure total is not None and greater than 0 + # Log progress every ~1MB + if current % (1024 * 1024) < 8192: + if total and total > 0: percent = current / total * 100 logger.info(f"File {filename}: Downloaded {current/1024/1024:.2f} MB / {total/1024/1024:.2f} MB ({percent:.2f}%)") else: - # If the total size is unknown or 0, only show the downloaded size logger.info(f"File {filename}: Downloaded {current/1024/1024:.2f} MB (total size unknown)") - # Use the request library to download the file and update the progress + # Download file with progress tracking response = requests.get(url, stream=True) if response.status_code == 200: downloaded = 0 - # Check if the response contains the Content-Length header information + # Update total size if needed actual_total = int(response.headers.get('content-length', 0)) if actual_total > 0 and (total_size == 0 or total_size != actual_total): - # If the HEAD request did not return the correct size, but the GET request did, then update the total size total_size = actual_total logger.info(f"Updated file size for {filename}: {total_size / 1024 / 1024:.2f} MB") progress_bar.total = total_size progress_bar.refresh() for chunk in response.iter_content(chunk_size=8192): - if chunk: # Filter out empty chunks that keep the connection alive + if chunk: f.write(chunk) downloaded += len(chunk) progress_callback(downloaded, total_size) progress_bar.close() logger.info(f"Completed download of file: {filename}") - return filename, True + return True else: logger.error(f"Failed to download {filename}: HTTP status {response.status_code}") - return filename, False + failed_files.append(filename) + return False except Exception as e: - logger.error(f"Error downloading {filename}: {str(e)}") - return filename, False - - # Create a list of file path information - file_infos = [(filename, os.path.join(save_path, filename)) for filename in files] - - # Use a thread pool to download all files in parallel - logger.info(f"Starting parallel download of {len(file_infos)} files") - - # Use a thread pool to download all files in parallel - from concurrent.futures import ThreadPoolExecutor + logger.error(f"Failed to download {filename}: {str(e)}") + failed_files.append(filename) + return False + + # Keep track of failed files for potential retry + failed_files = [] - # Limit the number of concurrent requests to avoid too many requests - max_workers = min(10, len(file_infos)) - results = [] + # Use ThreadPoolExecutor for parallel downloads with controlled concurrency + max_workers = min(8, len(files)) # Limit concurrent downloads to avoid overloading + successful_downloads = 0 with ThreadPoolExecutor(max_workers=max_workers) as executor: - # Submit all download tasks - future_to_file = {executor.submit(download_file_with_progress, file_info): file_info[0] - for file_info in file_infos} - - # Wait for all tasks to complete and collect results - for future in tqdm(future_to_file, desc="Overall Progress", unit="file"): - filename = future_to_file[future] - try: + # Create a progress bar for overall download progress + with tqdm(total=len(files), desc="Downloading model files") as progress: + futures = [executor.submit(download_file_with_progress, file) for file in files] + + for future in futures: result = future.result() - results.append(result) - logger.info(f"Finished processing {filename}") - except Exception as exc: - logger.error(f'{filename} generated an exception: {exc}') - results.append((filename, False)) + if result: + successful_downloads += 1 + progress.update(1) + + # Report progress periodically + if progress.n % 5 == 0 or progress.n == len(files): + logger.info(f"Downloaded {progress.n}/{len(files)} files ({successful_downloads} successful)") - # Check the download results - success_count = sum(1 for _, success in results if success) - logger.info(f"Download completed. Successfully downloaded {success_count}/{len(files)} files.") + # Handle any failed downloads + if failed_files: + logger.warning(f"Failed to download {len(failed_files)} files. First few: {failed_files[:5]}") + + # If most files failed, there might be an issue with the model repository + if len(failed_files) > len(files) * 0.5: + logger.error(f"More than 50% of files failed to download. There might be an issue with the model repository.") + raise RuntimeError("Too many files failed to download") + + # If critical files failed (like model weights or config), warn specifically + critical_patterns = ['model.safetensors', 'config.json', 'tokenizer.json'] + critical_failed = [f for f in failed_files if any(pattern in f for pattern in critical_patterns)] + if critical_failed: + logger.error(f"Failed to download critical files: {critical_failed}") + raise RuntimeError(f"Failed to download critical model files: {', '.join(critical_failed)}") # Record the download completion information try: @@ -527,15 +822,14 @@ def progress_callback(current, total): raise except KeyboardInterrupt: logger.warning(f"Download interrupted by user for model: {model_name}") + # Clean up partial downloads raise except Exception as e: - # Log any errors that occur logger.error(f"Error downloading model: {str(e)}") + logger.error(traceback.format_exc()) raise - return save_path - def format_timestr(utc_time_str): """Formats a UTC time string to a more readable format. diff --git a/lpm_kernel/api/domains/documents/routes.py b/lpm_kernel/api/domains/documents/routes.py index 3ae1ded0..d83b1ef3 100644 --- a/lpm_kernel/api/domains/documents/routes.py +++ b/lpm_kernel/api/domains/documents/routes.py @@ -349,3 +349,56 @@ def get_document_embedding(document_id: int): return jsonify( APIResponse.error(message=f"Error getting document embedding: {str(e)}") ) + + +@document_bp.route("/documents/verify-embeddings", methods=["GET"]) +def verify_document_embeddings(): + """Verify all document embeddings and return statistics""" + try: + verbose = request.args.get("verbose", "").lower() == "true" + results = document_service.verify_document_embeddings(verbose=verbose) + return jsonify(APIResponse.success(data=results)) + + except Exception as e: + logger.error(f"Error verifying document embeddings: {str(e)}", exc_info=True) + return jsonify(APIResponse.error(message=f"Error verifying document embeddings: {str(e)}")) + + +@document_bp.route("/documents/repair", methods=["POST"]) +def repair_documents(): + """Repair documents with missing analysis and embeddings""" + try: + # First, fix missing document analysis (summaries and insights) + fixed_analysis_count = document_service.fix_missing_document_analysis() + + # Get verification results after fixing analysis + verification_results = document_service.verify_document_embeddings(verbose=False) + + # Process documents with missing embeddings + documents_fixed = 0 + for doc in document_service._repository.list(): + embedding = document_service.get_document_embedding(doc.id) + if doc.raw_content and embedding is None: + try: + document_service.process_document_embedding(doc.id) + # Also process chunk embeddings + document_service.generate_document_chunk_embeddings(doc.id) + documents_fixed += 1 + except Exception as e: + logger.error(f"Error processing document {doc.id} embedding: {str(e)}") + + # Get final verification results + final_results = document_service.verify_document_embeddings(verbose=False) + + return jsonify(APIResponse.success( + data={ + "analysis_fixed": fixed_analysis_count, + "embeddings_fixed": documents_fixed, + "initial_state": verification_results, + "final_state": final_results + } + )) + + except Exception as e: + logger.error(f"Error repairing documents: {str(e)}", exc_info=True) + return jsonify(APIResponse.error(message=f"Error repairing documents: {str(e)}")) diff --git a/lpm_kernel/api/domains/kernel2/routes_l2.py b/lpm_kernel/api/domains/kernel2/routes_l2.py index 4491acab..8f352c4a 100644 --- a/lpm_kernel/api/domains/kernel2/routes_l2.py +++ b/lpm_kernel/api/domains/kernel2/routes_l2.py @@ -3,6 +3,8 @@ import os import time import sys +import torch # Add torch import for CUDA detection +import traceback from dataclasses import asdict from flask import Blueprint, jsonify, Response, request @@ -352,6 +354,19 @@ def train2(): model_name = data["model_name"] paths = get_model_paths(model_name) + # Get optional parameters with defaults + learning_rate = data.get("learning_rate", 2e-4) + num_train_epochs = data.get("number_of_epochs", 3) + concurrency_threads = data.get("concurrency_threads", 2) + data_synthesis_mode = data.get("data_synthesis_mode", "low") + use_cuda = data.get("use_cuda", False) + + # Convert use_cuda to string "True" or "False" for the shell script + use_cuda_str = "True" if use_cuda else "False" + + logger.info(f"Training configuration: learning_rate={learning_rate}, epochs={num_train_epochs}, " + f"threads={concurrency_threads}, mode={data_synthesis_mode}, use_cuda={use_cuda} ({use_cuda_str})") + # Check if model exists if not os.path.exists(paths["base_path"]): return jsonify(APIResponse.error( @@ -385,29 +400,61 @@ def train2(): script_path = os.path.join(os.getcwd(), "lpm_kernel/L2/train_for_user.sh") + # Build command arguments + cmd_args = [ + "--lr", str(learning_rate), + "--epochs", str(num_train_epochs), + "--threads", str(concurrency_threads), + "--mode", str(data_synthesis_mode), + "--cuda", use_cuda_str # Use the properly formatted string + ] + # Start training import threading _training_thread = threading.Thread( - target=start_training, - args=(script_path, log_path), + target=start_training_with_args, + args=(script_path, log_path, cmd_args), daemon=True ) _training_thread.start() return jsonify(APIResponse.success( data={ + "status": "training_started", "model_name": model_name, "log_path": log_path, - "personal_dir": paths["personal_dir"], - "merged_dir": paths["merged_dir"] }, - message="Training task started" + message="Training task started successfully" )) except Exception as e: - error_msg = f"Failed to start training: {str(e)}" - logger.error(error_msg) - return jsonify(APIResponse.error(message=error_msg, code=500)) + logger.error(f"Error starting training task: {str(e)}") + traceback.print_exc() + return jsonify(APIResponse.error(message=f"Failed to start training: {str(e)}")) + + +def start_training_with_args(script_path: str, log_path: str, args: list) -> None: + """Start training with additional arguments""" + global _training_process + try: + # Convert script path and args to a command + cmd = [script_path] + args + + # Use ScriptRunner to execute the script + runner = ScriptRunner(log_path=log_path) + _training_process = runner.execute_script( + script_path=script_path, + script_type="training", + is_python=False, # This is a bash script + args=args + ) + + logger.info(f"Training process started with args: {args}, process: {_training_process}") + + except Exception as e: + logger.error(f"Failed to start training process: {str(e)}") + _training_process = None + raise @kernel2_bp.route("/merge_weights", methods=["POST"]) @@ -550,6 +597,9 @@ def start_llama_server(): return jsonify(APIResponse.error(message="Missing required parameter: model_name", code=400)) model_name = data["model_name"] + # Get optional use_gpu parameter with default value of True + use_gpu = data.get("use_gpu", True) + paths = get_model_paths(model_name) gguf_path = os.path.join(paths["gguf_dir"], "model.gguf") @@ -564,53 +614,34 @@ def start_llama_server(): server_executable = "llama-server" server_path = os.path.join(server_path, server_executable) - # Check if service and model file exist - if not os.path.exists(server_path): - return jsonify(APIResponse.error(message="llama-server executable file does not exist", code=400)) + # Check if model file exists if not os.path.exists(gguf_path): return jsonify(APIResponse.error( message=f"Model '{model_name}' GGUF file does not exist, please convert model first", code=400 )) - # Check if service is already running + # Start the server using the LocalLLMService with GPU acceleration if requested + success = local_llm_service.start_server(gguf_path, use_gpu=use_gpu) + + if not success: + return jsonify(APIResponse.error(message="Failed to start llama-server", code=500)) + + # Get updated service status status = local_llm_service.get_server_status() - if status.is_running: - return jsonify( - APIResponse.error( - message=f"llama-server is already running, PID: {status.process_info.pid}", - code=400 - ) - ) - - # Build parameters - args = [server_path, "-m", gguf_path, "--port", "8080"] - - # Use thread to start service asynchronously - def start_server(): - script_executor.execute( - script_path=server_path, - script_type="llama_server", - args=args[1:], # Remove first parameter (executable file path) - shell=False, - ) - - # Start new thread to run service - from threading import Thread - - thread = Thread(target=start_server) - thread.daemon = True - thread.start() - - # Return start status immediately + + # Return success response with GPU info + gpu_info = "with GPU acceleration" if use_gpu and torch.cuda.is_available() else "with CPU only" return jsonify( APIResponse.success( data={ "model_name": model_name, "gguf_path": gguf_path, - "status": "starting" + "status": "running" if status.is_running else "starting", + "use_gpu": use_gpu and torch.cuda.is_available(), + "gpu_info": gpu_info }, - message="llama-server service is starting" + message=f"llama-server service started {gpu_info}" ) ) @@ -805,3 +836,31 @@ def chat(body: ChatRequest): if not getattr(body, 'stream', True): # Default to stream if attribute missing return jsonify(error_response), 500 return local_llm_service.handle_stream_response(iter([error_response])) + + +@kernel2_bp.route("/cuda/available", methods=["GET"]) +def check_cuda_available(): + """Check if CUDA is available for model training/inference""" + try: + import torch + cuda_available = torch.cuda.is_available() + cuda_info = {} + + if cuda_available: + cuda_info = { + "device_count": torch.cuda.device_count(), + "current_device": torch.cuda.current_device(), + "device_name": torch.cuda.get_device_name(0) + } + + return jsonify(APIResponse.success( + data={ + "cuda_available": cuda_available, + "cuda_info": cuda_info + }, + message="CUDA availability check completed" + )) + except Exception as e: + error_msg = f"Error checking CUDA availability: {str(e)}" + logger.error(error_msg) + return jsonify(APIResponse.error(message=error_msg, code=500)) diff --git a/lpm_kernel/api/domains/trainprocess/routes.py b/lpm_kernel/api/domains/trainprocess/routes.py index 2ac3beec..19ee9c40 100644 --- a/lpm_kernel/api/domains/trainprocess/routes.py +++ b/lpm_kernel/api/domains/trainprocess/routes.py @@ -24,6 +24,7 @@ def start_process(): number_of_epochs: Number of training epochs (optional) concurrency_threads: Number of threads for concurrent processing (optional) data_synthesis_mode: Mode for data synthesis (optional) + use_cuda: Whether to use CUDA for training (optional) Includes the following steps: 1. Health check @@ -63,6 +64,7 @@ def start_process(): number_of_epochs = data.get("number_of_epochs", None) concurrency_threads = data.get("concurrency_threads", None) data_synthesis_mode = data.get("data_synthesis_mode", None) + use_cuda = data.get("use_cuda", False) # Default to False if not provided is_cot = data.get("is_cot", None) # Log the received parameters @@ -90,6 +92,7 @@ def start_process(): "number_of_epochs": number_of_epochs, "concurrency_threads": concurrency_threads, "data_synthesis_mode": data_synthesis_mode, + "use_cuda": use_cuda, # Make sure to include use_cuda parameter "is_cot": is_cot } @@ -113,6 +116,7 @@ def start_process(): "number_of_epochs": number_of_epochs, "concurrency_threads": concurrency_threads, "data_synthesis_mode": data_synthesis_mode, + "use_cuda": use_cuda, # Include in response "is_cot": is_cot } ) diff --git a/lpm_kernel/api/services/local_llm_service.py b/lpm_kernel/api/services/local_llm_service.py index f5bb3447..29de8424 100644 --- a/lpm_kernel/api/services/local_llm_service.py +++ b/lpm_kernel/api/services/local_llm_service.py @@ -4,6 +4,7 @@ import psutil import time import subprocess +import torch # Add torch import for CUDA detection import threading import queue from typing import Iterator, Any, Optional, Generator, Dict @@ -38,9 +39,16 @@ def client(self) -> OpenAI: ) return self._client - def start_server(self, model_path: str) -> bool: + def start_server(self, model_path: str, use_gpu: bool = True) -> bool: """ - Start the llama-server service + Start the llama-server service with GPU acceleration when available + + Args: + model_path: Path to the GGUF model file + use_gpu: Whether to use GPU acceleration if available + + Returns: + bool: True if server started successfully, False otherwise """ try: # Check if server is already running @@ -49,27 +57,163 @@ def start_server(self, model_path: str) -> bool: logger.info("LLama server is already running") return True - # Start server + # Check for CUDA availability if GPU was requested + cuda_available = torch.cuda.is_available() if use_gpu else False + gpu_info = "" + + if use_gpu and cuda_available: + gpu_device = torch.cuda.current_device() + gpu_info = f" using GPU: {torch.cuda.get_device_name(gpu_device)}" + gpu_memory = torch.cuda.get_device_properties(gpu_device).total_memory / (1024**3) + + logger.info(f"CUDA is available. Using GPU acceleration{gpu_info}") + logger.info(f"CUDA device capabilities: {torch.cuda.get_device_capability(gpu_device)}") + logger.info(f"CUDA memory: {gpu_memory:.2f} GB") + + # Pre-initialize CUDA to speed up first inference + logger.info("Pre-initializing CUDA context to speed up first inference") + torch.cuda.init() + torch.cuda.empty_cache() + elif use_gpu and not cuda_available: + logger.warning("CUDA was requested but is not available. Using CPU instead.") + else: + logger.info("Using CPU for inference (GPU not requested)") + + # Check for GPU optimization marker + gpu_optimized = False + model_dir = os.path.dirname(model_path) + gpu_marker_path = os.path.join(model_dir, "gpu_optimized.json") + if os.path.exists(gpu_marker_path): + try: + with open(gpu_marker_path, 'r') as f: + gpu_data = json.load(f) + if gpu_data.get("gpu_optimized", False): + gpu_optimized = True + logger.info(f"Found GPU optimization marker created on {gpu_data.get('optimized_on', 'unknown date')}") + except Exception as e: + logger.warning(f"Error reading GPU marker file: {e}") + + # Get the correct path to the llama-server executable + base_dir = os.getcwd() + server_path = os.path.join(base_dir, "llama.cpp", "build", "bin", "llama-server") + + # For Windows, add .exe extension if needed + if os.name == 'nt' and not server_path.endswith('.exe'): + server_path += '.exe' + + # Verify executable exists + if not os.path.exists(server_path): + logger.error(f"llama-server executable not found at: {server_path}") + return False + + # Start server with optimal parameters for faster startup cmd = [ - "llama-server", + server_path, "-m", model_path, "--host", "0.0.0.0", - "--port", "8000" + "--port", "8080", + "--ctx-size", "2048", # Default context size (adjust based on needs) + "--parallel", "2", # Enable request parallelism + "--cont-batching" # Enable continuous batching ] + # Set up environment with CUDA variables to ensure GPU detection + env = os.environ.copy() + + # Add GPU-related parameters if CUDA is available + if cuda_available and use_gpu: + # Force GPU usage with optimal parameters for faster loads + cmd.extend([ + "--n-gpu-layers", "999", # Use all layers on GPU + "--tensor-split", "0", # Use the first GPU for all operations + "--main-gpu", "0", # Use GPU 0 as the primary device + "--mlock" # Lock memory to prevent swapping during inference + ]) + + # Set CUDA environment variables to help with GPU detection + env["CUDA_VISIBLE_DEVICES"] = "0" # Force using first GPU + + # Ensure comprehensive library paths for CUDA + cuda_lib_paths = [ + "/usr/local/cuda/lib64", + "/usr/lib/cuda/lib64", + "/usr/local/lib", + "/usr/lib/x86_64-linux-gnu", + "/usr/lib/wsl/lib" # For Windows WSL environments + ] + + # Build a comprehensive LD_LIBRARY_PATH + current_ld_path = env.get("LD_LIBRARY_PATH", "") + for path in cuda_lib_paths: + if os.path.exists(path) and path not in current_ld_path: + current_ld_path = f"{path}:{current_ld_path}" if current_ld_path else path + + env["LD_LIBRARY_PATH"] = current_ld_path + logger.info(f"Setting LD_LIBRARY_PATH to: {current_ld_path}") + + # If this is Windows, use different approach for CUDA libraries + if os.name == 'nt': + # Windows typically has CUDA in PATH already if installed + logger.info("Windows system detected, using system CUDA libraries") + else: + # On Linux, try to find CUDA libraries in common locations + for cuda_path in [ + # Common CUDA paths + "/usr/local/cuda/lib64", + "/usr/lib/cuda/lib64", + "/usr/local/lib/python3.12/site-packages/nvidia/cuda_runtime/lib", + "/usr/local/lib/python3.10/site-packages/nvidia/cuda_runtime/lib", + ]: + if os.path.exists(cuda_path): + # Add CUDA path to library path + env["LD_LIBRARY_PATH"] = f"{cuda_path}:{env.get('LD_LIBRARY_PATH', '')}" + env["CUDA_HOME"] = os.path.dirname(cuda_path) + logger.info(f"Found CUDA at {cuda_path}, setting environment variables") + break + + # NOTE: CUDA support and rebuild should be handled at build/setup time (e.g., Docker build or setup script). + # The runtime check and rebuild logic has been removed for efficiency and reliability. + # Ensure llama.cpp is built with CUDA support before running the server if GPU is required. + + # Pre-heat GPU to ensure faster initial response + if torch.cuda.is_available(): + logger.info("Pre-warming GPU to reduce initial latency...") + dummy_tensor = torch.zeros(1, 1).cuda() + del dummy_tensor + torch.cuda.synchronize() + torch.cuda.empty_cache() + logger.info("GPU warm-up complete") + + logger.info("Using GPU acceleration for inference with optimized settings") + else: + # If GPU isn't available or supported, optimize for CPU + cmd.extend([ + "--threads", str(max(1, os.cpu_count() - 1)), # Use all CPU cores except one + ]) + logger.info(f"Using CPU-only mode with {max(1, os.cpu_count() - 1)} threads") + + logger.info(f"Starting llama-server with command: {' '.join(cmd)}") + process = subprocess.Popen( cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, - universal_newlines=True + universal_newlines=True, + env=env ) - # Wait for server to start - time.sleep(2) + # Wait for server to start (longer wait for GPU initialization) + wait_time = 5 if cuda_available and use_gpu else 3 + logger.info(f"Waiting {wait_time} seconds for server to start...") + time.sleep(wait_time) - # Check if process started successfully + # Check if process is still running if process.poll() is None: - logger.info("LLama server started successfully") + # Log initialization success + if cuda_available and use_gpu: + logger.info(f"✅ LLama server started successfully with GPU acceleration{gpu_info}") + else: + logger.info("✅ LLama server started successfully in CPU-only mode") return True else: stdout, stderr = process.communicate() @@ -151,10 +295,15 @@ def get_server_status(self) -> ServerStatus: Returns: ServerStatus object """ try: + base_dir = os.getcwd() + server_path = os.path.join(base_dir, "llama.cpp", "build", "bin", "llama-server") + server_exec_name = os.path.basename(server_path) + for proc in psutil.process_iter(["pid", "name", "cmdline"]): try: cmdline = proc.cmdline() - if any("llama-server" in cmd for cmd in cmdline): + # Check both for the executable name and the full path + if any(server_exec_name in cmd for cmd in cmdline) or any("llama-server" in cmd for cmd in cmdline): with proc.oneshot(): process_info = ProcessInfo( pid=proc.pid, diff --git a/lpm_kernel/app.py b/lpm_kernel/app.py index 62f18c2f..dba7383b 100644 --- a/lpm_kernel/app.py +++ b/lpm_kernel/app.py @@ -6,6 +6,7 @@ from .database.migration_manager import MigrationManager import os import atexit +import subprocess def create_app(): diff --git a/lpm_kernel/common/llm.py b/lpm_kernel/common/llm.py index badd2d95..977fe17f 100644 --- a/lpm_kernel/common/llm.py +++ b/lpm_kernel/common/llm.py @@ -7,6 +7,13 @@ logger = get_train_process_logger() import lpm_kernel.common.strategy.classification as classification from sentence_transformers import SentenceTransformer +import json + +class EmbeddingError(Exception): + """Custom exception class for embedding-related errors""" + def __init__(self, message, original_error=None): + super().__init__(message) + self.original_error = original_error class LLMClient: """LLM client utility class""" @@ -55,10 +62,8 @@ def get_embedding(self, texts: Union[str, List[str]]) -> np.ndarray: user_llm_config = self.user_llm_config_service.get_available_llm() if not user_llm_config: - - raise Exception("No LLM configuration found") - # Prepare request data - + raise EmbeddingError("No LLM configuration found") + try: # Send request to embedding endpoint embeddings_array = classification.strategy_classification(user_llm_config, chunked_texts) @@ -81,7 +86,25 @@ def get_embedding(self, texts: Union[str, List[str]]) -> np.ndarray: return embeddings_array except requests.exceptions.RequestException as e: - raise Exception(f"Failed to get embeddings: {str(e)}") + # Handle request errors + error_msg = f"Request error getting embeddings: {str(e)}" + logger.error(error_msg) + raise EmbeddingError(error_msg, e) + except json.JSONDecodeError as e: + # Handle JSON parsing errors + error_msg = f"Invalid JSON response from embedding API: {str(e)}" + logger.error(error_msg) + raise EmbeddingError(error_msg, e) + except (KeyError, IndexError, ValueError) as e: + # Handle response structure errors + error_msg = f"Invalid response structure from embedding API: {str(e)}" + logger.error(error_msg) + raise EmbeddingError(error_msg, e) + except Exception as e: + # Fallback for any other errors + error_msg = f"Unexpected error getting embeddings: {str(e)}" + logger.error(error_msg, exc_info=True) + raise EmbeddingError(error_msg, e) @property def chat_credentials(self): diff --git a/lpm_kernel/file_data/document_service.py b/lpm_kernel/file_data/document_service.py index 805985f2..ab04ce18 100644 --- a/lpm_kernel/file_data/document_service.py +++ b/lpm_kernel/file_data/document_service.py @@ -148,6 +148,37 @@ def _analyze_document(self, doc: DocumentDTO) -> DocumentDTO: self._update_analyze_status_failed(doc.id) raise + def analyze_document(self, document_id: int) -> DocumentDTO: + """ + Analyze a single document by ID + + Args: + document_id (int): ID of document to analyze + + Returns: + DocumentDTO: The analyzed document + + Raises: + ValueError: If document not found + Exception: If analysis fails + """ + try: + # Get document + document = self._repository.find_one(document_id) + if not document: + raise ValueError(f"Document not found with id: {document_id}") + + # Perform analysis + return self._analyze_document(document) + + except ValueError as e: + logger.error(f"Document {document_id} not found: {str(e)}") + raise + except Exception as e: + logger.error(f"Error analyzing document {document_id}: {str(e)}", exc_info=True) + self._update_analyze_status_failed(document_id) + raise + def _update_analyze_status_failed(self, doc_id: int) -> None: """update status as failed""" try: @@ -593,6 +624,120 @@ def delete_file_by_name(self, filename: str) -> bool: logger.error(f"Error deleting file: {str(e)}", exc_info=True) raise + def fix_missing_document_analysis(self) -> int: + """Fix documents with missing insights or summaries + + Returns: + int: Number of documents fixed + """ + try: + # Find all documents that have analysis issues + docs = self._repository.list() + fixed_count = 0 + + for doc in docs: + needs_fixing = False + + # Check if document needs analysis + if not doc.analyze_status or doc.analyze_status != ProcessStatus.SUCCESS: + needs_fixing = True + logger.info(f"Document {doc.id} needs analysis (status: {doc.analyze_status})") + + # Check if document has missing insights or summaries + elif not doc.insight or not doc.summary: + needs_fixing = True + logger.info(f"Document {doc.id} has missing insight or summary") + + # Process documents that need fixing + if needs_fixing: + try: + # Process document analysis + self.analyze_document(doc.id) + fixed_count += 1 + logger.info(f"Fixed document {doc.id} analysis") + except Exception as e: + logger.error(f"Error fixing document {doc.id} analysis: {str(e)}") + + logger.info(f"Fixed {fixed_count} documents with missing analysis") + return fixed_count + + except Exception as e: + logger.error(f"Error in fix_missing_document_analysis: {str(e)}") + raise FileProcessingError(f"Failed to fix document analysis: {str(e)}") + + def verify_document_embeddings(self, verbose=True) -> Dict: + """ + Verify all document embeddings and return statistics + + Args: + verbose (bool): Whether to log detailed information + + Returns: + Dict: Statistics about document embeddings + """ + try: + docs = self._repository.list() + results = { + "total_documents": len(docs), + "documents_with_embedding": 0, + "documents_without_embedding": 0, + "documents_with_content": 0, + "documents_without_content": 0, + "documents_with_summary": 0, + "documents_without_summary": 0, + "documents_with_insight": 0, + "documents_without_insight": 0, + "documents_needing_repair": 0, + } + + documents_needing_repair = [] + + for doc in docs: + # Check if document has content + if doc.raw_content: + results["documents_with_content"] += 1 + else: + results["documents_without_content"] += 1 + + # Check if document has summary + if doc.summary: + results["documents_with_summary"] += 1 + else: + results["documents_without_summary"] += 1 + + # Check if document has insight + if doc.insight: + results["documents_with_insight"] += 1 + else: + results["documents_without_insight"] += 1 + + # Check if embeddings exist in ChromaDB + embedding = self.get_document_embedding(doc.id) + if embedding is not None: + results["documents_with_embedding"] += 1 + if verbose: + logger.info(f"Document {doc.id}: '{doc.name}' has embedding of dimension {len(embedding)}") + else: + results["documents_without_embedding"] += 1 + if verbose: + logger.warning(f"Document {doc.id}: '{doc.name}' missing embedding") + + # Check if document needs repair (has content but missing embedding or analysis) + if doc.raw_content and (embedding is None or not doc.summary or not doc.insight): + documents_needing_repair.append(doc.id) + results["documents_needing_repair"] += 1 + + # Log statistics + logger.info(f"Document embedding verification results: {results}") + if documents_needing_repair and verbose: + logger.info(f"Documents needing repair: {documents_needing_repair}") + + return results + + except Exception as e: + logger.error(f"Error verifying document embeddings: {str(e)}", exc_info=True) + raise + # create service document_service = DocumentService() diff --git a/lpm_kernel/file_data/embedding_service.py b/lpm_kernel/file_data/embedding_service.py index cc6fc7fc..dedd0717 100644 --- a/lpm_kernel/file_data/embedding_service.py +++ b/lpm_kernel/file_data/embedding_service.py @@ -418,4 +418,4 @@ def search_similar_chunks( raise except Exception as e: logger.error(f"Error searching similar chunks: {str(e)}") - raise + raise \ No newline at end of file diff --git a/lpm_kernel/train/training_params_manager.py b/lpm_kernel/train/training_params_manager.py index 2f6589ac..c400bae4 100644 --- a/lpm_kernel/train/training_params_manager.py +++ b/lpm_kernel/train/training_params_manager.py @@ -23,6 +23,7 @@ class TrainingParamsManager: "number_of_epochs": 3, "concurrency_threads": 2, "data_synthesis_mode": "low", + "use_cuda": False, # Default to using CUDA when available "is_cot": False } diff --git a/lpm_kernel/train/trainprocess_service.py b/lpm_kernel/train/trainprocess_service.py index fe1e58e0..24aeb1eb 100644 --- a/lpm_kernel/train/trainprocess_service.py +++ b/lpm_kernel/train/trainprocess_service.py @@ -42,7 +42,7 @@ class TrainProcessService: _instance = None _initialized = False - + def __new__(cls, *args, **kwargs): if cls._instance is None: cls._instance = super().__new__(cls) @@ -633,6 +633,7 @@ def _start_training(self, script_path, log_path): num_train_epochs = training_params.get("number_of_epochs") concurrency_threads = training_params.get("concurrency_threads") data_synthesis_mode = training_params.get("data_synthesis_mode") + use_cuda = training_params.get("use_cuda", False) is_cot = training_params.get("is_cot", False) # Log training parameters @@ -641,6 +642,8 @@ def _start_training(self, script_path, log_path): logger.info(f" Number of epochs: {num_train_epochs}") logger.info(f" Concurrency threads: {concurrency_threads}") logger.info(f" Data synthesis mode: {data_synthesis_mode}") + logger.info(f" Use CUDA: {use_cuda}") + logger.info(f" Is CoT: {is_cot}") # Prepare arguments for the script # Build command line arguments, need to include script path as the first parameter @@ -650,6 +653,7 @@ def _start_training(self, script_path, log_path): "--epochs", str(num_train_epochs), "--threads", str(concurrency_threads), "--mode", str(data_synthesis_mode), + "--cuda", str(use_cuda), "--is_cot", str(is_cot) ] diff --git a/scripts/prompt_cuda.bat b/scripts/prompt_cuda.bat new file mode 100644 index 00000000..f2826ae5 --- /dev/null +++ b/scripts/prompt_cuda.bat @@ -0,0 +1,69 @@ +@echo off +REM Script to prompt user for CUDA support preference + +echo === CUDA Support Selection === +echo. +echo Do you want to build with NVIDIA GPU (CUDA) support? +echo This requires an NVIDIA GPU and proper NVIDIA Docker runtime configuration. +echo. +set /p choice="Build with CUDA support? (y/n): " + +if /i "%choice%"=="y" goto cuda +if /i "%choice%"=="yes" goto cuda +goto nocuda + +:cuda +echo Selected: Build WITH CUDA support + +REM Create or update .env file with the Dockerfile selection +if exist .env ( + REM Check if variable already exists in file + findstr /c:"DOCKER_BACKEND_DOCKERFILE" .env >nul + if %ERRORLEVEL% EQU 0 ( + REM Update existing variable + powershell -Command "(Get-Content .env) -replace '^DOCKER_BACKEND_DOCKERFILE=.*', 'DOCKER_BACKEND_DOCKERFILE=Dockerfile.backend.cuda' | Set-Content .env" + ) else ( + REM Append to file with newline before + echo.>> .env + echo DOCKER_BACKEND_DOCKERFILE=Dockerfile.backend.cuda>> .env + ) +) else ( + REM Create new file + echo DOCKER_BACKEND_DOCKERFILE=Dockerfile.backend.cuda> .env +) + +REM Create a flag file to indicate GPU use +echo GPU > .gpu_selected + +echo Environment set to build with CUDA support +goto end + +:nocuda +echo Selected: Build WITHOUT CUDA support (CPU only) + +REM Create or update .env file with the Dockerfile selection +if exist .env ( + REM Check if variable already exists in file + findstr /c:"DOCKER_BACKEND_DOCKERFILE" .env >nul + if %ERRORLEVEL% EQU 0 ( + REM Update existing variable + powershell -Command "(Get-Content .env) -replace '^DOCKER_BACKEND_DOCKERFILE=.*', 'DOCKER_BACKEND_DOCKERFILE=Dockerfile.backend' | Set-Content .env" + ) else ( + REM Append to file with newline before + echo.>> .env + echo DOCKER_BACKEND_DOCKERFILE=Dockerfile.backend>> .env + ) +) else ( + REM Create new file + echo DOCKER_BACKEND_DOCKERFILE=Dockerfile.backend> .env +) + +REM Remove any GPU flag file if it exists +if exist .gpu_selected ( + del .gpu_selected +) + +echo Environment set to build without CUDA support + +:end +echo === CUDA Selection Complete === \ No newline at end of file diff --git a/scripts/prompt_cuda.sh b/scripts/prompt_cuda.sh new file mode 100644 index 00000000..0e34be0b --- /dev/null +++ b/scripts/prompt_cuda.sh @@ -0,0 +1,62 @@ +#!/bin/bash +# Script to prompt user for CUDA support preference and directly build with the appropriate Dockerfile + +echo "=== CUDA Support Selection ===" +echo "" +echo "Do you want to build with NVIDIA GPU (CUDA) support?" +echo "This requires an NVIDIA GPU and proper NVIDIA Docker runtime configuration." +echo "" +read -p "Build with CUDA support? (y/n): " choice + +case "$choice" in + y|Y|yes|YES|Yes ) + echo "Selected: Build WITH CUDA support" + + # Create or update .env file with the Dockerfile selection + if [ -f .env ]; then + # Update existing file + if grep -q "DOCKER_BACKEND_DOCKERFILE" .env; then + sed -i 's/^DOCKER_BACKEND_DOCKERFILE=.*/DOCKER_BACKEND_DOCKERFILE=Dockerfile.backend.cuda/' .env + else + # Add a newline before appending new content + echo "" >> .env + echo "DOCKER_BACKEND_DOCKERFILE=Dockerfile.backend.cuda" >> .env + fi + else + # Create new file + echo "DOCKER_BACKEND_DOCKERFILE=Dockerfile.backend.cuda" > .env + fi + + # Create a flag file to indicate GPU use + echo "GPU" > .gpu_selected + + echo "Environment set to build with CUDA support" + ;; + * ) + echo "Selected: Build WITHOUT CUDA support (CPU only)" + + # Create or update .env file with the Dockerfile selection + if [ -f .env ]; then + # Update existing file + if grep -q "DOCKER_BACKEND_DOCKERFILE" .env; then + sed -i 's/^DOCKER_BACKEND_DOCKERFILE=.*/DOCKER_BACKEND_DOCKERFILE=Dockerfile.backend/' .env + else + # Add a newline before appending new content + echo "" >> .env + echo "DOCKER_BACKEND_DOCKERFILE=Dockerfile.backend" >> .env + fi + else + # Create new file + echo "DOCKER_BACKEND_DOCKERFILE=Dockerfile.backend" > .env + fi + + # Remove any GPU flag file if it exists + if [ -f .gpu_selected ]; then + rm .gpu_selected + fi + + echo "Environment set to build without CUDA support" + ;; +esac + +echo "=== CUDA Selection Complete ===" \ No newline at end of file From 29a17c861560d930f8882d61787e757f618a41de Mon Sep 17 00:00:00 2001 From: KKKKKKKevin <115385420+kevin-mindverse@users.noreply.github.com> Date: Fri, 25 Apr 2025 14:17:15 +0800 Subject: [PATCH 09/21] Optimize TrainProcessService Singleton Pattern Implementation (#279) * feature: use uv to setup python environment * TrainProcessService add singleten method: get_instance --- lpm_kernel/api/domains/loads/load_service.py | 4 +- lpm_kernel/api/domains/trainprocess/routes.py | 38 ++----------------- lpm_kernel/train/trainprocess_service.py | 29 +++++++++++++- 3 files changed, 33 insertions(+), 38 deletions(-) diff --git a/lpm_kernel/api/domains/loads/load_service.py b/lpm_kernel/api/domains/loads/load_service.py index 463de918..915370d6 100644 --- a/lpm_kernel/api/domains/loads/load_service.py +++ b/lpm_kernel/api/domains/loads/load_service.py @@ -437,12 +437,12 @@ def _reset_training_progress() -> None: # Extract model name model_name = file.replace('trainprocess_progress_', '').replace('.json', '') # Create a new service instance for each model and reset progress - train_service = TrainProcessService(model_name=model_name) + train_service = TrainProcessService(current_model_name=model_name) train_service.progress.reset_progress() logger.info(f"Reset training progress for model: {model_name}") # Reset default training progress - default_train_service = TrainProcessService() + default_train_service = TrainProcessService.get_instance() default_train_service.progress.reset_progress() logger.info("Reset default training progress") diff --git a/lpm_kernel/api/domains/trainprocess/routes.py b/lpm_kernel/api/domains/trainprocess/routes.py index 19ee9c40..71bce508 100644 --- a/lpm_kernel/api/domains/trainprocess/routes.py +++ b/lpm_kernel/api/domains/trainprocess/routes.py @@ -71,9 +71,7 @@ def start_process(): logger.info(f"Training parameters: model_name={model_name}, learning_rate={learning_rate}, number_of_epochs={number_of_epochs}, concurrency_threads={concurrency_threads}, data_synthesis_mode={data_synthesis_mode}, is_cot={is_cot}") # Create service instance with model name and additional parameters - train_service = TrainProcessService( - current_model_name=model_name - ) + train_service = TrainProcessService(current_model_name=model_name) # Check if there are any in_progress statuses that need to be reset if train_service.progress.progress.data["status"] == "in_progress": @@ -198,7 +196,7 @@ def reset_progress(): } """ try: - train_service = TrainProcessService() + train_service = TrainProcessService.get_instance() train_service.progress.reset_progress() return jsonify(APIResponse.success(message="Progress reset successfully")) @@ -212,7 +210,7 @@ def stop_training(): """Stop training process and wait until status is suspended""" try: # Get the TrainProcessService instance - train_service = TrainProcessService() # Need to get instance based on your implementation + train_service = TrainProcessService.get_instance() # Need to get instance based on your implementation # Stop the process train_service.stop_process() @@ -239,32 +237,6 @@ def stop_training(): return jsonify(APIResponse.error(message=f"Error stopping training process: {str(e)}")) -@trainprocess_bp.route("/model_name", methods=["GET"]) -def get_model_name(): - """ - Get the model name currently used by the training service - - Returns: - Response: JSON response - { - "code": 0 for success, non-zero for failure, - "message": "Error message", - "data": { - "model_name": "Model name" - } - } - """ - try: - # Get TrainProcessService instance - train_service = TrainProcessService() - model_name = train_service.model_name - - return jsonify(APIResponse.success(data={"model_name": model_name})) - except Exception as e: - logger.error(f"Failed to get model name: {str(e)}", exc_info=True) - return jsonify(APIResponse.error(message=f"Failed to get model name: {str(e)}")) - - @trainprocess_bp.route("/training_params", methods=["GET"]) def get_training_params(): """ @@ -323,9 +295,7 @@ def retrain(): return jsonify(APIResponse.error(message="missing necessary parameter: model_name", code=400)) # Create training service instance - train_service = TrainProcessService( - current_model_name=model_name - ) + train_service = TrainProcessService(current_model_name=model_name) # Check if there are any in_progress statuses that need to be reset if train_service.progress.progress.data["status"] == "in_progress": diff --git a/lpm_kernel/train/trainprocess_service.py b/lpm_kernel/train/trainprocess_service.py index 24aeb1eb..dd4214a2 100644 --- a/lpm_kernel/train/trainprocess_service.py +++ b/lpm_kernel/train/trainprocess_service.py @@ -48,7 +48,10 @@ def __new__(cls, *args, **kwargs): cls._instance = super().__new__(cls) return cls._instance - def __init__(self, current_model_name: str = None): + def __init__(self, current_model_name: str): + if current_model_name is None: + raise ValueError("current_model_name cannot be None") + if not self._initialized: # Generate a unique progress file name based on model name self.progress = TrainProgressHolder(current_model_name) @@ -72,10 +75,32 @@ def __init__(self, current_model_name: str = None): self.l2_data_prepared = False # Update model name and progress instance if model name changes - if current_model_name is not None and current_model_name != self.model_name: + if current_model_name != self.model_name: self.model_name = current_model_name # Create new progress instance with updated progress file name self.progress = TrainProgressHolder(current_model_name) + + @classmethod + def get_instance(cls, current_model_name: str = None): + """Get the current instance of TrainProcessService + + Args: + current_model_name: Optional model name to update the instance with + + Returns: + TrainProcessService: The singleton instance + """ + if cls._instance is None: + if current_model_name is None: + raise ValueError("current_model_name must be provided when creating a new instance") + return cls(current_model_name) + + if current_model_name is not None: + # Update the existing instance with new model name + cls._instance.model_name = current_model_name + cls._instance.progress = TrainProgressHolder(current_model_name) + + return cls._instance def list_documents(self): """List all documents""" From 37553fb23b42cab3df94f345897f20bc6717f8f4 Mon Sep 17 00:00:00 2001 From: KKKKKKKevin <115385420+kevin-mindverse@users.noreply.github.com> Date: Fri, 25 Apr 2025 15:27:52 +0800 Subject: [PATCH 10/21] Feature/fix training model switch bug2 (#281) * feature: use uv to setup python environment * TrainProcessService add singleten method: get_instance * feat: fix code * Added CUDA support (#228) * Add CUDA support - CUDA detection - Memory handling - Ollama model release after training * Fix logging issue added cuda support flag so log accurately reflected cuda toggle * Update llama.cpp rebuild Changed llama.cpp to only check if cuda support is enabled and if so rebuild during the first build rather than each run * Improved vram management Enabled memory pinning and optimizer state offload * Fix CUDA check rewrote llama.cpp rebuild logic, added manual y/n toggle if user wants to enable cuda support * Added fast restart and fixed CUDA check command Added make docker-restart-backend-fast to restart the backend and reflect code changes without causing a full llama.cpp rebuild Fixed make docker-check-cuda command to correctly reflect cuda support * Added docker-compose.gpu.yml Added docker-compose.gpu.yml to fix error on machines without nvidia gpu and made sure "\n" is added before .env modification * Fixed cuda toggle Last push accidentally broke cuda toggle * Code review fixes Fixed errors resulting from removed code: - Added return save_path to end of save_hf_model function - Rolled back download_file_with_progress function * Update Makefile Use cuda by default when using docker-restart-backend-fast * Minor cleanup Removed unnecessary makefile command and fixed gpu logging * Delete .gpu_selected * Simplified cuda training code - Removed dtype setting to let torch automatically handle it - Removed vram logging - Removed Unnecessary/old comments * Fixed gpu/cpu selection Made "make docker-use-gpu/cpu" command work with .gpu_selected flag and changed "make docker-restart-backend-fast" command to respect flag instead of always using gpu * Fix Ollama embedding error Added custom exception class for Ollama embeddings, which seemed to be returning keyword arguments while the Python exception class only accepts positional ones * Fixed model selection & memory error Fixed training defaulting to 0.5B model regardless of selection and fixed "free(): double free detected in tcache 2" error caused by cuda flag being passed incorrectly * fix: train service singlten --------- Co-authored-by: Zachary Pitroda <30330004+zpitroda@users.noreply.github.com> --- lpm_kernel/api/domains/loads/load_service.py | 14 +++++++------- lpm_kernel/api/domains/trainprocess/routes.py | 14 +++++++++++--- lpm_kernel/train/trainprocess_service.py | 16 +++++++++------- 3 files changed, 27 insertions(+), 17 deletions(-) diff --git a/lpm_kernel/api/domains/loads/load_service.py b/lpm_kernel/api/domains/loads/load_service.py index 915370d6..55b9b32f 100644 --- a/lpm_kernel/api/domains/loads/load_service.py +++ b/lpm_kernel/api/domains/loads/load_service.py @@ -3,12 +3,14 @@ This module provides service functions for managing Load entities. """ - +import os import logging from typing import Optional, Dict, Any, Tuple from lpm_kernel.models.load import Load from lpm_kernel.common.repository.database_session import DatabaseSession from lpm_kernel.api.domains.loads.dto import LoadDTO +from lpm_kernel.train.trainprocess_service import TrainProcessService +from lpm_kernel.api.domains.kernel2.routes_l2 import _training_process, _training_thread, _stopping_training logger = logging.getLogger(__name__) @@ -424,10 +426,6 @@ def _clean_data_directories() -> None: def _reset_training_progress() -> None: """Reset training progress objects in memory""" try: - import os - # Import training service - from lpm_kernel.train.trainprocess_service import TrainProcessService - # Get all possible training progress file patterns base_dir = os.getenv('LOCAL_BASE_DIR', '.') progress_dir = os.path.join(base_dir, 'data', 'progress') @@ -443,14 +441,16 @@ def _reset_training_progress() -> None: # Reset default training progress default_train_service = TrainProcessService.get_instance() - default_train_service.progress.reset_progress() + if default_train_service is not None: + default_train_service.progress.reset_progress() + logger.info("Reset default training progress") # Reset global training process variables from lpm_kernel.api.domains.kernel2.routes_l2 import _training_process, _training_thread, _stopping_training if _training_process is not None: logger.info("Resetting global training process variables") - _training_process = None + _training_process = None _training_thread = None _stopping_training = False diff --git a/lpm_kernel/api/domains/trainprocess/routes.py b/lpm_kernel/api/domains/trainprocess/routes.py index 71bce508..b18082e1 100644 --- a/lpm_kernel/api/domains/trainprocess/routes.py +++ b/lpm_kernel/api/domains/trainprocess/routes.py @@ -71,15 +71,17 @@ def start_process(): logger.info(f"Training parameters: model_name={model_name}, learning_rate={learning_rate}, number_of_epochs={number_of_epochs}, concurrency_threads={concurrency_threads}, data_synthesis_mode={data_synthesis_mode}, is_cot={is_cot}") # Create service instance with model name and additional parameters - train_service = TrainProcessService(current_model_name=model_name) + last_train_service = TrainProcessService.get_instance() # Check if there are any in_progress statuses that need to be reset - if train_service.progress.progress.data["status"] == "in_progress": + if last_train_service is not None and last_train_service.progress.progress.data["status"] == "in_progress": return jsonify(APIResponse.error( message="There is an existing training process that was interrupted.", code=409 # Conflict status code )) + + train_service = TrainProcessService(current_model_name=model_name) if not train_service.check_training_condition(): train_service.reset_progress() @@ -197,7 +199,11 @@ def reset_progress(): """ try: train_service = TrainProcessService.get_instance() - train_service.progress.reset_progress() + if train_service is not None: + train_service.progress.reset_progress() + logger.info("Progress reset successfully") + else: + logger.warning("No active training process found") return jsonify(APIResponse.success(message="Progress reset successfully")) except Exception as e: @@ -211,6 +217,8 @@ def stop_training(): try: # Get the TrainProcessService instance train_service = TrainProcessService.get_instance() # Need to get instance based on your implementation + if train_service is None: + return jsonify(APIResponse.error(message="Failed to stop training: No active training process")) # Stop the process train_service.stop_process() diff --git a/lpm_kernel/train/trainprocess_service.py b/lpm_kernel/train/trainprocess_service.py index dd4214a2..26e0bad1 100644 --- a/lpm_kernel/train/trainprocess_service.py +++ b/lpm_kernel/train/trainprocess_service.py @@ -92,7 +92,8 @@ def get_instance(cls, current_model_name: str = None): """ if cls._instance is None: if current_model_name is None: - raise ValueError("current_model_name must be provided when creating a new instance") + logger.warning("current_model_name must be provided when creating a new instance") + return None return cls(current_model_name) if current_model_name is not None: @@ -1071,6 +1072,7 @@ def check_training_condition(self) -> bool: if current_step: step = ProcessStep(current_step) self.progress.mark_step_status(step, Status.FAILED) + return False def start_process(self) -> bool: """Start training process""" @@ -1122,8 +1124,9 @@ def start_process(self) -> bool: return True except Exception as e: - logger.error(f"Exception occurred: {str(e)}") - self.progress.mark_step_status(step, Status.FAILED) + logger.error(f"Exception occurred: {str(e)}", exc_info=True) + if self.current_step: + self.progress.mark_step_status(self.current_step, Status.FAILED) return False def reset_progress(self): @@ -1133,10 +1136,9 @@ def reset_progress(self): """ try: self.progress.reset_progress() - self.progress._save_progress() logger.info("Progress saved successfully") except Exception as e: - logger.error(f"Failed to save progress: {str(e)}") + logger.error(f"Failed to save progress: {str(e)}", exc_info=True) def stop_process(self): """Stop training process @@ -1201,8 +1203,8 @@ def stop_process(self): logger.warning(f"Process with PID {self.current_pid} no longer exists") return True except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess) as e: - logger.error(f"Failed to terminate process: {str(e)}") + logger.error(f"Failed to terminate process: {str(e)}", exc_info=True) except Exception as e: - logger.error(f"Error stopping training process: {str(e)}") + logger.error(f"Error stopping training process: {str(e)}", exc_info=True) return False \ No newline at end of file From de8370ba0d6818ac32730404502d8232a9edca29 Mon Sep 17 00:00:00 2001 From: KKKKKKKevin <115385420+kevin-mindverse@users.noreply.github.com> Date: Fri, 25 Apr 2025 16:26:36 +0800 Subject: [PATCH 11/21] fix move trainprocess to solve loop (#288) --- lpm_kernel/api/domains/loads/load_service.py | 3 +-- lpm_kernel/api/domains/trainprocess/routes.py | 4 ++-- .../domains/trainprocess}/training_params_manager.py | 0 .../domains/trainprocess}/trainprocess_service.py | 7 +------ lpm_kernel/train/__init__.py | 0 5 files changed, 4 insertions(+), 10 deletions(-) rename lpm_kernel/{train => api/domains/trainprocess}/training_params_manager.py (100%) rename lpm_kernel/{train => api/domains/trainprocess}/trainprocess_service.py (99%) delete mode 100644 lpm_kernel/train/__init__.py diff --git a/lpm_kernel/api/domains/loads/load_service.py b/lpm_kernel/api/domains/loads/load_service.py index 55b9b32f..3e4ed91e 100644 --- a/lpm_kernel/api/domains/loads/load_service.py +++ b/lpm_kernel/api/domains/loads/load_service.py @@ -9,8 +9,7 @@ from lpm_kernel.models.load import Load from lpm_kernel.common.repository.database_session import DatabaseSession from lpm_kernel.api.domains.loads.dto import LoadDTO -from lpm_kernel.train.trainprocess_service import TrainProcessService -from lpm_kernel.api.domains.kernel2.routes_l2 import _training_process, _training_thread, _stopping_training +from lpm_kernel.api.domains.trainprocess.trainprocess_service import TrainProcessService logger = logging.getLogger(__name__) diff --git a/lpm_kernel/api/domains/trainprocess/routes.py b/lpm_kernel/api/domains/trainprocess/routes.py index b18082e1..c4a3dc3c 100644 --- a/lpm_kernel/api/domains/trainprocess/routes.py +++ b/lpm_kernel/api/domains/trainprocess/routes.py @@ -3,8 +3,8 @@ from flask import Blueprint, jsonify, Response, request from charset_normalizer import from_path -from lpm_kernel.train.trainprocess_service import TrainProcessService -from lpm_kernel.train.training_params_manager import TrainingParamsManager +from lpm_kernel.api.domains.trainprocess.trainprocess_service import TrainProcessService +from lpm_kernel.api.domains.trainprocess.training_params_manager import TrainingParamsManager from ...common.responses import APIResponse from threading import Thread diff --git a/lpm_kernel/train/training_params_manager.py b/lpm_kernel/api/domains/trainprocess/training_params_manager.py similarity index 100% rename from lpm_kernel/train/training_params_manager.py rename to lpm_kernel/api/domains/trainprocess/training_params_manager.py diff --git a/lpm_kernel/train/trainprocess_service.py b/lpm_kernel/api/domains/trainprocess/trainprocess_service.py similarity index 99% rename from lpm_kernel/train/trainprocess_service.py rename to lpm_kernel/api/domains/trainprocess/trainprocess_service.py index 26e0bad1..f33f7703 100644 --- a/lpm_kernel/train/trainprocess_service.py +++ b/lpm_kernel/api/domains/trainprocess/trainprocess_service.py @@ -1,11 +1,7 @@ -from enum import Enum -from typing import Dict, List, Optional -import json import os import re import time import psutil -from lpm_kernel.configs.config import Config from lpm_kernel.L1.utils import save_true_topics from lpm_kernel.L1.serializers import NotesStorage from lpm_kernel.kernel.note_service import NoteService @@ -26,10 +22,9 @@ from lpm_kernel.kernel.l1.l1_manager import generate_l1_from_l0 import threading from lpm_kernel.api.domains.trainprocess.progress_enum import Status -from lpm_kernel.api.domains.trainprocess.train_progress import TrainProgress from lpm_kernel.api.domains.trainprocess.process_step import ProcessStep from lpm_kernel.api.domains.trainprocess.progress_holder import TrainProgressHolder -from lpm_kernel.train.training_params_manager import TrainingParamsManager +from lpm_kernel.api.domains.trainprocess.training_params_manager import TrainingParamsManager from lpm_kernel.common.repository.database_session import DatabaseSession from lpm_kernel.api.domains.kernel.routes import store_l1_data import gc diff --git a/lpm_kernel/train/__init__.py b/lpm_kernel/train/__init__.py deleted file mode 100644 index e69de29b..00000000 From 3ae664fe09e34074ba0bd7718de99d6c159ade18 Mon Sep 17 00:00:00 2001 From: KKKKKKKevin <115385420+kevin-mindverse@users.noreply.github.com> Date: Fri, 25 Apr 2025 17:20:07 +0800 Subject: [PATCH 12/21] add execute right (#289) --- lpm_kernel/L2/train_for_user.sh | 0 1 file changed, 0 insertions(+), 0 deletions(-) mode change 100644 => 100755 lpm_kernel/L2/train_for_user.sh diff --git a/lpm_kernel/L2/train_for_user.sh b/lpm_kernel/L2/train_for_user.sh old mode 100644 new mode 100755 From 19adcac4352b42762d5a1eaf5ee946df1a197ee3 Mon Sep 17 00:00:00 2001 From: ryangyuan Date: Fri, 25 Apr 2025 17:46:37 +0800 Subject: [PATCH 13/21] Feat/0423/train status (#287) * fix: adjustment status order * fix: adjustment train status * fix: split the status of service and train --- .../src/app/dashboard/train/memories/page.tsx | 2 +- .../src/app/dashboard/train/training/page.tsx | 42 +++-- .../src/components/ModelStatus/index.tsx | 155 ++++++++++-------- .../src/components/StatusBar/index.tsx | 8 - .../train/TrainingConfiguration.tsx | 26 ++- .../src/components/train/TrainingProgress.tsx | 2 +- .../components/upload/RegisterUploadModal.tsx | 9 - .../src/components/upload/TraingTipModal.tsx | 25 +++ .../layouts/DashboardLayout/Menu/index.tsx | 9 +- lpm_frontend/src/service/train.ts | 21 ++- lpm_frontend/src/store/useTrainingStore.ts | 67 ++++++-- 11 files changed, 225 insertions(+), 141 deletions(-) create mode 100644 lpm_frontend/src/components/upload/TraingTipModal.tsx diff --git a/lpm_frontend/src/app/dashboard/train/memories/page.tsx b/lpm_frontend/src/app/dashboard/train/memories/page.tsx index 87b74d07..1d0c15c2 100644 --- a/lpm_frontend/src/app/dashboard/train/memories/page.tsx +++ b/lpm_frontend/src/app/dashboard/train/memories/page.tsx @@ -217,7 +217,6 @@ export default function TrainPage() {
@@ -236,6 +235,7 @@ export default function TrainPage() { ) : null } onClose={() => setSelectedInfo(null)} + open={!!selectedInfo && !!trainSectionInfo[selectedInfo]} title={selectedInfo ? trainSectionInfo[selectedInfo].name : ''} /> diff --git a/lpm_frontend/src/app/dashboard/train/training/page.tsx b/lpm_frontend/src/app/dashboard/train/training/page.tsx index 21cb2789..0543b408 100644 --- a/lpm_frontend/src/app/dashboard/train/training/page.tsx +++ b/lpm_frontend/src/app/dashboard/train/training/page.tsx @@ -4,7 +4,14 @@ import { useState, useEffect, useRef } from 'react'; import { useRouter } from 'next/navigation'; import InfoModal from '@/components/InfoModal'; import type { TrainingConfig } from '@/service/train'; -import { startTrain, stopTrain, retrain, getTrainingParams, checkCudaAvailability, resetProgress } from '@/service/train'; +import { + startTrain, + stopTrain, + retrain, + getTrainingParams, + checkCudaAvailability, + resetProgress +} from '@/service/train'; import { useTrainingStore } from '@/store/useTrainingStore'; import { getMemoryList } from '@/service/memory'; import { message, Modal } from 'antd'; @@ -73,11 +80,13 @@ export default function TrainingPage() { const modelConfig = useModelConfigStore((store) => store.modelConfig); const status = useTrainingStore((state) => state.status); const trainingProgress = useTrainingStore((state) => state.trainingProgress); + const serviceStarted = useTrainingStore((state) => state.serviceStarted); const router = useRouter(); const [selectedInfo, setSelectedInfo] = useState(false); - const [isTraining, setIsTraining] = useState(false); + const isTraining = useTrainingStore((state) => state.isTraining); + const setIsTraining = useTrainingStore((state) => state.setIsTraining); const [trainingParams, setTrainingParams] = useState({} as TrainingConfig); const [nowTrainingParams, setNowTrainingParams] = useState(null); const [trainActionLoading, setTrainActionLoading] = useState(false); @@ -108,11 +117,12 @@ export default function TrainingPage() { useEffect(() => { // Check CUDA availability once on load checkCudaAvailability() - .then(res => { + .then((res) => { if (res.data.code === 0) { const { cuda_available, cuda_info } = res.data.data; + setCudaAvailable(cuda_available); - + if (cuda_available) { console.log('CUDA is available:', cuda_info); } else { @@ -122,7 +132,7 @@ export default function TrainingPage() { message.error(res.data.message || 'Failed to check CUDA availability'); } }) - .catch(err => { + .catch((err) => { console.error('CUDA availability check failed', err); message.error('CUDA availability check failed'); }); @@ -306,13 +316,17 @@ export default function TrainingPage() { eventSource.onmessage = (event) => { // Don't try to parse as JSON, just use the raw text data directly const logMessage = event.data; - + setTrainingDetails((prev) => { const newLogs = [ ...prev.slice(-500), // Keep more log entries (500 instead of 100) { message: logMessage, - timestamp: new Date().toLocaleTimeString([], { hour: '2-digit', minute: '2-digit', second: '2-digit' }) + timestamp: new Date().toLocaleTimeString([], { + hour: '2-digit', + minute: '2-digit', + second: '2-digit' + }) } ]; @@ -462,6 +476,12 @@ export default function TrainingPage() { return; } + if (!isTraining && serviceStarted) { + message.error('Model is already running, please stop it first'); + + return; + } + setTrainActionLoading(true); // If training is in progress, stop it @@ -472,8 +492,8 @@ export default function TrainingPage() { return; } - // If the same model has already been trained and status is 'trained' or 'running', perform retraining - if (!changeBaseModel && (status === 'trained' || status === 'running')) { + // If the same model has already been trained and service is started, perform retraining + if (!changeBaseModel && (status === 'trained' || serviceStarted)) { await handleRetrainModel(); } else { // Otherwise start new training @@ -532,6 +552,7 @@ export default function TrainingPage() { {/* Only show training progress after training starts */} - {(status === 'training' || status === 'trained' || status === 'running') && + {(status === 'training' || status === 'trained' || serviceStarted) && renderTrainingProgress()} {/* Always show training log regardless of training status */} diff --git a/lpm_frontend/src/components/ModelStatus/index.tsx b/lpm_frontend/src/components/ModelStatus/index.tsx index 28fde187..9a96d54a 100644 --- a/lpm_frontend/src/components/ModelStatus/index.tsx +++ b/lpm_frontend/src/components/ModelStatus/index.tsx @@ -1,5 +1,5 @@ -import { useTrainingStore } from '@/store/useTrainingStore'; -import { startService, stopService, getServiceStatus } from '@/service/train'; +import { Status, statusRankMap, useTrainingStore } from '@/store/useTrainingStore'; +import { startService, stopService } from '@/service/train'; import { StatusBar } from '../StatusBar'; import { useRef, useEffect, useState, useMemo } from 'react'; import { message } from 'antd'; @@ -13,6 +13,8 @@ import { import RegisterUploadModal from '../upload/RegisterUploadModal'; import { useLoadInfoStore } from '@/store/useLoadInfoStore'; +import TrainingTipModal from '../upload/TraingTipModal'; +import { getMemoryList } from '@/service/memory'; const StatusDot = ({ active }: { active: boolean }) => (
( export function ModelStatus() { const status = useTrainingStore((state) => state.status); const setStatus = useTrainingStore((state) => state.setStatus); + const serviceStarted = useTrainingStore((state) => state.serviceStarted); const isServiceStarting = useTrainingStore((state) => state.isServiceStarting); const isServiceStopping = useTrainingStore((state) => state.isServiceStopping); const setServiceStarting = useTrainingStore((state) => state.setServiceStarting); const setServiceStopping = useTrainingStore((state) => state.setServiceStopping); + const fetchServiceStatus = useTrainingStore((state) => state.fetchServiceStatus); + const isTraining = useTrainingStore((state) => state.isTraining); const [messageApi, contextHolder] = message.useMessage(); @@ -36,43 +41,37 @@ export function ModelStatus() { }, [loadInfo]); const [showRegisterModal, setShowRegisterModal] = useState(false); + const [showtrainingModal, setShowtrainingModal] = useState(false); const handleRegistryClick = () => { - if (status !== 'trained' && status !== 'running') { - messageApi.info({ - content: 'Please train your model first', - duration: 1 - }); - } else if (status === 'trained') { + if (!serviceStarted) { messageApi.info({ content: 'Please start your model service first', duration: 1 }); - } else if (status === 'running') { + } else { setShowRegisterModal(true); } }; - const fetchServiceStatus = async () => { + const fetchMemories = async () => { try { - const statusRes = await getServiceStatus(); + const memoryRes = await getMemoryList(); - if (statusRes.data.code === 0) { - const isRunning = statusRes.data.data.is_running; + if (memoryRes.data.code === 0) { + const memories = memoryRes.data.data; - if (isRunning) { - setStatus('running'); - setServiceStarting(false); - } else if (status === 'running') { - setStatus('trained'); + if (memories.length > 0 && statusRankMap[status] < statusRankMap[Status.MEMORY_UPLOAD]) { + setStatus(Status.MEMORY_UPLOAD); } } } catch (error) { - console.error('Error checking initial service status:', error); + console.error('Error fetching memories:', error); } }; useEffect(() => { + fetchMemories(); fetchServiceStatus(); return () => { @@ -94,13 +93,12 @@ export function ModelStatus() { // Start new polling interval pollingInterval.current = setInterval(() => { - getServiceStatus() - .then((statusRes) => { - if (statusRes.data.code === 0) { - const isRunning = statusRes.data.data.is_running; + fetchServiceStatus() + .then((res) => { + if (res.data.code === 0) { + const isRunning = res.data.data.is_running; if (isRunning) { - setStatus('running'); setServiceStarting(false); clearPolling(); } @@ -117,13 +115,12 @@ export function ModelStatus() { // Start new polling interval pollingInterval.current = setInterval(() => { - getServiceStatus() - .then((statusRes) => { - if (statusRes.data.code === 0) { - const isRunning = statusRes.data.data.is_running; + fetchServiceStatus() + .then((res) => { + if (res.data.code === 0) { + const isRunning = res.data.data.is_running; if (!isRunning) { - setStatus('trained'); setServiceStopping(false); clearPolling(); } @@ -135,7 +132,7 @@ export function ModelStatus() { }, 3000); }; - const handleServiceAction = () => { + const handleStartService = () => { const config = JSON.parse(localStorage.getItem('trainingParams') || '{}'); if (!config.model_name) { @@ -144,46 +141,60 @@ export function ModelStatus() { return; } - if (status === 'running') { - setServiceStopping(true); - stopService() - .then((res) => { - if (res.data.code === 0) { - messageApi.success({ content: 'Service stopping...', duration: 1 }); - startStopPolling(); - } else { - messageApi.error({ content: res.data.message!, duration: 1 }); - setServiceStopping(false); - } - }) - .catch((error) => { - console.error('Error stopping service:', error); - messageApi.error({ - content: error.response?.data?.message || error.message, - duration: 1 - }); + setServiceStarting(true); + startService({ model_name: config.model_name }) + .then((res) => { + if (res.data.code === 0) { + messageApi.success({ content: 'Service starting...', duration: 1 }); + startPolling(); + } else { + setServiceStarting(false); + messageApi.error({ content: res.data.message!, duration: 1 }); + } + }) + .catch((error) => { + console.error('Error starting service:', error); + setServiceStarting(false); + messageApi.error({ + content: error.response?.data?.message || error.message, + duration: 1 + }); + }); + }; + + const handleStopService = () => { + setServiceStopping(true); + stopService() + .then((res) => { + if (res.data.code === 0) { + messageApi.success({ content: 'Service stopping...', duration: 1 }); + startStopPolling(); + } else { + messageApi.error({ content: res.data.message!, duration: 1 }); setServiceStopping(false); + } + }) + .catch((error) => { + console.error('Error stopping service:', error); + messageApi.error({ + content: error.response?.data?.message || error.message, + duration: 1 }); + setServiceStopping(false); + }); + }; + + const handleServiceAction = () => { + if (serviceStarted) { + handleStopService(); } else { - setServiceStarting(true); - startService({ model_name: config.model_name }) - .then((res) => { - if (res.data.code === 0) { - messageApi.success({ content: 'Service starting...', duration: 1 }); - startPolling(); - } else { - setServiceStarting(false); - messageApi.error({ content: res.data.message!, duration: 1 }); - } - }) - .catch((error) => { - console.error('Error starting service:', error); - setServiceStarting(false); - messageApi.error({ - content: error.response?.data?.message || error.message, - duration: 1 - }); - }); + if (isTraining) { + setShowtrainingModal(true); + + return; + } + + handleStartService(); } }; @@ -211,7 +222,7 @@ export function ModelStatus() { {isServiceStarting ? 'Starting...' : 'Stopping...'} - ) : status === 'running' ? ( + ) : serviceStarted ? ( <> @@ -250,6 +261,14 @@ export function ModelStatus() {
setShowRegisterModal(false)} open={showRegisterModal} /> + { + handleStartService(); + setShowtrainingModal(false); + }} + onClose={() => setShowtrainingModal(false)} + open={showtrainingModal} + />
); } diff --git a/lpm_frontend/src/components/StatusBar/index.tsx b/lpm_frontend/src/components/StatusBar/index.tsx index 5c9e319c..98beac6c 100644 --- a/lpm_frontend/src/components/StatusBar/index.tsx +++ b/lpm_frontend/src/components/StatusBar/index.tsx @@ -13,14 +13,6 @@ export function StatusBar({ status }: StatusBarProps) { ] as const; const getStepState = (stepStatus: (typeof steps)[number]['status']) => { - // If current status is running, all steps should be shown as completed, and trained should be active - if (status === 'running') { - return { - isActive: stepStatus === 'trained', - isCompleted: stepStatus !== 'trained' - }; - } - const statusOrder = ['seed_identity', 'memory_upload', 'training', 'trained']; const currentIndex = statusOrder.indexOf(status); const stepIndex = statusOrder.indexOf(stepStatus); diff --git a/lpm_frontend/src/components/train/TrainingConfiguration.tsx b/lpm_frontend/src/components/train/TrainingConfiguration.tsx index ca3411ea..30047f35 100644 --- a/lpm_frontend/src/components/train/TrainingConfiguration.tsx +++ b/lpm_frontend/src/components/train/TrainingConfiguration.tsx @@ -26,14 +26,6 @@ interface ModelConfig { [key: string]: any; } -interface TrainingParams { - data_synthesis_mode: string; - learning_rate: number; - number_of_epochs: number; - concurrency_threads: number; - use_cuda: boolean; -} - interface TrainingConfigurationProps { baseModelOptions: BaseModelOption[]; modelConfig: ModelConfig | null; @@ -90,7 +82,7 @@ const TrainingConfiguration: React.FC = ({ return isTraining ? 'Stop Training' : !changeBaseModel - ? status === 'trained' || status === 'running' + ? status === 'trained' ? 'Retrain' : isResume ? 'Resume Training' @@ -415,7 +407,6 @@ const TrainingConfiguration: React.FC = ({
@@ -61,6 +62,14 @@ Star and join us, and you will receive all release notifications from GitHub wit

+## FAQ + +Got questions about running Second Me, model setup, or memory configuration? + +📖 [Check out the FAQ here](https://secondme.gitbook.io/secondme/faq) + + + ## Quick Start ### 📊 Model Deployment Memory and Supported Model Size Reference Guide From ef4c491d5fde1bc060d2ed5197de137c4a32a575 Mon Sep 17 00:00:00 2001 From: ryangyuan Date: Fri, 25 Apr 2025 18:08:13 +0800 Subject: [PATCH 15/21] Feat/0425/adjustment of training rule (#290) * fix: adjustment status order * fix: adjustment train status * fix: split the status of service and train * feat: adjustment train rule --- .../src/app/dashboard/train/training/page.tsx | 41 ++---- .../train/TrainingConfiguration.tsx | 117 +++++++----------- .../layouts/DashboardLayout/Menu/index.tsx | 23 +++- lpm_frontend/src/store/useTrainingStore.ts | 10 +- 4 files changed, 86 insertions(+), 105 deletions(-) diff --git a/lpm_frontend/src/app/dashboard/train/training/page.tsx b/lpm_frontend/src/app/dashboard/train/training/page.tsx index 0543b408..bd74182f 100644 --- a/lpm_frontend/src/app/dashboard/train/training/page.tsx +++ b/lpm_frontend/src/app/dashboard/train/training/page.tsx @@ -88,11 +88,9 @@ export default function TrainingPage() { const isTraining = useTrainingStore((state) => state.isTraining); const setIsTraining = useTrainingStore((state) => state.setIsTraining); const [trainingParams, setTrainingParams] = useState({} as TrainingConfig); - const [nowTrainingParams, setNowTrainingParams] = useState(null); const [trainActionLoading, setTrainActionLoading] = useState(false); const [showCelebration, setShowCelebration] = useState(false); const [showMemoryModal, setShowMemoryModal] = useState(false); - const [changeBaseModel, setChangeBaseModel] = useState(false); const cleanupEventSourceRef = useRef<(() => void) | undefined>(); const containerRef = useRef(null); @@ -100,15 +98,8 @@ export default function TrainingPage() { const pollingStopRef = useRef(false); const [cudaAvailable, setCudaAvailable] = useState(false); - const [isResume, setIsResume] = useState( - trainingProgress.status === 'suspended' || trainingProgress.status === 'failed' - ); - - useEffect(() => { - const localTrainingParams = JSON.parse(localStorage.getItem('trainingParams') || '{}'); - - setChangeBaseModel(localTrainingParams?.model_name !== trainingParams.model_name); - }, [trainingParams.model_name]); + const trainSuspended = useTrainingStore((state) => state.trainSuspended); + const setTrainSuspended = useTrainingStore((state) => state.setTrainSuspended); useEffect(() => { fetchModelConfig(); @@ -175,10 +166,6 @@ export default function TrainingPage() { pollingStopRef.current = true; }; - useEffect(() => { - setIsResume(trainingProgress.status === 'suspended' || trainingProgress.status === 'failed'); - }, [trainingProgress]); - useEffect(() => { if (status === 'trained' || trainingError) { stopPolling(); @@ -277,7 +264,6 @@ export default function TrainingPage() { const data = res.data.data; setTrainingParams(data); - setNowTrainingParams(data); localStorage.setItem('trainingParams', JSON.stringify(data)); } else { @@ -362,7 +348,7 @@ export default function TrainingPage() { if (res.data.code === 0) { setIsTraining(false); - setIsResume(true); + setTrainSuspended(true); } else { message.error(res.data.message || 'Failed to stop training'); } @@ -378,9 +364,7 @@ export default function TrainingPage() { resetProgress() .then((res) => { if (res.data.code === 0) { - setTrainingParams(nowTrainingParams || ({} as TrainingConfig)); - setNowTrainingParams(null); - setIsResume(false); + setTrainSuspended(false); resetTrainingState(); } else { throw new Error(res.data.message || 'Failed to reset progress'); @@ -404,18 +388,15 @@ export default function TrainingPage() { resetTrainingState(); try { - setNowTrainingParams(trainingParams); - console.log('Using startTrain API to train new model:', trainingParams.model_name); const res = await startTrain({ - ...(isResume && !changeBaseModel ? {} : trainingParams), + ...trainingParams, model_name: trainingParams.model_name }); if (res.data.code === 0) { // Save training configuration and start polling localStorage.setItem('trainingParams', JSON.stringify(trainingParams)); - setChangeBaseModel(false); scrollPageToBottom(); startGetTrainingProgress(); } else { @@ -425,7 +406,6 @@ export default function TrainingPage() { } catch (error: unknown) { console.error('Error starting training:', error); setIsTraining(false); - setNowTrainingParams(null); if (error instanceof Error) { message.error(error.message || 'Failed to start training'); @@ -492,8 +472,8 @@ export default function TrainingPage() { return; } - // If the same model has already been trained and service is started, perform retraining - if (!changeBaseModel && (status === 'trained' || serviceStarted)) { + // If the same model has already been trained and status is 'trained' or 'running', perform retraining + if (status === 'trained') { await handleRetrainModel(); } else { // Otherwise start new training @@ -551,24 +531,21 @@ export default function TrainingPage() { {/* Training Configuration Component */} {/* Only show training progress after training starts */} - {(status === 'training' || status === 'trained' || serviceStarted) && - renderTrainingProgress()} + {(status === 'training' || status === 'trained') && renderTrainingProgress()} {/* Always show training log regardless of training status */} {renderTrainingLog()} diff --git a/lpm_frontend/src/components/train/TrainingConfiguration.tsx b/lpm_frontend/src/components/train/TrainingConfiguration.tsx index 30047f35..866a439c 100644 --- a/lpm_frontend/src/components/train/TrainingConfiguration.tsx +++ b/lpm_frontend/src/components/train/TrainingConfiguration.tsx @@ -1,11 +1,11 @@ 'use client'; import type React from 'react'; -import { Fragment, useEffect, useMemo, useState } from 'react'; +import { Fragment, useMemo, useState } from 'react'; import { Listbox, Transition } from '@headlessui/react'; import { PlayIcon, StopIcon } from '@heroicons/react/24/outline'; import { EVENT } from '../../utils/event'; -import { Checkbox, InputNumber, Radio, Spin, Tooltip } from 'antd'; +import { Checkbox, InputNumber, message, Radio, Spin, Tooltip } from 'antd'; import type { TrainingConfig } from '@/service/train'; import { QuestionCircleOutlined } from '@ant-design/icons'; import OpenAiModelIcon from '../svgs/OpenAiModelIcon'; @@ -32,10 +32,8 @@ interface TrainingConfigurationProps { isTraining: boolean; updateTrainingParams: (params: TrainingConfig) => void; status: string; - isResume: boolean; + trainSuspended: boolean; handleResetProgress: () => void; - nowTrainingParams: TrainingConfig | null; - changeBaseModel: boolean; handleTrainingAction: () => Promise; trainActionLoading: boolean; setSelectedInfo: React.Dispatch>; @@ -55,21 +53,22 @@ const TrainingConfiguration: React.FC = ({ isTraining, updateTrainingParams, trainingParams, - nowTrainingParams, status, handleResetProgress, - isResume, - changeBaseModel, + trainSuspended, trainActionLoading, handleTrainingAction, setSelectedInfo, cudaAvailable }) => { - const [disabledChangeParams, setDisabledChangeParams] = useState(false); const [openThinkingModel, setOpenThinkingModel] = useState(false); const [showThinkingWarning, setShowThinkingWarning] = useState(false); const thinkingModelConfig = useModelConfigStore((state) => state.thinkingModelConfig); + const disabledChangeParams = useMemo(() => { + return isTraining || trainSuspended; + }, [isTraining, trainSuspended]); + const thinkingConfigComplete = useMemo(() => { return ( !!thinkingModelConfig.thinking_model_name && @@ -81,14 +80,12 @@ const TrainingConfiguration: React.FC = ({ const trainButtonText = useMemo(() => { return isTraining ? 'Stop Training' - : !changeBaseModel - ? status === 'trained' - ? 'Retrain' - : isResume - ? 'Resume Training' - : 'Start Training' - : 'Start Training'; - }, [isTraining, status, isResume, changeBaseModel]); + : status === 'trained' + ? 'Retrain' + : trainSuspended + ? 'Resume Training' + : 'Start Training'; + }, [isTraining, status, trainSuspended]); const trainButtonIcon = useMemo(() => { return isTraining ? ( @@ -102,10 +99,6 @@ const TrainingConfiguration: React.FC = ({ ); }, [isTraining, trainActionLoading]); - useEffect(() => { - setDisabledChangeParams(isTraining || (isResume && !changeBaseModel)); - }, [isTraining, isResume, changeBaseModel]); - return (
@@ -160,12 +153,7 @@ const TrainingConfiguration: React.FC = ({
) : (
-
{ - window.dispatchEvent(new CustomEvent(EVENT.SHOW_MODEL_CONFIG_MODAL)); - }} - > +
Model Used :   {modelConfig.provider_type === 'openai' ? ( @@ -176,10 +164,20 @@ const TrainingConfiguration: React.FC = ({ {modelConfig.provider_type === 'openai' ? 'OpenAI' : 'Custom Model'}
updateTrainingParams({ model_name: value })} value={trainingParams.model_name} >
- + {baseModelOptions.find((option) => option.value === trainingParams.model_name) ?.label || 'Select a model...'} @@ -315,11 +314,7 @@ const TrainingConfiguration: React.FC = ({ : undefined } step={0.0001} - value={ - disabledChangeParams && nowTrainingParams && !changeBaseModel - ? nowTrainingParams.learning_rate - : trainingParams.learning_rate - } + value={trainingParams.learning_rate} />
Enter a value between 0.00003 and 0.005 (recommended: 0.0001) @@ -350,11 +345,7 @@ const TrainingConfiguration: React.FC = ({ : undefined } step={1} - value={ - disabledChangeParams && nowTrainingParams && !changeBaseModel - ? nowTrainingParams.number_of_epochs - : trainingParams.number_of_epochs - } + value={trainingParams.number_of_epochs} />
Enter an integer between 1 and 10 (recommended: 2) @@ -386,11 +377,7 @@ const TrainingConfiguration: React.FC = ({ : undefined } step={1} - value={ - disabledChangeParams && nowTrainingParams && !changeBaseModel - ? nowTrainingParams.concurrency_threads - : trainingParams.concurrency_threads - } + value={trainingParams.concurrency_threads} />
Enter an integer between 1 and 10 (recommended: 2) @@ -407,11 +394,7 @@ const TrainingConfiguration: React.FC = ({
@@ -451,18 +428,17 @@ const TrainingConfiguration: React.FC = ({
{ e.stopPropagation(); if (!thinkingConfigComplete) { setShowThinkingWarning(true); - setTimeout(() => setShowThinkingWarning(false), 2000); + + if (!showThinkingWarning) { + setTimeout(() => setShowThinkingWarning(false), 2000); + } return; } @@ -475,9 +451,12 @@ const TrainingConfiguration: React.FC = ({ `text-sm font-medium px-4 py-2 bg-white border rounded-md cursor-pointer transition-all duration-500 ease-[cubic-bezier(0.4,0,0.2,1)]`, showThinkingWarning ? 'border-red-500 text-red-600 bg-red-50 shadow-[0_0_0_2px_rgba(220,38,38,0.4)] animate-pulse' - : 'border-gray-300 text-gray-700 hover:bg-gray-50' + : 'border-gray-300 text-gray-700 hover:bg-gray-50', + disabledChangeParams && 'opacity-50 !cursor-not-allowed' )} onClick={() => { + if (disabledChangeParams) return; + setOpenThinkingModel(true); }} > diff --git a/lpm_frontend/src/layouts/DashboardLayout/Menu/index.tsx b/lpm_frontend/src/layouts/DashboardLayout/Menu/index.tsx index 7bc3f293..79996760 100644 --- a/lpm_frontend/src/layouts/DashboardLayout/Menu/index.tsx +++ b/lpm_frontend/src/layouts/DashboardLayout/Menu/index.tsx @@ -23,7 +23,6 @@ const Menu = () => { const pathname = usePathname(); const router = useRouter(); const [isSidebarCollapsed, setIsSidebarCollapsed] = useState(false); - const status = useTrainingStore((state) => state.status); const loadInfo = useLoadInfoStore((state) => state.loadInfo); const clearLoadInfo = useLoadInfoStore((state) => state.clearLoadInfo); const serviceStarted = useTrainingStore((state) => state.serviceStarted); @@ -33,6 +32,13 @@ const Menu = () => { const [deleteConfirmLoading, setDeleteConfirmLoading] = useState(false); const [showModelConfig, setShowModelConfig] = useState(false); + const isTraining = useTrainingStore((state) => state.isTraining); + const trainSuspended = useTrainingStore((state) => state.trainSuspended); + + const disabledChangeParams = useMemo(() => { + return isTraining || trainSuspended; + }, [isTraining, trainSuspended]); + const isRegistered = useMemo(() => { return loadInfo?.status === 'online'; }, [loadInfo]); @@ -265,8 +271,19 @@ const Menu = () => {
-