Module ktrain.text.qa.core

Expand source code
from ...imports import *
from ... import utils as U
from .. import textutils as TU
from .. import preprocessor as tpp


from whoosh import index
from whoosh.fields import *
from whoosh import qparser
from whoosh.qparser import QueryParser


#from transformers import TFBertForQuestionAnswering
#from transformers import BertTokenizer
from transformers import TFAutoModelForQuestionAnswering
from transformers import AutoTokenizer
LOWCONF = -10000

def _answers2df(answers):
    dfdata = []
    for a in answers:
        answer_text = a['answer']
        snippet_html = '<div>' +a['sentence_beginning'] + " <font color='red'>"+a['answer']+"</font> "+a['sentence_end']+'</div>'
        confidence = a['confidence']
        doc_key = a['reference']
        dfdata.append([answer_text, snippet_html, confidence, doc_key])
    df = pd.DataFrame(dfdata, columns = ['Candidate Answer', 'Context',  'Confidence', 'Document Reference'])
    if "\t" in answers[0]['reference']:
        df['Document Reference'] = df['Document Reference'].apply(lambda x: '<a href="{}" target="_blank">{}</a>'.format(x.split('\t')[1], x.split('\t')[0]))
    return df



def display_answers(answers):
    if not answers: return
    df = _answers2df(answers)
    from IPython.core.display import display, HTML
    return display(HTML(df.to_html(render_links=True, escape=False)))


def _process_question(question, include_np=False):
    if include_np:
        try:
            # attempt to use extract_noun_phrases first if textblob is installed
            np_list = ['"%s"' % (np) for np in TU.extract_noun_phrases(question) if len(np.split()) > 1]
            q_tokens = TU.tokenize(question, join_tokens=False)
            q_tokens.extend(np_list)
            return " ".join(q_tokens)
        except:
            import warnings
            warnings.warn('TextBlob is not currently installed, so falling back to include_np=False with no extra question processing. '+\
                          'To install: pip install textblob')
            return TU.tokenize(question, join_tokens=True)
    else:
        return TU.tokenize(question, join_tokens=True)



class QA(ABC):
    """
    Base class for QA
    """

    def __init__(self, bert_squad_model='bert-large-uncased-whole-word-masking-finetuned-squad',
                 bert_emb_model='bert-base-uncased'):
        self.model_name = bert_squad_model
        try:
            self.model = TFAutoModelForQuestionAnswering.from_pretrained(self.model_name)
        except:
            self.model = TFAutoModelForQuestionAnswering.from_pretrained(self.model_name, from_pt=True)
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
        self.maxlen = 512
        self.te = tpp.TransformerEmbedding(bert_emb_model, layers=[-2])


    @abstractmethod
    def search(self, query):
        pass

    def predict_squad(self, documents, question):
        """ 
        Generates candidate answers to the <question> provided given <documents> as contexts.
        """
        if isinstance(documents, str): documents = [documents]
        sequences = [[question, d] for d in documents]
        batch = self.tokenizer.batch_encode_plus(sequences, return_tensors='tf', max_length=512, truncation='only_second', padding=True)
        tokens_batch = list( map(self.tokenizer.convert_ids_to_tokens, batch['input_ids']))

        # Added from: https://github.com/huggingface/transformers/commit/16ce15ed4bd0865d24a94aa839a44cf0f400ef50
        if U.get_hf_model_name(self.model_name) in  ['xlm', 'roberta', 'distilbert']:
           start_scores, end_scores = self.model(batch['input_ids'], attention_mask=batch['attention_mask'], return_dict=False)
        else:
           start_scores, end_scores = self.model(batch['input_ids'], attention_mask=batch['attention_mask'], 
                                                 token_type_ids=batch['token_type_ids'], return_dict=False)
        start_scores = start_scores[:,1:-1]
        end_scores = end_scores[:,1:-1]
        answer_starts = np.argmax(start_scores, axis=1)
        answer_ends = np.argmax(end_scores, axis=1)

        answers = []
        for i, tokens in enumerate(tokens_batch):
            answer_start = answer_starts[i]
            answer_end = answer_ends[i]
            answer = self._reconstruct_text(tokens, answer_start, answer_end+2)
            if answer.startswith('. ') or answer.startswith(', '):
                answer = answer[2:]  
            sep_index = tokens.index('[SEP]')
            full_txt_tokens = tokens[sep_index+1:]
            paragraph_bert = self._reconstruct_text(full_txt_tokens)

            ans={}
            ans['answer'] = answer
            if answer.startswith('[CLS]') or answer_end < sep_index or answer.endswith('[SEP]'):
                ans['confidence'] = LOWCONF
            else:
                #confidence = torch.max(start_scores) + torch.max(end_scores)
                #confidence = np.log(confidence.item())
                ans['confidence'] = start_scores[i,answer_start]+end_scores[i,answer_end]
            ans['start'] = answer_start
            ans['end'] = answer_end
            ans['context'] = paragraph_bert
            answers.append(ans)
        #if len(answers) == 1: answers = answers[0]
        return answers




    def _reconstruct_text(self, tokens, start=0, stop=-1):
        """
        Reconstruct text of *either* question or answer
        """
        tokens = tokens[start: stop]
        #if '[SEP]' in tokens:
            #sepind = tokens.index('[SEP]')
            #tokens = tokens[sepind+1:]
        txt = ' '.join(tokens)
        txt = txt.replace('[SEP]', '') # added for batch_encode_plus - removes [SEP] before [PAD]
        txt = txt.replace('[PAD]', '') # added for batch_encode_plus - removes [PAD]
        txt = txt.replace(' ##', '')
        txt = txt.replace('##', '')
        txt = txt.strip()
        txt = " ".join(txt.split())
        txt = txt.replace(' .', '.')
        txt = txt.replace('( ', '(')
        txt = txt.replace(' )', ')')
        txt = txt.replace(' - ', '-')
        txt_list = txt.split(' , ')
        txt = ''
        length = len(txt_list)
        if length == 1:
            return txt_list[0]
        new_list =[]
        for i,t in enumerate(txt_list):
            if i < length -1:
                if t[-1].isdigit() and txt_list[i+1][0].isdigit():
                    new_list += [t,',']
                else:
                    new_list += [t, ', ']
            else:
                new_list += [t]
        return ''.join(new_list)


    def _expand_answer(self, answer):
        """
        expand answer to include more of the context
        """
        full_abs = answer['context']
        bert_ans = answer['answer']
        split_abs = full_abs.split(bert_ans)
        sent_beginning = split_abs[0][split_abs[0].rfind('.')+1:]
        if len(split_abs) == 1:
            sent_end_pos = len(full_abs)
            sent_end =''
        else:
            sent_end_pos = split_abs[1].find('. ')+1
            if sent_end_pos == 0:
                sent_end = split_abs[1]
            else:
                sent_end = split_abs[1][:sent_end_pos]
            
        answer['full_answer'] = sent_beginning+bert_ans+sent_end
        answer['full_answer'] = answer['full_answer'].strip()
        answer['sentence_beginning'] = sent_beginning
        answer['sentence_end'] = sent_end
        return answer



    def ask(self, question, query=None, batch_size=8, n_docs_considered=10, n_answers=50, 
            rerank_threshold=0.015, include_np=False):
        """
        ```
        submit question to obtain candidate answers

        Args:
          question(str): question in the form of a string
          query(str): Optional. If not None, words in query will be used to retrieve contexts instead of words in question
          batch_size(int):  number of question-context pairs fed to model at each iteration
                            Default:8
                            Increase for faster answer-retrieval.
                            Decrease to reduce memory (if out-of-memory errors occur).
          n_docs_considered(int): number of top search results that will
                                  be searched for answer
                                  default:10
          n_answers(int): maximum number of candidate answers to return
                          default:50
          rerank_threshold(int): rerank top answers with confidence >= rerank_threshold
                                 based on semantic similarity between question and answer.
                                 This can help bump the correct answer closer to the top.
                                 default:0.015.
                                 If None, no re-ranking is performed.
          include_np(bool):  If True, noun phrases will be extracted from question and included
                             in query that retrieves documents likely to contain candidate answers.
                             This may be useful if you ask a question about artificial intelligence
                             and the answers returned pertain just to intelligence, for example.
                             Note: include_np=True requires textblob be installed.
                             Default:False
        Returns:
          list
        ```
        """
        # locate candidate document contexts
        paragraphs = []
        refs = []
        #doc_results = self.search(question, limit=n_docs_considered)
        doc_results = self.search(_process_question(query if query is not None else question, include_np=include_np), limit=n_docs_considered)
        if not doc_results: 
            warnings.warn('No documents matched words in question (or query if supplied)')
            return []
        # extract paragraphs as contexts
        contexts = []
        refs = []
        for doc_result in doc_results:
            rawtext = doc_result.get('rawtext', '')
            reference = doc_result.get('reference', '')
            if len(self.tokenizer.tokenize(rawtext)) < self.maxlen:
                contexts.append(rawtext)
                refs.append(reference)
            else:
                paragraphs = TU.paragraph_tokenize(rawtext, join_sentences=True)
                contexts.extend(paragraphs)
                refs.extend([reference] * len(paragraphs))


        #for doc_result in doc_results:
            #rawtext = doc_result.get('rawtext', '')
            #reference = doc_result.get('reference', '')
            #if len(self.tokenizer.tokenize(rawtext)) < self.maxlen:
                #paragraphs.append(rawtext)
                #refs.append(reference)
                #continue
            #plist = TU.paragraph_tokenize(rawtext, join_sentences=True)
            #paragraphs.extend(plist)
            #refs.extend([reference]*len(plist))


        # batchify contexts
        #return contexts
        if batch_size  > len(contexts): batch_size = len(contexts)
        #if len(contexts) >= 100 and batch_size==8:
            #warnings.warn('TIP: Try increasing batch_size to speedup ask predictions')
        num_chunks = math.ceil(len(contexts)/batch_size)
        context_batches = list( U.list2chunks(contexts, n=num_chunks) )


        # locate candidate answers
        answers = []
        mb = master_bar(range(1))
        answer_batches = []
        for i in mb:
            idx = 0
            for batch_id, contexts in enumerate(progress_bar(context_batches, parent=mb)):
                answer_batch = self.predict_squad(contexts, question)
                answer_batches.extend(answer_batch)
                for answer in answer_batch:
                    idx+=1
                    if not answer['answer'] or answer['confidence'] <-100: continue
                    answer['confidence'] = answer['confidence'].numpy()
                    answer['reference'] = refs[idx-1]
                    answer = self._expand_answer(answer)
                    answers.append(answer)

                mb.child.comment = f'generating candidate answers'

        if not answers: return answers # fix for #307
        answers = sorted(answers, key = lambda k:k['confidence'], reverse=True)
        if n_answers is not None:
            answers = answers[:n_answers]

        # transform confidence scores
        confidences = [a['confidence'] for a in answers]
        max_conf = max(confidences)
        total = 0.0
        exp_scores = []
        for c in confidences:
            s = np.exp(c-max_conf)
            exp_scores.append(s)
        total = sum(exp_scores)
        for idx,c in enumerate(confidences):
            answers[idx]['confidence'] = exp_scores[idx]/total

        if rerank_threshold is None:
            return answers

        # re-rank
        top_confidences = [a['confidence'] for idx, a in enumerate(answers) if a['confidence']> rerank_threshold]
        v1 = self.te.embed(question, word_level=False)
        for idx, answer in enumerate(answers):
            #if idx >= rerank_top_n: 
            if answer['confidence'] <= rerank_threshold:
                answer['similarity_score'] = 0.0
                continue
            v2 = self.te.embed(answer['full_answer'], word_level=False)
            score = v1 @ v2.T / (np.linalg.norm(v1)*np.linalg.norm(v2))
            answer['similarity_score'] = float(np.squeeze(score))
            answer['confidence'] = top_confidences[idx]
        answers = sorted(answers, key = lambda k:(k['similarity_score'], k['confidence']), reverse=True)
        for idx, confidence in enumerate(top_confidences):
            answers[idx]['confidence'] = confidence


        return answers


    def display_answers(self, answers):
        return display_answers(answers)



