Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(self, sample_batch_size: int, tp_world_size: int = 1, limit : int =
'''
Workers in the same tp group share this buffer and need same sample for one step.
Therefore a held_sample should be returned tp_world_size times before it could be dropped.
worker_state records wheter a worker got the held_sample
worker_state records whether a worker got the held_sample
'''
self.tp_world_size = tp_world_size
self.worker_state = [False] * self.tp_world_size
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
class ExperienceMakerHolder:
'''
Args:
detached_trainer_name_list: str list to get ray actor handleskkk
detached_trainer_name_list: str list to get ray actor handles
strategy:
experience_batch_size: batch size of generated experience
kl_coef: the coefficient of kl divergence loss
Expand Down
2 changes: 1 addition & 1 deletion applications/Chat/coati/ray/src/pipeline_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
class PipelineModel(torch.nn.Module):
'''
Actor has 2 kinds of jobs: forward and generate.
better to just pipelinize the inner model
better to just pipeline the inner model
'''
def __init__(self,
model: torch.nn.Module,
Expand Down
2 changes: 1 addition & 1 deletion applications/Chat/evaluate/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def save(self, path: str, model_name_list: List[str]) -> None:
jdump(all_evaluations,
os.path.join(evaluation_results_save_path, f"{model_name_list[0]}_evaluation_results.json"))

# Start to calculate scores and save statictics.
# Start to calculate scores and save statistics.
evaluation_statistics_save_path = os.path.join(base_save_path, "evaluation_statistics")
gpt_evaluate.save_gpt35_evaluation_statistics(model_name_list[0], all_evaluations,
evaluation_statistics_save_path)
Expand Down
8 changes: 4 additions & 4 deletions applications/Chat/evaluate/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def calculate_precision_recall_f1(preds: list, targets: list) -> dict:
The calculation of precision, recall and f1-score is realized by counting
the number f overlaps between the preds and target. The comparison length
limited by the shorter one of preds and targets. This design is mainly
considered for classifiction and extraction categories.
considered for classification and extraction categories.
"""
precision_recall_f1 = {"precision": 0, "recall": 0, "f1_score": 0}
precision_scores = []
Expand All @@ -138,7 +138,7 @@ def calculate_precision_recall_f1(preds: list, targets: list) -> dict:

def precision(preds: list, targets: list) -> dict:
"""Calculate Precision Metric
(design for classifiction and extraction categories)
(design for classification and extraction categories)

Calculating precision by counting the number of overlaps between the preds and target.
"""
Expand All @@ -149,7 +149,7 @@ def precision(preds: list, targets: list) -> dict:

def recall(preds: list, targets: list) -> dict:
"""Calculate Recall Metric
(design for classifiction and extraction categories)
(design for classification and extraction categories)

Calculating recall by counting the number of overlaps between the preds and target.
"""
Expand All @@ -160,7 +160,7 @@ def recall(preds: list, targets: list) -> dict:

def F1_score(preds: list, targets: list) -> dict:
"""Calculate F1-score Metric
(design for classifiction and extraction categories)
(design for classification and extraction categories)

Calculating f1-score by counting the number of overlaps between the preds and target.
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def _remove_sharding_on_broadcast_dim(key, strategy):
# e.g. [1, 2, 4] x [4, 4, 8] -> [4, 2, 8]
# the dim 0 of [1, 2, 4] is multiplied to 4
tensor_shape[dim_idx] = 1
elif broadcast_type == BroadcastType.PADDDING:
elif broadcast_type == BroadcastType.PADDING:
# if the dim is padded
# we remove its sharding
tensor_shape[dim_idx] = None
Expand Down
12 changes: 6 additions & 6 deletions colossalai/auto_parallel/tensor_shard/utils/broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

class BroadcastType(Enum):
EQUAL = auto()
PADDDING = auto()
PADDING = auto()
MULTIPLE = auto()


Expand Down Expand Up @@ -69,18 +69,18 @@ def get_broadcast_dim_info(logical_shape, physical_shape):
for i in range(logical_num_dims):
# get the trailing dim size
logical_dim_idx = logical_num_dims - i - 1
phyiscal_dim_idx = physical_num_dims - i - 1
physical_dim_idx = physical_num_dims - i - 1
logical_dim_size = logical_shape[logical_dim_idx]

if phyiscal_dim_idx >= 0:
physical_dim_size = physical_shape[phyiscal_dim_idx]
if physical_dim_idx >= 0:
physical_dim_size = physical_shape[physical_dim_idx]

if physical_dim_size == logical_dim_size:
logical_dim_broadcast_info[logical_dim_idx] = BroadcastType.EQUAL
elif physical_dim_size == 1 and physical_dim_size != logical_dim_size:
logical_dim_broadcast_info[logical_dim_idx] = BroadcastType.MULTIPLE
else:
logical_dim_broadcast_info[logical_dim_idx] = BroadcastType.PADDDING
logical_dim_broadcast_info[logical_dim_idx] = BroadcastType.PADDING

return logical_dim_broadcast_info

Expand Down Expand Up @@ -117,7 +117,7 @@ def recover_sharding_spec_for_broadcast_shape(logical_sharding_spec: ShardingSpe
for shape_dim, mesh_dim in logical_dim_partition.items():
logical_broadcast_type = logical_dim_broadcast_info[shape_dim]

if logical_broadcast_type == BroadcastType.PADDDING or logical_broadcast_type == BroadcastType.MULTIPLE:
if logical_broadcast_type == BroadcastType.PADDING or logical_broadcast_type == BroadcastType.MULTIPLE:
removed_dims.extend(mesh_dim)
else:
# get the corresponding physical dim
Expand Down