Module hashformers.beamsearch.algorithm
Expand source code
import itertools
import re
from hashformers.beamsearch.data_structures import (
Node,
ProbabilityDictionary
)
from hashformers.beamsearch.model_lm import ModelLM
class Beamsearch(ModelLM):
def __init__(
self,
model_name_or_path=None,
model_type=None,
device='cuda',
gpu_batch_size=1):
super().__init__(
model_name_or_path=model_name_or_path,
model_type=model_type,
device=device,
gpu_batch_size=gpu_batch_size)
def next_step(self, list_of_candidates):
output = []
for candidate_string in list_of_candidates:
candidates = [
candidate_string[:pos] + ' ' + candidate_string[pos:] \
if pos else candidate_string for pos in range(len(candidate_string))
]
candidates = list(filter(lambda x: not re.findall(".*?(?=\s{2})", x), candidates))
output.extend(candidates)
return output
def update_probabilities(self, tree, prob_dict):
for item in tree:
current_batch = []
for word in item:
if word in prob_dict:
continue
else:
current_batch.append(word)
if current_batch:
current_batch_probs = self.model.get_probs(current_batch)
for idx, word in enumerate(current_batch):
prob_dict[word] = current_batch_probs[idx]
return prob_dict
def reshape_tree(self, tree, measure):
return [ tree[x:x+measure] for x in range(0, len(tree), measure) ]
def flatten_list(self, list_):
return [ item for sublist in list_ for item in sublist ]
def trim_tree(self, tree, prob_dict, topk):
output = []
probs = [ prob_dict[x] for x in tree ]
candidates = [
Node(item, item.replace(" ", ""), probs[idx]) for idx, item in enumerate(tree)
]
for key, group in itertools.groupby(candidates, key=lambda x: x.characters):
sorted_group = sorted(list(group), key=lambda x: x.score)
trimmed_group = sorted_group[0:topk]
trimmed_group = [x.hypothesis for x in trimmed_group]
output.extend(trimmed_group)
return output
def run(self, dataset, topk=20, steps=13):
tree = dataset
prob_dict = {}
for i in range(steps):
tree = self.next_step(tree)
tree = self.reshape_tree(tree, self.gpu_batch_size)
prob_dict = self.update_probabilities(tree, prob_dict)
tree = self.flatten_list(tree)
tree = self.trim_tree(tree, prob_dict, topk)
return ProbabilityDictionary(prob_dict)
Classes
class Beamsearch (model_name_or_path=None, model_type=None, device='cuda', gpu_batch_size=1)
-
Expand source code
class Beamsearch(ModelLM): def __init__( self, model_name_or_path=None, model_type=None, device='cuda', gpu_batch_size=1): super().__init__( model_name_or_path=model_name_or_path, model_type=model_type, device=device, gpu_batch_size=gpu_batch_size) def next_step(self, list_of_candidates): output = [] for candidate_string in list_of_candidates: candidates = [ candidate_string[:pos] + ' ' + candidate_string[pos:] \ if pos else candidate_string for pos in range(len(candidate_string)) ] candidates = list(filter(lambda x: not re.findall(".*?(?=\s{2})", x), candidates)) output.extend(candidates) return output def update_probabilities(self, tree, prob_dict): for item in tree: current_batch = [] for word in item: if word in prob_dict: continue else: current_batch.append(word) if current_batch: current_batch_probs = self.model.get_probs(current_batch) for idx, word in enumerate(current_batch): prob_dict[word] = current_batch_probs[idx] return prob_dict def reshape_tree(self, tree, measure): return [ tree[x:x+measure] for x in range(0, len(tree), measure) ] def flatten_list(self, list_): return [ item for sublist in list_ for item in sublist ] def trim_tree(self, tree, prob_dict, topk): output = [] probs = [ prob_dict[x] for x in tree ] candidates = [ Node(item, item.replace(" ", ""), probs[idx]) for idx, item in enumerate(tree) ] for key, group in itertools.groupby(candidates, key=lambda x: x.characters): sorted_group = sorted(list(group), key=lambda x: x.score) trimmed_group = sorted_group[0:topk] trimmed_group = [x.hypothesis for x in trimmed_group] output.extend(trimmed_group) return output def run(self, dataset, topk=20, steps=13): tree = dataset prob_dict = {} for i in range(steps): tree = self.next_step(tree) tree = self.reshape_tree(tree, self.gpu_batch_size) prob_dict = self.update_probabilities(tree, prob_dict) tree = self.flatten_list(tree) tree = self.trim_tree(tree, prob_dict, topk) return ProbabilityDictionary(prob_dict)
Ancestors
Methods
def flatten_list(self, list_)
-
Expand source code
def flatten_list(self, list_): return [ item for sublist in list_ for item in sublist ]
def next_step(self, list_of_candidates)
-
Expand source code
def next_step(self, list_of_candidates): output = [] for candidate_string in list_of_candidates: candidates = [ candidate_string[:pos] + ' ' + candidate_string[pos:] \ if pos else candidate_string for pos in range(len(candidate_string)) ] candidates = list(filter(lambda x: not re.findall(".*?(?=\s{2})", x), candidates)) output.extend(candidates) return output
def reshape_tree(self, tree, measure)
-
Expand source code
def reshape_tree(self, tree, measure): return [ tree[x:x+measure] for x in range(0, len(tree), measure) ]
def run(self, dataset, topk=20, steps=13)
-
Expand source code
def run(self, dataset, topk=20, steps=13): tree = dataset prob_dict = {} for i in range(steps): tree = self.next_step(tree) tree = self.reshape_tree(tree, self.gpu_batch_size) prob_dict = self.update_probabilities(tree, prob_dict) tree = self.flatten_list(tree) tree = self.trim_tree(tree, prob_dict, topk) return ProbabilityDictionary(prob_dict)
def trim_tree(self, tree, prob_dict, topk)
-
Expand source code
def trim_tree(self, tree, prob_dict, topk): output = [] probs = [ prob_dict[x] for x in tree ] candidates = [ Node(item, item.replace(" ", ""), probs[idx]) for idx, item in enumerate(tree) ] for key, group in itertools.groupby(candidates, key=lambda x: x.characters): sorted_group = sorted(list(group), key=lambda x: x.score) trimmed_group = sorted_group[0:topk] trimmed_group = [x.hypothesis for x in trimmed_group] output.extend(trimmed_group) return output
def update_probabilities(self, tree, prob_dict)
-
Expand source code
def update_probabilities(self, tree, prob_dict): for item in tree: current_batch = [] for word in item: if word in prob_dict: continue else: current_batch.append(word) if current_batch: current_batch_probs = self.model.get_probs(current_batch) for idx, word in enumerate(current_batch): prob_dict[word] = current_batch_probs[idx] return prob_dict