class SimpleQA(QA):
    """
    SimpleQA: Question-Answering on a list of texts
    """
    def __init__(self, index_dir, 
                 bert_squad_model='bert-large-uncased-whole-word-masking-finetuned-squad',
                 bert_emb_model='bert-base-uncased'):
        """
        ```
        SimpleQA constructor
        Args:
          index_dir(str):  path to index directory created by SimpleQA.initialze_index
          bert_squad_model(str): name of BERT SQUAD model to use
          bert_emb_model(str): BERT model to use to generate embeddings for semantic similarity
        ```
        """

        self.index_dir = index_dir
        try:
            ix = index.open_dir(self.index_dir)
        except:
            raise ValueError('index_dir has not yet been created - please call SimpleQA.initialize_index("%s")' % (self.index_dir))
        super().__init__(bert_squad_model=bert_squad_model, bert_emb_model=bert_emb_model)


    def _open_ix(self):
        return index.open_dir(self.index_dir)


    @classmethod
    def initialize_index(cls, index_dir):
        schema = Schema(reference=ID(stored=True), content=TEXT, rawtext=TEXT(stored=True))
        if not os.path.exists(index_dir):
            os.makedirs(index_dir)
        else:
            raise ValueError('There is already an existing directory or file with path %s' % (index_dir))
        ix = index.create_in(index_dir, schema)
        return ix

    @classmethod
    def index_from_list(cls, docs, index_dir, commit_every=1024, breakup_docs=True,
                        procs=1, limitmb=256, multisegment=False, min_words=20, references=None):
        """
        ```
        index documents from list.
        The procs, limitmb, and especially multisegment arguments can be used to 
        speed up indexing, if it is too slow.  Please see the whoosh documentation
        for more information on these parameters:  https://whoosh.readthedocs.io/en/latest/batch.html
        Args:
          docs(list): list of strings representing documents
          index_dir(str): path to index directory (see initialize_index)
          commit_every(int): commet after adding this many documents
          breakup_docs(bool): break up documents into smaller paragraphs and treat those as the documents.
                              This can potentially improve the speed at which answers are returned by the ask method
                              when documents being searched are longer.
          procs(int): number of processors
          limitmb(int): memory limit in MB for each process
          multisegment(bool): new segments written instead of merging
          min_words(int):  minimum words for a document (or paragraph extracted from document when breakup_docs=True) to be included in index.
                           Useful for pruning contexts that are unlikely to contain useful answers
          references(list): List of strings containing a reference (e.g., file name) for each document in docs.
                            Each string is treated as a label for the document (e.g., file name, MD5 hash, etc.):
                               Example:  ['some_file.pdf', 'some_other_file,pdf', ...]
                            Strings can also be hyperlinks in which case the label and URL should be separated by a single tab character:
                               Example: ['ktrain_article\thttps://arxiv.org/pdf/2004.10703v4.pdf', ...]

                            These references will be returned in the output of the ask method.
                            If strings are  hyperlinks, then they will automatically be made clickable when the display_answers function
                            displays candidate answers in a pandas DataFRame.

                            If references is None, the index of element in docs is used as reference.
        ```
        """
        if not isinstance(docs, (np.ndarray, list)): raise ValueError('docs must be a list of strings')
        if references is not None and not isinstance(references, (np.ndarray, list)): raise ValueError('references must be a list of strings')
        if references is not None and len(references) != len(docs): raise ValueError('lengths of docs and references must be equal')

        ix = index.open_dir(index_dir)
        writer = ix.writer(procs=procs, limitmb=limitmb, multisegment=multisegment)

        mb = master_bar(range(1))
        for i in mb:
            for idx, doc in enumerate(progress_bar(docs, parent=mb)):
                reference = "%s" % (idx) if references is None else references[idx]

                if breakup_docs:
                    small_docs = TU.paragraph_tokenize(doc, join_sentences=True, lang='en')
                    refs = [reference] * len(small_docs)
                    for i, small_doc in enumerate(small_docs):
                        if len(small_doc.split()) < min_words: continue
                        content = small_doc
                        reference = refs[i]
                        writer.add_document(reference=reference, content=content, rawtext=content)
                else:
                    if len(doc.split()) < min_words: continue
                    content = doc 
                    writer.add_document(reference=reference, content=content, rawtext=content)

                idx +=1
                if idx % commit_every == 0:
                    writer.commit()
                    #writer = ix.writer()
                    writer = ix.writer(procs=procs, limitmb=limitmb, multisegment=multisegment)
                mb.child.comment = f'indexing documents'
            writer.commit()
            #mb.write(f'Finished indexing documents')
        return


    @classmethod
    def index_from_folder(cls, folder_path, index_dir,  use_text_extraction=False, commit_every=1024, breakup_docs=True, 
                          min_words=20, encoding='utf-8', procs=1, limitmb=256, multisegment=False, verbose=1):
        """
        ```
        index all plain text documents within a folder.
        The procs, limitmb, and especially multisegment arguments can be used to 
        speed up indexing, if it is too slow.  Please see the whoosh documentation
        for more information on these parameters:  https://whoosh.readthedocs.io/en/latest/batch.html

        Args:
          folder_path(str): path to folder containing plain text documents (e.g., .txt files)
          index_dir(str): path to index directory (see initialize_index)
          use_text_extraction(bool): If True, the  `textract` package will be used to index text from various
                                     file types including PDF, MS Word, and MS PowerPoint (in addition to plain text files).
                                     If False, only plain text files will be indexed.
          commit_every(int): commet after adding this many documents
          breakup_docs(bool): break up documents into smaller paragraphs and treat those as the documents.
                              This can potentially improve the speed at which answers are returned by the ask method
                              when documents being searched are longer.
          min_words(int):  minimum words for a document (or paragraph extracted from document when breakup_docs=True) to be included in index.
                           Useful for pruning contexts that are unlikely to contain useful answers
          encoding(str): encoding to use when reading document files from disk
          procs(int): number of processors
          limitmb(int): memory limit in MB for each process
          multisegment(bool): new segments written instead of merging
          verbose(bool): verbosity
        ```
        """
        if use_text_extraction:
            try:
                import textract
            except ImportError:
                raise Exception('use_text_extraction=True requires textract:   pip install textract')


        if not os.path.isdir(folder_path): raise ValueError('folder_path is not a valid folder')
        if folder_path[-1] != os.sep: folder_path += os.sep
        ix = index.open_dir(index_dir)
        writer = ix.writer(procs=procs, limitmb=limitmb, multisegment=multisegment)
        for idx, fpath in enumerate(TU.extract_filenames(folder_path)):
            reference = "%s" % (fpath.join(fpath.split(folder_path)[1:]))
            if TU.is_txt(fpath):
                with open(fpath, 'r', encoding=encoding) as f:
                    doc = f.read()
            else:
                if use_text_extraction:
                    try:
                        doc = textract.process(fpath)
                        doc = doc.decode('utf-8', 'ignore')
                    except:
                        if verbose:
                            warnings.warn('Could not extract text from %s' % (fpath))
                        continue
                else:
                    continue

            if breakup_docs:
                small_docs = TU.paragraph_tokenize(doc, join_sentences=True, lang='en')
                refs = [reference] * len(small_docs)
                for i, small_doc in enumerate(small_docs):
                    if len(small_doc.split()) < min_words: continue
                    content = small_doc
                    reference = refs[i]
                    writer.add_document(reference=reference, content=content, rawtext=content)
            else:
                if len(doc.split()) < min_words: continue
                content = doc
                writer.add_document(reference=reference, content=content, rawtext=content)

            idx +=1
            if idx % commit_every == 0:
                writer.commit()
                writer = ix.writer(procs=procs, limitmb=limitmb, multisegment=multisegment)
                if verbose: print("%s docs indexed" % (idx))
        writer.commit()
        return


    def search(self, query, limit=10):
        """
        ```
        search index for query
        Args:
          query(str): search query
          limit(int):  number of top search results to return
        Returns:
          list of dicts with keys: reference, rawtext
        ```
        """
        ix = self._open_ix()
        with ix.searcher() as searcher:
            query_obj = QueryParser("content", ix.schema, group=qparser.OrGroup).parse(query)
            results = searcher.search(query_obj, limit=limit)
            docs = []
            output = [dict(r) for r in results]
            return output

