Source code for embeddings.kazuma

import tarfile
import numpy as np
from tqdm import tqdm
from embeddings.embedding import Embedding


[docs]def ngrams(sentence, n): """ Returns: list: a list of lists of words corresponding to the ngrams in the sentence. """ return [sentence[i:i+n] for i in range(len(sentence)-n+1)]
[docs]class KazumaCharEmbedding(Embedding): """ Reference: https://www.logos.t.u-tokyo.ac.jp/~hassy/publications/arxiv2016jmt/ """ url = 'https://www.logos.t.u-tokyo.ac.jp/~hassy/publications/arxiv2016jmt/jmt_pre-trained_embeddings.tar.gz' size = 874474 d_emb = 100
[docs] def __init__(self, show_progress=True): """ Args: show_progress: whether to print progress. """ self.db = self.initialize_db(self.path('kazuma.db')) if len(self) < self.size: self.clear() self.load_word2emb(show_progress=show_progress)
[docs] def emb(self, w, default='zero'): assert default == 'zero', 'only zero default is supported for character embeddings' chars = ['#BEGIN#'] + list(w) + ['#END#'] embs = np.zeros(self.d_emb, dtype=np.float32) match = {} for i in [2, 3, 4]: grams = ngrams(chars, i) for g in grams: g = '{}gram-{}'.format(i, ''.join(g)) e = self.lookup(g) if e is not None: match[g] = np.array(e, np.float32) if match: embs = sum(match.values()) / len(match) return embs.tolist()
[docs] def load_word2emb(self, show_progress=True, batch_size=1000): fin_name = self.ensure_file('kazuma.tar.gz', url=self.url) seen = set() with tarfile.open(fin_name, 'r:gz') as fzip: ftxt = fzip.extractfile('charNgram.txt') content = ftxt.read() ftxt.close() lines = content.splitlines() if show_progress: lines = tqdm(lines) batch = [] for line in lines: elems = line.decode().rstrip().split() vec = [float(n) for n in elems[-self.d_emb:]] word = ' '.join(elems[:-self.d_emb]) if word in seen: continue seen.add(word) batch.append((word, vec)) if len(batch) == batch_size: self.insert_batch(batch) batch.clear() if batch: self.insert_batch(batch)
if __name__ == '__main__': from time import time emb = KazumaCharEmbedding(show_progress=True) for w in ['canada', 'vancouver', 'toronto']: start = time() print('embedding {}'.format(w)) # print(emb.emb(w)) print('took {}s'.format(time() - start))