From ec8c84201793a375bb487c2bc3456b9cee6e9ea5 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Wed, 1 Jul 2026 10:01:24 -0600 Subject: [PATCH] small improvements --- src/probeinterface/probegroup.py | 20 +++++++++++++++ tests/test_probegroup.py | 43 ++++++++++++++++++++++++++++++++ 2 files changed, 63 insertions(+) diff --git a/src/probeinterface/probegroup.py b/src/probeinterface/probegroup.py index 6214599f..144e9983 100644 --- a/src/probeinterface/probegroup.py +++ b/src/probeinterface/probegroup.py @@ -427,6 +427,11 @@ def get_slice(self, selection: np.ndarray[bool | int]) -> "ProbeGroup": if k not in new_probe.annotations: new_probe.annotate(**{k: orig_probe.annotations[k]}) + # probe_planar_contour is a probe-level attribute, not part of the to_numpy dtype, + # so from_numpy cannot restore it; copy it over explicitly. + if orig_probe.probe_planar_contour is not None and new_probe.probe_planar_contour is None: + new_probe.set_planar_contour(orig_probe.probe_planar_contour) + return sliced_probe_group def select_probes(self, probe_ids: str | np.ndarray | list) -> "ProbeGroup": @@ -494,6 +499,21 @@ def select_contacts( f"contact_ids must be unique, but {duplicated.tolist()} appear more than once. " "If the same contact id is on multiple probes, use probe_ids to disambiguate." ) + # each requested contact id must live on exactly one probe; collect every + # ambiguous one so the user can disambiguate them all at once + ambiguous = {} + for contact_id in contact_ids: + probes_for_id = np.unique(all_probe_ids[all_contact_ids == contact_id]).tolist() + if len(probes_for_id) > 1: + ambiguous[str(contact_id)] = probes_for_id + if ambiguous: + ambiguity_lines = "\n".join( + f'"{contact_id}" lives on probes {probes}' for contact_id, probes in ambiguous.items() + ) + message = f"""\ +Some contact ids are ambiguous because they live on multiple probes; pass probe_ids to disambiguate which probe each belongs to: +{ambiguity_lines}""" + raise ValueError(message) probe_ids = [None] * len(contact_ids) else: if len(probe_ids) != len(contact_ids): diff --git a/tests/test_probegroup.py b/tests/test_probegroup.py index b457f903..f0127926 100644 --- a/tests/test_probegroup.py +++ b/tests/test_probegroup.py @@ -675,6 +675,49 @@ def test_add_probe_default_id_with_non_numeric_ids(): assert pg.probe_ids == ["left", "right", "0"] +def test_select_contacts_ambiguous_id_message_points_to_probe_ids(): + """ + When a contact id exists on several probes and no probe_ids are given, the + error must guide the user to pass probe_ids rather than claim it cannot happen. + """ + pg = _probegroup_with_contact_ids(unique=False) + expected_error = """Some contact ids are ambiguous because they live on multiple probes; pass probe_ids to disambiguate which probe each belongs to: +"c0" lives on probes ['0', '1', '2']""" + with pytest.raises(ValueError) as exc_info: + pg.select_contacts(["c0"]) + assert str(exc_info.value) == expected_error + + +def test_select_contacts_reports_all_ambiguous_ids_at_once(): + """ + When several requested contact ids are ambiguous, the error lists all of them + (with the probes each lives on) rather than failing on the first one. + """ + pg = _probegroup_with_contact_ids(unique=False) + expected_error = """Some contact ids are ambiguous because they live on multiple probes; pass probe_ids to disambiguate which probe each belongs to: +"c0" lives on probes ['0', '1', '2'] +"c1" lives on probes ['0', '1', '2']""" + with pytest.raises(ValueError) as exc_info: + pg.select_contacts(["c0", "c1"]) + assert str(exc_info.value) == expected_error + + +def test_get_slice_preserves_planar_contour(): + """ + probe_planar_contour is a probe-level attribute (not part of the to_numpy + dtype), so get_slice must copy it over explicitly instead of losing it. + """ + pg = ProbeGroup() + probe = generate_dummy_probe() + contour = [[-10, -10], [-10, 100], [50, 120], [50, -10]] + probe.set_planar_contour(contour) + pg.add_probe(probe) + + sub = pg.get_slice(np.array([0, 1, 2])) + assert sub.probes[0].probe_planar_contour is not None + np.testing.assert_array_equal(sub.probes[0].probe_planar_contour, contour) + + if __name__ == "__main__": probegroup = _make_probegroup()