Training a Quantum-Classical Neural Network with Qiskit Runtime and AWS Batch#
The script below trains a hybrid neural network model that classifies images of dogs and cats.
To run the script, first pip install
the relevant requirements:
covalent==0.209.1
covalent-aws-plugins==0.13.0
covalent-awsbatch-plugin==0.26.0
matplotlib==3.7.1
numpy==1.23.5
qiskit-aer==0.12.0
qiskit-ibm-runtime==0.9.1
qiskit-ibmq-provider==0.20.2
qiskit-terra==0.23.2
scipy==1.10.1
torch==2.0.0
torchvision==0.15.1
Below, Covalent is used to access GPU’s via AWS Batch and QPU’s via IBM Quantum.
"""
Use Covalent to access EC2 instance on AWS that submits jobs to IBM Quantum
Hybrid classifier based on:
https://towardsdatascience.com/binary-image-classification-in-pytorch-5adf64f8c781
by Marcello Politi
data source:
https://www.kaggle.com/datasets/biaiscience/dogs-vs-cats
"""
import os
import warnings
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, List, Optional, Tuple
from zipfile import ZipFile
import covalent as ct
import matplotlib as mpl
import matplotlib.pyplot as plt
import torch
from qiskit import QuantumCircuit, QuantumRegister
from qiskit.circuit import Parameter
from qiskit.quantum_info import SparsePauliOp
from qiskit_ibm_runtime import (Estimator, IBMBackend, Options,
QiskitRuntimeService, RuntimeJob)
from torch import Tensor, nn
from torch.nn.modules.loss import L1Loss
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torchvision.models import resnet34
# export IBM_QUANTUM_TOKEN="abcdefghijklmnopqrstuvwxyz1234567890mytokenfromibmquantum"
TOKEN = os.getenv("IBM_QUANTUM_TOKEN", None)
INSTANCE = "my_hub/my_group/my_project"
_BASE_PATH = Path(__file__).parent.resolve()
DATA_DIR = _BASE_PATH / "dogs_vs_cats_reduced_0.01"
STATE_FILE = _BASE_PATH / "model_state.pt"
BUCKET_NAME = "my_s3_bucket"
DEPS_PIP = ["torch", "torchvision", "qiskit", "qiskit_ibm_runtime"]
EXECUTOR = ct.executor.AWSBatchExecutor(
credentials="/Users/username/.aws/credentials",
s3_bucket_name=BUCKET_NAME,
batch_job_log_group_name="my_log_group",
batch_queue="my_batch_queue",
memory=30,
num_gpus=1,
poll_freq=3,
retry_attempts=1,
time_limit=25000,
vcpu=4,
)
S3_STRATEGY = ct.fs_strategies.S3(
credentials="/Users/username/.aws/credentials",
region_name="us-east-1"
)
FT_1 = ct.fs.FileTransfer( # download training/validation data from S3 bucket
from_file=f"s3://{BUCKET_NAME}/{DATA_DIR.name}.zip",
to_file=f"{DATA_DIR.name}.zip",
order=ct.fs.Order.BEFORE,
strategy=S3_STRATEGY
)
FT_2 = ct.fs.FileTransfer( # upload model state to S3 bucket
from_file=STATE_FILE.name,
to_file=f"s3://{BUCKET_NAME}/{STATE_FILE.name}",
order=ct.fs.Order.AFTER,
strategy=S3_STRATEGY
)
FT_3 = ct.fs.FileTransfer( # download model state from S3 bucket
from_file=f"s3://{BUCKET_NAME}/{STATE_FILE.name}",
to_file=STATE_FILE.name,
order=ct.fs.Order.BEFORE,
strategy=S3_STRATEGY
)
class ParametricQC:
"""simplify interface for getting expectation value from quantum circuit"""
RETRY_MAX: int = 5
runs_total: int = 0
calls_total: int = 0
def __init__(
self,
n_qubits: int,
shift: float,
estimator: Estimator,
):
self.n_qubits = n_qubits
self.shift = shift
self.estimator = estimator
self._init_circuit_and_observable()
def _init_circuit_and_observable(self):
qr = QuantumRegister(size=self.n_qubits)
self.circuit = QuantumCircuit(qr)
self.circuit.barrier()
self.circuit.h(range(self.n_qubits))
self.thetas = []
for i in range(self.n_qubits):
theta = Parameter(f"theta{i}")
self.circuit.ry(theta, i)
self.thetas.append(theta)
self.circuit.assign_parameters({theta: 0.0 for theta in self.thetas})
self.obs = SparsePauliOp("Z" * self.n_qubits)
def run(self, inputs: Tensor) -> Tensor:
"""use inputs as parameters to compute expectation"""
parameter_values = inputs.tolist()
circuits_batch = [self.circuit] * len(parameter_values)
observables = [self.obs] * len(parameter_values)
exps = self._run(parameter_values, circuits_batch, observables).result()
return torch.tensor(exps.values).unsqueeze(dim=0).T
def _run(
self,
parameter_values: List[Any],
circuits: List[QuantumCircuit],
observables: List[SparsePauliOp],
) -> RuntimeJob:
# run job inside a try-except loop and retry if something goes wrong
job = None
retries = 0
while retries < ParametricQC.RETRY_MAX:
try:
job = self.estimator.run(
circuits=circuits,
observables=observables,
parameter_values=parameter_values
)
break
except RuntimeError as re:
warnings.warn(
f"job failed on attempt {retries + 1}:\n\n'{re}'\nresubmitting...",
category=UserWarning
)
retries += 1
finally:
ParametricQC.runs_total += len(circuits)
ParametricQC.calls_total += 1
if job is None:
raise RuntimeError(f"job failed after {retries + 1} retries")
return job
class QuantumFunction(torch.autograd.Function):
"""custom autograd function that uses a quantum circuit"""
@staticmethod
def forward(
ctx,
batch_inputs: Tensor,
qc: ParametricQC,
) -> Tensor:
"""forward pass computation"""
ctx.save_for_backward(batch_inputs)
ctx.qc = qc
return qc.run(batch_inputs)
@staticmethod
def backward(
ctx,
grad_output: Tensor
):
"""backward pass computation using parameter shift rule"""
batch_inputs = ctx.saved_tensors[0]
qc = ctx.qc
shifted_inputs_r = torch.empty(batch_inputs.shape)
shifted_inputs_l = torch.empty(batch_inputs.shape)
# loop over each input in the batch
for i, _input in enumerate(batch_inputs):
# loop entries in each input
for j in range(len(_input)):
# compute parameters for parameter shift rule
d = torch.zeros(_input.shape)
d[j] = qc.shift
shifted_inputs_r[i, j] = _input + d
shifted_inputs_l[i, j] = _input - d
# run gradients in batches
exps_r = qc.run(shifted_inputs_r)
exps_l = qc.run(shifted_inputs_l)
return (exps_r - exps_l).float() * grad_output.float(), None, None
class QuantumLayer(torch.nn.Module):
"""a neural network layer containing a quantum function"""
def __init__(
self,
n_qubits: int,
estimator: Estimator,
):
super().__init__()
self.qc = ParametricQC(
n_qubits=n_qubits,
shift=torch.pi / 2,
estimator=estimator,
)
def forward(self, xs: Tensor) -> Tensor:
"""forward pass computation"""
result = QuantumFunction.apply(xs, self.qc)
if xs.shape[0] == 1:
return result.view((1, 1))
return result
@property
def qc_counts(self) -> dict:
"""counts total number of circuits"""
return {
"n_qubits": self.qc.n_qubits,
"runs_total": ParametricQC.runs_total,
"calls_total": ParametricQC.calls_total
}
def _get_model(
n_qubits: int,
pretrained: bool,
backend: Optional[IBMBackend] = None,
options: Optional[Options] = None,
) -> nn.Sequential:
"""prepare an instance of a ResNet model"""
if pretrained:
# with pre-trained weights
resnet_model = resnet34(weights="ResNet34_Weights.DEFAULT")
for params in resnet_model.parameters():
params.requires_grad_ = False
else:
resnet_model = resnet34()
# modify final layer to output size 1
resnet_model.fc = nn.Linear(resnet_model.fc.in_features, n_qubits)
# append final quantum layer
if backend and options:
estimator = Estimator(session=backend, options=options)
else:
from qiskit.primitives import Estimator as _Estimator
estimator = _Estimator(options=options)
# initialize sequential neural network model
model = nn.Sequential(
resnet_model,
QuantumLayer(n_qubits, estimator),
)
model.to("cuda" if torch.cuda.is_available() else "cpu")
return model
def _get_transform(image_size: int) -> transforms.Compose:
"""get transformations for image data"""
return transforms.Compose([
transforms.Resize((image_size, image_size)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
def _dataloader(
kind: str,
batch_size: int,
image_size: int,
base_dir: Optional[Path] = None,
shuffle: bool = True,
) -> DataLoader:
"""prepare data loaders for train and test data"""
transform = _get_transform(image_size)
if base_dir is None:
base_dir = Path(".").resolve()
def _g(x):
# rescales target labels from {0,1} to {-1,1}
return 2 * x - 1
train_dir = base_dir / DATA_DIR.name / "training"
if kind == "train":
return DataLoader(
ImageFolder(train_dir, transform=transform, target_transform=_g),
shuffle=shuffle,
batch_size=batch_size,
)
test_dir = base_dir / DATA_DIR.name / "validation"
if kind == "test":
return DataLoader(
ImageFolder(test_dir, transform=transform, target_transform=_g),
shuffle=shuffle,
batch_size=batch_size
)
raise ValueError("parameter `kind` must be 'train' or 'test'.")
def _init_ibm_runtime(
backend_name: str,
n_qubits: int,
n_shots: int
) -> Tuple[IBMBackend, Options]:
"""Initialize the account; instantiate the estimator"""
service = QiskitRuntimeService(
channel="ibm_quantum",
token=TOKEN,
instance=INSTANCE,
)
# select remote backend
if backend_name == "least_busy":
backend = service.least_busy(n_qubits)
else:
backend = service.backend(backend_name)
# set options
estimator_options = Options()
estimator_options.execution.shots = n_shots
return backend, estimator_options
@dataclass
class TrainingResult:
"""container for training result and metadata"""
backend_name: str
n_qubits: int
n_shots: int
n_epochs: int
batch_size: int
image_size: int
learning_rate: float
runs_total: int
calls_total: int
pretrained: bool
saved_state_filename: str
n_tested: int = 0
n_correct: int = 0
losses: List[float] = field(repr=False, default_factory=list)
epoch_losses: List[float] = field(repr=False, default_factory=list)
@ct.electron(executor=EXECUTOR, deps_pip=DEPS_PIP, files=[FT_1, FT_2])
def train_model(
backend_name: str,
n_qubits: int,
n_shots: int,
n_epochs: int,
batch_size: int,
image_size: int,
learning_rate: float,
pretrained: bool,
save_state: str,
base_dir: Optional[Path] = None,
run_local: bool = False,
files=[],
) -> TrainingResult:
"""run training and testing (validation)"""
# extract training data
if not DATA_DIR.exists():
with ZipFile(f"{DATA_DIR.name}.zip", "r") as zipped_file:
zipped_file.extractall()
losses = []
epoch_losses = []
device = "cuda" if torch.cuda.is_available() else "cpu"
if run_local:
model = _get_model(n_qubits, pretrained)
else:
backend, estimator_options = _init_ibm_runtime(backend_name, n_qubits, n_shots)
model = _get_model(n_qubits, pretrained, backend, estimator_options)
loader_train = _dataloader("train", batch_size, image_size, base_dir=base_dir)
loss_fn = L1Loss()
optimizer = Adam(model.parameters(), lr=learning_rate)
def _compute_loss(x, y):
optimizer.zero_grad()
yhat = model(x)
model.train()
loss = loss_fn(yhat, y)
loss.backward()
optimizer.step()
return yhat, loss
for epoch in range(n_epochs):
epoch_loss = 0.0
N = len(loader_train)
for i, data in enumerate(loader_train):
x_batch, y_batch = data
x_batch = x_batch.to(device)
y_batch = y_batch.unsqueeze(1).float()
y_batch = y_batch.to(device)
_, loss = _compute_loss(x_batch, y_batch)
_loss = loss.item()
epoch_loss += _loss / N
losses.append(_loss)
epoch_losses.append(epoch_loss)
if save_state:
torch.save(model.state_dict(), save_state)
qc_counts = model[-1].qc_counts
return TrainingResult(
backend_name="local_simulator" if run_local else backend_name,
n_qubits=n_qubits,
n_shots=n_shots,
n_epochs=n_epochs,
batch_size=batch_size,
image_size=image_size,
learning_rate=learning_rate,
runs_total=qc_counts["runs_total"],
calls_total=qc_counts["calls_total"],
pretrained=pretrained,
saved_state_filename=save_state,
losses=losses,
epoch_losses=epoch_losses,
)
@ct.electron(files=[FT_3])
def plot_predictions(
tr: TrainingResult,
grid_dims: Tuple[int, int] = (6, 6),
device: str = "cpu",
save_name: str = "predictions.png",
random_seed: Optional[int] = None,
files=[]
) -> TrainingResult:
"""create labelled plots of the model"""
# set non-interactive MPL backend
mpl.use(backend="Agg")
# load model with local simulator
model = _get_model(n_qubits=tr.n_qubits, pretrained=tr.pretrained)
model.load_state_dict(torch.load(tr.saved_state_filename))
model.to(device)
# set random seed optionally
if random_seed is not None:
torch.random.manual_seed(random_seed)
# create figure
fig, axes = plt.subplots(
nrows=grid_dims[0],
ncols=grid_dims[1],
figsize=(1.5 * grid_dims[0], 1.25 * grid_dims[1]),
layout="constrained"
)
n = 0
n_correct = 0
loader_test = _dataloader(
"test",
batch_size=1,
image_size=tr.image_size,
base_dir=_BASE_PATH,
)
with torch.no_grad():
model.eval()
for x, y in loader_test:
# determine index in plots grid
if n >= grid_dims[0] * grid_dims[1]:
break
i = n // grid_dims[0]
j = n % grid_dims[1]
# get model prediction and compare to target
pred = model(x)
y_pred = pred.sign()
if y_pred == y:
n_correct += 1
else:
for _, spine in axes[i][j].spines.items():
spine.set_color("red")
spine.set_linewidth(2.0)
# prepare image and label
img = x - x.min()
img /= img.max()
img = img.squeeze().permute(1, 2, 0)
label = ("CAT" if pred < 0 else "DOG") + f" ({float(pred):.4f})"
# plot image
axes[i][j].imshow(img)
axes[i][j].set_xlabel(label, fontsize=10)
axes[i][j].set_xticks([])
axes[i][j].set_yticks([])
n += 1
fig.suptitle(f"correct: {n_correct}/{n}")
fig.savefig(_BASE_PATH / save_name, dpi=96 * 4)
plt.close()
# plot training losses
fig, ax = plt.subplots(layout="constrained")
ax.plot(tr.losses)
ax.set_ylabel("Loss", fontsize=10)
ax.set_xlabel("Batch Iteration")
fig.savefig(_BASE_PATH / "loss.png", dpi=96 * 2)
plt.close()
# plot epoch losses
fig, ax = plt.subplots(layout="constrained")
ax.plot(tr.epoch_losses)
ax.set_ylabel("Ave. Loss", fontsize=10)
ax.set_xlabel("Epoch")
fig.savefig(_BASE_PATH / "epoch_loss.png", dpi=96 * 2)
plt.close()
tr.n_tested = n
tr.n_correct = n_correct
return tr
@ct.lattice
def workflow(
backend_name="ibm_nairobi",
n_qubits: int = 1,
n_shots: int = 100,
n_epochs: int = 1,
batch_size: int = 16,
image_size: int = 244,
learning_rate: float = 1e-4,
pretrained: bool = True,
save_state: str = "model_state.pt",
) -> TrainingResult:
"""
- Use remote compute + IBMQ to run training
- Use local compute to plot results
"""
if TOKEN is None:
raise EnvironmentError("IBM_QUANTUM_TOKEN is not set")
# run training
training_result = train_model(
backend_name=backend_name,
n_qubits=n_qubits,
n_shots=n_shots,
n_epochs=n_epochs,
batch_size=batch_size,
image_size=image_size,
learning_rate=learning_rate,
pretrained=pretrained,
save_state=save_state,
base_dir=None,
)
training_result = plot_predictions(training_result)
return training_result
if __name__ == "__main__":
dispatch_id = ct.dispatch(workflow)()
print(f"\n{dispatch_id}")
res = ct.get_result(dispatch_id, wait=True)
print(res)