Module hashformers.beamsearch.model_lm

Expand source code
class ModelLM(object):

    def __init__(self, model_name_or_path=None, model_type=None, device=None, gpu_batch_size=None, gpu_id=0):
        self.gpu_batch_size = gpu_batch_size
        if model_type == 'gpt2':
            from hashformers.beamsearch.gpt2_lm import GPT2LM
            self.model = GPT2LM(model_name_or_path, device=device, gpu_batch_size=gpu_batch_size)
        elif model_type == 'bert':
            from hashformers.beamsearch.bert_lm import BertLM
            self.model = BertLM(model_name_or_path, gpu_batch_size=gpu_batch_size, gpu_id=gpu_id)

Classes

class ModelLM (model_name_or_path=None, model_type=None, device=None, gpu_batch_size=None, gpu_id=0)
Expand source code
class ModelLM(object):

    def __init__(self, model_name_or_path=None, model_type=None, device=None, gpu_batch_size=None, gpu_id=0):
        self.gpu_batch_size = gpu_batch_size
        if model_type == 'gpt2':
            from hashformers.beamsearch.gpt2_lm import GPT2LM
            self.model = GPT2LM(model_name_or_path, device=device, gpu_batch_size=gpu_batch_size)
        elif model_type == 'bert':
            from hashformers.beamsearch.bert_lm import BertLM
            self.model = BertLM(model_name_or_path, gpu_batch_size=gpu_batch_size, gpu_id=gpu_id)

Subclasses