diff --git a/docs/tutorials/preprocessing.ipynb b/docs/tutorials/preprocessing.ipynb index e34c399..d3d27b5 100644 --- a/docs/tutorials/preprocessing.ipynb +++ b/docs/tutorials/preprocessing.ipynb @@ -179,6 +179,13 @@ "Next, we compensate the data using the compensation matrix that is included in the FCS file header. Alternatively, one may provide a custom compensation matrix." ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The `compensate` function matches the `var_names` of `adata` with the column names of the spillover matrix to compensate the correct channels. " + ] + }, { "cell_type": "code", "execution_count": null, diff --git a/pytometry/preprocessing/_process_data.py b/pytometry/preprocessing/_process_data.py index a8a7079..2f630de 100644 --- a/pytometry/preprocessing/_process_data.py +++ b/pytometry/preprocessing/_process_data.py @@ -101,21 +101,17 @@ def find_indexes( # rename compute bleedthr to compensate def compensate( adata: AnnData, - var_key=None, - key="signal_type", - comp_matrix=None, - matrix_type="spillover", + comp_matrix: pd.DataFrame = None, + matrix_type: str = "spillover", inplace: bool = True, ) -> Optional[AnnData]: """Computes compensation for data channels. Args: adata (AnnData): AnnData object - var_key (str, optional): key where to check if a feature is an area, - height etc. type of value. Use `var_names` if None. key (str, optional): key where result vector is added to the adata.var. Defaults to 'signal_type'. - comp_matrix (None, optional): a custom compensation matrix. + comp_matrix (pd.DataFrame, optional): a custom compensation matrix. Please note that by default we use the spillover matrix directly for numeric stability. matrix_type (str, optional): whether to use a spillover matrix (default) @@ -132,8 +128,6 @@ def compensate( """ adata = adata if inplace else adata.copy() - key_in = key - # locate compensation matrix if comp_matrix is not None: if matrix_type == "spillover": @@ -157,22 +151,41 @@ def compensate( # Ignore channels 'FSC-H', 'FSC-A', 'SSC-H', 'SSC-A', # 'FSC-Width', 'Time' - if key_in not in adata.var_keys(): - find_indexes(adata, var_key=var_key, data_type="facs") - # select non other indices - indexes = np.invert(adata.var[key_in] == "other") - - # To Do: + # and compensate only the values indicated in the compensation matrix + # Note: # the compensation matrix may have different index names than the adata.X matrix - # add a check and match for the compensation - X_comp = np.linalg.solve(compens, adata.X[:, indexes].T).T - adata.X[:, indexes] = X_comp + ref_col = adata.var.index + idx_in = np.intersect1d(compens.columns, ref_col) + if idx_in is None: + # try the adata.var['channel'] as reference + ref_col = adata.var["channel"] + idx_in = np.intersect1d(compens.columns, ref_col) + if idx_in is None: + raise ValueError( + "Could not match the column names of the compensation matrix" + 'with neither `adata.var.index` nor `adata.var["channel"].' + ) + # match columns of spill mat such that they exactly correspond to adata.var.index + ref_names = ref_col[np.in1d(ref_col, idx_in)] + query_names = compens.columns[np.in1d(compens.columns, idx_in)] + idx_sort = [np.where(query_names == x)[0][0] for x in ref_names] + query_idx = np.in1d(compens.columns, query_names) + ref_idx = np.in1d(ref_col, ref_names) + + # subset compensation matrix to the columns to run the compensation on + compens = compens.iloc[query_idx, query_idx] + # sort compensation matrix by adata.var_names + compens = compens.iloc[idx_sort, idx_sort] + X_comp = np.linalg.solve(compens, adata.X[:, ref_idx].T).T + adata.X[:, ref_idx] = X_comp # check for nan values - nan_val = np.isnan(adata.X[:, indexes]).sum() + nan_val = np.isnan(adata.X[:, ref_idx]).sum() if nan_val > 0: - assert f"{nan_val} NaN values found after compensation. Please adjust " - "compensation matrix." + print( + f"{nan_val} NaN values found after compensation. Please adjust " + "compensation matrix." + ) return None if inplace else adata