Replies: 4 comments 25 replies
-
I am not sure about your use case, could you explain the feature you desire in more detail? |
Beta Was this translation helpful? Give feedback.
-
I understand the pros of JIT (but once a graph is jitted, it becomes static? ), and there are also use cases for dynamic graphs which are handy. I changed the PyTorch MNIST example in TC's tutorial to do the following: note the new function "def qpred_1step" and the for loop in "def forward" of "class QuantumNet". After these changes, the following code snippet is closer to how I would do a similar PQC in TorchQuantum, but not sure if this is a correct usage in TensorCircuit. For example, since the circuit c is defined outside of both QuantumNet and qpred_1step, e.g., would each call of qpred_1step still add gates to the same circuit c? Would the changes made break anything in TensorCircuit.Any comments, suggestions welcome. (Just testing the logic, and the following runs, but not everything is correct.) Additionally, PyTorch is in the process of adding native JIT and vmap support, is there any plan for adding native PyTorch support in TensorCircuit w/o needing TensorFlow? import time
import numpy as np
import tensorflow as tf
import torch
import tensorcircuit as tc
K = tc.set_backend("tensorflow")
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = x_train[..., np.newaxis] / 255.0
def filter_pair(x, y, a, b):
keep = (y == a) | (y == b)
x, y = x[keep], y[keep]
y = y == a
return x, y
x_train, y_train = filter_pair(x_train, y_train, 1, 5)
x_train_small = tf.image.resize(x_train, (3, 3)).numpy()
x_train_bin = np.array(x_train_small > 0.5, dtype=np.float32)
x_train_bin = np.squeeze(x_train_bin).reshape([-1, 9])
y_train_torch = torch.tensor(y_train, dtype=torch.float32)
x_train_torch = torch.tensor(x_train_bin)
x_train_torch.shape, y_train_torch.shape
n = 9
nlayers = 3
# We define the quantum function,
# note how this function is running on tensorflow
c = tc.Circuit(9)
def qpred_1step(x, weights): # each call of this constructs one layer of c
for k in range(n):
c.rx(k, theta=x[k])
for k in range(n - 1):
c.cnot(k, k + 1)
for k in range(n):
c.rx(k, theta=weights[2 * 0, k])
c.ry(k, theta=weights[2 * 0 + 1, k])
ypred = c.expectation_ps(z=[n // 2])
ypred = K.real(ypred)
return K.sigmoid(ypred)
# Wrap the function into pytorch form but with tensorflow speed!
qpred_torch = tc.interfaces.torch_interface(qpred_1step, jit=True)
class QuantumNet(torch.nn.Module):
def __init__(self):
super().__init__()
self.q_weights = torch.nn.Parameter(torch.randn([2 * nlayers, n]))
def forward(self, inputs):
for j in range(nlayers): # new for loop here
ypred = qpred_torch(inputs, self.q_weights)
return ypred
net = QuantumNet()
# net(x_train_torch[0])
criterion = torch.nn.BCELoss()
opt = torch.optim.Adam(net.parameters(), lr=1e-2)
nepochs = 500
nbatch = 32
times = []
for epoch in range(nepochs):
index = np.random.randint(low=0, high=100, size=nbatch)
# index = np.arange(nbatch)
inputs, labels = x_train_torch[index], y_train_torch[index]
opt.zero_grad()
with torch.set_grad_enabled(True):
time0 = time.time()
yps = []
for i in range(nbatch):
yp = net(inputs[i])
yps.append(yp)
yps = torch.stack(yps)
loss = criterion(
torch.reshape(yps, [nbatch, 1]), torch.reshape(labels, [nbatch, 1])
)
loss.backward()
if epoch % 100 == 0:
print(loss)
opt.step()
time1 = time.time()
times.append(time1 - time0)
print("training time per step: ", np.mean(time1 - time0)) |
Beta Was this translation helpful? Give feedback.
-
yes, many of these use cases can be transformed into static graph smartly, and if not, of course you can just use dynamic graph as long as you don't dress function with
One can now also run on PyTorch backend if speed is not the concern (no vmap and jit though). Last time I try jit and vmap feature of pytorch, they are still buggy and not robust to use. If these features are mature, it is actually very easy to integrate into tensorcircuit (less than 100 lines of code). Still not very sure of the code, for j in range(nlayers): # new for loop here
ypred = qpred_torch(inputs, self.q_weights) This one is not a loop right? Because ypred is not sent to anywhere, and the loop for nlayer times make no sense. You mean
namely, you want this or want to avoid this? |
Beta Was this translation helpful? Give feedback.
-
Here is some pseudocode: import time
import numpy as np
import tensorflow as tf
import torch
import tensorcircuit as tc
K = tc.set_backend("pytorch")
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = x_train[..., np.newaxis] / 255.0
def filter_pair(x, y, a, b):
keep = (y == a) | (y == b)
x, y = x[keep], y[keep]
y = y == a
return x, y
x_train, y_train = filter_pair(x_train, y_train, 1, 5)
x_train_small = tf.image.resize(x_train, (3, 3)).numpy()
x_train_bin = np.array(x_train_small > 0.5, dtype=np.float32)
x_train_bin = np.squeeze(x_train_bin).reshape([-1, 9])
y_train_torch = torch.tensor(y_train, dtype=torch.float32)
x_train_torch = torch.tensor(x_train_bin)
x_train_torch.shape, y_train_torch.shape
n = 9
# We define the quantum function,
# note how this function is running on tensorflow
def qpred_1step(x, c): # each call of this constructs one layer of c
for k in range(n):
c.rx(k, theta=x[k])
for k in range(n - 1):
c.cnot(k, k + 1)
ypred = c.expectation_ps(z=[n // 2])
ypred = K.real(ypred)
return ypred
class QuantumNet(torch.nn.Module):
def __init__(self):
super().__init__()
self.depth = 20
self.c = tc.Circuit(9)
def forward(self, inputs):
ypreds = []
for j in range(self.depth): # new for loop here
ypreds.append(qpred_1step(inputs, self.c))
return ypreds
net = QuantumNet()
# net(x_train_torch[0])
criterion = torch.nn.BCELoss()
opt = torch.optim.Adam(net.parameters(), lr=1e-2)
nepochs = 500
nbatch = 32
times = []
for epoch in range(nepochs):
index = np.random.randint(low=0, high=100, size=nbatch)
# index = np.arange(nbatch)
inputs, labels = x_train_torch[index], y_train_torch[index]
opt.zero_grad()
with torch.set_grad_enabled(True):
time0 = time.time()
yps = []
for i in range(nbatch):
yp = net(inputs[i])
yps.append(yp)
yps = torch.stack(yps)
loss = criterion(
torch.reshape(yps, [nbatch, 1]), torch.reshape(labels, [nbatch, 1])
)
loss.backward()
if epoch % 100 == 0:
print(loss)
opt.step()
time1 = time.time()
times.append(time1 - time0)
print("training time per step: ", np.mean(time1 - time0)) |
Beta Was this translation helpful? Give feedback.
-
Hi,
Can TensorCircuit be used to construct hybrid classical quantum computation graphs dynamically like is standard in PyTorch?
Thanks!
Beta Was this translation helpful? Give feedback.
All reactions