# vim: set fileencoding=<utf-8> :
# Copyright 2018-2023 John Lees and Nick Croucher
'''Sketchlib functions for database construction'''
# universal
import os
import sys
import subprocess
# additional
import re
from random import sample
import numpy as np
from scipy import optimize
import pp_sketchlib
import h5py
from .__init__ import SKETCHLIB_MAJOR, SKETCHLIB_MINOR, SKETCHLIB_PATCH
from .utils import readRfile, stderr_redirected
from .plot import plot_fit
sketchlib_exe = "sketchlib"
[docs]def checkSketchlibVersion():
"""Checks that sketchlib can be run, and returns version
Returns:
version (str)
Version string
"""
sketchlib_version = [0, 0, 0]
try:
version = pp_sketchlib.version
# Older versions didn't export attributes
except AttributeError:
try:
p = subprocess.Popen([sketchlib_exe + ' --version'], shell=True, stdout=subprocess.PIPE)
version = 0
for line in iter(p.stdout.readline, ''):
if line != '':
version = line.rstrip().decode().split(" ")[1]
break
except IndexError:
sys.stderr.write("WARNING: Sketchlib version could not be found\n")
version = re.sub(r'^v', '', version) # Remove leading v
sketchlib_version = [int(v) for v in version.split(".")]
if sketchlib_version[0] < SKETCHLIB_MAJOR or \
sketchlib_version[0] == SKETCHLIB_MAJOR and sketchlib_version[1] < SKETCHLIB_MINOR or \
sketchlib_version[0] == SKETCHLIB_MAJOR and sketchlib_version[1] == SKETCHLIB_MINOR and sketchlib_version[2] < SKETCHLIB_PATCH:
sys.stderr.write("This version of PopPUNK requires sketchlib "
"v" + str(SKETCHLIB_MAJOR) + \
"." + str(SKETCHLIB_MINOR) + \
"." + str(SKETCHLIB_PATCH) + " or higher\n")
sys.stderr.write("Continuing... but safety not guaranteed\n")
return version
[docs]def checkSketchlibLibrary():
"""Gets the location of the sketchlib library
Returns:
lib (str)
Location of sketchlib .so/.dyld
"""
sketchlib_loc = pp_sketchlib.__file__
return(sketchlib_loc)
[docs]def createDatabaseDir(outPrefix, kmers):
"""Creates the directory to write sketches to, removing old files if unnecessary
Args:
outPrefix (str)
output db prefix
kmers (list)
k-mer sizes in db
"""
# check for writing
if os.path.isdir(outPrefix):
# remove old database files if not needed
db_file = outPrefix + "/" + os.path.basename(outPrefix) + ".h5"
if os.path.isfile(db_file):
ref_db = h5py.File(db_file, 'r')
for sample_name in list(ref_db['sketches'].keys()):
knum = ref_db['sketches/' + sample_name].attrs['kmers']
remove_prev_db = False
for kmer_length in knum:
if not (kmer_length in knum):
sys.stderr.write("Previously-calculated k-mer size " + str(kmer_length) +
" not in requested range (" + str(knum) + ")\n")
remove_prev_db = True
break
if remove_prev_db:
sys.stderr.write("Removing old database " + db_file + "\n")
os.remove(db_file)
break
else:
try:
os.makedirs(outPrefix)
except OSError:
sys.stderr.write("Cannot create output directory\n")
sys.exit(1)
[docs]def getSketchSize(dbPrefix):
"""Determine sketch size, and ensures consistent in whole database
``sys.exit(1)`` is called if DBs have different sketch sizes
Args:
dbprefix (str)
Prefix for databases
Returns:
sketchSize (int)
sketch size (64x C++ definition)
codonPhased (bool)
whether the DB used codon phased seeds
"""
db_file = dbPrefix + "/" + os.path.basename(dbPrefix) + ".h5"
ref_db = h5py.File(db_file, 'r')
try:
codon_phased = ref_db['sketches'].attrs['codon_phased']
except KeyError:
codon_phased = False
prev_sketch = 0
for sample_name in list(ref_db['sketches'].keys()):
sketch_size = ref_db['sketches/' + sample_name].attrs['sketchsize64']
if prev_sketch == 0:
prev_sketch = sketch_size
elif sketch_size != prev_sketch:
sys.stderr.write("Problem with database; sketch sizes for sample " +
sample_name + " is " + str(prev_sketch) +
", but smaller kmers have sketch sizes of " + str(sketch_size) + "\n")
sys.exit(1)
return int(sketch_size), codon_phased
[docs]def getKmersFromReferenceDatabase(dbPrefix):
"""Get kmers lengths from existing database
Args:
dbPrefix (str)
Prefix for sketch DB files
Returns:
kmers (list)
List of k-mer lengths used in database
"""
db_file = dbPrefix + "/" + os.path.basename(dbPrefix) + ".h5"
ref_db = h5py.File(db_file, 'r')
prev_kmer_sizes = []
for sample_name in list(ref_db['sketches'].keys()):
kmer_size = ref_db['sketches/' + sample_name].attrs['kmers']
if len(prev_kmer_sizes) == 0:
prev_kmer_sizes = kmer_size
elif np.any(kmer_size != prev_kmer_sizes):
sys.stderr.write("Problem with database; kmer lengths inconsistent: " +
str(kmer_size) + " vs " + str(prev_kmer_sizes) + "\n")
sys.exit(1)
prev_kmer_sizes.sort()
kmers = np.asarray(prev_kmer_sizes)
return kmers
[docs]def readDBParams(dbPrefix):
"""Get kmers lengths and sketch sizes from existing database
Calls :func:`~getKmersFromReferenceDatabase` and :func:`~getSketchSize`
Uses passed values if db missing
Args:
dbPrefix (str)
Prefix for sketch DB files
Returns:
kmers (list)
List of k-mer lengths used in database
sketch_sizes (list)
List of sketch sizes used in database
codonPhased (bool)
whether the DB used codon phased seeds
"""
db_kmers = getKmersFromReferenceDatabase(dbPrefix)
if len(db_kmers) == 0:
sys.stderr.write("Couldn't find sketches in " + dbPrefix + "\n")
sys.exit(1)
else:
sketch_sizes, codon_phased = getSketchSize(dbPrefix)
return db_kmers, sketch_sizes, codon_phased
[docs]def getSeqsInDb(dbname):
"""Return an array with the sequences in the passed database
Args:
dbname (str)
Sketches database filename
Returns:
seqs (list)
List of sequence names in sketch DB
"""
seqs = []
ref = h5py.File(dbname, 'r')
for sample_name in list(ref['sketches'].keys()):
seqs.append(sample_name)
return seqs
[docs]def joinDBs(db1, db2, output, update_random = None, full_names = False):
"""Join two sketch databases with the low-level HDF5 copy interface
Args:
db1 (str)
Prefix for db1
db2 (str)
Prefix for db2
output (str)
Prefix for joined output
update_random (dict)
Whether to re-calculate the random object. May contain
control arguments strand_preserved and threads (see :func:`addRandom`)
full_names (bool)
If True, db_name and out_name are the full paths to h5 files
"""
if not full_names:
join_prefix = output + "/" + os.path.basename(output)
db1_name = db1 + "/" + os.path.basename(db1) + ".h5"
db2_name = db2 + "/" + os.path.basename(db2) + ".h5"
else:
db1_name = db1
db2_name = db2
join_prefix = output
hdf1 = h5py.File(db1_name, 'r')
hdf2 = h5py.File(db2_name, 'r')
hdf_join = h5py.File(join_prefix + ".tmp.h5", 'w') # add .tmp in case join_name exists
# Can only copy into new group, so for second file these are appended one at a time
try:
hdf1.copy('sketches', hdf_join)
join_grp = hdf_join['sketches']
read_grp = hdf2['sketches']
for dataset in read_grp:
join_grp.copy(read_grp[dataset], dataset)
# Copy or update random matches
if update_random is not None:
threads = 1
strand_preserved = False
if isinstance(update_random, dict):
if "threads" in update_random:
threads = update_random["threads"]
if "strand_preserved" in update_random:
strand_preserved = update_random["strand_preserved"]
sequence_names = list(hdf_join['sketches'].keys())
kmer_size = hdf_join['sketches/' + sequence_names[0]].attrs['kmers']
# Need to close before adding random
hdf_join.close()
if len(sequence_names) > 2:
sys.stderr.write("Updating random match chances\n")
pp_sketchlib.addRandom(db_name=join_prefix + ".tmp",
samples=sequence_names,
klist=kmer_size,
use_rc=(not strand_preserved),
num_threads=threads)
elif 'random' in hdf1:
hdf1.copy('random', hdf_join)
# Clean up
hdf1.close()
hdf2.close()
if update_random is None:
hdf_join.close()
except RuntimeError as e:
sys.stderr.write("ERROR: " + str(e) + "\n")
sys.stderr.write("Joining sketches failed, try running without --update-db\n")
sys.exit(1)
# Rename results to correct location
os.rename(join_prefix + ".tmp.h5", join_prefix + ".h5")
[docs]def removeFromDB(db_name, out_name, removeSeqs, full_names = False):
"""Remove sketches from the DB the low-level HDF5 copy interface
Args:
db_name (str)
Prefix for hdf database
out_name (str)
Prefix for output (pruned) database
removeSeqs (list)
Names of sequences to remove from database
full_names (bool)
If True, db_name and out_name are the full paths to h5 files
"""
removeSeqs = set(removeSeqs)
if not full_names:
db_file = db_name + "/" + os.path.basename(db_name) + ".h5"
out_file = out_name + "/" + os.path.basename(out_name) + ".tmp.h5"
else:
db_file = db_name
out_file = out_name
hdf_in = h5py.File(db_file, 'r')
hdf_out = h5py.File(out_file, 'w')
try:
if 'random' in hdf_in.keys():
hdf_in.copy('random', hdf_out)
out_grp = hdf_out.create_group('sketches')
read_grp = hdf_in['sketches']
for attr_name, attr_val in read_grp.attrs.items():
out_grp.attrs.create(attr_name, attr_val)
removed = []
for dataset in read_grp:
if dataset not in removeSeqs:
out_grp.copy(read_grp[dataset], dataset)
else:
removed.append(dataset)
except RuntimeError as e:
sys.stderr.write("ERROR: " + str(e) + "\n")
sys.stderr.write("Error while deleting sequence " + dataset + "\n")
sys.exit(1)
missed = removeSeqs.difference(set(removed))
if len(missed) > 0:
sys.stderr.write("WARNING: Did not find samples to remove:\n")
sys.stderr.write("\t".join(missed) + "\n")
# Clean up
hdf_in.close()
hdf_out.close()
[docs]def constructDatabase(assemblyList, klist, sketch_size, oPrefix,
threads, overwrite,
strand_preserved, min_count,
use_exact, calc_random = True,
codon_phased = False,
use_gpu = False, deviceid = 0):
"""Sketch the input assemblies at the requested k-mer lengths
A multithread wrapper around :func:`~runSketch`. Threads are used to either run multiple sketch
processes for each klist value.
Also calculates random match probability based on length of first genome
in assemblyList.
Args:
assemblyList (str)
File with locations of assembly files to be sketched
klist (list)
List of k-mer sizes to sketch
sketch_size (int)
Size of sketch (``-s`` option)
oPrefix (str)
Output prefix for resulting sketch files
threads (int)
Number of threads to use (default = 1)
overwrite (bool)
Whether to overwrite sketch DBs, if they already exist.
(default = False)
strand_preserved (bool)
Ignore reverse complement k-mers (default = False)
min_count (int)
Minimum count of k-mer in reads to include
(default = 0)
use_exact (bool)
Use exact count of k-mer appearance in reads
(default = False)
calc_random (bool)
Add random match chances to DB (turn off for queries)
codon_phased (bool)
Use codon phased seeds
(default = False)
use_gpu (bool)
Use GPU for read sketching
(default = False)
deviceid (int)
GPU device id
(default = 0)
Returns:
names (list)
List of names included in the database (from rfile)
"""
# read file names
names, sequences = readRfile(assemblyList)
# create directory
dbname = oPrefix + "/" + os.path.basename(oPrefix)
dbfilename = dbname + ".h5"
if os.path.isfile(dbfilename) and overwrite == True:
sys.stderr.write("Overwriting db: " + dbfilename + "\n")
os.remove(dbfilename)
# generate sketches
pp_sketchlib.constructDatabase(db_name=dbname,
samples=names,
files=sequences,
klist=klist,
sketch_size=sketch_size,
codon_phased=codon_phased,
calc_random=False,
use_rc=not strand_preserved,
min_count=min_count,
exact=use_exact,
num_threads=threads,
use_gpu=use_gpu,
device_id=deviceid)
# Add random matches if required
# (typically on for reference, off for query)
if (calc_random):
addRandom(oPrefix,
names,
klist,
strand_preserved,
overwrite = True,
threads = threads)
return names
[docs]def addRandom(oPrefix, sequence_names, klist,
strand_preserved = False, overwrite = False, threads = 1):
"""Add chance of random match to a HDF5 sketch DB
Args:
oPrefix (str)
Sketch database prefix
sequence_names (list)
Names of sequences to include in calculation
klist (list)
List of k-mer sizes to sketch
strand_preserved (bool)
Set true to ignore rc k-mers
overwrite (str)
Set true to overwrite existing random match chances
threads (int)
Number of threads to use (default = 1)
"""
if len(sequence_names) <= 2:
sys.stderr.write("Cannot add random match chances with this few genomes\n")
else:
dbname = oPrefix + "/" + os.path.basename(oPrefix)
hdf_in = h5py.File(dbname + ".h5", 'r+')
if 'random' in hdf_in:
if overwrite:
del hdf_in['random']
else:
sys.stderr.write("Using existing random match chances in DB\n")
return
hdf_in.close()
pp_sketchlib.addRandom(db_name=dbname,
samples=sequence_names,
klist=klist,
use_rc=(not strand_preserved),
num_threads=threads)
[docs]def queryDatabase(rNames, qNames, dbPrefix, queryPrefix, klist, self = True, number_plot_fits = 0,
threads = 1, use_gpu = False, deviceid = 0):
"""Calculate core and accessory distances between query sequences and a sketched database
For a reference database, runs the query against itself to find all pairwise
core and accessory distances.
Uses the relation :math:`pr(a, b) = (1-a)(1-c)^k`
To get the ref and query name for each row of the returned distances, call to the iterator
:func:`~PopPUNK.utils.iterDistRows` with the returned refList and queryList
Args:
rNames (list)
Names of references to query
qNames (list)
Names of queries
dbPrefix (str)
Prefix for reference sketch database created by :func:`~constructDatabase`
queryPrefix (str)
Prefix for query sketch database created by :func:`~constructDatabase`
klist (list)
K-mer sizes to use in the calculation
self (bool)
Set true if query = ref
(default = True)
number_plot_fits (int)
If > 0, the number of k-mer length fits to plot (saved as pdfs).
Takes random pairs of comparisons and calls :func:`~PopPUNK.plot.plot_fit`
(default = 0)
threads (int)
Number of threads to use in the process
(default = 1)
use_gpu (bool)
Use a GPU for querying
(default = False)
deviceid (int)
Index of the CUDA GPU device to use
(default = 0)
Returns:
distMat (numpy.array)
Core distances (column 0) and accessory distances (column 1) between
refList and queryList
"""
ref_db = dbPrefix + "/" + os.path.basename(dbPrefix)
if self:
if dbPrefix != queryPrefix:
raise RuntimeError("Must use same db for self query")
qNames = rNames
# Calls to library
distMat = pp_sketchlib.queryDatabase(ref_db_name=ref_db,
query_db_name=ref_db,
rList=rNames,
qList=rNames,
klist=klist,
random_correct=True,
jaccard=False,
num_threads=threads,
use_gpu=use_gpu,
device_id=deviceid)
# option to plot core/accessory fits. Choose a random number from cmd line option
if number_plot_fits > 0:
jacobian = -np.hstack((np.ones((klist.shape[0], 1)), klist.reshape(-1, 1)))
for plot_idx in range(number_plot_fits):
example = sample(rNames, k=2)
raw = np.zeros(len(klist))
corrected = np.zeros(len(klist))
with stderr_redirected(): # Hide the many progress bars
raw = pp_sketchlib.queryDatabase(ref_db_name=ref_db,
query_db_name=ref_db,
rList=[example[0]],
qList=[example[1]],
klist=klist,
random_correct=False,
jaccard=True,
num_threads=threads,
use_gpu = False)
corrected = pp_sketchlib.queryDatabase(ref_db_name=ref_db,
query_db_name=ref_db,
rList=[example[0]],
qList=[example[1]],
klist=klist,
random_correct=True,
jaccard=True,
num_threads=threads,
use_gpu = False)
raw_fit = fitKmerCurve(raw[0], klist, jacobian)
corrected_fit = fitKmerCurve(corrected[0], klist, jacobian)
plot_fit(klist,
raw[0],
raw_fit,
corrected[0],
corrected_fit,
ref_db + "_fit_example_" + str(plot_idx + 1),
"Example fit " + str(plot_idx + 1) + " - " + example[0] + " vs. " + example[1])
else:
duplicated = set(rNames).intersection(set(qNames))
if len(duplicated) > 0:
sys.stderr.write("Sample names in query are contained in reference database:\n")
sys.stderr.write("\n".join(duplicated))
sys.stderr.write("Unique names are required!\n")
sys.exit(1)
# Calls to library
query_db = queryPrefix + "/" + os.path.basename(queryPrefix)
distMat = pp_sketchlib.queryDatabase(ref_db_name=ref_db,
query_db_name=query_db,
rList=rNames,
qList=qNames,
klist=klist,
random_correct=True,
jaccard=False,
num_threads=threads,
use_gpu=use_gpu,
device_id=deviceid)
# option to plot core/accessory fits. Choose a random number from cmd line option
if number_plot_fits > 0:
jacobian = -np.hstack((np.ones((klist.shape[0], 1)), klist.reshape(-1, 1)))
ref_examples = sample(rNames, k = number_plot_fits)
query_examples = sample(qNames, k = number_plot_fits)
with stderr_redirected(): # Hide the many progress bars
raw = pp_sketchlib.queryDatabase(ref_db_name=ref_db,
query_db_name=query_db,
rList=ref_examples,
qList=query_examples,
klist=klist,
random_correct=False,
jaccard=True,
num_threads=threads,
use_gpu = False)
corrected = pp_sketchlib.queryDatabase(ref_db_name=ref_db,
query_db_name=query_db,
rList=ref_examples,
qList=query_examples,
klist=klist,
random_correct=True,
jaccard=True,
num_threads=threads,
use_gpu = False)
for plot_idx in range(number_plot_fits):
raw_fit = fitKmerCurve(raw[plot_idx], klist, jacobian)
corrected_fit = fitKmerCurve(corrected[plot_idx], klist, jacobian)
plot_fit(klist,
raw[plot_idx],
raw_fit,
corrected[plot_idx],
corrected_fit,
os.path.join(os.path.dirname(queryPrefix),
os.path.basename(queryPrefix) + "_fit_example_" + str(plot_idx + 1)),
"Example fit " + str(plot_idx + 1) + " - " + ref_examples[plot_idx] + \
" vs. " + query_examples[plot_idx])
return distMat
[docs]def fitKmerCurve(pairwise, klist, jacobian):
"""Fit the function :math:`pr = (1-a)(1-c)^k`
Supply ``jacobian = -np.hstack((np.ones((klist.shape[0], 1)), klist.reshape(-1, 1)))``
Args:
pairwise (numpy.array)
Proportion of shared k-mers at k-mer values in klist
klist (list)
k-mer sizes used
jacobian (numpy.array)
Should be set as above (set once to try and save memory)
Returns:
transformed_params (numpy.array)
Column with core and accessory distance
"""
# curve fit pr = (1-a)(1-c)^k
# log pr = log(1-a) + k*log(1-c)
# a = p[0]; c = p[1] (will flip on return)
try:
distFit = optimize.least_squares(fun=lambda p, x, y: y - (p[0] + p[1] * x),
x0=[0.0, -0.01],
jac=lambda p, x, y: jacobian,
args=(klist, np.log(pairwise)),
bounds=([-np.inf, -np.inf], [0, 0]))
transformed_params = 1 - np.exp(distFit.x)
except ValueError as e:
sys.stderr.write("Fitting k-mer curve failed: " + format(e) +
"\nWith k-mer match values " +
np.array2string(pairwise, precision=4, separator=',',suppress_small=True) +
"\nCheck for low quality input genomes\n")
transformed_params = [0, 0]
# Return core, accessory
return(np.flipud(transformed_params))
[docs]def get_database_statistics(prefix):
"""Extract statistics for evaluating databases.
Args:
prefix (str)
Prefix of database
"""
db_file = prefix + "/" + os.path.basename(prefix) + ".h5"
ref_db = h5py.File(db_file, 'r')
genome_lengths = []
ambiguous_bases = []
for sample_name in list(ref_db['sketches'].keys()):
genome_lengths.append(ref_db['sketches/' + sample_name].attrs['length'])
ambiguous_bases.append(ref_db['sketches/' + sample_name].attrs['missing_bases'])
return genome_lengths, ambiguous_bases