Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions colossalai/shardformer/policies/auto_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,9 @@ class PolicyLocation:

# ChatGLM
"colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMModel":
PolicyLocation(file_name="chatglm", class_name="ChatGLMModelPolicy"),
PolicyLocation(file_name="chatglm2", class_name="ChatGLMModelPolicy"),
"colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMForConditionalGeneration":
PolicyLocation(file_name="chatglm", class_name="ChatGLMForConditionalGenerationPolicy"),
PolicyLocation(file_name="chatglm2", class_name="ChatGLMForConditionalGenerationPolicy"),
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,15 @@

import colossalai.shardformer.layer as col_nn
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.modeling.chatglm import ChatGLMPipelineForwards
from colossalai.shardformer.modeling.chatglm2 import ChatGLMPipelineForwards
from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import (
ChatGLMForConditionalGeneration,
ChatGLMModel,
GLMBlock,
)

from ..modeling.chatglm import get_flash_core_attention_forward, get_jit_fused_glm_block_forward
from ..modeling.chatglm2 import get_flash_core_attention_forward, get_jit_fused_glm_block_forward
from ..modeling.jit import get_jit_fused_dropout_add_func
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription

Expand Down
2 changes: 1 addition & 1 deletion tests/kit/model_zoo/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from .bert import *
from .blip2 import *
from .bloom import *
from .chatglm import *
from .chatglm2 import *
from .gpt import *
from .llama import *
from .opt import *
Expand Down