Source code for datasketch.lsh

from collections import defaultdict


_integration_precision = 0.001
def _integration(f, a, b):
    p = _integration_precision
    area = 0.0
    x = a
    while x < b:
        area += f(x+0.5*p)*p
        x += p
    return area, None

try:
    from scipy.integrate import quad as integrate
except ImportError:
    # For when no scipy installed
    integrate = _integration


def _false_positive_probability(threshold, b, r):
    _probability = lambda s : 1 - (1 - s**float(r))**float(b)
    a, err = integrate(_probability, 0.0, threshold) 
    return a


def _false_negative_probability(threshold, b, r):
    _probability = lambda s : 1 - (1 - (1 - s**float(r))**float(b))
    a, err = integrate(_probability, threshold, 1.0)
    return a


def _optimal_param(threshold, num_perm, false_positive_weight,
        false_negative_weight):
    '''
    Compute the optimal `MinHashLSH` parameter that minimizes the weighted sum
    of probabilities of false positive and false negative.
    '''
    min_error = float("inf")
    opt = (0, 0)
    for b in range(1, num_perm+1):
        max_r = int(num_perm / b)
        for r in range(1, max_r+1):
            fp = _false_positive_probability(threshold, b, r)
            fn = _false_negative_probability(threshold, b, r)
            error = fp*false_positive_weight + fn*false_negative_weight
            if error < min_error:
                min_error = error
                opt = (b, r)
    return opt


[docs]class MinHashLSH(object): ''' The Locality Sensitive Hashing index with MinHash (MinHash LSH). It supports query with `Jaccard similarity`_ threshold. Reference: `Chapter 3, Mining of Massive Datasets <http://www.mmds.org/>`_. Args: threshold (float): The Jaccard similarity threshold between 0.0 and 1.0. The initialized MinHash LSH will be optimized for the threshold by minizing the false positive and false negative. num_perm (int, optional): The number of permutation functions used by the MinHash to be indexed. For weighted MinHash, this is the sample size (`sample_size`). weights (tuple, optional): Used to adjust the relative importance of minizing false positive and false negative when optimizing for the Jaccard similarity threshold. `weights` is a tuple in the format of :code:`(false_positive_weight, false_negative_weight)`. Note: The MinHash LSH index also works with weighted Jaccard similarity and weighted MinHash without modification. ''' def __init__(self, threshold=0.9, num_perm=128, weights=(0.5,0.5)): if threshold > 1.0 or threshold < 0.0: raise ValueError("threshold must be in [0.0, 1.0]") if num_perm < 2: raise ValueError("Too few permutation functions") if any(w < 0.0 or w > 1.0 for w in weights): raise ValueError("Weight must be in [0.0, 1.0]") if sum(weights) != 1.0: raise ValueError("Weights must sum to 1.0") self.threshold = threshold self.h = num_perm false_positive_weight, false_negative_weight = weights self.b, self.r = _optimal_param(threshold, num_perm, false_positive_weight, false_negative_weight) self.hashtables = [defaultdict(list) for _ in range(self.b)] self.hashranges = [(i*self.r, (i+1)*self.r) for i in range(self.b)] self.keys = dict()
[docs] def insert(self, key, minhash): ''' Insert a unique key to the index, together with a MinHash (or weighted MinHash) of the set referenced by the key. Args: key (hashable): The unique identifier of the set. minhash (datasketch.MinHash): The MinHash of the set. ''' if len(minhash) != self.h: raise ValueError("Expecting minhash with length %d, got %d" % (self.h, len(minhash))) if key in self.keys: raise ValueError("The given key already exists") self.keys[key] = [self._H(minhash.hashvalues[start:end]) for start, end in self.hashranges] for H, hashtable in zip(self.keys[key], self.hashtables): hashtable[H].append(key)
[docs] def query(self, minhash): ''' Giving the MinHash of the query set, retrieve the keys that references sets with Jaccard similarities greater than the threshold. Args: minhash (datasketch.MinHash): The MinHash of the query set. Returns: `list` of keys. ''' if len(minhash) != self.h: raise ValueError("Expecting minhash with length %d, got %d" % (self.h, len(minhash))) candidates = set() for (start, end), hashtable in zip(self.hashranges, self.hashtables): H = self._H(minhash.hashvalues[start:end]) if H in hashtable: for key in hashtable[H]: candidates.add(key) return list(candidates)
[docs] def __contains__(self, key): ''' Args: key (hashable): The unique identifier of a set. Returns: bool: True only if the key exists in the index. ''' return key in self.keys
[docs] def remove(self, key): ''' Remove the key from the index. Args: key (hashable): The unique identifier of a set. ''' if key not in self.keys: raise ValueError("The given key does not exist") for H, hashtable in zip(self.keys[key], self.hashtables): hashtable[H].remove(key) if not hashtable[H]: del hashtable[H] self.keys.pop(key)
[docs] def is_empty(self): ''' Returns: bool: Check if the index is empty. ''' return any(len(t) == 0 for t in self.hashtables)
def _H(self, hs): return bytes(hs.byteswap().data)