Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dimension mismatch in tensor op #827

Closed
pgermani opened this issue Jul 12, 2024 · 1 comment · Fixed by #828
Closed

Dimension mismatch in tensor op #827

pgermani opened this issue Jul 12, 2024 · 1 comment · Fixed by #828
Labels
bug Something isn't working

Comments

@pgermani
Copy link

pgermani commented Jul 12, 2024

I've been trying to use EZKL with the apple/mobilevit-xx-small model, but I got an error during the first step of the EZKL library, when calling gen_settings. I exported the model in ONNX format and performed inference, and then I attempted to use EZKL.

Description error:

[tensor] dimension mismatch in tensor op: Given groups=128, expected kernel to be at least 128 at dimension 0 but got 0 instead
Traceback (most recent call last):
File "mobilevit.py", line 83, in
res = ezkl.gen_settings(onnx_path, settings_path, py_run_args=py_run_args)
RuntimeError: Failed to generate settings: [graph] [halo2] General synthesis

Here is my code:

from transformers import MobileViTFeatureExtractor, MobileViTForImageClassification
from PIL import Image
import requests
import torch
import inspect
import onnxruntime as ort
import numpy as np
import os
import ezkl

onnx_path = os.path.join("./mobilevit.onnx")

url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)

feature_extractor = MobileViTFeatureExtractor.from_pretrained("apple/mobilevit-xx-small")
model = MobileViTForImageClassification.from_pretrained("apple/mobilevit-xx-small")

inputs = feature_extractor(images=image, return_tensors="pt")

input_names = list(inputs.keys())
print(f"input keys: {input_names}")

outputs = model(**inputs)
logits = outputs.logits

output_names = list(outputs.keys())
print(f"output keys: {output_names}")

predicted_class_idx = logits.argmax(-1).item()
print("Predicted class:", model.config.id2label[predicted_class_idx])


signature = inspect.signature(model.forward)
print(f"forward signature: {signature}")

print(f"input shape: {inputs['pixel_values'].shape}")
print(f"output shape: {outputs.logits.shape}")

torch.onnx.export(
    model, 
    (inputs['pixel_values'],), 
    onnx_path,
    input_names=['pixel_values'], 
    output_names=['logits'],
    dynamic_axes={
        'pixel_values': {0: 'batch_size'},  
        'logits': {0: 'batch_size'}         
    },
    opset_version=13,
)

session = ort.InferenceSession(onnx_path)
input_array = inputs['pixel_values'].numpy()
outputs = session.run(None, {"pixel_values": input_array})

logits = outputs[0]
print(f"Output logits: {logits}")
predicted_class = np.argmax(logits, axis=1)
print(f"Predicted class: {predicted_class[0]}")

## ezkl 
sol_code_path = os.path.join('./proof/Verifier.sol')
abi_path = os.path.join('./proof/Verifier.abi')

compiled_model_path = os.path.join('./proof/network.ezkl')
settings_path = os.path.join('./proof/settings.json')
cal_data_path = os.path.join('./proof/cal_data.json')
witness_path = os.path.join('./proof/witness.json')
proof_path = os.path.join('./proof/proof.json')
data_path = os.path.join('./proof/input.json')
pk_path = os.path.join('./proof/test.pk')
vk_path = os.path.join('./proof/test.vk')

if not os.path.exists("proof"):
    os.makedirs("proof")

py_run_args = ezkl.PyRunArgs() 
py_run_args.output_visibility = "public"
py_run_args.param_visibility = "private"
py_run_args.input_visibility = "hashed"

res = ezkl.gen_settings(onnx_path, settings_path, py_run_args=py_run_args) 
ezkl.calibrate_settings(cal_data_path, onnx_path, settings_path, "resources")
ezkl.compile_circuit(onnx_path, compiled_model_path, settings_path) 
res = ezkl.get_srs( settings_path) 
res = ezkl.setup( 
                compiled_model_path,
                vk_path,
                pk_path,
                )
res = ezkl.create_evm_verifier(
        vk_path,
        settings_path,
        sol_code_path,
        abi_path
    )
@pgermani pgermani added the bug Something isn't working label Jul 12, 2024
@alexander-camuto
Copy link
Collaborator

@pgermani taking a look at this now, was able to reproduce. Will patch it soon :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants