diff --git a/packages/wasm-mps/src/lib.rs b/packages/wasm-mps/src/lib.rs index 01ee70c396f..4d9af3ecbfc 100644 --- a/packages/wasm-mps/src/lib.rs +++ b/packages/wasm-mps/src/lib.rs @@ -116,6 +116,18 @@ mod mps { pub chaincode: [u8; 32], } + fn rem_prefix(prefix: &str, data: &Vec) -> Result, MpsError> { + Ok(data + .as_slice() + .strip_prefix(prefix.as_bytes()) + .ok_or(MpsError::InvalidInput)? + .to_vec()) + } + + fn add_prefix(prefix: &str, data: &Vec) -> Vec { + [prefix.as_bytes(), data.as_slice()].concat() + } + fn internal_dkg_round0_process( party_id: u8, decryption_key: &[u8; 32], @@ -193,7 +205,16 @@ mod mps { encryption_keys: &[Vec; 2], seed: &[u8; 32], ) -> Result { - internal_dkg_round0_process::(party_id, decryption_key, encryption_keys, seed) + let result = internal_dkg_round0_process::( + party_id, + decryption_key, + encryption_keys, + seed, + )?; + Ok(MsgState { + msg: add_prefix("mps-ed25519-dkg-round1-message$", &result.msg), + state: add_prefix("mps-ed25519-dkg-round1-state$", &result.state), + }) } fn internal_dkg_round1_process( @@ -240,7 +261,14 @@ mod mps { round1_messages: &[Vec; 2], state: &[u8], ) -> Result { - internal_dkg_round1_process::(round1_messages, state) + let i0_msg1 = rem_prefix("mps-ed25519-dkg-round1-message$", &round1_messages[0])?; + let i1_msg1 = rem_prefix("mps-ed25519-dkg-round1-message$", &round1_messages[1])?; + let state = rem_prefix("mps-ed25519-dkg-round1-state$", &state.to_vec())?; + let result = internal_dkg_round1_process::(&[i0_msg1, i1_msg1], &state)?; + Ok(MsgState { + msg: add_prefix("mps-ed25519-dkg-round2-message$", &result.msg), + state: add_prefix("mps-ed25519-dkg-round2-state$", &result.state), + }) } fn internal_dkg_round2_process( @@ -277,7 +305,10 @@ mod mps { round2_messages: &[Vec; 2], state: &[u8], ) -> Result { - let share = internal_dkg_round2_process::(round2_messages, state)?; + let i0_msg2 = rem_prefix("mps-ed25519-dkg-round2-message$", &round2_messages[0])?; + let i1_msg2 = rem_prefix("mps-ed25519-dkg-round2-message$", &round2_messages[1])?; + let state = rem_prefix("mps-ed25519-dkg-round2-state$", &state.to_vec())?; + let share = internal_dkg_round2_process::(&[i0_msg2, i1_msg2], &state)?; Ok(Share { share: bincode::serialize(&share).map_err(|_| MpsError::SerializationError)?, pk: share.public_key.compress().to_bytes(), @@ -328,7 +359,12 @@ mod mps { &mut rand::thread_rng(), ); - internal_dsg_round0_process(p0) + let result = internal_dsg_round0_process(p0)?; + + Ok(MsgState { + msg: add_prefix("mps-ed25519-dsg-round1-message$", &result.msg), + state: add_prefix("mps-ed25519-dsg-round1-state$", &result.state), + }) } fn internal_dsg_round1_process( @@ -373,7 +409,17 @@ mod mps { round1_message: &[u8], state: &[u8], ) -> Result { - internal_dsg_round1_process::(round1_message, state) + let round1_message = + rem_prefix("mps-ed25519-dsg-round1-message$", &round1_message.to_vec())?; + let state = rem_prefix("mps-ed25519-dsg-round1-state$", &state.to_vec())?; + let result = internal_dsg_round1_process::( + round1_message.as_slice(), + state.as_slice(), + )?; + Ok(MsgState { + msg: add_prefix("mps-ed25519-dsg-round2-message$", &result.msg), + state: add_prefix("mps-ed25519-dsg-round2-state$", &result.state), + }) } /// Process round 2 of DSG protocol. @@ -383,13 +429,18 @@ mod mps { round2_message: &[u8], state: &[u8], ) -> Result { + // Strip prefix + let round2_message = + rem_prefix("mps-ed25519-dsg-round2-message$", &round2_message.to_vec())?; + let state = rem_prefix("mps-ed25519-dsg-round2-state$", &state.to_vec())?; + // Parse state let state: DsgStateR2 = - bincode::deserialize(state).map_err(|_| MpsError::DeserializationError)?; + bincode::deserialize(&state).map_err(|_| MpsError::DeserializationError)?; // Parse messages let i0_msg2: SignMsg2 = - bincode::deserialize(round2_message).map_err(|_| MpsError::DeserializationError)?; + bincode::deserialize(&round2_message).map_err(|_| MpsError::DeserializationError)?; let msgs = vec![i0_msg2, state.msg]; // Process all round2 messages together @@ -408,8 +459,14 @@ mod mps { }; Ok(MsgState { - msg: bincode::serialize(&msg3).map_err(|_| MpsError::SerializationError)?, - state: bincode::serialize(&state).map_err(|_| MpsError::SerializationError)?, + msg: add_prefix( + "mps-ed25519-dsg-round3-message$", + &bincode::serialize(&msg3).map_err(|_| MpsError::SerializationError)?, + ), + state: add_prefix( + "mps-ed25519-dsg-round3-state$", + &bincode::serialize(&state).map_err(|_| MpsError::SerializationError)?, + ), }) } @@ -420,13 +477,18 @@ mod mps { round3_message: &[u8], state: &[u8], ) -> Result, MpsError> { + // Strip prefix + let round3_message = + rem_prefix("mps-ed25519-dsg-round3-message$", &round3_message.to_vec())?; + let state = rem_prefix("mps-ed25519-dsg-round3-state$", &state.to_vec())?; + // Parse state let state: DsgStateR3 = - bincode::deserialize(state).map_err(|_| MpsError::DeserializationError)?; + bincode::deserialize(&state).map_err(|_| MpsError::DeserializationError)?; // Parse messages let i0_msg3: SignMsg3 = - bincode::deserialize(round3_message).map_err(|_| MpsError::DeserializationError)?; + bincode::deserialize(&round3_message).map_err(|_| MpsError::DeserializationError)?; let msgs = vec![i0_msg3, state.msg]; // Process all round2 messages together diff --git a/packages/wasm-mps/test/mps.ts b/packages/wasm-mps/test/mps.ts index cd3730ee45d..05ff3cdc2bc 100644 --- a/packages/wasm-mps/test/mps.ts +++ b/packages/wasm-mps/test/mps.ts @@ -13,6 +13,15 @@ describe("mps", function () { ]; const keypairs: Array<{ privateKey: Uint8Array; publicKey: Uint8Array }> = []; + function shouldThrow(fn: () => unknown): unknown { + try { + fn(); + } catch (e: unknown) { + return e; + } + throw new Error("Expected function to throw an error"); + } + before("generates keypairs", function () { for (let i = 0; i < 3; i++) { keypairs.push(sodium.crypto_box_keypair()); @@ -21,13 +30,17 @@ describe("mps", function () { describe("dkg", function () { it("performs round 0", function () { + const messagePrefix = Buffer.from("mps-ed25519-dkg-round1-message$"); + const statePrefix = Buffer.from("mps-ed25519-dkg-round1-state$"); for (let i = 0; i < keypairs.length; i++) { - mps.ed25519_dkg_round0_process( + const result = mps.ed25519_dkg_round0_process( i, keypairs[i].privateKey, otherIndices[i].map((i) => keypairs[i].publicKey), crypto.randomBytes(32), ); + assert(Buffer.from(result.msg).slice(0, messagePrefix.length).equals(messagePrefix)); + assert(Buffer.from(result.state).slice(0, statePrefix.length).equals(statePrefix)); } }); @@ -45,11 +58,59 @@ describe("mps", function () { }); it("performs round 1", function () { + const messagePrefix = Buffer.from("mps-ed25519-dkg-round2-message$"); + const statePrefix = Buffer.from("mps-ed25519-dkg-round2-state$"); for (let i = 0; i < results1.length; i++) { - mps.ed25519_dkg_round1_process( + const result = mps.ed25519_dkg_round1_process( otherIndices[i].map((i) => results1[i].msg), results1[i].state, ); + assert(Buffer.from(result.msg).slice(0, messagePrefix.length).equals(messagePrefix)); + assert(Buffer.from(result.state).slice(0, statePrefix.length).equals(statePrefix)); + } + }); + + it("fails to perform round 1 with invalid message prefix", function () { + const messagePrefix = Buffer.from("mps-ed25519-dkg-round1-message$"); + for (let i = 0; i < results1.length; i++) { + shouldThrow(() => + mps.ed25519_dkg_round1_process( + otherIndices[i].map((i) => Buffer.from(results1[i].msg).slice(messagePrefix.length)), + results1[i].state, + ), + ); + shouldThrow(() => + mps.ed25519_dkg_round1_process( + otherIndices[i].map((i) => + Buffer.concat([ + Buffer.from("msg-ed25519-dkg-round2-message$"), + Buffer.from(results1[i].msg).slice(messagePrefix.length), + ]), + ), + results1[i].state, + ), + ); + } + }); + + it("fails to perform round 1 with invalid state prefix", function () { + const statePrefix = Buffer.from("mps-ed25519-dkg-round1-state$"); + for (let i = 0; i < results1.length; i++) { + shouldThrow(() => + mps.ed25519_dkg_round1_process( + otherIndices[i].map((i) => results1[i].msg), + Buffer.from(results1[i].state).slice(statePrefix.length), + ), + ); + shouldThrow(() => + mps.ed25519_dkg_round1_process( + otherIndices[i].map((i) => results1[i].msg), + Buffer.concat([ + Buffer.from("mps-ed25519-dkg-round2-state$"), + Buffer.from(results1[i].state).slice(statePrefix.length), + ]), + ), + ); } }); @@ -80,15 +141,6 @@ describe("mps", function () { }); describe("input handling", function () { - function shouldThrow(fn: () => unknown): unknown { - try { - fn(); - } catch (e: unknown) { - return e; - } - throw new Error("Expected function to throw an error"); - } - describe("round0_process", function () { it("does not panic on bad party size", function () { shouldThrow(() => @@ -248,8 +300,12 @@ describe("mps", function () { ); it("performs round 0", function () { + const messagePrefix = Buffer.from("mps-ed25519-dsg-round1-message$"); + const statePrefix = Buffer.from("mps-ed25519-dsg-round1-state$"); for (const i of [0, 2]) { - mps.ed25519_dsg_round0_process(shares[i].share, "m", message); + const result = mps.ed25519_dsg_round0_process(shares[i].share, "m", message); + assert(Buffer.from(result.msg).slice(0, messagePrefix.length).equals(messagePrefix)); + assert(Buffer.from(result.state).slice(0, statePrefix.length).equals(statePrefix)); } }); @@ -260,8 +316,54 @@ describe("mps", function () { }); it("performs round 1", function () { + const messagePrefix = Buffer.from("mps-ed25519-dsg-round2-message$"); + const statePrefix = Buffer.from("mps-ed25519-dsg-round2-state$"); + for (let i = 0; i < results1.length; i++) { + const result = mps.ed25519_dsg_round1_process( + results1[otherIndex[i]].msg, + results1[i].state, + ); + assert(Buffer.from(result.msg).slice(0, messagePrefix.length).equals(messagePrefix)); + assert(Buffer.from(result.state).slice(0, statePrefix.length).equals(statePrefix)); + } + }); + + it("fails to perform round 1 with invalid message prefix", function () { + const messagePrefix = Buffer.from("mps-ed25519-dsg-round1-message$"); for (let i = 0; i < results1.length; i++) { - mps.ed25519_dsg_round1_process(results1[otherIndex[i]].msg, results1[i].state); + shouldThrow(() => + mps.ed25519_dsg_round1_process( + Buffer.from(results1[otherIndex[i]].msg).slice(messagePrefix.length), + results1[i].state, + ), + ); + shouldThrow(() => + mps.ed25519_dsg_round1_process( + Buffer.concat([ + Buffer.from("mps-ed25519-dsg-round2-message$"), + Buffer.from(results1[otherIndex[i]].msg).slice(messagePrefix.length), + ]), + results1[i].state, + ), + ); + } + }); + + it("fails to perform round 1 with invalid state prefix", function () { + const statePrefix = Buffer.from("mps-ed25519-dsg-round1-state$"); + for (let i = 0; i < results1.length; i++) { + shouldThrow(() => + mps.ed25519_dsg_round1_process( + results1[otherIndex[i]].msg, + Buffer.from(results1[i].state).slice(statePrefix.length), + ), + ); + shouldThrow(() => + mps.ed25519_dsg_round1_process( + results1[otherIndex[i]].msg, + Buffer.concat([Buffer.from("mps-ed25519-dsg-round2-state$"), results1[i].state]), + ), + ); } }); @@ -279,6 +381,45 @@ describe("mps", function () { } }); + it("fails to perform round 2 with invalid message prefix", function () { + const messagePrefix = Buffer.from("mps-ed25519-dsg-round2-message$"); + for (let i = 0; i < results2.length; i++) { + shouldThrow(() => + mps.ed25519_dsg_round2_process( + Buffer.from(results2[otherIndex[i]].msg).slice(messagePrefix.length), + results2[i].state, + ), + ); + shouldThrow(() => + mps.ed25519_dsg_round2_process( + Buffer.concat([ + Buffer.from("mps-ed25519-dsg-round3-message$"), + Buffer.from(results2[otherIndex[i]].msg).slice(messagePrefix.length), + ]), + results2[i].state, + ), + ); + } + }); + + it("fails to perform round 2 with invalid state prefix", function () { + const statePrefix = Buffer.from("mps-ed25519-dsg-round2-state$"); + for (let i = 0; i < results2.length; i++) { + shouldThrow(() => + mps.ed25519_dsg_round2_process( + results2[otherIndex[i]].msg, + Buffer.from(results2[i].state).slice(statePrefix.length), + ), + ); + shouldThrow(() => + mps.ed25519_dsg_round2_process( + results2[otherIndex[i]].msg, + Buffer.concat([Buffer.from("mps-ed25519-dsg-round3-state$"), results2[i].state]), + ), + ); + } + }); + let results3: Array; before("performs round 2", function () { @@ -294,5 +435,44 @@ describe("mps", function () { assert(sodium.crypto_sign_verify_detached(signatures[0], message, shares[0].pk)); assert(sodium.crypto_sign_verify_detached(signatures[1], message, shares[2].pk)); }); + + it("fails to perform round 3 with invalid message prefix", function () { + const messagePrefix = Buffer.from("mps-ed25519-dsg-round3-message$"); + for (let i = 0; i < results3.length; i++) { + shouldThrow(() => + mps.ed25519_dsg_round3_process( + Buffer.from(results3[otherIndex[i]].msg).slice(messagePrefix.length), + results3[i].state, + ), + ); + shouldThrow(() => + mps.ed25519_dsg_round3_process( + Buffer.concat([ + Buffer.from("mps-ed25519-dsg-round4-message$"), + Buffer.from(results3[otherIndex[i]].msg).slice(messagePrefix.length), + ]), + results3[i].state, + ), + ); + } + }); + + it("fails to perform round 3 with invalid state prefix", function () { + const statePrefix = Buffer.from("mps-ed25519-dsg-round3-state$"); + for (let i = 0; i < results3.length; i++) { + shouldThrow(() => + mps.ed25519_dsg_round3_process( + results3[otherIndex[i]].msg, + Buffer.from(results3[i].state).slice(statePrefix.length), + ), + ); + shouldThrow(() => + mps.ed25519_dsg_round3_process( + results3[otherIndex[i]].msg, + Buffer.concat([Buffer.from("mps-ed25519-dsg-round4-state$"), results3[i].state]), + ), + ); + } + }); }); });