{ "cells": [ { "cell_type": "markdown", "source": [ "# QML on MNIST Classification" ], "metadata": {} }, { "cell_type": "markdown", "source": [ "## Overview\n", "\n", "The aim of this tutorial is not about the machine learning perspective on better design of QML method for MNIST classification. Instead, we use a simple parameterized circuit and demonstrate the QML-related technical ingredients of ``tensorcircuit``. Nevertheless, this note is by no means a good practice on QML.\n", "\n", "[WIP note]" ], "metadata": {} }, { "cell_type": "markdown", "source": [ "## Setup" ], "metadata": {} }, { "cell_type": "code", "execution_count": 1, "source": [ "from functools import partial\n", "import numpy as np\n", "import tensorflow as tf\n", "import jax\n", "from jax.config import config\n", "\n", "config.update(\"jax_enable_x64\", True)\n", "from jax import numpy as jnp\n", "import optax\n", "import tensorcircuit as tc" ], "outputs": [], "metadata": {} }, { "cell_type": "code", "execution_count": 2, "source": [ "tc.set_backend(\"tensorflow\")\n", "tc.set_dtype(\"complex128\")" ], "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "('complex128', 'float64')" ] }, "metadata": {}, "execution_count": 2 } ], "metadata": {} }, { "cell_type": "markdown", "source": [ "## Data Processing\n", "\n", "We utilize MNIST data and resize them to 3*3 to fit into a 9-qubit circuit.\n", "The testbed we use is a binary classification task, digit 1 vs. 5.\n", "And since this tutorial is not about good practice on QML, we leave the validation set away.\n", "And we only collect 100 data points for a small demo." ], "metadata": {} }, { "cell_type": "code", "execution_count": 3, "source": [ "# numpy data\n", "(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()\n", "x_train = x_train[..., np.newaxis] / 255.0\n", "\n", "\n", "def filter_pair(x, y, a, b):\n", " keep = (y == a) | (y == b)\n", " x, y = x[keep], y[keep]\n", " y = y == a\n", " return x, y\n", "\n", "\n", "x_train, y_train = filter_pair(x_train, y_train, 1, 5)\n", "x_train_small = tf.image.resize(x_train, (3, 3)).numpy()\n", "x_train_bin = np.array(x_train_small > 0.5, dtype=np.float32)\n", "x_train_bin = np.squeeze(x_train_bin)[:100]" ], "outputs": [], "metadata": {} }, { "cell_type": "code", "execution_count": 4, "source": [ "# tensorflow data\n", "\n", "x_train_tf = tf.reshape(tf.constant(x_train_bin, dtype=tf.float64), [-1, 9])\n", "y_train_tf = tf.constant(y_train[:100], dtype=tf.float64)\n", "\n", "# jax data\n", "\n", "x_train_jax = jnp.array(x_train_bin, dtype=np.float64).reshape([100, -1])\n", "y_train_jax = jnp.array(y_train[:100], dtype=np.float64).reshape([100])" ], "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" ] } ], "metadata": {} }, { "cell_type": "markdown", "source": [ "## Using ``vectorized_value_and_grad`` API" ], "metadata": {} }, { "cell_type": "code", "execution_count": 5, "source": [ "nlayers = 3\n", "\n", "\n", "def qml_loss(x, y, weights, nlayers):\n", " n = 9\n", " weights = tc.backend.cast(weights, \"complex128\")\n", " x = tc.backend.cast(x, \"complex128\")\n", " c = tc.Circuit(n)\n", " for i in range(n):\n", " c.rx(i, theta=x[i])\n", " for j in range(nlayers):\n", " for i in range(n - 1):\n", " c.cnot(i, i + 1)\n", " for i in range(n):\n", " c.rx(i, theta=weights[2 * j, i])\n", " c.ry(i, theta=weights[2 * j + 1, i])\n", " ypred = c.expectation([tc.gates.z(), (4,)])\n", " ypred = tc.backend.real(ypred)\n", " ypred = (tc.backend.real(ypred) + 1) / 2.0\n", " return -y * tc.backend.log(ypred) - (1 - y) * tc.backend.log(1 - ypred), ypred" ], "outputs": [], "metadata": {} }, { "cell_type": "code", "execution_count": 6, "source": [ "def get_qml_vvag():\n", " qml_vvag = tc.backend.vectorized_value_and_grad(\n", " qml_loss, argnums=(2,), vectorized_argnums=(0, 1), has_aux=True\n", " )\n", " qml_vvag = tc.backend.jit(qml_vvag, static_argnums=(3,))\n", " return qml_vvag\n", "\n", "\n", "qml_vvag = get_qml_vvag()\n", "qml_vvag(x_train_tf, y_train_tf, tf.ones([nlayers * 2, 9], dtype=tf.float64), nlayers)" ], "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "((,\n", " ),\n", " [])" ] }, "metadata": {}, "execution_count": 6 } ], "metadata": {} }, { "cell_type": "code", "execution_count": 7, "source": [ "# %timeit qml_vvag(x_train_tf, y_train_tf, tf.ones([nlayers*2, 9], dtype=tf.float64), nlayers)" ], "outputs": [], "metadata": {} }, { "cell_type": "markdown", "source": [ "### Jax Backend Compatibility" ], "metadata": {} }, { "cell_type": "code", "execution_count": 8, "source": [ "tc.set_backend(\"jax\")" ], "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "" ] }, "metadata": {}, "execution_count": 8 } ], "metadata": {} }, { "cell_type": "code", "execution_count": 9, "source": [ "qml_vvag = get_qml_vvag()\n", "qml_vvag(\n", " x_train_jax, y_train_jax, jnp.ones([nlayers * 2, 9], dtype=np.float64), nlayers\n", ")" ], "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "((DeviceArray([0.8433698 , 0.56257199, 0.54653163, 0.56257199, 0.82036163,\n", " 0.56257199, 0.56257199, 0.58030506, 0.82036163, 0.56257199,\n", " 0.82036163, 0.56257199, 0.82036163, 0.56257199, 0.54653163,\n", " 0.54653163, 0.56257199, 0.56257199, 0.58030506, 0.82036163,\n", " 0.54653163, 0.56257199, 0.56257199, 0.56257199, 0.56257199,\n", " 0.56257199, 0.56257199, 0.85182866, 0.56257199, 0.82036163,\n", " 0.82036163, 0.56257199, 0.8433698 , 0.56257199, 0.8433698 ,\n", " 0.56257199, 0.85182866, 0.56257199, 0.82036163, 0.54653163,\n", " 0.56257199, 0.56257199, 0.56257199, 0.56257199, 0.8433698 ,\n", " 0.58030506, 0.56257199, 0.82036163, 0.8433698 , 0.8433698 ,\n", " 0.54653163, 0.56257199, 0.82036163, 0.86501404, 0.56257199,\n", " 0.56257199, 0.8433698 , 0.56257199, 0.85182866, 0.82036163,\n", " 0.82036163, 0.56257199, 0.82036163, 0.56257199, 0.56257199,\n", " 0.56257199, 0.82036163, 0.8433698 , 0.8433698 , 0.82036163,\n", " 0.56257199, 0.56257199, 0.56257199, 0.56257199, 0.56257199,\n", " 0.54653163, 0.86501404, 0.54653163, 0.54653163, 0.82036163,\n", " 0.56257199, 0.54653163, 0.8433698 , 0.54653163, 0.8433698 ,\n", " 0.56257199, 0.56257199, 0.8433698 , 0.82036163, 0.8433698 ,\n", " 0.56257199, 0.56257199, 0.56257199, 0.56257199, 0.56257199,\n", " 0.82036163, 0.56257199, 0.58030506, 0.8433698 , 0.56257199], dtype=float64),\n", " DeviceArray([0.56974181, 0.56974181, 0.57895436, 0.56974181, 0.55972759,\n", " 0.56974181, 0.56974181, 0.55972759, 0.55972759, 0.56974181,\n", " 0.55972759, 0.56974181, 0.55972759, 0.56974181, 0.57895436,\n", " 0.57895436, 0.56974181, 0.56974181, 0.55972759, 0.55972759,\n", " 0.57895436, 0.56974181, 0.56974181, 0.56974181, 0.56974181,\n", " 0.56974181, 0.56974181, 0.57336595, 0.56974181, 0.55972759,\n", " 0.55972759, 0.56974181, 0.56974181, 0.56974181, 0.56974181,\n", " 0.56974181, 0.57336595, 0.56974181, 0.55972759, 0.57895436,\n", " 0.56974181, 0.56974181, 0.56974181, 0.56974181, 0.56974181,\n", " 0.55972759, 0.56974181, 0.55972759, 0.56974181, 0.56974181,\n", " 0.57895436, 0.56974181, 0.55972759, 0.57895436, 0.56974181,\n", " 0.56974181, 0.56974181, 0.56974181, 0.57336595, 0.55972759,\n", " 0.55972759, 0.56974181, 0.55972759, 0.56974181, 0.56974181,\n", " 0.56974181, 0.55972759, 0.56974181, 0.56974181, 0.55972759,\n", " 0.56974181, 0.56974181, 0.56974181, 0.56974181, 0.56974181,\n", " 0.57895436, 0.57895436, 0.57895436, 0.57895436, 0.55972759,\n", " 0.56974181, 0.57895436, 0.56974181, 0.57895436, 0.56974181,\n", " 0.56974181, 0.56974181, 0.56974181, 0.55972759, 0.56974181,\n", " 0.56974181, 0.56974181, 0.56974181, 0.56974181, 0.56974181,\n", " 0.55972759, 0.56974181, 0.55972759, 0.56974181, 0.56974181], dtype=float64)),\n", " (DeviceArray([[ 5.79464357e-02, 1.12182823e-01, 8.13605755e-02,\n", " 1.52611620e-01, 1.13641690e+00, -1.41695736e+00,\n", " 7.62883290e-01, 1.01307851e-15, -1.03389519e-15],\n", " [ 3.57155554e-02, 1.61488509e-01, 7.62331819e-02,\n", " 1.50335863e-01, -1.10363460e-01, -3.23606686e-01,\n", " -3.14523756e-01, -4.09394740e-16, -1.65145675e-15],\n", " [-3.04959149e-02, 6.03271869e-02, 2.47760477e-02,\n", " -7.61859417e-02, 1.72064441e+00, 1.66891120e+00,\n", " 6.24500451e-16, 9.71445147e-16, -8.39606162e-16],\n", " [-8.26503466e-03, -6.90338030e-02, -1.07589110e-01,\n", " 1.88650816e-01, -2.68228700e+00, -2.41159987e+00,\n", " -6.66133815e-16, 5.27355937e-16, -7.56339436e-16],\n", " [-1.08246745e-15, 1.65839564e-15, -2.77555756e-16,\n", " 1.38777878e-17, 4.02608813e-01, -7.14706072e-16,\n", " -4.16333634e-17, -9.02056208e-16, -9.57567359e-16],\n", " [-1.08940634e-15, 3.26128013e-16, -5.34294831e-16,\n", " 6.93889390e-18, -1.11152817e+00, 3.19189120e-16,\n", " -1.68615122e-15, 1.24900090e-16, 1.79717352e-15]], dtype=float64),))" ] }, "metadata": {}, "execution_count": 9 } ], "metadata": {} }, { "cell_type": "code", "execution_count": 10, "source": [ "# %timeit qml_vvag(x_train_jax, y_train_jax, jnp.ones([nlayers * 2, 9], dtype=np.float64), nlayers)" ], "outputs": [], "metadata": {} }, { "cell_type": "markdown", "source": [ "### Training Using ``tf.data``" ], "metadata": {} }, { "cell_type": "code", "execution_count": 11, "source": [ "# switch back to tensorflow\n", "tc.set_backend(\"tensorflow\")\n", "qml_vvag = get_qml_vvag()\n", "qml_vvag = tc.backend.jit(qml_vvag, static_argnums=(3,))" ], "outputs": [], "metadata": {} }, { "cell_type": "code", "execution_count": 12, "source": [ "mnist_data = (\n", " tf.data.Dataset.from_tensor_slices((x_train_tf, y_train_tf))\n", " .repeat(200)\n", " .shuffle(100)\n", " .batch(32)\n", ")" ], "outputs": [], "metadata": {} }, { "cell_type": "code", "execution_count": 13, "source": [ "opt = tf.keras.optimizers.Adam(1e-2)\n", "w = tf.Variable(\n", " initial_value=tf.random.normal(shape=(2 * nlayers, 9), stddev=0.5, dtype=tf.float64)\n", ")\n", "for i, (xs, ys) in zip(range(2000), mnist_data):\n", " (losses, ypreds), grad = qml_vvag(xs, ys, w, nlayers)\n", " if i % 20 == 0:\n", " print(tf.reduce_mean(losses))\n", " opt.apply_gradients([(grad[0], w)])" ], "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "tf.Tensor(0.689301607482696, shape=(), dtype=float64)\n", "tf.Tensor(0.6825438352666904, shape=(), dtype=float64)\n", "tf.Tensor(0.6815497367036047, shape=(), dtype=float64)\n", "tf.Tensor(0.6632433448327015, shape=(), dtype=float64)\n", "tf.Tensor(0.6641348270253142, shape=(), dtype=float64)\n", "tf.Tensor(0.6779914200102861, shape=(), dtype=float64)\n", "tf.Tensor(0.6550256969249619, shape=(), dtype=float64)\n", "tf.Tensor(0.6801325087248677, shape=(), dtype=float64)\n", "tf.Tensor(0.6190616725052769, shape=(), dtype=float64)\n", "tf.Tensor(0.6711760566099414, shape=(), dtype=float64)\n", "tf.Tensor(0.6965496746836946, shape=(), dtype=float64)\n", "tf.Tensor(0.6443036572691725, shape=(), dtype=float64)\n", "tf.Tensor(0.6060956714527996, shape=(), dtype=float64)\n", "tf.Tensor(0.6728839286340991, shape=(), dtype=float64)\n", "tf.Tensor(0.6584085272471567, shape=(), dtype=float64)\n", "tf.Tensor(0.6600981577311038, shape=(), dtype=float64)\n", "tf.Tensor(0.6581071758186605, shape=(), dtype=float64)\n", "tf.Tensor(0.6609348320181809, shape=(), dtype=float64)\n", "tf.Tensor(0.5919640703180435, shape=(), dtype=float64)\n", "tf.Tensor(0.6362392080775805, shape=(), dtype=float64)\n", "tf.Tensor(0.6844038809425064, shape=(), dtype=float64)\n", "tf.Tensor(0.6924617230085226, shape=(), dtype=float64)\n", "tf.Tensor(0.6594653043250199, shape=(), dtype=float64)\n", "tf.Tensor(0.7076707818117074, shape=(), dtype=float64)\n", "tf.Tensor(0.6730725215608222, shape=(), dtype=float64)\n", "tf.Tensor(0.6565711271336594, shape=(), dtype=float64)\n", "tf.Tensor(0.6665226844123278, shape=(), dtype=float64)\n", "tf.Tensor(0.6368469891760338, shape=(), dtype=float64)\n", "tf.Tensor(0.6499572506552256, shape=(), dtype=float64)\n", "tf.Tensor(0.6110576844713855, shape=(), dtype=float64)\n", "tf.Tensor(0.6312147945757532, shape=(), dtype=float64)\n", "tf.Tensor(0.6013772883771527, shape=(), dtype=float64)\n" ] } ], "metadata": {} }, { "cell_type": "markdown", "source": [ "## Using ``tf.keras`` API" ], "metadata": {} }, { "cell_type": "code", "execution_count": 14, "source": [ "from tensorcircuit import keras\n", "\n", "\n", "def qml_y(x, weights, nlayers):\n", " n = 9\n", " weights = tc.backend.cast(weights, \"complex128\")\n", " x = tc.backend.cast(x, \"complex128\")\n", " c = tc.Circuit(n)\n", " for i in range(n):\n", " c.rx(i, theta=x[i])\n", " for j in range(nlayers):\n", " for i in range(n - 1):\n", " c.cnot(i, i + 1)\n", " for i in range(n):\n", " c.rx(i, theta=weights[2 * j, i])\n", " c.ry(i, theta=weights[2 * j + 1, i])\n", " ypred = c.expectation([tc.gates.z(), (4,)])\n", " ypred = tc.backend.real(ypred)\n", " ypred = (tc.backend.real(ypred) + 1) / 2.0\n", " return ypred\n", "\n", "\n", "ql = keras.QuantumLayer(partial(qml_y, nlayers=nlayers), [(2 * nlayers, 9)])" ], "outputs": [], "metadata": {} }, { "cell_type": "code", "execution_count": 15, "source": [ "# keras interface with value and grad paradigm\n", "\n", "\n", "@tf.function\n", "def my_vvag(xs, ys):\n", " with tf.GradientTape() as tape:\n", " ypred = ql(xs)\n", " loss = tf.keras.losses.BinaryCrossentropy()(ys, ypred)\n", " return loss, tape.gradient(loss, ql.variables)\n", "\n", "\n", "my_vvag(x_train_tf, y_train_tf)" ], "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "(,\n", " [])" ] }, "metadata": {}, "execution_count": 15 } ], "metadata": {} }, { "cell_type": "code", "execution_count": 16, "source": [ "# %timeit my_vvag(x_train_tf, y_train_tf)" ], "outputs": [], "metadata": {} }, { "cell_type": "code", "execution_count": 17, "source": [ "# keras interface with keras training paradigm\n", "\n", "model = tf.keras.Sequential([ql])\n", "\n", "model.compile(\n", " loss=tf.keras.losses.BinaryCrossentropy(),\n", " optimizer=tf.keras.optimizers.Adam(0.01),\n", " metrics=[tf.keras.metrics.BinaryAccuracy()],\n", ")\n", "\n", "model.fit(x_train_tf, y_train_tf, batch_size=32, epochs=100)" ], "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Epoch 1/100\n", "4/4 [==============================] - 21s 8ms/step - loss: 0.7221 - binary_accuracy: 0.6016\n", "Epoch 2/100\n", "4/4 [==============================] - 0s 7ms/step - loss: 0.7123 - binary_accuracy: 0.6016\n", "Epoch 3/100\n", "4/4 [==============================] - 0s 8ms/step - loss: 0.7039 - binary_accuracy: 0.6562\n", "Epoch 4/100\n", "4/4 [==============================] - 0s 7ms/step - loss: 0.7009 - binary_accuracy: 0.6562\n", "Epoch 5/100\n", "4/4 [==============================] - 0s 7ms/step - loss: 0.6979 - binary_accuracy: 0.6562\n", "Epoch 6/100\n", "4/4 [==============================] - 0s 7ms/step - loss: 0.6957 - binary_accuracy: 0.6016\n", "Epoch 7/100\n", "4/4 [==============================] - 0s 7ms/step - loss: 0.6935 - binary_accuracy: 0.4922\n", "Epoch 8/100\n", "4/4 [==============================] - 0s 7ms/step - loss: 0.6918 - binary_accuracy: 0.6562\n", "Epoch 9/100\n", "4/4 [==============================] - 0s 7ms/step - loss: 0.6910 - binary_accuracy: 0.7109\n", "Epoch 10/100\n", "4/4 [==============================] - 0s 8ms/step - loss: 0.6901 - binary_accuracy: 0.5469\n", "Epoch 11/100\n", "4/4 [==============================] - 0s 7ms/step - loss: 0.6893 - binary_accuracy: 0.6016\n", "Epoch 12/100\n", "4/4 [==============================] - 0s 8ms/step - loss: 0.6883 - binary_accuracy: 0.6562\n", "Epoch 13/100\n", "4/4 [==============================] - 0s 8ms/step - loss: 0.6876 - binary_accuracy: 0.6016\n", "Epoch 14/100\n", "4/4 [==============================] - 0s 7ms/step - loss: 0.6869 - binary_accuracy: 0.5469\n", "Epoch 15/100\n", "4/4 [==============================] - 0s 7ms/step - loss: 0.6865 - binary_accuracy: 0.7109\n", "Epoch 16/100\n", "4/4 [==============================] - 0s 7ms/step - loss: 0.6858 - binary_accuracy: 0.6562\n", "Epoch 17/100\n", "4/4 [==============================] - 0s 7ms/step - loss: 0.6853 - binary_accuracy: 0.5469\n", "Epoch 18/100\n", "4/4 [==============================] - 0s 9ms/step - loss: 0.6847 - binary_accuracy: 0.6016\n", "Epoch 19/100\n", "4/4 [==============================] - 0s 9ms/step - loss: 0.6844 - binary_accuracy: 0.6016\n", "Epoch 20/100\n", "4/4 [==============================] - 0s 9ms/step - loss: 0.6842 - binary_accuracy: 0.5469\n", "Epoch 21/100\n", "4/4 [==============================] - 0s 9ms/step - loss: 0.6841 - binary_accuracy: 0.6016\n", "Epoch 22/100\n", "4/4 [==============================] - 0s 7ms/step - loss: 0.6839 - binary_accuracy: 0.7109\n", "Epoch 23/100\n", "4/4 [==============================] - 0s 7ms/step - loss: 0.6835 - binary_accuracy: 0.6562\n", "Epoch 24/100\n", "4/4 [==============================] - 0s 7ms/step - loss: 0.6829 - binary_accuracy: 0.6016\n", "Epoch 25/100\n", "4/4 [==============================] - 0s 7ms/step - loss: 0.6823 - binary_accuracy: 0.7109\n", "Epoch 26/100\n", "4/4 [==============================] - 0s 7ms/step - loss: 0.6816 - binary_accuracy: 0.6016\n", "Epoch 27/100\n", "4/4 [==============================] - 0s 7ms/step - loss: 0.6811 - binary_accuracy: 0.5469\n", "Epoch 28/100\n", "4/4 [==============================] - 0s 7ms/step - loss: 0.6805 - binary_accuracy: 0.4922\n", "Epoch 29/100\n", "4/4 [==============================] - 0s 7ms/step - loss: 0.6803 - binary_accuracy: 0.6562\n", "Epoch 30/100\n", "4/4 [==============================] - 0s 7ms/step - loss: 0.6799 - binary_accuracy: 0.5469\n", "Epoch 31/100\n", "4/4 [==============================] - 0s 6ms/step - loss: 0.6795 - binary_accuracy: 0.6016\n", "Epoch 32/100\n", "4/4 [==============================] - 0s 6ms/step - loss: 0.6793 - binary_accuracy: 0.5469\n", "Epoch 33/100\n", "4/4 [==============================] - 0s 8ms/step - loss: 0.6789 - binary_accuracy: 0.6016\n", "Epoch 34/100\n", "4/4 [==============================] - 0s 7ms/step - loss: 0.6785 - binary_accuracy: 0.5469\n", "Epoch 35/100\n", "4/4 [==============================] - 0s 7ms/step - loss: 0.6781 - binary_accuracy: 0.6562\n", "Epoch 36/100\n", "4/4 [==============================] - 0s 7ms/step - loss: 0.6775 - binary_accuracy: 0.5469\n", "Epoch 37/100\n", "4/4 [==============================] - 0s 7ms/step - loss: 0.6762 - binary_accuracy: 0.6016\n", "Epoch 38/100\n", "4/4 [==============================] - 0s 7ms/step - loss: 0.6752 - binary_accuracy: 0.6562\n", "Epoch 39/100\n", "4/4 [==============================] - 0s 7ms/step - loss: 0.6736 - binary_accuracy: 0.6016\n", "Epoch 40/100\n", "4/4 [==============================] - 0s 7ms/step - loss: 0.6714 - binary_accuracy: 0.6562\n", "Epoch 41/100\n", "4/4 [==============================] - 0s 7ms/step - loss: 0.6690 - binary_accuracy: 0.6562\n", "Epoch 42/100\n", "4/4 [==============================] - 0s 7ms/step - loss: 0.6658 - binary_accuracy: 0.6016\n", "Epoch 43/100\n", "4/4 [==============================] - 0s 7ms/step - loss: 0.6637 - binary_accuracy: 0.6016\n", "Epoch 44/100\n", "4/4 [==============================] - 0s 7ms/step - loss: 0.6617 - binary_accuracy: 0.6562\n", "Epoch 45/100\n", "4/4 [==============================] - 0s 7ms/step - loss: 0.6596 - binary_accuracy: 0.6016\n", "Epoch 46/100\n", "4/4 [==============================] - 0s 7ms/step - loss: 0.6586 - binary_accuracy: 0.6016\n", "Epoch 47/100\n", "4/4 [==============================] - 0s 7ms/step - loss: 0.6571 - binary_accuracy: 0.6016\n", "Epoch 48/100\n", "4/4 [==============================] - 0s 7ms/step - loss: 0.6561 - binary_accuracy: 0.6562\n", "Epoch 49/100\n", "4/4 [==============================] - 0s 8ms/step - loss: 0.6549 - binary_accuracy: 0.6562\n", "Epoch 50/100\n", "4/4 [==============================] - 0s 7ms/step - loss: 0.6536 - binary_accuracy: 0.6562\n", "Epoch 51/100\n", "4/4 [==============================] - 0s 8ms/step - loss: 0.6536 - binary_accuracy: 0.6562\n", "Epoch 52/100\n", "4/4 [==============================] - 0s 8ms/step - loss: 0.6519 - binary_accuracy: 0.6016\n", "Epoch 53/100\n", "4/4 [==============================] - 0s 7ms/step - loss: 0.6516 - binary_accuracy: 0.7109\n", "Epoch 54/100\n", "4/4 [==============================] - 0s 7ms/step - loss: 0.6504 - binary_accuracy: 0.6016\n", "Epoch 55/100\n", "4/4 [==============================] - 0s 8ms/step - loss: 0.6500 - binary_accuracy: 0.6016\n", "Epoch 56/100\n", "4/4 [==============================] - 0s 7ms/step - loss: 0.6486 - binary_accuracy: 0.5469\n", "Epoch 57/100\n", "4/4 [==============================] - 0s 7ms/step - loss: 0.6468 - binary_accuracy: 0.6016\n", "Epoch 58/100\n", "4/4 [==============================] - 0s 8ms/step - loss: 0.6466 - binary_accuracy: 0.7109\n", "Epoch 59/100\n", "4/4 [==============================] - 0s 8ms/step - loss: 0.6456 - binary_accuracy: 0.6562\n", "Epoch 60/100\n", "4/4 [==============================] - 0s 7ms/step - loss: 0.6446 - binary_accuracy: 0.7109\n", "Epoch 61/100\n", "4/4 [==============================] - 0s 7ms/step - loss: 0.6435 - binary_accuracy: 0.6016\n", "Epoch 62/100\n", "4/4 [==============================] - 0s 6ms/step - loss: 0.6429 - binary_accuracy: 0.7109\n", "Epoch 63/100\n", "4/4 [==============================] - 0s 6ms/step - loss: 0.6417 - binary_accuracy: 0.7109\n", "Epoch 64/100\n", "4/4 [==============================] - 0s 6ms/step - loss: 0.6432 - binary_accuracy: 0.5469\n", "Epoch 65/100\n", "4/4 [==============================] - 0s 7ms/step - loss: 0.6439 - binary_accuracy: 0.7109\n", "Epoch 66/100\n", "4/4 [==============================] - 0s 7ms/step - loss: 0.6430 - binary_accuracy: 0.6016\n", "Epoch 67/100\n", "4/4 [==============================] - 0s 6ms/step - loss: 0.6415 - binary_accuracy: 0.5469\n", "Epoch 68/100\n", "4/4 [==============================] - 0s 7ms/step - loss: 0.6391 - binary_accuracy: 0.6562\n", "Epoch 69/100\n", "4/4 [==============================] - 0s 6ms/step - loss: 0.6375 - binary_accuracy: 0.6016\n", "Epoch 70/100\n", "4/4 [==============================] - 0s 7ms/step - loss: 0.6372 - binary_accuracy: 0.6016\n", "Epoch 71/100\n", "4/4 [==============================] - 0s 7ms/step - loss: 0.6369 - binary_accuracy: 0.5469\n", "Epoch 72/100\n", "4/4 [==============================] - 0s 7ms/step - loss: 0.6367 - binary_accuracy: 0.6016\n", "Epoch 73/100\n", "4/4 [==============================] - 0s 7ms/step - loss: 0.6380 - binary_accuracy: 0.7109\n", "Epoch 74/100\n", "4/4 [==============================] - 0s 8ms/step - loss: 0.6377 - binary_accuracy: 0.6562\n", "Epoch 75/100\n", "4/4 [==============================] - 0s 8ms/step - loss: 0.6365 - binary_accuracy: 0.6562\n", "Epoch 76/100\n", "4/4 [==============================] - 0s 8ms/step - loss: 0.6350 - binary_accuracy: 0.6562\n", "Epoch 77/100\n", "4/4 [==============================] - 0s 8ms/step - loss: 0.6331 - binary_accuracy: 0.6016\n", "Epoch 78/100\n", "4/4 [==============================] - 0s 8ms/step - loss: 0.6331 - binary_accuracy: 0.6562\n", "Epoch 79/100\n" ] }, { "output_type": "stream", "name": "stdout", "text": [ "4/4 [==============================] - 0s 7ms/step - loss: 0.6337 - binary_accuracy: 0.5469\n", "Epoch 80/100\n", "4/4 [==============================] - 0s 7ms/step - loss: 0.6319 - binary_accuracy: 0.6562\n", "Epoch 81/100\n", "4/4 [==============================] - 0s 7ms/step - loss: 0.6317 - binary_accuracy: 0.7109\n", "Epoch 82/100\n", "4/4 [==============================] - 0s 6ms/step - loss: 0.6312 - binary_accuracy: 0.6562\n", "Epoch 83/100\n", "4/4 [==============================] - 0s 6ms/step - loss: 0.6307 - binary_accuracy: 0.6562\n", "Epoch 84/100\n", "4/4 [==============================] - 0s 7ms/step - loss: 0.6326 - binary_accuracy: 0.6016\n", "Epoch 85/100\n", "4/4 [==============================] - 0s 7ms/step - loss: 0.6307 - binary_accuracy: 0.6016\n", "Epoch 86/100\n", "4/4 [==============================] - 0s 6ms/step - loss: 0.6299 - binary_accuracy: 0.6016\n", "Epoch 87/100\n", "4/4 [==============================] - 0s 7ms/step - loss: 0.6288 - binary_accuracy: 0.6016\n", "Epoch 88/100\n", "4/4 [==============================] - 0s 7ms/step - loss: 0.6288 - binary_accuracy: 0.7109\n", "Epoch 89/100\n", "4/4 [==============================] - 0s 6ms/step - loss: 0.6289 - binary_accuracy: 0.6562\n", "Epoch 90/100\n", "4/4 [==============================] - 0s 8ms/step - loss: 0.6273 - binary_accuracy: 0.6562\n", "Epoch 91/100\n", "4/4 [==============================] - 0s 7ms/step - loss: 0.6275 - binary_accuracy: 0.5469\n", "Epoch 92/100\n", "4/4 [==============================] - 0s 7ms/step - loss: 0.6269 - binary_accuracy: 0.7109\n", "Epoch 93/100\n", "4/4 [==============================] - 0s 7ms/step - loss: 0.6269 - binary_accuracy: 0.6016\n", "Epoch 94/100\n", "4/4 [==============================] - 0s 7ms/step - loss: 0.6263 - binary_accuracy: 0.6016\n", "Epoch 95/100\n", "4/4 [==============================] - 0s 6ms/step - loss: 0.6258 - binary_accuracy: 0.6016\n", "Epoch 96/100\n", "4/4 [==============================] - 0s 7ms/step - loss: 0.6256 - binary_accuracy: 0.6562\n", "Epoch 97/100\n", "4/4 [==============================] - 0s 7ms/step - loss: 0.6250 - binary_accuracy: 0.6562\n", "Epoch 98/100\n", "4/4 [==============================] - 0s 6ms/step - loss: 0.6246 - binary_accuracy: 0.7109\n", "Epoch 99/100\n", "4/4 [==============================] - 0s 6ms/step - loss: 0.6240 - binary_accuracy: 0.6562\n", "Epoch 100/100\n", "4/4 [==============================] - 0s 6ms/step - loss: 0.6251 - binary_accuracy: 0.5469\n" ] }, { "output_type": "execute_result", "data": { "text/plain": [ "" ] }, "metadata": {}, "execution_count": 17 } ], "metadata": {} }, { "cell_type": "markdown", "source": [ "### Quantum-Classical Hybrid Model in Keras" ], "metadata": {} }, { "cell_type": "code", "execution_count": 18, "source": [ "def qml_ys(x, weights, nlayers):\n", " n = 9\n", " weights = tc.backend.cast(weights, \"complex128\")\n", " x = tc.backend.cast(x, \"complex128\")\n", " c = tc.Circuit(n)\n", " for i in range(n):\n", " c.rx(i, theta=x[i])\n", " for j in range(nlayers):\n", " for i in range(n - 1):\n", " c.cnot(i, i + 1)\n", " for i in range(n):\n", " c.rx(i, theta=weights[2 * j, i])\n", " c.ry(i, theta=weights[2 * j + 1, i])\n", " ypreds = []\n", " for i in range(n):\n", " ypred = c.expectation([tc.gates.z(), (i,)])\n", " ypred = tc.backend.real(ypred)\n", " ypred = (tc.backend.real(ypred) + 1) / 2.0\n", " ypreds.append(ypred)\n", " return tc.backend.stack(ypreds)" ], "outputs": [], "metadata": {} }, { "cell_type": "code", "execution_count": 19, "source": [ "ql = tc.keras.QuantumLayer(partial(qml_ys, nlayers=nlayers), [(2 * nlayers, 9)])\n", "model = tf.keras.Sequential([ql, tf.keras.layers.Dense(1, activation=\"sigmoid\")])" ], "outputs": [], "metadata": {} }, { "cell_type": "code", "execution_count": 20, "source": [ "model.compile(\n", " loss=tf.keras.losses.BinaryCrossentropy(),\n", " optimizer=tf.keras.optimizers.Adam(0.01),\n", " metrics=[tf.keras.metrics.BinaryAccuracy()],\n", ")\n", "\n", "model.fit(x_train_tf, y_train_tf, batch_size=32, epochs=100)" ], "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Epoch 1/100\n", "4/4 [==============================] - 24s 14ms/step - loss: 0.9307 - binary_accuracy: 0.3700\n", "Epoch 2/100\n", "4/4 [==============================] - 0s 14ms/step - loss: 0.8286 - binary_accuracy: 0.3700\n", "Epoch 3/100\n", "4/4 [==============================] - 0s 15ms/step - loss: 0.7538 - binary_accuracy: 0.3700\n", "Epoch 4/100\n", "4/4 [==============================] - 0s 14ms/step - loss: 0.7044 - binary_accuracy: 0.3700\n", "Epoch 5/100\n", "4/4 [==============================] - 0s 14ms/step - loss: 0.6796 - binary_accuracy: 0.6300\n", "Epoch 6/100\n", "4/4 [==============================] - 0s 14ms/step - loss: 0.6599 - binary_accuracy: 0.6300\n", "Epoch 7/100\n", "4/4 [==============================] - 0s 13ms/step - loss: 0.6543 - binary_accuracy: 0.6300\n", "Epoch 8/100\n", "4/4 [==============================] - 0s 15ms/step - loss: 0.6559 - binary_accuracy: 0.6300\n", "Epoch 9/100\n", "4/4 [==============================] - 0s 14ms/step - loss: 0.6575 - binary_accuracy: 0.6300\n", "Epoch 10/100\n", "4/4 [==============================] - 0s 14ms/step - loss: 0.6588 - binary_accuracy: 0.6300\n", "Epoch 11/100\n", "4/4 [==============================] - 0s 14ms/step - loss: 0.6587 - binary_accuracy: 0.6300\n", "Epoch 12/100\n", "4/4 [==============================] - 0s 14ms/step - loss: 0.6567 - binary_accuracy: 0.6300\n", "Epoch 13/100\n", "4/4 [==============================] - 0s 14ms/step - loss: 0.6551 - binary_accuracy: 0.6300\n", "Epoch 14/100\n", "4/4 [==============================] - 0s 14ms/step - loss: 0.6540 - binary_accuracy: 0.6300\n", "Epoch 15/100\n", "4/4 [==============================] - 0s 14ms/step - loss: 0.6528 - binary_accuracy: 0.6300\n", "Epoch 16/100\n", "4/4 [==============================] - 0s 15ms/step - loss: 0.6533 - binary_accuracy: 0.6300\n", "Epoch 17/100\n", "4/4 [==============================] - 0s 13ms/step - loss: 0.6540 - binary_accuracy: 0.6300\n", "Epoch 18/100\n", "4/4 [==============================] - 0s 13ms/step - loss: 0.6550 - binary_accuracy: 0.6300\n", "Epoch 19/100\n", "4/4 [==============================] - 0s 15ms/step - loss: 0.6546 - binary_accuracy: 0.6300\n", "Epoch 20/100\n", "4/4 [==============================] - 0s 14ms/step - loss: 0.6538 - binary_accuracy: 0.6300\n", "Epoch 21/100\n", "4/4 [==============================] - 0s 15ms/step - loss: 0.6513 - binary_accuracy: 0.6300\n", "Epoch 22/100\n", "4/4 [==============================] - 0s 15ms/step - loss: 0.6499 - binary_accuracy: 0.6300\n", "Epoch 23/100\n", "4/4 [==============================] - 0s 14ms/step - loss: 0.6497 - binary_accuracy: 0.6300\n", "Epoch 24/100\n", "4/4 [==============================] - 0s 15ms/step - loss: 0.6491 - binary_accuracy: 0.6300\n", "Epoch 25/100\n", "4/4 [==============================] - 0s 14ms/step - loss: 0.6492 - binary_accuracy: 0.6300\n", "Epoch 26/100\n", "4/4 [==============================] - 0s 14ms/step - loss: 0.6493 - binary_accuracy: 0.6300\n", "Epoch 27/100\n", "4/4 [==============================] - 0s 15ms/step - loss: 0.6487 - binary_accuracy: 0.6300\n", "Epoch 28/100\n", "4/4 [==============================] - 0s 14ms/step - loss: 0.6473 - binary_accuracy: 0.6300\n", "Epoch 29/100\n", "4/4 [==============================] - 0s 14ms/step - loss: 0.6470 - binary_accuracy: 0.6300\n", "Epoch 30/100\n", "4/4 [==============================] - 0s 16ms/step - loss: 0.6462 - binary_accuracy: 0.6300\n", "Epoch 31/100\n", "4/4 [==============================] - 0s 14ms/step - loss: 0.6451 - binary_accuracy: 0.6300\n", "Epoch 32/100\n", "4/4 [==============================] - 0s 15ms/step - loss: 0.6439 - binary_accuracy: 0.6300\n", "Epoch 33/100\n", "4/4 [==============================] - 0s 14ms/step - loss: 0.6433 - binary_accuracy: 0.6300\n", "Epoch 34/100\n", "4/4 [==============================] - 0s 14ms/step - loss: 0.6438 - binary_accuracy: 0.6300\n", "Epoch 35/100\n", "4/4 [==============================] - 0s 15ms/step - loss: 0.6413 - binary_accuracy: 0.6300\n", "Epoch 36/100\n", "4/4 [==============================] - 0s 14ms/step - loss: 0.6403 - binary_accuracy: 0.6300\n", "Epoch 37/100\n", "4/4 [==============================] - 0s 15ms/step - loss: 0.6391 - binary_accuracy: 0.6300\n", "Epoch 38/100\n", "4/4 [==============================] - 0s 16ms/step - loss: 0.6388 - binary_accuracy: 0.6300\n", "Epoch 39/100\n", "4/4 [==============================] - 0s 15ms/step - loss: 0.6379 - binary_accuracy: 0.6300\n", "Epoch 40/100\n", "4/4 [==============================] - 0s 15ms/step - loss: 0.6365 - binary_accuracy: 0.6300\n", "Epoch 41/100\n", "4/4 [==============================] - 0s 15ms/step - loss: 0.6352 - binary_accuracy: 0.6300\n", "Epoch 42/100\n", "4/4 [==============================] - 0s 15ms/step - loss: 0.6336 - binary_accuracy: 0.6300\n", "Epoch 43/100\n", "4/4 [==============================] - 0s 15ms/step - loss: 0.6338 - binary_accuracy: 0.6300\n", "Epoch 44/100\n", "4/4 [==============================] - 0s 15ms/step - loss: 0.6358 - binary_accuracy: 0.6300\n", "Epoch 45/100\n", "4/4 [==============================] - 0s 14ms/step - loss: 0.6367 - binary_accuracy: 0.6300\n", "Epoch 46/100\n", "4/4 [==============================] - 0s 13ms/step - loss: 0.6345 - binary_accuracy: 0.6300\n", "Epoch 47/100\n", "4/4 [==============================] - 0s 15ms/step - loss: 0.6303 - binary_accuracy: 0.6300\n", "Epoch 48/100\n", "4/4 [==============================] - 0s 14ms/step - loss: 0.6298 - binary_accuracy: 0.6300\n", "Epoch 49/100\n", "4/4 [==============================] - 0s 14ms/step - loss: 0.6285 - binary_accuracy: 0.6300\n", "Epoch 50/100\n", "4/4 [==============================] - 0s 14ms/step - loss: 0.6280 - binary_accuracy: 0.6300\n", "Epoch 51/100\n", "4/4 [==============================] - 0s 14ms/step - loss: 0.6274 - binary_accuracy: 0.6300\n", "Epoch 52/100\n", "4/4 [==============================] - 0s 15ms/step - loss: 0.6268 - binary_accuracy: 0.6300\n", "Epoch 53/100\n", "4/4 [==============================] - 0s 14ms/step - loss: 0.6262 - binary_accuracy: 0.6300\n", "Epoch 54/100\n", "4/4 [==============================] - 0s 14ms/step - loss: 0.6246 - binary_accuracy: 0.6300\n", "Epoch 55/100\n", "4/4 [==============================] - 0s 15ms/step - loss: 0.6231 - binary_accuracy: 0.6300\n", "Epoch 56/100\n", "4/4 [==============================] - 0s 14ms/step - loss: 0.6228 - binary_accuracy: 0.6300\n", "Epoch 57/100\n", "4/4 [==============================] - 0s 14ms/step - loss: 0.6226 - binary_accuracy: 0.6300\n", "Epoch 58/100\n", "4/4 [==============================] - 0s 13ms/step - loss: 0.6224 - binary_accuracy: 0.6300\n", "Epoch 59/100\n", "4/4 [==============================] - 0s 14ms/step - loss: 0.6228 - binary_accuracy: 0.6900\n", "Epoch 60/100\n", "4/4 [==============================] - 0s 15ms/step - loss: 0.6224 - binary_accuracy: 0.7200\n", "Epoch 61/100\n", "4/4 [==============================] - 0s 14ms/step - loss: 0.6214 - binary_accuracy: 0.7200\n", "Epoch 62/100\n", "4/4 [==============================] - 0s 14ms/step - loss: 0.6183 - binary_accuracy: 0.6300\n", "Epoch 63/100\n", "4/4 [==============================] - 0s 17ms/step - loss: 0.6161 - binary_accuracy: 0.6300\n", "Epoch 64/100\n", "4/4 [==============================] - 0s 13ms/step - loss: 0.6142 - binary_accuracy: 0.6300\n", "Epoch 65/100\n", "4/4 [==============================] - 0s 14ms/step - loss: 0.6131 - binary_accuracy: 0.6300\n", "Epoch 66/100\n", "4/4 [==============================] - 0s 14ms/step - loss: 0.6124 - binary_accuracy: 0.6300\n", "Epoch 67/100\n", "4/4 [==============================] - 0s 14ms/step - loss: 0.6101 - binary_accuracy: 0.6300\n", "Epoch 68/100\n", "4/4 [==============================] - 0s 15ms/step - loss: 0.6117 - binary_accuracy: 0.6600\n", "Epoch 69/100\n", "4/4 [==============================] - 0s 14ms/step - loss: 0.6099 - binary_accuracy: 0.7200\n", "Epoch 70/100\n", "4/4 [==============================] - 0s 15ms/step - loss: 0.6085 - binary_accuracy: 0.7200\n", "Epoch 71/100\n", "4/4 [==============================] - 0s 15ms/step - loss: 0.6070 - binary_accuracy: 0.7200\n", "Epoch 72/100\n", "4/4 [==============================] - 0s 13ms/step - loss: 0.6069 - binary_accuracy: 0.7200\n", "Epoch 73/100\n", "4/4 [==============================] - 0s 14ms/step - loss: 0.6060 - binary_accuracy: 0.7200\n", "Epoch 74/100\n", "4/4 [==============================] - 0s 14ms/step - loss: 0.6040 - binary_accuracy: 0.7200\n", "Epoch 75/100\n", "4/4 [==============================] - 0s 14ms/step - loss: 0.6041 - binary_accuracy: 0.7200\n", "Epoch 76/100\n", "4/4 [==============================] - 0s 15ms/step - loss: 0.6011 - binary_accuracy: 0.6900\n", "Epoch 77/100\n", "4/4 [==============================] - 0s 14ms/step - loss: 0.6005 - binary_accuracy: 0.6300\n", "Epoch 78/100\n", "4/4 [==============================] - 0s 14ms/step - loss: 0.5993 - binary_accuracy: 0.6300\n", "Epoch 79/100\n", "4/4 [==============================] - 0s 15ms/step - loss: 0.5987 - binary_accuracy: 0.6300\n", "Epoch 80/100\n", "4/4 [==============================] - 0s 14ms/step - loss: 0.5982 - binary_accuracy: 0.6300\n", "Epoch 81/100\n", "4/4 [==============================] - 0s 15ms/step - loss: 0.5970 - binary_accuracy: 0.6300\n", "Epoch 82/100\n", "4/4 [==============================] - 0s 15ms/step - loss: 0.5960 - binary_accuracy: 0.6900\n", "Epoch 83/100\n", "4/4 [==============================] - 0s 14ms/step - loss: 0.5937 - binary_accuracy: 0.7200\n", "Epoch 84/100\n", "4/4 [==============================] - 0s 15ms/step - loss: 0.5932 - binary_accuracy: 0.7200\n", "Epoch 85/100\n", "4/4 [==============================] - 0s 15ms/step - loss: 0.5913 - binary_accuracy: 0.7200\n", "Epoch 86/100\n", "4/4 [==============================] - 0s 14ms/step - loss: 0.5904 - binary_accuracy: 0.7200\n", "Epoch 87/100\n", "4/4 [==============================] - 0s 15ms/step - loss: 0.5895 - binary_accuracy: 0.7200\n", "Epoch 88/100\n", "4/4 [==============================] - 0s 14ms/step - loss: 0.5882 - binary_accuracy: 0.7200\n", "Epoch 89/100\n", "4/4 [==============================] - 0s 13ms/step - loss: 0.5873 - binary_accuracy: 0.7200\n", "Epoch 90/100\n", "4/4 [==============================] - 0s 15ms/step - loss: 0.5858 - binary_accuracy: 0.7200\n", "Epoch 91/100\n", "4/4 [==============================] - 0s 14ms/step - loss: 0.5848 - binary_accuracy: 0.7200\n", "Epoch 92/100\n", "4/4 [==============================] - 0s 16ms/step - loss: 0.5839 - binary_accuracy: 0.7200\n", "Epoch 93/100\n", "4/4 [==============================] - 0s 15ms/step - loss: 0.5832 - binary_accuracy: 0.7200\n", "Epoch 94/100\n", "4/4 [==============================] - 0s 14ms/step - loss: 0.5836 - binary_accuracy: 0.7200\n", "Epoch 95/100\n", "4/4 [==============================] - 0s 15ms/step - loss: 0.5848 - binary_accuracy: 0.7200\n", "Epoch 96/100\n", "4/4 [==============================] - 0s 15ms/step - loss: 0.5836 - binary_accuracy: 0.7200\n", "Epoch 97/100\n", "4/4 [==============================] - 0s 14ms/step - loss: 0.5812 - binary_accuracy: 0.7200\n", "Epoch 98/100\n", "4/4 [==============================] - 0s 14ms/step - loss: 0.5795 - binary_accuracy: 0.7200\n", "Epoch 99/100\n", "4/4 [==============================] - 0s 14ms/step - loss: 0.5780 - binary_accuracy: 0.7200\n", "Epoch 100/100\n", "4/4 [==============================] - 0s 14ms/step - loss: 0.5768 - binary_accuracy: 0.7200\n" ] }, { "output_type": "execute_result", "data": { "text/plain": [ "" ] }, "metadata": {}, "execution_count": 20 } ], "metadata": {} }, { "cell_type": "markdown", "source": [ "### Hybrid Model in Jax" ], "metadata": {} }, { "cell_type": "code", "execution_count": 21, "source": [ "tc.set_backend(\"jax\")" ], "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "" ] }, "metadata": {}, "execution_count": 21 } ], "metadata": {} }, { "cell_type": "code", "execution_count": 22, "source": [ "key = jax.random.PRNGKey(42)\n", "key, *subkeys = jax.random.split(key, num=4)\n", "params = {\n", " \"qweights\": jax.random.normal(subkeys[0], shape=[nlayers * 2, 9]),\n", " \"cweights:w\": jax.random.normal(subkeys[1], shape=[9]),\n", " \"cweights:b\": jax.random.normal(subkeys[2], shape=[1]),\n", "}" ], "outputs": [], "metadata": {} }, { "cell_type": "code", "execution_count": 23, "source": [ "def qml_hybrid_loss(x, y, params, nlayers):\n", " weights = params[\"qweights\"]\n", " w = params[\"cweights:w\"]\n", " b = params[\"cweights:b\"]\n", " ypred = qml_ys(x, weights, nlayers)\n", " ypred = tc.backend.reshape(ypred, [-1, 1])\n", " ypred = w @ ypred + b\n", " ypred = jax.nn.sigmoid(ypred)\n", " ypred = ypred[0]\n", " loss = -y * tc.backend.log(ypred) - (1 - y) * tc.backend.log(1 - ypred)\n", " return loss" ], "outputs": [], "metadata": {} }, { "cell_type": "code", "execution_count": 24, "source": [ "qml_hybrid_loss_vag = tc.backend.jit(\n", " tc.backend.vvag(qml_hybrid_loss, vectorized_argnums=(0, 1), argnums=2),\n", " static_argnums=3,\n", ")" ], "outputs": [], "metadata": {} }, { "cell_type": "code", "execution_count": 25, "source": [ "qml_hybrid_loss_vag(x_train_jax, y_train_jax, params, nlayers)" ], "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "(DeviceArray([3.73282398, 0.02421603, 0.02899787, 0.02421603, 4.08996787,\n", " 0.03069481, 0.02421603, 0.01688146, 4.08996787, 0.03069481,\n", " 4.08996787, 0.02421603, 4.08996787, 0.02421603, 0.02899787,\n", " 0.03354042, 0.02421603, 0.02421603, 0.01688146, 4.08996787,\n", " 0.03354042, 0.02421603, 0.02421603, 0.03069481, 0.02421603,\n", " 0.02421603, 0.03069481, 3.73798651, 0.02421603, 3.68810189,\n", " 4.08996787, 0.03069481, 3.73282398, 0.03069481, 3.73282398,\n", " 0.02421603, 3.49674264, 0.02421603, 4.08996787, 0.02899787,\n", " 0.02421603, 0.02421603, 0.03069481, 0.03069481, 3.73282398,\n", " 0.02533775, 0.03069481, 3.68810189, 3.73282398, 3.49896983,\n", " 0.02899787, 0.03069481, 4.08996787, 3.41172721, 0.02421603,\n", " 0.02421603, 3.73282398, 0.02421603, 3.73798651, 3.68810189,\n", " 4.08996787, 0.03069481, 4.08996787, 0.02421603, 0.03069481,\n", " 0.02421603, 3.68810189, 3.49896983, 3.49896983, 4.08996787,\n", " 0.02421603, 0.02421603, 0.02421603, 0.02421603, 0.03069481,\n", " 0.02899787, 3.41172721, 0.03354042, 0.02899787, 3.68810189,\n", " 0.02421603, 0.03354042, 3.73282398, 0.02899787, 3.73282398,\n", " 0.03069481, 0.02421603, 3.73282398, 3.68810189, 3.73282398,\n", " 0.02421603, 0.02421603, 0.03069481, 0.03069481, 0.02421603,\n", " 4.08996787, 0.02421603, 0.01688146, 3.73282398, 0.02421603], dtype=float64),\n", " {'cweights:b': DeviceArray([34.49476789], dtype=float64),\n", " 'cweights:w': DeviceArray([16.81782277, 15.05718878, 15.02498328, 23.18351696,\n", " 17.01897109, 16.13466029, 16.26046722, 23.54180309,\n", " 12.0721068 ], dtype=float64),\n", " 'qweights': DeviceArray([[-1.16993912e+01, -6.74730815e+00, -2.27227872e+00,\n", " -1.08703899e+00, 2.56625721e+00, 1.69462223e+00,\n", " -4.89847061e+00, 1.62487935e+00, 1.02424785e+01],\n", " [ 3.29984130e+00, -5.90635608e-01, 2.11407610e+00,\n", " 3.67096431e-02, 3.32526833e+00, -1.06468920e+00,\n", " -4.12299772e-01, -7.78105081e+00, -3.38506241e+00],\n", " [-3.59434442e+00, 3.84548015e+00, 8.50409406e-01,\n", " -2.66504333e+00, 1.47559967e+00, 1.38536529e+00,\n", " -1.47291602e-01, -7.32213541e+00, 5.17021200e+00],\n", " [-1.30975045e+00, 1.83003338e+00, 1.51443252e+00,\n", " 3.15082430e+00, -4.41767236e+00, 6.25968228e+00,\n", " 5.96980281e+00, 9.67198061e+00, -1.63091455e+01],\n", " [-2.24757712e+00, -5.66276080e-01, -1.67376432e+00,\n", " 1.75249049e-01, 2.77917505e-01, 3.84402979e-02,\n", " 1.03434679e-01, -4.05760762e-02, -3.33671956e-03],\n", " [ 3.13599600e+00, 3.85470136e+00, 3.17986238e-01,\n", " 1.72308312e-01, 5.09749793e+00, 2.90706770e-02,\n", " -5.59919189e-01, 1.96734688e+00, -6.96372626e-01]], dtype=float64)})" ] }, "metadata": {}, "execution_count": 25 } ], "metadata": {} }, { "cell_type": "code", "execution_count": 26, "source": [ "optimizer = optax.adam(5e-3)\n", "opt_state = optimizer.init(params)\n", "for i, (xs, ys) in zip(range(2000), mnist_data): # using tf data loader here\n", " xs = xs.numpy()\n", " ys = ys.numpy()\n", " v, grads = qml_hybrid_loss_vag(xs, ys, params, nlayers)\n", " updates, opt_state = optimizer.update(grads, opt_state)\n", " params = optax.apply_updates(params, updates)\n", " if i % 30 == 0:\n", " print(jnp.mean(v))" ], "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "1.2979572281332594\n", "0.8331012068009501\n", "0.6805939758448183\n", "0.5897353928152392\n", "0.6460840124038746\n", "0.6093143713632384\n", "0.6671721223530598\n", "0.5863347320393952\n", "0.5465362554431986\n", "0.5594138744621404\n", "0.5493311423294576\n", "0.5228166702417829\n", "0.6176455570797168\n", "0.5256494465741394\n", "0.5359881696740493\n", "0.5787532611935906\n", "0.49082340457493323\n", "0.4062487079116086\n", "0.5802733401377229\n", "0.4762524476616207\n", "0.5404245247888219\n" ] } ], "metadata": {} } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.0" } }, "nbformat": 4, "nbformat_minor": 5 }