Skip to content
Open
341 changes: 196 additions & 145 deletions Cargo.lock

Large diffs are not rendered by default.

7 changes: 7 additions & 0 deletions scripts/zx_tests/vadd_u16.zx
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from "EMBED" import vadd_u16

def main() -> u16[5]:
u16[5] a = [1, 2, 3, 4, 5]
u16[5] b = [2, 3, 4, 5, 6]
assert(vadd_u16(a, b) == [3, 5, 7, 9, 11])
return vadd_u16(a,b)
7 changes: 7 additions & 0 deletions scripts/zx_tests/vadd_u32.zx
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from "EMBED" import vadd_u32

def main() -> u32[5]:
u32[5] a = [1, 2, 3, 4, 5]
u32[5] b = [2, 3, 4, 5, 6]
assert(vadd_u32(a, b) == [3, 5, 7, 9, 11])
return vadd_u32(a,b)
6 changes: 6 additions & 0 deletions scripts/zx_tests/vadd_u32.zxf
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from "EMBED" import vadd_u32

def main() -> u32[5]:
u32[5] a = [1, 2, 3, 4, 5]
u32[5] b = []
return vadd_u32(a,b)
7 changes: 7 additions & 0 deletions scripts/zx_tests/vadd_u64.zx
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from "EMBED" import vadd_u64

def main() -> u64[5]:
u64[5] a = [1, 2, 3, 4, 5]
u64[5] b = [2, 3, 4, 5, 6]
assert(vadd_u64(a, b) == [3, 5, 7, 9, 11])
return vadd_u64(a,b)
7 changes: 7 additions & 0 deletions scripts/zx_tests/vadd_u8.zx
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from "EMBED" import vadd_u8

def main() -> u8[5]:
u8[5] a = [1, 2, 3, 4, 5]
u8[5] b = [2, 3, 4, 5, 6]
assert(vadd_u8(a, b) == [3, 5, 7, 9, 11])
return vadd_u8(a,b)
21 changes: 19 additions & 2 deletions src/front/zsharp/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ impl<'ast> ZGen<'ast> {
"bit_array_le" => {
if args.len() != 2 {
Err(format!(
"Got {} args to EMBED/bit_array_le, expected 1",
"Got {} args to EMBED/bit_array_le, expected 2",
args.len()
))
} else if generics.len() != 1 {
Expand Down Expand Up @@ -305,6 +305,24 @@ impl<'ast> ZGen<'ast> {
Ok(uint_lit(DFL_T.modulus().significant_bits(), 32))
}
}
"vadd_u8" | "vadd_u16" | "vadd_u32" | "vadd_u64" => {
if args.len() != 2 {
Err(format!(
"Got {} args to EMBED/vadd_*, expected 2",
args.len()
))
} else if generics.len() != 1 {
Err(format!(
"Got {} generic args to EMBED/vadd_*, expected 1",
generics.len()
))
} else {
assert!(args.iter().all(|t| matches!(t.type_(), Ty::Array(_, _))));
let b = args.pop().unwrap();
let a = args.pop().unwrap();
vector_op(BV_ADD, a, b)
}
}
_ => Err(format!("Unknown or unimplemented builtin '{}'", f_name)),
}
}
Expand Down Expand Up @@ -891,7 +909,6 @@ impl<'ast> ZGen<'ast> {
} else {
debug!("Expr: {}", e.span().as_str());
}

