Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 29 additions & 26 deletions scripts_on_cluster/bodyseg_training/job.run
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
#SBATCH --ntasks 1
#SBATCH --cpus-per-task 16
#SBATCH --mem 92GB
#SBATCH --time 48:00:00
#SBATCH --time 72:00:00
#SBATCH --partition=h100
#SBATCH --qos=normal
#SBATCH --gres=gpu:1
#SBATCH --output /home/sibwang/poseforge/scripts_on_cluster/bodyseg_training/output_20251118a.log
#SBATCH --output /home/sibwang/poseforge/scripts_on_cluster/bodyseg_training/output_20251127a.log

echo "Hello from $(hostname)"

Expand All @@ -19,7 +19,10 @@ conda activate poseforge
cd $HOME/poseforge

training_cli_path="src/poseforge/pose/bodyseg/scripts/run_bodyseg_training.py"
training_trial_name="trial_20251118a"
training_trial_name="trial_20251127a"
contrastive_pretraining_trial_name="trial_20251125a_lowlr"
contrastive_pretraining_epoch="epoch009"
contrastive_pretraining_local_step="step003055"

echo "Training starting at $(date)"

Expand All @@ -30,32 +33,32 @@ python -u $training_cli_path \
--model-architecture-config.final-upsampler-n-hidden-channels 32 \
--model-architecture-config.confidence-method entropy \
--model-weights-config.feature-extractor-weights \
"bulk_data/pose_estimation/contrastive_pretraining/trial_20251117a/checkpoints/checkpoint_epoch009_step003055.feature_extractor.pth" \
"bulk_data/pose_estimation/contrastive_pretraining/$contrastive_pretraining_trial_name/checkpoints/checkpoint_${contrastive_pretraining_epoch}_${contrastive_pretraining_local_step}.feature_extractor.pth" \
--loss-config.weight-dice 1.0 \
--loss-config.weight-ce 1.0 \
--training-data-config.train-data-dirs \
"bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly1_trial001" \
"bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly1_trial002" \
"bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly1_trial003" \
"bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly1_trial004" \
"bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly1_trial005" \
"bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly2_trial001" \
"bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly2_trial002" \
"bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly2_trial003" \
"bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly2_trial004" \
"bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly2_trial005" \
"bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly3_trial001" \
"bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly3_trial002" \
"bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly3_trial003" \
"bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly3_trial004" \
"bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly3_trial005" \
"bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly4_trial001" \
"bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly4_trial002" \
"bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly4_trial003" \
"bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly4_trial004" \
"bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly4_trial005" \
"bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly1_trial001" \
"bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly1_trial002" \
"bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly1_trial003" \
"bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly1_trial004" \
"bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly1_trial005" \
"bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly2_trial001" \
"bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly2_trial002" \
"bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly2_trial003" \
"bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly2_trial004" \
"bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly2_trial005" \
"bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly3_trial001" \
"bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly3_trial002" \
"bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly3_trial003" \
"bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly3_trial004" \
"bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly3_trial005" \
"bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly4_trial001" \
"bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly4_trial002" \
"bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly4_trial003" \
"bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly4_trial004" \
"bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly4_trial005" \
--training-data-config.val-data-dirs \
"bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly1_trial001" \
"bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly1_trial001" \
--training-data-config.input-image-size 256 256 \
--training-data-config.atomic-batch-n-samples 32 \
--training-data-config.atomic-batch-n-variants 4 \
Expand All @@ -67,7 +70,7 @@ python -u $training_cli_path \
--optimizer-config.learning-rate-segmentation-head 3e-4 \
--optimizer-config.weight-decay 1e-5 \
--training-artifacts-config.output-basedir \
"bulk_data/pose_estimation/bodyseg/trial_20251118a/" \
"bulk_data/pose_estimation/bodyseg/$training_trial_name/" \
--training-artifacts-config.logging-interval 10 \
--training-artifacts-config.checkpoint-interval 1000 \
--training-artifacts-config.validation-interval 1000 \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
#SBATCH --ntasks 1
#SBATCH --cpus-per-task 16
#SBATCH --mem 90GB
#SBATCH --time 48:00:00
#SBATCH --time 72:00:00
#SBATCH --partition=h100
#SBATCH --qos=normal
#SBATCH --gres=gpu:1
#SBATCH --output /home/sibwang/poseforge/scripts_on_cluster/contrastive_pretraining_training/output_20251117a.log
#SBATCH --output /home/sibwang/poseforge/scripts_on_cluster/contrastive_pretraining_training/output_20251125b_2variants.log

echo "Hello from $(hostname)"

Expand All @@ -20,7 +20,7 @@ conda activate poseforge
cd $HOME/poseforge
echo "Training starting at $(date)"

trial_name="trial_20251117a"
trial_name="trial_20251125b_2variants"

python -u src/poseforge/pose/contrast/scripts/run_contrastive_pretraining.py \
--n-epochs 10 \
Expand Down Expand Up @@ -53,7 +53,7 @@ python -u src/poseforge/pose/contrast/scripts/run_contrastive_pretraining.py \
--training-data-config.val-data-dirs \
"bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly1_trial001" \
--training-data-config.atomic-batch-n-samples 32 \
--training-data-config.atomic-batch-n-variants 4 \
--training-data-config.atomic-batch-n-variants 2 \
--training-data-config.train-batch-size 960 \
--training-data-config.val-batch-size 256 \
--training-data-config.image-size 256 256 \
Expand Down
70 changes: 70 additions & 0 deletions scripts_on_cluster/contrastive_pretraining_training/low_lr.run
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
#!/bin/bash -l

#SBATCH --job-name contr_pretrain_lr3e-5
#SBATCH --nodes 1
#SBATCH --ntasks 1
#SBATCH --cpus-per-task 16
#SBATCH --mem 90GB
#SBATCH --time 72:00:00
#SBATCH --partition=h100
#SBATCH --qos=normal
#SBATCH --gres=gpu:1
#SBATCH --output /home/sibwang/poseforge/scripts_on_cluster/contrastive_pretraining_training/output_20251125a_lowlr.log

echo "Hello from $(hostname)"

. ~/spack/share/spack/setup-env.sh
spack load ffmpeg
conda activate poseforge

cd $HOME/poseforge
echo "Training starting at $(date)"

trial_name="trial_20251125a_lowlr"

python -u src/poseforge/pose/contrast/scripts/run_contrastive_pretraining.py \
--n-epochs 10 \
--seed 42 \
--model-architecture-config.projection-head-hidden-dim 512 \
--model-architecture-config.projection-head-output-dim 256 \
--model-weights-config.feature-extractor-weights "IMAGENET1K_V1" \
--loss-config.info-nce-temperature 0.1 \
--training-data-config.train-data-dirs \
"bulk_data/pose_estimation/atomic_batches_4variants/BO_Gal4_fly1_trial001" \
"bulk_data/pose_estimation/atomic_batches_4variants/BO_Gal4_fly1_trial002" \
"bulk_data/pose_estimation/atomic_batches_4variants/BO_Gal4_fly1_trial003" \
"bulk_data/pose_estimation/atomic_batches_4variants/BO_Gal4_fly1_trial004" \
"bulk_data/pose_estimation/atomic_batches_4variants/BO_Gal4_fly1_trial005" \
"bulk_data/pose_estimation/atomic_batches_4variants/BO_Gal4_fly2_trial001" \
"bulk_data/pose_estimation/atomic_batches_4variants/BO_Gal4_fly2_trial002" \
"bulk_data/pose_estimation/atomic_batches_4variants/BO_Gal4_fly2_trial003" \
"bulk_data/pose_estimation/atomic_batches_4variants/BO_Gal4_fly2_trial004" \
"bulk_data/pose_estimation/atomic_batches_4variants/BO_Gal4_fly2_trial005" \
"bulk_data/pose_estimation/atomic_batches_4variants/BO_Gal4_fly3_trial001" \
"bulk_data/pose_estimation/atomic_batches_4variants/BO_Gal4_fly3_trial002" \
"bulk_data/pose_estimation/atomic_batches_4variants/BO_Gal4_fly3_trial003" \
"bulk_data/pose_estimation/atomic_batches_4variants/BO_Gal4_fly3_trial004" \
"bulk_data/pose_estimation/atomic_batches_4variants/BO_Gal4_fly3_trial005" \
"bulk_data/pose_estimation/atomic_batches_4variants/BO_Gal4_fly4_trial001" \
"bulk_data/pose_estimation/atomic_batches_4variants/BO_Gal4_fly4_trial002" \
"bulk_data/pose_estimation/atomic_batches_4variants/BO_Gal4_fly4_trial003" \
"bulk_data/pose_estimation/atomic_batches_4variants/BO_Gal4_fly4_trial004" \
"bulk_data/pose_estimation/atomic_batches_4variants/BO_Gal4_fly4_trial005" \
--training-data-config.val-data-dirs \
"bulk_data/pose_estimation/atomic_batches_4variants/BO_Gal4_fly1_trial001" \
--training-data-config.atomic-batch-n-samples 32 \
--training-data-config.atomic-batch-n-variants 4 \
--training-data-config.train-batch-size 960 \
--training-data-config.val-batch-size 256 \
--training-data-config.image-size 256 256 \
--training-data-config.n-workers 4 \
--optimizer-config.adam-lr 3e-5 \
--optimizer-config.adam-weight-decay 1e-4 \
--training-artifacts-config.output-basedir \
"bulk_data/pose_estimation/contrastive_pretraining/$trial_name" \
--training-artifacts-config.logging-interval 10 \
--training-artifacts-config.checkpoint-interval 500 \
--training-artifacts-config.validation-interval 200 \
--training-artifacts-config.n-batches-per-validation 100