Functions

def display_answers(answers)
Expand source code
def display_answers(answers):
    if not answers: return
    df = _answers2df(answers)
    from IPython.core.display import display, HTML
    return display(HTML(df.to_html(render_links=True, escape=False)))
def pack_byte(...)

S.pack(v1, v2, …) -> bytes

Return a bytes object containing values v1, v2, … packed according to the format string S.format. See help(struct) for more on format strings.

def unpack_byte(...)

S.unpack(buffer) -> (v1, v2, …)

Return a tuple containing values unpacked according to the format string S.format. The buffer's size in bytes must be S.size. See help(struct) for more on format strings.

Classes

class QA (bert_squad_model='bert-large-uncased-whole-word-masking-finetuned-squad', bert_emb_model='bert-base-uncased')

Base class for QA

Expand source code
class QA(ABC):
    """
    Base class for QA
    """

    def __init__(self, bert_squad_model='bert-large-uncased-whole-word-masking-finetuned-squad',
                 bert_emb_model='bert-base-uncased'):
        self.model_name = bert_squad_model
        try:
            self.model = TFAutoModelForQuestionAnswering.from_pretrained(self.model_name)
        except:
            self.model = TFAutoModelForQuestionAnswering.from_pretrained(self.model_name, from_pt=True)
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
        self.maxlen = 512
        self.te = tpp.TransformerEmbedding(bert_emb_model, layers=[-2])


    @abstractmethod
    def search(self, query):
        pass

    def predict_squad(self, documents, question):
        """ 
        Generates candidate answers to the <question> provided given <documents> as contexts.
        """
        if isinstance(documents, str): documents = [documents]
        sequences = [[question, d] for d in documents]
        batch = self.tokenizer.batch_encode_plus(sequences, return_tensors='tf', max_length=512, truncation='only_second', padding=True)
        tokens_batch = list( map(self.tokenizer.convert_ids_to_tokens, batch['input_ids']))

        # Added from: https://github.com/huggingface/transformers/commit/16ce15ed4bd0865d24a94aa839a44cf0f400ef50
        if U.get_hf_model_name(self.model_name) in  ['xlm', 'roberta', 'distilbert']:
           start_scores, end_scores = self.model(batch['input_ids'], attention_mask=batch['attention_mask'], return_dict=False)
        else:
           start_scores, end_scores = self.model(batch['input_ids'], attention_mask=batch['attention_mask'], 
                                                 token_type_ids=batch['token_type_ids'], return_dict=False)
        start_scores = start_scores[:,1:-1]
        end_scores = end_scores[:,1:-1]
        answer_starts = np.argmax(start_scores, axis=1)
        answer_ends = np.argmax(end_scores, axis=1)

        answers = []
        for i, tokens in enumerate(tokens_batch):
            answer_start = answer_starts[i]
            answer_end = answer_ends[i]
            answer = self._reconstruct_text(tokens, answer_start, answer_end+2)
            if answer.startswith('. ') or answer.startswith(', '):
                answer = answer[2:]  
            sep_index = tokens.index('[SEP]')
            full_txt_tokens = tokens[sep_index+1:]
            paragraph_bert = self._reconstruct_text(full_txt_tokens)

            ans={}
            ans['answer'] = answer
            if answer.startswith('[CLS]') or answer_end < sep_index or answer.endswith('[SEP]'):
                ans['confidence'] = LOWCONF
            else:
                #confidence = torch.max(start_scores) + torch.max(end_scores)
                #confidence = np.log(confidence.item())
                ans['confidence'] = start_scores[i,answer_start]+end_scores[i,answer_end]
            ans['start'] = answer_start
            ans['end'] = answer_end
            ans['context'] = paragraph_bert
            answers.append(ans)
        #if len(answers) == 1: answers = answers[0]
        return answers




    def _reconstruct_text(self, tokens, start=0, stop=-1):
        """
        Reconstruct text of *either* question or answer
        """
        tokens = tokens[start: stop]
        #if '[SEP]' in tokens:
            #sepind = tokens.index('[SEP]')
            #tokens = tokens[sepind+1:]
        txt = ' '.join(tokens)
        txt = txt.replace('[SEP]', '') # added for batch_encode_plus - removes [SEP] before [PAD]
        txt = txt.replace('[PAD]', '') # added for batch_encode_plus - removes [PAD]
        txt = txt.replace(' ##', '')
        txt = txt.replace('##', '')
        txt = txt.strip()
        txt = " ".join(txt.split())
        txt = txt.replace(' .', '.')
        txt = txt.replace('( ', '(')
        txt = txt.replace(' )', ')')
        txt = txt.replace(' - ', '-')
        txt_list = txt.split(' , ')
        txt = ''
        length = len(txt_list)
        if length == 1:
            return txt_list[0]
        new_list =[]
        for i,t in enumerate(txt_list):
            if i < length -1:
                if t[-1].isdigit() and txt_list[i+1][0].isdigit():
                    new_list += [t,',']
                else:
                    new_list += [t, ', ']
            else:
                new_list += [t]
        return ''.join(new_list)


    def _expand_answer(self, answer):
        """
        expand answer to include more of the context
        """
        full_abs = answer['context']
        bert_ans = answer['answer']
        split_abs = full_abs.split(bert_ans)
        sent_beginning = split_abs[0][split_abs[0].rfind('.')+1:]
        if len(split_abs) == 1:
            sent_end_pos = len(full_abs)
            sent_end =''
        else:
            sent_end_pos = split_abs[1].find('. ')+1
            if sent_end_pos == 0:
                sent_end = split_abs[1]
            else:
                sent_end = split_abs[1][:sent_end_pos]
            
        answer['full_answer'] = sent_beginning+bert_ans+sent_end
        answer['full_answer'] = answer['full_answer'].strip()
        answer['sentence_beginning'] = sent_beginning
        answer['sentence_end'] = sent_end
        return answer



    def ask(self, question, query=None, batch_size=8, n_docs_considered=10, n_answers=50, 
            rerank_threshold=0.015, include_np=False):
        """
        ```
        submit question to obtain candidate answers

        Args:
          question(str): question in the form of a string
          query(str): Optional. If not None, words in query will be used to retrieve contexts instead of words in question
          batch_size(int):  number of question-context pairs fed to model at each iteration
                            Default:8
                            Increase for faster answer-retrieval.
                            Decrease to reduce memory (if out-of-memory errors occur).
          n_docs_considered(int): number of top search results that will
                                  be searched for answer
                                  default:10
          n_answers(int): maximum number of candidate answers to return
                          default:50
          rerank_threshold(int): rerank top answers with confidence >= rerank_threshold
                                 based on semantic similarity between question and answer.
                                 This can help bump the correct answer closer to the top.
                                 default:0.015.
                                 If None, no re-ranking is performed.
          include_np(bool):  If True, noun phrases will be extracted from question and included
                             in query that retrieves documents likely to contain candidate answers.
                             This may be useful if you ask a question about artificial intelligence
                             and the answers returned pertain just to intelligence, for example.
                             Note: include_np=True requires textblob be installed.
                             Default:False
        Returns:
          list
        ```
        """
        # locate candidate document contexts
        paragraphs = []
        refs = []
        #doc_results = self.search(question, limit=n_docs_considered)
        doc_results = self.search(_process_question(query if query is not None else question, include_np=include_np), limit=n_docs_considered)
        if not doc_results: 
            warnings.warn('No documents matched words in question (or query if supplied)')
            return []
        # extract paragraphs as contexts
        contexts = []
        refs = []
        for doc_result in doc_results:
            rawtext = doc_result.get('rawtext', '')
            reference = doc_result.get('reference', '')
            if len(self.tokenizer.tokenize(rawtext)) < self.maxlen:
                contexts.append(rawtext)
                refs.append(reference)
            else:
                paragraphs = TU.paragraph_tokenize(rawtext, join_sentences=True)
                contexts.extend(paragraphs)
                refs.extend([reference] * len(paragraphs))


        #for doc_result in doc_results:
            #rawtext = doc_result.get('rawtext', '')
            #reference = doc_result.get('reference', '')
            #if len(self.tokenizer.tokenize(rawtext)) < self.maxlen:
                #paragraphs.append(rawtext)
                #refs.append(reference)
                #continue
            #plist = TU.paragraph_tokenize(rawtext, join_sentences=True)
            #paragraphs.extend(plist)
            #refs.extend([reference]*len(plist))


        # batchify contexts
        #return contexts
        if batch_size  > len(contexts): batch_size = len(contexts)
        #if len(contexts) >= 100 and batch_size==8:
            #warnings.warn('TIP: Try increasing batch_size to speedup ask predictions')
        num_chunks = math.ceil(len(contexts)/batch_size)
        context_batches = list( U.list2chunks(contexts, n=num_chunks) )


        # locate candidate answers
        answers = []
        mb = master_bar(range(1))
        answer_batches = []
        for i in mb:
            idx = 0
            for batch_id, contexts in enumerate(progress_bar(context_batches, parent=mb)):
                answer_batch = self.predict_squad(contexts, question)
                answer_batches.extend(answer_batch)
                for answer in answer_batch:
                    idx+=1
                    if not answer['answer'] or answer['confidence'] <-100: continue
                    answer['confidence'] = answer['confidence'].numpy()
                    answer['reference'] = refs[idx-1]
                    answer = self._expand_answer(answer)
                    answers.append(answer)

                mb.child.comment = f'generating candidate answers'

        if not answers: return answers # fix for #307
        answers = sorted(answers, key = lambda k:k['confidence'], reverse=True)
        if n_answers is not None:
            answers = answers[:n_answers]

        # transform confidence scores
        confidences = [a['confidence'] for a in answers]
        max_conf = max(confidences)
        total = 0.0
        exp_scores = []
        for c in confidences:
            s = np.exp(c-max_conf)
            exp_scores.append(s)
        total = sum(exp_scores)
        for idx,c in enumerate(confidences):
            answers[idx]['confidence'] = exp_scores[idx]/total

        if rerank_threshold is None:
            return answers

        # re-rank
        top_confidences = [a['confidence'] for idx, a in enumerate(answers) if a['confidence']> rerank_threshold]
        v1 = self.te.embed(question, word_level=False)
        for idx, answer in enumerate(answers):
            #if idx >= rerank_top_n: 
            if answer['confidence'] <= rerank_threshold:
                answer['similarity_score'] = 0.0
                continue
            v2 = self.te.embed(answer['full_answer'], word_level=False)
            score = v1 @ v2.T / (np.linalg.norm(v1)*np.linalg.norm(v2))
            answer['similarity_score'] = float(np.squeeze(score))
            answer['confidence'] = top_confidences[idx]
        answers = sorted(answers, key = lambda k:(k['similarity_score'], k['confidence']), reverse=True)
        for idx, confidence in enumerate(top_confidences):
            answers[idx]['confidence'] = confidence


        return answers


    def display_answers(self, answers):
        return display_answers(answers)