match e {
ast::Expression::Ternary(u) => {
match self.expr_impl_::<true>(&u.first).ok().and_then(const_bool) {
Expand Down
14 changes: 14 additions & 0 deletions src/front/zsharp/term.rs
Original file line number Diff line number Diff line change
Expand Up @@ -908,6 +908,20 @@ pub fn bit_array_le(a: T, b: T, n: usize) -> Result<T, String> {
))
}

pub fn vector_op(op: Op, a: T, b: T) -> Result<T, String> {
match (a.ty, b.ty) {
(Ty::Array(a_s, a_ty), Ty::Array(b_s, b_ty)) => {
if a_s == b_s && a_ty == b_ty {
let t = term![Op::Map(Box::new(op)); a.term, b.term];
Ok(T::new(Ty::Array(a_s, a_ty), t))
} else {
panic!("Mismatched array types (this is a bug: type checking should have caught this!)");
}
Comment thread
edwjchen marked this conversation as resolved.
Outdated
}
_ => Err("Cannot do vector_op on non-array types".to_string()),
}
}

pub struct ZSharp {
values: Option<HashMap<String, Integer>>,
}
Expand Down
23 changes: 23 additions & 0 deletions src/ir/opt/cfold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,29 @@ pub fn fold_cache(node: &Term, cache: &mut TermCache<Term>) -> Term {
b.width() + w,
))))
}),
Op::Map(op) => match (get(0).as_array_opt(), get(1).as_array_opt()) {
Comment thread
edwjchen marked this conversation as resolved.
Outdated
(Some(a), Some(b)) => {
// TODO: extend for n-ary arrays
let mut res = a.clone();
Comment thread
edwjchen marked this conversation as resolved.
Outdated
let mut merge = ArrayMerge::new(a.clone(), b.clone());
Comment thread
edwjchen marked this conversation as resolved.
Outdated
for (i, va, vb) in merge.into_iter() {
let r = fold_cache(
&term![*op.clone(); leaf_term(Op::Const(va.clone())), leaf_term(Op::Const(vb.clone()))],
Comment thread
edwjchen marked this conversation as resolved.
Outdated
cache,
);
match r.as_value_opt() {
Some(v) => {
res = res.clone().store(i, v.clone());
Comment thread
edwjchen marked this conversation as resolved.
Outdated
}
None => {
panic!("Unable to constant fold idx: {}", i);
}
}
}
Some(leaf_term(Op::Const(Value::Array(res))))
}
_ => None,
},
Comment thread
edwjchen marked this conversation as resolved.
Outdated
_ => None,
};
let new_t = {
Expand Down
100 changes: 100 additions & 0 deletions src/ir/term/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,10 @@ use hashconsing::{HConsed, WHConsed};
use lazy_static::lazy_static;
use log::debug;
use rug::Integer;
use std::cmp::Ordering;
use std::collections::BTreeMap;
use std::fmt::{self, Debug, Display, Formatter};
use std::iter::Peekable;
use std::sync::{Arc, RwLock};

pub mod bv;
Expand Down Expand Up @@ -721,6 +723,104 @@ impl Array {
self.check_idx(idx);
self.map.get(idx).unwrap_or(&*self.default).clone()
}

/// Iter
pub fn into_iter(&self) -> std::collections::btree_map::IntoIter<Value, Value> {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you're going to implement the into_iter function, you should implement the std::iter::IntoIterator trait, not just add a member function.

Continuing my crusade against unnecessary cloning: you probably want impl<'a> IntoIterator for &'a Array, i.e., not a consuming iterator---just one that returns references. So I think what you really want is something like this (untested---see impl<'a, K, V> IntoIterator for &'a BTreeMap<K, V> in the stdlib):

impl<'a> IntoIterator for &'a Array {
    type Item = (&'a Value, &'a Value);
    type IntoIter = std::collection::btree_map::Iter<'a, Value, Value>;
    fn into_iter(self) -> Self::IntoIter {
        self.map.iter()
    }
}

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A second, higher-level comment: it's extremely un-idiomatic for an into_iter function to take &self rather than self. This may seem slightly counterintuitive given what I've just said in my prior comment, but notice in my proposed impl IntoIterator that the argument to fn into_iter is self---it's just that we think of self as being a value of type &'a Array.

In contrast, what is currently implemented in the PR takes &self and then clones it to produce a by-value iterator, which is bad because it means that if someone did

    for (k, v) in some_array.into_iter() {
    }

they will have potentially caused a very expensive clone that they didn't expect (because idiomatically, into_iter never clones---it either consumes a value or it takes a reference and returns references).

Does this make sense?

self.map.clone().into_iter()
}
}

/// Merge two Array Iterators
pub struct ArrayMerge {
Comment thread
edwjchen marked this conversation as resolved.
Outdated
left: Peekable<Box<std::collections::btree_map::IntoIter<Value, Value>>>,
right: Peekable<Box<std::collections::btree_map::IntoIter<Value, Value>>>,
left_dfl: Value,
right_dfl: Value,
}

