diff --git a/silq/parameters/acquisition_parameters.py b/silq/parameters/acquisition_parameters.py index 759a31fd3..08442cf8a 100644 --- a/silq/parameters/acquisition_parameters.py +++ b/silq/parameters/acquisition_parameters.py @@ -834,8 +834,8 @@ def __init__(self, name='DC_sweep', **kwargs): self.sweep_parameters = OrderedDict() # Pulse to acquire trace at the end, disabled by default - self.trace_pulse = DCPulse(name='trace', duration=100e-3, enabled=False, - acquire=True, average='trace', amplitude=0) + self.trace_pulses = [DCPulse(name='trace', duration=100e-3, enabled=False, + acquire=True, average='trace', amplitude=0)] self.pulse_duration = 1e-3 self.final_delay = 120e-3 @@ -876,31 +876,38 @@ def setpoints(self): setpoints = (convert_setpoints(outer_sweep_voltages, inner_sweep_voltages)), - if self.trace_pulse.enabled: - # Also obtain a time trace at the end - points = round(self.trace_pulse.duration * self.sample_rate) - trace_setpoints = tuple( - np.linspace(0, self.trace_pulse.duration, points)) - setpoints += (convert_setpoints(trace_setpoints),) + for trace_pulse in self.trace_pulses: + if trace_pulse.enabled: + # Also obtain a time trace at the end + points = round(trace_pulse.duration * self.sample_rate) + trace_setpoints = tuple( + np.linspace(0, trace_pulse.duration, points)) + setpoints += (convert_setpoints(trace_setpoints),) return setpoints @property_ignore_setter def names(self): - if self.trace_pulse.enabled: - return ('DC_voltage', 'trace_voltage') - else: - return ('DC_voltage',) + names = ('DC_voltage', ) + for trace_pulse in self.trace_pulses: + if trace_pulse.enabled: + names += (trace_pulse.name,) + return names @property_ignore_setter def labels(self): - if self.trace_pulse.enabled: - return ('DC voltage', 'Trace voltage') - else: - return ('DC voltage',) + labels = ('DC voltage',) + for trace_pulse in self.trace_pulses: + if trace_pulse.enabled: + labels += (trace_pulse.name,) + return labels @property_ignore_setter def units(self): - return ('V', 'V') if self.trace_pulse.enabled else ('V',) + units = ('V',) + for trace_pulse in self.trace_pulses: + if trace_pulse.enabled: + units += ('V',) + return units @property_ignore_setter def shapes(self): @@ -915,17 +922,20 @@ def shapes(self): outer_sweep_voltages = next(iter_sweep_parameters).sweep_voltages shapes = (len(outer_sweep_voltages), len(inner_sweep_voltages)), - if self.trace_pulse.enabled: - shapes += (round( - self.trace_pulse.duration * self.sample_rate),), + for trace_pulse in self.trace_pulses: + if trace_pulse.enabled: + shapes += (round( + trace_pulse.duration * self.sample_rate),), return shapes @property_ignore_setter def setpoint_names(self): iter_sweep_parameters = reversed(self.sweep_parameters.keys()) names = tuple(iter_sweep_parameters), - if self.trace_pulse.enabled: - names += (('time',), ) + + for trace_pulse in self.trace_pulses: + if trace_pulse.enabled: + names += (('time',), ) return names @property_ignore_setter @@ -933,15 +943,17 @@ def setpoint_labels(self): iter_sweep_parameters = reversed( [(p if p.isupper() else p.capitalize()) for p in self.sweep_parameters.keys()]) labels = tuple(iter_sweep_parameters), - if self.trace_pulse.enabled: - labels += (('Time',),) + for trace_pulse in self.trace_pulses: + if trace_pulse.enabled: + labels += (('Time',),) return labels @property_ignore_setter def setpoint_units(self): setpoint_units = (('V',) * len(self.sweep_parameters),) - if self.trace_pulse.enabled: - setpoint_units += (('s',), ) + for trace_pulse in self.trace_pulses: + if trace_pulse.enabled: + setpoint_units += (('s',), ) return setpoint_units def add_sweep(self, @@ -1081,9 +1093,12 @@ def generate(self): raise NotImplementedError( f"Cannot handle {len(self.sweep_parameters)} parameters") - if self.trace_pulse.enabled: - # Also obtain a time trace at the end - pulses.append(self.trace_pulse) + for trace_pulse in self.trace_pulses: + if trace_pulse.enabled: + # Explicitly add trace_pulses after each other at the end of + # the pulse sequence. + trace_pulse.t_start = pulses[-1].t_stop + pulses.append(trace_pulse) self.pulse_sequence = PulseSequence(pulses=pulses) self.pulse_sequence.final_delay = self.final_delay @@ -1123,8 +1138,9 @@ def analyse(self, results = {'DC_voltage': DC_voltages.reshape(self.shapes[0])} - if self.trace_pulse.enabled: - results['trace_voltage'] = traces['trace'][self.channel_label] + for k, trace_pulse in enumerate(self.trace_pulses, 1): + if trace_pulse.enabled: + results[trace_pulse.name] = traces[trace_pulse.name][self.channel_label] return results diff --git a/silq/tools/plot_tools.py b/silq/tools/plot_tools.py index f1df14662..60564dac5 100644 --- a/silq/tools/plot_tools.py +++ b/silq/tools/plot_tools.py @@ -10,12 +10,13 @@ import logging import qcodes as qc -from qcodes.plots.qcmatplotlib import MatPlot +from qcodes.plots.qcmatplotlib import MatPlot, align_x_axis from qcodes.instrument.parameter import _BaseParameter from qcodes.station import Station from qcodes.data.data_set import DataSet from qcodes.data.data_array import DataArray from qcodes.utils.helpers import PerformanceTimer +from silq.meta_instruments.layout import Connection, CombinedConnection __all__ = ['PlotAction', 'SetGates', 'MeasureSingle', 'MoveGates', 'SwitchPlotIdx', 'InteractivePlot', 'SliderPlot', 'CalibrationPlot', @@ -799,7 +800,7 @@ class DCSweepPlot(ScanningPlot): **kwargs: Additional kwargs to `InteractivePlot` and ``MatPlot``. """ gate_mapping = {} - point_color = 'r' + trace_ylim = (-0.1, 1.3) # DCSweepParameter type def __init__(self, @@ -810,26 +811,70 @@ def __init__(self, if gate_mapping is not None: self.gate_mapping = gate_mapping - if parameter.trace_pulse.enabled: - subplots = (2, 1) - kwargs['gridspec_kw'] = {'height_ratios': [2, 1]} - kwargs['figsize'] = kwargs.get('figsize', (6.5, 6)) + num_traces = np.count_nonzero([trace_pulse.enabled + for trace_pulse in parameter.trace_pulses]) + if num_traces > 0: + subplots = (1 + num_traces, 1) + kwargs['gridspec_kw'] = {'height_ratios': [1.5] + [0.5]*num_traces} + kwargs['figsize'] = kwargs.get('figsize', (5, 4 + 1*num_traces)) else: subplots = 1 - self.point = None + self.points = {} self.buffer_length = averages self.buffers = [None, ] * self.buffer_length self.buf_idx = 0 super().__init__(parameter, subplots=subplots, **kwargs) - - if parameter.trace_pulse.enabled: - self[1].set_ylim(-0.1, 1.3) + self.tight_layout() + for k, trace_pulse in enumerate(parameter.trace_pulses, 1): + if trace_pulse.enabled: + self[k].set_ylim(*self.trace_ylim) + align_x_axis(self[k], self[0]) self.actions = [MoveGates(self)] + def _update_points(self, ax=None): + # This implicitly assumes the trace_pulse has a connection_label and an + # amplitude. There should be no reason that the trace_pulse will not be + # correctly initialized. + x_ref = self.x_gate.get_latest() + y_ref = self.y_gate.get_latest() + + for k, trace_pulse in enumerate(self.parameter.trace_pulses): + if trace_pulse.enabled: + connection = self.layout.get_connection( + trace_pulse.connection_label) + + # Add scaled offset for "read point" in diagram. + # Since pulse is already scaled to device voltages, we + # only need to apply the combination scaling. + new_x = x_ref + new_y = y_ref + + if isinstance(connection, CombinedConnection): + A = trace_pulse.amplitude + for con, scale in zip(connection.connections, + connection.scale): + if self.x_gate.name == con.label: + new_x += A * scale + elif self.y_gate.name == con.label: + new_y += A * scale + + if trace_pulse.name not in self.points: + assert ax is not None, "For the initial point to be drawn, axes must" \ + "be provided." + self.points[trace_pulse.name] = ax.plot(new_x, new_y, + marker='o', linestyle='', + color=f'C{k}', ms=5, + label=trace_pulse.name)[0] + else: + self.points[trace_pulse.name].set_xdata(new_x) + self.points[trace_pulse.name].set_ydata(new_y) + else: + pass + def update_plot(self, initialize=False): """Update plot with new 2D DC scan. @@ -868,9 +913,7 @@ def update_plot(self, initialize=False): self.x_gate = getattr(self.station, self.x_label) self.y_gate = getattr(self.station, self.y_label) - self.point = self[k].plot(self.x_gate.get_latest(), - self.y_gate.get_latest(), - 'o' + self.point_color, ms=5)[0] + self._update_points(ax=self[k]) else: self[k].add(result, x=setpoints[0], xlabel=setpoint_names[0], @@ -884,11 +927,10 @@ def update_plot(self, initialize=False): result_config['x'] = self.parameter.setpoints[k][1] result_config['y'] = self.parameter.setpoints[k][0] result_config['z'] = result - if self.point is not None: - self.point.set_xdata(self.x_gate.get_latest()) - self.point.set_ydata(self.y_gate.get_latest()) + self._update_points() else: result_config['y'] = result + result_config['x'] = self.parameter.setpoints[k][0] super().update_plot()