Skip to content
Closed
4 changes: 2 additions & 2 deletions docs/changes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -82,5 +82,5 @@ Contributors
~~~~~~~~~~~~

:user:`Adrin Jalali <adrinjalali>`, :user:`Merve Noyan <merveenoyan>`,
:user:`Benjamin Bossan <BenjaminBossan>`, :user:`Ayyuce Demirbas
<ayyucedemirbas>`, :user:`Prajjwal Mishra <p-mishra1>`
:user:`Benjamin Bossan <BenjaminBossan>`, :user:`Ayyuce Demirbas <ayyucedemirbas>`,
:user:`Prajjwal Mishra <p-mishra1>`, :user:`Ali Osman Kaya <aliosmankaya>`
19 changes: 17 additions & 2 deletions skops/hub_utils/_hf_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,11 @@ def _get_column_names(data):
# TODO: this is going to fail for Structured Arrays. We can add support for
# them later if we see need for it.
if isinstance(data, np.ndarray):
return [f"x{x}" for x in range(data.shape[1])]
if data.dtype.names:
return [f"x{x}" for x in range(len(data.dtype.names))]
return (
[f"x{x}" for x in range(data.shape[1])] if len(data.shape) > 1 else ["x0"]
)

raise ValueError("The data is not a pandas.DataFrame or a numpy.ndarray.")

Expand Down Expand Up @@ -672,7 +676,18 @@ def get_model_output(repo_id: str, data: Any, token: Optional[str] = None) -> An
inputs = {"data": data.to_dict(orient="list")}
except AttributeError:
# the input is not a pandas DataFrame
inputs = {f"x{i}": data[:, i] for i in range(data.shape[1])}
if data.dtype.names:
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With that condition, we can figure out the data is a structured array or a regular numpy array.

inputs = {
col: [x[0] for x in data[:, idx].tolist()]
if len(data.shape) > 1
else [x[0] for x in data.tolist()]
for idx, col in enumerate(_get_column_names(data))
}
else:
inputs = {
col: data[:, idx].tolist() if len(data.shape) > 1 else data.tolist()
for idx, col in enumerate(_get_column_names(data))
}
inputs = {"data": inputs}

res = InferenceApi(repo_id=repo_id, task=model_info.pipeline_tag, token=token)(
Expand Down
13 changes: 12 additions & 1 deletion skops/hub_utils/tests/test_hf_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,7 +551,18 @@ def test_get_column_names():
expected_columns = [f"x{x}" for x in range(10)]
assert _get_column_names(X_array) == expected_columns

expected_columns = [f"column{x}" for x in range(10)]
expected_columns = ["x0", "x1"]
X_array = np.array([(1, 2), (3, 4)], dtype=[("foo", "i8"), ("bar", "f4")])
assert _get_column_names(X_array) == expected_columns

expected_columns = ["x0", "x1", "x2"] # Default names
X_array = np.zeros(3, dtype="int8, float32, float64")
assert _get_column_names(X_array) == expected_columns

expected_columns = ["x0"]
X_array = np.zeros(3)
assert _get_column_names(X_array) == expected_columns

X_df = pd.DataFrame(X_array, columns=expected_columns)
assert _get_column_names(X_df) == expected_columns

Expand Down