{ "cells": [ { "cell_type": "markdown", "id": "cf69bb3f-94e6-4dba-92cd-ce08df117d67", "metadata": { "id": "cf69bb3f-94e6-4dba-92cd-ce08df117d67" }, "source": [ "## EZKL Jupyter Notebook Demo LOCAL\n", "\n", "Here we demonstrate how to use the EZKL package to run a publicly known / committed to network on some private data, producing a public output.\n" ] }, { "cell_type": "code", "execution_count": 8, "id": "95613ee9", "metadata": { "id": "95613ee9" }, "outputs": [ { "ename": "ImportError", "evalue": "attempted relative import beyond top-level package", "output_type": "error", "traceback": [ "\u001b[31m---------------------------------------------------------------------------\u001b[39m", "\u001b[31mImportError\u001b[39m Traceback (most recent call last)", "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[8]\u001b[39m\u001b[32m, line 18\u001b[39m\n\u001b[32m 15\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m extra_path \u001b[38;5;129;01min\u001b[39;00m sys.path:\n\u001b[32m 16\u001b[39m sys.path.append(extra_path)\n\u001b[32m---> \u001b[39m\u001b[32m18\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mmodel\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m fashion\n", "\u001b[36mFile \u001b[39m\u001b[32m~/src/zkml-bootcamp2025Q1-g6/app/model/fashion.py:15\u001b[39m\n\u001b[32m 13\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mPIL\u001b[39;00m\n\u001b[32m 14\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mPIL\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m Image\n\u001b[32m---> \u001b[39m\u001b[32m15\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01m.\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m tools\n\u001b[32m 17\u001b[39m classes = [\n\u001b[32m 18\u001b[39m \u001b[33m\"\u001b[39m\u001b[33mT-shirt/top\u001b[39m\u001b[33m\"\u001b[39m,\n\u001b[32m 19\u001b[39m \u001b[33m\"\u001b[39m\u001b[33mTrouser\u001b[39m\u001b[33m\"\u001b[39m,\n\u001b[32m (...)\u001b[39m\u001b[32m 27\u001b[39m \u001b[33m\"\u001b[39m\u001b[33mAnkle boot\u001b[39m\u001b[33m\"\u001b[39m,\n\u001b[32m 28\u001b[39m ]\n\u001b[32m 30\u001b[39m \u001b[38;5;66;03m# Download test data from open datasets.\u001b[39;00m\n", "\u001b[31mImportError\u001b[39m: attempted relative import beyond top-level package" ] } ], "source": [ "# make sure you have the dependencies required here already installed\n", "import subprocess\n", "import sys\n", "from torch import nn\n", "import ezkl\n", "import os\n", "import json\n", "import torch\n", "from torch.utils.data import DataLoader\n", "from torchvision import datasets\n", "from torchvision.transforms import ToTensor\n", "\n", "# hack PYTHON_PATH\n", "extra_path = \"/home/philippe/src/zkml-bootcamp2025Q1-g6/app\"\n", "if not extra_path in sys.path:\n", " sys.path.append(extra_path)\n", " \n", "from model import fashion" ] }, { "cell_type": "code", "execution_count": 5, "id": "2c90da79-a173-4a3d-ad82-7505fb718c88", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 10, "id": "945dc7a9-f969-45b6-8202-eea00bc77e6e", "metadata": {}, "outputs": [ { "ename": "AttributeError", "evalue": "module 'model' has no attribute 'fashion'", "output_type": "error", "traceback": [ "\u001b[31m---------------------------------------------------------------------------\u001b[39m", "\u001b[31mAttributeError\u001b[39m Traceback (most recent call last)", "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[10]\u001b[39m\u001b[32m, line 1\u001b[39m\n\u001b[32m----> \u001b[39m\u001b[32m1\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m.\u001b[49m\u001b[43mfashion\u001b[49m\n", "\u001b[31mAttributeError\u001b[39m: module 'model' has no attribute 'fashion'" ] } ], "source": [] }, { "cell_type": "code", "execution_count": 2, "id": "cde92973-2ff1-405b-9ca4-fa3f68f3d954", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100.0%\n", "100.0%\n", "100.0%\n", "100.0%\n" ] } ], "source": [ "\n", "# Download training data from open datasets.\n", "training_data = datasets.FashionMNIST(\n", " root=\"data\",\n", " train=True,\n", " download=True,\n", " transform=ToTensor(),\n", ")\n", "\n", "# Download test data from open datasets.\n", "test_data = datasets.FashionMNIST(\n", " root=\"data\",\n", " train=False,\n", " download=True,\n", " transform=ToTensor(),\n", ")\n", "\n", "batch_size = 64\n", "\n", "# Create data loaders.\n", "train_dataloader = DataLoader(training_data, batch_size=batch_size)\n", "test_dataloader = DataLoader(test_data, batch_size=batch_size)\n", "\n", "def train(dataloader, model, loss_fn, optimizer):\n", " size = len(dataloader.dataset)\n", " model.train()\n", " for batch, (X, y) in enumerate(dataloader):\n", " X, y = X, y\n", "\n", " # Compute prediction error\n", " pred = model(X)\n", " loss = loss_fn(pred, y)\n", "\n", " # Backpropagation\n", " loss.backward()\n", " optimizer.step()\n", " optimizer.zero_grad()\n", "\n", " if batch % 100 == 0:\n", " loss, current = loss.item(), (batch + 1) * len(X)\n", " print(f\"loss: {loss:>7f} [{current:>5d}/{size:>5d}]\")\n", "\n", "def test(dataloader, model, loss_fn):\n", " size = len(dataloader.dataset)\n", " num_batches = len(dataloader)\n", " model.eval()\n", " test_loss, correct = 0, 0\n", " with torch.no_grad():\n", " for X, y in dataloader:\n", " X, y = X, y\n", " pred = model(X)\n", " test_loss += loss_fn(pred, y).item()\n", " correct += (pred.argmax(1) == y).type(torch.float).sum().item()\n", " test_loss /= num_batches\n", " correct /= size\n", " print(f\"Test Error: \\n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \\n\")\n", "\n", "# Defines the model\n", "class NeuralNetwork(nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", " self.flatten = nn.Flatten()\n", " self.linear_relu_stack = nn.Sequential(\n", " nn.Linear(28*28, 512),\n", " nn.ReLU(),\n", " nn.Linear(512, 512),\n", " nn.ReLU(),\n", " nn.Linear(512, 10)\n", " )\n", "\n", " def forward(self, x):\n", " x = self.flatten(x)\n", " logits = self.linear_relu_stack(x)\n", " return logits\n", "\n" ] }, { "cell_type": "code", "execution_count": 5, "id": "b37637c4", "metadata": { "id": "b37637c4" }, "outputs": [], "source": [ "model_path = os.path.join('network.onnx')\n", "compiled_model_path = os.path.join('network.compiled')\n", "pk_path = os.path.join('test.pk')\n", "vk_path = os.path.join('test.vk')\n", "settings_path = os.path.join('settings.json')\n", "\n", "witness_path = os.path.join('witness.json')\n", "data_path = os.path.join('input.json')" ] }, { "cell_type": "code", "execution_count": 6, "id": "mbTNTVVVUWHf", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "mbTNTVVVUWHf", "outputId": "298c39ef-d47d-45e1-b398-69b8752f9843" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "NeuralNetwork(\n", " (flatten): Flatten(start_dim=1, end_dim=-1)\n", " (linear_relu_stack): Sequential(\n", " (0): Linear(in_features=784, out_features=512, bias=True)\n", " (1): ReLU()\n", " (2): Linear(in_features=512, out_features=512, bias=True)\n", " (3): ReLU()\n", " (4): Linear(in_features=512, out_features=10, bias=True)\n", " )\n", ")\n", "Epoch 1\n", "-------------------------------\n", "loss: 2.312977 [ 64/60000]\n", "loss: 2.296970 [ 6464/60000]\n", "loss: 2.274892 [12864/60000]\n", "loss: 2.276994 [19264/60000]\n", "loss: 2.254272 [25664/60000]\n", "loss: 2.238387 [32064/60000]\n", "loss: 2.232406 [38464/60000]\n", "loss: 2.202704 [44864/60000]\n", "loss: 2.202426 [51264/60000]\n", "loss: 2.174140 [57664/60000]\n", "Test Error: \n", " Accuracy: 56.1%, Avg loss: 2.163561 \n", "\n", "Epoch 2\n", "-------------------------------\n", "loss: 2.175747 [ 64/60000]\n", "loss: 2.158823 [ 6464/60000]\n", "loss: 2.105635 [12864/60000]\n", "loss: 2.131705 [19264/60000]\n", "loss: 2.067987 [25664/60000]\n", "loss: 2.017259 [32064/60000]\n", "loss: 2.039357 [38464/60000]\n", "loss: 1.960464 [44864/60000]\n", "loss: 1.975269 [51264/60000]\n", "loss: 1.903171 [57664/60000]\n", "Test Error: \n", " Accuracy: 58.7%, Avg loss: 1.897245 \n", "\n", "Epoch 3\n", "-------------------------------\n", "loss: 1.933288 [ 64/60000]\n", "loss: 1.893456 [ 6464/60000]\n", "loss: 1.784591 [12864/60000]\n", "loss: 1.838915 [19264/60000]\n", "loss: 1.714538 [25664/60000]\n", "loss: 1.665039 [32064/60000]\n", "loss: 1.689406 [38464/60000]\n", "loss: 1.585296 [44864/60000]\n", "loss: 1.619385 [51264/60000]\n", "loss: 1.516735 [57664/60000]\n", "Test Error: \n", " Accuracy: 61.9%, Avg loss: 1.527831 \n", "\n", "Epoch 4\n", "-------------------------------\n", "loss: 1.597046 [ 64/60000]\n", "loss: 1.551634 [ 6464/60000]\n", "loss: 1.408163 [12864/60000]\n", "loss: 1.488660 [19264/60000]\n", "loss: 1.363807 [25664/60000]\n", "loss: 1.358963 [32064/60000]\n", "loss: 1.368338 [38464/60000]\n", "loss: 1.288468 [44864/60000]\n", "loss: 1.324361 [51264/60000]\n", "loss: 1.233305 [57664/60000]\n", "Test Error: \n", " Accuracy: 63.3%, Avg loss: 1.251777 \n", "\n", "Epoch 5\n", "-------------------------------\n", "loss: 1.330162 [ 64/60000]\n", "loss: 1.306223 [ 6464/60000]\n", "loss: 1.143752 [12864/60000]\n", "loss: 1.257733 [19264/60000]\n", "loss: 1.132717 [25664/60000]\n", "loss: 1.157656 [32064/60000]\n", "loss: 1.170492 [38464/60000]\n", "loss: 1.103810 [44864/60000]\n", "loss: 1.139225 [51264/60000]\n", "loss: 1.067870 [57664/60000]\n", "Test Error: \n", " Accuracy: 64.5%, Avg loss: 1.082341 \n", "\n", "Epoch 6\n", "-------------------------------\n", "loss: 1.152435 [ 64/60000]\n", "loss: 1.153687 [ 6464/60000]\n", "loss: 0.973343 [12864/60000]\n", "loss: 1.116238 [19264/60000]\n", "loss: 0.991616 [25664/60000]\n", "loss: 1.024298 [32064/60000]\n", "loss: 1.049273 [38464/60000]\n", "loss: 0.989383 [44864/60000]\n", "loss: 1.020612 [51264/60000]\n", "loss: 0.966005 [57664/60000]\n", "Test Error: \n", " Accuracy: 65.8%, Avg loss: 0.974972 \n", "\n", "Epoch 7\n", "-------------------------------\n", "loss: 1.031440 [ 64/60000]\n", "loss: 1.057128 [ 6464/60000]\n", "loss: 0.859297 [12864/60000]\n", "loss: 1.023219 [19264/60000]\n", "loss: 0.902562 [25664/60000]\n", "loss: 0.932362 [32064/60000]\n", "loss: 0.970724 [38464/60000]\n", "loss: 0.916998 [44864/60000]\n", "loss: 0.939589 [51264/60000]\n", "loss: 0.898259 [57664/60000]\n", "Test Error: \n", " Accuracy: 67.1%, Avg loss: 0.902818 \n", "\n", "Epoch 8\n", "-------------------------------\n", "loss: 0.944603 [ 64/60000]\n", "loss: 0.991639 [ 6464/60000]\n", "loss: 0.779391 [12864/60000]\n", "loss: 0.957908 [19264/60000]\n", "loss: 0.842370 [25664/60000]\n", "loss: 0.866209 [32064/60000]\n", "loss: 0.916129 [38464/60000]\n", "loss: 0.870094 [44864/60000]\n", "loss: 0.881627 [51264/60000]\n", "loss: 0.849794 [57664/60000]\n", "Test Error: \n", " Accuracy: 68.4%, Avg loss: 0.851363 \n", "\n", "Epoch 9\n", "-------------------------------\n", "loss: 0.878879 [ 64/60000]\n", "loss: 0.943430 [ 6464/60000]\n", "loss: 0.720472 [12864/60000]\n", "loss: 0.909670 [19264/60000]\n", "loss: 0.798988 [25664/60000]\n", "loss: 0.817197 [32064/60000]\n", "loss: 0.875414 [38464/60000]\n", "loss: 0.837863 [44864/60000]\n", "loss: 0.838859 [51264/60000]\n", "loss: 0.812779 [57664/60000]\n", "Test Error: \n", " Accuracy: 69.7%, Avg loss: 0.812555 \n", "\n", "Epoch 10\n", "-------------------------------\n", "loss: 0.826808 [ 64/60000]\n", "loss: 0.905002 [ 6464/60000]\n", "loss: 0.674947 [12864/60000]\n", "loss: 0.872401 [19264/60000]\n", "loss: 0.765720 [25664/60000]\n", "loss: 0.779864 [32064/60000]\n", "loss: 0.842847 [38464/60000]\n", "loss: 0.814456 [44864/60000]\n", "loss: 0.805864 [51264/60000]\n", "loss: 0.782933 [57664/60000]\n", "Test Error: \n", " Accuracy: 71.1%, Avg loss: 0.781673 \n", "\n", "Done!\n", "\n", "Saved PyTorch Model State to model.pth\n" ] } ], "source": [ "# Train the model as you like here (skipped for brevity)\n", "model = NeuralNetwork()\n", "print(model)\n", "\n", "loss_fn = nn.CrossEntropyLoss()\n", "optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)\n", "\n", "epochs = 10\n", "for t in range(epochs):\n", " print(f\"Epoch {t+1}\\n-------------------------------\")\n", " train(train_dataloader, model, loss_fn, optimizer)\n", " test(test_dataloader, model, loss_fn)\n", "print(\"Done!\")\n", "\n", "torch.save(model.state_dict(), \"model.pth\")\n", "print(\"\\nSaved PyTorch Model State to model.pth\")" ] }, { "cell_type": "code", "execution_count": 7, "id": "82db373a", "metadata": { "id": "82db373a" }, "outputs": [], "source": [ "\n", "model.eval()\n", "model.load_state_dict(torch.load(\"model.pth\", weights_only=True))\n", "dummy_input = test_data[0][0]\n", "\n", " # Export the model\n", "torch.onnx.export(model, # model being run\n", " dummy_input, # model input (or a tuple for multiple inputs)\n", " model_path, # where to save the model (can be a file or file-like object)\n", " export_params=True, # store the trained parameter weights inside the model file\n", " opset_version=10, # the ONNX version to export the model to\n", " do_constant_folding=True, # whether to execute constant folding for optimization\n", " input_names = ['input'], # the model's input names\n", " output_names = ['output'], # the model's output names\n", " dynamic_axes={'input' : {0 : 'batch_size'}, # variable length axes\n", " 'output' : {0 : 'batch_size'}})\n", "\n", "data_array = ((dummy_input).detach().numpy()).reshape([-1]).tolist()\n", "\n", "data = dict(input_data = [data_array])\n", "\n", " # Serialize data into file:\n", "json.dump( data, open(data_path, 'w' ))\n" ] }, { "cell_type": "code", "execution_count": 24, "id": "d5e374a2", "metadata": { "id": "d5e374a2" }, "outputs": [], "source": [ "py_run_args = ezkl.PyRunArgs()\n", "py_run_args.input_visibility = \"private\"\n", "py_run_args.output_visibility = \"public\"\n", "py_run_args.param_visibility = \"fixed\" # private by default\n", "\n", "res = ezkl.gen_settings(model_path, settings_path, py_run_args=py_run_args)\n", "\n", "assert res == True\n" ] }, { "cell_type": "code", "execution_count": null, "id": "c6iTDB6f2JOK", "metadata": { "id": "c6iTDB6f2JOK" }, "outputs": [], "source": [ "cal_path = os.path.join(\"calibration.json\")\n", "\n", "data_array = (torch.rand(20, *shape, requires_grad=True).detach().numpy()).reshape([-1]).tolist()\n", "\n", "data = dict(input_data = [data_array])\n", "\n", "# Serialize data into file:\n", "json.dump(data, open(cal_path, 'w'))\n", "\n", "\n", "await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")" ] }, { "cell_type": "code", "execution_count": 25, "id": "3aa4f090", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "3aa4f090", "outputId": "c370bb1a-35ea-4044-84df-21110a79b7b0" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING:ezkl:low scale values (<8) may impact precision\n" ] } ], "source": [ "res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)\n", "assert res == True" ] }, { "cell_type": "code", "execution_count": 27, "id": "8b74dcee", "metadata": { "id": "8b74dcee" }, "outputs": [], "source": [ "# srs path\n", "res = await ezkl.get_srs( settings_path, srs_path=\"kzg.srs\")" ] }, { "cell_type": "code", "execution_count": 28, "id": "18c8b7c7", "metadata": { "id": "18c8b7c7" }, "outputs": [], "source": [ "# now generate the witness file\n", "\n", "res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n", "assert os.path.isfile(witness_path)" ] }, { "cell_type": "code", "execution_count": 29, "id": "b1c561a8", "metadata": { "id": "b1c561a8" }, "outputs": [], "source": [ "\n", "# HERE WE SETUP THE CIRCUIT PARAMS\n", "# WE GOT KEYS\n", "# WE GOT CIRCUIT PARAMETERS\n", "# EVERYTHING ANYONE HAS EVER NEEDED FOR ZK\n", "\n", "\n", "\n", "res = ezkl.setup(\n", " compiled_model_path,\n", " vk_path,\n", " pk_path,\n", " srs_path=\"kzg.srs\"\n", " )\n", "\n", "assert res == True\n", "assert os.path.isfile(vk_path)\n", "assert os.path.isfile(pk_path)\n", "assert os.path.isfile(settings_path)" ] }, { "cell_type": "code", "execution_count": 31, "id": "c384cbc8", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "c384cbc8", "outputId": "bdb4cd87-820c-460d-e9f9-962b3fa35235" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'instances': [['ebfdffef93f5e1439170b97948e833285d588181b64550b829a031e1724e6430', 'e4fdffef93f5e1439170b97948e833285d588181b64550b829a031e1724e6430', 'e4feffef93f5e1439170b97948e833285d588181b64550b829a031e1724e6430', 'c4feffef93f5e1439170b97948e833285d588181b64550b829a031e1724e6430', '46ffffef93f5e1439170b97948e833285d588181b64550b829a031e1724e6430', '8302000000000000000000000000000000000000000000000000000000000000', 'e7feffef93f5e1439170b97948e833285d588181b64550b829a031e1724e6430', '5102000000000000000000000000000000000000000000000000000000000000', '3d01000000000000000000000000000000000000000000000000000000000000', 'b302000000000000000000000000000000000000000000000000000000000000']], 'proof': '', 'transcript_type': 'EVM'}\n" ] } ], "source": [ "# GENERATE A PROOF\n", "\n", "\n", "proof_path = os.path.join('test.pf')\n", "\n", "res = ezkl.prove(\n", " witness_path,\n", " compiled_model_path,\n", " pk_path,\n", " proof_path,\n", " \"single\",\n", " srs_path=\"kzg.srs\",\n", " )\n", "\n", "print(res)\n", "assert os.path.isfile(proof_path)" ] }, { "cell_type": "code", "execution_count": 32, "id": "76f00d41", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "76f00d41", "outputId": "c4a7b153-dd32-42ed-9b31-1cb4616ebff4" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "verified\n" ] } ], "source": [ "# VERIFY IT\n", "\n", "res = ezkl.verify(\n", " proof_path,\n", " settings_path,\n", " vk_path,\n", " srs_path=\"kzg.srs\"\n", " )\n", "\n", "assert res == True\n", "print(\"verified\")" ] } ], "metadata": { "colab": { "provenance": [] }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.7" } }, "nbformat": 4, "nbformat_minor": 5 }