echo "Training ends at $(date)"
53 changes: 28 additions & 25 deletions scripts_on_cluster/keypoints3d_training/job.run
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/bin/bash -l

#SBATCH --job-name keypoints3d-20251118a
#SBATCH --job-name keypoints3d
#SBATCH --nodes 1
#SBATCH --ntasks 1
#SBATCH --cpus-per-task 16
Expand All @@ -9,7 +9,7 @@
#SBATCH --partition=h100
#SBATCH --qos=normal
#SBATCH --gres=gpu:1
#SBATCH --output /home/sibwang/poseforge/scripts_on_cluster/keypoints3d_training/output_20251118a.log
#SBATCH --output /home/sibwang/poseforge/scripts_on_cluster/keypoints3d_training/output_20251127a.log

echo "Hello from $(hostname)"

Expand All @@ -19,37 +19,40 @@ conda activate poseforge
cd $HOME/poseforge

training_cli_path="src/poseforge/pose/keypoints3d/scripts/run_keypoints3d_training.py"
training_trial_name="trial_20251118a"
training_trial_name="trial_20251127a"
contrastive_pretraining_trial_name="trial_20251125a_lowlr"
contrastive_pretraining_epoch="epoch009"
contrastive_pretraining_local_step="step003055"

echo "Training starting at $(date)"

python -u $training_cli_path \
--n-epochs 30 \
--model-weights-config.feature-extractor-weights \
"bulk_data/pose_estimation/contrastive_pretraining/trial_20251117a/checkpoints/checkpoint_epoch009_step003055.feature_extractor.pth" \
"bulk_data/pose_estimation/contrastive_pretraining/$contrastive_pretraining_trial_name/checkpoints/checkpoint_${contrastive_pretraining_epoch}_${contrastive_pretraining_local_step}.feature_extractor.pth" \
--training-data-config.train-data-dirs \
"bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly1_trial001" \
"bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly1_trial002" \
"bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly1_trial003" \
"bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly1_trial004" \
"bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly1_trial005" \
"bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly2_trial001" \
"bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly2_trial002" \
"bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly2_trial003" \
"bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly2_trial004" \
"bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly2_trial005" \
"bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly3_trial001" \
"bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly3_trial002" \
"bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly3_trial003" \
"bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly3_trial004" \
"bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly3_trial005" \
"bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly4_trial001" \
"bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly4_trial002" \
"bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly4_trial003" \
"bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly4_trial004" \
"bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly4_trial005" \
"bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly1_trial001" \
"bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly1_trial002" \
"bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly1_trial003" \
"bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly1_trial004" \
"bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly1_trial005" \
"bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly2_trial001" \
"bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly2_trial002" \
"bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly2_trial003" \
"bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly2_trial004" \
"bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly2_trial005" \
"bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly3_trial001" \
"bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly3_trial002" \
"bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly3_trial003" \
"bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly3_trial004" \
"bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly3_trial005" \
"bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly4_trial001" \
"bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly4_trial002" \
"bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly4_trial003" \
"bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly4_trial004" \
"bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly4_trial005" \
--training-data-config.val-data-dirs \
"bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly5_trial001" \
"bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly5_trial001" \
--training-data-config.input-image-size 256 256 \
--training-data-config.atomic-batch-n-samples 32 \
--training-data-config.atomic-batch-n-variants 4 \
Expand Down
75 changes: 74 additions & 1 deletion src/poseforge/neuromechfly/constants.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import numpy as np


