Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adjust DC sweep plot "read point" #298

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
78 changes: 47 additions & 31 deletions silq/parameters/acquisition_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -834,8 +834,8 @@ def __init__(self, name='DC_sweep', **kwargs):

self.sweep_parameters = OrderedDict()
# Pulse to acquire trace at the end, disabled by default
self.trace_pulse = DCPulse(name='trace', duration=100e-3, enabled=False,
acquire=True, average='trace', amplitude=0)
self.trace_pulses = [DCPulse(name='trace', duration=100e-3, enabled=False,
acquire=True, average='trace', amplitude=0)]

self.pulse_duration = 1e-3
self.final_delay = 120e-3
Expand Down Expand Up @@ -876,31 +876,38 @@ def setpoints(self):
setpoints = (convert_setpoints(outer_sweep_voltages,
inner_sweep_voltages)),

if self.trace_pulse.enabled:
# Also obtain a time trace at the end
points = round(self.trace_pulse.duration * self.sample_rate)
trace_setpoints = tuple(
np.linspace(0, self.trace_pulse.duration, points))
setpoints += (convert_setpoints(trace_setpoints),)
for trace_pulse in self.trace_pulses:
if trace_pulse.enabled:
# Also obtain a time trace at the end
points = round(trace_pulse.duration * self.sample_rate)
trace_setpoints = tuple(
np.linspace(0, trace_pulse.duration, points))
setpoints += (convert_setpoints(trace_setpoints),)
return setpoints

@property_ignore_setter
def names(self):
if self.trace_pulse.enabled:
return ('DC_voltage', 'trace_voltage')
else:
return ('DC_voltage',)
names = ('DC_voltage', )
for trace_pulse in self.trace_pulses:
if trace_pulse.enabled:
names += (trace_pulse.name,)
return names

@property_ignore_setter
def labels(self):
if self.trace_pulse.enabled:
return ('DC voltage', 'Trace voltage')
else:
return ('DC voltage',)
labels = ('DC voltage',)
for trace_pulse in self.trace_pulses:
if trace_pulse.enabled:
labels += (trace_pulse.name,)
return labels

@property_ignore_setter
def units(self):
return ('V', 'V') if self.trace_pulse.enabled else ('V',)
units = ('V',)
for trace_pulse in self.trace_pulses:
if trace_pulse.enabled:
units += ('V',)
return units

@property_ignore_setter
def shapes(self):
Expand All @@ -915,33 +922,38 @@ def shapes(self):
outer_sweep_voltages = next(iter_sweep_parameters).sweep_voltages
shapes = (len(outer_sweep_voltages), len(inner_sweep_voltages)),

if self.trace_pulse.enabled:
shapes += (round(
self.trace_pulse.duration * self.sample_rate),),
for trace_pulse in self.trace_pulses:
if trace_pulse.enabled:
shapes += (round(
trace_pulse.duration * self.sample_rate),),
return shapes

@property_ignore_setter
def setpoint_names(self):
iter_sweep_parameters = reversed(self.sweep_parameters.keys())
names = tuple(iter_sweep_parameters),
if self.trace_pulse.enabled:
names += (('time',), )

for trace_pulse in self.trace_pulses:
if trace_pulse.enabled:
names += (('time',), )
return names

@property_ignore_setter
def setpoint_labels(self):
iter_sweep_parameters = reversed(
[(p if p.isupper() else p.capitalize()) for p in self.sweep_parameters.keys()])
labels = tuple(iter_sweep_parameters),
if self.trace_pulse.enabled:
labels += (('Time',),)
for trace_pulse in self.trace_pulses:
if trace_pulse.enabled:
labels += (('Time',),)
return labels

@property_ignore_setter
def setpoint_units(self):
setpoint_units = (('V',) * len(self.sweep_parameters),)
if self.trace_pulse.enabled:
setpoint_units += (('s',), )
for trace_pulse in self.trace_pulses:
if trace_pulse.enabled:
setpoint_units += (('s',), )
return setpoint_units

