diff --git a/src/openeo_processes/cubes.py b/src/openeo_processes/cubes.py index 5b31178a..e4855943 100644 --- a/src/openeo_processes/cubes.py +++ b/src/openeo_processes/cubes.py @@ -2061,9 +2061,11 @@ def predict_random_forest(): class PredictRandomForest: @staticmethod - def exec_xar(data, model, dimension, client = None, input_filepath = None): + def exec_xar(data, dimension, model = None, context = None, client = None, input_filepath = None): if isinstance(model, str): model = load_ml_model(model, input_filepath = input_filepath) + if context is not None: + model = context if dimension in ['time', 't', 'times']: # time dimension must be converted into values dimension = get_time_dimension_from_data(data, dimension) predictor_cols = list(data.dims) @@ -2142,11 +2144,11 @@ def load_ml_model(): class LoadMLModel: @staticmethod - def exec_num(model, input_filepath = 'path'): - date = os.listdir(f'{input_filepath}/jobs/{model}') + def exec_num(id, input_filepath = 'path'): + date = os.listdir(f'{input_filepath}/jobs/{id}') if len(date) > 0: date = date[0] - filepath = f'{input_filepath}/jobs/{model}/{date}/result/out_model.json' + filepath = f'{input_filepath}/jobs/{id}/{date}/result/out_model.json' model_xgb = xgb.Booster() model_xgb.load_model(filepath) return model_xgb