# vim: set fileencoding=<utf-8> :
# Copyright 2018-2023 John Lees and Nick Croucher
'''General utility functions for data read/writing/manipulation in PopPUNK'''
# universal
import os
import sys
# additional
import pickle
import multiprocessing
from collections import defaultdict
from itertools import chain
from tempfile import mkstemp
from functools import partial
import contextlib
import numpy as np
import pandas as pd
try:
import cudf
import rmm
import cupy
from numba import cuda
except ImportError:
pass
import poppunk_refine
import pp_sketchlib
def setGtThreads(threads):
import graph_tool.all as gt
# Check on parallelisation of graph-tools
if gt.openmp_enabled():
gt.openmp_set_num_threads(threads)
sys.stderr.write('\nGraph-tools OpenMP parallelisation enabled:')
sys.stderr.write(' with ' + str(gt.openmp_get_num_threads()) + ' threads\n')
# thanks to Laurent LAPORTE on SO
[docs]@contextlib.contextmanager
def set_env(**environ):
"""
Temporarily set the process environment variables.
>>> with set_env(PLUGINS_DIR=u'test/plugins'):
... "PLUGINS_DIR" in os.environ
True
>>> "PLUGINS_DIR" in os.environ
False
"""
old_environ = dict(os.environ)
os.environ.update(environ)
try:
yield
finally:
os.environ.clear()
os.environ.update(old_environ)
# https://stackoverflow.com/a/17954769
from contextlib import contextmanager
[docs]@contextmanager
def stderr_redirected(to=os.devnull):
'''
import os
with stdout_redirected(to=filename):
print("from Python")
os.system("echo non-Python applications are also supported")
'''
fd = sys.stderr.fileno()
def _redirect_stderr(to):
sys.stderr.close()
os.dup2(to.fileno(), fd)
sys.stderr = os.fdopen(fd, 'w')
with os.fdopen(os.dup(fd), 'w') as old_stderr:
with open(to, 'w') as file:
_redirect_stderr(to=file)
try:
yield
finally:
_redirect_stderr(to=old_stderr)
# Use partials to set up slightly different function calls between
# both possible backends
[docs]def setupDBFuncs(args):
"""Wraps common database access functions from sketchlib and mash,
to try and make their API more similar
Args:
args (argparse.opts)
Parsed command lines options
qc_dict (dict)
Table of parameters for QC function
Returns:
dbFuncs (dict)
Functions with consistent arguments to use as the database API
"""
from .sketchlib import checkSketchlibVersion
from .sketchlib import createDatabaseDir
from .sketchlib import joinDBs
from .sketchlib import constructDatabase as constructDatabaseSketchlib
from .sketchlib import queryDatabase as queryDatabaseSketchlib
from .sketchlib import readDBParams
from .sketchlib import getSeqsInDb
backend = "sketchlib"
version = checkSketchlibVersion()
constructDatabase = partial(constructDatabaseSketchlib,
strand_preserved = args.strand_preserved,
min_count = args.min_kmer_count,
use_exact = args.exact_count,
use_gpu = args.gpu_sketch,
deviceid = args.deviceid)
queryDatabase = partial(queryDatabaseSketchlib,
use_gpu = args.gpu_dist,
deviceid = args.deviceid)
# Dict of DB access functions for assign_query (which is out of scope)
dbFuncs = {'createDatabaseDir': createDatabaseDir,
'joinDBs': joinDBs,
'constructDatabase': constructDatabase,
'queryDatabase': queryDatabase,
'readDBParams': readDBParams,
'getSeqsInDb': getSeqsInDb,
'backend': backend,
'backend_version': version
}
return dbFuncs
[docs]def storePickle(rlist, qlist, self, X, pklName):
"""Saves core and accessory distances in a .npy file, names in a .pkl
Called during ``--create-db``
Args:
rlist (list)
List of reference sequence names (for :func:`~iterDistRows`)
qlist (list)
List of query sequence names (for :func:`~iterDistRows`)
self (bool)
Whether an all-vs-all self DB (for :func:`~iterDistRows`)
X (numpy.array)
n x 2 array of core and accessory distances
If None, do not save
pklName (str)
Prefix for output files
"""
with open(pklName + ".pkl", 'wb') as pickle_file:
pickle.dump([rlist, qlist, self], pickle_file)
if isinstance(X, np.ndarray):
np.save(pklName + ".npy", X)
[docs]def readPickle(pklName, enforce_self=False, distances=True):
"""Loads core and accessory distances saved by :func:`~storePickle`
Called during ``--fit-model``
Args:
pklName (str)
Prefix for saved files
enforce_self (bool)
Error if self == False
[default = True]
distances (bool)
Read the distance matrix
[default = True]
Returns:
rlist (list)
List of reference sequence names (for :func:`~iterDistRows`)
qlist (list)
List of query sequence names (for :func:`~iterDistRows`)
self (bool)
Whether an all-vs-all self DB (for :func:`~iterDistRows`)
X (numpy.array)
n x 2 array of core and accessory distances
"""
with open(pklName + ".pkl", 'rb') as pickle_file:
rlist, qlist, self = pickle.load(pickle_file)
if enforce_self and (not self or rlist != qlist):
sys.stderr.write("Old distances " + pklName + ".npy not complete\n")
sys.exit(1)
if distances:
X = np.load(pklName + ".npy")
else:
X = None
return rlist, qlist, self, X
[docs]def iterDistRows(refSeqs, querySeqs, self=True):
"""Gets the ref and query ID for each row of the distance matrix
Returns an iterable with ref and query ID pairs by row.
Args:
refSeqs (list)
List of reference sequence names.
querySeqs (list)
List of query sequence names.
self (bool)
Whether a self-comparison, used when constructing a database.
Requires refSeqs == querySeqs
Default is True
Returns:
ref, query (str, str)
Iterable of tuples with ref and query names for each distMat row.
"""
if self:
if refSeqs != querySeqs:
raise RuntimeError('refSeqs must equal querySeqs for db building (self = true)')
for i, ref in enumerate(refSeqs):
for j in range(i + 1, len(refSeqs)):
yield(refSeqs[j], ref)
else:
for query in querySeqs:
for ref in refSeqs:
yield(ref, query)
[docs]def listDistInts(refSeqs, querySeqs, self=True):
"""Gets the ref and query ID for each row of the distance matrix
Returns an iterable with ref and query ID pairs by row.
Args:
refSeqs (list)
List of reference sequence names.
querySeqs (list)
List of query sequence names.
self (bool)
Whether a self-comparison, used when constructing a database.
Requires refSeqs == querySeqs
Default is True
Returns:
ref, query (str, str)
Iterable of tuples with ref and query names for each distMat row.
"""
num_ref = len(refSeqs)
num_query = len(querySeqs)
if self:
if refSeqs != querySeqs:
raise RuntimeError('refSeqs must equal querySeqs for db building (self = true)')
for i in range(num_ref):
for j in range(i + 1, num_ref):
yield(j, i)
else:
comparisons = [(0,0)] * (len(refSeqs) * len(querySeqs))
for i in range(num_query):
for j in range(num_ref):
yield(j, i)
return comparisons
[docs]def readIsolateTypeFromCsv(clustCSV, mode = 'clusters', return_dict = False):
"""Read cluster definitions from CSV file.
Args:
clustCSV (str)
File name of CSV with isolate assignments
mode (str)
Type of file to read 'clusters', 'lineages', or 'external'
return_dict (bool)
If True, return a dict with sample->cluster instead
of sets
[default = False]
Returns:
clusters (dict)
Dictionary of cluster assignments (keys are cluster names, values are
sets containing samples in the cluster). Or if return_dict is set keys
are sample names, values are cluster assignments.
"""
# data structures
if return_dict:
clusters = defaultdict(dict)
else:
clusters = {}
# read CSV
clustersCsv = pd.read_csv(clustCSV, index_col = 0, quotechar='"')
# select relevant columns according to mode
if mode == 'clusters':
type_columns = [n for n,col in enumerate(clustersCsv.columns) if ('Cluster' in col)]
elif mode == 'lineages':
type_columns = [n for n,col in enumerate(clustersCsv.columns) if ('Rank_' in col or 'overall' in col)]
elif mode == 'external':
if len(clustersCsv.columns) == 1:
type_columns = [0]
elif len(clustersCsv.columns) > 1:
type_columns = range((len(clustersCsv.columns)-1))
else:
sys.stderr.write('Unknown CSV reading mode: ' + mode + '\n')
sys.exit(1)
# read file
for row in clustersCsv.itertuples():
for cls_idx in type_columns:
cluster_name = clustersCsv.columns[cls_idx]
cluster_name = cluster_name.replace('__autocolour','')
if return_dict:
clusters[cluster_name][str(row.Index)] = str(row[cls_idx + 1])
else:
if cluster_name not in clusters.keys():
clusters[cluster_name] = defaultdict(set)
clusters[cluster_name][str(row[cls_idx + 1])].add(row.Index)
# return data structure
return clusters
[docs]def joinClusterDicts(d1, d2):
"""Join two dictionaries returned by :func:`~readIsolateTypeFromCsv` with
return_dict = True. Useful for concatenating ref and query assignments
Args:
d1 (dict of dicts)
First dictionary to concat
d2 (dict of dicts)
Second dictionary to concat
Returns:
d1 (dict of dicts)
d1 with d2 appended
"""
matching_cols = set(d1.keys()).intersection(d2.keys())
if len(matching_cols) == 0:
sys.stderr.write("Cluster columns do not match between sets being combined\n")
sys.stderr.write(f"{d1.keys()} {d2.keys()}\n")
sys.exit(1)
missing_cols = []
for column in d1.keys():
if column in matching_cols:
# Combine dicts: https://stackoverflow.com/a/15936211
d1[column] = \
dict(chain.from_iterable(d.items() for d in (d1[column], d2[column])))
else:
missing_cols.append(column)
for missing in missing_cols:
del d1[missing]
return d1
[docs]def update_distance_matrices(refList, distMat, queryList = None, query_ref_distMat = None,
query_query_distMat = None, threads = 1):
"""Convert distances from long form (1 matrix with n_comparisons rows and 2 columns)
to a square form (2 NxN matrices), with merging of query distances if necessary.
Args:
refList (list)
List of references
distMat (numpy.array)
Two column long form list of core and accessory distances
for pairwise comparisons between reference db sequences
queryList (list)
List of queries
query_ref_distMat (numpy.array)
Two column long form list of core and accessory distances
for pairwise comparisons between queries and reference db
sequences
query_query_distMat (numpy.array)
Two column long form list of core and accessory distances
for pairwise comparisons between query sequences
threads (int)
Number of threads to use
Returns:
seqLabels (list)
Combined list of reference and query sequences
coreMat (numpy.array)
NxN array of core distances for N sequences
accMat (numpy.array)
NxN array of accessory distances for N sequences
"""
seqLabels = refList
if queryList is not None:
seqLabels = seqLabels + queryList
if queryList == None:
coreMat = pp_sketchlib.longToSquare(distVec=distMat[:, [0]],
num_threads=threads)
accMat = pp_sketchlib.longToSquare(distVec=distMat[:, [1]],
num_threads=threads)
else:
coreMat = pp_sketchlib.longToSquareMulti(distVec=distMat[:, [0]],
query_ref_distVec=query_ref_distMat[:, [0]],
query_query_distVec=query_query_distMat[:, [0]],
num_threads=threads)
accMat = pp_sketchlib.longToSquareMulti(distVec=distMat[:, [1]],
query_ref_distVec=query_ref_distMat[:, [1]],
query_query_distVec=query_query_distMat[:, [1]],
num_threads=threads)
# return outputs
return seqLabels, coreMat, accMat
[docs]def readRfile(rFile, oneSeq=False):
"""Reads in files for sketching. Names and sequence, tab separated
Args:
rFile (str)
File with locations of assembly files to be sketched
oneSeq (bool)
Return only the first sequence listed, rather than a list
(used with mash)
Returns:
names (list)
Array of sequence names
sequences (list of lists)
Array of sequence files
"""
names = []
sequences = []
with open(rFile, 'r') as refFile:
for refLine in refFile:
rFields = refLine.rstrip().split("\t")
if len(rFields) < 2:
sys.stderr.write("Input reference list is misformatted\n"
"Must contain sample name and file, tab separated\n")
sys.exit(1)
if "/" in rFields[0]:
sys.stderr.write("Sample names may not contain slashes\n")
sys.exit(1)
names.append(rFields[0])
sample_files = []
for sequence in rFields[1:]:
sample_files.append(sequence)
# Take first of sequence list
if oneSeq:
if len(sample_files) > 1:
sys.stderr.write("Multiple sequence found for " + rFields[0] +
". Only using first\n")
sequences.append(sample_files[0])
else:
sequences.append(sample_files)
# Process names to ensure compatibility with downstream software
names = isolateNameToLabel(names)
if len(set(names)) != len(names):
seen = set()
dupes = set(x for x in names if x in seen or seen.add(x))
sys.stderr.write("Input contains duplicate names! All names must be unique\n")
sys.stderr.write("Non-unique names are " + ",".join(dupes) + "\n")
sys.exit(1)
# Names are sorted on return
# We have had issues (though they should be fixed) with unordered input
# not matching the database. This should help simplify things
list_iterable = zip(names, sequences)
sorted_names = sorted(list_iterable)
tuples = zip(*sorted_names)
names, sequences = [list(r_tuple) for r_tuple in tuples]
return (names, sequences)
[docs]def isolateNameToLabel(names):
"""Function to process isolate names to labels
appropriate for visualisation.
Args:
names (list)
List of isolate names.
Returns:
labels (list)
List of isolate labels.
"""
# useful to have as a function in case we
# want to remove certain characters
labels = [name.split('/')[-1].replace('.','_').replace(':','').replace('(','_').replace(')','_') \
for name in names]
return labels
def createOverallLineage(rank_list, lineage_clusters):
# process multirank lineages
overall_lineages = {'Rank_' + str(rank):{} for rank in rank_list}
overall_lineages['overall'] = {}
isolate_list = lineage_clusters[rank_list[0]].keys()
for isolate in isolate_list:
overall_lineage = None
for rank in rank_list:
overall_lineages['Rank_' + str(rank)][isolate] = lineage_clusters[rank][isolate]
if overall_lineage is None:
overall_lineage = str(lineage_clusters[rank][isolate])
else:
overall_lineage = overall_lineage + '-' + str(lineage_clusters[rank][isolate])
overall_lineages['overall'][isolate] = overall_lineage
return overall_lineages
[docs]def decisionBoundary(intercept, gradient, adj = 0.0):
"""Returns the co-ordinates where the triangle the decision boundary forms
meets the x- and y-axes.
Args:
intercept (numpy.array)
Cartesian co-ordinates of point along line (:func:`~transformLine`)
which intercepts the boundary
gradient (float)
Gradient of the line
adj (float)
Fraction by which to shift the intercept up the y axis
Returns:
x (float)
The x-axis intercept
y (float)
The y-axis intercept
"""
if adj != 0.0:
original_hypotenuse = (intercept[0]**2 + intercept[1]**2)**0.5
length_ratio = (original_hypotenuse + adj)/original_hypotenuse
intercept[0] = intercept[0] * length_ratio
intercept[1] = intercept[1] * length_ratio
x = intercept[0] + intercept[1] * gradient
y = intercept[1] + intercept[0] / gradient
return(x, y)
[docs]def check_and_set_gpu(use_gpu, gpu_lib, quit_on_fail = False):
"""Check GPU libraries can be loaded and set managed memory.
Args:
use_gpu (bool)
Whether GPU packages have been requested
gpu_lib (bool)
Whether GPU packages are available
Returns:
use_gpu (bool)
Whether GPU packages can be used
"""
# load CUDA libraries
if use_gpu and not gpu_lib:
if quit_on_fail:
sys.stderr.write('Unable to load GPU libraries; exiting\n')
sys.exit(1)
else:
sys.stderr.write('Unable to load GPU libraries; using CPU libraries '
'instead\n')
use_gpu = False
# Set memory management for large networks
if use_gpu:
multiprocessing.set_start_method('spawn', force=True)
rmm.reinitialize(managed_memory=True)
if "cupy" in sys.modules:
cupy.cuda.set_allocator(rmm.allocators.cupy.rmm_cupy_allocator)
if "cuda" in sys.modules:
cuda.set_memory_manager(rmm.allocators.numba.RMMNumbaManager)
assert(rmm.is_initialized())
return use_gpu
[docs]def read_rlist_from_distance_pickle(fn, allow_non_self = True, include_queries = False):
"""Return the list of reference sequences from a distance pickle.
Args:
fn (str)
Name of distance pickle
allow_non_self (bool)
Whether non-self distance datasets are permissible
include_queries (bool)
Whether queries should be included in the rlist
Returns:
rlist (list)
List of reference sequence names
"""
with open(fn, 'rb') as pickle_file:
rlist, qlist, self = pickle.load(pickle_file)
if not allow_non_self and not self:
sys.stderr.write("Thi analysis requires an all-v-all"
" distance dataset\n")
sys.exit(1)
if include_queries:
rlist = rlist + qlist
return rlist
[docs]def get_match_search_depth(rlist,rank_list):
"""Return a default search depth for lineage model fitting.
Args:
rlist (list)
List of sequences in database
rank_list (list)
List of ranks to be used to fit lineage models
Returns:
max_search_depth (int)
Maximum kNN used for lineage model fitting
"""
# Defaults to maximum of 10% of database size, unless this is smaller than the maximum search rank
max_search_depth = max([int(0.1*len(rlist)),int(1.1*max(rank_list)),int(1+max(rank_list))])
# Cannot be higher than the number of comparisons
if max_search_depth > len(rlist) - 1:
max_search_depth = len(rlist) - 1
return max_search_depth