Ancestors

  • abc.ABC

Subclasses

Methods

def ask(self, question, query=None, batch_size=8, n_docs_considered=10, n_answers=50, rerank_threshold=0.015, include_np=False)
submit question to obtain candidate answers

Args:
  question(str): question in the form of a string
  query(str): Optional. If not None, words in query will be used to retrieve contexts instead of words in question
  batch_size(int):  number of question-context pairs fed to model at each iteration
                    Default:8
                    Increase for faster answer-retrieval.
                    Decrease to reduce memory (if out-of-memory errors occur).
  n_docs_considered(int): number of top search results that will
                          be searched for answer
                          default:10
  n_answers(int): maximum number of candidate answers to return
                  default:50
  rerank_threshold(int): rerank top answers with confidence >= rerank_threshold
                         based on semantic similarity between question and answer.
                         This can help bump the correct answer closer to the top.
                         default:0.015.
                         If None, no re-ranking is performed.
  include_np(bool):  If True, noun phrases will be extracted from question and included
                     in query that retrieves documents likely to contain candidate answers.
                     This may be useful if you ask a question about artificial intelligence
                     and the answers returned pertain just to intelligence, for example.
                     Note: include_np=True requires textblob be installed.
                     Default:False
Returns:
  list
Expand source code
def ask(self, question, query=None, batch_size=8, n_docs_considered=10, n_answers=50, 
        rerank_threshold=0.015, include_np=False):
    """
    ```
    submit question to obtain candidate answers

    Args:
      question(str): question in the form of a string
      query(str): Optional. If not None, words in query will be used to retrieve contexts instead of words in question
      batch_size(int):  number of question-context pairs fed to model at each iteration
                        Default:8
                        Increase for faster answer-retrieval.
                        Decrease to reduce memory (if out-of-memory errors occur).
      n_docs_considered(int): number of top search results that will
                              be searched for answer
                              default:10
      n_answers(int): maximum number of candidate answers to return
                      default:50
      rerank_threshold(int): rerank top answers with confidence >= rerank_threshold
                             based on semantic similarity between question and answer.
                             This can help bump the correct answer closer to the top.
                             default:0.015.
                             If None, no re-ranking is performed.
      include_np(bool):  If True, noun phrases will be extracted from question and included
                         in query that retrieves documents likely to contain candidate answers.
                         This may be useful if you ask a question about artificial intelligence
                         and the answers returned pertain just to intelligence, for example.
                         Note: include_np=True requires textblob be installed.
                         Default:False
    Returns:
      list
    ```
    """
    # locate candidate document contexts
    paragraphs = []
    refs = []
    #doc_results = self.search(question, limit=n_docs_considered)
    doc_results = self.search(_process_question(query if query is not None else question, include_np=include_np), limit=n_docs_considered)
    if not doc_results: 
        warnings.warn('No documents matched words in question (or query if supplied)')
        return []
    # extract paragraphs as contexts
    contexts = []
    refs = []
    for doc_result in doc_results:
        rawtext = doc_result.get('rawtext', '')
        reference = doc_result.get('reference', '')
        if len(self.tokenizer.tokenize(rawtext)) < self.maxlen:
            contexts.append(rawtext)
            refs.append(reference)
        else:
            paragraphs = TU.paragraph_tokenize(rawtext, join_sentences=True)
            contexts.extend(paragraphs)
            refs.extend([reference] * len(paragraphs))


    #for doc_result in doc_results:
        #rawtext = doc_result.get('rawtext', '')
        #reference = doc_result.get('reference', '')
        #if len(self.tokenizer.tokenize(rawtext)) < self.maxlen:
            #paragraphs.append(rawtext)
            #refs.append(reference)
            #continue
        #plist = TU.paragraph_tokenize(rawtext, join_sentences=True)
        #paragraphs.extend(plist)
        #refs.extend([reference]*len(plist))


    # batchify contexts
    #return contexts
    if batch_size  > len(contexts): batch_size = len(contexts)
    #if len(contexts) >= 100 and batch_size==8:
        #warnings.warn('TIP: Try increasing batch_size to speedup ask predictions')
    num_chunks = math.ceil(len(contexts)/batch_size)
    context_batches = list( U.list2chunks(contexts, n=num_chunks) )


    # locate candidate answers
    answers = []
    mb = master_bar(range(1))
    answer_batches = []
    for i in mb:
        idx = 0
        for batch_id, contexts in enumerate(progress_bar(context_batches, parent=mb)):
            answer_batch = self.predict_squad(contexts, question)
            answer_batches.extend(answer_batch)
            for answer in answer_batch:
                idx+=1
                if not answer['answer'] or answer['confidence'] <-100: continue
                answer['confidence'] = answer['confidence'].numpy()
                answer['reference'] = refs[idx-1]
                answer = self._expand_answer(answer)
                answers.append(answer)

            mb.child.comment = f'generating candidate answers'

    if not answers: return answers # fix for #307
    answers = sorted(answers, key = lambda k:k['confidence'], reverse=True)
    if n_answers is not None:
        answers = answers[:n_answers]

    # transform confidence scores
    confidences = [a['confidence'] for a in answers]
    max_conf = max(confidences)
    total = 0.0
    exp_scores = []
    for c in confidences:
        s = np.exp(c-max_conf)
        exp_scores.append(s)
    total = sum(exp_scores)
    for idx,c in enumerate(confidences):
        answers[idx]['confidence'] = exp_scores[idx]/total

    if rerank_threshold is None:
        return answers

    # re-rank
    top_confidences = [a['confidence'] for idx, a in enumerate(answers) if a['confidence']> rerank_threshold]
    v1 = self.te.embed(question, word_level=False)
    for idx, answer in enumerate(answers):
        #if idx >= rerank_top_n: 
        if answer['confidence'] <= rerank_threshold:
            answer['similarity_score'] = 0.0
            continue
        v2 = self.te.embed(answer['full_answer'], word_level=False)
        score = v1 @ v2.T / (np.linalg.norm(v1)*np.linalg.norm(v2))
        answer['similarity_score'] = float(np.squeeze(score))
        answer['confidence'] = top_confidences[idx]
    answers = sorted(answers, key = lambda k:(k['similarity_score'], k['confidence']), reverse=True)
    for idx, confidence in enumerate(top_confidences):
        answers[idx]['confidence'] = confidence


    return answers
