Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
251 changes: 145 additions & 106 deletions fastcrypto/src/twisted_elgamal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,20 @@ pub struct VerifiableKeyEncapsulation<const N: usize> {
/// same limb values `u_i` that open those commitments reconstruct to the private key, i.e.
/// `(\sum_i u_i * 2^{32i}) * G == U` where `U` is the sender's public key. Crucially, the proof
/// binds (1) and (2) together, so the verifier is assured that the values inside the commitments
/// are exactly the limbs of the private key for `U`. The proof is made non-interactive via the
/// Fiat-Shamir transform and supports multiple recipients sharing the same commitment per limb.
/// are exactly the limbs of the private key for `U`.
///
/// The proof is made non-interactive via the Fiat-Shamir transform with two challenges:
///
/// * `rho` (m scalars), derived from the statement, batches the per-recipient auditor
/// equations into a single equation per limb. This makes the prover-side auditor
/// commitments `a3` constant in `m` (`N` group elements rather than `N*m`).
/// * `c` (one scalar), derived from the full transcript including `rho` and all commitments,
/// is the standard sigma-protocol challenge.
///
/// Commitments: `a1[i] = a_i*G + b_i*H` (Pedersen, round 1), `a2[i] = b_i*G` (recombination,
/// round 1), `a3[i] = a_i*S_rho` where `S_rho = sum_j rho_j*S_j` (auditor, batched, round 3).
pub struct KeyConsistencyProof<const N: usize> {
a1: Vec<RistrettoPoint>,
a1: [RistrettoPoint; N],
a2: [RistrettoPoint; N],
a3: [RistrettoPoint; N],
z1: [RistrettoScalar; N],
Expand Down Expand Up @@ -283,6 +293,22 @@ impl MultiRecipientCiphertext {
}

impl<const N: usize> KeyConsistencyProof<N> {
/// Construct the consistency proof, showing that the prover knows witness (r_i, u_i)_{i=1..N} satisfying:
///
/// (W1) Pedersen openings: C_i = r_i * G + u_i * H for all i
/// (W2) Key recombination: (sum_i 2^{32i} * u_i) * G = U
/// (W3) Decryption handles: D_ij = r_i * S_j for all i, j
///
/// The proof is a 5-message sigma protocol made non-interactive via Fiat-Shamir, with two
/// challenges:
///
/// Round 1: Sample a_i, b_i and send the round-1 commitments a1[i] = a_i * G + b_i * H (Pedersen) and
/// a2[i] = b_i * G (key recombination).
/// Round 2: Compute m batching scalars rho = challenge_rho(statement).

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

currently we compute hash_to_group(hash(...)) per rho_j
can we just use hash_to_group(hash(...))^j? would be cheaper to compute

/// Round 3: Send the batched decryption handle commitment a3[i] = a_i * S_rho where S_rho = sum_j rho_j * pk_j.
/// This collapses the m per-recipient decryption handle commitments into one per limb, making the proof size constant in m.
/// Round 4: Compute the main sigma protocol challenge c = challenge_c(statement, rho, a1, a2, a3).
/// Round 5: Send the responses z1[i] = a_i + c * r_i and z2[i] = b_i + c * u_i.
pub fn prove(
sender_private_key_limbs: &[u32; N],
sender_public_key: &PublicKey,
Expand All @@ -291,147 +317,129 @@ impl<const N: usize> KeyConsistencyProof<N> {
blindings: &[Blinding; N],
rng: &mut impl AllowedRng,
) -> Self {
// Sample N random a_i and b_i
let a: [_; N] = from_fn(|_| RistrettoScalar::rand(rng));
let b: [_; N] = from_fn(|_| RistrettoScalar::rand(rng));

// A_1ij = a_i * pk_j for all (i, j) — N*m elements, ordered by limb then recipient
let a1 = a
.iter()
.flat_map(|ai| recipient_encryption_keys.iter().map(move |pk| pk.0 * ai))
.collect_vec();
let a1: [_; N] = from_fn(|i| *G * a[i] + *H * b[i]);
let a2: [_; N] = from_fn(|i| *G * b[i]);

let rho = Self::challenge_rho(sender_public_key, recipient_encryption_keys, ciphertexts);

// A_2i = a_i * G + b_i * H for all i
let a2 = from_fn(|i| *G * a[i] + *H * b[i]);
let recipient_pk_points: Vec<RistrettoPoint> =
recipient_encryption_keys.iter().map(|pk| pk.0).collect();
let s_rho = RistrettoPoint::multi_scalar_mul(&rho, &recipient_pk_points)
.expect("Consistent lengths");

// A_3i = b_i * G for all i
let a3 = from_fn(|i| *G * b[i]);
let a3: [RistrettoPoint; N] = from_fn(|i| s_rho * a[i]);

// c = Hash(G, H, sender_public_key, recipient_encryption_keys, ciphertexts, a1, a2, a3)
let c = Self::challenge(
let c = Self::challenge_c(
sender_public_key,
recipient_encryption_keys,
ciphertexts,
&rho,
&a1,
&a2,
&a3,
);

// z_1i = a_i + c * r_i
let z1 = from_fn(|i| a[i] + c * blindings[i].0);

// z_2i = b_i + c * u_i
let z2 = from_fn(|i| b[i] + c * RistrettoScalar::from(sender_private_key_limbs[i] as u64));

Self { a1, a2, a3, z1, z2 }
}

/// Verify checks the provided consistency proof. To do so, it batches all three groups of verification equations
/// into a single MSM using hash-derived scalars. The three groups of equations that must hold for a valid proof are:
/// Verify the consistency proof. The three groups of equations that must hold are:
///
/// Check 1 (decryption handle consistency): Verifies that each decryption handle was formed with the same
/// blinding r_i as the commitment via
/// A1_ij + c * D_ij == z_1i * S_j
/// for all limbs i and recipients j where D_ij = r_i * S_j is the decryption handle and S_j is recipient j's public key.
/// Combined equations using scalars mu_ij = Hash("mu", c, i, j):
/// \sum_j (\sum_i mu_ij * z_1i) * S_j - \sum_{i,j} mu_ij * A1_ij - \sum_{i,j} (c * mu_ij) * D_ij == 0
/// Check 1 (Pedersen commitment): For each limb i, the prover knows the blinding r_i and message u_i
/// opening the commitment C_i = r_i * G + u_i * H: z1[i] * G + z2[i] * H == a1[i] + c * C_i.
///
/// Check 2 (commitment consistency): Verifies knowledge of the blinding r_i and message u_i opening the
/// commitment via
/// A2_i + c * C_i == z_1i * G + z_2i * H
/// for all limbs i where C_i = r_i * G + u_i * H is the Pedersen commitment.
/// Combined equations using scalars rho_i = Hash("rho", c, i):
/// (\sum_i rho_i * z_1i) * G + (\sum_i rho_i * z_2i) * H - \sum_i rho_i * A2_i - \sum_i (c * rho_i) * C_i == 0
/// Check 2 (key recombination): The encrypted 32-bit key limbs u_i reconstruct to the private
/// key corresponding to U: (sum_i 2^{32i} * z2[i]) * G == (sum_i 2^{32i} * a2[i]) + c * U.
///
/// Check 3 (public key consistency): Verifies that the encrypted 32-bit key limbs u_i reconstruct to the
/// private key corresponding to U via
/// (\sum_i z_2i * 2^{32i}) * G == (\sum_i A3_i * 2^{32i}) + c * U
/// where U is the sender's public key.
/// Check 3 (decryption handle consistency, batched against rho): For each limb i, each decryption handle was
/// formed with the same blinding r_i used in the commitment: z1[i] * S_rho == a3[i] + c * D_rho[i]
/// where S_rho = sum_j rho_j * pk_j and D_rho[i] = sum_j rho_j * D_ij.
///
/// We combine the individual checks as
/// (check 1) + alpha * (check 2) + beta * (check 3) == 0
/// using hash-derived outer scalars alpha = Hash("alpha", c) and beta = Hash("beta", c) to ensure soundness.
/// All 2n+1 equations are consolidated into a single MSM using verifier-side hash-derived
/// scalars from a fresh random seed msm_seed:
/// - w_ped[i] = Hash("w_ped", msm_seed, i): inner weights for the N Pedersen equations
/// - w_dec[i] = Hash("w_dec", msm_seed, i): inner weights for the N decryption handle equations
/// - w_rec = Hash("w_rec", msm_seed): outer weight for the key recombination equation
///
/// The combined check is sum_i w_ped[i] * (Pedersen)_i + w_rec * (Recomb) + sum_i w_dec[i] * (DecryptionHandle)_i == 0.
pub fn verify(
&self,
sender_public_key: &PublicKey,
recipient_encryption_keys: &[PublicKey],
ciphertexts: &[MultiRecipientCiphertext; N],
rng: &mut impl AllowedRng,
) -> FastCryptoResult<()> {
// Fiat-Shamir challenge
let c = Self::challenge(
let rho = Self::challenge_rho(sender_public_key, recipient_encryption_keys, ciphertexts);

let recipient_pk_points: Vec<RistrettoPoint> =
recipient_encryption_keys.iter().map(|pk| pk.0).collect();
let s_rho = RistrettoPoint::multi_scalar_mul(&rho, &recipient_pk_points)
.expect("Consistent lengths");

let d_rho: [RistrettoPoint; N] = from_fn(|i| {
RistrettoPoint::multi_scalar_mul(&rho, &ciphertexts[i].decryption_handles)
.expect("Consistent lengths")
});

let c = Self::challenge_c(
sender_public_key,
recipient_encryption_keys,
ciphertexts,
&rho,
&self.a1,
&self.a2,
&self.a3,
);

// Number of recipients
let m = recipient_encryption_keys.len();

// Compute inner scalars mu_ij = Hash("mu", c, i, j) for all i and j used in check 1
let mu: Vec<RistrettoScalar> = (0..N)
.flat_map(|i| (0..m).map(move |j| fiat_shamir_challenge(&("mu", &c, i, j))))
.collect();

// Compute inner scalars rho_i = Hash("rho", c, i) for all i used in check 2
let rho: [RistrettoScalar; N] = from_fn(|i| fiat_shamir_challenge(&("rho", &c, i)));
let msm_seed = RistrettoScalar::rand(rng);
let w_ped: [RistrettoScalar; N] =
from_fn(|i| fiat_shamir_challenge(&("w_ped", &msm_seed, i)));
let w_dec: [RistrettoScalar; N] =
from_fn(|i| fiat_shamir_challenge(&("w_dec", &msm_seed, i)));
let w_rec = fiat_shamir_challenge(&("w_rec", &msm_seed));

// Compute outer scalars alpha = Hash("alpha", c) and beta = Hash("beta", c) combining the three zero-expressions:
// (check 1) + alpha * (check 2) + beta * (check 3) == 0
let alpha = fiat_shamir_challenge(&("alpha", &c));
let beta = fiat_shamir_challenge(&("beta", &c));

// Check 2: compute sum_i(rho_i * z_1i) and sum_i(rho_i * z_2i)
let rho_z1 = RistrettoScalar::inner_product(rho, self.z1);
let rho_z2 = RistrettoScalar::inner_product(rho, self.z2);

// Check 3: compute z = \sum_i z_2i * 2^{32i}
let b = RistrettoScalar::from(1u64 << 32);
let z = RistrettoScalar::inner_product(
iterate(RistrettoScalar::generator(), |e| e * b),
self.z2,
);

let mut scalars: Vec<RistrettoScalar> = vec![alpha * rho_z1 + beta * z, alpha * rho_z2];
let mut points: Vec<RistrettoPoint> = vec![*G, *H];

// Check 1: Append (\sum_i mu_ij * z_1i, S_j) terms for each recipient j
for j in 0..m {
scalars.push(RistrettoScalar::inner_product(
(0..N).map(|i| mu[i * m + j]),
self.z1,
));
points.push(recipient_encryption_keys[j].0);
}

// Check 1: Append (-mu_ij, A1_ij) and (-c * mu_ij, D_ij) terms
for (i, (a1_chunk, ci)) in self.a1.chunks(m).zip(ciphertexts).enumerate() {
for (j, (a1ij, dij)) in a1_chunk.iter().zip(&ci.decryption_handles).enumerate() {
scalars.push(-mu[i * m + j]);
points.push(*a1ij);
scalars.push(-(c * mu[i * m + j]));
points.push(*dij);
}
}

// Check 2: Append (-alpha * rho_i, A2_i) and (-c * alpha * rho_i, C_i) terms
for ((rhoi, a2i), ci) in rho.iter().zip(self.a2).zip(ciphertexts) {
scalars.push(-(alpha * *rhoi));
points.push(a2i);
scalars.push(-(c * alpha * *rhoi));
points.push(ci.commitment.0);
}

// Check 3: Append (-beta * c, U) and (-beta * 2^{32i}, A3_i) terms
scalars.push(-(beta * c));
points.push(sender_public_key.0);
let mut exp = RistrettoScalar::generator();
for a3i in self.a3 {
scalars.push(-(beta * exp));
points.push(a3i);
exp *= b;
// Coefficients on the fixed points:
// on G: sum_i w_ped[i] * z1[i] + w_rec * sum_i 2^{32i} * z2[i]
// on H: sum_i w_ped[i] * z2[i]
// on S_rho: sum_i w_dec[i] * z1[i]
// on U: -w_rec * c
let coef_g = RistrettoScalar::inner_product(w_ped, self.z1)
+ w_rec
* RistrettoScalar::inner_product(
iterate(RistrettoScalar::generator(), |e| e * b),
self.z2,
);
let coef_h = RistrettoScalar::inner_product(w_ped, self.z2);
let coef_s_rho = RistrettoScalar::inner_product(w_dec, self.z1);
let coef_u = -(w_rec * c);

let mut scalars: Vec<RistrettoScalar> = vec![coef_g, coef_h, coef_s_rho, coef_u];
let mut points: Vec<RistrettoPoint> = vec![*G, *H, s_rho, sender_public_key.0];

// Per-limb terms (5n): a1[i], C_i, a2[i], a3[i], D_rho[i]
let mut w_rec_pow = w_rec;
for i in 0..N {
// Pedersen contributions
scalars.push(-w_ped[i]);
points.push(self.a1[i]);
scalars.push(-(c * w_ped[i]));
points.push(ciphertexts[i].commitment.0);
// Recombination contribution
scalars.push(-w_rec_pow);
points.push(self.a2[i]);
// Decryption handle contributions
scalars.push(-w_dec[i]);
points.push(self.a3[i]);
scalars.push(-(c * w_dec[i]));
points.push(d_rho[i]);
w_rec_pow *= b;
}

if RistrettoPoint::multi_scalar_mul(&scalars, &points).expect("Consistent lengths")
Expand All @@ -443,20 +451,45 @@ impl<const N: usize> KeyConsistencyProof<N> {
Ok(())
}

pub fn challenge(
/// Round-2 challenge: derive m batching scalars rho_1, ..., rho_m from the statement.
/// Hashes the statement once into a base scalar, then derives each rho_j from (base, j).
pub fn challenge_rho(
sender_public_key: &PublicKey,
recipient_encryption_keys: &[PublicKey],
ciphertexts: &[MultiRecipientCiphertext; N],
) -> Vec<RistrettoScalar> {
let base = fiat_shamir_challenge(&(
"rho_base",
&*G,
&*H,
sender_public_key,
recipient_encryption_keys,
ciphertexts.as_slice(),
));
(0..recipient_encryption_keys.len())
.map(|j| fiat_shamir_challenge(&("rho", &base, j)))
.collect()
}

/// Round-4 challenge: derive the sigma-protocol challenge c from the full transcript
/// (statement + rho + round-1 and round-3 commitments).
pub fn challenge_c(
sender_public_key: &PublicKey,
recipient_encryption_keys: &[PublicKey],
ciphertexts: &[MultiRecipientCiphertext; N],
rho: &[RistrettoScalar],
a1: &[RistrettoPoint],
a2: &[RistrettoPoint],
a3: &[RistrettoPoint],
) -> RistrettoScalar {
fiat_shamir_challenge(&(
"c",
&*G,
&*H,
sender_public_key,
recipient_encryption_keys,
ciphertexts.as_slice(),
rho,
a1,
a2,
a3,
Expand Down Expand Up @@ -542,6 +575,7 @@ impl<const N: usize> VerifiableKeyEncapsulation<N> {
sender_public_key,
recipient_encryption_keys,
&self.ciphertexts,
rng,
)
}

Expand Down Expand Up @@ -731,13 +765,18 @@ fn test_key_consistency_proof() {

// Verification passes with correct sender public key
assert!(proof
.verify(&pk_snd, std::slice::from_ref(&pk_rcv), &ciphertexts)
.verify(
&pk_snd,
std::slice::from_ref(&pk_rcv),
&ciphertexts,
&mut rng
)
.is_ok());

// Verification fails with a different sender public key
let (other_pk_snd, _) = generate_keypair(&mut rng);
assert!(proof
.verify(&other_pk_snd, &[pk_rcv], &ciphertexts)
.verify(&other_pk_snd, &[pk_rcv], &ciphertexts, &mut rng)
.is_err());
}

Expand Down
Loading