diff --git a/scripts_on_cluster/bodyseg_training/job.run b/scripts_on_cluster/bodyseg_training/job.run index 73e663b..7a1154c 100644 --- a/scripts_on_cluster/bodyseg_training/job.run +++ b/scripts_on_cluster/bodyseg_training/job.run @@ -5,11 +5,11 @@ #SBATCH --ntasks 1 #SBATCH --cpus-per-task 16 #SBATCH --mem 92GB -#SBATCH --time 48:00:00 +#SBATCH --time 72:00:00 #SBATCH --partition=h100 #SBATCH --qos=normal #SBATCH --gres=gpu:1 -#SBATCH --output /home/sibwang/poseforge/scripts_on_cluster/bodyseg_training/output_20251118a.log +#SBATCH --output /home/sibwang/poseforge/scripts_on_cluster/bodyseg_training/output_20251127a.log echo "Hello from $(hostname)" @@ -19,7 +19,10 @@ conda activate poseforge cd $HOME/poseforge training_cli_path="src/poseforge/pose/bodyseg/scripts/run_bodyseg_training.py" -training_trial_name="trial_20251118a" +training_trial_name="trial_20251127a" +contrastive_pretraining_trial_name="trial_20251125a_lowlr" +contrastive_pretraining_epoch="epoch009" +contrastive_pretraining_local_step="step003055" echo "Training starting at $(date)" @@ -30,32 +33,32 @@ python -u $training_cli_path \ --model-architecture-config.final-upsampler-n-hidden-channels 32 \ --model-architecture-config.confidence-method entropy \ --model-weights-config.feature-extractor-weights \ - "bulk_data/pose_estimation/contrastive_pretraining/trial_20251117a/checkpoints/checkpoint_epoch009_step003055.feature_extractor.pth" \ + "bulk_data/pose_estimation/contrastive_pretraining/$contrastive_pretraining_trial_name/checkpoints/checkpoint_${contrastive_pretraining_epoch}_${contrastive_pretraining_local_step}.feature_extractor.pth" \ --loss-config.weight-dice 1.0 \ --loss-config.weight-ce 1.0 \ --training-data-config.train-data-dirs \ - "bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly1_trial001" \ - "bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly1_trial002" \ - "bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly1_trial003" \ - "bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly1_trial004" \ - "bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly1_trial005" \ - "bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly2_trial001" \ - "bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly2_trial002" \ - "bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly2_trial003" \ - "bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly2_trial004" \ - "bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly2_trial005" \ - "bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly3_trial001" \ - "bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly3_trial002" \ - "bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly3_trial003" \ - "bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly3_trial004" \ - "bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly3_trial005" \ - "bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly4_trial001" \ - "bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly4_trial002" \ - "bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly4_trial003" \ - "bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly4_trial004" \ - "bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly4_trial005" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly1_trial001" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly1_trial002" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly1_trial003" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly1_trial004" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly1_trial005" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly2_trial001" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly2_trial002" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly2_trial003" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly2_trial004" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly2_trial005" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly3_trial001" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly3_trial002" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly3_trial003" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly3_trial004" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly3_trial005" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly4_trial001" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly4_trial002" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly4_trial003" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly4_trial004" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly4_trial005" \ --training-data-config.val-data-dirs \ - "bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly1_trial001" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly1_trial001" \ --training-data-config.input-image-size 256 256 \ --training-data-config.atomic-batch-n-samples 32 \ --training-data-config.atomic-batch-n-variants 4 \ @@ -67,7 +70,7 @@ python -u $training_cli_path \ --optimizer-config.learning-rate-segmentation-head 3e-4 \ --optimizer-config.weight-decay 1e-5 \ --training-artifacts-config.output-basedir \ - "bulk_data/pose_estimation/bodyseg/trial_20251118a/" \ + "bulk_data/pose_estimation/bodyseg/$training_trial_name/" \ --training-artifacts-config.logging-interval 10 \ --training-artifacts-config.checkpoint-interval 1000 \ --training-artifacts-config.validation-interval 1000 \ diff --git a/scripts_on_cluster/contrastive_pretraining_training/job.run b/scripts_on_cluster/contrastive_pretraining_training/2variants.run similarity index 94% rename from scripts_on_cluster/contrastive_pretraining_training/job.run rename to scripts_on_cluster/contrastive_pretraining_training/2variants.run index befbd3c..8ba3069 100644 --- a/scripts_on_cluster/contrastive_pretraining_training/job.run +++ b/scripts_on_cluster/contrastive_pretraining_training/2variants.run @@ -5,11 +5,11 @@ #SBATCH --ntasks 1 #SBATCH --cpus-per-task 16 #SBATCH --mem 90GB -#SBATCH --time 48:00:00 +#SBATCH --time 72:00:00 #SBATCH --partition=h100 #SBATCH --qos=normal #SBATCH --gres=gpu:1 -#SBATCH --output /home/sibwang/poseforge/scripts_on_cluster/contrastive_pretraining_training/output_20251117a.log +#SBATCH --output /home/sibwang/poseforge/scripts_on_cluster/contrastive_pretraining_training/output_20251125b_2variants.log echo "Hello from $(hostname)" @@ -20,7 +20,7 @@ conda activate poseforge cd $HOME/poseforge echo "Training starting at $(date)" -trial_name="trial_20251117a" +trial_name="trial_20251125b_2variants" python -u src/poseforge/pose/contrast/scripts/run_contrastive_pretraining.py \ --n-epochs 10 \ @@ -53,7 +53,7 @@ python -u src/poseforge/pose/contrast/scripts/run_contrastive_pretraining.py \ --training-data-config.val-data-dirs \ "bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly1_trial001" \ --training-data-config.atomic-batch-n-samples 32 \ - --training-data-config.atomic-batch-n-variants 4 \ + --training-data-config.atomic-batch-n-variants 2 \ --training-data-config.train-batch-size 960 \ --training-data-config.val-batch-size 256 \ --training-data-config.image-size 256 256 \ diff --git a/scripts_on_cluster/contrastive_pretraining_training/low_lr.run b/scripts_on_cluster/contrastive_pretraining_training/low_lr.run new file mode 100644 index 0000000..c519def --- /dev/null +++ b/scripts_on_cluster/contrastive_pretraining_training/low_lr.run @@ -0,0 +1,70 @@ +#!/bin/bash -l + +#SBATCH --job-name contr_pretrain_lr3e-5 +#SBATCH --nodes 1 +#SBATCH --ntasks 1 +#SBATCH --cpus-per-task 16 +#SBATCH --mem 90GB +#SBATCH --time 72:00:00 +#SBATCH --partition=h100 +#SBATCH --qos=normal +#SBATCH --gres=gpu:1 +#SBATCH --output /home/sibwang/poseforge/scripts_on_cluster/contrastive_pretraining_training/output_20251125a_lowlr.log + +echo "Hello from $(hostname)" + +. ~/spack/share/spack/setup-env.sh +spack load ffmpeg +conda activate poseforge + +cd $HOME/poseforge +echo "Training starting at $(date)" + +trial_name="trial_20251125a_lowlr" + +python -u src/poseforge/pose/contrast/scripts/run_contrastive_pretraining.py \ + --n-epochs 10 \ + --seed 42 \ + --model-architecture-config.projection-head-hidden-dim 512 \ + --model-architecture-config.projection-head-output-dim 256 \ + --model-weights-config.feature-extractor-weights "IMAGENET1K_V1" \ + --loss-config.info-nce-temperature 0.1 \ + --training-data-config.train-data-dirs \ + "bulk_data/pose_estimation/atomic_batches_4variants/BO_Gal4_fly1_trial001" \ + "bulk_data/pose_estimation/atomic_batches_4variants/BO_Gal4_fly1_trial002" \ + "bulk_data/pose_estimation/atomic_batches_4variants/BO_Gal4_fly1_trial003" \ + "bulk_data/pose_estimation/atomic_batches_4variants/BO_Gal4_fly1_trial004" \ + "bulk_data/pose_estimation/atomic_batches_4variants/BO_Gal4_fly1_trial005" \ + "bulk_data/pose_estimation/atomic_batches_4variants/BO_Gal4_fly2_trial001" \ + "bulk_data/pose_estimation/atomic_batches_4variants/BO_Gal4_fly2_trial002" \ + "bulk_data/pose_estimation/atomic_batches_4variants/BO_Gal4_fly2_trial003" \ + "bulk_data/pose_estimation/atomic_batches_4variants/BO_Gal4_fly2_trial004" \ + "bulk_data/pose_estimation/atomic_batches_4variants/BO_Gal4_fly2_trial005" \ + "bulk_data/pose_estimation/atomic_batches_4variants/BO_Gal4_fly3_trial001" \ + "bulk_data/pose_estimation/atomic_batches_4variants/BO_Gal4_fly3_trial002" \ + "bulk_data/pose_estimation/atomic_batches_4variants/BO_Gal4_fly3_trial003" \ + "bulk_data/pose_estimation/atomic_batches_4variants/BO_Gal4_fly3_trial004" \ + "bulk_data/pose_estimation/atomic_batches_4variants/BO_Gal4_fly3_trial005" \ + "bulk_data/pose_estimation/atomic_batches_4variants/BO_Gal4_fly4_trial001" \ + "bulk_data/pose_estimation/atomic_batches_4variants/BO_Gal4_fly4_trial002" \ + "bulk_data/pose_estimation/atomic_batches_4variants/BO_Gal4_fly4_trial003" \ + "bulk_data/pose_estimation/atomic_batches_4variants/BO_Gal4_fly4_trial004" \ + "bulk_data/pose_estimation/atomic_batches_4variants/BO_Gal4_fly4_trial005" \ + --training-data-config.val-data-dirs \ + "bulk_data/pose_estimation/atomic_batches_4variants/BO_Gal4_fly1_trial001" \ + --training-data-config.atomic-batch-n-samples 32 \ + --training-data-config.atomic-batch-n-variants 4 \ + --training-data-config.train-batch-size 960 \ + --training-data-config.val-batch-size 256 \ + --training-data-config.image-size 256 256 \ + --training-data-config.n-workers 4 \ + --optimizer-config.adam-lr 3e-5 \ + --optimizer-config.adam-weight-decay 1e-4 \ + --training-artifacts-config.output-basedir \ + "bulk_data/pose_estimation/contrastive_pretraining/$trial_name" \ + --training-artifacts-config.logging-interval 10 \ + --training-artifacts-config.checkpoint-interval 500 \ + --training-artifacts-config.validation-interval 200 \ + --training-artifacts-config.n-batches-per-validation 100 + +echo "Training ends at $(date)" diff --git a/scripts_on_cluster/keypoints3d_training/job.run b/scripts_on_cluster/keypoints3d_training/job.run index 678ebcb..5e623d4 100644 --- a/scripts_on_cluster/keypoints3d_training/job.run +++ b/scripts_on_cluster/keypoints3d_training/job.run @@ -1,6 +1,6 @@ #!/bin/bash -l -#SBATCH --job-name keypoints3d-20251118a +#SBATCH --job-name keypoints3d #SBATCH --nodes 1 #SBATCH --ntasks 1 #SBATCH --cpus-per-task 16 @@ -9,7 +9,7 @@ #SBATCH --partition=h100 #SBATCH --qos=normal #SBATCH --gres=gpu:1 -#SBATCH --output /home/sibwang/poseforge/scripts_on_cluster/keypoints3d_training/output_20251118a.log +#SBATCH --output /home/sibwang/poseforge/scripts_on_cluster/keypoints3d_training/output_20251127a.log echo "Hello from $(hostname)" @@ -19,37 +19,40 @@ conda activate poseforge cd $HOME/poseforge training_cli_path="src/poseforge/pose/keypoints3d/scripts/run_keypoints3d_training.py" -training_trial_name="trial_20251118a" +training_trial_name="trial_20251127a" +contrastive_pretraining_trial_name="trial_20251125a_lowlr" +contrastive_pretraining_epoch="epoch009" +contrastive_pretraining_local_step="step003055" echo "Training starting at $(date)" python -u $training_cli_path \ --n-epochs 30 \ --model-weights-config.feature-extractor-weights \ - "bulk_data/pose_estimation/contrastive_pretraining/trial_20251117a/checkpoints/checkpoint_epoch009_step003055.feature_extractor.pth" \ + "bulk_data/pose_estimation/contrastive_pretraining/$contrastive_pretraining_trial_name/checkpoints/checkpoint_${contrastive_pretraining_epoch}_${contrastive_pretraining_local_step}.feature_extractor.pth" \ --training-data-config.train-data-dirs \ - "bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly1_trial001" \ - "bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly1_trial002" \ - "bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly1_trial003" \ - "bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly1_trial004" \ - "bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly1_trial005" \ - "bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly2_trial001" \ - "bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly2_trial002" \ - "bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly2_trial003" \ - "bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly2_trial004" \ - "bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly2_trial005" \ - "bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly3_trial001" \ - "bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly3_trial002" \ - "bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly3_trial003" \ - "bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly3_trial004" \ - "bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly3_trial005" \ - "bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly4_trial001" \ - "bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly4_trial002" \ - "bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly4_trial003" \ - "bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly4_trial004" \ - "bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly4_trial005" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly1_trial001" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly1_trial002" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly1_trial003" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly1_trial004" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly1_trial005" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly2_trial001" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly2_trial002" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly2_trial003" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly2_trial004" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly2_trial005" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly3_trial001" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly3_trial002" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly3_trial003" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly3_trial004" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly3_trial005" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly4_trial001" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly4_trial002" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly4_trial003" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly4_trial004" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly4_trial005" \ --training-data-config.val-data-dirs \ - "bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly5_trial001" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly5_trial001" \ --training-data-config.input-image-size 256 256 \ --training-data-config.atomic-batch-n-samples 32 \ --training-data-config.atomic-batch-n-variants 4 \ diff --git a/src/poseforge/neuromechfly/constants.py b/src/poseforge/neuromechfly/constants.py index 12dac23..dfccbfc 100644 --- a/src/poseforge/neuromechfly/constants.py +++ b/src/poseforge/neuromechfly/constants.py @@ -1,6 +1,10 @@ import numpy as np +########################################################################### +## NEUROMECHFLY BODY CONFIGURATION BELOW ## +########################################################################### + dof_name_lookup_nmf_to_canonical = { "Coxa": "ThC_pitch", "Coxa_roll": "ThC_roll", @@ -38,7 +42,33 @@ f"{leg}{link}" for leg in legs for link in leg_keypoints_nmf ] + ["LPedicel", "RPedicel"] -kchain_plotting_colors = { +all_segment_names_per_leg = [ + "Coxa", + "Femur", + "Tibia", + "Tarsus1", + "Tarsus2", + "Tarsus3", + "Tarsus4", + "Tarsus5", +] + +all_leg_dofs = [ + f"joint_{side}{pos}{dof}" + for side in "LR" + for pos in "FMH" + for dof in [ + "Coxa", + "Coxa_roll", + "Coxa_yaw", + "Femur", + "Femur_roll", + "Tibia", + "Tarsus1", + ] +] + +kchain_plotting_colors = { # these are only for plotting aesthetics "LF": np.array([15, 115, 153]) / 255, "LM": np.array([26, 141, 175]) / 255, "LH": np.array([117, 190, 203]) / 255, @@ -49,6 +79,49 @@ "RAntenna": np.array([50, 120, 32]) / 255, } + +########################################################################### +## COLORS FOR BODY SEGMENT RENDERING BELOW ## +## These are set to artificially boost contrast between body segments ## +## -- they are NOT just for aesthetics! ## +########################################################################### + +# Define color combo by body segment +color_by_link = { + "Coxa": "cyan", + "Femur": "yellow", + "Tibia": "blue", + "Tarsus": "green", + "Antenna": "magenta", + "Thorax": "gray", +} +color_by_kinematic_chain = { + "LF": "red", # left front leg + "LM": "green", # left mid leg + "LH": "blue", # left hind leg + "RF": "cyan", # right front leg + "RM": "magenta", # right mid leg + "RH": "yellow", # right hind leg + "L": "red", # left antenna + "R": "green", # right antenna + "Thorax": "white", # thorax +} +color_palette = { + "red": (1, 0, 0, 1), + "green": (0, 1, 0, 1), + "blue": (0, 0, 1, 1), + "yellow": (1, 1, 0, 1), + "magenta": (1, 0, 1, 1), + "cyan": (0, 1, 1, 1), + "gray": (0.4, 0.4, 0.4, 1), + "white": (1, 1, 1, 1), +} + + +########################################################################### +## PARAMETERS FOR INVERSE KINEMATICS WITH SEQIKPY BELOW ## +########################################################################### + # SeqIKPy considers the anchor point of every DoF a "joint" keypoint. However, some # anatomical joints have multiple DoFs (e.g., ThC has yaw, pitch, roll). This results in # some "virtual" keypoints in the inverse kinematics output. This mask filters them out. diff --git a/src/poseforge/neuromechfly/data.py b/src/poseforge/neuromechfly/data.py index 8901e7f..4cf3b8e 100644 --- a/src/poseforge/neuromechfly/data.py +++ b/src/poseforge/neuromechfly/data.py @@ -1,9 +1,8 @@ import numpy as np import pandas as pd from pathlib import Path -from flygym.preprogrammed import all_leg_dofs -from poseforge.neuromechfly.constants import parse_nmf_joint_name +from poseforge.neuromechfly.constants import parse_nmf_joint_name, all_leg_dofs def extract_joint_angles_trajectory( diff --git a/src/poseforge/neuromechfly/postprocessing.py b/src/poseforge/neuromechfly/postprocessing.py index ce62626..fa33870 100644 --- a/src/poseforge/neuromechfly/postprocessing.py +++ b/src/poseforge/neuromechfly/postprocessing.py @@ -11,14 +11,7 @@ from tqdm import tqdm from joblib import Parallel, delayed -import poseforge.neuromechfly.simulate as simulate -from poseforge.neuromechfly.constants import ( - keypoint_name_lookup_canonical_to_nmf, - kchain_plotting_colors, - keypoint_segments_nmf, - legs, - leg_keypoints_canonical, -) +import poseforge.neuromechfly.constants as constants from poseforge.util.plot import ( configure_matplotlib_style, get_segmentation_color_palette, @@ -65,8 +58,10 @@ def __init__(self): for pos in "FMH": for link in leg_segments: leg = f"{side}{pos}" - color0 = nmf_rendered_colors[simulate.color_by_link[link]] - color1 = nmf_rendered_colors[simulate.color_by_kinematic_chain[leg]] + color0 = nmf_rendered_colors[constants.color_by_link[link]] + color1 = nmf_rendered_colors[ + constants.color_by_kinematic_chain[leg] + ] color_6d = np.array(list(color0) + list(color1)) label = f"{leg}{link}" self.label_keys.append(label) @@ -74,8 +69,8 @@ def __init__(self): # Antennas for side in "LR": - color0 = nmf_rendered_colors[simulate.color_by_link["Antenna"]] - color1 = nmf_rendered_colors[simulate.color_by_kinematic_chain[side]] + color0 = nmf_rendered_colors[constants.color_by_link["Antenna"]] + color1 = nmf_rendered_colors[constants.color_by_kinematic_chain[side]] color_6d = np.array(list(color0) + list(color1)) label = f"{side}Antenna" self.label_keys.append(label) @@ -288,7 +283,7 @@ def process_single_frame( # Gather keypoint positions in coordinates and rotate/center-crop accordingly keypoints_pos_dict_world_raw, keypoints_pos_dict_camera_raw = ( extract_body_segment_positions( - h5_file, frame_idx, "pos_atparent", keypoint_segments_nmf + h5_file, frame_idx, "pos_atparent", constants.keypoint_segments_nmf ) ) keypoints_pos_dict_world_rotated = rotate_keypoint_positions_world( @@ -352,8 +347,7 @@ def process_subsegment( segment_label_parser, ) ) - - # Process frames in parallel + # Parallel execution with joblib # Use 'loky' backend for CPU-intensive image processing operations parallel_executor = Parallel(n_jobs=n_jobs, backend="loky") effective_n_jobs = parallel_executor._effective_n_jobs() @@ -446,9 +440,10 @@ def process_subsegment( keypoint_pos_group = postprocessed_group.create_group("keypoint_pos") for ref_frame in ["camera", "world"]: data_block = np.empty( - (num_frames, len(keypoint_segments_nmf), 3), dtype="float32" + (num_frames, len(constants.keypoint_segments_nmf), 3), + dtype="float32", ) - for seg_id, body_segment in enumerate(keypoint_segments_nmf): + for seg_id, body_segment in enumerate(constants.keypoint_segments_nmf): key = f"keypoint_pos_{ref_frame}_{body_segment}" values = np.array(derived_variables_by_key[key]) data_block[:, seg_id, :] = values @@ -456,13 +451,13 @@ def process_subsegment( pos_ds = keypoint_pos_group.create_dataset( f"{ref_frame}_coords", data=data_block, dtype="float32" ) - pos_ds.attrs["keys"] = keypoint_segments_nmf + pos_ds.attrs["keys"] = constants.keypoint_segments_nmf pos_ds.attrs["description"] = ( f"Keypoint positions in {ref_frame} coordinates. Shape is " "(num_frames, num_keypoints, 3). See the `.attrs['keys']` for the " "order of keypoints." ) - keypoint_pos_group.attrs["keys"] = keypoint_segments_nmf + keypoint_pos_group.attrs["keys"] = constants.keypoint_segments_nmf keypoint_pos_group.attrs["description"] = ( "This group contains positions of joint keypoints in the rotated image " "centered around the fly, cropped, and rotated so that the fly faces " @@ -492,6 +487,17 @@ def process_subsegment( "for the mapping from label IDs (pixel values) to body segment names." ) + # Add mesh state labels + seg_states_grp = postprocessed_group.create_group("body_segment_states") + seg_states_grp.attrs.update(source_h5_file["body_segment_states"].attrs) + for sensor_type in source_h5_file["body_segment_states"].keys(): + source_ds = source_h5_file["body_segment_states"][sensor_type] + seg_states_grp.create_dataset( + sensor_type, + data=source_ds[frame_idx_start:frame_idx_end, :, :], + dtype="float32", + ) + def _draw_pose_2d_and_3d( ax_pose2d: plt.Axes, @@ -518,11 +524,11 @@ def _draw_pose_2d_and_3d( # Legs keypoint_pos_cam_ds = h5_file["postprocessed/keypoint_pos/camera_coords"] keypoints = keypoint_pos_cam_ds.attrs["keys"].tolist() - for leg in legs: - color = kchain_plotting_colors[leg] + for leg in constants.legs: + color = constants.kchain_plotting_colors[leg] all_positions = [] - for kpt in leg_keypoints_canonical: - segment_name = keypoint_name_lookup_canonical_to_nmf[kpt] + for kpt in constants.leg_keypoints_canonical: + segment_name = constants.keypoint_name_lookup_canonical_to_nmf[kpt] keypoint_idx = keypoints.index(f"{leg}{segment_name}") pos = keypoint_pos_cam_ds[frame_index, keypoint_idx, :] all_positions.append(pos) @@ -535,17 +541,17 @@ def _draw_pose_2d_and_3d( segment_name = f"{side}Pedicel" keypoint_idx = keypoints.index(segment_name) pos = keypoint_pos_cam_ds[frame_index, keypoint_idx, :] - color = kchain_plotting_colors[f"{side}Antenna"] + color = constants.kchain_plotting_colors[f"{side}Antenna"] ax_pose2d.plot(pos[0], pos[1], marker="o", color=color, markersize=5) # Plot 3D keypoints keypoint_pos_world_ds = h5_file["postprocessed/keypoint_pos/world_coords"] # Legs - for leg in legs: - color = kchain_plotting_colors[leg] + for leg in constants.legs: + color = constants.kchain_plotting_colors[leg] all_positions = [] - for kpt in leg_keypoints_canonical: - segment_name = keypoint_name_lookup_canonical_to_nmf[kpt] + for kpt in constants.leg_keypoints_canonical: + segment_name = constants.keypoint_name_lookup_canonical_to_nmf[kpt] keypoint_idx = keypoints.index(f"{leg}{segment_name}") pos = keypoint_pos_world_ds[frame_index, keypoint_idx, :] all_positions.append(pos) @@ -563,7 +569,7 @@ def _draw_pose_2d_and_3d( segment_name = f"{side}Pedicel" keypoint_idx = keypoints.index(segment_name) pos = keypoint_pos_world_ds[frame_index, keypoint_idx, :] - color = kchain_plotting_colors[f"{side}Antenna"] + color = constants.kchain_plotting_colors[f"{side}Antenna"] ax_pose3d.plot( pos[0], pos[1], pos[2], marker="o", color=color, markersize=5 ) diff --git a/src/poseforge/neuromechfly/scripts/run_simulation.py b/src/poseforge/neuromechfly/scripts/run_simulation.py index 8af55bf..c875731 100644 --- a/src/poseforge/neuromechfly/scripts/run_simulation.py +++ b/src/poseforge/neuromechfly/scripts/run_simulation.py @@ -53,7 +53,7 @@ from pathlib import Path from poseforge.neuromechfly.data import load_kinematic_recording -from poseforge.neuromechfly.simulate import simulate_one_segment +# from poseforge.neuromechfly.simulate import simulate_one_segment # TODO: revert from poseforge.neuromechfly.postprocessing import postprocess_segment from poseforge.util import get_hardware_availability @@ -134,19 +134,20 @@ def simulate_using_kinematic_prior( print(f"=== Simulating segment #{segment_id} ({num_segments} total) ===") segment = kinematic_recording_segments[segment_id] output_subdir = trial_output_dir / f"segment_{segment_id:03d}" - is_success = simulate_one_segment( - kinematic_recording_segment=segment, - output_dir=output_subdir, - input_timestep=input_timestep, - sim_timestep=sim_timestep, - output_data_freq=output_data_freq, - render_play_speed=render_play_speed, - min_sim_duration_sec=0.2, - max_sim_steps=max_sim_steps_per_segment, - ) + # is_success = simulate_one_segment( # TODO: revert + # kinematic_recording_segment=segment, + # output_dir=output_subdir, + # input_timestep=input_timestep, + # sim_timestep=sim_timestep, + # output_data_freq=output_data_freq, + # render_play_speed=render_play_speed, + # min_sim_duration_sec=0.2, + # max_sim_steps=max_sim_steps_per_segment, + # ) + is_success = output_subdir.exists() and len(list(output_subdir.iterdir())) > 0 if is_success: postprocess_segment( - output_subdir, visualize=True, min_subsegment_duration_sec=0.1 + output_subdir, visualize=False, min_subsegment_duration_sec=0.1 # TODO: enable visualization ) print(f"### Done processing trial: {trial_name} ###") @@ -154,19 +155,22 @@ def simulate_using_kinematic_prior( def run_sequentially_for_testing(): """Run everything sequentially (for debugging)""" # Configs - output_basedir = Path("bulk_data/nmf_rendering_test/") + output_basedir = Path("bulk_data/nmf_rendering_new/") # TODO: change back to *_test input_timestep = 0.01 sim_timestep = 0.0001 - trial_paths = [ - # For testing: change this list to limit the scope - Path("bulk_data/kinematic_prior/aymanns2022/trials/BO_Gal4_fly1_trial001.pkl") - ] + # trial_paths = [ + # # For testing: change this list to limit the scope + # Path("bulk_data/kinematic_prior/aymanns2022/trials/BO_Gal4_fly1_trial001.pkl") + # ] + trial_paths = sorted( # TODO: revert + Path("bulk_data/kinematic_prior/aymanns2022/trials/").glob("*.pkl") + ) # Limit scope of simulation as this is only for testing # Don't make `max_sim_steps_per_segment` too small; otherwise no subsegment-level # postprocessing will be performed - max_segments_per_trial = 2 - max_sim_steps_per_segment = 3000 + max_segments_per_trial = None # 2 # TODO: revert + max_sim_steps_per_segment = None # 3000 # TODO: revert # Process each trial for trial_path in trial_paths: @@ -187,7 +191,7 @@ def run_sequentially_for_testing(): get_hardware_availability(check_gpu=False, print_results=True) # Run the CLI - tyro.cli(simulate_using_kinematic_prior) + tyro.cli(simulate_using_kinematic_prior) # TODO: enable CLI - # # Run everything sequentially (for debugging) + # Run everything sequentially (for debugging) # TODO: disable testing # run_sequentially_for_testing() diff --git a/src/poseforge/neuromechfly/scripts/visualize_meshes.py b/src/poseforge/neuromechfly/scripts/visualize_meshes.py index a965440..cb6554e 100644 --- a/src/poseforge/neuromechfly/scripts/visualize_meshes.py +++ b/src/poseforge/neuromechfly/scripts/visualize_meshes.py @@ -1,75 +1,125 @@ import numpy as np import pandas as pd import pyvista as pv -import flygym +import h5py from xml.etree import ElementTree from pathlib import Path from scipy.spatial.transform import Rotation +from scipy.linalg import rq +from poseforge.neuromechfly.constants import legs, all_segment_names_per_leg -df = pd.read_pickle( - "bulk_data/nmf_rendering_enhanced/BO_Gal4_fly1_trial001/segment_000/subsegment_000/processed_kinematic_states.pkl" + +# Define paths +subsegment_dir = Path( + "bulk_data/nmf_rendering/BO_Gal4_fly1_trial001/segment_000/subsegment_000" ) -flygym_data_dir = Path(flygym.__file__).parent / "data" +sim_data_path = subsegment_dir / "processed_simulation_data.h5" +flygym_data_dir = Path("~/projects/flygym/flygym").expanduser() / "data" nmf_mesh_dir = flygym_data_dir / "mesh" mjcf_path = flygym_data_dir / "mjcf/neuromechfly_seqik_kinorder_ypr.xml" -sides = "LR" -positions = "FMH" -links = [ - "Coxa", - "Femur", - "Tibia", - "Tarsus1", - "Tarsus2", - "Tarsus3", - "Tarsus4", - "Tarsus5", -] - +# Load NeuroMechFly model mjcf_tree = ElementTree.parse(mjcf_path) worldbody = mjcf_tree.find("worldbody") body_attributes = {body.attrib["name"]: body.attrib for body in worldbody.iter("body")} -frame_idx = 10 -entry = df.loc[frame_idx] +# Load simulation data +with h5py.File(sim_data_path, "r") as f: + all_seg_pos_global = f["raw/body_segment_states/pos_global"][:] + all_seg_quat_global = f["raw/body_segment_states/quat_global"][:] + all_cam_matrices = f["raw/camera_matrix"][:] + all_seg_names = list(f["raw/body_segment_states"].attrs["keys"]) -plotter = pv.Plotter() - +n_frames = all_seg_pos_global.shape[0] segments_to_include = [ - f"{side}{pos}{link}" for side in sides for pos in positions for link in links + f"{leg}{seg}" for leg in legs for seg in all_segment_names_per_leg ] segments_to_include += ["Thorax"] -meshes = [] -for key in segments_to_include: - translation = entry[f"body_seg_pos_global_{key}"] - quaternion = entry[f"body_seg_quat_global_{key}"] - placement_transform = np.eye(4) - placement_transform[:3, :3] = Rotation.from_quat(quaternion).as_matrix() - placement_transform[:3, 3] = translation - mesh_file = nmf_mesh_dir / f"{key}.stl" - mesh = pv.read(mesh_file) +# Load original meshes once (before any transformations) +original_meshes = {} +for seg_name in segments_to_include: + mesh_file = nmf_mesh_dir / f"{seg_name}.stl" + original_meshes[seg_name] = pv.read(mesh_file) - # Scale - scale_transform = np.eye(4) - np.fill_diagonal(scale_transform, [1000, 1000, 1000, 1]) - mesh.transform(scale_transform) +# Create plotter +plotter = pv.Plotter() +plotter.set_background("black") +plotter.show_axes() - # Apply transformation based on MuJoCo state - placement_transform = np.eye(4) - rotation_object = Rotation.from_quat(quaternion, scalar_first=True) - placement_transform[:3, :3] = rotation_object.as_matrix() - placement_transform[:3, 3] = translation - mesh.transform(placement_transform) +# Add all meshes to plotter initially +current_meshes = {} +for seg_name in segments_to_include: + current_meshes[seg_name] = original_meshes[seg_name].copy() + plotter.add_mesh( + current_meshes[seg_name], + show_edges=False, + name=seg_name, + smooth_shading=True + ) - meshes.append(mesh) - plotter.add_mesh(mesh, show_edges=False, name=key, smooth_shading=True) +# Current frame tracker +current_frame = [0] - # print(f"Loading {key}:\n\ttranslation={translation}\n\tquaternion={quaternion}") +def update_frame(): + """Update all meshes to the current frame""" + frame_idx = current_frame[0] + cam_matrix = all_cam_matrices[frame_idx, :, :] + + # Compute camera transformation once per frame + cam_intrinsics, cam_rotation = rq(cam_matrix[:, :3]) + _sign_multiplier = np.diag(np.sign(np.diag(cam_intrinsics))) + cam_intrinsics = cam_intrinsics @ _sign_multiplier + cam_rotation = _sign_multiplier @ cam_rotation + cam_translation = np.linalg.inv(cam_intrinsics) @ cam_matrix[:, 3] + + transform_world2cam = np.eye(4) + transform_world2cam[:3, :3] = cam_rotation + transform_world2cam[:3, 3] = cam_translation + transform_cam2world = np.linalg.inv(transform_world2cam) + + # Update each segment mesh + for seg_name in segments_to_include: + seg_idx = all_seg_names.index(seg_name) + translation = all_seg_pos_global[frame_idx, seg_idx, :] + quaternion = all_seg_quat_global[frame_idx, seg_idx, :] + + # Start with original mesh + mesh = original_meshes[seg_name].copy() + + # Scale + scale_transform = np.eye(4) + np.fill_diagonal(scale_transform, [1000, 1000, 1000, 1]) + mesh = mesh.transform(scale_transform, inplace=False) + + # Apply MuJoCo state transformation + placement_transform = np.eye(4) + rotation_object = Rotation.from_quat(quaternion, scalar_first=True) + placement_transform[:3, :3] = rotation_object.as_matrix() + placement_transform[:3, 3] = translation + mesh = mesh.transform(placement_transform, inplace=False) + + # Transform to camera coordinates + mesh = mesh.transform(transform_cam2world, inplace=False) + + # Update the mesh points in place + current_meshes[seg_name].points[:] = mesh.points + + # Update frame counter + current_frame[0] = (current_frame[0] + 1) % n_frames + + # Update title to show current frame + plotter.add_text(f"Frame: {frame_idx}/{n_frames}", name="frame_counter", position="upper_left") -plotter.set_background("black") + +# Initialize first frame +update_frame() plotter.reset_camera() -plotter.show_axes() -plotter.show() # or plotter.show(screenshot='screenshot.png') + +# Add timer callback for animation (30 fps) +# The callback receives a step argument, so we need to accept it +plotter.add_timer_event(max_steps=n_frames, duration=int(1000/30), callback=lambda step: update_frame()) + +plotter.show() \ No newline at end of file diff --git a/src/poseforge/neuromechfly/simulate.py b/src/poseforge/neuromechfly/simulate.py index 75642e0..144b631 100644 --- a/src/poseforge/neuromechfly/simulate.py +++ b/src/poseforge/neuromechfly/simulate.py @@ -13,39 +13,12 @@ from flygym.preprogrammed import all_leg_dofs from poseforge.neuromechfly.data import interpolate_trajectories -from poseforge.neuromechfly.constants import parse_nmf_joint_name - - -# Define color combo by body segment -color_by_link = { - "Coxa": "cyan", - "Femur": "yellow", - "Tibia": "blue", - "Tarsus": "green", - "Antenna": "magenta", - "Thorax": "gray", -} -color_by_kinematic_chain = { - "LF": "red", # left front leg - "LM": "green", # left mid leg - "LH": "blue", # left hind leg - "RF": "cyan", # right front leg - "RM": "magenta", # right mid leg - "RH": "yellow", # right hind leg - "L": "red", # left antenna - "R": "green", # right antenna - "Thorax": "white", # thorax -} -color_palette = { - "red": (1, 0, 0, 1), - "green": (0, 1, 0, 1), - "blue": (0, 0, 1, 1), - "yellow": (1, 1, 0, 1), - "magenta": (1, 0, 1, 1), - "cyan": (0, 1, 1, 1), - "gray": (0.4, 0.4, 0.4, 1), - "white": (1, 1, 1, 1), -} +from poseforge.neuromechfly.constants import ( + parse_nmf_joint_name, + color_by_link, + color_by_kinematic_chain, + color_palette, +) class SpotlightArena(FlatTerrain):