Skip to content

Commit

Permalink
Compute spikeFeatures after detect in one recording or all recordings
Browse files Browse the repository at this point in the history
  • Loading branch information
Alan Liddell committed Feb 21, 2019
1 parent 52fa1fa commit 033c331
Show file tree
Hide file tree
Showing 48 changed files with 1,024 additions and 901 deletions.
10 changes: 6 additions & 4 deletions +jrclust/+curate/@CurateController/CurateController.m
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@
%% LIFECYCLE
methods
function obj = CurateController(res)
obj.res = res;
% transfer hClust from res to cRes
obj.hClust = res.hClust;
obj.res = rmfield(res, 'hClust');

obj.hFigs = containers.Map();
obj.hMenus = containers.Map();
obj.isEnding = 0;
Expand Down Expand Up @@ -86,11 +88,11 @@ function delete(obj)
%% GETTERS/SETTERS
methods
% hCfg
function hc = get.hCfg(obj)
function val = get.hCfg(obj)
if ~isempty(obj.hClust)
hc = obj.hClust.hCfg;
val = obj.hClust.hCfg;
else
hc = [];
val = [];
end
end

Expand Down
41 changes: 41 additions & 0 deletions +jrclust/+detect/@DetectController/CARRealign.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
function [spikeWindows, spikeTimes] = CARRealign(obj, spikeWindows, samplesIn, spikeTimes, neighbors)
%CARREALIGN Realign spike peaks after applying local CAR
if ~strcmpi(obj.hCfg.getOr('vcSpkRef', 'nmean'), 'nmean')
return;
end

% find where true peaks are not in the correct place after applying local CAR
spikeWindowsCAR = jrclust.utils.localCAR(single(spikeWindows), obj.hCfg);
[shiftMe, shiftBy] = findShifted(spikeWindowsCAR, obj.hCfg);

if isempty(shiftMe)
return;
end

% adjust spike times
shiftedTimes = spikeTimes(shiftMe) - int32(shiftBy(:));
spikeTimes(shiftMe) = shiftedTimes;

% extract windows at new shifted times
spikeWindows(:, shiftMe, :) = obj.extractWindows(samplesIn, shiftedTimes, neighbors, 0);
end

%% LOCAL FUNCTIONS
function [shiftMe, shiftBy] = findShifted(spikeWindows, hCfg)
%FINDSHIFTED
% spikeWindows: nSamples x nSpikes x nSites
peakLoc = 1 - hCfg.evtWindowSamp(1);

if hCfg.detectBipolar
[~, truePeakLoc] = max(abs(spikeWindows(:, :, 1)));
else
[~, truePeakLoc] = min(spikeWindows(:, :, 1));
end

shiftMe = find(truePeakLoc ~= peakLoc);
shiftBy = peakLoc - truePeakLoc(shiftMe);

shiftOkay = (abs(shiftBy) <= 2); % throw out drastic shifts
shiftMe = shiftMe(shiftOkay);
shiftBy = shiftBy(shiftOkay);
end
151 changes: 12 additions & 139 deletions +jrclust/+detect/@DetectController/DetectController.m
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
classdef DetectController < handle
%DETECTCONTROLLER
%% CONFIGURATION
properties (SetObservable, Transient)
hCfg; % Config object
end

properties (SetAccess=private, SetObservable, Transient)
hCfg;
properties (SetAccess=private)
errMsg;
isError;
end
Expand All @@ -14,7 +17,6 @@
siteThresh;
spikeTimes;
spikeAmps;
spikeSites;
centerSites;

spikesRaw;
Expand All @@ -40,141 +42,12 @@
end
end

%% USER METHODS
methods
function res = detect(obj)
res = struct();
t0 = tic();

% get manually-set spike thresholds
if ~isempty(obj.hCfg.threshFile)
try
S = load(obj.hCfg.threshFile);
siteThresh_ = S.siteThresh;
if obj.hCfg.verbose
fprintf('Loaded %s\n', obj.hCfg.threshFile);
end
catch ME
warning('Could not load threshFile %s: %s', obj.hCfg.threshFile, ME.message);
siteThresh_ = [];
end
else
siteThresh_ = [];
end

nRecs = numel(obj.hCfg.rawRecordings);
hRecs = cell(nRecs, 1);

obj.siteThresh = cell(nRecs, 1);
obj.spikeTimes = cell(nRecs, 1);
obj.spikeAmps = cell(nRecs, 1);
obj.spikeSites = cell(nRecs, 1);
obj.centerSites = cell(nRecs, 1);
obj.spikesRaw = cell(nRecs, 1);
obj.spikesFilt = cell(nRecs, 1);
obj.spikeFeatures = cell(nRecs, 1);

