Skip to content

Commit

Permalink
feat: bounded lookup log argument (#864)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-camuto authored Nov 7, 2024
1 parent 0876faa commit 00155e5
Show file tree
Hide file tree
Showing 17 changed files with 492 additions and 19 deletions.
2 changes: 1 addition & 1 deletion examples/notebooks/data_attest.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -592,7 +592,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.2"
"version": "3.12.5"
},
"orig_nbformat": 4
},
Expand Down
4 changes: 2 additions & 2 deletions examples/notebooks/data_attest_hashed.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -648,10 +648,10 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
"version": "3.12.5"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
}
2 changes: 1 addition & 1 deletion examples/notebooks/logistic_regression.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.2"
"version": "3.12.7"
}
},
"nbformat": 4,
Expand Down
42 changes: 42 additions & 0 deletions examples/onnx/log/gen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from torch import nn
import torch
import json
import numpy as np


class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()

def forward(self, x):
m = torch.log(x)

return m


circuit = MyModel()

x = torch.empty(1, 8).uniform_(0, 3)

out = circuit(x)

print(out)

torch.onnx.export(circuit, x, "network.onnx",
export_params=True, # store the trained parameter weights inside the model file
opset_version=17, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names=['input'], # the model's input names
output_names=['output'], # the model's output names
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
'output': {0: 'batch_size'}})


d1 = ((x).detach().numpy()).reshape([-1]).tolist()

data = dict(
input_data=[d1],
)

# Serialize data into file:
json.dump(data, open("input.json", 'w'))
1 change: 1 addition & 0 deletions examples/onnx/log/input.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"input_data": [[1.9252371788024902, 1.8418371677398682, 0.8400403261184692, 2.083845853805542, 0.9760497808456421, 0.6940176486968994, 0.015579521656036377, 2.2689192295074463]]}
14 changes: 14 additions & 0 deletions examples/onnx/log/network.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
pytorch2.2.2:o

inputoutput/Log"Log
main_graphZ!
input


batch_size
b"
output


batch_size
B
5 changes: 5 additions & 0 deletions src/bindings/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,9 @@ struct PyRunArgs {
/// int: The number of legs used for decomposition
#[pyo3(get, set)]
pub decomp_legs: usize,
/// bool: Should the circuit use unbounded lookups for log
#[pyo3(get, set)]
pub bounded_log_lookup: bool,
}

/// default instantiation of PyRunArgs
Expand All @@ -212,6 +215,7 @@ impl PyRunArgs {
impl From<PyRunArgs> for RunArgs {
fn from(py_run_args: PyRunArgs) -> Self {
RunArgs {
bounded_log_lookup: py_run_args.bounded_log_lookup,
tolerance: Tolerance::from(py_run_args.tolerance),
input_scale: py_run_args.input_scale,
param_scale: py_run_args.param_scale,
Expand All @@ -236,6 +240,7 @@ impl From<PyRunArgs> for RunArgs {
impl Into<PyRunArgs> for RunArgs {
fn into(self) -> PyRunArgs {
PyRunArgs {
bounded_log_lookup: self.bounded_log_lookup,
tolerance: self.tolerance.val,
input_scale: self.input_scale,
param_scale: self.param_scale,
Expand Down
9 changes: 9 additions & 0 deletions src/circuit/ops/hybrid.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ use serde::{Deserialize, Serialize};
/// An enum representing the operations that consist of both lookups and arithmetic operations.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum HybridOp {
Ln {
scale: utils::F32,
},

RoundHalfToEven {
scale: utils::F32,
legs: usize,
Expand Down Expand Up @@ -112,6 +116,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid

fn as_string(&self) -> String {
match self {
HybridOp::Ln { scale } => format!("LN(scale={})", scale),
HybridOp::RoundHalfToEven { scale, legs } => {
format!("ROUND_HALF_TO_EVEN(scale={}, legs={})", scale, legs)
}
Expand Down Expand Up @@ -189,6 +194,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
values: &[ValTensor<F>],
) -> Result<Option<ValTensor<F>>, CircuitError> {
Ok(Some(match self {
HybridOp::Ln { scale } => layouts::ln(config, region, values[..].try_into()?, *scale)?,
HybridOp::RoundHalfToEven { scale, legs } => {
layouts::round_half_to_even(config, region, values[..].try_into()?, *scale, *legs)?
}
Expand Down Expand Up @@ -327,6 +333,9 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
HybridOp::Softmax { output_scale, .. } | HybridOp::Recip { output_scale, .. } => {
multiplier_to_scale(output_scale.0 as f64)
}
HybridOp::Ln {
scale: output_scale,
} => 4 * multiplier_to_scale(output_scale.0 as f64),
_ => in_scales[0],
};
Ok(scale)
Expand Down
Loading

0 comments on commit 00155e5

Please sign in to comment.