Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions src/probeinterface/probegroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,11 @@ def get_slice(self, selection: np.ndarray[bool | int]) -> "ProbeGroup":
if k not in new_probe.annotations:
new_probe.annotate(**{k: orig_probe.annotations[k]})

# probe_planar_contour is a probe-level attribute, not part of the to_numpy dtype,
# so from_numpy cannot restore it; copy it over explicitly.
if orig_probe.probe_planar_contour is not None and new_probe.probe_planar_contour is None:
new_probe.set_planar_contour(orig_probe.probe_planar_contour)

return sliced_probe_group

def select_probes(self, probe_ids: str | np.ndarray | list) -> "ProbeGroup":
Expand Down Expand Up @@ -494,6 +499,21 @@ def select_contacts(
f"contact_ids must be unique, but {duplicated.tolist()} appear more than once. "
"If the same contact id is on multiple probes, use probe_ids to disambiguate."
)
# each requested contact id must live on exactly one probe; collect every
# ambiguous one so the user can disambiguate them all at once
ambiguous = {}
for contact_id in contact_ids:
probes_for_id = np.unique(all_probe_ids[all_contact_ids == contact_id]).tolist()
if len(probes_for_id) > 1:
ambiguous[str(contact_id)] = probes_for_id
if ambiguous:
ambiguity_lines = "\n".join(
f'"{contact_id}" lives on probes {probes}' for contact_id, probes in ambiguous.items()
)
message = f"""\
Some contact ids are ambiguous because they live on multiple probes; pass probe_ids to disambiguate which probe each belongs to:
{ambiguity_lines}"""
raise ValueError(message)
probe_ids = [None] * len(contact_ids)
else:
if len(probe_ids) != len(contact_ids):
Expand Down
43 changes: 43 additions & 0 deletions tests/test_probegroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,6 +675,49 @@ def test_add_probe_default_id_with_non_numeric_ids():
assert pg.probe_ids == ["left", "right", "0"]


def test_select_contacts_ambiguous_id_message_points_to_probe_ids():
"""
When a contact id exists on several probes and no probe_ids are given, the
error must guide the user to pass probe_ids rather than claim it cannot happen.
"""
pg = _probegroup_with_contact_ids(unique=False)
expected_error = """Some contact ids are ambiguous because they live on multiple probes; pass probe_ids to disambiguate which probe each belongs to:
"c0" lives on probes ['0', '1', '2']"""
with pytest.raises(ValueError) as exc_info:
pg.select_contacts(["c0"])
assert str(exc_info.value) == expected_error


def test_select_contacts_reports_all_ambiguous_ids_at_once():
"""
When several requested contact ids are ambiguous, the error lists all of them
(with the probes each lives on) rather than failing on the first one.
"""
pg = _probegroup_with_contact_ids(unique=False)
expected_error = """Some contact ids are ambiguous because they live on multiple probes; pass probe_ids to disambiguate which probe each belongs to:
"c0" lives on probes ['0', '1', '2']
"c1" lives on probes ['0', '1', '2']"""
with pytest.raises(ValueError) as exc_info:
pg.select_contacts(["c0", "c1"])
assert str(exc_info.value) == expected_error


def test_get_slice_preserves_planar_contour():
"""
probe_planar_contour is a probe-level attribute (not part of the to_numpy
dtype), so get_slice must copy it over explicitly instead of losing it.
"""
pg = ProbeGroup()
probe = generate_dummy_probe()
contour = [[-10, -10], [-10, 100], [50, 120], [50, -10]]
probe.set_planar_contour(contour)
pg.add_probe(probe)

sub = pg.get_slice(np.array([0, 1, 2]))
assert sub.probes[0].probe_planar_contour is not None
np.testing.assert_array_equal(sub.probes[0].probe_planar_contour, contour)


if __name__ == "__main__":
probegroup = _make_probegroup()

Expand Down
Loading