#!/usr/bin/env python3
import os
import sys
import json
import timeit
import numpy as np
from snudda.neurons.neuron_prototype import NeuronPrototype
from snudda.utils.numpy_encoder import NumpyEncoder
import scipy.sparse as sparse
from scipy.spatial import distance_matrix
[docs]
class SnuddaLoad(object):
"""
Load data from network-neuron-positions.hdf5 or network-neuron-synapses.hdf5 into python dictionary
"""
############################################################################
def __init__(self, network_file, snudda_data=None, load_synapses=True, verbose=False):
"""
Constructor
Args:
network_file (str) : Data file to load
snudda_data (str, optional) : Snudda Data path, if you want to override the one specified in the hdf5 file
load_synapses (bool, optional) : Whether to read synapses into memory, or keep them on disk (this keeps file open)
verbose (bool, optional) : Print more info during execution
"""
# This variable will only be set if the synapses are not kept in
# memory so that we can access them later, otherwise the hdf5 file is
# automatically closed
self.hdf5_file = None
self.verbose = verbose
self.config = None
self.network_file = None
self.snudda_data = snudda_data
if network_file:
alt_file = os.path.join(network_file, "network-synapses.hdf5")
if os.path.isdir(network_file):
if not os.path.isfile(alt_file):
raise ValueError(f"Network path {network_file} specified, but no file {alt_file}")
network_file = alt_file
self.data = self.load_hdf5(network_file, load_synapses)
else:
self.data = None
############################################################################
[docs]
def close(self):
""" Close hdf5 data file. """
if self.hdf5_file:
self.hdf5_file.close()
self.hdf5_file = None
############################################################################
def __del__(self):
if self.hdf5_file is not None:
try:
self.hdf5_file.close()
except:
print("Unable to close HDF5, already closed?")
############################################################################
[docs]
@staticmethod
def to_str(data_str):
""" Helper function to convert data to string. """
if data_str is None:
return None
if type(data_str) in [bytes, np.bytes_]:
return data_str.decode()
# Warn the user if they accidentally call to_str on an int or something else
assert type(data_str) == str, f"to_str is used on strings or bytes (that are converted to str), " \
f"received {type(data_str)} -- {data_str}"
return data_str
############################################################################
[docs]
def load_hdf5(self, network_file, load_synapses=True, load_morph=False):
"""
Load data from hdf5 file.
Args:
network_file (str) : Network file to load data from
load_synapses (bool) : Load synapses into memory, or read on demand from file (keeps file open)
load_morph
Returns:
data (dictionary) : Dictionary with data.
Data format. The top level of the data hierarchy are "meta", "morphologies", "network".
"meta" hierarchy:
"slurm_id" (int) : Run ID of Slurm job
"axon_stump_id_flag" (bool) : Should axon be replaced by a axon stump
"config" : Config data
"config_file" : Config file
"connectivity_distributions" : Information about connections allowed, including pruning information
"hyper_voxel_ids" : List of all hyper voxels IDs
"hyper_voxel_size" (int, int, int): Number of hyper voxels along each dimension x, y, z
"hyper_voxel_width" (float, float, float) : Size of hypervoxel in meters (x,y,z)
"num_hyper_voxels" (int) : Number of hyper voxels
"position_file" : Neuron position file
"simulation_origo" (float, float, float) : Simulation origo (x, y, z) in SI-units (m)
"voxel_size" (float) : Voxel size in meters
"morphologies" heirarchy:
List of neuron names, contains the location of the morphologies
(either SWC file or directory with multiple SWC files)
"network" hierarchy:
"gap_junctions" : Gap junction data matrix (see below for format)
"num_gap_junctions" (int) : Number of gap junctions
"num_synapses" (int) : Number of synapses
"neurons" : Neuron data structure (see below for format)
"synapses" : Synapse data matrix (see below for format)
Neuron data format:
Synapse data format (column within parenthesis):
0: source_cell_id, 1: dest_cell_id, 2: voxel_x, 3: voxel_y, 4: voxel_z,
5: hyper_voxel_id, 6: channel_model_id,
7: source_axon_soma_dist (not SI scaled 1e6, micrometers),
8: dest_dend_soma_dist (not SI scalled 1e6, micrometers)
9: dest_seg_id, 10: dest_seg_x (int 0 - 1000, SONATA wants float 0.0-1.0)
11: conductance (int, not SI scaled 1e12, in pS)
12: parameterID
Note on parameterID:
If there are n parameter sets for the particular synapse type, then
the ID to use is parameterID % n, this way we can reuse connectivity
if we add more synapse parameter sets later.
Gap junction format (column in parentheses):
0: source_cell_id, 1: dest_cell_id, 2: source_seg_id, 3: dest_seg_id,
4: source_seg_x, 5: dest_seg_x, 6: voxel_x, 7: voxel_y, 8: voxel_z,
9: hyper_voxel_id, 10: conductance (integer, in pS)
"""
assert not load_morph, "load_hdf5: load_morph=True currently disabled, does not handle morph variations"
self.network_file = network_file
if self.verbose:
print(f"Loading {network_file}")
start_time = timeit.default_timer()
data = dict([])
# Save a reference to the name of the loaded network file
data["network_file"] = self.network_file
# Blender notebook has hdf5 library/header file mismatch, so importing this only where needed
# This allows us to use fake_load.py in snudda.utils
import h5py
f = h5py.File(network_file, 'r')
# We need to keep f open if load_synapses = False, using "with" would close file
if "config" in f["meta"]:
if self.verbose:
print("Loading config data from HDF5")
# data["config"] = SnuddaLoad.to_str(f["meta/config"][()])
self.config = json.loads(f["meta/config"][()])
data["config"] = self.config
# Added so this code can also load the position file, which
# does not have the network group yet
if "network/synapses" in f:
try:
data["num_neurons"] = f["network/neurons/neuron_id"].shape[0]
data["neuron_id"] = f["network/neurons/neuron_id"][()]
except:
if "neuronID" in f["network/neurons"]:
print(f"\nERROR: Old version of the network detected!\n"
f"Please regenerate your Snudda network: {self.network_file}\n\n")
import traceback
print(traceback.format_exc())
print("Type \"exit()\" to exit.")
import pdb
pdb.set_trace()
if "num_synapses" in f["network"]:
if f["network/num_synapses"].shape == ():
data["num_synapses"] = f["network/num_synapses"][()]
else:
data["num_synapses"] = f["network/num_synapses"][0]
if data["num_synapses"] != f["network/synapses"].shape[0]:
print(f"Expected {data['num_synapses']} synapses, found {f['network/synapses'].shape[0]} synapse rows")
data["num_synapses"] = f["network/synapses"].shape[0]
else:
data["num_synapses"] = f["network/synapses"].shape[0]
if "num_gap_junctions" in f["network"]:
if f["network/num_gap_junctions"].shape == ():
data["num_gap_junctions"] = f["network/num_gap_junctions"][()]
else:
data["num_gap_junctions"] = f["network/num_gap_junctions"][0]
if data["num_gap_junctions"] != f["network/gap_junctions"].shape[0]:
print(f"Expected {data['num_gap_junctions']} gap junctions, "
f"found {f['network/gap_junctions'].shape[0]} gap junction rows")
data["num_gap_junctions"] = f["network/gap_junctions"].shape[0]
else:
data["num_gap_junctions"] = f["network/gap_junctions"].shape[0]
if data["num_synapses"] > 100e6:
print(f"Found {data['num_synapses']} synapses (too many!), not loading them into memory!")
load_synapses = False
if "network/hyper_voxel_ids" in f:
data["hyper_voxel_ids"] = f["network/hyper_voxel_ids"][()]
if "meta/hyper_voxel_size" in f:
data["hyper_voxel_size"] = f["meta/hyper_voxel_size"][()]
if "meta/hyper_voxel_width" in f:
data["hyper_voxel_width"] = f["meta/hyper_voxel_width"][()]
if "meta/num_hyper_voxels" in f:
data["num_hyper_voxels"] = f["meta/num_hyper_voxels"][()]
if load_synapses:
# 0: source_cell_id, 1: dest_cell_id, 2: voxel_x, 3: voxel_y, 4: voxel_z,
# 5: hyper_voxel_id, 6: channel_model_id,
# 7: source_axon_soma_dist (not SI scaled 1e6, micrometers),
# 8: dest_dend_soma_dist (not SI scalled 1e6, micrometers)
# 9: dest_seg_id, 10: dest_seg_x (int 0 - 1000, SONATA wants float 0.0-1.0)
# 11: conductance (int, not SI scaled 1e12, in pS)
# 12: parameterID
#
# Note on parameterID:
# If there are n parameter sets for the particular synapse type, then
# the ID to use is parameterID % n, this way we can reuse connectivity
# if we add more synapse parameter sets later.
data["synapses"] = f["network/synapses"][:]
data["gap_junctions"] = f["network/gap_junctions"][:]
# !!! Convert from voxel idx to coordinates
if f["network/synapses"].shape[0] > 0:
data["synapse_coords"] = f["network/synapses"][:, 2:5] * f["meta/voxel_size"][()] \
+ f["meta/simulation_origo"][()]
else:
data["synapse_coords"] = np.zeros((3, 0))
else:
# Point the data structure to the synapses and gap junctions on file
# This will be slower, and only work while the file is open
data["synapses"] = f["network/synapses"]
data["gap_junctions"] = f["network/gap_junctions"]
# We need to keep f alive, since we did not load synapses into
# the memory
self.hdf5_file = f
else:
data["num_neurons"] = f["network/neurons/neuron_id"].shape[0]
assert data["num_neurons"] == f["network/neurons/neuron_id"][-1] + 1, \
"Internal error, something fishy with number of neurons found"
data["config_file"] = SnuddaLoad.to_str(f["meta/config_file"][()])
if "meta/position_file" in f:
data["position_file"] = SnuddaLoad.to_str(f["meta/position_file"][()])
if "meta/slurm_id" in f:
if type(f["meta/slurm_id"][()]) in [bytes, np.bytes_]:
data["slurm_id"] = int(f["meta/slurm_id"][()].decode())
else:
data["slurm_id"] = int(f["meta/slurm_id"][()])
else:
if self.verbose:
print("No slurm_id set, using -1")
data["slurm_id"] = -1
if "meta/simulation_origo" in f:
data["simulation_origo"] = f["meta/simulation_origo"][()]
if "meta/voxel_size" in f:
data["voxel_size"] = f["meta/voxel_size"][()]
if "meta/axon_stump_id_flag" in f:
data["axon_stump_id_flag"] = f["meta/axon_stump_id_flag"][()]
if "meta/snudda_data" in f:
data["snudda_data"] = SnuddaLoad.to_str(f["meta/snudda_data"][()])
if self.snudda_data is None:
self.snudda_data = data["snudda_data"]
data["neurons"] = self.extract_neurons(f)
# This is for old format, update for new format
if "parameters" in f:
# print("Parameters found, loading")
data["synapse_range"] = f["parameters/synapse_range"][()]
data["gap_junction_range"] = f["parameters/gap_junction_range"][()]
data["min_synapse_spacing"] = f["parameters/min_synapse_spacing"][()]
data["neuron_positions"] = f["network/neurons/position"][()]
data["name"] = [SnuddaLoad.to_str(x) for x in f["network/neurons/name"][()]]
if "population_unit_id" in f["network/neurons"]:
data["population_unit"] = f["network/neurons/population_unit_id"][()]
else:
if self.verbose:
print("No Population Units detected.")
data["population_unit"] = np.zeros(data["num_neurons"], dtype=int)
# TODO: Remove this, or make it able to handle multiple morphologies for each neuron_name,
# ie when morphologies is given as a dir
if load_morph and "morphologies" in f:
data["morph"] = dict([])
for morph_name in f["morphologies"].keys():
data["morph"][morph_name] = {"swc": f["morphologies"][morph_name]["swc"][()],
"location": f["morphologies"][morph_name]["location"][()]}
data["connectivity_distributions"] = dict([])
if "connectivity_distributions" in f["meta"]:
orig_connectivity_distributions = \
json.loads(SnuddaLoad.to_str(f["meta/connectivity_distributions"][()]))
for keys in orig_connectivity_distributions:
(pre_type, post_type) = keys.split("$$")
data["connectivity_distributions"][pre_type, post_type] \
= orig_connectivity_distributions[keys]
if "synapses" in data and self.verbose:
if "gap_junctions" in data:
print(f"Loading {len(data['neurons'])} neurons with {data['num_synapses']} synapses"
f" and {data['num_gap_junctions']} gap junctions")
else:
print(f"Loading {len(data['neurons'])} neurons with {data['synapses'].shape[0]} synapses")
if self.verbose:
print(f"Load done. {timeit.default_timer() - start_time:.1f}")
if load_synapses:
f.close()
else:
self.hdf5_file = f
return data
############################################################################
@staticmethod
def gather_extra_axons(hdf5_file):
extra_axons = dict()
if "extra_axons" in hdf5_file["network/neurons"]:
for neuron_id, axon_name, position, rotation, swc_file \
in zip(hdf5_file["network/neurons/extra_axons/parent_neuron"],
hdf5_file["network/neurons/extra_axons/name"],
hdf5_file["network/neurons/extra_axons/position"],
hdf5_file["network/neurons/extra_axons/rotation"],
hdf5_file["network/neurons/extra_axons/morphology"]):
if neuron_id not in extra_axons:
extra_axons[neuron_id] = dict()
extra_axons[neuron_id][axon_name] = dict()
extra_axons[neuron_id][axon_name]["position"] = position.copy()
extra_axons[neuron_id][axon_name]["rotation"] = rotation.copy().reshape(3, 3)
extra_axons[neuron_id][axon_name]["morphology"] = SnuddaLoad.to_str(swc_file)
return extra_axons
############################################################################
############################################################################
[docs]
def load_config_file(self):
""" Load config data from JSON file. """
if self.config is None:
config_file = self.data["config_file"]
self.config = json.load(open(config_file, 'r'))
############################################################################
[docs]
def synapse_iterator(self, chunk_size=1000000, data_type=None):
"""
Iterates through all synapses in chunks (default 1e6 synapses).
Args:
chunk_size (int) : Number of synapses per chunk
data_type (string) : "synapses" (default) or "gap_junctions"
Returns:
Iterator over the synapses
"""
if data_type is None:
data_type = "synapses"
# data_type is "synapses" or "gap_junctions"
assert data_type in ["synapses", "gap_junctions"]
num_rows = self.data[data_type].shape[0]
if num_rows == 0:
# No synapses
return
chunk_size = min(num_rows, chunk_size)
num_steps = int(np.ceil(num_rows / chunk_size))
row_start = 0
for row_end in np.linspace(chunk_size, num_rows, num_steps, dtype=int):
synapses = self.data[data_type][row_start:row_end, :]
row_start = row_end
yield synapses
############################################################################
[docs]
def gap_junction_iterator(self, chunk_size=1000000):
"""
Iterates through all gap junctions in chunks (default 1e6 gap junctions).
Args:
chunk_size (int) : Number of gap junctions per chunk
Returns:
Iterator over the gap junctions
"""
return self.synapse_iterator(chunk_size=chunk_size, data_type="gap_junctions")
############################################################################
# Helper methods for sorting
@staticmethod
def _row_eval_post_pre(row, num_neurons):
return row[1] * num_neurons + row[0]
@staticmethod
def _row_eval_post(row, num_neurons):
return row[1]
############################################################################
[docs]
def find_synapses_slow(self, pre_id, n_max=1000000):
"""
Returns subset of synapses.
Args:
pre_id (int) : Pre-synaptic neuron ID (can also be a list)
n_max (int) : Maximum number of synapses to return
Returns:
Subset of synapse matrix, synapse coordinates
"""
if self.verbose:
print(f"Finding synapses originating from {pre_id}, this is slow")
synapses = np.zeros((n_max, 13), dtype=np.int32)
syn_ctr = 0
if np.issubdtype(type(pre_id), np.integer):
for syn_list in self.synapse_iterator():
for syn in syn_list:
if syn[0] == pre_id:
synapses[syn_ctr, :] = syn
syn_ctr += 1
else:
for syn_list in self.synapse_iterator():
for syn in syn_list:
if syn[0] in pre_id:
synapses[syn_ctr, :] = syn
syn_ctr += 1
synapse_coords = synapses[:, 2:5][:syn_ctr, :] * self.data["voxel_size"] + self.data["simulation_origo"]
return synapses[:syn_ctr, :], synapse_coords
############################################################################
# Either give preID and postID, or just postID
[docs]
def find_synapses(self, pre_id=None, post_id=None, silent=True, return_index=False):
"""
Returns subset of synapses.
Args:
pre_id (int) : Pre-synaptic neuron ID
post_id (int) : Post-synaptic neuron ID
silent (bool) : Work quietly or verbosely
Returns:
Subset of synapse matrix.
"""
if self.data["synapses"].shape[0] == 0:
if not silent:
print("No synapses in network")
if return_index:
return None, None, None
else:
return None, None
if post_id is None:
assert return_index is False, "You must specify pre_id and post_id if return_index is True"
return self.find_synapses_slow(pre_id=pre_id)
assert post_id is not None, "Must specify at least postID"
num_rows = self.data["synapses"].shape[0]
num_neurons = len(self.data["neuron_id"])
if pre_id is None:
row_eval = self._row_eval_post
val_target = post_id
else:
row_eval = self._row_eval_post_pre
val_target = post_id * num_neurons + pre_id
idx_a1 = 0
idx_a2 = num_rows - 1
idx_found = None
# We use idxA1 and idxA2 as upper and lower range within which we
# hope to find one of the synapses. Once we found a synapse row
# we go up and down in matrix to find the range of the synapses
# matching the requested condition. This works because the matrix is
# sorted on postID, and then preID if postID matches
if row_eval(self.data["synapses"][idx_a1, :], num_neurons) == val_target:
idx_found = idx_a1
if row_eval(self.data["synapses"][idx_a2, :], num_neurons) == val_target:
idx_found = idx_a2
# -1 since if idx_a1 and idx_a2 are one apart, we have checked all values
while idx_a1 < idx_a2 - 1 and idx_found is None:
idx_next = int(np.round((idx_a1 + idx_a2) / 2))
val_next = row_eval(self.data["synapses"][idx_next, :], num_neurons)
if val_next < val_target:
idx_a1 = idx_next
elif val_next > val_target:
idx_a2 = idx_next
else:
# We found a hit
idx_found = idx_next
break
if idx_found is None:
# No synapses found
if self.verbose:
print("No synapses found")
if return_index:
return None, None, None
else:
return None, None
# Find start of synapse range
idx_b1 = idx_found
val_b1 = row_eval(self.data["synapses"][idx_b1 - 1, :], num_neurons)
while val_b1 == val_target and idx_b1 > 0:
idx_b1 -= 1
val_b1 = row_eval(self.data["synapses"][idx_b1 - 1, :], num_neurons)
# Find end of synapse range
idx_b2 = idx_found
if idx_b2 + 1 < self.data["synapses"].shape[0]:
val_b2 = row_eval(self.data["synapses"][idx_b2 + 1, :], num_neurons)
while val_b2 == val_target and idx_b2 + 1 < self.data["synapses"].shape[0]:
idx_b2 += 1
val_b2 = row_eval(self.data["synapses"][idx_b2 + 1, :], num_neurons)
synapses = self.data["synapses"][idx_b1:idx_b2 + 1, :].copy()
if not silent and self.verbose:
print(f"Synapse range, first {idx_b1}, last {idx_b2}")
print(f"{synapses}")
# Calculate coordinates
synapse_coords = synapses[:, 2:5] * self.data["voxel_size"] + self.data["simulation_origo"]
if return_index:
return synapses, synapse_coords, np.arange(idx_b1, idx_b2+1)
else:
return synapses, synapse_coords
############################################################################
def get_neuron_population_units(self, neuron_id=None, return_set=False):
if neuron_id is not None:
neuron_population_units = self.data["population_unit"][neuron_id].flatten().copy()
else:
neuron_population_units = self.data["population_unit"].flatten().copy()
if return_set:
return set(neuron_population_units)
else:
return neuron_population_units
def get_neuron_types(self, neuron_id=None, return_set=False):
if neuron_id is not None:
neuron_types = [self.data["neurons"][x]["type"] for x in neuron_id]
else:
neuron_types = [x["type"] for x in self.data["neurons"]]
if return_set:
return set(neuron_types)
else:
return neuron_types
############################################################################
# Returns neuron_id of all neurons of neuron_type
# OBS, random_permute is not using a controled rng, so not affected by random seed set
[docs]
def get_neuron_id_of_type(self, neuron_type, num_neurons=None, random_permute=False, volume=None,
include_virtual=True, population_unit_id=None):
"""
Find all neuron ID of a specific neuron type.
Args:
neuron_type (string) : Neuron type (e.g. "FS")
num_neurons (int) : Maximum number of neurons to return
random_permute (bool) : Shuffle the resulting neuron IDs?
volume (string) : volume_id containing neurons (default None -- all neurons of type)
Returns:
List of neuron ID of specified neuron type
"""
neuron_id = np.array([x["neuron_id"] for x in self.data["neurons"]
if (neuron_type is None or x["type"] == neuron_type)
and (volume is None or x["volume_id"] == volume)
and (include_virtual or not x["virtual_neuron"])
and (population_unit_id is None or x["population_unit"] == population_unit_id)])
assert not random_permute or num_neurons is not None, "random_permute is only valid when num_neurons is given"
if num_neurons is not None:
if random_permute:
# Do not use this if you have a simulation with multiple
# workers... they might randomize differently, and you might
# fewer neurons in total than you wanted
keep_idx = np.random.permutation(len(neuron_id))
if len(keep_idx) > num_neurons:
keep_idx = keep_idx[:num_neurons]
neuron_id = neuron_id[keep_idx]
else:
neuron_id = neuron_id[:num_neurons]
if len(neuron_id) < num_neurons:
if self.verbose:
print(f"get_neuron_id_of_type: wanted {num_neurons} only got {len(neuron_id)} "
f"neurons of type {neuron_type}")
# Double check that all of the same type (or neuron_type is None)
assert neuron_type is None or np.array([self.data["neurons"][x]["type"] == neuron_type for x in neuron_id]).all()
assert volume is None or np.array([self.data["neurons"][x]["volume_id"] == volume for x in neuron_id]).all()
return neuron_id
[docs]
def get_neuron_id(self, include_virtual=True):
""" Returns all neuron_id, if include_virtual is set (default) virtual neurons are also included."""
neuron_id = np.array([x["neuron_id"] for x in self.data["neurons"] if include_virtual or not x["virtual_neuron"]])
return neuron_id
[docs]
def get_neuron_id_with_name(self, neuron_name, include_virtual=True):
"""
Find neuron ID of neurons with a given name.
Args:
neuron_name (str): Name of neurons (e.g. "dSPN_0")
include_virtual (bool): Should virtual neurons also be included?
Returns:
List of neuron ID
"""
neuron_id = np.array([x["neuron_id"] for x in self.data["neurons"]
if x["name"] == neuron_name and (include_virtual or not x["virtual_neuron"])])
return neuron_id
[docs]
def get_population_unit_members(self, population_unit, num_neurons=None, random_permute=False):
"""
Returns neuron ID of neurons belonging to a specific population unit.
Args:
population_unit (int) : Population unit ID
num_neurons (int) : Number of neurons to return (None = all hits)
random_permute (bool) : Randomly shuffle neuron IDs?
Returns:
List of neuron ID belonging to population unit.
"""
neuron_id = np.where(self.data["population_unit"] == population_unit)[0]
if num_neurons:
if random_permute:
neuron_id = np.random.permutation(neuron_id)
if len(neuron_id) > num_neurons:
neuron_id = neuron_id[:num_neurons]
# Just double check
assert (self.data["population_unit"][neuron_id] == population_unit).all()
return neuron_id
[docs]
def load_neuron(self, neuron_id):
"""
Loads a specific neuron. Returns a NeuronMorphology object.
Args:
neuron_id (int): Neuron ID
Returns:
NeuronMorphology object.
"""
neuron_prototype = NeuronPrototype(neuron_path=self.data["neurons"][neuron_id]["neuron_path"],
snudda_data=self.data["snudda_data"],
neuron_name=self.data["neurons"][neuron_id]["name"])
neuron_object = neuron_prototype.clone(parameter_key=self.data["neurons"][neuron_id]["parameter_key"],
morphology_key=self.data["neurons"][neuron_id]["morphology_key"],
modulation_key=self.data["neurons"][neuron_id]["modulation_key"],
position=self.data["neurons"][neuron_id]["position"],
rotation=self.data["neurons"][neuron_id]["rotation"])
return neuron_object
def iter_neuron_id(self):
for x, v in enumerate(self.data["neurons"]):
assert x == v["neuron_id"], \
f"Neuron at position {x} has neuron_id {v['neuron_id']} (should be same)"
yield x
def get_neuron_keys(self, neuron_id):
n = self.data["neurons"][neuron_id]
return n["parameter_key"], n["morphology_key"], n["modulation_key"]
def get_neuron_params(self, neuron_id):
neuron_path = self.data["neurons"][neuron_id]["neuron_path"]
parameter_key = self.data["neurons"][neuron_id]["parameter_key"]
parameter_file = os.path.join(neuron_path, "parameters.json")
with open(parameter_file, "r") as f:
parameter_data = json.load(f)
param_data = parameter_data[parameter_key]
if "modulation_key" in self.data["neurons"][neuron_id]:
modulation_key = self.data["neurons"][neuron_id]["modulation_key"]
modulation_file = os.path.join(neuron_path, "modulation.json")
with open(modulation_file, "r") as f:
modulation_data = json.load(f)
mod_data = modulation_data[modulation_key]
else:
mod_data = None
return param_data, mod_data
[docs]
def find_gap_junctions(self, neuron_id, n_max=1000000, return_index=False):
""" Find gap junctions associated with neuron_id
Args:
neuron_id (int) : Neuron with gap junction (can also be a list)
n_max (int) : Maximum number of gap junctions to return
return_index (bool): Should third return value index be present
Returns:
Subset of gap junction matrix, gap junction coordinates
"""
if self.verbose:
print(f"Finding gap junctions connecting neuron {neuron_id}")
gap_junctions = np.zeros((n_max, 11), dtype=np.int32)
gj_ctr = 0
gj_index = 0
gj_index_list = np.zeros((n_max,), dtype=int)
if np.issubdtype(type(neuron_id), np.integer):
for gj_list in self.gap_junction_iterator():
for gj in gj_list:
if gj[0] == neuron_id or gj[1] == neuron_id:
gap_junctions[gj_ctr, :] = gj
gj_ctr += 1
gj_index_list[gj_ctr] = gj_index
gj_index += 1
else:
for gj_list in self.gap_junction_iterator():
for gj in gj_list:
if gj[0] in neuron_id or gj[1] in neuron_id:
gap_junctions[gj_ctr, :] = gj
gj_ctr += 1
gj_index_list[gj_ctr] = gj_index
gj_index += 1
gj_coords = gap_junctions[:, 6:9][:gj_ctr, :] * self.data["voxel_size"] + self.data["simulation_origo"]
if return_index:
return gap_junctions[:gj_ctr, :], gj_coords, gj_index_list[:gj_ctr]
else:
return gap_junctions[:gj_ctr, :], gj_coords
[docs]
def get_centre_neurons_iterator(self, n_neurons=None, neuron_type=None, neuron_name=None,
centre_point=None, max_distance=None,
return_distance=True, include_virtual=False):
""" Return neuron id:s, starting from the centre most and moving outwards
Args:
n_neurons (int) : Number of neurons to return, None = all available
neuron_type (str) : Type of neurons to return, None = all available
centre_point (np.array) : x,y,z of centre position, None = auto detect centre
"""
if centre_point is None:
centre_point = np.mean(self.data["neuron_positions"], axis=0)
dist_to_centre = np.linalg.norm(self.data["neuron_positions"] - centre_point, axis=-1)
idx = np.argsort(dist_to_centre)
neuron_ctr = 0
for neuron_id in idx:
if not include_virtual and self.data["neurons"][neuron_id]["virtual_neuron"]:
# Ignore virtual neurons
continue
if neuron_type is not None and self.data["neurons"][neuron_id]["type"] != neuron_type:
# Wrong neuron type
continue
if neuron_name is not None and self.data["neurons"][neuron_id]["name"] != neuron_name:
# Wrong neuron name (e.g. "dSPN_0", a specific model of a neuron type)
continue
if max_distance is not None and dist_to_centre[neuron_id] > max_distance:
# Stop iterator if max distance is reached
return
if return_distance:
yield neuron_id, dist_to_centre[neuron_id]
else:
yield neuron_id
neuron_ctr += 1
if n_neurons is not None and neuron_ctr >= n_neurons:
# Stop iterator if n_neurons are delivered
return
############################################################################
def create_connection_matrix(self, sparse_matrix=True, connection_type="synapses"):
if sparse_matrix:
connection_matrix = sparse.lil_matrix((self.data["num_neurons"], self.data["num_neurons"]), dtype=np.int16)
else:
connection_matrix = np.zeros((self.data["num_neurons"], self.data["num_neurons"]), dtype=np.int16)
for syn_row in self.data[connection_type]:
connection_matrix[syn_row[0], syn_row[1]] += 1
if connection_type == "gap_junctions":
# Gap junctions are symmetric, so mark other side also
connection_matrix[syn_row[1], syn_row[0]] += 1
return connection_matrix
def find_neighbours(self, neuron_id, connection_matrix=None, exclude_parent=True):
if connection_matrix is None:
connection_matrix = self.create_connection_matrix(sparse_matrix=False, connection_type="synapses")
pre_neighbours = set(np.where(np.sum(connection_matrix[:, list(neuron_id)], axis=1))[0])
post_neighbours = set(np.where(np.sum(connection_matrix[list(neuron_id), :], axis=0))[0])
if exclude_parent:
parent_id = set(neuron_id)
pre_neighbours -= parent_id
post_neighbours -= parent_id
return pre_neighbours, post_neighbours
def find_neighbours_gap_junctions(self, neuron_id, connection_matrix=None, exclude_parent=True):
if connection_matrix is None:
connection_matrix = self.create_connection_matrix(sparse_matrix=False, connection_type="gap_junctions")
# This matrix should be symmetric!
pre_neighbours = set(np.where(np.sum(connection_matrix[:, list(neuron_id)], axis=1))[0])
post_neighbours = set(np.where(np.sum(connection_matrix[list(neuron_id), :], axis=0))[0])
neighbours = pre_neighbours | post_neighbours
if exclude_parent:
parent_id = set(neuron_id)
neighbours -= parent_id
return neighbours
def create_distance_matrix(self, neuron_id=None, pre_id=None, post_id=None):
if neuron_id is not None and pre_id is not None and post_id is not None:
raise ValueError("Specify either neuron_id or the two parameters pre_id and post_id.")
if (pre_id is None) ^ (post_id is None):
raise ValueError("pre_id and post_id must both either be specified, or neither")
pos = self.data["neuron_positions"]
if neuron_id is not None:
pos = pos[neuron_id, :]
dist_matrix = distance_matrix(pos, pos)
elif pre_id is not None and post_id is not None:
dist_matrix = distance_matrix(pos[pre_id, :], pos[post_id, :])
else:
dist_matrix = distance_matrix(pos, pos)
return dist_matrix
def print_all_synapse_counts_per_type(self):
synapse_types = sorted(list(self.get_neuron_types(return_set=True)))
connection_matrix = self.create_connection_matrix()
for pre_type in synapse_types:
for post_type in synapse_types:
count = self.count_synapses_per_type(pre_neuron_type=pre_type,
post_neuron_type=post_type,
connection_matrix=connection_matrix)
if count > 0:
print(f"{pre_type} -> {post_type}: {count} synapses")
[docs]
def count_synapses_per_type(self, pre_neuron_type=None, post_neuron_type=None, connection_matrix=None):
"""
Args:
pre_neuron_type (list): List of neuron types
post_neuron_type (list): List of neuron types, e.g. ["dSPN", "iSPN"]
connection_matrix
"""
if connection_matrix is None:
connection_matrix = self.create_connection_matrix()
pre_mask = np.zeros((connection_matrix.shape[0]), dtype=bool)
post_mask = np.zeros((connection_matrix.shape[1]), dtype=bool)
if pre_neuron_type is None:
pre_neuron_type = self.get_neuron_types(return_set=True)
elif type(pre_neuron_type) != list:
pre_neuron_type = [pre_neuron_type]
if post_neuron_type is None:
post_neuron_type = self.get_neuron_types(return_set=True)
elif type(post_neuron_type) != list:
post_neuron_type = [post_neuron_type]
for neuron_type in pre_neuron_type:
pre_id = self.get_neuron_id_of_type(neuron_type)
pre_mask[pre_id] = True
for neuron_type in post_neuron_type:
post_id = self.get_neuron_id_of_type(neuron_type)
post_mask[post_id] = True
synapse_count = np.sum(connection_matrix[pre_id, :][:, post_id])
return synapse_count
def count_incoming_connections(self, neuron_type):
neuron_id = self.get_neuron_id_of_type(neuron_type)
neuron_id_mask = np.zeros((self.data["num_neurons"],), dtype=bool)
neuron_id_mask[neuron_id] = True
synapse_count = 0
gap_junction_count = 0
for synapses in self.synapse_iterator():
synapse_count += np.sum(neuron_id_mask[synapses[:, 1]])
for gap_junctions in self.gap_junction_iterator():
gap_junction_count += np.sum(neuron_id_mask[gap_junctions[:, 0]])
gap_junction_count += np.sum(neuron_id_mask[gap_junctions[:, 1]])
gap_junction_count /= 2
return synapse_count, gap_junction_count
[docs]
def snudda_load_cli():
""" Command line parser for SnuddaLoad script """
from argparse import ArgumentParser
parser = ArgumentParser(description="Load snudda network file (hdf5)")
parser.add_argument("network_file", help="Network file (hdf5)", type=str)
parser.add_argument("--listN", help="Lists neurons in network", action="store_true")
parser.add_argument("--listT", type=str, help="List neurons of type, --listT ? list the types.", default=None)
parser.add_argument("--listPre", help="List pre synaptic neurons", type=int)
parser.add_argument("--listPost", help="List post synaptic neurons (slow)", type=int)
parser.add_argument("--listGJ", help="List gap junctions (slow)", type=int)
parser.add_argument("--listTotalIncoming", help="List number of total incoming connections to neuron type",
type=str, default=None)
parser.add_argument("--keepOpen", help="This prevents loading of synapses to memory, and keeps HDF5 file open",
action="store_true")
parser.add_argument("--detailed", help="More information", action="store_true")
parser.add_argument("--voxels", help="Voxel information", action="store_true")
parser.add_argument("--centre", help="List n neurons in centre (-1 = all)", type=int)
parser.add_argument("--listParam", help="List parameters for neuron_id", type=int)
parser.add_argument("--countSyn", help="Count synapses per type", action="store_true")
args = parser.parse_args()
if args.keepOpen:
load_synapses = False
else:
load_synapses = True
nl = SnuddaLoad(args.network_file, load_synapses=load_synapses, verbose=True)
if args.listN:
print("Neurons in network: ")
if args.detailed:
for nid, name, pos, par_key, morph_key, mod_key, neuron_path, pop_id, virt_flag \
in [(x["neuron_id"], x["name"], x["position"],
x["parameter_key"], x["morphology_key"], x["modulation_key"],
x["neuron_path"], x["population_unit"], x["virtual_neuron"])
for x in nl.data["neurons"]]:
print(f"{nid} : {name}{' [virtual]' if virt_flag else ''}, ({pos[0]:.6f}, {pos[1]:.6f}, {pos[2]:.6f}) "
f"pop_id {pop_id}, par_key {par_key}, morph_key {morph_key}, neuron_path: {neuron_path}")
else:
for nid, name, pos, pid, virt_flag \
in [(x["neuron_id"], x["name"], x["position"], x["population_unit"], x["virtual_neuron"])
for x in nl.data["neurons"]]:
print(f"{nid} {'V' if virt_flag else ':'} {name} [{pid}], ({pos[0]:.6f}, {pos[1]:.6f}, {pos[2]:.6f})")
if args.listT is not None:
if args.listT == "?":
print("List neuron types in network:")
n_types = np.unique([x["type"] for x in nl.data["neurons"]])
for nt in n_types:
num = len([x["type"] for x in nl.data["neurons"] if x["type"] == nt])
print(f"{nt} ({num} total)")
else:
print(f"Neurons of type {args.listT}:")
n_of_type = [(x["neuron_id"], x["name"]) for x in nl.data["neurons"]
if x["type"] == args.listT]
for nid, name in n_of_type:
print("%d : %s" % (nid, name))
if args.listPre is not None:
print(f"List neurons pre-synaptic to neuron_id = {args.listPre} "
f"({nl.data['neurons'][args.listPre]['name']})")
synapses, synapse_coords = nl.find_synapses(post_id=args.listPre)
if synapses is None:
print("No pre synaptic neurons were found.")
else:
print(f"The neuron receives {synapses.shape[0]} synapses")
pre_id = np.unique(synapses[:, 0])
for nid, name in [(x["neuron_id"], x["name"]) for x in nl.data["neurons"] if x["neuron_id"] in pre_id]:
n_syn = np.sum(synapses[:, 0] == nid)
print(f"{nid} : {name} ({n_syn} synapses)")
if args.detailed:
idx = np.where(synapses[:, 0] == nid)[0]
for i in idx:
print(f" -- SecID {synapses[i, 9]}, SecX {synapses[i, 10] * 1e-3:.2f}, "
f"Soma dist: {synapses[i,8]:.1f} μm, "
f"Coord: ({synapse_coords[i, 0]*1e6:.1f}, "
f"{synapse_coords[i, 1]*1e6:.1f}, "
f"{synapse_coords[i, 2]*1e6:.1f}) μm, "
f"Cond: {synapses[i, 11] * 1e-3:.2f} nS")
print("")
if args.voxels:
idx = np.where(synapses[:, 0] == nid)[0]
for i in idx:
print(f" -- voxels: {synapses[i, 2:5]}, hyper id: {synapses[i, 5]}, "
f" sec id {synapses[i, 9]}, sec x {synapses[i, 10] * 1e-3:.2f}, "
f"Soma dist: {synapses[i,8]:.1f} μm, ")
print("")
if args.listPost is not None:
print(f"List neurons post-synaptic to neuron_id = {args.listPost}"
f" ({nl.data['neurons'][args.listPost]['name']}):")
synapses, synapse_coords = nl.find_synapses(pre_id=args.listPost)
print(f"The neuron makes {synapses.shape[0]} synapses on other neurons")
if synapses is None:
print("No post synaptic targets found.")
else:
post_id = np.unique(synapses[:, 1])
for nid, name in [(x["neuron_id"], x["name"]) for x in nl.data["neurons"] if x["neuron_id"] in post_id]:
n_syn = np.sum(synapses[:, 1] == nid)
print(f"{nid} : {name} ({n_syn} synapses)")
if args.detailed:
idx = np.where(synapses[:, 1] == nid)[0]
for i in idx:
print(f" -- SecID {synapses[i, 9]}, SecX {synapses[i, 10] * 1e-3:.2f}, "
f"Soma dist: {synapses[i,8]:.1f} μm, "
f"Coord: ({synapse_coords[i, 0]*1e6:.1f}, "
f"{synapse_coords[i, 1]*1e6:.1f}, "
f"{synapse_coords[i, 2]*1e6:.1f}) μm, "
f"Cond: {synapses[i, 11] * 1e-3:.2f} nS")
print("")
if args.listGJ is not None:
print(f"List gap junctions of neuron_id = {args.listGJ}"
f" ({nl.data['neurons'][args.listGJ]['name']})")
gap_junctions, gap_junction_coords = nl.find_gap_junctions(neuron_id=args.listGJ)
if gap_junctions.shape[0] == 0:
print("No gap junctions on neuron.")
else:
connected_id = set(gap_junctions[:, 0]).union(gap_junctions[:, 1])
connected_id.remove(args.listGJ)
for nid, name in [(x["neuron_id"], x["name"]) for x in nl.data["neurons"] if x["neuron_id"] in connected_id]:
n_gj = np.sum(gap_junctions[:, 0] == nid) + np.sum(gap_junctions[:, 1] == nid)
print(f"{nid} : {name} ({n_gj} gap junctions)")
if args.detailed:
idx1 = np.where(gap_junctions[:, 0] == nid)[0]
for i in idx1:
print(f" -- SecID {gap_junctions[i, 2]}, SecX {gap_junctions[i, 4] * 1e-3:.3f}, "
f"Coord: ({gap_junction_coords[i, 0]*1e6:.1f}, "
f"{gap_junction_coords[i, 1]*1e6:.1f}, "
f"{gap_junction_coords[i, 2]*1e6:.1f}) μm, "
f"Cond: {gap_junctions[i, 10] * 1e-3:.3f} nS")
idx2 = np.where(gap_junctions[:, 1] == nid)[0]
for i in idx2:
print(f" -- SecID {gap_junctions[i, 3]}, SecX {gap_junctions[i, 5] * 1e-3:.3f}, "
f"Coord: ({gap_junction_coords[i, 0]*1e6:.1f}, "
f"{gap_junction_coords[i, 1]*1e6:.1f}, "
f"{gap_junction_coords[i, 2]*1e6:.1f}) μm, "
f"Cond: {gap_junctions[i, 10] * 1e-3:.3f} nS")
if args.listParam is not None:
param_data, mod_data = nl.get_neuron_params(neuron_id=args.listParam)
print(f"Neuron ID {args.listParam}\n"
f"parameters: {json.dumps(param_data, indent=4, cls=NumpyEncoder)}\n\n"
f"modulation: {json.dumps(mod_data, indent=4, cls=NumpyEncoder)}\n")
if args.centre:
if args.centre < 0:
n_neurons = None
else:
n_neurons = args.centre
for neuron_id, centre_dist in nl.get_centre_neurons_iterator(n_neurons):
pos = nl.data["neurons"][neuron_id]["position"] * 1e6
name = nl.data["neurons"][neuron_id]["name"]
print(f"{neuron_id} {name}: ({pos[0]:.1f}, {pos[1]:.1f}, {pos[2]:.1f}) μm, distance to centre {centre_dist*1e6:.1f} μm")
if args.countSyn:
nl.print_all_synapse_counts_per_type()
if args.listTotalIncoming:
incoming_to_type = args.listTotalIncoming
synapse_count, gap_junction_count = nl.count_incoming_connections(neuron_type=incoming_to_type)
print(f"All neurons of type {incoming_to_type} receive in total {synapse_count:.0f} synapses, "
f"and have {gap_junction_count:.0f} gap junctions in total.")
if __name__ == "__main__":
snudda_load_cli()