def display_answers(self, answers)
Expand source code
def display_answers(self, answers):
    return display_answers(answers)
def predict_squad(self, documents, question)

Generates candidate answers to the provided given as contexts.

Expand source code
def predict_squad(self, documents, question):
    """ 
    Generates candidate answers to the <question> provided given <documents> as contexts.
    """
    if isinstance(documents, str): documents = [documents]
    sequences = [[question, d] for d in documents]
    batch = self.tokenizer.batch_encode_plus(sequences, return_tensors='tf', max_length=512, truncation='only_second', padding=True)
    tokens_batch = list( map(self.tokenizer.convert_ids_to_tokens, batch['input_ids']))

    # Added from: https://github.com/huggingface/transformers/commit/16ce15ed4bd0865d24a94aa839a44cf0f400ef50
    if U.get_hf_model_name(self.model_name) in  ['xlm', 'roberta', 'distilbert']:
       start_scores, end_scores = self.model(batch['input_ids'], attention_mask=batch['attention_mask'], return_dict=False)
    else:
       start_scores, end_scores = self.model(batch['input_ids'], attention_mask=batch['attention_mask'], 
                                             token_type_ids=batch['token_type_ids'], return_dict=False)
    start_scores = start_scores[:,1:-1]
    end_scores = end_scores[:,1:-1]
    answer_starts = np.argmax(start_scores, axis=1)
    answer_ends = np.argmax(end_scores, axis=1)

    answers = []
    for i, tokens in enumerate(tokens_batch):
        answer_start = answer_starts[i]
        answer_end = answer_ends[i]
        answer = self._reconstruct_text(tokens, answer_start, answer_end+2)
        if answer.startswith('. ') or answer.startswith(', '):
            answer = answer[2:]  
        sep_index = tokens.index('[SEP]')
        full_txt_tokens = tokens[sep_index+1:]
        paragraph_bert = self._reconstruct_text(full_txt_tokens)

        ans={}
        ans['answer'] = answer
        if answer.startswith('[CLS]') or answer_end < sep_index or answer.endswith('[SEP]'):
            ans['confidence'] = LOWCONF
        else:
            #confidence = torch.max(start_scores) + torch.max(end_scores)
            #confidence = np.log(confidence.item())
            ans['confidence'] = start_scores[i,answer_start]+end_scores[i,answer_end]
        ans['start'] = answer_start
        ans['end'] = answer_end
        ans['context'] = paragraph_bert
        answers.append(ans)
    #if len(answers) == 1: answers = answers[0]
    return answers
