Module hashformers.beamsearch.reranker
Expand source code
from hashformers.beamsearch.model_lm import ModelLM
from hashformers.beamsearch.data_structures import (
enforce_prob_dict,
ProbabilityDictionary
)
class Reranker(ModelLM):
def __init__(
self,
model_name_or_path='bert-base-uncased',
model_type='bert',
gpu_batch_size=1,
gpu_id=0
):
super().__init__(
model_name_or_path=model_name_or_path,
model_type=model_type,
gpu_batch_size=gpu_batch_size,
gpu_id=gpu_id
)
def rerank(
self,
data
):
input_data = enforce_prob_dict(data)
candidates = list(input_data.dictionary.keys())
scores = self.model.get_probs(candidates)
rank = { k:v for k,v in list(zip(candidates, scores))}
return ProbabilityDictionary(rank)
Classes
class Reranker (model_name_or_path='bert-base-uncased', model_type='bert', gpu_batch_size=1, gpu_id=0)
-
Expand source code
class Reranker(ModelLM): def __init__( self, model_name_or_path='bert-base-uncased', model_type='bert', gpu_batch_size=1, gpu_id=0 ): super().__init__( model_name_or_path=model_name_or_path, model_type=model_type, gpu_batch_size=gpu_batch_size, gpu_id=gpu_id ) def rerank( self, data ): input_data = enforce_prob_dict(data) candidates = list(input_data.dictionary.keys()) scores = self.model.get_probs(candidates) rank = { k:v for k,v in list(zip(candidates, scores))} return ProbabilityDictionary(rank)
Ancestors
Methods
def rerank(self, data)
-
Expand source code
def rerank( self, data ): input_data = enforce_prob_dict(data) candidates = list(input_data.dictionary.keys()) scores = self.model.get_probs(candidates) rank = { k:v for k,v in list(zip(candidates, scores))} return ProbabilityDictionary(rank)