-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathcount_params.py
More file actions
103 lines (85 loc) · 3.13 KB
/
count_params.py
File metadata and controls
103 lines (85 loc) · 3.13 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
# ignoring all issues with config keys
# pyright: reportArgumentType=false
import warnings
from glob import glob
from pathlib import Path
from typing import Any
import pandas as pd
from omegaconf import OmegaConf
from tqdm import tqdm
from ebes.model import build_model
DATASETS_PRETTY = {
"mbd": "MBD",
"x5": "Retail",
"age": "Age",
"taobao": "Taobao",
"bpi_17": "BPI17",
"physionet2012": "PhysioNet2012",
"mimic3": "MIMIC-III",
"pendulum_cls": "Pendulum",
"arabic": "ArabicDigits",
"electric_devices": "ElectricDevices",
}
METHODS_PRETTY = {
"coles": "CoLES",
"gru": "GRU",
"mlem": "MLEM",
"transformer": "Transformer",
"mamba": "Mamba",
"convtran": "ConvTran",
"mtand": "mTAND",
"primenet": "PrimeNet",
"mlp": "MLP",
}
# Suppress the specific UserWarning
warnings.filterwarnings(
"ignore",
category=UserWarning,
message="dropout option adds dropout after all but last recurrent layer.*",
)
def collect_config(dataset, method, specify=None) -> dict[str, Any]:
data_config = OmegaConf.load(Path(f"configs/datasets/{dataset}.yaml"))
method_config = OmegaConf.load(Path(f"configs/methods/{method}.yaml"))
exp_config = OmegaConf.load(Path("configs/experiments/test.yaml"))
if specify is None:
specify_path = Path(f"configs/specify/{dataset}/{method}/best.yaml")
else:
specify_path = Path(specify)
configs = [data_config, method_config, exp_config]
configs.append(OmegaConf.load(specify_path))
config = OmegaConf.merge(*configs)
config["device"] = "cpu"
return config # type: ignore
def get_param_counts(dataset, method, specify=None):
conf = collect_config(dataset, method, specify)
model = build_model(conf["model"])
return sum(p.numel() for p in model.parameters() if p.requires_grad)
if __name__ == "__main__":
index = pd.MultiIndex.from_product(
[DATASETS_PRETTY.values(), ["min", "best", "max"]], names=["Method", "Params"]
)
res = pd.DataFrame(index=index, columns=METHODS_PRETTY.values())
for dataset in DATASETS_PRETTY:
print(dataset.upper(), "STARTED")
for method in METHODS_PRETTY:
best_c = get_param_counts(dataset, method)
best_c = f"{best_c:.1e}"
optuna_counts = []
for spec in tqdm(glob(f"log/{dataset}/{method}/optuna/*/params.txt")):
spec = Path(spec)
if not spec.with_name("results.csv").exists():
continue
optuna_counts += [get_param_counts(dataset, method, spec)]
min_c, max_c = f"{min(optuna_counts):.1e}", f"{max(optuna_counts):.1e}"
res.loc[(DATASETS_PRETTY[dataset], "min"), METHODS_PRETTY[method]] = min_c
res.loc[(DATASETS_PRETTY[dataset], "best"), METHODS_PRETTY[method]] = (
f"\\cellcolor{{lightgray}}{best_c}"
)
res.loc[(DATASETS_PRETTY[dataset], "max"), METHODS_PRETTY[method]] = max_c
print(
res.to_latex(
bold_rows=True,
column_format="r" * (len(METHODS_PRETTY) + 2),
)
)
res.to_csv("log/Ablations/param_counts.csv")