Skip to content

Commit

Permalink
Merge branch 'main' of github.com:EQuS/qcsys
Browse files Browse the repository at this point in the history
  • Loading branch information
Phionx committed Aug 11, 2024
2 parents 0135186 + c3e0a4d commit bd03bf0
Show file tree
Hide file tree
Showing 10 changed files with 347 additions and 278 deletions.
2 changes: 0 additions & 2 deletions qcsys/devices/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
from .kno import *
from .transmon import *
from .tunable_transmon import *
from .transmon_single_charge_basis import *
from .squid_transmon import *
from .fluxonium import *
from .ats import *
from .ideal_qubit import *
1 change: 1 addition & 0 deletions qcsys/devices/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
class BasisTypes(str, Enum):
fock = "fock"
charge = "charge"
single_charge = "single_charge"

@classmethod
def from_str(cls, string: str):
Expand Down
15 changes: 8 additions & 7 deletions qcsys/devices/drive.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from jax import tree_util, Array
from jax import config
import jax.numpy as jnp
import jaxquantum as jqt

config.update("jax_enable_x64", True)

Expand All @@ -31,20 +32,20 @@ def label(self):
def ops(self):
return self.common_ops()

def common_ops(self) -> Dict[str, Array]:
def common_ops(self) -> Dict[str, jqt.Qarray]:
ops = {}

M_max = self.M_max

# Construct M = ∑ₘ m|m><m| operator in drive charge basis
ops["M"] = jnp.diag(jnp.arange(-M_max, M_max + 1))
ops["M"] = jqt.jnp2jqt(jnp.diag(jnp.arange(-M_max, M_max + 1)))

# Construct Id = ∑ₘ|m><m| in the drive charge basis
ops["Id_drive"] = jnp.identity(2 * M_max + 1)
ops["id"] = jqt.jnp2jqt(jnp.identity(2 * M_max + 1))

# Construct M₊ ≡ exp(iθ) and M₋ ≡ exp(-iθ) operators for drive
ops["M-"] = jnp.eye(2 * M_max + 1, k=1)
ops["M+"] = jnp.eye(2 * M_max + 1, k=-1)
ops["M-"] = jqt.jnp2jqt(jnp.eye(2 * M_max + 1, k=1))
ops["M+"] = jqt.jnp2jqt(jnp.eye(2 * M_max + 1, k=-1))

# Construct cos(θ) ≡ 1/2 * [M₊ + M₋] = 1/2 * ∑ₘ|m+1><m| + h.c
ops["cos(θ)"] = 0.5 * (ops["M+"] + ops["M-"])
Expand All @@ -54,8 +55,8 @@ def common_ops(self) -> Dict[str, Array]:

# Construct more general drive operators cos(kθ) and sin(kθ)
for k in range(2, M_max + 1):
ops[f"M_+{k}"] = jnp.eye(2 * M_max + 1, k=-k)
ops[f"M_-{k}"] = jnp.eye(2 * M_max + 1, k=k)
ops[f"M_+{k}"] = jqt.jnp2jqt(jnp.eye(2 * M_max + 1, k=-k))
ops[f"M_-{k}"] = jqt.jnp2jqt(jnp.eye(2 * M_max + 1, k=k))
ops[f"cos({k}θ)"] = 0.5 * (ops[f"M_+{k}"] + ops[f"M_-{k}"])
ops[f"sin({k}θ)"] = -0.5j * (ops[f"M_+{k}"] - ops[f"M_-{k}"])

Expand Down
26 changes: 18 additions & 8 deletions qcsys/devices/ideal_qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from jax import config
import jaxquantum as jqt

from .base import Device
from .base import Device, BasisTypes, HamiltonianTypes


config.update("jax_enable_x64", True)
Expand All @@ -16,30 +16,40 @@ class IdealQubit(Device):
Ideal qubit Device.
"""

@classmethod
def param_validation(cls, N, N_pre_diag, params, hamiltonian, basis):
""" This can be overridden by subclasses."""
assert basis == BasisTypes.fock, "IdealQubit is a two-level system defined in the Fock basis."
assert hamiltonian == HamiltonianTypes.full, "IdealQubit requires a full Hamiltonian."
assert N == N_pre_diag == 2, "IdealQubit is a two-level system."
assert "ω" in params, "IdealQubit requires a frequency parameter 'ω'."

def common_ops(self):
""" Written in the linear basis. """
"""Written in the linear basis."""
ops = {}

assert self.N_pre_diag == 2
assert self.N == 2

N = self.N_pre_diag
ops["id"] = jqt.identity(N)
ops["sigma_z"] = jqt.sigmaz()
ops["sigma_x"] = jqt.sigmax()
ops["sigmaz"] = jqt.sigmaz()
ops["sigmax"] = jqt.sigmax()
ops["sigmay"] = jqt.sigmay()
ops["sigmam"] = jqt.sigmam()
ops["sigmap"] = jqt.sigmap()

return ops

def get_linear_ω(self):
"""Get frequency of linear terms."""
return self.params["frequency"]
return self.params["ω"]

def get_H_linear(self):
"""Return linear terms in H."""
w = self.get_linear_ω()
return w * self.linear_ops["sigma_z"]
return (w / 2) * self.linear_ops["sigma_z"]

def get_H_full(self):
"""Return full H in linear basis."""
H = self.get_H_linear()
return H
return self.get_H_linear()
12 changes: 8 additions & 4 deletions qcsys/devices/kno.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
""" Kerr Nonlinear Oscillator """

from flax import struct
from jax import config
import jaxquantum as jqt
import jax.numpy as jnp

from qcsys.devices.base import Device
from qcsys.devices.base import Device, BasisTypes, HamiltonianTypes

config.update("jax_enable_x64", True)

Expand All @@ -16,9 +17,12 @@ class KNO(Device):
"""