recOffset = 0; % sample offset for each recording in sequence

% load from files
for iRec = 1:nRecs
t1 = tic;

fn = obj.hCfg.rawRecordings{iRec};
hRec = jrclust.models.recording.Recording(fn, obj.hCfg);

if hRec.isError
error(hRec.errMsg);
end

% subset imported samples in this recording interval
[impTimes, impSites] = deal([]);
if ~isempty(obj.importTimes)
inInterval = (obj.importTimes > recOffset & obj.importTimes <= recOffset + hRec.nSamples);
impTimes = obj.importTimes(inInterval) - recOffset; % shift spike timing

% take sites assoc with times between limits
if ~isempty(obj.importSites)
impSites = obj.importSites(inInterval);
end
end

recData = detectInRecording(hRec, impTimes, impSites, siteThresh_, obj.hCfg);
try
hRec.setDetections(recData);
catch ME % maybe rethrow
warning('error caught: %s', ME.message);
continue;
end

t1 = toc(t1);
nBytesFile = recData.nBytesFile;
tr = (nBytesFile/jrclust.utils.typeBytes(obj.hCfg.dataType)/obj.hCfg.nChans)/obj.hCfg.sampleRate;

if obj.hCfg.verbose
fprintf('File %d/%d took %0.1fs (%0.1f MB, %0.1f MB/s, x%0.1f realtime)\n', ...
iRec, nRecs, t1, nBytesFile/1e6, nBytesFile/t1/1e6, tr/t1);
end

obj.siteThresh{iRec} = recData.siteThresh;
obj.spikeTimes{iRec} = recData.spikeTimes + recOffset;
obj.spikeAmps{iRec} = recData.spikeAmps;
obj.spikeSites{iRec} = recData.spikeSites;
obj.centerSites{iRec} = recData.centerSites;
obj.spikesRaw{iRec} = recData.spikesRaw;
obj.spikesFilt{iRec} = recData.spikesFilt;
obj.spikeFeatures{iRec} = recData.spikeFeatures;

recOffset = recOffset + hRec.nSamples;
hRecs{iRec} = hRec;
end % for

res.spikeTimes = cat(1, obj.spikeTimes{:});
res.spikeAmps = cat(1, obj.spikeAmps{:});
res.siteThresh = mean(single(cat(1, obj.siteThresh{:})), 1);

% spike sites
obj.centerSites = cat(1, obj.centerSites{:});
res.spikeSites = obj.centerSites(:, 1);
if size(obj.centerSites, 2) > 1
res.spikeSites2 = obj.centerSites(:, 2);
else
res.spikeSites2 = [];
end

% spikes by site
nSites = obj.hCfg.nSites;
res.spikesBySite = arrayfun(@(iSite) find(obj.centerSites(:, 1) == iSite), 1:nSites, 'UniformOutput', 0);
if size(obj.centerSites, 2) >= 2
res.spikesBySite2 = arrayfun(@(iSite) find(obj.centerSites(:, 2) == iSite), 1:nSites, 'UniformOutput', 0);
else
res.spikesBySite2 = cell(1, nSites);
end
if size(obj.centerSites, 2) == 3
res.spikesBySite3 = arrayfun(@(iSite) find(obj.centerSites(:, 3) == iSite), 1:nSites, 'UniformOutput', 0);
else
res.spikesBySite3 = [];
end

% detected spikes (raw and filtered), features
res.spikesRaw = cat(3, obj.spikesRaw{:});
res.rawShape = size(res.spikesRaw);

res.spikesFilt = cat(3, obj.spikesFilt{:});
res.filtShape = size(res.spikesFilt);

res.spikeFeatures = cat(3, obj.spikeFeatures{:});
res.featuresShape = size(res.spikeFeatures);

% spike positions
res.spikePositions = spikePos(res.spikeSites, res.spikeFeatures, obj.hCfg);

% recordings for inspection
res.hRecs = hRecs;

