Reputation: 33
I'm trying to couple several Quadratic integrate-and-fire neurons.
My script works successfully with two neurons, but when I modified the script for 3 neurons, I noticed that the voltage of the third neuron suddenly explodes and therefore, the integration fails.
I did some basic analysis, and looking at the solution array, my guess is that the event detection of scipy.solve_ivp can't detect when two neurons fire at the same time. My reason for saying this is that the 2nd and 3rd neurons' firing rate should be identical, since they only neuron with an external current is the 1st one.
However, while they both fire together, the event detection only detects one event, and therefore fails to reset the voltage of the other one, hence the exponential growth.
My ultimate goal would be to couple this with other types of neurons, but since many of those have intrinsic repolarization dynamics, event handling of QIFs is the crucial part of scaling the network.
import numpy as np
import matplotlib.pyplot as plt
from scipy.integrate import solve_ivp
# Define vectors, indices and parameters
resetV = -0.1
nIN = 3
incIN = nIN
ylen = nIN*(incIN)
indIN = np.arange(0,ylen,incIN)
INs = np.arange(0,nIN)
gI = -0.4
Ileak = 0.5
# Define heaviside function for synaptic gates (just a continuous step function)
def heaviside(v,thresh):
H = 0.5*(1 +np.tanh((v-thresh)/1e-8))
return H
# Define event functions and set them as terminal
def event(t, y):
return y[indIN[0]] - 2
event.terminal = True
def event2(t,y):
return y[indIN[1]] - 2
event2.terminal = True
def event3(t,y):
return y[indIN[2]] - 2
event3.terminal = True
#ODE function
def Network(t,y):
V1 = y[0]
n11 = y[1]
n12 = y[2]
V2 = y[3]
n21 = y[4]
n22 = y[5]
V3 = y[6]
n31 = y[7]
n32 = y[8]
H = heaviside(np.array([V1,V2,V3]),INthresh)
dydt = [V1*V1 - gI*n11*(V2)- gI*n12*(V3)+0.5,
H[1]*5*(1-n11) - (0.9*n11),
H[2]*5*(1-n12) - (0.9*n12),
V2*V2 -gI*n21*(V1)- gI*n22*(V3),
H[0]*5*(1-n21) - (0.9*n21),
H[2]*5*(1-n22) - (0.9*n22),
V3*V3 -gI*n31*(V1)- gI*n32*(V2),
H[0]*5*(1-n31) - (0.9*n31),
H[1]*5*(1-n32) - (0.9*n32)
]
return dydt
# Preallocation of some vectors (mostly not crucial)
INthresh = 0.5
dydt = [0]*ylen
INheavies = np.zeros((nIN,))
preInhVs = np.zeros((nIN,))
y = np.zeros((ylen,))
allt = []
ally = []
t = 0
end = 100
# Integrate until an event is hit, reset the spikes, and use the last time step and y-value to continue integration
while True:
net = solve_ivp(Network, (t, end), y, events= [event,event2,event3])
allt.append(net.t)
ally.append(net.y)
if net.status == 1:
t = net.t[-1]
y = net.y[:, -1].copy()
for i in INs:
if net.t_events[i].size != 0:
y[indIN[i]] = resetV
print('reseting V%d' %(i+1))
elif net.status == -1:
print('failed!')
print(y[0])
break
else:
break
# Putting things together and plotting
Tp = np.concatenate(ts)
Yp = np.concatenate(ys, axis=1)
fig = plt.figure(facecolor='w', edgecolor='k')
ax1 = fig.add_subplot(311)
ax2 = fig.add_subplot(312)
ax3 = fig.add_subplot(313)
ax1.plot(Tp, Yp[0].T)
ax2.plot(Tp, Yp[3].T)
ax3.plot(Tp, Yp[6].T)
plt.subplots_adjust(hspace=0.8)
plt.show()
Of course this is only a guess.
I'm currently looking to learn to work with PyDSTool, but due to deadlines, I'd like to get this script working, since even a quick and dirty implementation of a QIF neural network would do for my preliminary analysis.
I'm a student of biology, and only know a bit of Python and MATLAB, but I'd appreciate any input regardless.
Upvotes: 3
Views: 1435
Reputation: 25992
You are indeed correct, solve_ivp
does not detect additional events that happen at the same time (outside of situations where you duplicate a component as here it is highly unlikely to arrive at such a situation in a numerical simulation). You can test this manually, as an event is a root of the event function. So set
def gen_event(i):
def event(t, y):
return y[indIN[i]] - 2
event.terminal = True
return event
events = [gen_event(i) for i in range(3)]
and replace the test for which function triggers an event by
t = net.t[-1]
y = net.y[:, -1].copy()
for i in INs:
if abs(events[i](t,y)) < 1e-12:
y[indIN[i]] = resetV
print(f'reseting V{i+1} at time {net.t_events[i]}')
This then also captures the double events and results in the plots
Upvotes: 2