def search(self, query)
Expand source code
@abstractmethod
def search(self, query):
    pass
class SimpleQA (index_dir, bert_squad_model='bert-large-uncased-whole-word-masking-finetuned-squad', bert_emb_model='bert-base-uncased')

SimpleQA: Question-Answering on a list of texts

SimpleQA constructor
Args:
  index_dir(str):  path to index directory created by SimpleQA.initialze_index
  bert_squad_model(str): name of BERT SQUAD model to use
  bert_emb_model(str): BERT model to use to generate embeddings for semantic similarity
Expand source code
class SimpleQA(QA):
    """
    SimpleQA: Question-Answering on a list of texts
    """
    def __init__(self, index_dir, 
                 bert_squad_model='bert-large-uncased-whole-word-masking-finetuned-squad',
                 bert_emb_model='bert-base-uncased'):
        """
        ```
        SimpleQA constructor
        Args:
          index_dir(str):  path to index directory created by SimpleQA.initialze_index
          bert_squad_model(str): name of BERT SQUAD model to use
          bert_emb_model(str): BERT model to use to generate embeddings for semantic similarity
        ```
        """

        self.index_dir = index_dir
        try:
            ix = index.open_dir(self.index_dir)
        except:
            raise ValueError('index_dir has not yet been created - please call SimpleQA.initialize_index("%s")' % (self.index_dir))
        super().__init__(bert_squad_model=bert_squad_model, bert_emb_model=bert_emb_model)


    def _open_ix(self):
        return index.open_dir(self.index_dir)


    @classmethod
    def initialize_index(cls, index_dir):
        schema = Schema(reference=ID(stored=True), content=TEXT, rawtext=TEXT(stored=True))
        if not os.path.exists(index_dir):
            os.makedirs(index_dir)
        else:
            raise ValueError('There is already an existing directory or file with path %s' % (index_dir))
        ix = index.create_in(index_dir, schema)
        return ix

    @classmethod
    def index_from_list(cls, docs, index_dir, commit_every=1024, breakup_docs=True,
                        procs=1, limitmb=256, multisegment=False, min_words=20, references=None):
        """
        ```
        index documents from list.
        The procs, limitmb, and especially multisegment arguments can be used to 
        speed up indexing, if it is too slow.  Please see the whoosh documentation
        for more information on these parameters:  https://whoosh.readthedocs.io/en/latest/batch.html
        Args:
          docs(list): list of strings representing documents
          index_dir(str): path to index directory (see initialize_index)
          commit_every(int): commet after adding this many documents
          breakup_docs(bool): break up documents into smaller paragraphs and treat those as the documents.
                              This can potentially improve the speed at which answers are returned by the ask method
                              when documents being searched are longer.
          procs(int): number of processors
          limitmb(int): memory limit in MB for each process
          multisegment(bool): new segments written instead of merging
          min_words(int):  minimum words for a document (or paragraph extracted from document when breakup_docs=True) to be included in index.
                           Useful for pruning contexts that are unlikely to contain useful answers
          references(list): List of strings containing a reference (e.g., file name) for each document in docs.
                            Each string is treated as a label for the document (e.g., file name, MD5 hash, etc.):
                               Example:  ['some_file.pdf', 'some_other_file,pdf', ...]
                            Strings can also be hyperlinks in which case the label and URL should be separated by a single tab character:
                               Example: ['ktrain_article\thttps://arxiv.org/pdf/2004.10703v4.pdf', ...]

                            These references will be returned in the output of the ask method.
                            If strings are  hyperlinks, then they will automatically be made clickable when the display_answers function
                            displays candidate answers in a pandas DataFRame.

                            If references is None, the index of element in docs is used as reference.
        ```
        """
        if not isinstance(docs, (np.ndarray, list)): raise ValueError('docs must be a list of strings')
        if references is not None and not isinstance(references, (np.ndarray, list)): raise ValueError('references must be a list of strings')
        if references is not None and len(references) != len(docs): raise ValueError('lengths of docs and references must be equal')

        ix = index.open_dir(index_dir)
        writer = ix.writer(procs=procs, limitmb=limitmb, multisegment=multisegment)

        mb = master_bar(range(1))
        for i in mb:
            for idx, doc in enumerate(progress_bar(docs, parent=mb)):
                reference = "%s" % (idx) if references is None else references[idx]

                if breakup_docs:
                    small_docs = TU.paragraph_tokenize(doc, join_sentences=True, lang='en')
                    refs = [reference] * len(small_docs)
                    for i, small_doc in enumerate(small_docs):
                        if len(small_doc.split()) < min_words: continue
                        content = small_doc
                        reference = refs[i]
                        writer.add_document(reference=reference, content=content, rawtext=content)
                else:
                    if len(doc.split()) < min_words: continue
                    content = doc 
                    writer.add_document(reference=reference, content=content, rawtext=content)

                idx +=1
                if idx % commit_every == 0:
                    writer.commit()
                    #writer = ix.writer()
                    writer = ix.writer(procs=procs, limitmb=limitmb, multisegment=multisegment)
                mb.child.comment = f'indexing documents'
            writer.commit()
            #mb.write(f'Finished indexing documents')
        return


    @classmethod
    def index_from_folder(cls, folder_path, index_dir,  use_text_extraction=False, commit_every=1024, breakup_docs=True, 
                          min_words=20, encoding='utf-8', procs=1, limitmb=256, multisegment=False, verbose=1):
        """
        ```
        index all plain text documents within a folder.
        The procs, limitmb, and especially multisegment arguments can be used to 
        speed up indexing, if it is too slow.  Please see the whoosh documentation
        for more information on these parameters:  https://whoosh.readthedocs.io/en/latest/batch.html

        Args:
          folder_path(str): path to folder containing plain text documents (e.g., .txt files)
          index_dir(str): path to index directory (see initialize_index)
          use_text_extraction(bool): If True, the  `textract` package will be used to index text from various
                                     file types including PDF, MS Word, and MS PowerPoint (in addition to plain text files).
                                     If False, only plain text files will be indexed.
          commit_every(int): commet after adding this many documents
          breakup_docs(bool): break up documents into smaller paragraphs and treat those as the documents.
                              This can potentially improve the speed at which answers are returned by the ask method
                              when documents being searched are longer.
          min_words(int):  minimum words for a document (or paragraph extracted from document when breakup_docs=True) to be included in index.
                           Useful for pruning contexts that are unlikely to contain useful answers
          encoding(str): encoding to use when reading document files from disk
          procs(int): number of processors
          limitmb(int): memory limit in MB for each process
          multisegment(bool): new segments written instead of merging
          verbose(bool): verbosity
        ```
        """
        if use_text_extraction:
            try:
                import textract
            except ImportError:
                raise Exception('use_text_extraction=True requires textract:   pip install textract')


        if not os.path.isdir(folder_path): raise ValueError('folder_path is not a valid folder')
        if folder_path[-1] != os.sep: folder_path += os.sep
        ix = index.open_dir(index_dir)
        writer = ix.writer(procs=procs, limitmb=limitmb, multisegment=multisegment)
        for idx, fpath in enumerate(TU.extract_filenames(folder_path)):
            reference = "%s" % (fpath.join(fpath.split(folder_path)[1:]))
            if TU.is_txt(fpath):
                with open(fpath, 'r', encoding=encoding) as f:
                    doc = f.read()
            else:
                if use_text_extraction:
                    try:
                        doc = textract.process(fpath)
                        doc = doc.decode('utf-8', 'ignore')
                    except:
                        if verbose:
                            warnings.warn('Could not extract text from %s' % (fpath))
                        continue
                else:
                    continue

            if breakup_docs:
                small_docs = TU.paragraph_tokenize(doc, join_sentences=True, lang='en')
                refs = [reference] * len(small_docs)
                for i, small_doc in enumerate(small_docs):
                    if len(small_doc.split()) < min_words: continue
                    content = small_doc
                    reference = refs[i]
                    writer.add_document(reference=reference, content=content, rawtext=content)
            else:
                if len(doc.split()) < min_words: continue
                content = doc
                writer.add_document(reference=reference, content=content, rawtext=content)

            idx +=1
            if idx % commit_every == 0:
                writer.commit()
                writer = ix.writer(procs=procs, limitmb=limitmb, multisegment=multisegment)
                if verbose: print("%s docs indexed" % (idx))
        writer.commit()
        return


    def search(self, query, limit=10):
        """
        ```
        search index for query
        Args:
          query(str): search query
          limit(int):  number of top search results to return
        Returns:
          list of dicts with keys: reference, rawtext
        ```
        """
        ix = self._open_ix()
        with ix.searcher() as searcher:
            query_obj = QueryParser("content", ix.schema, group=qparser.OrGroup).parse(query)
            results = searcher.search(query_obj, limit=limit)
            docs = []
            output = [dict(r) for r in results]
            return output