def add_sweep(self,
Expand Down Expand Up @@ -1081,9 +1093,12 @@ def generate(self):
raise NotImplementedError(
f"Cannot handle {len(self.sweep_parameters)} parameters")

if self.trace_pulse.enabled:
# Also obtain a time trace at the end
pulses.append(self.trace_pulse)
for trace_pulse in self.trace_pulses:
if trace_pulse.enabled:
# Explicitly add trace_pulses after each other at the end of
# the pulse sequence.
trace_pulse.t_start = pulses[-1].t_stop
pulses.append(trace_pulse)

self.pulse_sequence = PulseSequence(pulses=pulses)
self.pulse_sequence.final_delay = self.final_delay
Expand Down Expand Up @@ -1123,8 +1138,9 @@ def analyse(self,
results = {'DC_voltage':
DC_voltages.reshape(self.shapes[0])}

if self.trace_pulse.enabled:
results['trace_voltage'] = traces['trace'][self.channel_label]
for k, trace_pulse in enumerate(self.trace_pulses, 1):
if trace_pulse.enabled:
results[trace_pulse.name] = traces[trace_pulse.name][self.channel_label]

return results

Expand Down
74 changes: 58 additions & 16 deletions silq/tools/plot_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,13 @@
import logging

import qcodes as qc
from qcodes.plots.qcmatplotlib import MatPlot
from qcodes.plots.qcmatplotlib import MatPlot, align_x_axis
from qcodes.instrument.parameter import _BaseParameter
from qcodes.station import Station
from qcodes.data.data_set import DataSet
from qcodes.data.data_array import DataArray
from qcodes.utils.helpers import PerformanceTimer
from silq.meta_instruments.layout import Connection, CombinedConnection

__all__ = ['PlotAction', 'SetGates', 'MeasureSingle', 'MoveGates',
'SwitchPlotIdx', 'InteractivePlot', 'SliderPlot', 'CalibrationPlot',
Expand Down Expand Up @@ -799,7 +800,7 @@ class DCSweepPlot(ScanningPlot):
**kwargs: Additional kwargs to `InteractivePlot` and ``MatPlot``.
"""
gate_mapping = {}
point_color = 'r'
trace_ylim = (-0.1, 1.3)

# DCSweepParameter type
def __init__(self,
Expand All @@ -810,26 +811,70 @@ def __init__(self,
if gate_mapping is not None:
self.gate_mapping = gate_mapping

if parameter.trace_pulse.enabled:
subplots = (2, 1)
kwargs['gridspec_kw'] = {'height_ratios': [2, 1]}
kwargs['figsize'] = kwargs.get('figsize', (6.5, 6))
num_traces = np.count_nonzero([trace_pulse.enabled
for trace_pulse in parameter.trace_pulses])
if num_traces > 0:
subplots = (1 + num_traces, 1)
kwargs['gridspec_kw'] = {'height_ratios': [1.5] + [0.5]*num_traces}
kwargs['figsize'] = kwargs.get('figsize', (5, 4 + 1*num_traces))
else:
subplots = 1

self.point = None
self.points = {}

self.buffer_length = averages
self.buffers = [None, ] * self.buffer_length
self.buf_idx = 0

super().__init__(parameter, subplots=subplots, **kwargs)

if parameter.trace_pulse.enabled:
self[1].set_ylim(-0.1, 1.3)
self.tight_layout()
for k, trace_pulse in enumerate(parameter.trace_pulses, 1):
if trace_pulse.enabled:
self[k].set_ylim(*self.trace_ylim)
align_x_axis(self[k], self[0])

self.actions = [MoveGates(self)]

def _update_points(self, ax=None):
# This implicitly assumes the trace_pulse has a connection_label and an
# amplitude. There should be no reason that the trace_pulse will not be
# correctly initialized.
x_ref = self.x_gate.get_latest()
y_ref = self.y_gate.get_latest()

for k, trace_pulse in enumerate(self.parameter.trace_pulses):
if trace_pulse.enabled:
connection = self.layout.get_connection(
trace_pulse.connection_label)

# Add scaled offset for "read point" in diagram.
# Since pulse is already scaled to device voltages, we
# only need to apply the combination scaling.
new_x = x_ref
new_y = y_ref

if isinstance(connection, CombinedConnection):
A = trace_pulse.amplitude
for con, scale in zip(connection.connections,
connection.scale):
if self.x_gate.name == con.label:
new_x += A * scale
elif self.y_gate.name == con.label:
new_y += A * scale

if trace_pulse.name not in self.points:
assert ax is not None, "For the initial point to be drawn, axes must" \
"be provided."
self.points[trace_pulse.name] = ax.plot(new_x, new_y,
marker='o', linestyle='',
color=f'C{k}', ms=5,
label=trace_pulse.name)[0]
else:
self.points[trace_pulse.name].set_xdata(new_x)
self.points[trace_pulse.name].set_ydata(new_y)
else:
pass

def update_plot(self, initialize=False):
"""Update plot with new 2D DC scan.

Expand Down Expand Up @@ -868,9 +913,7 @@ def update_plot(self, initialize=False):
self.x_gate = getattr(self.station, self.x_label)
self.y_gate = getattr(self.station, self.y_label)

self.point = self[k].plot(self.x_gate.get_latest(),
self.y_gate.get_latest(),
'o' + self.point_color, ms=5)[0]
self._update_points(ax=self[k])
else:
self[k].add(result, x=setpoints[0],
xlabel=setpoint_names[0],
Expand All @@ -884,11 +927,10 @@ def update_plot(self, initialize=False):
result_config['x'] = self.parameter.setpoints[k][1]
result_config['y'] = self.parameter.setpoints[k][0]
result_config['z'] = result
if self.point is not None:
self.point.set_xdata(self.x_gate.get_latest())
self.point.set_ydata(self.y_gate.get_latest())
self._update_points()
else:
result_config['y'] = result
result_config['x'] = self.parameter.setpoints[k][0]

super().update_plot()

Expand Down