diff --git a/docs/api-reference/index.md b/docs/api-reference/index.md index 5458e71..38d33a4 100644 --- a/docs/api-reference/index.md +++ b/docs/api-reference/index.md @@ -12,15 +12,17 @@ Chopper ChopperReading + Component ComponentReading Dashboard Detector DetectorReading + InelasticSample Model ReadingField Result Source - SourceParameters + SourceReading ``` ## Top-level functions @@ -39,5 +41,6 @@ :template: module-template.rst :recursive: + facilities utils ``` diff --git a/docs/components.ipynb b/docs/components.ipynb index 9e5a463..7a2185d 100644 --- a/docs/components.ipynb +++ b/docs/components.ipynb @@ -18,7 +18,10 @@ "metadata": {}, "outputs": [], "source": [ + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", "import scipp as sc\n", + "import plopp as pp\n", "import tof\n", "\n", "meter = sc.Unit('m')\n", @@ -412,7 +415,183 @@ "id": "33", "metadata": {}, "source": [ - "## Loading from a JSON file\n", + "## Inelastic sample\n", + "\n", + "Placing an `InelasticSample` in the instrument will change the energy of the incoming neutrons by a $\\Delta E$ defined by a probability distribution function.\n", + "It defines the likeliness of what energy-shift will be applied to each neutron.\n", + "\n", + "To give an equal chance for a random $\\Delta E$ between -0.2 and 0.2 meV, we can create a flat distribution (all ones):" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "34", + "metadata": {}, + "outputs": [], + "source": [ + "sample = tof.InelasticSample(\n", + " distance=28.0 * meter,\n", + " name=\"sample\",\n", + " delta_e=sc.DataArray(\n", + " data=sc.ones(sizes={'e': 100}),\n", + " coords={'e': sc.linspace('e', -0.2, 0.2, 100, unit='meV')},\n", + " ),\n", + ")\n", + "sample" + ] + }, + { + "cell_type": "markdown", + "id": "35", + "metadata": {}, + "source": [ + "We then make a single fast-rotating chopper with one small opening,\n", + "to select a narrow wavelength range at every rotation.\n", + "\n", + "We also add a monitor before the sample, and a detector after the sample so we can follow the changes in energies/wavelengths." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "36", + "metadata": {}, + "outputs": [], + "source": [ + "choppers = [\n", + " tof.Chopper(\n", + " frequency=70.0 * Hz,\n", + " open=sc.array(dims=['cutout'], values=[0.0], unit='deg'),\n", + " close=sc.array(dims=['cutout'], values=[1.0], unit='deg'),\n", + " phase=0.0 * deg,\n", + " distance=20.0 * meter,\n", + " name=\"fastchopper\",\n", + " ),\n", + "]\n", + "\n", + "detectors = [\n", + " tof.Detector(distance=26.0 * meter, name='monitor'),\n", + " tof.Detector(distance=32.0 * meter, name='detector'),\n", + "]\n", + "\n", + "source = tof.Source(facility='ess', neutrons=5_000_000)\n", + "\n", + "model = tof.Model(source=source, components=choppers + detectors + [sample])\n", + "res = model.run()\n", + "\n", + "\n", + "fig, ax = plt.subplots(1, 2, figsize=(12, 4.5))\n", + "\n", + "dw = sc.scalar(0.1, unit='angstrom')\n", + "pp.plot(\n", + " {\n", + " 'monitor': res['monitor'].data.hist(wavelength=dw),\n", + " 'detector': res['detector'].data.hist(wavelength=dw),\n", + " },\n", + " title=\"With inelastic sample\",\n", + " xmin=4,\n", + " xmax=20,\n", + " ymin=-20,\n", + " ymax=400,\n", + " ax=ax[1],\n", + ")\n", + "\n", + "res.plot(visible_rays=10000, ax=ax[0])" + ] + }, + { + "cell_type": "markdown", + "id": "37", + "metadata": {}, + "source": [ + "### Non-uniform energy-transfer distributions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "38", + "metadata": {}, + "outputs": [], + "source": [ + "# Sample 1: double-peak at min and max\n", + "delta_e = sc.DataArray(\n", + " data=sc.zeros(sizes={'e': 100}),\n", + " coords={'e': sc.linspace('e', -0.2, 0.2, 100, unit='meV')},\n", + ")\n", + "delta_e.values[[0, -1]] = 1.0\n", + "sample1 = tof.InelasticSample(\n", + " distance=28.0 * meter,\n", + " name=\"sample\",\n", + " delta_e=delta_e,\n", + ")\n", + "\n", + "# Sample 2: normal distribution\n", + "x = sc.linspace('e', -0.2, 0.2, 100, unit='meV')\n", + "sig = sc.scalar(0.03, unit='meV')\n", + "y = 1.0 / (np.sqrt(2.0 * np.pi) * sig) * sc.exp(-((x / sig) ** 2) / 2)\n", + "y.unit = \"\"\n", + "\n", + "sample2 = tof.InelasticSample(\n", + " distance=28.0 * meter,\n", + " name=\"sample\",\n", + " delta_e=sc.DataArray(data=y, coords={'e': x}),\n", + ")\n", + "\n", + "sample1.plot() + sample2.plot()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "39", + "metadata": {}, + "outputs": [], + "source": [ + "model1 = tof.Model(source=source, components=choppers + detectors + [sample1])\n", + "model2 = tof.Model(source=source, components=choppers + detectors + [sample2])\n", + "\n", + "res1 = model1.run()\n", + "res2 = model2.run()\n", + "\n", + "fig, ax = plt.subplots(2, 2, figsize=(12, 9))\n", + "\n", + "res1.plot(ax=ax[0, 0], title=\"Sample 1\")\n", + "pp.plot(\n", + " {\n", + " 'monitor': res1['monitor'].data.hist(wavelength=dw),\n", + " 'detector': res1['detector'].data.hist(wavelength=dw),\n", + " },\n", + " title=\"Sample 1\",\n", + " xmin=4,\n", + " xmax=20,\n", + " ymin=-20,\n", + " ymax=400,\n", + " ax=ax[0, 1],\n", + ")\n", + "\n", + "res2.plot(ax=ax[1, 0], title=\"Sample 2\")\n", + "_ = pp.plot(\n", + " {\n", + " 'monitor': res2['monitor'].data.hist(wavelength=dw),\n", + " 'detector': res2['detector'].data.hist(wavelength=dw),\n", + " },\n", + " title=\"Sample 2\",\n", + " xmin=4,\n", + " xmax=20,\n", + " ymin=-20,\n", + " ymax=400,\n", + " ax=ax[1, 1],\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "40", + "metadata": {}, + "source": [ + "## Loading components from a JSON file\n", "\n", "It is also possible to load components from a JSON file,\n", "which can be very useful to quickly load a pre-configured instrument.\n", @@ -423,19 +602,14 @@ { "cell_type": "code", "execution_count": null, - "id": "34", + "id": "41", "metadata": {}, "outputs": [], "source": [ "import json\n", "\n", "params = {\n", - " \"source\": {\n", - " \"type\": \"source\",\n", - " \"facility\": \"ess\",\n", - " \"neutrons\": 1e6,\n", - " \"pulses\": 1\n", - " },\n", + " \"source\": {\"type\": \"source\", \"facility\": \"ess\", \"neutrons\": 1e6, \"pulses\": 1},\n", " \"chopper1\": {\n", " \"type\": \"chopper\",\n", " \"frequency\": {\"value\": 56.0, \"unit\": \"Hz\"},\n", @@ -470,7 +644,7 @@ { "cell_type": "code", "execution_count": null, - "id": "35", + "id": "42", "metadata": {}, "outputs": [], "source": [ @@ -479,7 +653,7 @@ }, { "cell_type": "markdown", - "id": "36", + "id": "43", "metadata": {}, "source": [ "We now use the `tof.Model.from_json()` method to load our instrument:" @@ -488,7 +662,7 @@ { "cell_type": "code", "execution_count": null, - "id": "37", + "id": "44", "metadata": {}, "outputs": [], "source": [ @@ -498,7 +672,7 @@ }, { "cell_type": "markdown", - "id": "38", + "id": "45", "metadata": {}, "source": [ "We can see that all components have been read in correctly.\n", @@ -508,7 +682,7 @@ { "cell_type": "code", "execution_count": null, - "id": "39", + "id": "46", "metadata": {}, "outputs": [], "source": [ @@ -517,7 +691,7 @@ }, { "cell_type": "markdown", - "id": "40", + "id": "47", "metadata": {}, "source": [ "### Modifying the source\n", @@ -531,7 +705,7 @@ { "cell_type": "code", "execution_count": null, - "id": "41", + "id": "48", "metadata": {}, "outputs": [], "source": [ @@ -542,7 +716,7 @@ }, { "cell_type": "markdown", - "id": "42", + "id": "49", "metadata": {}, "source": [ "## Saving to JSON\n", @@ -553,7 +727,7 @@ { "cell_type": "code", "execution_count": null, - "id": "43", + "id": "50", "metadata": {}, "outputs": [], "source": [ diff --git a/src/tof/__init__.py b/src/tof/__init__.py index d3a75af..93f8699 100644 --- a/src/tof/__init__.py +++ b/src/tof/__init__.py @@ -17,12 +17,13 @@ submodules=['facilities'], submod_attrs={ 'chopper': ['AntiClockwise', 'Chopper', 'ChopperReading', 'Clockwise'], + 'component': ['Component', 'ComponentReading', 'ReadingField'], 'dashboard': ['Dashboard'], 'detector': ['Detector', 'DetectorReading'], + 'inelastic': ['InelasticSample', 'InelasticSampleReading'], 'model': ['Model'], - 'reading': ['ComponentReading', 'ReadingField'], 'result': ['Result'], - 'source': ['Source', 'SourceParameters'], + 'source': ['Source', 'SourceReading'], }, ) diff --git a/src/tof/chopper.py b/src/tof/chopper.py index 939a4c4..9f367cc 100644 --- a/src/tof/chopper.py +++ b/src/tof/chopper.py @@ -6,10 +6,11 @@ from enum import Enum from typing import TYPE_CHECKING +import numpy as np import scipp as sc -from .reading import ComponentReading -from .utils import two_pi, var_to_dict +from .component import Component, ComponentReading +from .utils import two_pi, var_from_dict, var_to_dict if TYPE_CHECKING: try: @@ -27,7 +28,76 @@ class Direction(Enum): AntiClockwise = Direction.ANTICLOCKWISE -class Chopper: +def _array_or_none(container: dict, key: str) -> sc.Variable | None: + return var_from_dict(container[key], dim="cutout") if key in container else None + + +@dataclass(frozen=True) +class ChopperReading(ComponentReading): + """ + Read-only container for the neutrons that reach the chopper. + """ + + distance: sc.Variable + name: str + frequency: sc.Variable + open: sc.Variable + close: sc.Variable + phase: sc.Variable + open_times: sc.Variable + close_times: sc.Variable + data: sc.DataArray + + @property + def kind(self) -> str: + return "chopper" + + def _repr_stats(self) -> str: + return ( + f"visible={int(self.data.sum().value)}, " + f"blocked={int(self.data.masks['blocked_by_me'].sum().value)}" + ) + + def __repr__(self) -> str: + return f"""ChopperReading: '{self.name}' + distance: {self.distance:c} + frequency: {self.frequency:c} + phase: {self.phase:c} + cutouts: {len(self.open)} + neutrons: {self._repr_stats()} +""" + + def __str__(self) -> str: + return self.__repr__() + + def __getitem__(self, val: int | slice | tuple[str, int | slice]) -> ChopperReading: + if isinstance(val, int): + val = ('pulse', val) + return replace(self, data=self.data[val]) + + def plot_on_time_distance_diagram(self, ax, tmax) -> None: + dx = 0.05 * tmax + x0 = self.open_times.values + x1 = self.close_times.values + x = np.empty(3 * x0.size, dtype=x0.dtype) + x[0::3] = x0 + x[1::3] = 0.5 * (x0 + x1) + x[2::3] = x1 + x = np.concatenate( + ([[0]] if x[0] > 0 else [x[0:1]]) + + [x] + + ([[tmax + dx]] if x[-1] < tmax else []) + ) + y = np.full_like(x, self.distance.value) + y[2::3] = None + inds = np.argsort(x) + ax.plot(x[inds], y[inds], color="k") + ax.text( + tmax, self.distance.value, self.name, ha="right", va="bottom", color="k" + ) + + +class Chopper(Component): """ A chopper is a rotating device with cutouts that blocks the beam at certain times. @@ -106,6 +176,7 @@ def __init__( self.distance = distance.to(dtype=float, copy=False) self.phase = phase.to(dtype=float, copy=False) self.name = name + self.kind = "chopper" super().__init__() @property @@ -178,6 +249,18 @@ def __repr__(self) -> str: f"direction={self.direction.name}, cutouts={len(self.open)})" ) + def __eq__(self, other: object) -> bool: + if not isinstance(other, Chopper): + return NotImplemented + if self.name != other.name: + return False + if self.direction != other.direction: + return False + return all( + sc.identical(getattr(self, field), getattr(other, field)) + for field in ('frequency', 'distance', 'phase', 'open', 'close') + ) + def as_dict(self) -> dict: """ Return the chopper as a dictionary. @@ -192,6 +275,30 @@ def as_dict(self) -> dict: 'direction': self.direction, } + @classmethod + def from_json(cls, name: str, params: dict) -> Chopper: + direction = params["direction"].lower() + if direction == "clockwise": + _dir = Clockwise + elif any(x in direction for x in ("anti", "counter")): + _dir = AntiClockwise + else: + raise ValueError( + f"Chopper direction must be 'clockwise' or 'anti-clockwise', got " + f"'{params['direction']}' for component {name}." + ) + return cls( + frequency=var_from_dict(params["frequency"]), + direction=_dir, + open=_array_or_none(params, "open"), + close=_array_or_none(params, "close"), + centers=_array_or_none(params, "centers"), + widths=_array_or_none(params, "widths"), + phase=var_from_dict(params["phase"]), + distance=var_from_dict(params["distance"]), + name=name, + ) + def as_json(self) -> dict: """ Return the chopper as a JSON-serializable dictionary. @@ -212,18 +319,6 @@ def as_json(self) -> dict: ) return out - def __eq__(self, other: object) -> bool: - if not isinstance(other, Chopper): - return NotImplemented - if self.name != other.name: - return False - if self.direction != other.direction: - return False - return all( - sc.identical(getattr(self, field), getattr(other, field)) - for field in ('frequency', 'distance', 'phase', 'open', 'close') - ) - @classmethod def from_diskchopper( cls, disk_chopper: DiskChopper, name: str | None = None @@ -273,6 +368,27 @@ def from_diskchopper( name=name, ) + def to_diskchopper(self) -> DiskChopper: + """ + Export the chopper as a scippneutron DiskChopper. + """ + from scippneutron.chopper import DiskChopper + + frequency = ( + self.frequency if self.direction == AntiClockwise else -self.frequency + ) + phase = self.phase if self.direction == AntiClockwise else -self.phase + return DiskChopper( + frequency=frequency, + beam_position=sc.scalar(0.0, unit='deg'), + slit_begin=self.open, + slit_end=self.close, + phase=phase, + axle_position=sc.vector( + value=[0.0, 0.0, self.distance.value], unit=self.distance.unit + ), + ) + @classmethod def from_nexus(cls, nexus_chopper, name: str | None = None) -> Chopper: """ @@ -321,63 +437,37 @@ def from_nexus(cls, nexus_chopper, name: str | None = None) -> Chopper: name=name, ) - def to_diskchopper(self) -> DiskChopper: + def as_readonly( + self, neutrons: sc.DataArray, time_limit: sc.Variable + ) -> ChopperReading: """ - Export the chopper as a scippneutron DiskChopper. + Create a ChopperReading from the given neutrons that have been processed by this + chopper. """ - from scippneutron.chopper import DiskChopper - - frequency = ( - self.frequency if self.direction == AntiClockwise else -self.frequency - ) - phase = self.phase if self.direction == AntiClockwise else -self.phase - return DiskChopper( - frequency=frequency, - beam_position=sc.scalar(0.0, unit='deg'), - slit_begin=self.open, - slit_end=self.close, - phase=phase, - axle_position=sc.vector( - value=[0.0, 0.0, self.distance.value], unit=self.distance.unit - ), - ) - - -@dataclass(frozen=True) -class ChopperReading(ComponentReading): - """ - Read-only container for the neutrons that reach the chopper. - """ - - distance: sc.Variable - name: str - frequency: sc.Variable - open: sc.Variable - close: sc.Variable - phase: sc.Variable - open_times: sc.Variable - close_times: sc.Variable - data: sc.DataArray - - def _repr_stats(self) -> str: - return ( - f"visible={int(self.data.sum().value)}, " - f"blocked={int(self.data.masks['blocked_by_me'].sum().value)}" + to, tc = self.open_close_times(time_limit=time_limit) + return ChopperReading( + distance=self.distance, + name=self.name, + frequency=self.frequency, + phase=self.phase, + open=self.open, + close=self.close, + open_times=to, + close_times=tc, + data=neutrons, ) - def __repr__(self) -> str: - return f"""ChopperReading: '{self.name}' - distance: {self.distance:c} - frequency: {self.frequency:c} - phase: {self.phase:c} - cutouts: {len(self.open)} - neutrons: {self._repr_stats()} -""" - - def __str__(self) -> str: - return self.__repr__() - - def __getitem__(self, val: int | slice | tuple[str, int | slice]) -> ChopperReading: - if isinstance(val, int): - val = ('pulse', val) - return replace(self, data=self.data[val]) + def apply( + self, neutrons: sc.DataArray, time_limit: sc.Variable + ) -> tuple[sc.DataArray, ChopperReading]: + """ + Apply the effect of the chopper to the given neutrons. + """ + # Apply the chopper's open/close times to the data + m = sc.zeros(sizes=neutrons.sizes, unit=None, dtype=bool) + to, tc = self.open_close_times(time_limit=time_limit) + for i in range(len(to)): + m |= (neutrons.coords['toa'] > to[i]) & (neutrons.coords['toa'] < tc[i]) + neutrons.masks['blocked_by_me'] = (~m) & (~neutrons.masks['blocked_by_others']) + + return neutrons, self.as_readonly(neutrons, time_limit=time_limit) diff --git a/src/tof/reading.py b/src/tof/component.py similarity index 85% rename from src/tof/reading.py rename to src/tof/component.py index 62887de..628ff85 100644 --- a/src/tof/reading.py +++ b/src/tof/component.py @@ -1,8 +1,9 @@ # SPDX-License-Identifier: BSD-3-Clause -# Copyright (c) 2025 Scipp contributors (https://github.com/scipp) +# Copyright (c) 2026 Scipp contributors (https://github.com/scipp) from __future__ import annotations +from abc import abstractmethod from dataclasses import dataclass import plopp as pp @@ -66,11 +67,7 @@ def __getitem__(self, val: int | slice | tuple[str, int | slice]) -> ReadingFiel def _make_reading_field(da: sc.DataArray, dim: str) -> ReadingField: return ReadingField( - data=sc.DataArray( - data=da.data, - coords={dim: da.coords[dim]}, - masks=da.masks, - ), + data=sc.DataArray(data=da.data, coords={dim: da.coords[dim]}, masks=da.masks), dim=dim, ) @@ -133,3 +130,25 @@ def plot(self, bins: int = 300) -> Plot: Number of bins to use for histogramming the neutrons. """ return self.toa.plot(bins=bins) + self.wavelength.plot(bins=bins) + + +class Component: + kind: str + + @abstractmethod + def apply( + self, neutrons: sc.DataArray, time_limit: sc.Variable + ) -> tuple[sc.DataArray, ComponentReading]: + """ + Apply the component to the given neutrons. + + Parameters + ---------- + neutrons: + The neutrons to which the component will be applied. + + Returns + ------- + The modified neutrons. + """ + raise NotImplementedError diff --git a/src/tof/dashboard.py b/src/tof/dashboard.py index d9fbae2..28e4f21 100644 --- a/src/tof/dashboard.py +++ b/src/tof/dashboard.py @@ -259,24 +259,25 @@ def populate_from_instrument(self, change): for det in self.detectors_container.children: self.remove_detector(None, uid=det._uid) params = INSTRUMENT_LIBRARY[change["new"]] - for ch in params["choppers"]: - self.add_chopper(None) - chop = self.choppers_container.children[-1] - chop.frequency_widget.value = ch.frequency.to(unit='Hz').value - chop.open_widget.value = ", ".join( - str(x) for x in ch.open.to(unit='deg').values - ) - chop.close_widget.value = ", ".join( - str(x) for x in ch.close.to(unit='deg').values - ) - chop.phase_widget.value = ch.phase.to(unit='deg').value - chop.distance_widget.value = ch.distance.to(unit='m').value - chop.name_widget.value = ch.name - for d in params["detectors"]: - self.add_detector(None) - det = self.detectors_container.children[-1] - det.distance_widget.value = d.distance.to(unit='m').value - det.name_widget.value = d.name + for comp in params["components"]: + if comp.kind == "chopper": + self.add_chopper(None) + chop = self.choppers_container.children[-1] + chop.frequency_widget.value = comp.frequency.to(unit='Hz').value + chop.open_widget.value = ", ".join( + str(x) for x in comp.open.to(unit='deg').values + ) + chop.close_widget.value = ", ".join( + str(x) for x in comp.close.to(unit='deg').values + ) + chop.phase_widget.value = comp.phase.to(unit='deg').value + chop.distance_widget.value = comp.distance.to(unit='m').value + chop.name_widget.value = comp.name + elif comp.kind == "detector": + self.add_detector(None) + det = self.detectors_container.children[-1] + det.distance_widget.value = comp.distance.to(unit='m').value + det.name_widget.value = comp.name self.run(None) self.continuous_update.value = cont_update_value diff --git a/src/tof/detector.py b/src/tof/detector.py index 9dc8075..57027a8 100644 --- a/src/tof/detector.py +++ b/src/tof/detector.py @@ -6,11 +6,49 @@ import scipp as sc -from .reading import ComponentReading -from .utils import var_to_dict +from .component import Component, ComponentReading +from .utils import var_from_dict, var_to_dict -class Detector: +@dataclass(frozen=True) +class DetectorReading(ComponentReading): + """ + Read-only container for the neutrons that reach the detector. + """ + + distance: sc.Variable + name: str + data: sc.DataArray + + @property + def kind(self) -> str: + return "detector" + + def _repr_stats(self) -> str: + return f"visible={int(self.data.sum().value)}" + + def __repr__(self) -> str: + return f"""DetectorReading: '{self.name}' + distance: {self.distance:c} + neutrons: {self._repr_stats()} +""" + + def __str__(self) -> str: + return self.__repr__() + + def __getitem__( + self, val: int | slice | tuple[str, int | slice] + ) -> DetectorReading: + if isinstance(val, int): + val = ('pulse', val) + return replace(self, data=self.data[val]) + + def plot_on_time_distance_diagram(self, ax, tmax) -> None: + ax.plot([0, tmax], [self.distance.value] * 2, color="gray", lw=3) + ax.text(0, self.distance.value, self.name, ha="left", va="bottom", color="gray") + + +class Detector(Component): """ A detector component does not block any neutrons, it sees all neutrons passing through it. @@ -26,16 +64,32 @@ class Detector: def __init__(self, distance: sc.Variable, name: str): self.distance = distance.to(dtype=float, copy=False) self.name = name + self.kind = "detector" def __repr__(self) -> str: return f"Detector(name={self.name}, distance={self.distance:c})" + def __eq__(self, other: object) -> bool: + if not isinstance(other, Detector): + return NotImplemented + return self.name == other.name and sc.identical(self.distance, other.distance) + def as_dict(self) -> dict: """ Return the detector as a dictionary. """ return {'distance': self.distance, 'name': self.name} + @classmethod + def from_json(cls, name: str, params: dict) -> Detector: + """ + Create a detector from a JSON-serializable dictionary. + """ + return cls( + distance=var_from_dict(params["distance"]), + name=name, + ) + def as_json(self) -> dict: """ Return the detector as a JSON-serializable dictionary. @@ -48,37 +102,22 @@ def as_json(self) -> dict: 'name': self.name, } - def __eq__(self, other: object) -> bool: - if not isinstance(other, Detector): - return NotImplemented - return self.name == other.name and sc.identical(self.distance, other.distance) - + def as_readonly(self, neutrons: sc.DataArray) -> DetectorReading: + return DetectorReading(distance=self.distance, name=self.name, data=neutrons) -@dataclass(frozen=True) -class DetectorReading(ComponentReading): - """ - Read-only container for the neutrons that reach the detector. - """ - - distance: sc.Variable - name: str - data: sc.DataArray - - def _repr_stats(self) -> str: - return f"visible={int(self.data.sum().value)}" - - def __repr__(self) -> str: - return f"""DetectorReading: '{self.name}' - distance: {self.distance:c} - neutrons: {self._repr_stats()} -""" - - def __str__(self) -> str: - return self.__repr__() - - def __getitem__( - self, val: int | slice | tuple[str, int | slice] - ) -> DetectorReading: - if isinstance(val, int): - val = ('pulse', val) - return replace(self, data=self.data[val]) + def apply( + self, neutrons: sc.DataArray, time_limit: sc.Variable + ) -> tuple[sc.DataArray, DetectorReading]: + """ + Apply the detector to the given neutrons. + A detector does not modify the neutrons, it simply records them without + blocking any. + + Parameters + ---------- + neutrons: + The neutrons to which the detector will be applied. + time_limit: + The time limit for the neutrons to be considered as reaching the detector. + """ + return neutrons, self.as_readonly(neutrons) diff --git a/src/tof/inelastic.py b/src/tof/inelastic.py new file mode 100644 index 0000000..3e412fd --- /dev/null +++ b/src/tof/inelastic.py @@ -0,0 +1,182 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2026 Scipp contributors (https://github.com/scipp) +from __future__ import annotations + +from dataclasses import dataclass, replace + +import numpy as np +import plopp as pp +import scipp as sc + +from .component import Component, ComponentReading +from .utils import ( + energy_to_wavelength, + var_from_dict, + var_to_dict, + wavelength_to_energy, + wavelength_to_speed, +) + + +@dataclass(frozen=True) +class InelasticSampleReading(ComponentReading): + """ + Read-only container for the neutrons that reach the inelastic sample. + """ + + distance: sc.Variable + name: str + data: sc.DataArray + + @property + def kind(self) -> str: + return "inelastic_sample" + + def _repr_stats(self) -> str: + return f"visible={int(self.data.sum().value)}" + + def __repr__(self) -> str: + return f"""InelasticSampleReading: '{self.name}' + distance: {self.distance:c} + neutrons: {self._repr_stats()} +""" + + def __str__(self) -> str: + return self.__repr__() + + def __getitem__( + self, val: int | slice | tuple[str, int | slice] + ) -> InelasticSampleReading: + if isinstance(val, int): + val = ('pulse', val) + return replace(self, data=self.data[val]) + + def plot_on_time_distance_diagram(self, ax, tmax) -> None: + ax.plot([0, tmax], [self.distance.value] * 2, color="tab:brown", lw=4) + ax.text( + 0, self.distance.value, self.name, ha="left", va="bottom", color="tab:brown" + ) + + +class InelasticSample(Component): + """ + An inelastic sample component changes the energy of the neutrons that pass through + it, but does not block any. + + Parameters + ---------- + distance: + The distance from the source to the inelastic sample. + name: + The name of the inelastic sample. + delta_e: + The change in energy of the neutrons when they pass through the inelastic + sample. The values of the array represent the probability of a neutron to have + its energy changed by the corresponding amount in the coordinates. The + coordinate values should be in energy units, and the array should be 1D. + seed: + The seed for the random number generator used to apply the energy change. + """ + + def __init__( + self, + distance: sc.Variable, + name: str, + delta_e: sc.DataArray, + seed: int | None = None, + ): + self.distance = distance.to(dtype=float, copy=False) + self.name = name + if delta_e.ndim != 1: + raise ValueError("delta_e must be a 1D array.") + self.probabilities = delta_e.data / delta_e.data.sum() + dim = delta_e.dim + self.energies = delta_e.coords[dim] + # TODO: check for bin edges + self._noise_scale = ( + 0.5 + * (self.energies.max() - self.energies.min()).value + / (max(len(delta_e), 2) - 1) + ) + self.kind = "inelastic_sample" + self.seed = seed + self._rng = np.random.default_rng(self.seed) + + def __repr__(self) -> str: + return f"InelasticSample(name={self.name}, distance={self.distance:c})" + + def plot(self, **kwargs) -> pp.FigureLike: + return pp.xyplot(self.energies, self.probabilities, **kwargs) + + def as_dict(self) -> dict: + """ + Return the inelastic sample as a dictionary. + """ + return {'distance': self.distance, 'name': self.name, 'delta_e': self.delta_e} + + @classmethod + def from_json(cls, name: str, params: dict) -> InelasticSample: + """ + Create an inelastic sample from a JSON-serializable dictionary. + """ + return cls( + distance=var_from_dict(params["distance"]), + name=name, + delta_e=sc.DataArray( + data=var_from_dict(params['probabilities'], dim='e'), + coords={'e': var_from_dict(params['energies'], dim='e')}, + ), + seed=params.get("seed"), + ) + + def as_json(self) -> dict: + """ + Return the inelastic sample as a JSON-serializable dictionary. + .. versionadded:: 26.03.0 + """ + return { + 'type': 'inelastic_sample', + 'distance': var_to_dict(self.distance), + 'name': self.name, + 'energies': var_to_dict(self.energies), + 'probabilities': var_to_dict(self.probabilities), + 'seed': self.seed, + } + + def as_readonly(self, neutrons: sc.DataArray) -> InelasticSampleReading: + return InelasticSampleReading( + distance=self.distance, name=self.name, data=neutrons + ) + + def apply( + self, neutrons: sc.DataArray, time_limit: sc.Variable + ) -> tuple[sc.DataArray, InelasticSampleReading]: + """ + Apply the change in energy to the given neutrons. + + Parameters + ---------- + neutrons: + The neutrons to which the inelastic sample will be applied. + time_limit: + The time limit for the neutrons to be considered as reaching the inelastic + sample. + """ + w_initial = neutrons.coords["wavelength"] + + n = neutrons.shape + inds = self._rng.choice(len(self.energies), size=n, p=self.probabilities.values) + de = sc.array( + dims=w_initial.dims, + values=self.energies.values[inds] + + self._rng.normal(scale=self._noise_scale, size=n), + unit=self.energies.unit, + ) + # Convert energy change to wavelength change + w_final = energy_to_wavelength( + wavelength_to_energy(w_initial, unit=de.unit) + de, unit=w_initial.unit + ) + neutrons = neutrons.assign_coords( + wavelength=w_final, speed=wavelength_to_speed(w_final) + ) + return neutrons, self.as_readonly(neutrons) diff --git a/src/tof/model.py b/src/tof/model.py index 36c241f..a169144 100644 --- a/src/tof/model.py +++ b/src/tof/model.py @@ -5,27 +5,20 @@ import warnings from itertools import chain +from types import MappingProxyType import scipp as sc -from .chopper import AntiClockwise, Chopper, Clockwise +from .chopper import Chopper +from .component import Component from .detector import Detector from .result import Result from .source import Source +from .utils import extract_component_group ComponentType = Chopper | Detector -def _array_or_none(container: dict, key: str) -> sc.Variable | None: - return ( - sc.array( - dims=["cutout"], values=container[key]["value"], unit=container[key]["unit"] - ) - if key in container - else None - ) - - def make_beamline(instrument: dict) -> dict[str, list[Chopper] | list[Detector]]: """ Create choppers and detectors from a dictionary. @@ -42,85 +35,85 @@ def make_beamline(instrument: dict) -> dict[str, list[Chopper] | list[Detector]] type, see the documentation of the :class:`Chopper` and :class:`Detector` classes for details. """ - choppers = [] - detectors = [] + beamline = {"components": []} + mapping = {"chopper": Chopper, "detector": Detector} for name, comp in instrument.items(): - if comp["type"] == "chopper": - direction = comp["direction"].lower() - if direction == "clockwise": - _dir = Clockwise - elif any(x in direction for x in ("anti", "counter")): - _dir = AntiClockwise - else: + if comp["type"] == "source": + if "source" in beamline: raise ValueError( - f"Chopper direction must be 'clockwise' or 'anti-clockwise', got " - f"'{comp['direction']}' for component {name}." - ) - choppers.append( - Chopper( - frequency=comp["frequency"]["value"] - * sc.Unit(comp["frequency"]["unit"]), - direction=_dir, - open=_array_or_none(comp, "open"), - close=_array_or_none(comp, "close"), - centers=_array_or_none(comp, "centers"), - widths=_array_or_none(comp, "widths"), - phase=comp["phase"]["value"] * sc.Unit(comp["phase"]["unit"]), - distance=comp["distance"]["value"] - * sc.Unit(comp["distance"]["unit"]), - name=name, + "Only one source is allowed, but multiple were found in the" + "instrument parameters." ) - ) - elif comp["type"] == "detector": - detectors.append( - Detector( - distance=comp["distance"]["value"] - * sc.Unit(comp["distance"]["unit"]), - name=name, - ) - ) - elif comp["type"] == "source": + beamline["source"] = Source.from_json(params=comp) continue - else: + if comp["type"] not in mapping: raise ValueError( f"Unknown component type: {comp['type']} for component {name}. " - "Supported types are 'chopper', 'detector', and 'source'." ) - return {"choppers": choppers, "detectors": detectors} + beamline["components"].append( + mapping[comp["type"]].from_json(name=name, params=comp) + ) + return beamline class Model: """ A class that represents a neutron instrument. - It is defined by a list of choppers, a list of detectors, and a source. + It is defined by a source and a list of components (choppers, detectors, etc.). Parameters ---------- - choppers: - A list of choppers. - detectors: - A list of detectors. source: A source of neutrons. + components: + A list of components. + choppers: + A list of choppers. This is kept for backwards-compatibility; new code + should use the `components` parameter instead. + detectors: + A list of detectors. This is kept for backwards-compatibility; new code + should use the `components` parameter instead. """ def __init__( self, source: Source | None = None, + components: list[Component] | tuple[Component, ...] | None = None, choppers: list[Chopper] | tuple[Chopper, ...] | None = None, detectors: list[Detector] | tuple[Detector, ...] | None = None, ): - self.choppers = {} - self.detectors = {} self.source = source - for components, kind in ((choppers, Chopper), (detectors, Detector)): - for c in components or (): - if not isinstance(c, kind): - raise TypeError( - f"Beamline components: expected {kind.__name__} instance, " - f"got {type(c)}." - ) - self.add(c) + self._components = {} + for comp in chain((choppers or ()), (detectors or ()), (components or ())): + self.add(comp) + + @property + def components(self) -> dict[str, Component]: + """ + A dictionary of the components in the instrument. + """ + return self._components + + @property + def choppers(self) -> MappingProxyType[str, Chopper]: + """ + A dictionary of the choppers in the instrument. + """ + return extract_component_group(self._components, "chopper") + + @property + def detectors(self) -> MappingProxyType[str, Detector]: + """ + A dictionary of the detectors in the instrument. + """ + return extract_component_group(self._components, "detector") + + @property + def samples(self) -> MappingProxyType[str, Component]: + """ + A dictionary of the samples in the instrument. + """ + return extract_component_group(self._components, "sample") @classmethod def from_json(cls, filename: str) -> Model: @@ -140,20 +133,7 @@ def from_json(cls, filename: str) -> Model: with open(filename) as f: instrument = json.load(f) - beamline = make_beamline(instrument) - source = None - for item in instrument.values(): - if item.get("type") == "source": - if "facility" not in item: - raise ValueError( - "Currently, only sources from facilities are supported when " - "loading from JSON." - ) - source_args = item.copy() - del source_args["type"] - source = Source(**source_args) - break - return cls(source=source, **beamline) + return cls(**make_beamline(instrument)) def as_json(self) -> dict: """ @@ -173,13 +153,11 @@ def as_json(self) -> dict: ) else: instrument_dict['source'] = self.source.as_json() - for ch in self.choppers.values(): - instrument_dict[ch.name] = ch.as_json() - for det in self.detectors.values(): - instrument_dict[det.name] = det.as_json() + for comp in self._components.values(): + instrument_dict[comp.name] = comp.as_json() return instrument_dict - def to_json(self, filename: str): + def to_json(self, filename: str) -> None: """ Save the model to a JSON file. If the source is not from a facility, it is not included in the output. @@ -196,7 +174,7 @@ def to_json(self, filename: str): with open(filename, 'w') as f: json.dump(self.as_json(), f, indent=2) - def add(self, component: Chopper | Detector): + def add(self, component: Component) -> None: """ Add a component to the instrument. Component names must be unique across choppers and detectors. @@ -208,21 +186,19 @@ def add(self, component: Chopper | Detector): component: A chopper or detector. """ - if not isinstance(component, (Chopper | Detector)): + if not isinstance(component, Component): raise TypeError( - f"Cannot add component of type {type(component)} to the model. " - "Only Chopper and Detector instances are allowed." + "Component must be an instance of Component or derived class, " + f"but got {type(component)}." ) # Note that the name "source" is reserved for the source. - if component.name in chain(self.choppers, self.detectors, ("source",)): + if component.name in (*self._components, "source"): raise KeyError( f"Component with name {component.name} already exists. " "If you wish to replace/update an existing component, use " - "``model.choppers['name'] = new_chopper`` or " - "``model.detectors['name'] = new_detector``." + "``model.components['name'] = new_component``." ) - container = self.choppers if isinstance(component, Chopper) else self.detectors - container[component.name] = component + self._components[component.name] = component def remove(self, name: str): """ @@ -233,25 +209,9 @@ def remove(self, name: str): name: The name of the component to remove. """ - if name in self.choppers: - del self.choppers[name] - elif name in self.detectors: - del self.detectors[name] - else: - raise KeyError(f"No component with name {name} was found.") - - def __iter__(self): - return chain(self.choppers, self.detectors) - - def __getitem__(self, name) -> Chopper | Detector: - if name not in self: - raise KeyError(f"No component with name {name} was found.") - return self.choppers[name] if name in self.choppers else self.detectors[name] - - def __delitem__(self, name): - self.remove(name) + del self._components[name] - def run(self): + def run(self) -> Result: """ Run the simulation. """ @@ -260,10 +220,7 @@ def run(self): "No source has been defined for this model. Please add a source using " "`model.source = Source(...)` before running the simulation." ) - components = sorted( - chain(self.choppers.values(), self.detectors.values()), - key=lambda c: c.distance.value, - ) + components = sorted(self._components.values(), key=lambda c: c.distance.value) if len(components) == 0: raise ValueError("Cannot run model: no components have been defined.") @@ -274,56 +231,59 @@ def run(self): "itself. Please check the distances of the components." ) - birth_time = self.source.data.coords['birth_time'] - speed = self.source.data.coords['speed'] - initial_mask = sc.ones(sizes=birth_time.sizes, unit=None, dtype=bool) + neutrons = self.source.data.copy(deep=False) + neutrons.masks["blocked_by_others"] = sc.zeros( + sizes=neutrons.sizes, unit=None, dtype=bool + ) + neutrons.coords.update( + distance=self.source.distance, toa=neutrons.coords['birth_time'] + ) - result_choppers = {} - result_detectors = {} + time_unit = neutrons.coords['birth_time'].unit + + readings = {} time_limit = ( - birth_time - + ((components[-1].distance - self.source.distance) / speed).to( - unit=birth_time.unit - ) + neutrons.coords['birth_time'] + + ( + (components[-1].distance - self.source.distance) + / neutrons.coords['speed'] + ).to(unit=time_unit) ).max() - for c in components: - container = result_detectors if isinstance(c, Detector) else result_choppers - container[c.name] = c.as_dict() - container[c.name]['data'] = self.source.data.copy(deep=False) - tof = ((c.distance - self.source.distance) / speed).to( - unit=birth_time.unit, copy=False + for comp in components: + neutrons = neutrons.copy(deep=False) + toa = neutrons.coords['toa'] + ( + (comp.distance - neutrons.coords['distance']) / neutrons.coords['speed'] + ).to(unit=time_unit, copy=False) + neutrons.coords['toa'] = toa + neutrons.coords['eto'] = toa % (1 / self.source.frequency).to( + unit=time_unit, copy=False ) - t = birth_time + tof - container[c.name]['data'].coords['toa'] = t - container[c.name]['data'].coords['eto'] = t % ( - 1 / self.source.frequency - ).to(unit=t.unit, copy=False) - container[c.name]['data'].coords['distance'] = c.distance - container[c.name]['data'].coords['tof'] = tof - if isinstance(c, Detector): - container[c.name]['data'].masks['blocked_by_others'] = ~initial_mask - continue - m = sc.zeros(sizes=t.sizes, unit=None, dtype=bool) - to, tc = c.open_close_times(time_limit=time_limit) - container[c.name].update({'open_times': to, 'close_times': tc}) - for i in range(len(to)): - m |= (t > to[i]) & (t < tc[i]) - combined = initial_mask & m - container[c.name]['data'].masks['blocked_by_others'] = ~initial_mask - container[c.name]['data'].masks['blocked_by_me'] = ~m & initial_mask - initial_mask = combined - - return Result( - source=self.source, choppers=result_choppers, detectors=result_detectors - ) + neutrons.coords['distance'] = comp.distance + + if "blocked_by_me" in neutrons.masks: + # Because we use shallow copies, we do not want to do an in-place |= + # operation here + neutrons.masks['blocked_by_others'] = neutrons.masks[ + 'blocked_by_others' + ] | neutrons.masks.pop('blocked_by_me') + + neutrons, reading = comp.apply(neutrons=neutrons, time_limit=time_limit) + readings[comp.name] = reading + + return Result(source=self.source.as_readonly(), readings=readings) def __repr__(self) -> str: - out = f"Model:\n Source: {self.source}\n Choppers:\n" - for name, ch in self.choppers.items(): - out += f" {name}: {ch}\n" - out += " Detectors:\n" - for name, det in self.detectors.items(): - out += f" {name}: {det}\n" + out = f"Model:\n Source: {self.source}\n" + groups = {} + for comp in self._components.values(): + if comp.kind not in groups: + groups[comp.kind] = [] + groups[comp.kind].append(comp) + + for group, comps in groups.items(): + out += f" {group.capitalize()}s:\n" + for comp in sorted(comps, key=lambda c: c.distance): + out += f" {comp.name}: {comp}\n" return out def __str__(self) -> str: diff --git a/src/tof/result.py b/src/tof/result.py index 0e3ad2b..cdb1bf1 100644 --- a/src/tof/result.py +++ b/src/tof/result.py @@ -11,10 +11,37 @@ import scipp as sc from matplotlib.collections import LineCollection -from .chopper import Chopper, ChopperReading -from .detector import Detector, DetectorReading -from .source import Source, SourceParameters -from .utils import Plot, one_mask +from .chopper import ChopperReading +from .component import ComponentReading +from .detector import DetectorReading +from .source import SourceReading +from .utils import Plot, extract_component_group, one_mask + + +def _get_rays( + components: list[ComponentReading], pulse: int, inds: np.ndarray +) -> tuple[np.ndarray, np.ndarray]: + x = [] + y = [] + c = [] + data = components[0].data["pulse", pulse] + xstart = data.coords["toa"].values[inds] + ystart = np.full_like(xstart, components[0].distance.value) + color = data.coords["wavelength"].values[inds] + for comp in components[1:]: + xend = comp.data["pulse", pulse].coords["toa"].values[inds] + yend = np.full_like(xend, comp.distance.value) + x.append([xstart, xend]) + y.append([ystart, yend]) + c.append(color) + xstart, ystart = xend, yend + color = comp.data["pulse", pulse].coords["wavelength"].values[inds] + + return ( + np.array(x).transpose((0, 2, 1)), + np.array(y).transpose((0, 2, 1)), + np.array(c), + ) def _add_rays( @@ -29,12 +56,13 @@ def _add_rays( cax: plt.Axes | None = None, zorder: int = 1, ): + x, y = (a.reshape((-1, 2)) for a in (x, y)) coll = LineCollection(np.stack((x, y), axis=2), zorder=zorder) if isinstance(color, str): coll.set_color(color) else: coll.set_cmap(plt.colormaps[cmap]) - coll.set_array(color) + coll.set_array(color.ravel()) coll.set_norm(plt.Normalize(vmin, vmax)) if cbar: cb = plt.colorbar(coll, ax=ax, cax=cax) @@ -51,64 +79,45 @@ class Result: ---------- source: The source of neutrons. - choppers: - The choppers in the model. - detectors: - The detectors in the model. + results: + The state of neutrons at each component in the model. """ - def __init__( - self, - source: Source, - choppers: dict[str, Chopper], - detectors: dict[str, Detector], - ): - self._source = source.as_readonly() - self._choppers = {} - for name, chopper in choppers.items(): - self._choppers[name] = ChopperReading( - distance=chopper["distance"], - name=chopper["name"], - frequency=chopper["frequency"], - open=chopper["open"], - close=chopper["close"], - phase=chopper["phase"], - open_times=chopper["open_times"], - close_times=chopper["close_times"], - data=chopper["data"], - ) - - self._detectors = {} - for name, det in detectors.items(): - self._detectors[name] = DetectorReading( - distance=det["distance"], name=det["name"], data=det["data"] - ) - - self._choppers = MappingProxyType(self._choppers) - self._detectors = MappingProxyType(self._detectors) + def __init__(self, source: SourceReading, readings: dict[str, dict]): + self._source = source + self._components = MappingProxyType(readings) @property def choppers(self) -> MappingProxyType[str, ChopperReading]: - """The choppers in the model.""" - return self._choppers + """ + A dictionary of the choppers in the instrument. + """ + return extract_component_group(self._components, "chopper") @property def detectors(self) -> MappingProxyType[str, DetectorReading]: - """The detectors in the model.""" - return self._detectors + """ + A dictionary of the detectors in the instrument. + """ + return extract_component_group(self._components, "detector") @property - def source(self) -> SourceParameters: + def samples(self) -> MappingProxyType[str, ComponentReading]: + """ + A dictionary of the samples in the instrument. + """ + return extract_component_group(self._components, "sample") + + @property + def source(self) -> SourceReading: """The source of neutrons.""" return self._source def __iter__(self): - return chain(self._choppers, self._detectors) + return iter(self._components) - def __getitem__(self, name: str) -> ChopperReading | DetectorReading: - if name not in self: - raise KeyError(f"No component with name {name} was found.") - return self._choppers[name] if name in self._choppers else self._detectors[name] + def __getitem__(self, name: str) -> ComponentReading: + return self._components[name] def plot( self, @@ -122,6 +131,7 @@ def plot( seed: int | None = None, vmin: float | None = None, vmax: float | None = None, + title: str | None = None, ) -> Plot: """ Plot the time-distance diagram for the instrument, including the rays of @@ -158,13 +168,11 @@ def plot( fig, ax = plt.subplots(figsize=figsize) else: fig = ax.get_figure() + components = sorted( - chain(self.choppers.values(), self.detectors.values()), - key=lambda c: c.distance, + chain((self.source,), self._components.values()), key=lambda c: c.distance ) furthest_component = components[-1] - source_dist = self.source.distance.value - repeats = [1] + [2] * len(components) wavelengths = sc.DataArray( data=furthest_component.data.coords["wavelength"], @@ -173,11 +181,12 @@ def plot( wmin, wmax = wavelengths.min(), wavelengths.max() rng = np.random.default_rng(seed) + # Make ids for neutrons per pulse, instead of using their id coord + ids = np.arange(self.source.neutrons) for i in range(self._source.data.sizes["pulse"]): - source_data = self.source.data["pulse", i] component_data = furthest_component.data["pulse", i] - ids = np.arange(self.source.neutrons) + # Plot visible rays blocked = one_mask(component_data.masks).values nblocked = int(blocked.sum()) @@ -187,17 +196,12 @@ def plot( size=min(self.source.neutrons - nblocked, visible_rays), replace=False, ) - - xstart = source_data.coords["birth_time"].values[inds] - xend = component_data.coords["toa"].values[inds] - ystart = np.full_like(xstart, source_dist) - yend = np.full_like(ystart, furthest_component.distance.value) - + x, y, c = _get_rays(components, pulse=i, inds=inds) _add_rays( ax=ax, - x=np.stack((xstart, xend), axis=1), - y=np.stack((ystart, yend), axis=1), - color=source_data.coords["wavelength"].values[inds], + x=x, + y=y, + color=c, cbar=cbar and (i == 0), cmap=cmap, vmin=wmin.value if vmin is None else vmin, @@ -209,94 +213,59 @@ def plot( inds = rng.choice( ids[blocked], size=min(blocked_rays, nblocked), replace=False ) - x = np.repeat( - np.stack( - [source_data.coords["birth_time"].values[inds]] - + [ - c.data.coords["toa"]["pulse", i].values[inds] - for c in components - ], - axis=1, - ), - repeats, + x, y, _ = _get_rays(components, pulse=i, inds=inds) + blocked_by_others = np.stack( + [ + comp.data["pulse", i].masks["blocked_by_others"].values[inds] + for comp in components[1:] + ], axis=1, + ).T + line_selection = np.broadcast_to( + blocked_by_others.reshape((*blocked_by_others.shape, 1)), x.shape ) - y = np.repeat( - np.stack( - [np.full_like(x[:, 0], source_dist)] - + [np.full_like(x[:, 0], c.distance.value) for c in components], - axis=1, - ), - repeats, - axis=1, - ) - for j, c in enumerate(components): - comp_data = c.data["pulse", i] - m_others = comp_data.masks["blocked_by_others"].values[inds] - x[:, 2 * j + 1][m_others] = np.nan - y[:, 2 * j + 1][m_others] = np.nan - if "blocked_by_me" in comp_data.masks: - m_me = comp_data.masks["blocked_by_me"].values[inds] - x[:, 2 * j + 2][m_me] = np.nan - y[:, 2 * j + 2][m_me] = np.nan + x[line_selection] = np.nan + y[line_selection] = np.nan _add_rays(ax=ax, x=x, y=y, color="lightgray", zorder=-1) # Plot pulse - time_coord = source_data.coords["birth_time"].values - tmin = time_coord.min() - ax.plot([tmin, time_coord.max()], [source_dist] * 2, color="gray", lw=3) - ax.text(tmin, source_dist, "Pulse", ha="left", va="top", color="gray") + self.source.plot_on_time_distance_diagram(ax, pulse=i) if furthest_component.toa.data.sum().value > 0: toa_max = furthest_component.toa.max().value else: toa_max = furthest_component.toa.data.coords["toa"].max().value - dx = 0.05 * toa_max - # Plot choppers - for ch in self._choppers.values(): - x0 = ch.open_times.values - x1 = ch.close_times.values - x = np.empty(3 * x0.size, dtype=x0.dtype) - x[0::3] = x0 - x[1::3] = 0.5 * (x0 + x1) - x[2::3] = x1 - x = np.concatenate( - ([[0]] if x[0] > 0 else [x[0:1]]) - + [x] - + ([[toa_max + dx]] if x[-1] < toa_max else []) - ) - y = np.full_like(x, ch.distance.value) - y[2::3] = None - inds = np.argsort(x) - ax.plot(x[inds], y[inds], color="k") - ax.text( - toa_max, ch.distance.value, ch.name, ha="right", va="bottom", color="k" - ) - # Plot detectors - for det in self._detectors.values(): - ax.plot([0, toa_max], [det.distance.value] * 2, color="gray", lw=3) - ax.text( - 0, det.distance.value, det.name, ha="left", va="bottom", color="gray" - ) + # Plot components + for comp in self._components.values(): + comp.plot_on_time_distance_diagram(ax=ax, tmax=toa_max) + dx = 0.05 * toa_max ax.set(xlabel="Time [μs]", ylabel="Distance [m]") ax.set_xlim(0 - dx, toa_max + dx) if figsize is None: inches = fig.get_size_inches() fig.set_size_inches((min(inches[0] * self.source.pulses, 12.0), inches[1])) fig.tight_layout() + if title is not None: + ax.set_title(title) return Plot(fig=fig, ax=ax) def __repr__(self) -> str: out = ( f"Result:\n Source: {self.source.pulses} pulses, " - f"{self.source.neutrons} neutrons per pulse.\n Choppers:\n" + f"{self.source.neutrons} neutrons per pulse.\n" ) - for name, ch in self._choppers.items(): - out += f" {name}: {ch._repr_stats()}\n" - out += " Detectors:\n" - for name, det in self._detectors.items(): - out += f" {name}: {det._repr_stats()}\n" + groups = {} + for comp in self._components.values(): + if comp.kind not in groups: + groups[comp.kind] = [] + groups[comp.kind].append(comp) + + for group, comps in groups.items(): + out += f" {group.capitalize()}s:\n" + for comp in sorted(comps, key=lambda c: c.distance): + out += f" {comp.name}: {comp._repr_stats()}\n" + return out def __str__(self) -> str: @@ -315,11 +284,12 @@ def to_nxevent_data(self, key: str | None = None) -> sc.DataArray: start = sc.datetime("2024-01-01T12:00:00.000000") period = sc.reciprocal(self.source.frequency) - keys = list(self._detectors.keys()) if key is None else [key] + detectors = self.detectors + keys = list(detectors.keys()) if key is None else [key] event_data = [] for name in keys: - raw_data = self._detectors[name].data.flatten(to="event") + raw_data = detectors[name].data.flatten(to="event") events = ( raw_data[~raw_data.masks["blocked_by_others"]] .copy() @@ -338,7 +308,7 @@ def to_nxevent_data(self, key: str | None = None) -> sc.DataArray: "toa" ) % period.to(unit=dt.unit) out = ( - event_data.drop_coords(["tof", "speed", "birth_time", "wavelength"]) + event_data.drop_coords(["speed", "birth_time", "wavelength"]) .group("distance") .rename_dims(distance="detector_number") ) diff --git a/src/tof/source.py b/src/tof/source.py index 504d4c3..fc0489f 100644 --- a/src/tof/source.py +++ b/src/tof/source.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: BSD-3-Clause # Copyright (c) 2025 Scipp contributors (https://github.com/scipp) +from __future__ import annotations import warnings from dataclasses import dataclass @@ -8,6 +9,7 @@ import plopp as pp import scipp as sc +from .component import ComponentReading from .utils import wavelength_to_speed TIME_UNIT = "us" @@ -24,6 +26,18 @@ def _default_frequency(frequency: sc.Variable | None, pulses: int) -> sc.Variabl return frequency +def _bin_edges_to_midpoints( + da: sc.DataArray, dims: list[str] | tuple[str] +) -> sc.DataArray: + return da.assign_coords( + { + dim: sc.midpoints(da.coords[dim], dim) + for dim in dims + if da.coords.is_edges(dim) + } + ) + + def _make_pulses( neutrons: int, frequency: sc.Variable, @@ -407,6 +421,14 @@ def from_distribution( distance if distance is not None else sc.scalar(0.0, unit="m") ) source._frequency = _default_frequency(frequency, pulses) + + if p is not None: + p = _bin_edges_to_midpoints(p, dims=["birth_time", "wavelength"]) + if p_time is not None: + p_time = _bin_edges_to_midpoints(p_time, dims=["birth_time"]) + if p_wav is not None: + p_wav = _bin_edges_to_midpoints(p_wav, dims=["wavelength"]) + pulse_params = _make_pulses( neutrons=neutrons, p=p, @@ -452,8 +474,10 @@ def plot(self, bins: int = 300) -> tuple: return f1 + f2 def as_readonly(self): - return SourceParameters( - data=self.data, + return SourceReading( + data=self.data.assign_masks( + blocked_by_others=sc.zeros_like(self.data.data, dtype=bool, unit=None) + ), facility=self.facility, neutrons=self.neutrons, frequency=self.frequency, @@ -469,6 +493,31 @@ def __repr__(self) -> str: f" distance={self.distance:c}" ) + @classmethod + def from_json(cls, params: dict) -> Source: + """ + Create a source from a JSON-serializable dictionary. + Currently, only sources from facilities are supported when loading from JSON. + + The dictionary should have the following format: + + .. code-block:: json + + { + "type": "source", + "facility": "ess", + "neutrons": 1000000, + "pulses": 1, + "seed": 42 + } + """ + if params.get("facility") is None: + raise ValueError( + "Currently, only sources from facilities are supported when " + "loading from JSON." + ) + return cls(**{k: v for k, v in params.items() if k != "type"}) + def as_json(self) -> dict: """ Return the source as a JSON-serializable dictionary. @@ -485,7 +534,7 @@ def as_json(self) -> dict: @dataclass(frozen=True) -class SourceParameters: +class SourceReading(ComponentReading): """ Read-only container for the parameters of a source. """ @@ -496,3 +545,14 @@ class SourceParameters: frequency: sc.Variable pulses: int distance: sc.Variable + + @property + def kind(self) -> str: + return "source" + + def plot_on_time_distance_diagram(self, ax, pulse) -> None: + birth_time = self.data.coords["birth_time"]["pulse", pulse] + tmin = birth_time.min().value + dist = self.distance.value + ax.plot([tmin, birth_time.max().value], [dist] * 2, color="gray", lw=3) + ax.text(tmin, dist, "Pulse", ha="left", va="top", color="gray") diff --git a/src/tof/utils.py b/src/tof/utils.py index 640d121..83356d5 100644 --- a/src/tof/utils.py +++ b/src/tof/utils.py @@ -6,6 +6,7 @@ from types import MappingProxyType import matplotlib.pyplot as plt +import numpy as np import scipp as sc import scipp.constants as const @@ -69,6 +70,34 @@ def energy_to_speed(x: sc.Variable, unit="m/s") -> sc.Variable: return sc.sqrt(x / (0.5 * const.m_n)).to(unit=unit) +def wavelength_to_energy(x: sc.Variable, unit="meV") -> sc.Variable: + """ + Convert neutron wavelengths to energies. + + Parameters + ---------- + x: + Input wavelengths. + unit: + The unit of the output energies. + """ + return speed_to_energy(wavelength_to_speed(x)).to(unit=unit) + + +def energy_to_wavelength(x: sc.Variable, unit="angstrom") -> sc.Variable: + """ + Convert neutron energies to wavelengths. + + Parameters + ---------- + x: + Input energies. + unit: + The unit of the output wavelengths. + """ + return speed_to_wavelength(energy_to_speed(x)).to(unit=unit) + + def one_mask( masks: MappingProxyType[str, sc.Variable], unit: str | None = None ) -> sc.Variable: @@ -102,6 +131,44 @@ def var_to_dict(var: sc.Variable) -> dict: } +def var_from_dict(data: dict, dim: str | None = None) -> sc.Variable: + """ + Convert a dictionary with 'value' and 'unit' keys to a scipp Variable. + + Parameters + ---------- + data: + The dictionary to convert. + dim: + The dimension of the output variable (non-scalar data only). + """ + values = np.asarray(data["value"]) + unit = data['unit'] + if values.shape: + if dim is None: + raise ValueError("Missing dimension to construct variable from json.") + return sc.array(dims=[dim], values=values, unit=unit) + return sc.scalar(values, unit=unit) + + +def extract_component_group( + components: dict | MappingProxyType, kind: str +) -> MappingProxyType: + """ + Extract a group of components of a given kind from a dictionary of components. + + Parameters + ---------- + components: + The components to extract from. + kind: + The kind of components to extract. + """ + return MappingProxyType( + {name: comp for name, comp in components.items() if kind in comp.kind} + ) + + @dataclass class Plot: ax: plt.Axes diff --git a/tests/inelastic_test.py b/tests/inelastic_test.py new file mode 100644 index 0000000..f39ad17 --- /dev/null +++ b/tests/inelastic_test.py @@ -0,0 +1,246 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2023 Scipp contributors (https://github.com/scipp) + +import numpy as np +import scipp as sc + +import tof + +Hz = sc.Unit('Hz') +deg = sc.Unit('deg') +meter = sc.Unit('m') + + +def test_inelastic_sample_flat_distribution(): + sample = tof.InelasticSample( + distance=28.0 * meter, + name="sample", + delta_e=sc.DataArray( + data=sc.ones(sizes={'e': 100}), + coords={'e': sc.linspace('e', -0.2, 0.2, 100, unit='meV')}, + ), + ) + + choppers = [ + tof.Chopper( + frequency=70.0 * Hz, + open=sc.array(dims=['cutout'], values=[0.0], unit='deg'), + close=sc.array(dims=['cutout'], values=[1.0], unit='deg'), + phase=0.0 * deg, + distance=20.0 * meter, + name="fastchopper", + ), + ] + + detectors = [ + tof.Detector(distance=26.0 * meter, name='monitor'), + tof.Detector(distance=32.0 * meter, name='detector'), + ] + + source = tof.Source(facility='ess', neutrons=500_000, seed=77) + + model = tof.Model(source=source, components=choppers + detectors + [sample]) + model_no_sample = tof.Model(source=source, components=choppers + detectors) + + res = model.run() + res_no_sample = model_no_sample.run() + + assert sc.identical( + res_no_sample['monitor'].data.coords['wavelength'], + res_no_sample['detector'].data.coords['wavelength'], + ) + assert not sc.identical( + res['monitor'].data.coords['wavelength'], + res['detector'].data.coords['wavelength'], + ) + assert not sc.allclose( + res['monitor'].data.coords['wavelength'], + res['detector'].data.coords['wavelength'], + ) + + +def test_inelastic_sample_doube_peaked_distribution(): + delta_e = sc.DataArray( + data=sc.zeros(sizes={'e': 100}), + coords={'e': sc.linspace('e', -0.2, 0.2, 100, unit='meV')}, + ) + delta_e.values[[0, -1]] = 1.0 + sample = tof.InelasticSample(distance=28.0 * meter, name="sample", delta_e=delta_e) + + choppers = [ + tof.Chopper( + frequency=70.0 * Hz, + open=sc.array(dims=['cutout'], values=[0.0], unit='deg'), + close=sc.array(dims=['cutout'], values=[1.0], unit='deg'), + phase=0.0 * deg, + distance=20.0 * meter, + name="fastchopper", + ), + ] + + detectors = [ + tof.Detector(distance=26.0 * meter, name='monitor'), + tof.Detector(distance=32.0 * meter, name='detector'), + ] + + source = tof.Source(facility='ess', neutrons=500_000, seed=78) + + model = tof.Model(source=source, components=choppers + detectors + [sample]) + model_no_sample = tof.Model(source=source, components=choppers + detectors) + + res = model.run() + res_no_sample = model_no_sample.run() + + assert sc.identical( + res_no_sample['monitor'].data.coords['wavelength'], + res_no_sample['detector'].data.coords['wavelength'], + ) + assert not sc.identical( + res['monitor'].data.coords['wavelength'], + res['detector'].data.coords['wavelength'], + ) + assert not sc.allclose( + res['monitor'].data.coords['wavelength'], + res['detector'].data.coords['wavelength'], + ) + + +def test_inelastic_sample_normal_distribution(): + x = sc.linspace('e', -0.2, 0.2, 100, unit='meV') + sig = sc.scalar(0.03, unit='meV') + y = 1.0 / (np.sqrt(2.0 * np.pi) * sig) * sc.exp(-((x / sig) ** 2) / 2) + y.unit = "" + + sample = tof.InelasticSample( + distance=28.0 * meter, + name="sample", + delta_e=sc.DataArray(data=y, coords={'e': x}), + ) + + choppers = [ + tof.Chopper( + frequency=70.0 * Hz, + open=sc.array(dims=['cutout'], values=[0.0], unit='deg'), + close=sc.array(dims=['cutout'], values=[1.0], unit='deg'), + phase=0.0 * deg, + distance=20.0 * meter, + name="fastchopper", + ), + ] + + detectors = [ + tof.Detector(distance=26.0 * meter, name='monitor'), + tof.Detector(distance=32.0 * meter, name='detector'), + ] + + source = tof.Source(facility='ess', neutrons=500_000, seed=78) + + model = tof.Model(source=source, components=choppers + detectors + [sample]) + model_no_sample = tof.Model(source=source, components=choppers + detectors) + + res = model.run() + res_no_sample = model_no_sample.run() + + assert sc.identical( + res_no_sample['monitor'].data.coords['wavelength'], + res_no_sample['detector'].data.coords['wavelength'], + ) + assert not sc.identical( + res['monitor'].data.coords['wavelength'], + res['detector'].data.coords['wavelength'], + ) + assert not sc.allclose( + res['monitor'].data.coords['wavelength'], + res['detector'].data.coords['wavelength'], + ) + + +def test_inelastic_sample_that_has_zero_delta_e(): + sample = tof.InelasticSample( + distance=28.0 * meter, + name="sample", + delta_e=sc.DataArray( + data=sc.array(dims=['e'], values=[1.0], unit=''), + coords={'e': sc.array(dims=['e'], values=[0.0], unit='meV')}, + ), + ) + + choppers = [ + tof.Chopper( + frequency=70.0 * Hz, + open=sc.array(dims=['cutout'], values=[0.0], unit='deg'), + close=sc.array(dims=['cutout'], values=[1.0], unit='deg'), + phase=0.0 * deg, + distance=20.0 * meter, + name="fastchopper", + ), + ] + + detectors = [ + tof.Detector(distance=26.0 * meter, name='monitor'), + tof.Detector(distance=32.0 * meter, name='detector'), + ] + + source = tof.Source(facility='ess', neutrons=500_000, seed=78) + + model = tof.Model(source=source, components=choppers + detectors + [sample]) + model_no_sample = tof.Model(source=source, components=choppers + detectors) + + res = model.run() + res_no_sample = model_no_sample.run() + + assert sc.identical( + res_no_sample['monitor'].data.coords['wavelength'], + res_no_sample['detector'].data.coords['wavelength'], + ) + assert sc.allclose( + res['monitor'].data.coords['wavelength'], + res['detector'].data.coords['wavelength'], + ) + + +def test_inelastic_sample_as_json(): + p = np.array([0.4, 0.45, 0.7, 1.0, 0.7, 0.45, 0.4]) + e = np.array([-0.5, -0.25, -0.1, 0.0, 0.1, 0.25, 0.5]) + + delta_e = sc.DataArray( + data=sc.array(dims=['e'], values=p), + coords={'e': sc.array(dims=['e'], values=e, unit='meV')}, + ) + sample = tof.InelasticSample( + distance=28.0 * meter, name="sample1", delta_e=delta_e, seed=66 + ) + + json_dict = sample.as_json() + assert json_dict['type'] == 'inelastic_sample' + assert json_dict['name'] == 'sample1' + assert json_dict['distance']['value'] == 28.0 + assert json_dict['distance']['unit'] == 'm' + assert np.array_equal(json_dict['probabilities']['value'], p / p.sum()) + assert json_dict['probabilities']['unit'] == 'dimensionless' + assert np.array_equal(json_dict['energies']['value'], e) + assert json_dict['energies']['unit'] == 'meV' + assert json_dict['seed'] == 66 + + +def test_inelastic_sample_from_json(): + p = np.array([0.4, 0.7, 1.0, 0.7, 0.4]) + e = np.array([-0.5, -0.25, 0.0, 0.25, 0.5]) + json_dict = { + 'type': 'inelastic_sample', + 'distance': {'value': 28.0, 'unit': 'm'}, + 'name': 'sample1', + 'energies': {'value': e, 'unit': 'meV'}, + 'probabilities': {'value': p, 'unit': ''}, + 'seed': 78, + } + sample = tof.InelasticSample.from_json(name=json_dict['name'], params=json_dict) + + assert sample.distance.value == 28.0 + assert sample.distance.unit == 'm' + assert sample.name == 'sample1' + assert np.array_equal(sample.energies.values, e) + assert sample.energies.unit == 'meV' + assert np.array_equal(sample.probabilities.values, p / p.sum()) + assert sample.probabilities.unit == 'dimensionless' + assert sample.seed == 78 diff --git a/tests/model_test.py b/tests/model_test.py index 11e9005..15c8733 100644 --- a/tests/model_test.py +++ b/tests/model_test.py @@ -17,7 +17,8 @@ ms = sc.Unit('ms') -def test_one_chopper_one_opening(make_chopper, make_source): +@pytest.mark.parametrize("use_components", [True, False]) +def test_one_chopper_one_opening(make_chopper, make_source, use_components): # Make a chopper open from 10-20 ms. Assume zero phase. topen = 10.0 * ms tclose = 20.0 * ms @@ -40,7 +41,11 @@ def test_one_chopper_one_opening(make_chopper, make_source): distance=chopper.distance, ) - model = tof.Model(source=source, choppers=[chopper], detectors=[detector]) + if use_components: + args = {"components": [chopper, detector]} + else: + args = {"choppers": [chopper], "detectors": [detector]} + model = tof.Model(source=source, **args) res = model.run() toa = res.choppers['chopper'].toa.data @@ -71,7 +76,8 @@ def test_one_chopper_one_opening(make_chopper, make_source): ) -def test_two_choppers_one_opening(make_chopper, make_source): +@pytest.mark.parametrize("use_components", [True, False]) +def test_two_choppers_one_opening(make_chopper, make_source, use_components): # Make a first chopper open from 5-16 ms. Assume zero phase. topen = 5.0 * ms tclose = 16.0 * ms @@ -105,9 +111,11 @@ def test_two_choppers_one_opening(make_chopper, make_source): distance=chopper1.distance, ) - model = tof.Model( - source=source, choppers=[chopper1, chopper2], detectors=[detector] - ) + if use_components: + args = {"components": [chopper1, chopper2, detector]} + else: + args = {"choppers": [chopper1, chopper2], "detectors": [detector]} + model = tof.Model(source=source, **args) res = model.run() ch1_toas = res.choppers['chopper1'].toa.data @@ -158,7 +166,8 @@ def test_two_choppers_one_opening(make_chopper, make_source): ) -def test_two_choppers_one_and_two_openings(make_chopper, make_source): +@pytest.mark.parametrize("use_components", [True, False]) +def test_two_choppers_one_and_two_openings(make_chopper, make_source, use_components): topen = 5.0 * ms tclose = 16.0 * ms chopper1 = make_chopper( @@ -201,9 +210,11 @@ def test_two_choppers_one_and_two_openings(make_chopper, make_source): distance=chopper1.distance, ) - model = tof.Model( - source=source, choppers=[chopper1, chopper2], detectors=[detector] - ) + if use_components: + args = {"components": [chopper1, chopper2, detector]} + else: + args = {"choppers": [chopper1, chopper2], "detectors": [detector]} + model = tof.Model(source=source, **args) res = model.run() assert res.choppers['chopper1'].toa.data.sum().value == 5 @@ -216,7 +227,8 @@ def test_two_choppers_one_and_two_openings(make_chopper, make_source): ) -def test_neutron_conservation(make_chopper): +@pytest.mark.parametrize("use_components", [True, False]) +def test_neutron_conservation(make_chopper, use_components): N = 100_000 source = tof.Source(facility='ess', neutrons=N) @@ -238,9 +250,11 @@ def test_neutron_conservation(make_chopper): ) detector = tof.Detector(distance=20 * meter, name='detector') - model = tof.Model( - source=source, choppers=[chopper1, chopper2], detectors=[detector] - ) + if use_components: + args = {"components": [chopper1, chopper2, detector]} + else: + args = {"choppers": [chopper1, chopper2], "detectors": [detector]} + model = tof.Model(source=source, **args) res = model.run() ch1 = res.choppers['chopper1'].toa.data @@ -258,42 +272,6 @@ def test_neutron_conservation(make_chopper): assert det.sum().value + det.masks['blocked_by_others'].sum().value == N -def test_neutron_time_of_flight(make_chopper): - N = 10_000 - source = tof.Source(facility='ess', neutrons=N) - - chopper1 = make_chopper( - topen=[5.0 * ms], - tclose=[16.0 * ms], - f=10.0 * Hz, - phase=0.0 * deg, - distance=10 * meter, - name='chopper1', - ) - chopper2 = make_chopper( - topen=[9.0 * ms, 15.0 * ms], - tclose=[15.0 * ms, 20.0 * ms], - f=15.0 * Hz, - phase=0.0 * deg, - distance=15 * meter, - name='chopper2', - ) - - detector = tof.Detector(distance=20 * meter, name='detector') - model = tof.Model( - source=source, choppers=[chopper1, chopper2], detectors=[detector] - ) - res = model.run() - - ch1 = res.choppers['chopper1'].data - ch2 = res.choppers['chopper2'].data - det = res.detectors['detector'].data - - assert sc.allclose(ch1.coords['tof'], ch1.coords['toa'] - ch1.coords['birth_time']) - assert sc.allclose(ch2.coords['tof'], ch2.coords['toa'] - ch2.coords['birth_time']) - assert sc.allclose(det.coords['tof'], det.coords['toa'] - det.coords['birth_time']) - - def test_source_not_at_origin(make_chopper): N = 100_000 source1 = tof.Source(facility='ess', neutrons=N, seed=123) @@ -405,50 +383,25 @@ def test_remove(dummy_chopper, dummy_detector, dummy_source): chopper = dummy_chopper detector = dummy_detector model = tof.Model(source=dummy_source, choppers=[chopper], detectors=[detector]) - del model['dummy_chopper'] - assert 'dummy_chopper' not in model - assert 'dummy_detector' in model - del model['dummy_detector'] - assert 'dummy_detector' not in model - - -def test_getitem(dummy_chopper, dummy_detector, dummy_source): - chopper = dummy_chopper - detector = dummy_detector - model = tof.Model(source=dummy_source, choppers=[chopper], detectors=[detector]) - assert model['dummy_chopper'] is chopper - assert model['dummy_detector'] is detector - with pytest.raises(KeyError, match='No component with name foo'): - model['foo'] + model.remove('dummy_chopper') + assert 'dummy_chopper' not in model.components + assert 'dummy_detector' in model.components + model.remove('dummy_detector') + assert 'dummy_detector' not in model.components def test_bad_input_type_raises(dummy_chopper, dummy_detector, dummy_source): chopper = dummy_chopper detector = dummy_detector - with pytest.raises( - TypeError, match='Beamline components: expected Chopper instance' - ): + err = "Component must be an instance of Component or derived class" + with pytest.raises(TypeError, match=err): _ = tof.Model(source=dummy_source, choppers='bad chopper') - with pytest.raises( - TypeError, match='Beamline components: expected Detector instance' - ): + with pytest.raises(TypeError, match=err): _ = tof.Model(source=dummy_source, choppers=[chopper], detectors='abc') - with pytest.raises( - TypeError, match='Beamline components: expected Chopper instance' - ): + with pytest.raises(TypeError, match=err): _ = tof.Model(source=dummy_source, choppers=[chopper, 'bad chopper']) - with pytest.raises( - TypeError, match='Beamline components: expected Detector instance' - ): + with pytest.raises(TypeError, match=err): _ = tof.Model(source=dummy_source, detectors=(1234, detector)) - with pytest.raises( - TypeError, match='Beamline components: expected Chopper instance' - ): - _ = tof.Model(source=dummy_source, choppers=[detector]) - with pytest.raises( - TypeError, match='Beamline components: expected Detector instance' - ): - _ = tof.Model(source=dummy_source, detectors=[chopper]) def test_model_repr_does_not_raise(make_chopper): diff --git a/tests/result_test.py b/tests/result_test.py index f72872a..7b9f8b9 100644 --- a/tests/result_test.py +++ b/tests/result_test.py @@ -343,7 +343,7 @@ def test_plot_reading_pulse_skipping_does_not_raise(): distance=10 * meter, name='skip', ) - model.choppers['skip'] = skip + model.components['skip'] = skip res = model.run() res.choppers['chopper'].toa.plot() res.choppers['chopper'].wavelength.plot() @@ -361,7 +361,7 @@ def test_plot_reading_nothing_to_plot_raises(): distance=10 * meter, name='skip', ) - model.choppers['skip'] = skip + model.components['skip'] = skip res = model.run() with pytest.raises(RuntimeError, match="Nothing to plot."): res.detectors['detector'].toa.plot()