Ancestors

  • QA
  • abc.ABC

Static methods

def index_from_folder(folder_path, index_dir, use_text_extraction=False, commit_every=1024, breakup_docs=True, min_words=20, encoding='utf-8', procs=1, limitmb=256, multisegment=False, verbose=1)
index all plain text documents within a folder.
The procs, limitmb, and especially multisegment arguments can be used to 
speed up indexing, if it is too slow.  Please see the whoosh documentation
for more information on these parameters:  https://whoosh.readthedocs.io/en/latest/batch.html

Args:
  folder_path(str): path to folder containing plain text documents (e.g., .txt files)
  index_dir(str): path to index directory (see initialize_index)
  use_text_extraction(bool): If True, the  `textract` package will be used to index text from various
                             file types including PDF, MS Word, and MS PowerPoint (in addition to plain text files).
                             If False, only plain text files will be indexed.
  commit_every(int): commet after adding this many documents
  breakup_docs(bool): break up documents into smaller paragraphs and treat those as the documents.
                      This can potentially improve the speed at which answers are returned by the ask method
                      when documents being searched are longer.
  min_words(int):  minimum words for a document (or paragraph extracted from document when breakup_docs=True) to be included in index.
                   Useful for pruning contexts that are unlikely to contain useful answers
  encoding(str): encoding to use when reading document files from disk
  procs(int): number of processors
  limitmb(int): memory limit in MB for each process
  multisegment(bool): new segments written instead of merging
  verbose(bool): verbosity
Expand source code
@classmethod
def index_from_folder(cls, folder_path, index_dir,  use_text_extraction=False, commit_every=1024, breakup_docs=True, 
                      min_words=20, encoding='utf-8', procs=1, limitmb=256, multisegment=False, verbose=1):
    """
    ```
    index all plain text documents within a folder.
    The procs, limitmb, and especially multisegment arguments can be used to 
    speed up indexing, if it is too slow.  Please see the whoosh documentation
    for more information on these parameters:  https://whoosh.readthedocs.io/en/latest/batch.html

    Args:
      folder_path(str): path to folder containing plain text documents (e.g., .txt files)
      index_dir(str): path to index directory (see initialize_index)
      use_text_extraction(bool): If True, the  `textract` package will be used to index text from various
                                 file types including PDF, MS Word, and MS PowerPoint (in addition to plain text files).
                                 If False, only plain text files will be indexed.
      commit_every(int): commet after adding this many documents
      breakup_docs(bool): break up documents into smaller paragraphs and treat those as the documents.
                          This can potentially improve the speed at which answers are returned by the ask method
                          when documents being searched are longer.
      min_words(int):  minimum words for a document (or paragraph extracted from document when breakup_docs=True) to be included in index.
                       Useful for pruning contexts that are unlikely to contain useful answers
      encoding(str): encoding to use when reading document files from disk
      procs(int): number of processors
      limitmb(int): memory limit in MB for each process
      multisegment(bool): new segments written instead of merging
      verbose(bool): verbosity
    ```
    """
    if use_text_extraction:
        try:
            import textract
        except ImportError:
            raise Exception('use_text_extraction=True requires textract:   pip install textract')


    if not os.path.isdir(folder_path): raise ValueError('folder_path is not a valid folder')
    if folder_path[-1] != os.sep: folder_path += os.sep
    ix = index.open_dir(index_dir)
    writer = ix.writer(procs=procs, limitmb=limitmb, multisegment=multisegment)
    for idx, fpath in enumerate(TU.extract_filenames(folder_path)):
        reference = "%s" % (fpath.join(fpath.split(folder_path)[1:]))
        if TU.is_txt(fpath):
            with open(fpath, 'r', encoding=encoding) as f:
                doc = f.read()
        else:
            if use_text_extraction:
                try:
                    doc = textract.process(fpath)
                    doc = doc.decode('utf-8', 'ignore')
                except:
                    if verbose:
                        warnings.warn('Could not extract text from %s' % (fpath))
                    continue
            else:
                continue

        if breakup_docs:
            small_docs = TU.paragraph_tokenize(doc, join_sentences=True, lang='en')
            refs = [reference] * len(small_docs)
            for i, small_doc in enumerate(small_docs):
                if len(small_doc.split()) < min_words: continue
                content = small_doc
                reference = refs[i]
                writer.add_document(reference=reference, content=content, rawtext=content)
        else:
            if len(doc.split()) < min_words: continue
            content = doc
            writer.add_document(reference=reference, content=content, rawtext=content)

        idx +=1
        if idx % commit_every == 0:
            writer.commit()
            writer = ix.writer(procs=procs, limitmb=limitmb, multisegment=multisegment)
            if verbose: print("%s docs indexed" % (idx))
    writer.commit()
    return
