-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathbpi_17.py
More file actions
125 lines (104 loc) · 3.43 KB
/
bpi_17.py
File metadata and controls
125 lines (104 loc) · 3.43 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
from argparse import ArgumentParser
from pathlib import Path
import pandas as pd
import pm4py
from tqdm.auto import tqdm
import pyspark.sql.functions as F
from pyspark.sql import SparkSession
from pyspark.sql.types import FloatType, LongType
from common import cat_freq, collect_lists
tqdm.pandas()
CAT_FEATURES = [
"Action",
"org_resource",
"concept_name",
"EventOrigin",
"lifecycle_transition",
"case_LoanGoal",
"case_ApplicationType",
]
NUM_FEATURES = [
"case_RequestedAmount",
"FirstWithdrawalAmount",
"NumberOfTerms",
"MonthlyCost",
"CreditScore",
"OfferedAmount",
]
INDEX_COLUMNS = ["OfferID", "Accepted"]
ORDERING_COLUMNS = ["time_timestamp"]
FIRST_TEST_DATE = "2016-10-20"
def get_seqs_from_app(df):
seqs = []
cr_off = df.query("`concept:name` == 'O_Create Offer'")
for idx in cr_off.index:
off = df.loc[:idx].copy()
offer_id = df.at[idx + 1, "OfferID"]
acc = df.at[idx, "Accepted"]
off["OfferID"] = offer_id
off["Accepted"] = acc
seqs.append(off.drop(["Selected", "EventID"], axis=1))
return pd.concat(seqs, axis=0)
def main():
parser = ArgumentParser()
parser.add_argument(
"--data-path",
help="Path to directory containing .xes.gz file",
required=True,
type=Path,
)
parser.add_argument(
"--save-path",
help="Where to save preprocessed parquets",
required=True,
type=Path,
)
parser.add_argument(
"--cat-codes-path",
help="Path where to save codes for categorical features",
type=Path,
)
parser.add_argument(
"--overwrite",
help='Toggle "overwrite" mode on all spark writes',
action="store_true",
)
args = parser.parse_args()
mode = "overwrite" if args.overwrite else "error"
pdf = pm4py.read_xes((args.data_path / "BPI Challenge 2017.xes.gz").as_posix())
pdf = pdf.groupby("case:concept:name", group_keys=False).progress_apply(
get_seqs_from_app, include_groups=False
)
pdf.columns = map(lambda s: s.replace(":", "_"), pdf.columns)
spark = SparkSession.builder.master("local[32]").getOrCreate() # pyright: ignore
df = spark.createDataFrame(pdf)
df = df.withColumn("Accepted", F.col("Accepted").cast(LongType()))
for nc in NUM_FEATURES:
df = df.withColumn(nc, F.col(nc).cast(FloatType()))
last = (
df.groupby("OfferID").agg(F.max("time_timestamp").alias("last_ev_dt")).cache()
)
train_clients = last.filter(f"last_ev_dt < '{FIRST_TEST_DATE}'").select("OfferID")
test_clients = last.select("OfferID").subtract(train_clients)
train_df = df.join(train_clients, on="OfferID")
test_df = df.join(test_clients, on="OfferID")
vcs = cat_freq(train_df, CAT_FEATURES)
for vc in vcs:
train_df = vc.encode(train_df)
test_df = vc.encode(test_df)
if args.cat_codes_path is not None:
vc.write(args.cat_codes_path / vc.feature_name, mode=mode)
train_df = collect_lists(
train_df,
group_by=INDEX_COLUMNS,
order_by=ORDERING_COLUMNS,
)
test_df = collect_lists(
test_df,
group_by=INDEX_COLUMNS,
order_by=ORDERING_COLUMNS,
)
train_df.coalesce(1).write.parquet((args.save_path / "train").as_posix(), mode=mode)
test_df.coalesce(1).write.parquet((args.save_path / "test").as_posix(), mode=mode)
if __name__ == "__main__":
main()