import sqlite3
from os import path, makedirs, environ
import requests
import logging
from array import array
from io import StringIO
[docs]class Embedding:
[docs] @staticmethod
def path(p):
"""
Args:
p (str): relative path.
Returns:
str: absolute path to the file, located in the ``$EMBEDDINGS_ROOT`` directory.
"""
root = environ.get('EMBEDDINGS_ROOT', path.join(environ['HOME'], '.embeddings'))
return path.join(path.abspath(root), p)
[docs] @staticmethod
def download_file(url, local_filename):
"""
Downloads a file from an url to a local file.
Args:
url (str): url to download from.
local_filename (str): local file to download to.
Returns:
str: file name of the downloaded file.
"""
r = requests.get(url, stream=True, verify=False)
if path.dirname(local_filename) and not path.isdir(path.dirname(local_filename)):
raise Exception(local_filename)
makedirs(path.dirname(local_filename))
with open(local_filename, 'wb') as f:
for chunk in r.iter_content(chunk_size=1024):
if chunk:
f.write(chunk)
return local_filename
[docs] @staticmethod
def ensure_file(name, url=None, force=False, logger=logging.getLogger(), postprocess=None):
"""
Ensures that the file requested exists in the cache, downloading it if it does not exist.
Args:
name (str): name of the file.
url (str): url to download the file from, if it doesn't exist.
force (bool): whether to force the download, regardless of the existence of the file.
logger (logging.Logger): logger to log results.
postprocess (function): a function that, if given, will be applied after the file is downloaded. The function has the signature ``f(fname)``
Returns:
str: file name of the downloaded file.
"""
fname = Embedding.path(name)
if not path.isfile(fname) or force:
if url:
logger.critical('Downloading from {} to {}'.format(url, fname))
Embedding.download_file(url, fname)
if postprocess:
postprocess(fname)
else:
raise Exception('{} does not exist!'.format(fname))
return fname
[docs] @staticmethod
def initialize_db(fname):
"""
Args:
fname (str): location of the database.
Returns:
db (sqlite3.Connection): a SQLite3 database with an embeddings table.
"""
if path.dirname(fname) and not path.isdir(path.dirname(fname)):
makedirs(path.dirname(fname))
# open database in autocommit mode by setting isolation_level to None.
db = sqlite3.connect(fname, isolation_level=None)
c = db.cursor()
c.execute('create table if not exists embeddings(word text primary key, emb blob)')
return db
[docs] def load_memory(self):
# Read database to tempfile
tempfile = StringIO()
for line in self.db.iterdump():
tempfile.write('%s\n' % line)
self.db.close()
tempfile.seek(0)
# Create a database in memory and import from tempfile
# open database in autocommit mode by setting isolation_level to None.
self.db = sqlite3.connect(":memory:", isolation_level=None)
self.db.cursor().executescript(tempfile.read())
self.db.row_factory = sqlite3.Row
[docs] def __len__(self):
"""
Returns:
count (int): number of embeddings in the database.
"""
c = self.db.cursor()
q = c.execute('select count(*) from embeddings')
return q.fetchone()[0]
[docs] def insert_batch(self, batch):
"""
Args:
batch (list): a list of embeddings to insert, each of which is a tuple ``(word, embeddings)``.
Example:
.. code-block:: python
e = Embedding()
e.db = e.initialize_db(self.e.path('mydb.db'))
e.insert_batch([
('hello', [1, 2, 3]),
('world', [2, 3, 4]),
('!', [3, 4, 5]),
])
"""
c = self.db.cursor()
binarized = [(word, array('f', emb).tobytes()) for word, emb in batch]
try:
c.execute("BEGIN TRANSACTION;")
c.executemany("insert into embeddings values (?, ?)", binarized)
c.execute("COMMIT;")
except Exception as e:
print('insert failed\n{}'.format([w for w, e in batch]))
raise e
def __contains__(self, w):
"""
Args:
w: word to look up.
Returns:
whether an embedding for ``w`` exists.
"""
return self.lookup(w) is not None
[docs] def clear(self):
"""
Deletes all embeddings from the database.
"""
c = self.db.cursor()
c.execute('delete from embeddings')
[docs] def lookup(self, w):
"""
Args:
w: word to look up.
Returns:
embeddings for ``w``, if it exists.
``None``, otherwise.
"""
c = self.db.cursor()
q = c.execute('select emb from embeddings where word = :word', {'word': w}).fetchone()
return array('f', q[0]).tolist() if q else None