def index_from_list(docs, index_dir, commit_every=1024, breakup_docs=True, procs=1, limitmb=256, multisegment=False, min_words=20, references=None)
index documents from list.
The procs, limitmb, and especially multisegment arguments can be used to 
speed up indexing, if it is too slow.  Please see the whoosh documentation
for more information on these parameters:  https://whoosh.readthedocs.io/en/latest/batch.html
Args:
  docs(list): list of strings representing documents
  index_dir(str): path to index directory (see initialize_index)
  commit_every(int): commet after adding this many documents
  breakup_docs(bool): break up documents into smaller paragraphs and treat those as the documents.
                      This can potentially improve the speed at which answers are returned by the ask method
                      when documents being searched are longer.
  procs(int): number of processors
  limitmb(int): memory limit in MB for each process
  multisegment(bool): new segments written instead of merging
  min_words(int):  minimum words for a document (or paragraph extracted from document when breakup_docs=True) to be included in index.
                   Useful for pruning contexts that are unlikely to contain useful answers
  references(list): List of strings containing a reference (e.g., file name) for each document in docs.
                    Each string is treated as a label for the document (e.g., file name, MD5 hash, etc.):
                       Example:  ['some_file.pdf', 'some_other_file,pdf', ...]
                    Strings can also be hyperlinks in which case the label and URL should be separated by a single tab character:
                       Example: ['ktrain_article        https://arxiv.org/pdf/2004.10703v4.pdf', ...]

                    These references will be returned in the output of the ask method.
                    If strings are  hyperlinks, then they will automatically be made clickable when the display_answers function
                    displays candidate answers in a pandas DataFRame.

                    If references is None, the index of element in docs is used as reference.
Expand source code
@classmethod
def index_from_list(cls, docs, index_dir, commit_every=1024, breakup_docs=True,
                    procs=1, limitmb=256, multisegment=False, min_words=20, references=None):
    """
    ```
    index documents from list.
    The procs, limitmb, and especially multisegment arguments can be used to 
    speed up indexing, if it is too slow.  Please see the whoosh documentation
    for more information on these parameters:  https://whoosh.readthedocs.io/en/latest/batch.html
    Args:
      docs(list): list of strings representing documents
      index_dir(str): path to index directory (see initialize_index)
      commit_every(int): commet after adding this many documents
      breakup_docs(bool): break up documents into smaller paragraphs and treat those as the documents.
                          This can potentially improve the speed at which answers are returned by the ask method
                          when documents being searched are longer.
      procs(int): number of processors
      limitmb(int): memory limit in MB for each process
      multisegment(bool): new segments written instead of merging
      min_words(int):  minimum words for a document (or paragraph extracted from document when breakup_docs=True) to be included in index.
                       Useful for pruning contexts that are unlikely to contain useful answers
      references(list): List of strings containing a reference (e.g., file name) for each document in docs.
                        Each string is treated as a label for the document (e.g., file name, MD5 hash, etc.):
                           Example:  ['some_file.pdf', 'some_other_file,pdf', ...]
                        Strings can also be hyperlinks in which case the label and URL should be separated by a single tab character:
                           Example: ['ktrain_article\thttps://arxiv.org/pdf/2004.10703v4.pdf', ...]

                        These references will be returned in the output of the ask method.
                        If strings are  hyperlinks, then they will automatically be made clickable when the display_answers function
                        displays candidate answers in a pandas DataFRame.

                        If references is None, the index of element in docs is used as reference.
    ```
    """
    if not isinstance(docs, (np.ndarray, list)): raise ValueError('docs must be a list of strings')
    if references is not None and not isinstance(references, (np.ndarray, list)): raise ValueError('references must be a list of strings')
    if references is not None and len(references) != len(docs): raise ValueError('lengths of docs and references must be equal')

    ix = index.open_dir(index_dir)
    writer = ix.writer(procs=procs, limitmb=limitmb, multisegment=multisegment)

    mb = master_bar(range(1))
    for i in mb:
        for idx, doc in enumerate(progress_bar(docs, parent=mb)):
            reference = "%s" % (idx) if references is None else references[idx]

            if breakup_docs:
                small_docs = TU.paragraph_tokenize(doc, join_sentences=True, lang='en')
                refs = [reference] * len(small_docs)
                for i, small_doc in enumerate(small_docs):
                    if len(small_doc.split()) < min_words: continue
                    content = small_doc
                    reference = refs[i]
                    writer.add_document(reference=reference, content=content, rawtext=content)
            else:
                if len(doc.split()) < min_words: continue
                content = doc 
                writer.add_document(reference=reference, content=content, rawtext=content)

            idx +=1
            if idx % commit_every == 0:
                writer.commit()
                #writer = ix.writer()
                writer = ix.writer(procs=procs, limitmb=limitmb, multisegment=multisegment)
            mb.child.comment = f'indexing documents'
        writer.commit()
        #mb.write(f'Finished indexing documents')
    return
def initialize_index(index_dir)
Expand source code
@classmethod
def initialize_index(cls, index_dir):
    schema = Schema(reference=ID(stored=True), content=TEXT, rawtext=TEXT(stored=True))
    if not os.path.exists(index_dir):
        os.makedirs(index_dir)
    else:
        raise ValueError('There is already an existing directory or file with path %s' % (index_dir))
    ix = index.create_in(index_dir, schema)
    return ix

Methods

def search(self, query, limit=10)
search index for query
Args:
  query(str): search query
  limit(int):  number of top search results to return
Returns:
  list of dicts with keys: reference, rawtext
Expand source code
def search(self, query, limit=10):
    """
    ```
    search index for query
    Args:
      query(str): search query
      limit(int):  number of top search results to return
    Returns:
      list of dicts with keys: reference, rawtext
    ```
    """
    ix = self._open_ix()
    with ix.searcher() as searcher:
        query_obj = QueryParser("content", ix.schema, group=qparser.OrGroup).parse(query)
        results = searcher.search(query_obj, limit=limit)
        docs = []
        output = [dict(r) for r in results]
        return output

Inherited members