@classmethod
def create(cls, N, params, label=0, use_linear=False):
return cls(N, params, label, use_linear)

def param_validation(cls, N, N_pre_diag, params, hamiltonian, basis):
""" This can be overridden by subclasses."""
assert basis == BasisTypes.fock, "Kerr Nonlinear Oscillator must be defined in the Fock basis."
assert hamiltonian == HamiltonianTypes.full, "Kerr Nonlinear Oscillator uses a full Hamiltonian."
assert "ω" in params and "α" in params, "Kerr Nonlinear Oscillator requires frequency 'ω' and anharmonicity 'α' as parameters."

def common_ops(self):
ops = {}

Expand Down
133 changes: 0 additions & 133 deletions qcsys/devices/squid_transmon.py

This file was deleted.

40 changes: 34 additions & 6 deletions qcsys/devices/transmon.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@ def param_validation(cls, N, N_pre_diag, params, hamiltonian, basis):
elif hamiltonian == HamiltonianTypes.truncated:
assert basis == BasisTypes.fock, "Truncated Hamiltonian only works with Fock basis."
elif hamiltonian == HamiltonianTypes.full:
assert basis == BasisTypes.charge, "Full Hamiltonian only works with charge basis."
assert basis in [BasisTypes.charge, BasisTypes.single_charge], "Full Hamiltonian only works with Cooper pair charge or single-electron charge bases."

# Set the gate offset charge to zero if not provided
if "ng" not in params:
params["ng"] = 0.0

assert (N_pre_diag - 1) % 2 == 0, "N_pre_diag must be odd."

Expand All @@ -44,11 +48,37 @@ def common_ops(self):
ops["n"] = 1j * self.n_zpf() * (ops["a_dag"] - ops["a"])

elif self.basis == BasisTypes.charge:
"""
Here H = 4 * Ec (n - ng)² - Ej cos(φ) in the Cooper pair charge basis.
"""
ops["id"] = jqt.identity(N)
ops["cos(φ)"] = 0.5*(jqt.jnp2jqt(jnp.eye(N,k=1) + jnp.eye(N,k=-1)))
ops["sin(φ)"] = 0.5j*(jqt.jnp2jqt(jnp.eye(N,k=1) - jnp.eye(N,k=-1)))
n_max = (N - 1) // 2
ops["n"] = jqt.jnp2jqt(jnp.diag(jnp.arange(-n_max, n_max + 1)))

n_minus_ng_array = jnp.arange(-n_max, n_max + 1) - self.params["ng"] * jnp.ones(N)
ops["H_charge"] = jqt.jnp2jqt(jnp.diag(4 * self.params["Ec"] * n_minus_ng_array**2))

elif self.basis == BasisTypes.single_charge:
"""
Here H = Ec (n - 2ng)² - Ej cos(φ) in the single-electron charge basis. Using Eq. (5.36) of Kyle Serniak's
thesis, we have H = Ec ∑ₙ(n - 2*ng) |n⟩⟨n| - Ej/2 * ∑ₙ|n⟩⟨n+2| + h.c where n counts the number of electrons,
not Cooper pairs. Note, we use 2ng instead of ng to match the gate offset charge convention of the transmon
(as done in Kyle's thesis).
"""
n_max = (N - 1) // 2

ops["id"] = jqt.identity(N)
ops["cos(φ)"] = 0.5*(jqt.jnp2jqt(jnp.eye(N,k=2) + jnp.eye(N,k=-2)))
ops["sin(φ)"] = 0.5j*(jqt.jnp2jqt(jnp.eye(N,k=2) - jnp.eye(N,k=-2)))
ops["cos(φ/2)"] = 0.5*(jqt.jnp2jqt(jnp.eye(N,k=1) + jnp.eye(N,k=-1)))
ops["sin(φ/2)"] = 0.5j*(jqt.jnp2jqt(jnp.eye(N,k=1) - jnp.eye(N,k=-1)))
ops["n"] = jqt.jnp2jqt(jnp.diag(jnp.arange(-n_max, n_max + 1)))

n_minus_ng_array = jnp.arange(-n_max, n_max + 1) - 2 * self.params["ng"] * jnp.ones(N)
ops["H_charge"] = jqt.jnp2jqt(jnp.diag(self.params["Ec"] * n_minus_ng_array**2))

return ops

@property
Expand All @@ -74,10 +104,7 @@ def get_H_linear(self):

def get_H_full(self):
"""Return full H in specified basis."""

cos_phi_op = self.original_ops["cos(φ)"]
n_op = self.original_ops["n"]
return 4*self.params["Ec"]*n_op@n_op - self.Ej * cos_phi_op
return self.original_ops["H_charge"] - self.Ej * self.original_ops["cos(φ)"]

def get_H_truncated(self):
"""Return truncated H in specified basis."""
Expand All @@ -95,7 +122,6 @@ def _get_H_in_original_basis(self):
return self.get_H_full()
elif self.hamiltonian == HamiltonianTypes.truncated:
return self.get_H_truncated()


def potential(self, phi):
"""Return potential energy for a given phi."""
Expand All @@ -115,6 +141,8 @@ def calculate_wavefunctions(self, phi_vals):

if self.basis == BasisTypes.fock:
return super().calculate_wavefunctions(phi_vals)
elif self.basis == BasisTypes.single_charge:
raise NotImplementedError("Wavefunctions for single charge basis not yet implemented.")
elif self.basis == BasisTypes.charge:
phi_vals = jnp.array(phi_vals)

Expand Down
Loading

0 comments on commit bd03bf0

Please sign in to comment.