impl ArrayMerge {
/// Create a new [ArrayMerge] from two [Array]s
pub fn new(a: Array, b: Array) -> Self {
if a.size != b.size {
panic!("IR Arrays have different lengths: {}, {}", a.size, b.size);
}
if a.key_sort != b.key_sort {
panic!(
"IR Arrays have different key sorts: {}, {}",
a.key_sort, b.key_sort
);
}
if a.default.sort() != b.default.sort() {
panic!(
"IR Arrays default values have different key sorts: {}, {}",
a.default.sort(),
b.default.sort()
);
}
Comment thread
edwjchen marked this conversation as resolved.
Outdated

Self {
left: Box::new(a.into_iter()).peekable(),
right: Box::new(b.into_iter()).peekable(),
Comment thread
edwjchen marked this conversation as resolved.
Outdated
left_dfl: *a.default,
right_dfl: *b.default,
}
}

/// Iter
Comment thread
edwjchen marked this conversation as resolved.
Outdated
pub fn into_iter(&mut self) -> Box<dyn Iterator<Item = (Value, Value, Value)>> {
let mut acc: Vec<(Value, Value, Value)> = Vec::new();
let mut next = self.next_();
while let Some(n) = next {
acc.push(n);
next = self.next_();
}
Box::new(acc.into_iter())
}

/// Next
pub fn next_(&mut self) -> Option<(Value, Value, Value)> {
let l_peek = self.left.peek();
let r_peek = self.right.peek();

let mut left_next = false;
let mut right_next = false;
Comment thread
edwjchen marked this conversation as resolved.
Outdated

let res = match (l_peek, r_peek) {
Comment thread
edwjchen marked this conversation as resolved.
Outdated
(Some((l_ind, l_val)), Some((r_ind, r_val))) => match l_ind.cmp(r_ind) {
Ordering::Less => {
left_next = true;
Some((l_ind.clone(), l_val.clone(), self.right_dfl.clone()))
Comment thread
edwjchen marked this conversation as resolved.
Outdated
}
Ordering::Greater => {
right_next = true;
Some((r_ind.clone(), self.left_dfl.clone(), r_val.clone()))
}
Ordering::Equal => {
left_next = true;
right_next = true;
Some((l_ind.clone(), l_val.clone(), r_val.clone()))
}
},
(Some((l_ind, l_val)), None) => {
left_next = true;
Some((l_ind.clone(), l_val.clone(), self.right_dfl.clone()))
}
(None, Some((r_ind, r_val))) => {
right_next = true;
Some((r_ind.clone(), self.left_dfl.clone(), r_val.clone()))
}
(None, None) => None,
};

if left_next {
self.left.next();
}
if right_next {
self.right.next();
}

res
}
}

impl Display for Value {
Expand Down
13 changes: 13 additions & 0 deletions third_party/ZoKrates/zokrates_stdlib/stdlib/EMBED.zok
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,16 @@ def u16_to_u32(u16 i) -> u32:

def u8_to_u16(u8 i) -> u16:
return 0u16

// vector functions
def vadd_u8<N>(u8[N] a, u8[N] b) -> u8[N]:
return [0; N]

def vadd_u16<N>(u16[N] a, u16[N] b) -> u16[N]:
return [0; N]

def vadd_u32<N>(u32[N] a, u32[N] b) -> u32[N]:
return [0; N]

def vadd_u64<N>(u64[N] a, u64[N] b) -> u64[N]:
return [0; N]
Comment on lines +79 to +90
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think if you're only going to support one type here it should be field, not u32.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think only BV are currently supported in the FHE backend, which is why I left it as u32.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still, it might make sense to support all the basic types, no? It wouldn't be much extra work:

  • add vadd_field, vadd_u64, vadd_u32, vadd_u16, vadd_u8 stubs in EMBED
  • in builtin_call_, extend the current match clause to catch all of these
  • extend the function in zsharp/term.rs to support both bitvector and field types
  • the current constant folder may or may not work, but I bet there's a small edit distance to working if no

Since this is so simple, I don't see a reason to leave it half-done. But I could be missing something!

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, updated!