###########################################################################
## NEUROMECHFLY BODY CONFIGURATION BELOW ##
###########################################################################

dof_name_lookup_nmf_to_canonical = {
"Coxa": "ThC_pitch",
"Coxa_roll": "ThC_roll",
Expand Down Expand Up @@ -38,7 +42,33 @@
f"{leg}{link}" for leg in legs for link in leg_keypoints_nmf
] + ["LPedicel", "RPedicel"]

kchain_plotting_colors = {
all_segment_names_per_leg = [
"Coxa",
"Femur",
"Tibia",
"Tarsus1",
"Tarsus2",
"Tarsus3",
"Tarsus4",
"Tarsus5",
]

all_leg_dofs = [
f"joint_{side}{pos}{dof}"
for side in "LR"
for pos in "FMH"
for dof in [
"Coxa",
"Coxa_roll",
"Coxa_yaw",
"Femur",
"Femur_roll",
"Tibia",
"Tarsus1",
]
]

kchain_plotting_colors = { # these are only for plotting aesthetics
"LF": np.array([15, 115, 153]) / 255,
"LM": np.array([26, 141, 175]) / 255,
"LH": np.array([117, 190, 203]) / 255,
Expand All @@ -49,6 +79,49 @@
"RAntenna": np.array([50, 120, 32]) / 255,
}


###########################################################################
## COLORS FOR BODY SEGMENT RENDERING BELOW ##
## These are set to artificially boost contrast between body segments ##
## -- they are NOT just for aesthetics! ##
###########################################################################

# Define color combo by body segment
color_by_link = {
"Coxa": "cyan",
"Femur": "yellow",
"Tibia": "blue",
"Tarsus": "green",
"Antenna": "magenta",
"Thorax": "gray",
}
color_by_kinematic_chain = {
"LF": "red", # left front leg
"LM": "green", # left mid leg
"LH": "blue", # left hind leg
"RF": "cyan", # right front leg
"RM": "magenta", # right mid leg
"RH": "yellow", # right hind leg
"L": "red", # left antenna
"R": "green", # right antenna
"Thorax": "white", # thorax
}
color_palette = {
"red": (1, 0, 0, 1),
"green": (0, 1, 0, 1),
"blue": (0, 0, 1, 1),
"yellow": (1, 1, 0, 1),
"magenta": (1, 0, 1, 1),
"cyan": (0, 1, 1, 1),
"gray": (0.4, 0.4, 0.4, 1),
"white": (1, 1, 1, 1),
}


###########################################################################
## PARAMETERS FOR INVERSE KINEMATICS WITH SEQIKPY BELOW ##
###########################################################################

# SeqIKPy considers the anchor point of every DoF a "joint" keypoint. However, some
# anatomical joints have multiple DoFs (e.g., ThC has yaw, pitch, roll). This results in
# some "virtual" keypoints in the inverse kinematics output. This mask filters them out.
Expand Down
3 changes: 1 addition & 2 deletions src/poseforge/neuromechfly/data.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import numpy as np
import pandas as pd
from pathlib import Path
from flygym.preprogrammed import all_leg_dofs

from poseforge.neuromechfly.constants import parse_nmf_joint_name
from poseforge.neuromechfly.constants import parse_nmf_joint_name, all_leg_dofs


def extract_joint_angles_trajectory(
Expand Down
Loading