diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 0000000..1909a4a --- /dev/null +++ b/.coveragerc @@ -0,0 +1,5 @@ +[report] +exclude_lines = + def __repr__ + raise AssertionError + raise NotImplementedError diff --git a/camacqplugins/gain/README.md b/camacqplugins/gain/README.md new file mode 100644 index 0000000..902e332 --- /dev/null +++ b/camacqplugins/gain/README.md @@ -0,0 +1,17 @@ +# Gain + +## Usage + +Add configuration for the `gain` plugin, in the `camacq` configuration file. +See the [config_templates](../../config_templates/) directory for example configuration. + +```yaml +gain: + ... +``` + +Then start `camacq`. + +```sh +camacq +``` diff --git a/camacqplugins/gain/__init__.py b/camacqplugins/gain/__init__.py new file mode 100644 index 0000000..961ebb4 --- /dev/null +++ b/camacqplugins/gain/__init__.py @@ -0,0 +1,308 @@ +"""Handle default gain feedback plugin.""" +import logging +import os +from collections import defaultdict, namedtuple +from functools import partial +from itertools import groupby +from pathlib import Path + +import matplotlib +import pandas as pd +import voluptuous as vol +from scipy.optimize import curve_fit + +from camacq.const import CHANNEL_ID, WELL, WELL_NAME +from camacq.event import Event +from camacq.plugins.sample import Channel +from camacq.helper import BASE_ACTION_SCHEMA +from camacq.image import make_proj +from camacq.util import write_csv + +matplotlib.use("AGG") # use noninteractive default backend +# pylint: disable=wrong-import-order, wrong-import-position, ungrouped-imports +import matplotlib.pyplot as plt # noqa: E402 + +_LOGGER = logging.getLogger(__name__) +BOX = "box" +COUNT = "count" +VALID = "valid" +CONF_CHANNEL = "channel" +CONF_CHANNELS = "channels" +CONF_GAIN = "gain" +CONF_INIT_GAIN = "init_gain" +CONF_SAVE_DIR = "save_dir" +COUNT_CLOSE_TO_ZERO = 2 +GAIN_CALC_EVENT = "gain_calc_event" +SAVED_GAINS = "saved_gains" + +ACTION_CALC_GAIN = "calc_gain" +CALC_GAIN_ACTION_SCHEMA = BASE_ACTION_SCHEMA.extend( + { + vol.Required("well_x"): vol.Coerce(int), + vol.Required("well_y"): vol.Coerce(int), + vol.Required("plate_name"): vol.Coerce(str), + "images": [vol.Coerce(str)], + } +) + +CONFIG_SCHEMA = vol.Schema( + { + vol.Required(CONF_CHANNELS): [ + { + vol.Required(CONF_CHANNEL): vol.Coerce(str), + vol.Required(CONF_INIT_GAIN): [vol.Coerce(int)], + } + ], + # pylint: disable=no-value-for-parameter + vol.Optional(CONF_SAVE_DIR): vol.IsDir(), + } +) + +GAIN = "gain" +Data = namedtuple("Data", [BOX, GAIN, VALID]) # pylint: disable=invalid-name + + +async def setup_module(center, config): + """Set up gain calculation plugin.""" + + async def handle_calc_gain(**kwargs): + """Handle call to calc_gain action.""" + well_x = kwargs.get("well_x") + well_y = kwargs.get("well_y") + plate_name = kwargs.get("plate_name") + paths = kwargs.get("images") # list of paths to calculate gain for + if not paths: + well = center.sample.get_well(plate_name, well_x, well_y) + if not well: + return + images = {path: image.channel_id for path, image in well.images.items()} + else: + images = { + path: image.channel_id + for path, image in center.sample.images.items() + if path in paths + } + projs = await center.add_executor_job(make_proj, images) + await calc_gain(center, config, plate_name, well_x, well_y, projs) + + center.actions.register( + "gain", ACTION_CALC_GAIN, handle_calc_gain, CALC_GAIN_ACTION_SCHEMA + ) + + +async def calc_gain( + center, config, plate_name, well_x, well_y, projs, +): + """Calculate gain values for the well.""" + # pylint: disable=too-many-arguments, too-many-locals + gain_conf = config[CONF_GAIN] + save_dir = gain_conf.get(CONF_SAVE_DIR) or "" + make_plots = bool(save_dir) + plot_dir = Path(save_dir) / "plots" + await center.add_executor_job(ensure_plot_dir, plot_dir) + + init_gain = [ + Channel(channel[CONF_CHANNEL], gain=gain) + for channel in gain_conf[CONF_CHANNELS] + for gain in channel[CONF_INIT_GAIN] + ] + + # This should be a path to a base file name, not to a dir or file. + plot_path = plot_dir / f"U{well_x:02}--V{well_y:02}" + gains = await center.add_executor_job( + partial(_calc_gain, projs, init_gain, plot=make_plots, save_path=plot_path) + ) + _LOGGER.info("Calculated gains: %s", gains) + if SAVED_GAINS not in center.data: + center.data[SAVED_GAINS] = defaultdict(dict) + center.data[SAVED_GAINS].update({WELL_NAME.format(well_x, well_y): gains}) + if make_plots: + await center.add_executor_job( + save_gain, save_dir, center.data[SAVED_GAINS], [WELL] + list(gains) + ) + + for channel_name, gain in gains.items(): + event = GainCalcEvent( + { + "plate_name": plate_name, + "well_x": well_x, + "well_y": well_y, + "channel_name": channel_name, + "gain": gain, + } + ) + await center.bus.notify(event) # await in sequential order + + +def _power_func(inp, alpha, beta): + """Return the value of function of inp, alpha and beta.""" + return alpha * inp ** beta + + +def _check_upward(points): + """Return a function that checks if points move upward.""" + + def wrapped(point): + """Return True if trend is upward. + + The calculation is done for a point with neighbouring points. + """ + idx, item = point + valid = item.valid and item.box <= 600 + prev = next_ = True + if idx > 0: + prev = item.box >= points[idx - 1].box + if idx < len(points) - 1: + next_ = item.box <= points[idx + 1].box + return valid and (prev or next_) + + return wrapped + + +def _create_plot(path, x_data, y_data, coeffs, label): + """Plot and save plot to path.""" + plt.ioff() + plt.clf() + plt.yscale("log") + plt.xscale("log") + plt.plot( + x_data, y_data, "bo", x_data, _power_func(x_data, *coeffs), "g-", label=label + ) + plt.savefig(path) + + +def _calc_gain(projs, init_gain, plot=True, save_path=""): + """Calculate gain values for the well. + + Do the actual math. + """ + # pylint: disable=too-many-locals + box_vs_gain = {} + + for c_id, proj in projs.items(): + channel = init_gain[c_id] + if channel.name not in box_vs_gain: + box_vs_gain[channel.name] = [] + hist_data = pd.DataFrame( + {BOX: list(range(len(proj.histogram[0]))), COUNT: proj.histogram[0]} + ) + # Handle all zero pixels + non_zero_hist_data = hist_data[(hist_data[COUNT] > 0) & (hist_data[BOX] > 0)] + if non_zero_hist_data.empty: + continue + # Find the max box holding pixels + box_max_count = non_zero_hist_data[BOX].iloc[-1] + # Select only histo data where count is > 0 and 255 > box > 0. + # Only use values in interval 10-100 and + # > (max box holding pixels - 175). + roi = hist_data[ + (hist_data[COUNT] > 0) + & (hist_data[BOX] > 0) + & (hist_data[BOX] < 255) + & (hist_data[COUNT] >= 10) + & (hist_data[COUNT] <= 100) + & (hist_data[BOX] > (box_max_count - 175)) + ] + if roi.shape[0] < 3: + continue + x_data = roi[COUNT].astype(float).values + y_data = roi[BOX].astype(float).values + coeffs, _ = curve_fit(_power_func, x_data, y_data, p0=(1000, -1)) + if plot: + _save_path = "{}{}.ome.png".format(save_path, CHANNEL_ID.format(c_id)) + _create_plot( + _save_path, hist_data[COUNT], hist_data[BOX], coeffs, "count-box" + ) + # Find box value where count is close to zero. + # Store that box value and it's corresponding gain value. + # Store boolean saying if second slope coefficient is negative. + box_vs_gain[channel.name].append( + Data._make( + (_power_func(COUNT_CLOSE_TO_ZERO, *coeffs), channel.gain, coeffs[1] < 0) + ) + ) + + gains = {} + for channel, points in box_vs_gain.items(): + # Sort points with ascending gain, to allow grouping. + points = sorted(points, key=lambda item: item.gain) + long_group = [] + for key, group in groupby(enumerate(points), _check_upward(points)): + # Find the group with the most points and use that below. + stored_group = list(group) + if key and len(stored_group) > len(long_group): + long_group = stored_group + + # Curve fit the longest group with power function. + # Plot the points and the fit. + # Return the calculated gains at bin 255, using fit function. + if len(long_group) < 3: + gains[channel] = None + continue + coeffs, _ = curve_fit( + _power_func, + [p[1].box for p in long_group], + [p[1].gain for p in long_group], + p0=(1, 1), + ) + if plot: + _save_path = "{}_{}.png".format(save_path, channel) + _create_plot( + _save_path, + [p.box for p in points], + [p.gain for p in points], + coeffs, + "box-gain", + ) + gains[channel] = round(_power_func(255, *coeffs)) + + return gains + + +def save_gain(save_dir, saved_gains, header): + """Save a csv file with gain values per image channel.""" + path = os.path.normpath(os.path.join(save_dir, "output_gains.csv")) + write_csv(path, saved_gains, header) + + +def ensure_plot_dir(plot_dir): + """Make sure that plot dir exists.""" + if not plot_dir.exists(): + plot_dir.mkdir() + + +class GainCalcEvent(Event): + """An event produced by a sample channel change event.""" + + __slots__ = () + + event_type = GAIN_CALC_EVENT + + @property + def channel_name(self): + """:str: Return the channel name of the event.""" + return self.data.get("channel_name") + + @property + def gain(self): + """:str: Return the channel gain of the event.""" + return self.data.get("gain") + + @property + def plate_name(self): + """:str: Return the name of the plate.""" + return self.data.get("plate_name") + + @property + def well_x(self): + """:int: Return the well x coordinate of the event.""" + return self.data.get("well_x") + + @property + def well_y(self): + """:int: Return the well y coordinate of the event.""" + return self.data.get("well_y") + + def __repr__(self): + """Return the representation.""" + return "<{}: {}>".format(type(self).__name__, self.gain) diff --git a/camacqplugins/production/__init__.py b/camacqplugins/production/__init__.py index d151b75..9bbb997 100644 --- a/camacqplugins/production/__init__.py +++ b/camacqplugins/production/__init__.py @@ -1,8 +1,6 @@ """Provide a plugin for production standard flow.""" import logging -import tempfile from math import ceil -from pathlib import Path import voluptuous as vol @@ -29,7 +27,6 @@ CONF_WELL_LAYOUT = "well_layout" CONF_X_FIELDS = "x_fields" CONF_Y_FIELDS = "y_fields" -CONF_PLOT_SAVE_PATH = "plot_save_path" CONF_SAMPLE_STATE_FILE = "sample_state_file" PLATE_NAME = "00" @@ -98,7 +95,6 @@ def is_sample_state(value): vol.Required(CONF_Y_FIELDS): vol.Coerce(int), }, # pylint: disable=no-value-for-parameter - CONF_PLOT_SAVE_PATH: vol.IsDir(), CONF_SAMPLE_STATE_FILE: vol.All( vol.IsFile(), is_csv, read_csv, is_sample_state ), @@ -143,7 +139,6 @@ def __init__(self, center, conf): well_layout = conf[CONF_WELL_LAYOUT] self.x_fields = well_layout[CONF_X_FIELDS] self.y_fields = well_layout[CONF_Y_FIELDS] - self.plot_save_path = conf.get(CONF_PLOT_SAVE_PATH) self._remove_handle_exp_image = None self.wells_left = set() @@ -228,23 +223,8 @@ async def calc_gain(center, event): return await center.actions.command.stop_imaging() - - if self.plot_save_path is None: - save_path = Path(tempfile.gettempdir()) / event.plate_name - else: - save_path = Path(self.plot_save_path) - if not save_path.exists(): - await center.add_executor_job(save_path.mkdir) - - # This should be a path to a base file name, not to a dir or file. - save_path = save_path / f"{event.well_x}--{event.well_y}" - await center.actions.gain.calc_gain( - plate_name=event.plate_name, - well_x=event.well_x, - well_y=event.well_y, - make_plots=True, - save_path=save_path, + plate_name=event.plate_name, well_x=event.well_x, well_y=event.well_y, ) return self._center.bus.register(IMAGE_EVENT, calc_gain) diff --git a/config_templates/gain.yml b/config_templates/gain.yml new file mode 100644 index 0000000..7e866ab --- /dev/null +++ b/config_templates/gain.yml @@ -0,0 +1,11 @@ +gain: + channels: + - channel: green + init_gain: [450, 495, 540, 585, 630, 675, 720, 765, 810, 855, 900] + - channel: blue + init_gain: [700, 735, 770, 805, 840, 875, 910] + - channel: yellow + init_gain: [700, 735, 770, 805, 840, 875, 910] + - channel: red + init_gain: [600, 635, 670, 705, 740, 775, 810] + #save_dir: "/path/to/gains/dir" diff --git a/config_templates/production.yml b/config_templates/production.yml index 3a28568..2181ae6 100644 --- a/config_templates/production.yml +++ b/config_templates/production.yml @@ -32,7 +32,6 @@ production: well_layout: x_fields: 2 y_fields: 3 - #plot_save_path: "/path/to/gains/dir/00/" gain: channels: diff --git a/requirements.txt b/requirements.txt index 438a380..4c66360 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1,4 @@ -camacq==0.4.0 +camacq==0.5.0 +matplotlib==3.1.2 +pandas==0.25.3 +scipy==1.3.3 diff --git a/scripts/setup_tests.py b/scripts/setup_tests.py new file mode 100755 index 0000000..a605999 --- /dev/null +++ b/scripts/setup_tests.py @@ -0,0 +1,100 @@ +#!/usr/bin/env python3 +"""Tool to generate image fixture for tests from real image data.""" +import argparse +import fnmatch +import gzip +import os +import shutil +from pathlib import Path + +import numpy as np +import tifffile + +IMAGE_DATA_DIR = os.path.join(os.path.dirname(__file__), "../tests/fixtures/image_data") + + +def _find_files(root_dir, search): + """Search for files in root directory.""" + matches = [] + for root, _, filenames in os.walk(os.path.normpath(root_dir)): + for filename in fnmatch.filter(filenames, search): + matches.append(os.path.join(root, filename)) + return matches + + +def pack_image_fixture(root_dir=None): + """Gunzip tif images for image tests.""" + if root_dir is None: + root_dir = IMAGE_DATA_DIR + matches = _find_files(root_dir, "*.tif") + print("Gzipping the images, this will take some time...") + for path in matches: + gz_path = "{}.gz".format(path) + with open(path, "rb") as f_in: + with gzip.open(gz_path, "wb") as f_out: + shutil.copyfileobj(f_in, f_out) + os.remove(path) + + +def unpack_image_fixture(root_dir=None): + """Unzip gunzipped tif images for image tests.""" + if root_dir is None: + root_dir = IMAGE_DATA_DIR + matches = _find_files(root_dir, "*.gz") + for gz_path in matches: + path, _ = os.path.splitext(gz_path) + with gzip.open(gz_path, "rb") as f_in: + with open(path, "wb") as f_out: + shutil.copyfileobj(f_in, f_out) + + +def read_image_data(root_dir=None): + """Return a list of dicts with path and image numpy array data.""" + if root_dir is None: + root_dir = IMAGE_DATA_DIR + matches = _find_files(root_dir, "*.tif") + image_data = [] + for path in matches: + try: + data = tifffile.imread(path, key=0) + except OSError as exc: + print("Failed reading image:", exc) + raise + else: + image_data.append({"path": path, "data": data}) + + return image_data + + +def save_images_to_npz(path): + """Save image data as compressed npz.""" + path = Path(path).resolve() + image_data = read_image_data() + np.savez_compressed(path, **{data["path"]: data["data"] for data in image_data}) + + +def get_arguments(args=None): + """Get parsed arguments.""" + parser = argparse.ArgumentParser(description="Unpack or pack fixture files.") + parser.add_argument("--pack", action="store_true", help="Pack fixture files.") + parser.add_argument("--npz", help="Save image fixture data in a npz file.") + args = parser.parse_args(args=args) + + return args + + +def main(args=None): + """Pack or unpack the images for test fixtures.""" + args = get_arguments(args=args) + + if args.npz: + save_images_to_npz(args.npz) + return + if args.pack: + pack_image_fixture() + else: + unpack_image_fixture() + + +if __name__ == "__main__": + main() diff --git a/setup.py b/setup.py index 8aada30..86e3518 100644 --- a/setup.py +++ b/setup.py @@ -7,7 +7,7 @@ VERSION = (PROJECT_DIR / "camacqplugins" / "VERSION").read_text().strip() README_FILE = PROJECT_DIR / "README.md" LONG_DESCR = README_FILE.read_text(encoding="utf-8") -REQUIRES = ["camacq"] +REQUIRES = ["camacq>=0.5.0", "matplotlib", "pandas", "scipy"] setuptools.setup( @@ -23,7 +23,12 @@ packages=setuptools.find_packages(), python_requires=">=3.6", install_requires=REQUIRES, - entry_points={"camacq.plugins": ["production = camacqplugins.production",],}, + entry_points={ + "camacq.plugins": [ + "gain = camacqplugins.gain", + "production = camacqplugins.production", + ], + }, classifiers=[ "Development Status :: 2 - Pre-Alpha", "Programming Language :: Python", diff --git a/tests/common.py b/tests/common.py new file mode 100644 index 0000000..d91c535 --- /dev/null +++ b/tests/common.py @@ -0,0 +1,10 @@ +"""Provide common test utils.""" +from pathlib import Path + +IMAGE_DATA_DIR = (Path(__file__).parent / "fixtures/image_data").resolve() +WELL_NAME = "U01--V00" +FULL_WELL_NAME = f"chamber--{WELL_NAME}" +WELL_PATH = Path(IMAGE_DATA_DIR) / "slide" / FULL_WELL_NAME +IMAGE_PATH = ( + Path(WELL_PATH) / "field--X00--Y00/image--U01--V00--E02--X00--Y00--Z00--C00.ome.tif" +) diff --git a/tests/fixtures/image_data/image_data.npz b/tests/fixtures/image_data/image_data.npz new file mode 100644 index 0000000..ef44c55 Binary files /dev/null and b/tests/fixtures/image_data/image_data.npz differ diff --git a/tests/gain/test_gain.py b/tests/gain/test_gain.py new file mode 100644 index 0000000..3933d66 --- /dev/null +++ b/tests/gain/test_gain.py @@ -0,0 +1,97 @@ +"""Test gain calculation.""" +from unittest.mock import patch, PropertyMock + +import numpy as np +import pytest + +from camacq import plugins +from camacq.plugins.leica import LeicaImageEvent +from camacqplugins.gain import GAIN_CALC_EVENT +from tests.common import IMAGE_DATA_DIR + +# All test coroutines will be treated as marked. +pytestmark = pytest.mark.asyncio # pylint: disable=invalid-name + +PLATE_NAME = "slide" +WELL_X, WELL_Y = 1, 0 + + +@pytest.fixture(name="load_image") +def load_image_fixture(): + """Patch load image and metadata.""" + with patch( + "camacq.image.ImageData._load_image_data", autospec=True + ) as load_image, patch( + "camacq.image.ImageData.metadata", new_callable=PropertyMock + ) as mock_metadata: + mock_metadata.return_value = "" + yield load_image + + +async def test_gain(center, load_image): + """Run gain calculation test.""" + config = { + "gain": { + "channels": [ + { + "channel": "green", + "init_gain": [ + 450, + 495, + 540, + 585, + 630, + 675, + 720, + 765, + 810, + 855, + 900, + ], + }, + {"channel": "blue", "init_gain": [400, 435, 470, 505, 540, 575, 610],}, + { + "channel": "yellow", + "init_gain": [550, 585, 620, 655, 690, 725, 760], + }, + {"channel": "red", "init_gain": [525, 560, 595, 630, 665, 700, 735],}, + ], + } + } + await plugins.setup_module(center, config) + image_fixture = IMAGE_DATA_DIR / "image_data.npz" + image_data = await center.add_executor_job(np.load, image_fixture) + + def mock_load_image(image): + """Mock load image.""" + data = image_data[image.path] + image._data = data # pylint: disable=protected-access + + load_image.side_effect = mock_load_image + + events = [LeicaImageEvent({"path": path}) for path in image_data] + + for event in events: + await center.bus.notify(event) + + calculated = {} + + async def handle_gain_event(center, event): + """Handle gain event.""" + if ( + event.plate_name != PLATE_NAME + or event.well_x != WELL_X + or event.well_y != WELL_Y + ): + return + calculated[event.channel_name] = event.gain + + center.bus.register(GAIN_CALC_EVENT, handle_gain_event) + + images = [event.path for event in events] + await center.actions.gain.calc_gain( + plate_name=PLATE_NAME, well_x=WELL_X, well_y=WELL_Y, images=images + ) + + solution = {"blue": 480, "green": 740, "red": 745, "yellow": 805} + assert calculated == pytest.approx(solution, abs=10) diff --git a/tests/production/test_production.py b/tests/production/test_production.py index 99e670d..061c0bd 100644 --- a/tests/production/test_production.py +++ b/tests/production/test_production.py @@ -8,7 +8,7 @@ from camacq import plugins from camacq.plugins.api import ImageEvent -from camacq.plugins.gain import GainCalcEvent +from camacqplugins.gain import GainCalcEvent # All test coroutines will be treated as marked. pytestmark = pytest.mark.asyncio # pylint: disable=invalid-name @@ -70,17 +70,13 @@ def job_id(self): return self.data.get("job_id") -async def test_image_events(center, tmp_path): +async def test_image_events(center): """Test image events.""" config = YAML(typ="safe").load(CONFIG) plate_name = "00" well_x = 0 well_y = 0 - save_path = tmp_path / plate_name - await center.add_executor_job(save_path.mkdir) - config["production"]["plot_save_path"] = save_path await plugins.setup_module(center, config) - save_path = save_path / f"{well_x}--{well_y}" calc_gain = CoroutineMock() gains = { "green": 800, @@ -130,12 +126,7 @@ async def fire_gain_event(**kwargs): assert calc_gain.call_count == 1 assert calc_gain.call_args == call( - action_id="calc_gain", - plate_name=plate_name, - well_x=well_x, - well_y=well_y, - make_plots=True, - save_path=save_path, + action_id="calc_gain", plate_name=plate_name, well_x=well_x, well_y=well_y, ) for channel_name, gain in gains.items(): channel = center.sample.get_channel(plate_name, well_x, well_y, channel_name)