Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 67 additions & 24 deletions qlib/data/dataset/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,11 +295,14 @@ class DataHandlerLP(DataHandler):

# process type
PTYPE_I = "independent"
# - self._infer will be processed by infer_processors
# - self._learn will be processed by learn_processors
# - self._infer will be processed by shared_processors + infer_processors
# - self._learn will be processed by shared_processors + learn_processors

# NOTE:
PTYPE_A = "append"
# - self._infer will be processed by infer_processors
# - self._learn will be processed by infer_processors + learn_processors

# - self._infer will be processed by shared_processors + infer_processors
# - self._learn will be processed by shared_processors + infer_processors + learn_processors
# - (e.g. self._infer processed by learn_processors )

def __init__(
Expand All @@ -308,8 +311,9 @@ def __init__(
start_time=None,
end_time=None,
data_loader: Union[dict, str, DataLoader] = None,
infer_processors=[],
learn_processors=[],
infer_processors: List = [],
learn_processors: List = [],
shared_processors: List = [],
process_type=PTYPE_A,
drop_raw=False,
**kwargs,
Expand Down Expand Up @@ -360,7 +364,8 @@ def __init__(
# Setup preprocessor
self.infer_processors = [] # for lint
self.learn_processors = [] # for lint
for pname in "infer_processors", "learn_processors":
self.shared_processors = [] # for lint
for pname in "infer_processors", "learn_processors", "shared_processors":
for proc in locals()[pname]:
getattr(self, pname).append(
init_instance_by_config(
Expand All @@ -375,9 +380,12 @@ def __init__(
super().__init__(instruments, start_time, end_time, data_loader, **kwargs)

def get_all_processors(self):
return self.infer_processors + self.learn_processors
return self.shared_processors + self.infer_processors + self.learn_processors

def fit(self):
"""
fit data without processing the data
"""
for proc in self.get_all_processors():
with TimeInspector.logt(f"{proc.__class__.__name__}"):
proc.fit(self._data)
Expand All @@ -390,45 +398,80 @@ def fit_process_data(self):
"""
self.process_data(with_fit=True)

@staticmethod
def _run_proc_l(
df: pd.DataFrame, proc_l: List[processor_module.Processor], with_fit: bool, check_for_infer: bool
) -> pd.DataFrame:
for proc in proc_l:
if check_for_infer and not proc.is_for_infer():
raise TypeError("Only processors usable for inference can be used in `infer_processors` ")
with TimeInspector.logt(f"{proc.__class__.__name__}"):
if with_fit:
proc.fit(df)
df = proc(df)
return df

@staticmethod
def _is_proc_readonly(proc_l: List[processor_module.Processor]):
"""
NOTE: it will return True if `len(proc_l) == 0`
"""
for p in proc_l:
if not p.readonly():
return False
return True

def process_data(self, with_fit: bool = False):
"""
process_data data. Fun `processor.fit` if necessary

Notation: (data) [processor]

# data processing flow of self.process_type == DataHandlerLP.PTYPE_I
(self._data)-[shared_processors]-(_shared_df)-[learn_processors]-(_learn_df)
\
-[infer_processors]-(_infer_df)

# data processing flow of self.process_type == DataHandlerLP.PTYPE_A
(self._data)-[shared_processors]-(_shared_df)-[infer_processors]-(_infer_df)-[learn_processors]-(_learn_df)

Parameters
----------
with_fit : bool
The input of the `fit` will be the output of the previous processor
"""
# shared data processors
# 1) assign
_shared_df = self._data
if not self._is_proc_readonly(self.shared_processors): # avoid modifying the original data
_shared_df = _shared_df.copy()
# 2) process
_shared_df = self._run_proc_l(_shared_df, self.shared_processors, with_fit=with_fit, check_for_infer=True)

# data for inference
_infer_df = self._data
if len(self.infer_processors) > 0 and not self.drop_raw: # avoid modifying the original data
# 1) assign
_infer_df = _shared_df
if not self._is_proc_readonly(self.infer_processors): # avoid modifying the original data
_infer_df = _infer_df.copy()
# 2) process
_infer_df = self._run_proc_l(_infer_df, self.infer_processors, with_fit=with_fit, check_for_infer=True)

for proc in self.infer_processors:
if not proc.is_for_infer():
raise TypeError("Only processors usable for inference can be used in `infer_processors` ")
with TimeInspector.logt(f"{proc.__class__.__name__}"):
if with_fit:
proc.fit(_infer_df)
_infer_df = proc(_infer_df)
self._infer = _infer_df

# data for learning
# 1) assign
if self.process_type == DataHandlerLP.PTYPE_I:
_learn_df = self._data
elif self.process_type == DataHandlerLP.PTYPE_A:
# based on `infer_df` and append the processor
_learn_df = _infer_df
else:
raise NotImplementedError(f"This type of input is not supported")

if len(self.learn_processors) > 0: # avoid modifying the original data
if not self._is_proc_readonly(self.learn_processors): # avoid modifying the original data
_learn_df = _learn_df.copy()
for proc in self.learn_processors:
with TimeInspector.logt(f"{proc.__class__.__name__}"):
if with_fit:
proc.fit(_learn_df)
_learn_df = proc(_learn_df)
# 2) process
_learn_df = self._run_proc_l(_learn_df, self.learn_processors, with_fit=with_fit, check_for_infer=False)

self._learn = _learn_df

if self.drop_raw:
Expand Down
17 changes: 17 additions & 0 deletions qlib/data/dataset/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,14 @@ def is_for_infer(self) -> bool:
"""
return True

def readonly(self) -> bool:
"""
Does the processor treat the input data readonly (i.e. does not write the input data) when processsing

Knowning the readonly information is helpful to the Handler to avoid uncessary copy
"""
return False

def config(self, **kwargs):
attr_list = {"fit_start_time", "fit_end_time"}
for k, v in kwargs.items():
Expand All @@ -92,6 +100,9 @@ def __init__(self, fields_group=None):
def __call__(self, df):
return df.dropna(subset=get_group_columns(df, self.fields_group))

def readonly(self):
return True


class DropnaLabel(DropnaProcessor):
def __init__(self, fields_group="label"):
Expand All @@ -113,6 +124,9 @@ def __call__(self, df):
mask = df.columns.isin(self.col_list)
return df.loc[:, ~mask]

def readonly(self):
return True


class FilterCol(Processor):
def __init__(self, fields_group="feature", col_list=[]):
Expand All @@ -128,6 +142,9 @@ def __call__(self, df):
mask = df.columns.get_level_values(-1).isin(self.col_list)
return df.loc[:, mask]

def readonly(self):
return True


class TanhProcess(Processor):
"""Use tanh to process noise data"""
Expand Down
13 changes: 6 additions & 7 deletions tests/storage_tests/test_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from pathlib import Path
from collections.abc import Iterable

import pytest
import numpy as np
from qlib.tests import TestAutoData

Expand Down Expand Up @@ -33,13 +32,13 @@ def test_calendar_storage(self):
print(f"calendar[-1]: {calendar[-1]}")

calendar = CalendarStorage(freq="1min", future=False, provider_uri="not_found")
with pytest.raises(ValueError):
with self.assertRaises(ValueError):
print(calendar.data)

with pytest.raises(ValueError):
with self.assertRaises(ValueError):
print(calendar[:])

with pytest.raises(ValueError):
with self.assertRaises(ValueError):
print(calendar[0])

def test_instrument_storage(self):
Expand Down Expand Up @@ -90,10 +89,10 @@ def test_instrument_storage(self):
print(f"instrument['SH600000']: {instrument['SH600000']}")

instrument = InstrumentStorage(market="csi300", provider_uri="not_found")
with pytest.raises(ValueError):
with self.assertRaises(ValueError):
print(instrument.data)

with pytest.raises(ValueError):
with self.assertRaises(ValueError):
print(instrument["sSH600000"])

def test_feature_storage(self):
Expand Down Expand Up @@ -152,7 +151,7 @@ def test_feature_storage(self):

feature = FeatureStorage(instrument="SH600004", field="close", freq="day", provider_uri=self.provider_uri)

with pytest.raises(IndexError):
with self.assertRaises(IndexError):
print(feature[0])
assert isinstance(
feature[815][1], (float, np.float32)
Expand Down