-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathrun_model.py
More file actions
1291 lines (1088 loc) · 68.5 KB
/
run_model.py
File metadata and controls
1291 lines (1088 loc) · 68.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#!/usr/bin/env python3
"""
Command-line script for evaluating a krakencoder model checkpoint on new data
Must provide a checkpoint file (.pt) and input transform files (*ioxfm*.npy), unless model was trained on non-transformed data
Outputs can include predicted connectomes of different types, latent vectors, and/or just heatmaps and records of model performance metrics on new data
Main functions it calls, after parsing args:
- krakencoder/model.py/Krakencoder.load_checkpoint()
- krakencoder/data.py/load_hcp_data()
or data.py/load_input_data()
- krakencoder/data.py/generate_transformer()
using information from specified saved transformer files
- krakencoder/model.py/Krakencoder()
- to run forward predictions on data
Examples:
#1. Evaluate checkpoint on held-out "test" split from HCP data, using precomputed PCA input transformers, and save performance metrics and heatmaps
python run_model.py --inputdata test --subjectfile subject_splits_958subj_683train_79val_196test_retestInTest.mat \
--checkpoint kraken_chkpt_SCFC_20240406_022034_ep002000.pt \
--inputxform kraken_ioxfm_SCFC_coco439_993subj_pc256_25paths_710train_20220527.npy \
kraken_ioxfm_SCFC_fs86_993subj_pc256_25paths_710train_20220527.npy \
kraken_ioxfm_SCFC_shen268_993subj_pc256_25paths_710train_20220527.npy \
--newtrainrecord hcp_20240406_022034_ep002000_mse.w1000_newver_test.mat \
--heatmap hcp_20240406_022034_ep002000_mse.w1000_newver_test.png \
--heatmapmetrics top1acc topNacc avgrank avgcorr_resid \
--fusion --fusioninclude fusion=all fusionSC=SC fusionFC=FC --fusionnoself --fusionnoatlas
#2. To generate predicted connectomes, add:
--outputname all --output mydata_20240406_022034_ep002000_{output}.mat
# which will generate an file predictions of each connectivity flavor in the model, named like:
# mydata_20240406_022034_ep002000_FCcorr_shen268_hpf_FC.mat
# which will contain the predicted FCcorr_shen268_hpf_FC from every input type provided
#3. To generate predicted connectomes from only fusion inputs:
--output 'mydata_20240406_022034_ep002000_{input}.mat' --fusion --onlyfusioninputs
# or for multiple fusion types (i.e., fusionSC=only using SC inputs):
--output 'mydata_20240406_022034_ep002000_{input}.mat' --fusion --onlyfusioninputs \
--fusioninclude fusion=all fusionSC=SC fusionFC=FC --fusionnoself --fusionnoatlas
#4. To generate the latent space outputs for this input data, add:
--outputname encoded --output mydata_20240406_022034_ep002000_{output}.mat"
#5. To use your own non-HCP input data, provide a .mat file for each input type, with a 'data' field containing the [subjects x region x region]
# connectivity data. Then include the filenames and connectivity names using:
--inputdata '[SCsdstream_fs86_volnorm]=mydata_fs86_sdstream_volnorm.mat' \
'[SCifod2act_fs86_volnorm]=mydata_fs86_ifod2act_volnorm.mat' \
'[SCsdstream_shen268_volnorm]=mydata_shen268_sdstream_volnorm.mat' \
'[SCifod2act_shen268_volnorm]=mydata_shen268_ifod2act_volnorm.mat' \
'[SCsdstream_coco439_volnorm]=mydata_coco439_sdstream_volnorm.mat' \
'[SCifod2act_coco439_volnorm]=mydata_coco439_ifod2act_volnorm.mat' \
--adaptmode meanfit+meanshift
#where "--adaptmode meanfit+meanshift" uses minimal approach for domain shift by linearly mapping the population mean of your input data to the
# population mean of the training data. This is a simple way to adapt the input data to the model, but may not be sufficient for all cases."
"""
from krakencoder.model import *
from krakencoder.train import *
from krakencoder.data import *
from krakencoder.utils import *
from krakencoder.merge import *
from krakencoder.fetch import *
from scipy.io import loadmat, savemat
import re
import os
import sys
import argparse
import warnings
def argument_parse_newdata(argv):
#for list-based inputs, need to specify the defaults this way, otherwise the argparse append just adds to them
arg_defaults={}
arg_defaults['fusion_include']=[]
arg_defaults['input_names']=[]
arg_defaults['output_names']=[]
arg_defaults['input_data_file']=[]
arg_defaults['input_transform_file']=["auto"]
arg_defaults['heatmap_metrictype_list']=['top1acc','topNacc','avgrank','avgcorr_resid']
arg_defaults['input_json_load_types']=['all']
parser=argparse.ArgumentParser(description='Evaluate krakencoder checkpoint',formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--checkpoint',action='store',dest='checkpoint', help='Checkpoint file (.pt)', nargs='*')
parser.add_argument('--innercheckpoint',action='store',dest='innercheckpoint', help='Inner checkpoint file (.pt): eg: if checkpoint is an adaptation layer')
parser.add_argument('--trainrecord',action='store',dest='trainrecord', default='auto', help='trainrecord.mat file')
parser.add_argument('--inputxform',action='append',dest='input_transform_file', help='Precomputed transformer file (.npy)',nargs='*')
parser.add_argument('--inputdata',action='append',dest='input_data_file', help='.mat file(s) containing input data to transform (instead of default HCP validation set). Can be "name=file"', nargs='*')
parser.add_argument('--output',action='store',dest='output', help='file to save model outputs. Can include "{input}" and/or "{output}" in name to save separate files for each input/output combo (or just group inputs and outputs)')
parser.add_argument('--output_squaremats','--output_square',action='store_true',dest='output_squaremats',help='Convert output matrices from triangular to square form (Currently k=1 with NaN diag only)')
parser.add_argument('--inputname','--inputnames',action='append',dest='input_names', help='Name(s) of input data flavors (eg: FCcorr_fs86_hpf, SCsdstream_fs86, encoded)',nargs='*')
parser.add_argument('--outputname','--outputnames',action='append',dest='output_names', help='List of data flavors from model to predict (default=all)',nargs='*')
parser.add_argument('--adaptmode',action='store',dest='adapt_mode',default='none',help='How do adapt new data to fit model (default: none)')
parser.add_argument('--adaptmodesplitname',action='store',dest='adapt_mode_split_name',default='train',help='Which subject split to use for domain adaptation (default: train)')
parser.add_argument('--adaptmodesourcedata',action='append',dest='adapt_mode_data_file',help='.mat file(s) containing input data to use for domain adaptation (instead of input mean). Can be "name=file"', nargs='*')
parser.add_argument('--fusioninclude',action='append',dest='fusion_include',help='inputnames to include in fusion average',nargs='*')
parser.add_argument('--fusion',action ='store_true',dest='fusion',help='fusion mode eval')
parser.add_argument('--fusionnoself','--fusion.noself',action ='store_true',dest='fusion_noself',help='Add .noself version to fusion outputs (excludes latent from same input)')
parser.add_argument('--fusionnoatlas','--fusion.noatlas',action ='store_true',dest='fusion_noatlas',help='Add .noatlas version to fusion outputs (excludes latent from same atlas)')
parser.add_argument('--fusionnorm',action='store_true',dest='fusionnorm',help='re-normalize latent vectors after averaging')
parser.add_argument('--onlyfusioninputs',action='store_true',dest='only_fusion_inputs',help='Only predict outputs from fusion inputs (not from individual flavors)')
parser.add_argument('--inputjson',action='store',dest='input_json', help='.json file containing filenames for checkpoints, xforms, and input data for each flavor',nargs='*')
parser.add_argument('--inputjson_search_directories',action='store',dest='input_json_search_directories', help='List of directories to search for data files from input json',nargs='*')
parser.add_argument('--inputjson_fetch_files',action='store_true',dest='input_json_fetch_files', help='Fetch files from input json listing if possible')
parser.add_argument('--inputjson_load_types',action='store',dest='input_json_load_types', help='Data to load from json. (["checkpoint","xform","data"])',nargs='*')
misc_group=parser.add_argument_group('Misc')
misc_group.add_argument('--heatmap',action='store',dest='heatmap_file', help='Save heatmap image')
misc_group.add_argument('--heatmapmetrics',action='store',dest='heatmap_metrictype_list', help='List of metric types for heatmap',nargs='*')
misc_group.add_argument('--heatmap_colormap',action='store',dest='heatmap_colormap', default='magma', help='Colormap name for heatmap')
misc_group.add_argument('--heatmapcsv',action='store',dest='heatmap_csv_file', help='Save heatmap csv file')
misc_group.add_argument('--newtrainrecord',action='store',dest='new_train_record_file', help='Save a "fake" trainrecord file')
misc_group.add_argument('--subjectsplitfile','--subjectfile',action='store',dest='subject_split_file', help='.mat file containing pre-saved "subjects","subjidx_train","subjidx_val","subjidx_test" fields (or "trainrecord" to use from training record)')
misc_group.add_argument('--subjectsplitname',action='store',dest='subject_split_name', help='Which data split to evaluate: "all", "train", "test", "val", "retest", etc... (overrides --inputdata for hardcoded HCP)')
misc_group.add_argument('--metricincludeallsubj',action='store_true',dest='metric_include_allsubj', help='Include per-subject accuracy in outputs (default: only include summary metrics)')
testing_group=parser.add_argument_group('Testing options')
testing_group.add_argument('--savetransformedinputs',action='store_true',dest='save_transformed_inputs', help='Transform inputs and save them as "encoded" values (for testing PCA/transformed space data)')
testing_group.add_argument('--untransformedoutputs',action='store_true',dest='save_untransformed_outputs', help='Keep outputs in PCA/transformed space (for testing)')
testing_group.add_argument('--hackcortex',action='store_true',dest='hack_cortex', help='Hack to only use cortex for eval (for fs86 and coco439) (for testing)')
testing_group.add_argument('--ccmat',action='store_true',dest='ccmat', help='Save full SUBJxSUBJ ccmat for every prediction path (large record file)')
testing_group.add_argument('--usetrainmean',action='store_true',dest='use_train_mean', help='Use training mean for demean if available in ioxform')
parser.add_argument('--version', action='version',version='Krakencoder v{version}'.format(version=get_version(include_date=True)))
args=parser.parse_args(argv)
args=clean_args(args,arg_defaults)
return args
def search_flavors(searchstring_list,full_list, return_index=False):
if isinstance(searchstring_list,str):
searchstring_list=[searchstring_list]
if len(searchstring_list)==0:
new_list=full_list.copy()
new_list_index=[0]*len(full_list)
else:
new_list=[]
new_list_index=[]
for i,s in enumerate(searchstring_list):
try:
s_canonical=canonical_data_flavor(s)
except:
s_canonical=None
s_new=[s]
if s.lower() == 'all':
s_new=full_list
elif s.lower() in ['encoded','fusion','transformed']:
s_new=[s.lower()]
elif s in full_list:
s_new=[s]
elif s_canonical is not None and s_canonical in full_list:
s_new=[s_canonical]
elif s.lower()=='sc':
s_new=[intype for intype in full_list if intype.startswith("SC") or "sdstream" in intype or "ifod2act" in intype]
elif s.lower()=='fc':
s_new=[intype for intype in full_list if intype.startswith("FC")]
else:
s_new=[intype for intype in full_list if s in intype]
new_list+=s_new
new_list_index+=[i]*len(s_new)
#return unique elements in list, in their original order
uidx=np.sort(np.unique(new_list,return_index=True)[1])
new_list=[str(new_list[i]) for i in uidx]
new_list_index=[new_list_index[i] for i in uidx]
if return_index:
return new_list,new_list_index
else:
return new_list
def run_model_on_new_data(argv=None):
if argv is None:
argv=sys.argv[1:]
#read in command-line inputs
args=argument_parse_newdata(argv)
ptfile=args.checkpoint
innerptfile=args.innercheckpoint
recordfile=args.trainrecord
fusionmode=args.fusion
fusionnorm=args.fusionnorm
fusion_noself=args.fusion_noself
fusion_noatlas=args.fusion_noatlas
input_fusionmode_names=args.fusion_include
only_fusion_mode=args.only_fusion_inputs
do_save_transformed_inputs=args.save_transformed_inputs
outputs_in_model_space=args.save_untransformed_outputs
outfile = args.output
do_output_squaremats = args.output_squaremats
input_transform_file_list=args.input_transform_file
input_conntype_list=args.input_names
output_conntype_list=args.output_names
input_file_list=args.input_data_file
heatmapfile=args.heatmap_file
heatmap_metrictype_list=args.heatmap_metrictype_list
heatmap_colormap=args.heatmap_colormap
heatmap_csv_file=args.heatmap_csv_file
new_train_recordfile=args.new_train_record_file
input_subject_split_file=args.subject_split_file
input_subject_split_name=args.subject_split_name
metric_include_extra_resid=False
metric_include_allsubj=args.metric_include_allsubj
input_json=args.input_json
input_json_search_directories=args.input_json_search_directories
input_json_fetch_files=args.input_json_fetch_files
input_json_load_types=args.input_json_load_types
adapt_mode=args.adapt_mode
adapt_mode_split_name=args.adapt_mode_split_name
adapt_mode_file_list=args.adapt_mode_data_file
save_ccmat_in_record=args.ccmat
do_use_train_mean=args.use_train_mean
hack_cortex=args.hack_cortex
if hack_cortex:
trimask86=np.triu_indices(86,1)
trimask439=np.triu_indices(439,1)
cortex86=np.ones((86,86),dtype=bool)
cortex439=np.ones((439,439),dtype=bool)
cortex86[:18,:]=False
cortex86[:,:18]=False
cortex439[:81,:]=False
cortex439[:,:81]=False
hack_cortex_mask={'fs86':cortex86[trimask86],'coco439':cortex439[trimask439]}
if adapt_mode.lower()=='none':
adapt_mode=None
##############
#check for json option, and override checkpoints, xforms, inputfiles, etc...
if input_json:
if 'all' in input_json_load_types:
input_json_load_types=["checkpoint","xform","data"]
if len(input_conntype_list)==0:
raise Exception("Must specify input flavors with --inputname when using --inputjson")
print("Parsing options for input flavors from json file(s):",input_json)
json_directory_search_list=[model_data_folder()]
if input_json_search_directories is not None:
json_directory_search_list+=[d for d in input_json_search_directories]
flavor_input_info={}
for j in input_json:
tmpinfo=load_flavor_database(j, directory_search_list=json_directory_search_list, override_abs_path=True, fields_to_check=['checkpoint','xform','data'])
for k in tmpinfo:
flavor_input_info[k]=tmpinfo[k]
if ptfile is None:
ptfile=[]
#absolute paths for files that exist, otherwise just use the filename
ptfile=[os.path.abspath(x) if os.path.exists(os.path.abspath(x)) else x for x in ptfile]
if input_transform_file_list is None or all([x=='auto' for x in input_transform_file_list]):
input_transform_file_list=[]
input_transform_file_list=[os.path.abspath(x) if os.path.exists(os.path.abspath(x)) else x for x in input_transform_file_list]
if input_conntype_list[0].lower() == 'all':
input_conntype_list=list(flavor_input_info.keys())
elif input_conntype_list[0].lower() == 'sc':
input_conntype_list=[x for x in list(flavor_input_info.keys()) if x.startswith("SC")]
elif input_conntype_list[0].lower() == 'fc':
input_conntype_list=[x for x in list(flavor_input_info.keys()) if x.startswith("FC")]
if output_conntype_list[0].lower() == 'all':
output_conntype_list=list(flavor_input_info.keys())
elif output_conntype_list[0].lower() == 'sc':
output_conntype_list=[x for x in list(flavor_input_info.keys()) if x.startswith("SC")]
elif output_conntype_list[0].lower() == 'fc':
output_conntype_list=[x for x in list(flavor_input_info.keys()) if x.startswith("FC")]
conntypes_to_find=unique_preserve_order(np.concatenate([input_conntype_list,output_conntype_list]))
flavors_with_missing_files=[]
for conntype in conntypes_to_find:
if not conntype in flavor_input_info:
print("%s: No info for flavor in json file:" % (conntype),input_json)
continue
tmp_found=[]
tmp_chkpt=None
tmp_xform=None
tmp_datafile=None
tmp_all_exist=True
for f in input_json_load_types:
tmp_f=None
if flavor_input_info[conntype][f'{f}_exists']:
tmp_f=flavor_input_info[conntype][f]
tmp_found.append(f'- {f}: '+tmp_f)
elif input_json_fetch_files and flavor_input_info[conntype][f'{f}_fetchable']:
tmp_f=fetch_model_data(files_to_fetch=flavor_input_info[conntype][f],force_download=False)
tmp_found.append(f'- {f}: '+tmp_f)
elif flavor_input_info[conntype][f'{f}_fetchable']:
tmp_found.append(f'x (fetchable) {f}: '+flavor_input_info[conntype][f])
else:
tmp_found.append(f'x missing {f}: '+flavor_input_info[conntype][f])
if tmp_f is None:
tmp_all_exist=False
if f == 'checkpoint':
tmp_chkpt=tmp_f
elif f == 'xform':
tmp_xform=tmp_f
elif f == 'data':
tmp_datafile=tmp_f
if tmp_all_exist:
tmp_data=f'{conntype}={tmp_datafile}'
if tmp_chkpt not in ptfile:
ptfile.append(tmp_chkpt)
if tmp_xform not in input_transform_file_list:
input_transform_file_list.append(tmp_xform)
if tmp_data not in input_file_list and tmp_datafile not in input_file_list:
input_file_list.append(tmp_data)
print("\n%s: Found all files for flavor in json file:\n" % (conntype),"\n".join(["\t"+x for x in tmp_found]))
else:
print("\n%s: Necessary files not found for flavor in json file:\n" % (conntype),"\n".join(["\t"+x for x in tmp_found]))
flavors_with_missing_files.append(conntype)
if len(flavors_with_missing_files)>0:
raise FileNotFoundError("Necessary files not found for flavors in json file: ", flavors_with_missing_files)
if not 'checkpoint' in input_json_load_types:
ptfile=args.checkpoint
if not 'xform' in input_json_load_types:
input_transform_file_list=args.input_transform_file
##############
#load model checkpoint
warnings.filterwarnings("ignore", category=UserWarning, message="CUDA initialization")
outerptfile=None
if innerptfile is not None:
outerptfile=ptfile
ptfile=innerptfile
print("Loading inner model from %s" % (ptfile))
#download model data if necessary
for i,p in enumerate(ptfile):
p=fetch_model_data_if_needed(p)
if os.path.exists(p):
ptfile[i]=p
if outerptfile is not None:
for i,p in enumerate(outerptfile):
p=fetch_model_data_if_needed(p)
if os.path.exists(p):
outerptfile[i]=p
ptfile_list=[p for p in ptfile]
if len(ptfile)==1:
ptfile=ptfile[0]
net, checkpoint=Krakencoder.load_checkpoint(ptfile)
else:
net, checkpoint=merge_model_files(checkpoint_filename_list=ptfile, canonicalize_input_names=False)
print("Merged model info:")
print_merged_model(checkpoint)
ptfile='merged:'+ ','.join(ptfile)
##########
#handle special case for OLD checkpoints before we updated the connectivity flavors
#if the checkpoint uses the old style of flavor names, convert them to the new style
try:
all_old_names=all([canonical_data_flavor_OLD(x)==x for x in checkpoint['input_name_list']])
except:
all_old_names=False
if all_old_names:
print("This checkpoint uses old style of flavor names")
checkpoint['input_name_list']=[canonical_data_flavor(x) for x in checkpoint['input_name_list']]
checkpoint['training_params']['trainpath_names']=['%s->%s' %
(canonical_data_flavor(x.split("->")[0]),canonical_data_flavor(x.split("->")[1]))
for x in checkpoint['training_params']['trainpath_names']]
checkpoint['training_params']['trainpath_names_short']=['%s->%s' %
(canonical_data_flavor(x.split("->")[0]),canonical_data_flavor(x.split("->")[1]))
for x in checkpoint['training_params']['trainpath_names_short']]
########
conn_names=checkpoint['input_name_list']
trainpath_pairs = [[conn_names[i],conn_names[j]] for i,j in zip(checkpoint['trainpath_encoder_index_list'], checkpoint['trainpath_decoder_index_list'])]
#############
#for fusion mode, check whether we are doing multiple fusion types
if any(['=' in b for b in input_fusionmode_names]):
orig_input_fusionmode_names=input_fusionmode_names
input_fusionmode_names={}
for b in orig_input_fusionmode_names:
bname=b.split("=")[0]
blist=flatlist(b.split("=")[-1].split(","))
input_fusionmode_names[bname]=blist
else:
bname='fusion'
blist=flatlist([b.split(",") for b in input_fusionmode_names])
input_fusionmode_names={bname:blist}
#note: don't actually use trainrecord during model evaluation
#checkpoint includes all info about data flavors and model design
#but doesn't contain any info about loss functions, training schedule, etc...
#consider: in save_checkpoint, just add the trainrecord keys that aren't the
#big per-epoch loss values (ie: we have networkinfo[], make traininfo[] with those params)
#note: we might use this to get training/testing/val subject info
if recordfile == "auto":
recordfile=ptfile_list[0].replace("_checkpoint_","_trainrecord_")
recordfile=recordfile.replace("_chkpt_","_trainrecord_")
recordfile=re.sub(r"_(epoch|ep)[0-9]+\.pt$",".mat",recordfile)
if len(input_transform_file_list)>0 and input_transform_file_list[0] == "auto":
input_transform_file=ptfile_list[0].replace("_checkpoint_","_iox_")
input_transform_file=input_transform_file.replace("_chkpt_","_ioxfm_")
input_transform_file=re.sub(r"_(epoch|ep)[0-9]+\.pt$",".npy",input_transform_file)
input_transform_file_list=[input_transform_file]
if input_subject_split_file and input_subject_split_file.lower() == "trainrecord":
if os.path.exists(recordfile):
print("Using subject splits from training record: %s" % (recordfile))
input_subject_split_file=recordfile
else:
raise Exception("Training record not found. Cannot use 'trainrecord' subject split option. %s" % (recordfile))
input_subject_splits=None
subjects_train=None
subjects_val=None
subjects_test=None
subjects_to_eval=None #this might end iup being subjects_test, subjects_val, etc...
if input_subject_split_file:
input_subject_split_file=fetch_model_data_if_needed(input_subject_split_file)
subjects_to_eval_splitname='all'
if input_subject_split_name:
subjects_to_eval_splitname=input_subject_split_name.lower()
print("Loading subject splits from %s" % (input_subject_split_file))
input_subject_splits=loadmat(input_subject_split_file,simplify_cells=True)
for f in ["subjects", "subjidx_train", "subjidx_val", "subjidx_test"]:
if not f in input_subject_splits:
input_subject_splits[f]=[]
print("\t%d %s" % (len(input_subject_splits[f]),f))
subjects=input_subject_splits['subjects']
subjects=clean_subject_list(subjects)
subjects_train=[s for i,s in enumerate(subjects) if i in input_subject_splits['subjidx_train']]
subjects_val=[s for i,s in enumerate(subjects) if i in input_subject_splits['subjidx_val']]
subjects_test=[s for i,s in enumerate(subjects) if i in input_subject_splits['subjidx_test']]
if subjects_to_eval_splitname == 'all':
subjects_to_eval=subjects
elif 'subjidx_' + subjects_to_eval_splitname in input_subject_splits:
subjects_to_eval=[subjects[i] for i in input_subject_splits['subjidx_' + subjects_to_eval_splitname]]
else:
raise Exception("Invalid subject split name: %s" % (subjects_to_eval_splitname))
if len(input_file_list)>0:
if input_file_list[0] in ['all','test','train','val','retest']:
#use HCP data
pass
elif all([x.endswith(".zip") for x in input_file_list]):
tmp_inputfiles=input_file_list
input_conntype_list=[]
input_file_list=[]
for x in tmp_inputfiles:
tmpdata,tmpinfo=load_data_zip(x, just_read_info=True)
input_conntype_list+=[canonical_data_flavor(k) for k in tmpdata.keys()]
input_file_list+=[x for k in tmpdata.keys()] #just repeat this file for all
elif(all(["=" in x for x in input_file_list])):
tmp_inputfiles=input_file_list
input_conntype_list=[]
input_file_list=[]
for x in tmp_inputfiles:
input_conntype_list+=[x.split("=")[0]]
input_file_list+=[x.split("=")[-1]]
else:
#try to figure out conntypes from filenames
print("--inputname not provided. Guessing input type from filenames:")
input_conntype_list=[]
for x in input_file_list:
xc=canonical_data_flavor(justfilename(x))
input_conntype_list+=[xc]
print(" %s = %s" % (xc,x))
#if input "file list" is an HCP data split name, and we provided a subject split name argument, override the "file list" argument
if len(input_file_list) > 0 and input_file_list[0].lower() in ['all','test','train','val','retest']:
if input_subject_split_name in ['all','test','train','val','retest']:
input_file_list[0]=input_subject_split_name.lower()
######################
# parse user-specified input data type
#conn_names = full list of input names from model checkpoint
#input_conntype_canonical=canonical_data_flavor(input_conntype)
#input_conntype_list=[input_conntype_canonical]
input_group_dict={x.split('@')[0]:x.split('@')[1] for x in input_conntype_list if '@' in x}
input_conntype_list=[x.split('@')[0] for x in input_conntype_list]
input_conntype_list, input_conntype_idx =search_flavors(input_conntype_list, conn_names, return_index=True)
if len(input_file_list) > 0 and not input_file_list[0] in ['all','test','train','val','retest']:
#if we provided multiple input files, the conntype list may have been reordered during search_flavors
# or only a subset were used (if the model checkpoint only accepts certain inputs),
# so reorder the input files
input_file_list = [input_file_list[i] for i in input_conntype_idx]
print("Final conntype = filename mapping:")
for xc,x in zip(input_conntype_list,input_file_list):
print(" %s = %s" % (xc,x))
do_self_only=False
if "none" in [x.lower() for x in output_conntype_list]:
output_conntype_list=[]
elif "self" in [x.lower() for x in output_conntype_list]:
do_self_only=True
output_conntype_list=input_conntype_list.copy()
else:
output_conntype_list=search_flavors(output_conntype_list,conn_names)
#if user requested TRANSFORMED INPUTS (eg: PC-space inputs) dont produce any predicted outputs
if do_save_transformed_inputs:
#only output 'transformed' for this option
output_conntype_list=['transformed']
print("Input types (%d):" % (len(input_conntype_list)), input_conntype_list)
print("Output types (%d):" % (len(output_conntype_list)), output_conntype_list)
if fusionmode:
fusionmode_names_dict={k:search_flavors(v, input_conntype_list) for k,v in input_fusionmode_names.items()}
original_fusionmode_names_dict=fusionmode_names_dict.copy()
if fusion_noself:
#note: include all flavors as inputs for this, because the selection is done at the decoding stage
new_fusionmode_names_dict=fusionmode_names_dict.copy()
for k in original_fusionmode_names_dict:
new_fusionmode_names_dict[k+".noself"]=fusionmode_names_dict[k].copy()
fusionmode_names_dict=new_fusionmode_names_dict
if fusion_noatlas:
#note: include all flavors as inputs for this, because the selection is done at the decoding stage
new_fusionmode_names_dict=fusionmode_names_dict.copy()
for k in original_fusionmode_names_dict:
new_fusionmode_names_dict[k+".noatlas"]=fusionmode_names_dict[k].copy()
fusionmode_names_dict=new_fusionmode_names_dict
for k in fusionmode_names_dict:
print("fusion mode '%s' input types (%d):" % (k,len(fusionmode_names_dict[k])), fusionmode_names_dict[k])
else:
fusionmode_names_dict={}
#build a list of output files (either consolidated, per input/output or per input->output path)
if only_fusion_mode and fusionmode_names_dict:
eval_input_conntype_list=list(fusionmode_names_dict.keys())
elif fusionmode_names_dict:
eval_input_conntype_list=input_conntype_list.copy()+list(fusionmode_names_dict.keys())
else:
eval_input_conntype_list=input_conntype_list.copy()
#handle optional adaptation data files (to use instead of input mean)
#(needs to be down here because we need the final input_conntype_list)
adapt_mode_file_dict={c:None for c in input_conntype_list}
if adapt_mode_file_list is not None and len(adapt_mode_file_list)>0:
if(all(["=" in x for x in adapt_mode_file_list])):
for x in adapt_mode_file_list:
c=x.split("=")[0]
cf=x.split("=")[-1]
if c.lower()=='all':
#providing a single file that must contain all conntypes
adapt_mode_file_dict={c:cf for c in input_conntype_list}
break
else:
adapt_mode_file_dict[c]=cf
elif len(adapt_mode_file_list)==1 and len(input_conntype_list)==1:
adapt_mode_file_dict[input_conntype_list[0]]=adapt_mode_file_list[0]
else:
raise ValueError("When providing multiple --adaptmodedata files, must use name=file format.")
if any([adapt_mode_file_dict[c] is None for c in adapt_mode_file_dict]):
raise ValueError("When providing --adaptmodedata files, must have one for each input conntype")
#handle some shortcuts for the input/output filenames
do_save_output_data=True
if outfile is None:
do_save_output_data=False
outfile="dummy_{input}_{output}.mat"
outfile_template=outfile
if re.search("{.+}",outfile_template):
instrlist=["i","in","input","s","src","source"]
outstrlist=["o","out","output","t","trg","targ","target"]
for s in instrlist:
outfile_template=outfile_template.replace("{"+s.upper()+"}","{input}")
outfile_template=outfile_template.replace("{"+s+"}","{input}")
for s in outstrlist:
outfile_template=outfile_template.replace("{"+s.upper()+"}","{output}")
outfile_template=outfile_template.replace("{"+s+"}","{output}")
outfile_list=[]
outfile_input_output_list=[]
if "{input}" in outfile_template and "{output}" in outfile_template:
for intype in eval_input_conntype_list:
for outtype in output_conntype_list:
if do_self_only and outtype != intype:
continue
outfile_tmp=outfile_template.replace("{input}",intype).replace("{output}",outtype)
outfile_list+=[outfile_tmp]
outfile_input_output_list+=[{"intype":[intype],"outtype":[outtype]}]
elif "{input}" in outfile_template:
for intype in eval_input_conntype_list:
outfile_tmp=outfile_template.replace("{input}",intype)
outfile_list+=[outfile_tmp]
if do_self_only:
outfile_input_output_list+=[{"intype":[intype],"outtype":[intype]}]
else:
outfile_input_output_list+=[{"intype":[intype],"outtype":output_conntype_list}]
elif "{output}" in outfile_template:
for outtype in output_conntype_list:
outfile_tmp=outfile_template.replace("{output}",outtype)
outfile_list+=[outfile_tmp]
if do_self_only:
outfile_input_output_list+=[{"intype":[outtype],"outtype":[outtype]}]
else:
outfile_input_output_list+=[{"intype":eval_input_conntype_list,"outtype":[outtype]}]
else:
outfile_list+=[outfile_template]
outfile_input_output_list+=[{"intype":eval_input_conntype_list,"outtype":output_conntype_list}]
##############
#load input transformers
if not input_transform_file_list and checkpoint['input_transformation_info'].upper()!='NONE':
print("Must provide input transform (ioxfm) file")
sys.exit(1)
transformer_list={}
transformer_info_list={}
if checkpoint['input_transformation_info']=='none':
for conntype in conn_names:
transformer_list[conntype], transformer_info[conntype] = generate_transformer(traindata=None,
transformer_type='none', transformer_param_dict=None,
precomputed_transformer_params={'type':'none'}, return_components=True)
else:
for i,ioxfile in enumerate(input_transform_file_list):
ioxfile=fetch_model_data_if_needed(ioxfile)
if os.path.exists(ioxfile):
input_transform_file_list[i]=ioxfile
transformer_list, transformer_info_list = load_transformers_from_file(input_transform_file_list, input_names=conn_names)
traindata_mean_list={}
for conntype, transformer_info in transformer_info_list.items():
if 'params' in transformer_info and 'pca_input_mean' in transformer_info['params']:
traindata_mean_list[conntype]=transformer_info['params']['pca_input_mean']
if 'params' in transformer_info and 'input_mean' in transformer_info['params']:
traindata_mean_list[conntype]=transformer_info['params']['input_mean']
##########
#handle special case for OLD saved transformers before we updated the connectivity flavors
#if the transformers use the old style of flavor names, convert them to the new style
try:
all_old_names=all([canonical_data_flavor_OLD(x)==x for x in transformer_list])
except:
all_old_names=False
if all_old_names:
print("The transformers use old style of flavor names")
transformer_list={canonical_data_flavor(x):transformer_list[x] for x in transformer_list}
transformer_info_list={canonical_data_flavor(x):transformer_info_list[x] for x in transformer_info_list}
########
######################
#load input data
output_subject_splits=input_subject_splits
if len(input_file_list) > 0 and not input_file_list[0] in ['all','test','train','val','retest']:
#load data from files specified in command line
if subjects_to_eval is not None:
output_subject_splits=None
conndata_alltypes={}
for i,x in enumerate(input_conntype_list):
if input_file_list[i].endswith(".zip"):
if x in conndata_alltypes and i > 0 and input_file_list[i] in input_file_list[:i]:
#already read this zip file (and we read all data from it)
break
try:
conndata_tmp,participant_info_tmp=load_data_zip(input_file_list[i],conntypes_to_load=input_conntype_list)
if 'subject' in participant_info_tmp:
subj_tmp=list(participant_info_tmp['subject'])
else:
subj_tmp=list(participant_info_tmp['participant_id'])
subj_tmp=clean_subject_list(subj_tmp)
for kk in input_conntype_list:
if kk in conndata_tmp:
conndata_alltypes[kk]=load_data_square2tri(np.array(conndata_tmp[kk]), subjects=subj_tmp, group=None)
#conndata_alltypes[kk]={'subjects':subj_tmp,'data':conndata_tmp[kk]}
del conndata_tmp
except:
print("Only bids-ish zips can be loaded by this function.")
exit(0)
else:
conndata_alltypes[x]=load_input_data(inputfile=input_file_list[i], inputfield=None)
if 'subjects' in conndata_alltypes[x]:
conndata_alltypes[x]['subjects']=clean_subject_list(conndata_alltypes[x]['subjects'])
adxfm_info_alltypes={}
for i,x in enumerate(input_conntype_list):
#how should we adapt the input data to the model?
if input_subject_splits and adapt_mode_split_name is not None:
if 'subjidx_'+adapt_mode_split_name in input_subject_splits:
subjects_adapt=[s for i,s in enumerate(subjects) if i in input_subject_splits['subjidx_'+adapt_mode_split_name]]
subjidx_adapt=np.array([i for i,s in enumerate(conndata_alltypes[x]['subjects']) if s in subjects_adapt])
elif adapt_mode_split_name == 'all':
subjidx_adapt=np.arange(conndata_alltypes[x]['data'].shape[0])
else:
raise Exception("Invalid adaptmode subject split name: %s" % (adapt_mode_split_name))
elif subjects_train is not None:
subjidx_adapt=np.array([i for i,s in enumerate(conndata_alltypes[x]['subjects']) if s in subjects_train])
else:
subjidx_adapt=np.arange(conndata_alltypes[x]['data'].shape[0])
adapt_data_tmp=conndata_alltypes[x]['data']
subjidx_adapt_tmp=subjidx_adapt
adapt_source_name="input"
if x in adapt_mode_file_dict and adapt_mode_file_dict[x] is not None:
try:
#first try to load the data from a multi-conntype file using conntype as the fieldname
adapt_data_tmp=load_input_data(inputfile=adapt_mode_file_dict[x], inputfield=x)
except:
#if that fails, assume it is a single conntype file
adapt_data_tmp=load_input_data(inputfile=adapt_mode_file_dict[x], inputfield=None)
adapt_data_tmp=adapt_data_tmp['data']
subjidx_adapt_tmp=None
adapt_source_name="adaptsource"
adxfm, adxfm_info=generate_adapt_transformer(input_data=adapt_data_tmp,
target_data=transformer_info_list[x],
adapt_mode=adapt_mode,
input_data_fitsubjmask=subjidx_adapt_tmp,
return_fit_info=True,
input_source_name=adapt_source_name)
adxfm_info_alltypes[x]=adxfm_info
conndata_alltypes[x]['data']=adxfm.transform(conndata_alltypes[x]['data'])
if subjects_to_eval is not None:
subjidx_for_eval=np.array([i for i,s in enumerate(conndata_alltypes[x]['subjects']) if s in subjects_to_eval])
conndata_alltypes[x]['data']=conndata_alltypes[x]['data'][subjidx_for_eval]
print(x,conndata_alltypes[x]['data'].shape)
else:
#load hardcoded HCP data files
input_file="all"
if len(input_file_list) > 0:
if any([x in input_file_list for x in ["train","val","test"]]) and input_subject_splits is None:
raise Exception("Must provide --subjectsplitfile for train, val, test options")
subjects_out=None
conndata_alltypes=None
if any([x in input_file_list for x in ["train","val","test","retest"]]):
#any time we request ONLY a specific split, after reading in that split's data, we need
# to set the split dict to None so that the trainrecord/heatmap creation later uses all data
# from the requested split
output_subject_splits=None
for i in input_file_list:
if i=="all":
subjects_tmp, _ = load_hcp_subject_list(numsubj=993)
subjects_out_tmp, conndata_alltypes_tmp = load_hcp_data(subjects=subjects_tmp, conn_name_list=input_conntype_list, quiet=False)
elif i=="retest":
subjects_out_tmp, conndata_alltypes_tmp = load_hcp_data(subjects=None, conn_name_list=input_conntype_list, load_retest=True,quiet=False)
elif i=="train":
subjects_out_tmp, conndata_alltypes_tmp = load_hcp_data(subjects=subjects_train, conn_name_list=input_conntype_list, quiet=False)
elif i=="val":
subjects_out_tmp, conndata_alltypes_tmp = load_hcp_data(subjects=subjects_val, conn_name_list=input_conntype_list, quiet=False)
elif i=="test":
subjects_out_tmp, conndata_alltypes_tmp = load_hcp_data(subjects=subjects_test, conn_name_list=input_conntype_list, quiet=False)
if subjects_out is None:
subjects_out=subjects_out_tmp
else:
if any([s in subjects_out for s in subjects_out_tmp]):
raise Exception("Duplicate subject for in input data list: %s" % (", ".join(input_file_list)))
subjects_out=np.concatenate((subjects_out,subjects_out_tmp))
if conndata_alltypes is None:
conndata_alltypes=conndata_alltypes_tmp
else:
for conntype in conndata_alltypes.keys():
conndata_alltypes[conntype]['data']=np.vstack((conndata_alltypes[conntype]['data'],conndata_alltypes_tmp[conntype]['data']))
conndata_alltypes[conntype]['subjects']=subjects_out
else:
subjects, famidx = load_hcp_subject_list(numsubj=993)
subjects_out, conndata_alltypes = load_hcp_data(subjects=subjects, conn_name_list=input_conntype_list, quiet=False)
for conntype in conndata_alltypes.keys():
conndata_alltypes[conntype]['subjects']=clean_subject_list(subjects_out)
quiet=False
if not quiet:
for conntype in conndata_alltypes.keys():
var_explained_ratio=explained_variance_ratio(torchfloat(conndata_alltypes[conntype]['data']),
transformer_list[conntype].inverse_transform(
transformer_list[conntype].transform(conndata_alltypes[conntype]['data'])
))
print("%s (%dx%d) variance maintained: %.2f%%" % (conntype,conndata_alltypes[conntype]['data'].shape[0],conndata_alltypes[conntype]['data'].shape[1],var_explained_ratio*100))
#if we have an inner and outer checkpoint, we loaded the inner first, now load outer
#cant do this until we have all of the data transformers built
if outerptfile is not None:
print("Loading outer model from %s" % (outerptfile))
inner_net=net
inner_checkpoint_dict=checkpoint
data_transformer_list=[]
data_inputsize_list=[]
none_transformer, none_transformer_info = generate_transformer(transformer_type="none")
for i_conn, conn_name in enumerate(conn_names):
if conn_name in transformer_info_list:
transformer, transformer_info = generate_transformer(transformer_type=transformer_info_list[conn_name]["params"]["type"], precomputed_transformer_params=transformer_info_list[conn_name]["params"])
else:
transformer=none_transformer
transformer_info=none_transformer_info
if conn_name in conndata_alltypes:
inputsize=conndata_alltypes[conn_name]['data'].shape[1]
#inputsize=1
else:
inputsize=1
data_transformer_list+=[transformer]
data_inputsize_list+=[inputsize]
net, checkpoint = KrakenAdapter.load_checkpoint(filename=outerptfile, inner_model=inner_net, data_transformer_list=data_transformer_list, inner_model_extra_dict=inner_checkpoint_dict)
#now that we've built the outer model with the data transformers built in, we need to reset the data transformers
#objects to "none"
transformer_list={}
transformer_info_list={}
for conn_name in conn_names:
transformer_list[conn_name]=none_transformer
transformer_info_list[conn_name]=none_transformer_info
encoded_alltypes={}
predicted_alltypes={}
neg1_torch=torchint(-1)
net.eval()
#keep track of outputs where we converted category inputs to one-hot
conntype_output_is_onehot={c:False for c in input_conntype_list}
#encode all inputs to latent space
for intype in input_conntype_list:
if intype == "encoded" or intype.startswith("fusion"):
encoded_name=intype
conn_encoded=torchfloat(conndata_alltypes[intype]['data'])
encoded_alltypes[intype]=numpyvar(conn_encoded)
else:
if not intype in conn_names:
raise Exception("Input type %s not found in model" % (intype))
encoder_index=[idx for idx,c in enumerate(conn_names) if c==intype][0]
encoder_index_torch=torchint(encoder_index)
inputdata_origscale=conndata_alltypes[intype]['data']
inputdata=torchfloat(transformer_list[intype].transform(inputdata_origscale))
if do_save_transformed_inputs:
conn_encoded=inputdata
else:
if inputdata.shape[1] == 1 and net.inputsize_list[encoder_index] > 1:
#this is a categorical variable
#expand to the number of input features
num_categories=net.inputsize_list[encoder_index]
inputdata=torch.nn.functional.one_hot(inputdata[:,0].to(int),num_classes=num_categories).float()
conntype_output_is_onehot[intype]=True
with torch.no_grad():
conn_encoded = net(inputdata, encoder_index_torch, neg1_torch)
encoded_alltypes[intype]=numpyvar(conn_encoded)
#fusionmode averaging in encoding latent space
for fusiontype, fusionmode_names in fusionmode_names_dict.items():
print("%s: fusion mode evaluation. Computing mean of input data flavors in latent space." % (fusiontype))
print("%s: input types " % (fusiontype),fusionmode_names)
encoded_mean=None
encoded_inputtype_count=0
for intype in encoded_alltypes.keys():
if intype == fusiontype:
print("fusion type '%s' evaluation already computed in input" % (fusiontype))
encoded_mean=encoded_alltypes[intype].copy()
encoded_inputtype_count=1
break
else:
if not intype in fusionmode_names:
#only average encodings from specified input types
continue
print("fusion type '%s' evaluation includes: %s" % (fusiontype,intype))
conn_encoded=encoded_alltypes[intype].copy()
if encoded_mean is None:
encoded_mean=conn_encoded
else:
encoded_mean+=conn_encoded
encoded_inputtype_count+=1
encoded_mean=encoded_mean/encoded_inputtype_count
encoded_norm=np.sqrt(np.sum(encoded_mean**2,axis=1,keepdims=True))
if fusionnorm:
print("Mean '%s' vector length before re-normalization: %.4f (renormed to 1.0)" % (fusiontype,np.mean(encoded_norm)))
encoded_mean=encoded_mean/encoded_norm
else:
print("Mean '%s' vector length: %.4f" % (fusiontype,np.mean(encoded_norm)))
encoded_alltypes[fusiontype]=encoded_mean.copy()
#now decode encoded inputs to requested outputs
for intype in encoded_alltypes.keys():
outtypes_for_this_input=np.unique([x['outtype'] for x in outfile_input_output_list if intype in x['intype']])
if len(outtypes_for_this_input) == 0:
continue
############## intergroup
#use this so we know if its intergroup
if intype in conn_names:
encoder_index=[idx for idx,c in enumerate(conn_names) if c==intype][0]
else:
encoder_index=-1
encoder_index_torch=torchint(encoder_index)
############## end intergroup
predicted_alltypes[intype]={}
for outtype in outtypes_for_this_input:
if outtype == "encoded" or outtype=='transformed' or outtype in fusionmode_names_dict:
continue
if not outtype in conn_names:
raise Exception("Output type %s not found in model" % (outtype))
if do_self_only and intype != outtype:
continue
decoder_index=[idx for idx,c in enumerate(conn_names) if c==outtype][0]
decoder_index_torch=torchint(decoder_index)
############# intergroup
conn_encoded=None
if intype in fusionmode_names_dict and (".noself" in intype or ".noatlas" in intype):
pass
else:
conn_encoded=torchfloat(encoded_alltypes[intype])
if net.intergroup:
#net.inputgroup_list has items, but no neames, but net.inputgroup_list should correspond to encoder_index and decoder_index
#conntype_group_dict=net.inputgroup_list
conntype_encoder_index_dict={k:i for i,k in enumerate(conn_names)}
encoded_alltypes_thisgroup={}
for k_in in encoded_alltypes:
if k_in in conntype_encoder_index_dict:
encidx=conntype_encoder_index_dict[k_in]
else:
encidx=-1
encoded_alltypes_thisgroup[k_in]=net.intergroup_transform_latent(torchfloat(encoded_alltypes[k_in]),torchint(encidx),decoder_index_torch).cpu().detach().numpy()
if conn_encoded is not None:
conn_encoded=net.intergroup_transform_latent(conn_encoded,encoder_index_torch,decoder_index_torch)
else:
encoded_alltypes_thisgroup=encoded_alltypes
############# end intergroup
noself_str=''
if intype in fusionmode_names_dict and ".noself" in intype:
#"noself" fusion mode has to compute a new average of all inputs except self