% summarize
res.detectTime = toc(t0);
res.detectedOn = now();
end
%% UTILITY METHODS
methods (Access=protected, Hidden)
[spikeWindowsOut, spikeWindows2Out] = cancelOverlap(obj, spikeWindows, spikeWindows2, spikeTimes, spikeSites, spikeSites2, siteThresh);
[spikeWindows, spikeTimes] = CARRealign(obj, spikeWindows, samplesIn, spikeTimes, neighbors);
[windows, timeRanges] = extractWindows(obj, samplesIn, spTimes, spSites, fRaw);
[spikeSites2, spikeSites3] = findSecondaryPeaks(obj, spikeWindows, spikeSites);
spikeWindows = samplesToWindows2(obj, samplesIn, spikeSites, spikeTimes);
end
end
Original file line number Diff line number Diff line change
@@ -1,36 +1,34 @@
% 12/16/17 JJJ: Find overlapping spikes and set superthreshold sample points to zero in the overlapping region
function [tnWav_spk_out, tnWav_spk2_out] = cancel_overlap_spk_(spikeWindows, spikeWindows2, spikeTimes, spikeSites, spikeSites2, siteThresh, hCfg)
function [spikeWindowsOut, spikeWindows2Out] = cancelOverlap(obj, spikeWindows, spikeWindows2, spikeTimes, spikeSites, spikeSites2, siteThresh)
% Overlap detection. only return one stronger than other
[spikeTimes, spikeWindows, spikeWindows2] = jrclust.utils.tryGather(spikeTimes, spikeWindows, spikeWindows2);
[viSpk_ol_spk, vnDelay_ol_spk] = findPotentialOverlaps(spikeTimes, spikeSites, hCfg);
[tnWav_spk_out, tnWav_spk2_out] = deal(spikeWindows, spikeWindows2);
[viSpk_ol_spk, vnDelay_ol_spk] = findPotentialOverlaps(spikeTimes, spikeSites, obj.hCfg);
[spikeWindowsOut, spikeWindows2Out] = deal(spikeWindows, spikeWindows2);

% find spike index that are larger and fit and deploy
viSpk_ol_a = find(viSpk_ol_spk>0); % later occuring
[viSpk_ol_b, vnDelay_ol_b] = deal(viSpk_ol_spk(viSpk_ol_a), vnDelay_ol_spk(viSpk_ol_a)); % first occuring
viTime_spk0 = int32(hCfg.evtWindowSamp(1):hCfg.evtWindowSamp(2));
viTime_spk0 = int32(obj.hCfg.evtWindowSamp(1):obj.hCfg.evtWindowSamp(2));
siteThresh = jrclust.utils.tryGather(-abs(siteThresh(:))');

% for each pair identify time range where threshold crossing occurs and set to zero
% correct only first occuring (b)
siteNeighbors = hCfg.siteNeighbors;
nSpk_ol = numel(viSpk_ol_a);
nSpk = size(spikeWindows,2);

for iSpk_ol = 1:nSpk_ol
[iSpk_b, nDelay_b] = deal(viSpk_ol_b(iSpk_ol), vnDelay_ol_b(iSpk_ol));
viSite_b = siteNeighbors(:,spikeSites(iSpk_b));
mnWav_b = tnWav_spk_out(nDelay_b+1:end,:,iSpk_b);
viSite_b = obj.hCfg.siteNeighbors(:,spikeSites(iSpk_b));
mnWav_b = spikeWindowsOut(nDelay_b+1:end,:,iSpk_b);
mlWav_b = bsxfun(@le, mnWav_b, siteThresh(viSite_b));
mnWav_b(mlWav_b) = 0;
tnWav_spk_out(nDelay_b+1:end,:,iSpk_b) = mnWav_b;
spikeWindowsOut(nDelay_b+1:end,:,iSpk_b) = mnWav_b;

if ~isempty(spikeWindows2)
viSite_b = siteNeighbors(:,spikeSites2(iSpk_b));
mnWav_b = tnWav_spk2_out(nDelay_b+1:end,:,iSpk_b);
viSite_b = obj.hCfg.siteNeighbors(:,spikeSites2(iSpk_b));
mnWav_b = spikeWindows2Out(nDelay_b+1:end,:,iSpk_b);
mlWav_b = bsxfun(@le, mnWav_b, siteThresh(viSite_b));
mnWav_b(mlWav_b) = 0;
tnWav_spk2_out(nDelay_b+1:end,:,iSpk_b) = mnWav_b;
spikeWindows2Out(nDelay_b+1:end,:,iSpk_b) = mnWav_b;
end
end

Expand Down
11 changes: 11 additions & 0 deletions +jrclust/+detect/@DetectController/computeThreshold.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
function siteThresh = computeThreshold(obj, samplesIn)
%COMPUTETHRESHOLD Compute sitewise threshold for samplesIn
try
siteThresh = jrclust.utils.estimateRMS(samplesIn, 1e5)*obj.hCfg.qqFactor;
siteThresh = int16(jrclust.utils.tryGather(siteThresh));
catch ME
warning('GPU threshold computation failed: %s (retrying in CPU)', ME.message);
obj.hCfg.useGPU = 0;
siteThresh = int16(jrclust.utils.estimateRMS(jrclust.utils.tryGather(samplesIn), 1e5)*obj.hCfg.qqFactor);
end
end
Loading

0 comments on commit 033c331

Please sign in to comment.