From ad05352b52f41fbf32b8f21b687c1c8e4c792e85 Mon Sep 17 00:00:00 2001 From: Phionx Date: Tue, 30 Jul 2024 22:44:18 -0400 Subject: [PATCH] removing jit since we don't want to vmap over anything that is jitted --- qcsys/common/utils.py | 3 +-- qcsys/devices/drive.py | 2 +- qcsys/devices/squid_transmon.py | 2 -- qcsys/devices/system.py | 3 +-- qcsys/devices/transmon.py | 1 - qcsys/devices/transmon_single_charge_basis.py | 2 -- qcsys/devices/truncated_transmon.py | 1 - 7 files changed, 3 insertions(+), 11 deletions(-) diff --git a/qcsys/common/utils.py b/qcsys/common/utils.py index b48197d..98c477c 100644 --- a/qcsys/common/utils.py +++ b/qcsys/common/utils.py @@ -2,11 +2,10 @@ from scipy.special import pbdv from scipy import constants -from jax import jit, vmap, grad +from jax import vmap, grad import jax.scipy as jsp -@jit def factorial_approx(n): return jsp.special.gamma(n+1) diff --git a/qcsys/devices/drive.py b/qcsys/devices/drive.py index 18f444a..7b23381 100644 --- a/qcsys/devices/drive.py +++ b/qcsys/devices/drive.py @@ -4,7 +4,7 @@ from typing import Dict, Any from flax import struct -from jax import tree_util, jit, Array +from jax import tree_util, Array from jax import config import jax.numpy as jnp diff --git a/qcsys/devices/squid_transmon.py b/qcsys/devices/squid_transmon.py index 5bb9d96..a8f92b3 100644 --- a/qcsys/devices/squid_transmon.py +++ b/qcsys/devices/squid_transmon.py @@ -9,7 +9,6 @@ import jaxquantum as jqt import jax.numpy as jnp import jax.scipy as jsp -from jax import jit from qcsys.devices.base import Device @@ -125,7 +124,6 @@ def get_H_full(self): - 2 * self.params["Ej"] * jnp.cos(phi_ext / 2) * self.linear_ops["cos(φ)"] ) - @jit def get_op_in_H_eigenbasis(self, op): """ We overwrite this function to effectively truncate to the first N levels out of N_max_charge diff --git a/qcsys/devices/system.py b/qcsys/devices/system.py index 589941a..c05f738 100644 --- a/qcsys/devices/system.py +++ b/qcsys/devices/system.py @@ -5,7 +5,7 @@ import math from flax import struct -from jax import jit, vmap, Array +from jax import vmap, Array from jax import config import jaxquantum as jqt import jax.numpy as jnp @@ -16,7 +16,6 @@ config.update("jax_enable_x64", True) -@partial(jit, static_argnums=(0,)) def calculate_eig(Ns, H: jqt.Qarray): N_tot = math.prod(Ns) edxs = jnp.arange(N_tot) diff --git a/qcsys/devices/transmon.py b/qcsys/devices/transmon.py index 7fd798c..6302d8a 100644 --- a/qcsys/devices/transmon.py +++ b/qcsys/devices/transmon.py @@ -5,7 +5,6 @@ import jaxquantum as jqt import jax.numpy as jnp import jax.scipy as jsp -from jax import jit from qcsys.devices.base import BasisTypes, FluxDevice, HamiltonianTypes diff --git a/qcsys/devices/transmon_single_charge_basis.py b/qcsys/devices/transmon_single_charge_basis.py index 508da42..f1f6862 100644 --- a/qcsys/devices/transmon_single_charge_basis.py +++ b/qcsys/devices/transmon_single_charge_basis.py @@ -9,7 +9,6 @@ import jaxquantum as jqt import jax.numpy as jnp import jax.scipy as jsp -from jax import jit from qcsys.devices.base import Device @@ -110,7 +109,6 @@ def get_H_full(self): - self.params["Ej"] * self.linear_ops["cos(φ)"] ) - @jit def get_op_in_H_eigenbasis(self, op): """ We overwrite this function to effectively truncate to the first N levels out of N_max_charge diff --git a/qcsys/devices/truncated_transmon.py b/qcsys/devices/truncated_transmon.py index 14afd30..f2cd9ca 100644 --- a/qcsys/devices/truncated_transmon.py +++ b/qcsys/devices/truncated_transmon.py @@ -5,7 +5,6 @@ import jaxquantum as jqt import jax.numpy as jnp import jax.scipy as jsp -from jax import jit from qcsys.devices.base import FluxDevice