From d3c2f7a20ccae5832f9a0039eb04db356b85be85 Mon Sep 17 00:00:00 2001 From: anmol thapar Date: Wed, 1 Jul 2026 08:06:40 +0100 Subject: [PATCH 1/4] get working batched version --- src/estimint/hbr.py | 181 +++++++++++-------------- src/estimint/scenarios.py | 271 +++++++++++++++++++++++++------------- 2 files changed, 251 insertions(+), 201 deletions(-) diff --git a/src/estimint/hbr.py b/src/estimint/hbr.py index 894495b..660d337 100644 --- a/src/estimint/hbr.py +++ b/src/estimint/hbr.py @@ -11,143 +11,110 @@ 5. EIR_new = EIR_baseline * (EIR_scaled / EIR_roundtrip) """ -from typing import Dict, Any, Optional +from typing import Any import pandas as pd from .run import run_xgb_model -from .storage import load_xgb_model -# Lazy-loaded model cache -_models: Dict[str, Any] = {} - - -def _get_model(name: str) -> Dict[str, Any]: - """Load and cache a bundled model by name.""" - if name not in _models: - _models[name] = load_xgb_model(name) - return _models[name] - - -def estimate_eir_with_mosquito_delta( - prevalence: float, - mosquito_delta: float, - dn0_use: float, - Q0: float, - phi_bednets: float, - seasonal: float, - itn_use: float, - irs_use: float, - prev_model: Optional[Dict[str, Any]] = None, - hbr_model: Optional[Dict[str, Any]] = None, - eir_to_hbr_model: Optional[Dict[str, Any]] = None, -) -> Dict[str, float]: +def estimate_eir_with_mosquito_delta_batch(inputs: pd.DataFrame, *, models: dict[str, Any]) -> pd.DataFrame: """ - Estimate the new EIR after a change in mosquito density. + Estimate new EIR after a change in mosquito density for multiple scenarios. - Uses a ratio approach: the HBR model predicts EIR at both the baseline - and scaled HBR, then applies the relative multiplier to the clean + The HBR model predicts EIR at both the + baseline and scaled HBR, then applies the relative multiplier to the clean baseline EIR from the prevalence model. Parameters ---------- - prevalence : float - Baseline malaria prevalence (prev_y9), e.g. 0.30 for 30%. - mosquito_delta : float - Fractional change in mosquito density, e.g. 0.10 for +10%, -0.50 for -50%. - Must be > -1 (cannot eliminate more mosquitoes than exist). - dn0_use : float - Bednet contact reduction parameter. - Q0 : float - Human blood index. - phi_bednets : float - Proportion of bites on humans while in bed. - seasonal : float - Seasonality flag (0.0 or 1.0). - itn_use : float - ITN coverage (0-1). - irs_use : float - IRS coverage (0-1). - prev_model : dict, optional - Custom prevalence model. If None, uses bundled model. - hbr_model : dict, optional - Custom HBR->EIR model. If None, uses bundled model. - eir_to_hbr_model : dict, optional - Custom EIR->HBR model. If None, uses bundled model. + inputs : pd.DataFrame + One row per scenario. Required columns: + + - ``prevalence`` : baseline malaria prevalence (prev_y9), e.g. 0.30 for 30%. + - ``mosquito_delta`` : fractional change in mosquito density, e.g. 0.10 + for +10%, -0.50 for -50%. Must be > -1 per row. + - ``dn0_use`` : bednet contact reduction parameter. + - ``Q0`` : human blood index. + - ``phi_bednets`` : proportion of bites on humans while in bed. + - ``seasonal`` : seasonality flag (0.0 or 1.0). + - ``itn_use`` : ITN coverage (0-1). + - ``irs_use`` : IRS coverage (0-1). + + models : dict + Pre-loaded model dictionary with keys ``"prevalence"``, ``"hbr"``, + and ``"eir_to_hbr"``. Returns ------- - dict - Dictionary with keys: - - eir_baseline: Baseline EIR from prevalence - - eir_new: New EIR after mosquito density change - - eir_multiplier: Ratio of new EIR to baseline - - hbr_baseline: Estimated baseline HBR - - hbr_new: HBR after mosquito density change + pd.DataFrame + Same index as *inputs*, with columns: + + - ``eir_baseline`` : baseline EIR from prevalence. + - ``eir_new`` : new EIR after mosquito density change. + - ``eir_multiplier`` : ratio of new EIR to baseline. + - ``hbr_baseline`` : estimated baseline HBR. + - ``hbr_new`` : HBR after mosquito density change. Examples -------- - >>> from estimint import estimate_eir_with_mosquito_delta - >>> result = estimate_eir_with_mosquito_delta( - ... prevalence=0.30, mosquito_delta=0.25, - ... dn0_use=0.33, Q0=0.87, phi_bednets=0.82, - ... seasonal=0.0, itn_use=0.6, irs_use=0.0, - ... ) - >>> print(f"EIR: {result['eir_baseline']:.1f} -> {result['eir_new']:.1f}") + >>> import pandas as pd + >>> from estimint import estimate_eir_with_mosquito_delta_batch + >>> inputs = pd.DataFrame([ + ... {"prevalence": 0.30, "mosquito_delta": 0.25, + ... "dn0_use": 0.33, "Q0": 0.87, "phi_bednets": 0.82, + ... "seasonal": 0.0, "itn_use": 0.6, "irs_use": 0.0}, + ... ]) + >>> result = estimate_eir_with_mosquito_delta_batch(inputs, models=models) + >>> print(result[["eir_baseline", "eir_new"]]) """ - if prev_model is None: - prev_model = _get_model("prevalence") - if hbr_model is None: - hbr_model = _get_model("hbr") - if eir_to_hbr_model is None: - eir_to_hbr_model = _get_model("eir_to_hbr") - - intv = { - "dn0_use": [dn0_use], - "Q0": [Q0], - "phi_bednets": [phi_bednets], - "seasonal": [seasonal], - "itn_use": [itn_use], - "irs_use": [irs_use], - } + features = [ + "dn0_use", + "Q0", + "phi_bednets", + "seasonal", + "itn_use", + "irs_use", + ] + intervention_data = inputs[features] # Step 1: prevalence -> EIR baseline - X_prev = pd.DataFrame({"prev_y9": [prevalence], **intv}) - eir_baseline = float(run_xgb_model(X_prev, prev_model)[0]) - - if mosquito_delta == 0: - return { - "eir_baseline": eir_baseline, - "eir_new": eir_baseline, - "eir_multiplier": 1.0, - "hbr_baseline": 0.0, - "hbr_new": 0.0, - } + prevalence_data = intervention_data.assign(prev_y9=inputs["prevalence"].to_numpy()) + eir_baseline = run_xgb_model(prevalence_data, models["prevalence"]) # Step 2: EIR -> HBR baseline - X_eir = pd.DataFrame({"eir": [eir_baseline], **intv}) - hbr_baseline = float(run_xgb_model(X_eir, eir_to_hbr_model)[0]) + eir_data = intervention_data.assign(eir=eir_baseline) + hbr_baseline = run_xgb_model(eir_data, models["eir_to_hbr"]) # Step 3: apply mosquito delta (positive or negative) - hbr_new = hbr_baseline * (1 + mosquito_delta) + hbr_new = hbr_baseline * (1 + inputs["mosquito_delta"].to_numpy()) # Step 4: ratio approach — batch both HBR values in one call so they # share the same smooth PCHIP curve - intv2 = {k: v * 2 for k, v in intv.items()} # repeat for 2 rows - X_hbr = pd.DataFrame({"hbr_y9": [hbr_baseline, hbr_new], **intv2}) - eir_both = run_xgb_model(X_hbr, hbr_model) - eir_rt = float(eir_both[0]) - eir_new_raw = float(eir_both[1]) + hbr_data = pd.concat( + [ + intervention_data.assign(hbr_y9=hbr_baseline), + intervention_data.assign(hbr_y9=hbr_new), + ], + ignore_index=True, + ) + eir_from_hbr = run_xgb_model(hbr_data, models["hbr"]) + + count = len(inputs) + eir_rt = eir_from_hbr[:count] + eir_new_raw = eir_from_hbr[count:] # Step 5: multiplier applied to clean baseline multiplier = eir_new_raw / eir_rt eir_new = eir_baseline * multiplier - return { - "eir_baseline": eir_baseline, - "eir_new": eir_new, - "eir_multiplier": multiplier, - "hbr_baseline": hbr_baseline, - "hbr_new": hbr_new, - } + return pd.DataFrame( + { + "eir_baseline": eir_baseline, + "eir_new": eir_new, + "eir_multiplier": multiplier, + "hbr_baseline": hbr_baseline, + "hbr_new": hbr_new, + }, + index=inputs.index, + ) diff --git a/src/estimint/scenarios.py b/src/estimint/scenarios.py index fd76d88..eace341 100644 --- a/src/estimint/scenarios.py +++ b/src/estimint/scenarios.py @@ -1,18 +1,18 @@ -"""One-call estiMINT -> stateMINT scenario runner (stateMINT imported lazily).""" - from __future__ import annotations +from dataclasses import dataclass from typing import Any, Dict import numpy as np import pandas as pd -from .bednet import calculate_dn0, DN0Result +from .bednet import DN0Result, calculate_dn0 +from .hbr import estimate_eir_with_mosquito_delta_batch from .run import run_xgb_model -from .hbr import estimate_eir_with_mosquito_delta from .storage import load_xgb_model -from .types import Scenario +from .types import EirTarget, Scenario +####################### Constants and global storage ################### HF_REPO = "dide-ic/stateMINT" # 157 windows of 14 days from day 2190; intervention at day 3285. _ABS_T = 2190 + 14 * np.arange(157) @@ -28,29 +28,19 @@ "py_ppf", ) - -def _bednet(scenario: Scenario): - """Current and future net (dn0, itn_use); returns (cur, net_now, net_next).""" - cur_nets = {net_type: getattr(scenario, net_type) for net_type in _NET_KEYS if getattr(scenario, net_type)} - - net_now = calculate_dn0(scenario.res_use, **cur_nets) if cur_nets else DN0Result(0.0, 0.0) - - net_type_future = scenario.net_type_future - if scenario.itn_future == 0.0 or not net_type_future: - net_next = DN0Result(0.0, 0.0) - else: - net_next = calculate_dn0(scenario.res_use, **{net_type_future: scenario.itn_future}) - return cur_nets, net_now, net_next +_EST_MODEL_NAMES = ("prevalence", "hbr", "eir_to_hbr") +_EIR_INPUT_COLS = {"prevalence": ("prev_y9", "prevalence"), "hbr": ("hbr_y9", "hbr")} -def _est_models() -> Dict[str, Any]: +######################## Internal helpers ######################## +def _load_eir_hbr_models() -> Dict[str, Any]: if not _MODELS: - _MODELS["prevalence"] = load_xgb_model("prevalence") - _MODELS["hbr"] = load_xgb_model("hbr") + for name in _EST_MODEL_NAMES: + _MODELS[name] = load_xgb_model(name) return _MODELS -def _emulators(hf_repo: str) -> Dict[str, Any]: +def _load_emulators(hf_repo: str) -> Dict[str, Any]: if hf_repo not in _EMULATORS: try: from stateMINT.model import Mamba2Regressor @@ -61,87 +51,178 @@ def _emulators(hf_repo: str) -> Dict[str, Any]: '"git+https://github.com/mrc-ide/stateMINT.git@mamba2-train").' ) from e _EMULATORS[hf_repo] = { - p: Mamba2Regressor.from_pretrained(hf_repo, predictor=p) for p in ("prevalence", "cases") + predictor: Mamba2Regressor.from_pretrained(hf_repo, predictor=predictor) + for predictor in ("prevalence", "cases") } return _EMULATORS[hf_repo] -_EIR_INPUT_COLS = {"prevalence": ("prev_y9", "prevalence"), "hbr": ("hbr_y9", "hbr")} +def _bednet(scenario: Scenario): + """Current and future net (dn0, itn_use); returns (cur, net_now, net_next).""" + cur_nets = {net_type: getattr(scenario, net_type) for net_type in _NET_KEYS if getattr(scenario, net_type)} + + net_now = calculate_dn0(scenario.res_use, **cur_nets) if cur_nets else DN0Result(0.0, 0.0) + + net_type_future = scenario.net_type_future + if scenario.itn_future == 0.0 or not net_type_future: + net_next = DN0Result(0.0, 0.0) + else: + net_next = calculate_dn0(scenario.res_use, **{net_type_future: scenario.itn_future}) + return cur_nets, net_now, net_next -def _estimate_eir(scenario: Scenario, eir_models: Dict[str, Any]) -> Dict[str, Any]: - """scenario -> EIR + the stateMINT covariate dict.""" +@dataclass +class _ScenarioWork: + scenario: Scenario + eir_target: EirTarget + mosquito_delta: float + eir_features: Dict[str, float] + row: dict[str, Any] + cov: dict[str, float] + + +def _prepare_scenario(scenario: Scenario) -> _ScenarioWork: + """Compute non-model scenario values once.""" cur_nets, net_now, net_next = _bednet(scenario) dn0_use, itn_use = net_now.dn0, net_now.itn_use dn0_future, itn_future = net_next.dn0, net_next.itn_use - Q0, phi, seasonal = scenario.Q0, scenario.phi, scenario.seasonal - irs_use = scenario.irs - irs_future = scenario.irs_future - routine = scenario.routine - ppf = scenario.py_ppf lsm = scenario.lsm - if ppf > 0: - lsm = min(ppf * 0.248 + lsm, 1.0) - feats = dict(dn0_use=dn0_use, Q0=Q0, phi_bednets=phi, seasonal=seasonal, itn_use=itn_use, irs_use=irs_use) - - # ── EIR by input mode ── - delta = scenario.mosquito_delta - hbr_baseline, hbr_new = np.nan, np.nan - input_mode, input_value = scenario.eir_target.input_mode, scenario.eir_target.input_value - - if input_mode == "eir": - eir_base = eir_final = input_value - elif delta and input_mode == "prevalence": - r = estimate_eir_with_mosquito_delta(prevalence=input_value, mosquito_delta=delta, **feats) - eir_base, eir_final = r["eir_baseline"], r["eir_new"] - hbr_baseline, hbr_new = r["hbr_baseline"], r["hbr_new"] - elif input_mode in _EIR_INPUT_COLS: - col, model_key = _EIR_INPUT_COLS[input_mode] - eir_base = eir_final = float( - run_xgb_model( - pd.DataFrame({**{k: [v] for k, v in feats.items()}, col: [input_value]}), - eir_models[model_key], - )[0] - ) - else: - raise ValueError(f"'input' must be 'prevalence', 'hbr' or 'eir'; got {input_mode!r}") - - cov = dict( - eir=eir_final, - dn0_use=dn0_use, - dn0_future=dn0_future, - Q0=Q0, - phi_bednets=phi, - seasonal=seasonal, - routine=routine, - itn_use=itn_use, - irs_use=irs_use, - itn_future=itn_future, - irs_future=irs_future, - lsm=lsm, - ) - row = dict( - name=scenario.name, - input_mode=input_mode, - net="+".join(cur_nets) or "none", - net_future=scenario.net_type_future or "none", - dn0_use=dn0_use, - itn_use=itn_use, - irs_use=irs_use, - dn0_future=dn0_future, - itn_future=itn_future, - irs_future=irs_future, - routine=routine, - lsm=lsm, - seasonal=seasonal, - eir_baseline=eir_base, - mosquito_delta=delta, - eir_final=eir_final, - hbr_baseline=hbr_baseline, - hbr_new=hbr_new, + if scenario.py_ppf > 0: + lsm = min(scenario.py_ppf * 0.248 + lsm, 1.0) + + eir_features = { + "dn0_use": dn0_use, + "Q0": scenario.Q0, + "phi_bednets": scenario.phi, + "seasonal": scenario.seasonal, + "itn_use": itn_use, + "irs_use": scenario.irs, + } + + cov = { + "eir": np.nan, + "dn0_use": dn0_use, + "dn0_future": dn0_future, + "Q0": scenario.Q0, + "phi_bednets": scenario.phi, + "seasonal": scenario.seasonal, + "routine": scenario.routine, + "itn_use": itn_use, + "irs_use": scenario.irs, + "itn_future": itn_future, + "irs_future": scenario.irs_future, + "lsm": lsm, + } + + row = { + "name": scenario.name, + "input_mode": scenario.eir_target.input_mode, + "net": "+".join(cur_nets) or "none", + "net_future": scenario.net_type_future or "none", + "dn0_use": dn0_use, + "itn_use": itn_use, + "irs_use": scenario.irs, + "dn0_future": dn0_future, + "itn_future": itn_future, + "irs_future": scenario.irs_future, + "routine": scenario.routine, + "lsm": lsm, + "seasonal": scenario.seasonal, + "eir_baseline": np.nan, + "mosquito_delta": scenario.mosquito_delta, + "eir_final": np.nan, + "hbr_baseline": np.nan, + "hbr_new": np.nan, + } + + return _ScenarioWork( + scenario=scenario, + eir_target=scenario.eir_target, + mosquito_delta=scenario.mosquito_delta, + eir_features=eir_features, + row=row, + cov=cov, ) - return {"row": row, "cov": cov} + + +def _set_eir_result( + work: _ScenarioWork, + *, + eir_baseline: float, + eir_final: float, + hbr_baseline: float = np.nan, + hbr_new: float = np.nan, +) -> None: + work.row["eir_baseline"] = float(eir_baseline) + work.row["eir_final"] = float(eir_final) + work.row["hbr_baseline"] = float(hbr_baseline) + work.row["hbr_new"] = float(hbr_new) + work.cov["eir"] = float(eir_final) + + +def _predict_direct_inputs(works: list[_ScenarioWork], *, input_mode: str, eir_models: Dict[str, Any]) -> None: + """Predict EIR for scenarios with direct inputs (prevalence or hbr).""" + col, model_key = _EIR_INPUT_COLS[input_mode] + rows = [{**work.eir_features, col: work.eir_target.input_value} for work in works] + + predictions = run_xgb_model(pd.DataFrame(rows), eir_models[model_key]) + + for work, prediction in zip(works, predictions): + _set_eir_result(work, eir_baseline=prediction, eir_final=prediction) + + +def _estimate_eir_many(scenarios: list[Scenario], eir_models: Dict[str, Any]) -> list[Dict[str, Any]]: + """Estimate EIR for many scenarios.""" + if any(scenario.eir_target.input_mode not in {"prevalence", "eir", "hbr"} for scenario in scenarios): + raise ValueError("All scenarios must have input_mode in {'prevalence', 'eir', 'hbr'}") + + works = [_prepare_scenario(scenario) for scenario in scenarios] + + eir_works = [work for work in works if work.eir_target.input_mode == "eir"] + for work in eir_works: + _set_eir_result(work, eir_baseline=work.eir_target.input_value, eir_final=work.eir_target.input_value) + + prevalence_works = [ + work for work in works if work.eir_target.input_mode == "prevalence" and not work.mosquito_delta + ] + if prevalence_works: + _predict_direct_inputs(prevalence_works, input_mode="prevalence", eir_models=eir_models) + + hbr_works = [work for work in works if work.eir_target.input_mode == "hbr"] + if hbr_works: + _predict_direct_inputs(hbr_works, input_mode="hbr", eir_models=eir_models) + + mosquito_delta_works = [ + work for work in works if work.eir_target.input_mode == "prevalence" and work.mosquito_delta + ] + if mosquito_delta_works: + mosquito_deltas_inputs = pd.DataFrame( + [ + {"prevalence": work.eir_target.input_value, "mosquito_delta": work.mosquito_delta, **work.eir_features} + for work in mosquito_delta_works + ] + ) + mosquito_delta_results = estimate_eir_with_mosquito_delta_batch(mosquito_deltas_inputs, models=eir_models) + for work, result in zip(mosquito_delta_works, mosquito_delta_results.to_dict(orient="records")): + _set_eir_result( + work, + eir_baseline=result["eir_baseline"], + eir_final=result["eir_new"], + hbr_baseline=result["hbr_baseline"], + hbr_new=result["hbr_new"], + ) + + return [{"row": work.row, "cov": work.cov} for work in works] + + +######################### Public API ######################## +def preload_models(*, hf_repo: str = HF_REPO) -> tuple[Dict[str, Any], Dict[str, Any]]: + """Preload the models used by run_scenarios.""" + eir_models = _load_eir_hbr_models() + emulator_models = _load_emulators(hf_repo) + + return eir_models, emulator_models def run_scenarios( @@ -150,11 +231,13 @@ def run_scenarios( hf_repo: str = HF_REPO, ) -> pd.DataFrame: """Run a list of scenarios through the estiMINT -> stateMINT pipeline.""" + if not scenarios: + return pd.DataFrame() - eir_models = _est_models() - emulator_models = _emulators(hf_repo) + eir_models, emulator_models = preload_models(hf_repo=hf_repo) - parts = [_estimate_eir(scenario, eir_models) for scenario in scenarios] + # parts = [_estimate_eir(scenario, eir_models) for scenario in scenarios] + parts = _estimate_eir_many(scenarios, eir_models) covs = [p["cov"] for p in parts] prev = emulator_models["prevalence"].predict(covs) cases = np.maximum(emulator_models["cases"].predict(covs), 0.0) From d1886f368f6a8b201ee228169d39bf6ac8d01b18 Mon Sep 17 00:00:00 2001 From: anmol thapar Date: Wed, 1 Jul 2026 09:33:21 +0100 Subject: [PATCH 2/4] refactor and clean up tests and add --- src/estimint/hbr.py | 6 +- src/estimint/scenarios.py | 364 +++++++++++++++++++++++++------------- tests/test_flows.py | 40 +++-- tests/test_scenarios.py | 230 +++++++++++++++++------- 4 files changed, 436 insertions(+), 204 deletions(-) diff --git a/src/estimint/hbr.py b/src/estimint/hbr.py index 660d337..fa70135 100644 --- a/src/estimint/hbr.py +++ b/src/estimint/hbr.py @@ -18,7 +18,7 @@ from .run import run_xgb_model -def estimate_eir_with_mosquito_delta_batch(inputs: pd.DataFrame, *, models: dict[str, Any]) -> pd.DataFrame: +def estimate_eir_with_mosquito_delta(inputs: pd.DataFrame, *, models: dict[str, Any]) -> pd.DataFrame: """ Estimate new EIR after a change in mosquito density for multiple scenarios. @@ -59,13 +59,13 @@ def estimate_eir_with_mosquito_delta_batch(inputs: pd.DataFrame, *, models: dict Examples -------- >>> import pandas as pd - >>> from estimint import estimate_eir_with_mosquito_delta_batch + >>> from estimint import estimate_eir_with_mosquito_delta >>> inputs = pd.DataFrame([ ... {"prevalence": 0.30, "mosquito_delta": 0.25, ... "dn0_use": 0.33, "Q0": 0.87, "phi_bednets": 0.82, ... "seasonal": 0.0, "itn_use": 0.6, "irs_use": 0.0}, ... ]) - >>> result = estimate_eir_with_mosquito_delta_batch(inputs, models=models) + >>> result = estimate_eir_with_mosquito_delta(inputs, models=models) >>> print(result[["eir_baseline", "eir_new"]]) """ features = [ diff --git a/src/estimint/scenarios.py b/src/estimint/scenarios.py index eace341..7f7ad73 100644 --- a/src/estimint/scenarios.py +++ b/src/estimint/scenarios.py @@ -7,19 +7,20 @@ import pandas as pd from .bednet import DN0Result, calculate_dn0 -from .hbr import estimate_eir_with_mosquito_delta_batch +from .hbr import estimate_eir_with_mosquito_delta from .run import run_xgb_model from .storage import load_xgb_model from .types import EirTarget, Scenario +from collections import defaultdict ####################### Constants and global storage ################### HF_REPO = "dide-ic/stateMINT" # 157 windows of 14 days from day 2190; intervention at day 3285. -_ABS_T = 2190 + 14 * np.arange(157) -_IDX_Y9 = int(np.argmin(np.abs(_ABS_T - 3285))) +_ABS_TIME = 2190 + 14 * np.arange(157) +_IDX_Y9 = int(np.argmin(np.abs(_ABS_TIME - 3285))) -_MODELS: Dict[str, Any] = {} -_EMULATORS: Dict[str, Dict[str, Any]] = {} +_EIR_MODEL_CACHE: Dict[str, Any] = {} +_EMULATOR_MODEL_CACHE: Dict[str, Dict[str, Any]] = {} _NET_KEYS = ( "py_only", @@ -28,106 +29,130 @@ "py_ppf", ) -_EST_MODEL_NAMES = ("prevalence", "hbr", "eir_to_hbr") -_EIR_INPUT_COLS = {"prevalence": ("prev_y9", "prevalence"), "hbr": ("hbr_y9", "hbr")} +_REQUIRED_EIR_MODEL_NAMES = ("prevalence", "hbr", "eir_to_hbr") + + +@dataclass(frozen=True) +class _EirInputModelConfig: + feature_column: str + model_name: str + + +_EIR_INPUT_MODEL_CONFIG = { + "prevalence": _EirInputModelConfig(feature_column="prev_y9", model_name="prevalence"), + "hbr": _EirInputModelConfig(feature_column="hbr_y9", model_name="hbr"), +} ######################## Internal helpers ######################## def _load_eir_hbr_models() -> Dict[str, Any]: - if not _MODELS: - for name in _EST_MODEL_NAMES: - _MODELS[name] = load_xgb_model(name) - return _MODELS + if not _EIR_MODEL_CACHE: + for model_name in _REQUIRED_EIR_MODEL_NAMES: + _EIR_MODEL_CACHE[model_name] = load_xgb_model(model_name) + return _EIR_MODEL_CACHE def _load_emulators(hf_repo: str) -> Dict[str, Any]: - if hf_repo not in _EMULATORS: + if hf_repo not in _EMULATOR_MODEL_CACHE: try: from stateMINT.model import Mamba2Regressor - except ImportError as e: + except ImportError as error: raise ImportError( "run_scenarios needs stateMINT. Install it with: " "uv sync --extra scenarios (or pip install " '"git+https://github.com/mrc-ide/stateMINT.git@mamba2-train").' - ) from e - _EMULATORS[hf_repo] = { - predictor: Mamba2Regressor.from_pretrained(hf_repo, predictor=predictor) - for predictor in ("prevalence", "cases") + ) from error + _EMULATOR_MODEL_CACHE[hf_repo] = { + outcome_name: Mamba2Regressor.from_pretrained(hf_repo, predictor=outcome_name) + for outcome_name in ("prevalence", "cases") } - return _EMULATORS[hf_repo] + return _EMULATOR_MODEL_CACHE[hf_repo] -def _bednet(scenario: Scenario): - """Current and future net (dn0, itn_use); returns (cur, net_now, net_next).""" - cur_nets = {net_type: getattr(scenario, net_type) for net_type in _NET_KEYS if getattr(scenario, net_type)} +@dataclass(frozen=True) +class _BedNetEffects: + current_coverages: dict[str, float] + current: DN0Result + future: DN0Result - net_now = calculate_dn0(scenario.res_use, **cur_nets) if cur_nets else DN0Result(0.0, 0.0) - net_type_future = scenario.net_type_future - if scenario.itn_future == 0.0 or not net_type_future: - net_next = DN0Result(0.0, 0.0) +def _calculate_bednet_effects(scenario: Scenario) -> _BedNetEffects: + """Calculate current and future bed-net coverage and transmission effects.""" + current_coverages = {net_type: getattr(scenario, net_type) for net_type in _NET_KEYS if getattr(scenario, net_type)} + + current_effect = calculate_dn0(scenario.res_use, **current_coverages) if current_coverages else DN0Result(0.0, 0.0) + + future_net_type = scenario.net_type_future + if scenario.itn_future == 0.0 or not future_net_type: + future_effect = DN0Result(0.0, 0.0) else: - net_next = calculate_dn0(scenario.res_use, **{net_type_future: scenario.itn_future}) - return cur_nets, net_now, net_next + future_effect = calculate_dn0(scenario.res_use, **{future_net_type: scenario.itn_future}) + + return _BedNetEffects( + current_coverages=current_coverages, + current=current_effect, + future=future_effect, + ) @dataclass -class _ScenarioWork: - scenario: Scenario +class _PreparedScenario: eir_target: EirTarget - mosquito_delta: float - eir_features: Dict[str, float] - row: dict[str, Any] - cov: dict[str, float] + mosquito_density_change: float + eir_model_features: Dict[str, float] + summary_values: dict[str, Any] + emulator_covariates: dict[str, float] -def _prepare_scenario(scenario: Scenario) -> _ScenarioWork: - """Compute non-model scenario values once.""" - cur_nets, net_now, net_next = _bednet(scenario) - dn0_use, itn_use = net_now.dn0, net_now.itn_use - dn0_future, itn_future = net_next.dn0, net_next.itn_use +def _prepare_scenario_inputs(scenario: Scenario) -> _PreparedScenario: + """Compute the model inputs and initial output values for one scenario.""" + bednet_effects = _calculate_bednet_effects(scenario) + current_dn0 = bednet_effects.current.dn0 + current_itn_use = bednet_effects.current.itn_use + future_dn0 = bednet_effects.future.dn0 + future_itn_use = bednet_effects.future.itn_use - lsm = scenario.lsm + adjusted_lsm_coverage = scenario.lsm if scenario.py_ppf > 0: - lsm = min(scenario.py_ppf * 0.248 + lsm, 1.0) + adjusted_lsm_coverage = min(scenario.py_ppf * 0.248 + adjusted_lsm_coverage, 1.0) - eir_features = { - "dn0_use": dn0_use, + eir_model_features = { + "dn0_use": current_dn0, "Q0": scenario.Q0, "phi_bednets": scenario.phi, "seasonal": scenario.seasonal, - "itn_use": itn_use, + "itn_use": current_itn_use, "irs_use": scenario.irs, } - cov = { + emulator_covariates = { "eir": np.nan, - "dn0_use": dn0_use, - "dn0_future": dn0_future, + "dn0_use": current_dn0, + "dn0_future": future_dn0, "Q0": scenario.Q0, "phi_bednets": scenario.phi, "seasonal": scenario.seasonal, "routine": scenario.routine, - "itn_use": itn_use, + "itn_use": current_itn_use, "irs_use": scenario.irs, - "itn_future": itn_future, + "itn_future": future_itn_use, "irs_future": scenario.irs_future, - "lsm": lsm, + "lsm": adjusted_lsm_coverage, } - row = { + summary_values = { "name": scenario.name, "input_mode": scenario.eir_target.input_mode, - "net": "+".join(cur_nets) or "none", + "net": "+".join(bednet_effects.current_coverages) or "none", "net_future": scenario.net_type_future or "none", - "dn0_use": dn0_use, - "itn_use": itn_use, + "dn0_use": current_dn0, + "itn_use": current_itn_use, "irs_use": scenario.irs, - "dn0_future": dn0_future, - "itn_future": itn_future, + "dn0_future": future_dn0, + "itn_future": future_itn_use, "irs_future": scenario.irs_future, "routine": scenario.routine, - "lsm": lsm, + "lsm": adjusted_lsm_coverage, "seasonal": scenario.seasonal, "eir_baseline": np.nan, "mosquito_delta": scenario.mosquito_delta, @@ -136,84 +161,111 @@ def _prepare_scenario(scenario: Scenario) -> _ScenarioWork: "hbr_new": np.nan, } - return _ScenarioWork( - scenario=scenario, + return _PreparedScenario( eir_target=scenario.eir_target, - mosquito_delta=scenario.mosquito_delta, - eir_features=eir_features, - row=row, - cov=cov, + mosquito_density_change=scenario.mosquito_delta, + eir_model_features=eir_model_features, + summary_values=summary_values, + emulator_covariates=emulator_covariates, ) -def _set_eir_result( - work: _ScenarioWork, +def _record_eir_estimate( + prepared_scenario: _PreparedScenario, *, eir_baseline: float, eir_final: float, hbr_baseline: float = np.nan, hbr_new: float = np.nan, ) -> None: - work.row["eir_baseline"] = float(eir_baseline) - work.row["eir_final"] = float(eir_final) - work.row["hbr_baseline"] = float(hbr_baseline) - work.row["hbr_new"] = float(hbr_new) - work.cov["eir"] = float(eir_final) + prepared_scenario.summary_values["eir_baseline"] = float(eir_baseline) + prepared_scenario.summary_values["eir_final"] = float(eir_final) + prepared_scenario.summary_values["hbr_baseline"] = float(hbr_baseline) + prepared_scenario.summary_values["hbr_new"] = float(hbr_new) + prepared_scenario.emulator_covariates["eir"] = float(eir_final) -def _predict_direct_inputs(works: list[_ScenarioWork], *, input_mode: str, eir_models: Dict[str, Any]) -> None: - """Predict EIR for scenarios with direct inputs (prevalence or hbr).""" - col, model_key = _EIR_INPUT_COLS[input_mode] - rows = [{**work.eir_features, col: work.eir_target.input_value} for work in works] +def _predict_eir_from_measurements( + prepared_scenarios: list[_PreparedScenario], *, input_mode: str, eir_models: Dict[str, Any] +) -> None: + """Predict EIR from baseline prevalence or HBR measurements.""" + model_config = _EIR_INPUT_MODEL_CONFIG[input_mode] + model_input_records = [ + { + **prepared_scenario.eir_model_features, + model_config.feature_column: prepared_scenario.eir_target.input_value, + } + for prepared_scenario in prepared_scenarios + ] + + eir_predictions = run_xgb_model(pd.DataFrame(model_input_records), eir_models[model_config.model_name]) + + for prepared_scenario, eir_prediction in zip(prepared_scenarios, eir_predictions): + _record_eir_estimate( + prepared_scenario, + eir_baseline=eir_prediction, + eir_final=eir_prediction, + ) - predictions = run_xgb_model(pd.DataFrame(rows), eir_models[model_key]) - for work, prediction in zip(works, predictions): - _set_eir_result(work, eir_baseline=prediction, eir_final=prediction) +def _classify_prepared_scenario(prepared_scenario: _PreparedScenario) -> str: + """Return the EIR estimation method for a prepared scenario.""" + if prepared_scenario.eir_target.input_mode == "eir": + return "eir" + if prepared_scenario.eir_target.input_mode == "prevalence" and prepared_scenario.mosquito_density_change != 0.0: + return "mosquito_delta" + return prepared_scenario.eir_target.input_mode # "prevalence" or "hbr" -def _estimate_eir_many(scenarios: list[Scenario], eir_models: Dict[str, Any]) -> list[Dict[str, Any]]: - """Estimate EIR for many scenarios.""" +def _apply_mosquito_delta_batch(prepared_scenarios: list[_PreparedScenario], eir_models: Dict[str, Any]) -> None: + inputs = pd.DataFrame( + [ + { + "prevalence": prepared_scenario.eir_target.input_value, + "mosquito_delta": prepared_scenario.mosquito_density_change, + **prepared_scenario.eir_model_features, + } + for prepared_scenario in prepared_scenarios + ] + ) + estimates = estimate_eir_with_mosquito_delta(inputs, models=eir_models).to_dict(orient="records") + for prepared_scenario, estimate in zip(prepared_scenarios, estimates): + _record_eir_estimate( + prepared_scenario, + eir_baseline=estimate["eir_baseline"], + eir_final=estimate["eir_new"], + hbr_baseline=estimate["hbr_baseline"], + hbr_new=estimate["hbr_new"], + ) + + +def _estimate_eir_many(scenarios: list[Scenario], eir_models: Dict[str, Any]) -> list[_PreparedScenario]: + """Estimate EIR for many scenarios, dispatching each to one of three paths: + - "eir": supplied directly, passed through unchanged + - "prevalence" / "hbr": predicted from baseline measurements via XGBoost + - "mosquito_delta": prevalence input with a projected mosquito-density change + """ if any(scenario.eir_target.input_mode not in {"prevalence", "eir", "hbr"} for scenario in scenarios): raise ValueError("All scenarios must have input_mode in {'prevalence', 'eir', 'hbr'}") - works = [_prepare_scenario(scenario) for scenario in scenarios] + prepared_scenarios = [_prepare_scenario_inputs(scenario) for scenario in scenarios] - eir_works = [work for work in works if work.eir_target.input_mode == "eir"] - for work in eir_works: - _set_eir_result(work, eir_baseline=work.eir_target.input_value, eir_final=work.eir_target.input_value) + scenario_groups: dict[str, list[_PreparedScenario]] = defaultdict(list) + for prepared_scenario in prepared_scenarios: + scenario_groups[_classify_prepared_scenario(prepared_scenario)].append(prepared_scenario) - prevalence_works = [ - work for work in works if work.eir_target.input_mode == "prevalence" and not work.mosquito_delta - ] - if prevalence_works: - _predict_direct_inputs(prevalence_works, input_mode="prevalence", eir_models=eir_models) + for prepared_scenario in scenario_groups["eir"]: + supplied_eir = prepared_scenario.eir_target.input_value + _record_eir_estimate(prepared_scenario, eir_baseline=supplied_eir, eir_final=supplied_eir) - hbr_works = [work for work in works if work.eir_target.input_mode == "hbr"] - if hbr_works: - _predict_direct_inputs(hbr_works, input_mode="hbr", eir_models=eir_models) + for input_mode in ("prevalence", "hbr"): + if scenario_groups[input_mode]: + _predict_eir_from_measurements(scenario_groups[input_mode], input_mode=input_mode, eir_models=eir_models) - mosquito_delta_works = [ - work for work in works if work.eir_target.input_mode == "prevalence" and work.mosquito_delta - ] - if mosquito_delta_works: - mosquito_deltas_inputs = pd.DataFrame( - [ - {"prevalence": work.eir_target.input_value, "mosquito_delta": work.mosquito_delta, **work.eir_features} - for work in mosquito_delta_works - ] - ) - mosquito_delta_results = estimate_eir_with_mosquito_delta_batch(mosquito_deltas_inputs, models=eir_models) - for work, result in zip(mosquito_delta_works, mosquito_delta_results.to_dict(orient="records")): - _set_eir_result( - work, - eir_baseline=result["eir_baseline"], - eir_final=result["eir_new"], - hbr_baseline=result["hbr_baseline"], - hbr_new=result["hbr_new"], - ) + if scenario_groups["mosquito_delta"]: + _apply_mosquito_delta_batch(scenario_groups["mosquito_delta"], eir_models) - return [{"row": work.row, "cov": work.cov} for work in works] + return prepared_scenarios ######################### Public API ######################## @@ -230,28 +282,94 @@ def run_scenarios( *, hf_repo: str = HF_REPO, ) -> pd.DataFrame: - """Run a list of scenarios through the estiMINT -> stateMINT pipeline.""" + """Run a list of scenarios through the estiMINT -> stateMINT pipeline. + + For each scenario, estimates the entomological inoculation rate (EIR) from + the given input (prevalence, EIR, or HBR), then feeds the resulting + covariate vector into emulator models to predict malaria prevalence and + case burden over time. + + Args: + scenarios: Scenarios to evaluate. Each ``Scenario`` describes + intervention coverages (ITN, IRS, LSM, etc.) and an + ``EirTarget`` specifying the baseline transmission intensity. + hf_repo: HuggingFace repo ID from which emulator model weights are + downloaded. Defaults to the package-level ``HF_REPO`` constant. + + Returns: + A ``pd.DataFrame`` with one row per scenario containing: + + - ``name``, ``input_mode``, ``net``, ``net_future`` — scenario labels + - ``dn0_use``, ``itn_use``, ``irs_use``, ``dn0_future``, + ``itn_future``, ``irs_future``, ``routine``, ``lsm``, + ``seasonal``, ``mosquito_delta`` — intervention covariates + - ``eir_baseline``, ``eir_final`` — estimated EIR before and after + interventions + - ``hbr_baseline``, ``hbr_new`` — human biting rate (populated for + prevalence inputs with a mosquito-density change, otherwise ``NaN``) + - ``prev_y9`` — prevalence at year 9 (≈3285 days) + - ``prev_endline`` — prevalence at the final time step + - ``cases_endline`` — cases at the final time step + - ``prevalence`` — full prevalence time series (numpy array) + - ``cases`` — full cases time series (numpy array, floored at 0) + + Returns an empty ``DataFrame`` if ``scenarios`` is empty. + + Example: + >>> from estimint.scenarios import run_scenarios + >>> from estimint.types import EirTarget, Scenario + >>> + >>> scenarios = [ + ... Scenario( + ... name="baseline", + ... res_use=0.0, + ... Q0=0.9, + ... phi=0.85, + ... seasonal=0.5, + ... irs=0.0, + ... eir_target=EirTarget(input_value=50.0, input_mode="eir"), + ... ), + ... Scenario( + ... name="itn_campaign", + ... res_use=0.0, + ... Q0=0.9, + ... phi=0.85, + ... seasonal=0.5, + ... irs=0.0, + ... eir_target=EirTarget(input_value=50.0, input_mode="eir"), + ... py_only=0.6, + ... net_type_future="pyrethroid-only", + ... itn_future=0.6, + ... ), + ... ] + >>> + >>> results = run_scenarios(scenarios) + >>> results[["name", "prevalence", "cases"]] + """ if not scenarios: return pd.DataFrame() eir_models, emulator_models = preload_models(hf_repo=hf_repo) - # parts = [_estimate_eir(scenario, eir_models) for scenario in scenarios] - parts = _estimate_eir_many(scenarios, eir_models) - covs = [p["cov"] for p in parts] - prev = emulator_models["prevalence"].predict(covs) - cases = np.maximum(emulator_models["cases"].predict(covs), 0.0) + scenario_estimates = _estimate_eir_many(scenarios, eir_models) + emulator_covariates = [estimate.emulator_covariates for estimate in scenario_estimates] + prevalence_timeseries = emulator_models["prevalence"].predict(emulator_covariates) + case_timeseries = np.maximum(emulator_models["cases"].predict(emulator_covariates), 0.0) return pd.DataFrame( [ { - **part["row"], - "prev_y9": float(p[_IDX_Y9]), - "prev_endline": float(p[-1]), - "cases_endline": float(c[-1]), - "prevalence": p, - "cases": c, + **scenario_estimate.summary_values, + "prev_y9": float(prevalence_series[_IDX_Y9]), + "prev_endline": float(prevalence_series[-1]), + "cases_endline": float(case_series[-1]), + "prevalence": prevalence_series, + "cases": case_series, } - for part, p, c in zip(parts, prev, cases) + for scenario_estimate, prevalence_series, case_series in zip( + scenario_estimates, + prevalence_timeseries, + case_timeseries, + ) ] ) diff --git a/tests/test_flows.py b/tests/test_flows.py index e963b0d..4d7a164 100644 --- a/tests/test_flows.py +++ b/tests/test_flows.py @@ -4,6 +4,7 @@ """ import pandas as pd +import pytest from estimint import ( load_xgb_model, @@ -33,33 +34,40 @@ def test_higher_prevalence_gives_higher_eir(self): class TestMosquitoDelta: - def _run(self, delta): - return estimate_eir_with_mosquito_delta( - prevalence=0.30, mosquito_delta=delta, **INTERVENTIONS - ) + @pytest.fixture(scope="class") + def models(self): + return {name: load_xgb_model(name) for name in ("prevalence", "hbr", "eir_to_hbr")} - def test_returns_expected_keys(self): - res = self._run(0.25) - assert set(res) == { + def _run(self, models, delta): + inputs = pd.DataFrame([{"prevalence": 0.30, "mosquito_delta": delta, **INTERVENTIONS}]) + return estimate_eir_with_mosquito_delta(inputs, models=models).iloc[0] + + def test_returns_expected_columns(self, models): + res = self._run(models, 0.25) + assert set(res.index) == { "eir_baseline", "eir_new", "eir_multiplier", "hbr_baseline", "hbr_new", } - def test_zero_delta_is_identity(self): - res = self._run(0.0) + def test_zero_delta_is_identity(self, models): + res = self._run(models, 0.0) assert res["eir_new"] == res["eir_baseline"] assert res["eir_multiplier"] == 1.0 - def test_more_mosquitoes_raises_eir(self): - res = self._run(0.25) + def test_more_mosquitoes_raises_eir(self, models): + res = self._run(models, 0.25) assert res["eir_new"] > res["eir_baseline"] assert res["eir_multiplier"] > 1.0 assert res["hbr_new"] > res["hbr_baseline"] - def test_fewer_mosquitoes_lowers_eir(self): - res = self._run(-0.50) + def test_fewer_mosquitoes_lowers_eir(self, models): + res = self._run(models, -0.50) assert res["eir_new"] < res["eir_baseline"] assert res["hbr_new"] < res["hbr_baseline"] - def test_monotonic_in_delta(self): - eirs = [self._run(d)["eir_new"] for d in (-0.5, -0.25, 0.0, 0.25, 0.5, 1.0)] - assert eirs == sorted(eirs) + def test_batch_is_monotonic_in_delta(self, models): + # a single batched call handles every row and preserves input order + deltas = [-0.5, -0.25, 0.0, 0.25, 0.5, 1.0] + inputs = pd.DataFrame([{"prevalence": 0.30, "mosquito_delta": d, **INTERVENTIONS} for d in deltas]) + res = estimate_eir_with_mosquito_delta(inputs, models=models) + assert list(res.index) == list(range(len(deltas))) + assert list(res["eir_new"]) == sorted(res["eir_new"]) diff --git a/tests/test_scenarios.py b/tests/test_scenarios.py index bb18c34..af1ceb6 100644 --- a/tests/test_scenarios.py +++ b/tests/test_scenarios.py @@ -1,8 +1,8 @@ """Tests for the run_scenarios pipeline. -The estiMINT half (_estimate_eir) is tested offline against the bundled models. -The full run_scenarios call also runs the stateMINT emulator, so it is skipped -unless stateMINT is installed. +The estiMINT half (_estimate_eir_many) is tested offline against the bundled +models. The full run_scenarios call also runs the stateMINT emulator, so it is +skipped unless stateMINT is installed. """ from typing import Any @@ -10,7 +10,15 @@ import numpy as np import pytest -from estimint.scenarios import _estimate_eir, _est_models, run_scenarios +from estimint.scenarios import ( + _PreparedScenario, + _apply_mosquito_delta_batch, + _classify_prepared_scenario, + _estimate_eir_many, + _load_eir_hbr_models, + _prepare_scenario_inputs, + run_scenarios, +) from estimint.types import EirTarget, Scenario INTV: dict[str, Any] = dict(Q0=0.87, phi=0.82, seasonal=0.0, irs=0.0) @@ -24,119 +32,217 @@ def mk(**kwargs: Any) -> Scenario: return Scenario(eir_target=EirTarget(input_value, input_mode), **defaults) +def _estimate_eir(scenario: Scenario, eir_models: dict[str, Any]) -> _PreparedScenario: + """Estimate EIR for a single scenario via the batch estimator.""" + return _estimate_eir_many([scenario], eir_models)[0] + + @pytest.fixture(scope="module") def est(): - return _est_models() + return _load_eir_hbr_models() class TestEstimateEir: def test_prevalence_input(self, est): out = _estimate_eir(mk(input="prevalence", value=0.30), est) - assert out["row"]["eir_baseline"] > 0 - assert out["cov"]["eir"] == out["row"]["eir_final"] + assert out.summary_values["eir_baseline"] > 0 + assert out.emulator_covariates["eir"] == out.summary_values["eir_final"] # no nets, no delta - assert out["row"]["dn0_use"] == 0.0 - assert np.isnan(out["row"]["hbr_baseline"]) + assert out.summary_values["dn0_use"] == 0.0 + assert np.isnan(out.summary_values["hbr_baseline"]) def test_eir_input_passes_through(self, est): out = _estimate_eir(mk(input="eir", value=20.0), est) - assert out["row"]["eir_baseline"] == 20.0 - assert out["row"]["eir_final"] == 20.0 + assert out.summary_values["eir_baseline"] == 20.0 + assert out.summary_values["eir_final"] == 20.0 def test_hbr_input(self, est): out = _estimate_eir(mk(input="hbr", value=250000.0), est) - assert out["row"]["eir_baseline"] > 0 + assert out.summary_values["eir_baseline"] > 0 def test_bednet_mix_matches_minte(self, est): # net-type usage mix feeds calculate_dn0 directly, same as minte; itn_use # is the sum of pyrethroid shares (NOT rescaled by coverage again). - out = _estimate_eir( - mk(input="prevalence", value=0.30, py_only=0.70, py_pbo=0.30, res_use=0.30), est) - assert out["row"]["dn0_use"] > 0 - assert out["row"]["itn_use"] == pytest.approx(1.00) - assert out["row"]["net"] == "py_only+py_pbo" + out = _estimate_eir(mk(input="prevalence", value=0.30, py_only=0.70, py_pbo=0.30, res_use=0.30), est) + assert out.summary_values["dn0_use"] > 0 + assert out.summary_values["itn_use"] == pytest.approx(1.00) + assert out.summary_values["net"] == "py_only+py_pbo" def test_future_net_switch_is_separate_leg(self, est): # net_type_future/itn_future drive the future leg, not current nets; # the future leg shares the same res_use as current (no res_future field). - out = _estimate_eir( - mk(input="prevalence", value=0.30, py_only=0.50, res_use=0.30, - net_type_future="pyrethroid_pbo", itn_future=0.70), est)["row"] - assert out["itn_use"] == pytest.approx(0.50) # current: py_only=0.50 - assert out["itn_future"] == pytest.approx(0.70) # future: pbo=0.70 - assert out["dn0_use"] != out["dn0_future"] - assert out["net"] == "py_only" and out["net_future"] == "pyrethroid_pbo" + scenario_estimate = _estimate_eir( + mk( + input="prevalence", + value=0.30, + py_only=0.50, + res_use=0.30, + net_type_future="pyrethroid_pbo", + itn_future=0.70, + ), + est, + ) + summary_values = scenario_estimate.summary_values + assert summary_values["itn_use"] == pytest.approx(0.50) # current: py_only=0.50 + assert summary_values["itn_future"] == pytest.approx(0.70) # future: pbo=0.70 + assert summary_values["dn0_use"] != summary_values["dn0_future"] + assert summary_values["net"] == "py_only" and summary_values["net_future"] == "pyrethroid_pbo" def test_future_without_net_type_is_zeroed(self, est): # no net_type_future named -> future leg is zeroed, even if itn_future is set # (no carry-forward of the current mix; this is intentional, not a default) - out = _estimate_eir( - mk(input="prevalence", value=0.30, py_pbo=0.80, res_use=0.30, - itn_future=0.70), est)["row"] - assert out["dn0_future"] == 0.0 and out["itn_future"] == 0.0 - assert out["net_future"] == "none" + summary_values = _estimate_eir( + mk(input="prevalence", value=0.30, py_pbo=0.80, res_use=0.30, itn_future=0.70), est + ).summary_values + assert summary_values["dn0_future"] == 0.0 and summary_values["itn_future"] == 0.0 + assert summary_values["net_future"] == "none" def test_future_nets_removed(self, est): # itn_future == 0 removes nets in the future leg - out = _estimate_eir( - mk(input="prevalence", value=0.30, py_only=0.60, res_use=0.30, - itn_future=0.0), est)["row"] - assert out["itn_use"] == pytest.approx(0.60) - assert out["dn0_future"] == 0.0 and out["itn_future"] == 0.0 + summary_values = _estimate_eir( + mk(input="prevalence", value=0.30, py_only=0.60, res_use=0.30, itn_future=0.0), est + ).summary_values + assert summary_values["itn_use"] == pytest.approx(0.60) + assert summary_values["dn0_future"] == 0.0 and summary_values["itn_future"] == 0.0 def test_ppf_boosts_lsm(self, est): # PPF nets add larviciding to LSM (minte: py_ppf * 0.248) - out = _estimate_eir( - mk(input="eir", value=15.0, py_ppf=0.50, res_use=0.30, lsm=0.10), est) - assert out["cov"]["lsm"] == pytest.approx(0.50 * 0.248 + 0.10) + out = _estimate_eir(mk(input="eir", value=15.0, py_ppf=0.50, res_use=0.30, lsm=0.10), est) + assert out.emulator_covariates["lsm"] == pytest.approx(0.50 * 0.248 + 0.10) def test_irs_future_and_routine_are_inputs(self, est): # irs_future and routine are separate scenario inputs, like minte - out = _estimate_eir( - mk(input="eir", value=15.0, irs=0.40, irs_future=0.10, routine=0.25), est)["cov"] - assert out["irs_use"] == 0.40 and out["irs_future"] == 0.10 - assert out["routine"] == 0.25 + emulator_covariates = _estimate_eir( + mk(input="eir", value=15.0, irs=0.40, irs_future=0.10, routine=0.25), est + ).emulator_covariates + assert emulator_covariates["irs_use"] == 0.40 and emulator_covariates["irs_future"] == 0.10 + assert emulator_covariates["routine"] == 0.25 def test_irs_future_and_routine_defaults(self, est): # irs_future is a static dataclass default (0.0); it does NOT carry irs # forward automatically, even when irs is nonzero. - out = _estimate_eir(mk(input="eir", value=15.0, irs=0.40), est)["cov"] - assert out["irs_future"] == 0.0 - assert out["irs_use"] == 0.40 - assert out["routine"] == 0.0 + emulator_covariates = _estimate_eir( + mk(input="eir", value=15.0, irs=0.40), est + ).emulator_covariates + assert emulator_covariates["irs_future"] == 0.0 + assert emulator_covariates["irs_use"] == 0.40 + assert emulator_covariates["routine"] == 0.0 def test_mosquito_delta_direction(self, est): up = _estimate_eir(mk(input="prevalence", value=0.30, mosquito_delta=0.25), est) down = _estimate_eir(mk(input="prevalence", value=0.30, mosquito_delta=-0.50), est) - assert up["row"]["eir_final"] > up["row"]["eir_baseline"] - assert down["row"]["eir_final"] < down["row"]["eir_baseline"] - assert up["row"]["hbr_new"] > up["row"]["hbr_baseline"] + assert up.summary_values["eir_final"] > up.summary_values["eir_baseline"] + assert down.summary_values["eir_final"] < down.summary_values["eir_baseline"] + assert up.summary_values["hbr_new"] > up.summary_values["hbr_baseline"] def test_covariate_dict_keys(self, est): out = _estimate_eir(mk(input="eir", value=15.0, lsm=0.3), est) - assert set(out["cov"]) == { - "eir", "dn0_use", "dn0_future", "Q0", "phi_bednets", "seasonal", - "routine", "itn_use", "irs_use", "itn_future", "irs_future", "lsm", + assert set(out.emulator_covariates) == { + "eir", + "dn0_use", + "dn0_future", + "Q0", + "phi_bednets", + "seasonal", + "routine", + "itn_use", + "irs_use", + "itn_future", + "irs_future", + "lsm", } - assert out["cov"]["lsm"] == 0.3 - assert out["cov"]["dn0_future"] == out["cov"]["dn0_use"] + assert out.emulator_covariates["lsm"] == 0.3 + assert out.emulator_covariates["dn0_future"] == out.emulator_covariates["dn0_use"] def test_bad_input_raises(self, est): - with pytest.raises(ValueError, match="input"): + with pytest.raises(ValueError, match="input_mode"): _estimate_eir(mk(input="nope", value=1.0), est) + def test_batch_mixed_input_modes(self, est): + # one batched call must estimate every input mode and preserve order + scenarios = [ + mk(name="prev", input="prevalence", value=0.30), + mk(name="eir", input="eir", value=20.0), + mk(name="hbr", input="hbr", value=250000.0), + mk(name="prev_delta", input="prevalence", value=0.30, mosquito_delta=0.25), + ] + outs = _estimate_eir_many(scenarios, est) + assert [out.summary_values["name"] for out in outs] == ["prev", "eir", "hbr", "prev_delta"] + assert all(out.summary_values["eir_baseline"] > 0 for out in outs) + assert outs[1].summary_values["eir_final"] == 20.0 # explicit eir passes through + # mosquito-delta scenario shifts eir away from baseline and fills hbr + assert outs[3].summary_values["eir_final"] != outs[3].summary_values["eir_baseline"] + assert outs[3].summary_values["hbr_new"] > 0 + + +class TestClassifyPreparedScenario: + def _make_prepared(self, input_mode: str, mosquito_density_change: float = 0.0) -> _PreparedScenario: + from typing import cast + from estimint.types import Input_Mode + + return _PreparedScenario( + eir_target=EirTarget(input_value=10.0, input_mode=cast(Input_Mode, input_mode)), + mosquito_density_change=mosquito_density_change, + eir_model_features={}, + summary_values={}, + emulator_covariates={}, + ) + + def test_eir_mode(self): + assert _classify_prepared_scenario(self._make_prepared("eir")) == "eir" + + def test_eir_mode_ignores_mosquito_delta(self): + # eir input always passes through, even if mosquito_delta is set + assert _classify_prepared_scenario(self._make_prepared("eir", mosquito_density_change=0.5)) == "eir" + + def test_prevalence_without_delta(self): + assert _classify_prepared_scenario(self._make_prepared("prevalence")) == "prevalence" + + def test_prevalence_with_delta(self): + assert _classify_prepared_scenario(self._make_prepared("prevalence", mosquito_density_change=0.25)) == "mosquito_delta" + + def test_hbr_mode(self): + assert _classify_prepared_scenario(self._make_prepared("hbr")) == "hbr" + + +class TestApplyMosquitoDeltaBatch: + def test_fills_eir_and_hbr_estimates(self, est): + prepared = _prepare_scenario_inputs(mk(input="prevalence", value=0.30, mosquito_delta=0.25)) + _apply_mosquito_delta_batch([prepared], est) + assert prepared.summary_values["eir_baseline"] > 0 + assert prepared.summary_values["eir_final"] > prepared.summary_values["eir_baseline"] + assert prepared.summary_values["hbr_new"] > prepared.summary_values["hbr_baseline"] + assert prepared.emulator_covariates["eir"] == prepared.summary_values["eir_final"] + + def test_negative_delta_lowers_eir(self, est): + prepared = _prepare_scenario_inputs(mk(input="prevalence", value=0.30, mosquito_delta=-0.50)) + _apply_mosquito_delta_batch([prepared], est) + assert prepared.summary_values["eir_final"] < prepared.summary_values["eir_baseline"] + class TestRunScenariosFullPipeline: def test_end_to_end(self): pytest.importorskip("stateMINT", reason="stateMINT not installed") - df = run_scenarios([ - Scenario(name="prev+delta", eir_target=EirTarget(0.30, "prevalence"), - py_only=0.60, res_use=0.55, net_type_future="pyrethroid_pbo", - itn_future=0.85, - Q0=0.90, phi=0.85, seasonal=1, irs=0.40, mosquito_delta=0.60), - Scenario(name="eir", eir_target=EirTarget(20.0, "eir"), res_use=0.0, - Q0=0.88, phi=0.78, seasonal=1, irs=0.60), - ]) + df = run_scenarios( + [ + Scenario( + name="prev+delta", + eir_target=EirTarget(0.30, "prevalence"), + py_only=0.60, + res_use=0.55, + net_type_future="pyrethroid_pbo", + itn_future=0.85, + Q0=0.90, + phi=0.85, + seasonal=1, + irs=0.40, + mosquito_delta=0.60, + ), + Scenario( + name="eir", eir_target=EirTarget(20.0, "eir"), res_use=0.0, Q0=0.88, phi=0.78, seasonal=1, irs=0.60 + ), + ] + ) assert len(df) == 2 assert {"eir_final", "prev_y9", "prevalence", "cases"} <= set(df.columns) assert len(df.iloc[0]["prevalence"]) == 157 From 78689197b3476218a621ae89b73b40d645cbea61 Mon Sep 17 00:00:00 2001 From: anmol thapar Date: Wed, 1 Jul 2026 09:35:09 +0100 Subject: [PATCH 3/4] bump version to 1.5.4 and add preload_models to public API --- pyproject.toml | 2 +- src/estimint/__init__.py | 4 +++- uv.lock | 2 +- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index fddb894..46c921a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "estimint" -version = "1.5.3" +version = "1.5.4" description = "EIR Estimation using Machine learning interventions " readme = "README.md" license = "MIT" diff --git a/src/estimint/__init__.py b/src/estimint/__init__.py index 0f23b4d..95c1834 100644 --- a/src/estimint/__init__.py +++ b/src/estimint/__init__.py @@ -56,7 +56,7 @@ from .bednet import calculate_dn0, net_types, DN0Result -from .scenarios import run_scenarios +from .scenarios import run_scenarios, preload_models from .types import Scenario, EirTarget __all__ = [ @@ -100,6 +100,8 @@ "DN0Result", # scenarios "run_scenarios", + "preload_models", + # types "Scenario", "EirTarget", ] diff --git a/uv.lock b/uv.lock index 9505c2f..657798c 100644 --- a/uv.lock +++ b/uv.lock @@ -361,7 +361,7 @@ wheels = [ [[package]] name = "estimint" -version = "1.5.3" +version = "1.5.4" source = { editable = "." } dependencies = [ { name = "numpy" }, From 9f8ec066dd910a5023a824ccbb71d2438294d770 Mon Sep 17 00:00:00 2001 From: anmol thapar Date: Wed, 1 Jul 2026 09:37:26 +0100 Subject: [PATCH 4/4] refactor: rename _estimate_eir_many to _estimate_eir for clarity and consistency --- src/estimint/scenarios.py | 4 ++-- tests/test_scenarios.py | 43 ++++++++++++++++++++------------------- 2 files changed, 24 insertions(+), 23 deletions(-) diff --git a/src/estimint/scenarios.py b/src/estimint/scenarios.py index 7f7ad73..d853d4b 100644 --- a/src/estimint/scenarios.py +++ b/src/estimint/scenarios.py @@ -239,7 +239,7 @@ def _apply_mosquito_delta_batch(prepared_scenarios: list[_PreparedScenario], eir ) -def _estimate_eir_many(scenarios: list[Scenario], eir_models: Dict[str, Any]) -> list[_PreparedScenario]: +def _estimate_eir(scenarios: list[Scenario], eir_models: Dict[str, Any]) -> list[_PreparedScenario]: """Estimate EIR for many scenarios, dispatching each to one of three paths: - "eir": supplied directly, passed through unchanged - "prevalence" / "hbr": predicted from baseline measurements via XGBoost @@ -351,7 +351,7 @@ def run_scenarios( eir_models, emulator_models = preload_models(hf_repo=hf_repo) - scenario_estimates = _estimate_eir_many(scenarios, eir_models) + scenario_estimates = _estimate_eir(scenarios, eir_models) emulator_covariates = [estimate.emulator_covariates for estimate in scenario_estimates] prevalence_timeseries = emulator_models["prevalence"].predict(emulator_covariates) case_timeseries = np.maximum(emulator_models["cases"].predict(emulator_covariates), 0.0) diff --git a/tests/test_scenarios.py b/tests/test_scenarios.py index af1ceb6..e986d37 100644 --- a/tests/test_scenarios.py +++ b/tests/test_scenarios.py @@ -14,7 +14,7 @@ _PreparedScenario, _apply_mosquito_delta_batch, _classify_prepared_scenario, - _estimate_eir_many, + _estimate_eir, _load_eir_hbr_models, _prepare_scenario_inputs, run_scenarios, @@ -32,9 +32,9 @@ def mk(**kwargs: Any) -> Scenario: return Scenario(eir_target=EirTarget(input_value, input_mode), **defaults) -def _estimate_eir(scenario: Scenario, eir_models: dict[str, Any]) -> _PreparedScenario: +def _estimate_eir_single(scenario: Scenario, eir_models: dict[str, Any]) -> _PreparedScenario: """Estimate EIR for a single scenario via the batch estimator.""" - return _estimate_eir_many([scenario], eir_models)[0] + return _estimate_eir([scenario], eir_models)[0] @pytest.fixture(scope="module") @@ -44,7 +44,7 @@ def est(): class TestEstimateEir: def test_prevalence_input(self, est): - out = _estimate_eir(mk(input="prevalence", value=0.30), est) + out = _estimate_eir_single(mk(input="prevalence", value=0.30), est) assert out.summary_values["eir_baseline"] > 0 assert out.emulator_covariates["eir"] == out.summary_values["eir_final"] # no nets, no delta @@ -52,18 +52,18 @@ def test_prevalence_input(self, est): assert np.isnan(out.summary_values["hbr_baseline"]) def test_eir_input_passes_through(self, est): - out = _estimate_eir(mk(input="eir", value=20.0), est) + out = _estimate_eir_single(mk(input="eir", value=20.0), est) assert out.summary_values["eir_baseline"] == 20.0 assert out.summary_values["eir_final"] == 20.0 def test_hbr_input(self, est): - out = _estimate_eir(mk(input="hbr", value=250000.0), est) + out = _estimate_eir_single(mk(input="hbr", value=250000.0), est) assert out.summary_values["eir_baseline"] > 0 def test_bednet_mix_matches_minte(self, est): # net-type usage mix feeds calculate_dn0 directly, same as minte; itn_use # is the sum of pyrethroid shares (NOT rescaled by coverage again). - out = _estimate_eir(mk(input="prevalence", value=0.30, py_only=0.70, py_pbo=0.30, res_use=0.30), est) + out = _estimate_eir_single(mk(input="prevalence", value=0.30, py_only=0.70, py_pbo=0.30, res_use=0.30), est) assert out.summary_values["dn0_use"] > 0 assert out.summary_values["itn_use"] == pytest.approx(1.00) assert out.summary_values["net"] == "py_only+py_pbo" @@ -71,7 +71,7 @@ def test_bednet_mix_matches_minte(self, est): def test_future_net_switch_is_separate_leg(self, est): # net_type_future/itn_future drive the future leg, not current nets; # the future leg shares the same res_use as current (no res_future field). - scenario_estimate = _estimate_eir( + scenario_estimate = _estimate_eir_single( mk( input="prevalence", value=0.30, @@ -91,7 +91,7 @@ def test_future_net_switch_is_separate_leg(self, est): def test_future_without_net_type_is_zeroed(self, est): # no net_type_future named -> future leg is zeroed, even if itn_future is set # (no carry-forward of the current mix; this is intentional, not a default) - summary_values = _estimate_eir( + summary_values = _estimate_eir_single( mk(input="prevalence", value=0.30, py_pbo=0.80, res_use=0.30, itn_future=0.70), est ).summary_values assert summary_values["dn0_future"] == 0.0 and summary_values["itn_future"] == 0.0 @@ -99,7 +99,7 @@ def test_future_without_net_type_is_zeroed(self, est): def test_future_nets_removed(self, est): # itn_future == 0 removes nets in the future leg - summary_values = _estimate_eir( + summary_values = _estimate_eir_single( mk(input="prevalence", value=0.30, py_only=0.60, res_use=0.30, itn_future=0.0), est ).summary_values assert summary_values["itn_use"] == pytest.approx(0.60) @@ -107,12 +107,12 @@ def test_future_nets_removed(self, est): def test_ppf_boosts_lsm(self, est): # PPF nets add larviciding to LSM (minte: py_ppf * 0.248) - out = _estimate_eir(mk(input="eir", value=15.0, py_ppf=0.50, res_use=0.30, lsm=0.10), est) + out = _estimate_eir_single(mk(input="eir", value=15.0, py_ppf=0.50, res_use=0.30, lsm=0.10), est) assert out.emulator_covariates["lsm"] == pytest.approx(0.50 * 0.248 + 0.10) def test_irs_future_and_routine_are_inputs(self, est): # irs_future and routine are separate scenario inputs, like minte - emulator_covariates = _estimate_eir( + emulator_covariates = _estimate_eir_single( mk(input="eir", value=15.0, irs=0.40, irs_future=0.10, routine=0.25), est ).emulator_covariates assert emulator_covariates["irs_use"] == 0.40 and emulator_covariates["irs_future"] == 0.10 @@ -121,22 +121,20 @@ def test_irs_future_and_routine_are_inputs(self, est): def test_irs_future_and_routine_defaults(self, est): # irs_future is a static dataclass default (0.0); it does NOT carry irs # forward automatically, even when irs is nonzero. - emulator_covariates = _estimate_eir( - mk(input="eir", value=15.0, irs=0.40), est - ).emulator_covariates + emulator_covariates = _estimate_eir_single(mk(input="eir", value=15.0, irs=0.40), est).emulator_covariates assert emulator_covariates["irs_future"] == 0.0 assert emulator_covariates["irs_use"] == 0.40 assert emulator_covariates["routine"] == 0.0 def test_mosquito_delta_direction(self, est): - up = _estimate_eir(mk(input="prevalence", value=0.30, mosquito_delta=0.25), est) - down = _estimate_eir(mk(input="prevalence", value=0.30, mosquito_delta=-0.50), est) + up = _estimate_eir_single(mk(input="prevalence", value=0.30, mosquito_delta=0.25), est) + down = _estimate_eir_single(mk(input="prevalence", value=0.30, mosquito_delta=-0.50), est) assert up.summary_values["eir_final"] > up.summary_values["eir_baseline"] assert down.summary_values["eir_final"] < down.summary_values["eir_baseline"] assert up.summary_values["hbr_new"] > up.summary_values["hbr_baseline"] def test_covariate_dict_keys(self, est): - out = _estimate_eir(mk(input="eir", value=15.0, lsm=0.3), est) + out = _estimate_eir_single(mk(input="eir", value=15.0, lsm=0.3), est) assert set(out.emulator_covariates) == { "eir", "dn0_use", @@ -156,7 +154,7 @@ def test_covariate_dict_keys(self, est): def test_bad_input_raises(self, est): with pytest.raises(ValueError, match="input_mode"): - _estimate_eir(mk(input="nope", value=1.0), est) + _estimate_eir_single(mk(input="nope", value=1.0), est) def test_batch_mixed_input_modes(self, est): # one batched call must estimate every input mode and preserve order @@ -166,7 +164,7 @@ def test_batch_mixed_input_modes(self, est): mk(name="hbr", input="hbr", value=250000.0), mk(name="prev_delta", input="prevalence", value=0.30, mosquito_delta=0.25), ] - outs = _estimate_eir_many(scenarios, est) + outs = _estimate_eir(scenarios, est) assert [out.summary_values["name"] for out in outs] == ["prev", "eir", "hbr", "prev_delta"] assert all(out.summary_values["eir_baseline"] > 0 for out in outs) assert outs[1].summary_values["eir_final"] == 20.0 # explicit eir passes through @@ -199,7 +197,10 @@ def test_prevalence_without_delta(self): assert _classify_prepared_scenario(self._make_prepared("prevalence")) == "prevalence" def test_prevalence_with_delta(self): - assert _classify_prepared_scenario(self._make_prepared("prevalence", mosquito_density_change=0.25)) == "mosquito_delta" + assert ( + _classify_prepared_scenario(self._make_prepared("prevalence", mosquito_density_change=0.25)) + == "mosquito_delta" + ) def test_hbr_mode(self): assert _classify_prepared_scenario(self._make_prepared("hbr")) == "hbr"