aalpy.learning_algs.non_deterministic.AbstractedOnfsmObservationTable

View Source
from collections import defaultdict

from aalpy.automata import Onfsm, OnfsmState
from aalpy.learning_algs.non_deterministic.OnfsmObservationTable import NonDetObservationTable
from aalpy.learning_algs.non_deterministic.NonDeterministicSULWrapper import NonDeterministicSULWrapper
from aalpy.utils.HelperFunctions import all_suffixes, extend_set


class AbstractedNonDetObservationTable:
    def __init__(self, alphabet: list, sul: NonDeterministicSULWrapper, abstraction_mapping: dict, n_sampling=100):
        """
        Construction of the abstracted non-deterministic observation table.

        Args:

            alphabet: input alphabet
            sul: system under learning
            abstraction_mapping: map that translates outputs to abstracted outputs
            n_sampling: number of samples to be performed for each cell
        """

        assert alphabet is not None and sul is not None

        self.observation_table = NonDetObservationTable(alphabet, sul, n_sampling)

        self.S = list()
        self.S_dot_A = []
        self.E = []
        self.T = defaultdict(dict)
        self.A = [tuple([a]) for a in alphabet]

        self.abstraction_mapping = abstraction_mapping
        self.sul = sul

        empty_word = tuple()
        self.S.append((empty_word, empty_word))

    def update_obs_table(self, s_set=None, e_set: list = None):
        """
        Perform the membership queries and abstraction on observation table
        With  the  all-weather  assumption,  each  output  query  is  tried  a  number  of  times  on  the  system,
        and  the  driver  reports  the  set  of  all  possible  outputs.

        Args:

            s_set: Prefixes of S set on which to preform membership queries (Default value = None)
            e_set: Suffixes of E set on which to perform membership queries


        """

        self.observation_table.query_missing_observations(s_set, e_set)
        self.abstract_obs_table()
        self.clean_obs_table()

    def abstract_obs_table(self):
        """
        Creation of abstracted observation table. The provided abstraction mapping is used to
        replace outputs by abstracted outputs.
        """

        self.S = self.observation_table.S
        self.S_dot_A = list(set(self.observation_table.get_extended_S()).union(set(self.S_dot_A) - set(self.S)))
        self.E = self.observation_table.E

        update_S = self.S + self.S_dot_A
        update_E = self.E

        for s in update_S:
            for e in update_E:
                for o_tup in self.get_all_outputs(s, e):
                    abstracted_outputs = []
                    o_tup = tuple([o_tup])
                    for outputs in o_tup:
                        for o in outputs:
                            abstract_output = self.get_abstraction(o)
                            abstracted_outputs.append(abstract_output)
                    self.add_to_T(s, e, tuple(abstracted_outputs))

    def add_to_T(self, s, e, value):
        """
        Add values to the cell at T[s][e].

        Args:

            s: prefix
            e: element of S
            value: value to be added to the cell


        """
        if e not in self.T[s]:
            self.T[s][e] = set()
        self.T[s][e].add(value)

    # CHANGED
    # helper function
    def get_all_outputs(self, s, e):
        cell_outputs = set()
        cell_outputs.update(self.sul.cache.get_all_traces(s, e))
        return cell_outputs

    def update_extended_S(self, row_prefix=None):
        """
        Helper generator function that returns extended S, or S.A set.
        For all values in the cell, create a new row where inputs is parent input plus element of alphabet, and
        output is parent output plus value in cell.

        Returns:

            New rows of extended S set.
        """
        return self.observation_table.get_extended_S(row_prefix=row_prefix)

    def get_row_to_close(self):
        """
        Get row for that needs to be closed.

        Returns:

            row that will be moved to S set and closed
        """
        s_rows = set()
        for s in self.S:
            s_rows.add(self.row_to_hashable(s))

        for t in self.S_dot_A:
            row_t = self.row_to_hashable(t)

            if row_t not in s_rows:
                self.S.append(t)
                self.S_dot_A.remove(t)
                return t

        return None

    def get_row_to_complete(self):
        """
        Get row for that needs to be completed.

        Returns:

            row that will be added to S.A
        """

        s_rows = set()
        for s in self.S:
            s_rows.add(tuple((s, self.row_to_hashable(s))))

        for s_row in s_rows:
            similar_s_dot_a_rows = []
            for t in self.S_dot_A:
                row_t = self.row_to_hashable(t)
                if row_t == s_row[1]:
                    similar_s_dot_a_rows.append(t)
            similar_s_dot_a_rows.sort(key=lambda row: len(row[0]))
            for a in self.A:
                complete_outputs = self.get_all_outputs(s_row[0], a)
                for similar_s_dot_a_row in similar_s_dot_a_rows:
                    t_row_outputs = self.get_all_outputs(similar_s_dot_a_row, a)
                    output_difference = t_row_outputs.difference(complete_outputs)
                    if len(output_difference) > 0:
                        for o in output_difference:
                            extension = (similar_s_dot_a_row[0] + a, similar_s_dot_a_row[1] + tuple([o[0]]))
                            if extension not in self.S and extension not in self.S_dot_A:
                                return extension
                            else:
                                complete_outputs = complete_outputs.union(output_difference)

        return None

    def get_row_to_make_consistent(self):
        """
        Get row that violates consistency.
        """
        unified_S = self.S + self.S_dot_A
        s_rows = set()
        for s in self.S:
            s_rows.add(tuple((s, self.row_to_hashable(s))))

        for s_row in s_rows:
            similar_s_dot_a_rows = []
            for t in self.S_dot_A:
                row_t = self.row_to_hashable(t)
                if row_t == s_row[1]:
                    similar_s_dot_a_rows.append(t)

            similar_s_dot_a_rows.sort(key=lambda row: len(row[0]))

            for a in self.A:
                # CHANGED
                #                 outputs = self.observation_table.T[s_row[0]][a]
                outputs = self.get_all_outputs(s_row[0], a)
                for o in outputs:
                    extended_s_sequence = (s_row[0][0] + a, s_row[0][1] + tuple([o]))
                    if extended_s_sequence in unified_S:
                        extended_s_sequence_row = self.row_to_hashable(extended_s_sequence)
                        for similar_s_dot_a_row in similar_s_dot_a_rows:
                            extended_s_dot_a_sequence = (
                                similar_s_dot_a_row[0] + a, similar_s_dot_a_row[1] + tuple([o]))
                            if extended_s_dot_a_sequence in unified_S:
                                extended_s_dot_a_sequence_row = self.row_to_hashable(extended_s_dot_a_sequence)
                                if extended_s_sequence_row is not extended_s_dot_a_sequence_row:
                                    return self.get_distinctive_input_sequence(extended_s_sequence,
                                                                               extended_s_dot_a_sequence, a)

        return None

    def get_distinctive_input_sequence(self, first_row, second_row, inp):
        """
        get input sequence that leads to a different output sequence for two given input/output sequences

        Args:

            first_row: row to be compared
            second_row: row to be compared
            inp: appended input to first_row and second_row that leads to different state 

        Returns:

            input sequence that leads to different outputs

        """
        for e in self.E:
            if len(self.T[first_row][e].difference(self.T[second_row][e])) > 0:
                return tuple([inp]) + e

        return None

    def update_E(self, seq):
        if seq not in self.E:
            self.E.append(seq)

    def clean_obs_table(self):
        """
        Moves duplicates from S to S_dot_A. The entries in S_dot_A which are based on the moved row get deleted.
        The table will be smaller and more efficient.

        """
        # just for testing without cleaning
        # return False

        tmp_S = self.S.copy()
        tmp_both_S = self.S + self.S_dot_A
        hashed_rows_from_s = set()

        tmp_S.sort(key=lambda t: len(t[0]))

        for s in tmp_S:
            hashed_s_row = self.row_to_hashable(s)
            if hashed_s_row in hashed_rows_from_s:
                if s in self.S:
                    self.S.remove(s)
                    self.observation_table.S.remove(s)
                size = len(s[0])
                for row_prefix in tmp_both_S:
                    s_both_row = (row_prefix[0][:size], row_prefix[1][:size])
                    if s != row_prefix and s == s_both_row:
                        if row_prefix in self.S:
                            self.S.remove(row_prefix)
                            self.observation_table.S.remove(s)
            else:
                hashed_rows_from_s.add(hashed_s_row)

    def row_to_hashable(self, row_prefix):
        """
        Creates the hashable representation of the row. Frozenset is used as the order of element in each cell does not
        matter

        Args:

            row_prefix: prefix of the row in the observation table

        Returns:

            hashable representation of the row

        """
        row_repr = tuple()
        for e in self.E:
            # if e in self.T[row_prefix].keys():
            row_repr += (frozenset(self.T[row_prefix][e]),)
        return row_repr

    def gen_hypothesis(self) -> Onfsm:
        """
        Generate automaton based on the values found in the abstracted observation table.

        Returns:

            Current abstracted hypothesis

        """
        state_distinguish = dict()
        states_dict = dict()
        initial = None

        unified_S = self.S + self.S_dot_A

        stateCounter = 0
        for prefix in self.S:
            state_id = f's{stateCounter}'
            states_dict[prefix] = OnfsmState(state_id)

            states_dict[prefix].prefix = prefix
            state_distinguish[self.row_to_hashable(prefix)] = states_dict[prefix]

            if prefix == self.S[0]:
                initial = states_dict[prefix]
            stateCounter += 1

        for prefix in self.S:
            similar_rows = []
            for row in unified_S:
                if self.row_to_hashable(row) == self.row_to_hashable(prefix):
                    similar_rows.append(row)
            for row in similar_rows:
                for a in self.A:
                    for t in self.get_all_outputs(row, a):
                        s_entry = (row[0] + a, row[1] + t)
                        if s_entry in unified_S:
                            state_in_S = state_distinguish[self.row_to_hashable(s_entry)]

                            if (t[0], state_in_S) not in states_dict[prefix].transitions[a[0]]:
                                states_dict[prefix].transitions[a[0]].append((t[0], state_in_S))

        assert initial
        automaton = Onfsm(initial, [s for s in states_dict.values()])
        automaton.characterization_set = self.E

        return automaton

    def extend_S_dot_A(self, cex_prefixes: list):
        """
        Extends S.A based on counterexample prefixes.

        Args:

        cex_prefixes: input/output sequences that are added to S.A

        Returns:

        input/output sequences that have been added to the S.A
        """
        prefixes = self.S + self.S_dot_A
        prefixes_to_extend = []
        for cex_prefix in cex_prefixes:
            if cex_prefix not in prefixes:
                prefixes_to_extend.append(cex_prefix)
                self.S_dot_A.append(cex_prefix)
        return prefixes_to_extend

    def get_abstraction(self, out):
        """
        Get an abstraction for a concrete output. If such abstraction is not defined, return output.

        Args:

            out: output to be abstracted if possible

        Returns:

            abstracted output or output itself
        """
        return self.abstraction_mapping[out] if out in self.abstraction_mapping.keys() else out

    def cex_processing(self, cex: tuple, hypothesis: Onfsm):
        """
        Add counterexample to the observation table. If the counterexample leads to a state where an output of the
        same equivalence class already exists, the prefixes of the counterexample are added to S.A.
        Otherwise, the postfixes of counterexample are added to E.


        Args:

            cex: counterexample that should be added to the observation table
            hypothesis: onfsm that implements the counterexample
        """

        cex_len = len(cex[0])
        hypothesis.reset_to_initial()

        for step in range(0, cex_len - 1):
            hypothesis.step_to(cex[0][step], cex[1][step])

        possible_outputs = hypothesis.outputs_on_input(cex[0][cex_len - 1])

        equivalent_output = False

        for out in possible_outputs:
            abstracted_out = self.get_abstraction(out)
            abstracted_out_cex = self.get_abstraction(cex[1][cex_len - 1])
            if abstracted_out_cex == abstracted_out:
                equivalent_output = True
                break

        if equivalent_output:
            # add prefixes of cex to S_dot_A
            cex_prefixes = [(tuple(cex[0][0:i + 1]), tuple(cex[1][0:i + 1])) for i in range(0, len(cex[0]))]
            prefixes_to_extend = self.extend_S_dot_A(cex_prefixes)

            # CHANGED: REMOVED
            # self.observation_table.S_dot_A.extend(prefixes_to_extend)
            self.update_obs_table(s_set=prefixes_to_extend)
        else:
            # add distinguishing suffixes of cex to E
            # CHANGED CEX PROX
            # TODO: this will now not work as cex processing was changed
            # cex_suffixes = non_det_longest_prefix_cex_processing(self.observation_table, cex)
            cex_suffixes = all_suffixes(cex[0])

            added_suffixes = extend_set(self.observation_table.E, cex_suffixes)
            self.update_obs_table(e_set=added_suffixes)

    def clean_tables(self):

        self.observation_table.clean_obs_table()
        self.abstract_obs_table()

        update_S = self.S.copy()
        whole_S = self.S + self.S_dot_A

        update_S.sort()
        update_S.sort(key=lambda t: len(t[0]))

        s_rows = set()
        for s in update_S:
            hashed_s_row = self.row_to_hashable(s)
            if hashed_s_row not in s_rows:
                s_rows.add(hashed_s_row)
            else:
                size = len(s[0])
                for row in whole_S:
                    cmp_row = (row[0][:size], row[1][:size])
                    if s == cmp_row:
                        if row in self.S_dot_A:
                            self.S_dot_A.remove(row)
                        elif row in self.S:
                            self.S.remove(row)

                self.S_dot_A.append(s)
                self.S.remove(s)
#   class AbstractedNonDetObservationTable:
View Source
class AbstractedNonDetObservationTable:
    def __init__(self, alphabet: list, sul: NonDeterministicSULWrapper, abstraction_mapping: dict, n_sampling=100):
        """
        Construction of the abstracted non-deterministic observation table.

        Args:

            alphabet: input alphabet
            sul: system under learning
            abstraction_mapping: map that translates outputs to abstracted outputs
            n_sampling: number of samples to be performed for each cell
        """

        assert alphabet is not None and sul is not None

        self.observation_table = NonDetObservationTable(alphabet, sul, n_sampling)

        self.S = list()
        self.S_dot_A = []
        self.E = []
        self.T = defaultdict(dict)
        self.A = [tuple([a]) for a in alphabet]

        self.abstraction_mapping = abstraction_mapping
        self.sul = sul

        empty_word = tuple()
        self.S.append((empty_word, empty_word))

    def update_obs_table(self, s_set=None, e_set: list = None):
        """
        Perform the membership queries and abstraction on observation table
        With  the  all-weather  assumption,  each  output  query  is  tried  a  number  of  times  on  the  system,
        and  the  driver  reports  the  set  of  all  possible  outputs.

        Args:

            s_set: Prefixes of S set on which to preform membership queries (Default value = None)
            e_set: Suffixes of E set on which to perform membership queries


        """

        self.observation_table.query_missing_observations(s_set, e_set)
        self.abstract_obs_table()
        self.clean_obs_table()

    def abstract_obs_table(self):
        """
        Creation of abstracted observation table. The provided abstraction mapping is used to
        replace outputs by abstracted outputs.
        """

        self.S = self.observation_table.S
        self.S_dot_A = list(set(self.observation_table.get_extended_S()).union(set(self.S_dot_A) - set(self.S)))
        self.E = self.observation_table.E

        update_S = self.S + self.S_dot_A
        update_E = self.E

        for s in update_S:
            for e in update_E:
                for o_tup in self.get_all_outputs(s, e):
                    abstracted_outputs = []
                    o_tup = tuple([o_tup])
                    for outputs in o_tup:
                        for o in outputs:
                            abstract_output = self.get_abstraction(o)
                            abstracted_outputs.append(abstract_output)
                    self.add_to_T(s, e, tuple(abstracted_outputs))

    def add_to_T(self, s, e, value):
        """
        Add values to the cell at T[s][e].

        Args:

            s: prefix
            e: element of S
            value: value to be added to the cell


        """
        if e not in self.T[s]:
            self.T[s][e] = set()
        self.T[s][e].add(value)

    # CHANGED
    # helper function
    def get_all_outputs(self, s, e):
        cell_outputs = set()
        cell_outputs.update(self.sul.cache.get_all_traces(s, e))
        return cell_outputs

    def update_extended_S(self, row_prefix=None):
        """
        Helper generator function that returns extended S, or S.A set.
        For all values in the cell, create a new row where inputs is parent input plus element of alphabet, and
        output is parent output plus value in cell.

        Returns:

            New rows of extended S set.
        """
        return self.observation_table.get_extended_S(row_prefix=row_prefix)

    def get_row_to_close(self):
        """
        Get row for that needs to be closed.

        Returns:

            row that will be moved to S set and closed
        """
        s_rows = set()
        for s in self.S:
            s_rows.add(self.row_to_hashable(s))

        for t in self.S_dot_A:
            row_t = self.row_to_hashable(t)

            if row_t not in s_rows:
                self.S.append(t)
                self.S_dot_A.remove(t)
                return t

        return None

    def get_row_to_complete(self):
        """
        Get row for that needs to be completed.

        Returns:

            row that will be added to S.A
        """

        s_rows = set()
        for s in self.S:
            s_rows.add(tuple((s, self.row_to_hashable(s))))

        for s_row in s_rows:
            similar_s_dot_a_rows = []
            for t in self.S_dot_A:
                row_t = self.row_to_hashable(t)
                if row_t == s_row[1]:
                    similar_s_dot_a_rows.append(t)
            similar_s_dot_a_rows.sort(key=lambda row: len(row[0]))
            for a in self.A:
                complete_outputs = self.get_all_outputs(s_row[0], a)
                for similar_s_dot_a_row in similar_s_dot_a_rows:
                    t_row_outputs = self.get_all_outputs(similar_s_dot_a_row, a)
                    output_difference = t_row_outputs.difference(complete_outputs)
                    if len(output_difference) > 0:
                        for o in output_difference:
                            extension = (similar_s_dot_a_row[0] + a, similar_s_dot_a_row[1] + tuple([o[0]]))
                            if extension not in self.S and extension not in self.S_dot_A:
                                return extension
                            else:
                                complete_outputs = complete_outputs.union(output_difference)

        return None

    def get_row_to_make_consistent(self):
        """
        Get row that violates consistency.
        """
        unified_S = self.S + self.S_dot_A
        s_rows = set()
        for s in self.S:
            s_rows.add(tuple((s, self.row_to_hashable(s))))

        for s_row in s_rows:
            similar_s_dot_a_rows = []
            for t in self.S_dot_A:
                row_t = self.row_to_hashable(t)
                if row_t == s_row[1]:
                    similar_s_dot_a_rows.append(t)

            similar_s_dot_a_rows.sort(key=lambda row: len(row[0]))

            for a in self.A:
                # CHANGED
                #                 outputs = self.observation_table.T[s_row[0]][a]
                outputs = self.get_all_outputs(s_row[0], a)
                for o in outputs:
                    extended_s_sequence = (s_row[0][0] + a, s_row[0][1] + tuple([o]))
                    if extended_s_sequence in unified_S:
                        extended_s_sequence_row = self.row_to_hashable(extended_s_sequence)
                        for similar_s_dot_a_row in similar_s_dot_a_rows:
                            extended_s_dot_a_sequence = (
                                similar_s_dot_a_row[0] + a, similar_s_dot_a_row[1] + tuple([o]))
                            if extended_s_dot_a_sequence in unified_S:
                                extended_s_dot_a_sequence_row = self.row_to_hashable(extended_s_dot_a_sequence)
                                if extended_s_sequence_row is not extended_s_dot_a_sequence_row:
                                    return self.get_distinctive_input_sequence(extended_s_sequence,
                                                                               extended_s_dot_a_sequence, a)

        return None

    def get_distinctive_input_sequence(self, first_row, second_row, inp):
        """
        get input sequence that leads to a different output sequence for two given input/output sequences

        Args:

            first_row: row to be compared
            second_row: row to be compared
            inp: appended input to first_row and second_row that leads to different state 

        Returns:

            input sequence that leads to different outputs

        """
        for e in self.E:
            if len(self.T[first_row][e].difference(self.T[second_row][e])) > 0:
                return tuple([inp]) + e

        return None

    def update_E(self, seq):
        if seq not in self.E:
            self.E.append(seq)

    def clean_obs_table(self):
        """
        Moves duplicates from S to S_dot_A. The entries in S_dot_A which are based on the moved row get deleted.
        The table will be smaller and more efficient.

        """
        # just for testing without cleaning
        # return False

        tmp_S = self.S.copy()
        tmp_both_S = self.S + self.S_dot_A
        hashed_rows_from_s = set()

        tmp_S.sort(key=lambda t: len(t[0]))

        for s in tmp_S:
            hashed_s_row = self.row_to_hashable(s)
            if hashed_s_row in hashed_rows_from_s:
                if s in self.S:
                    self.S.remove(s)
                    self.observation_table.S.remove(s)
                size = len(s[0])
                for row_prefix in tmp_both_S:
                    s_both_row = (row_prefix[0][:size], row_prefix[1][:size])
                    if s != row_prefix and s == s_both_row:
                        if row_prefix in self.S:
                            self.S.remove(row_prefix)
                            self.observation_table.S.remove(s)
            else:
                hashed_rows_from_s.add(hashed_s_row)

    def row_to_hashable(self, row_prefix):
        """
        Creates the hashable representation of the row. Frozenset is used as the order of element in each cell does not
        matter

        Args:

            row_prefix: prefix of the row in the observation table

        Returns:

            hashable representation of the row

        """
        row_repr = tuple()
        for e in self.E:
            # if e in self.T[row_prefix].keys():
            row_repr += (frozenset(self.T[row_prefix][e]),)
        return row_repr

    def gen_hypothesis(self) -> Onfsm:
        """
        Generate automaton based on the values found in the abstracted observation table.

        Returns:

            Current abstracted hypothesis

        """
        state_distinguish = dict()
        states_dict = dict()
        initial = None

        unified_S = self.S + self.S_dot_A

        stateCounter = 0
        for prefix in self.S:
            state_id = f's{stateCounter}'
            states_dict[prefix] = OnfsmState(state_id)

            states_dict[prefix].prefix = prefix
            state_distinguish[self.row_to_hashable(prefix)] = states_dict[prefix]

            if prefix == self.S[0]:
                initial = states_dict[prefix]
            stateCounter += 1

        for prefix in self.S:
            similar_rows = []
            for row in unified_S:
                if self.row_to_hashable(row) == self.row_to_hashable(prefix):
                    similar_rows.append(row)
            for row in similar_rows:
                for a in self.A:
                    for t in self.get_all_outputs(row, a):
                        s_entry = (row[0] + a, row[1] + t)
                        if s_entry in unified_S:
                            state_in_S = state_distinguish[self.row_to_hashable(s_entry)]

                            if (t[0], state_in_S) not in states_dict[prefix].transitions[a[0]]:
                                states_dict[prefix].transitions[a[0]].append((t[0], state_in_S))

        assert initial
        automaton = Onfsm(initial, [s for s in states_dict.values()])
        automaton.characterization_set = self.E

        return automaton

    def extend_S_dot_A(self, cex_prefixes: list):
        """
        Extends S.A based on counterexample prefixes.

        Args:

        cex_prefixes: input/output sequences that are added to S.A

        Returns:

        input/output sequences that have been added to the S.A
        """
        prefixes = self.S + self.S_dot_A
        prefixes_to_extend = []
        for cex_prefix in cex_prefixes:
            if cex_prefix not in prefixes:
                prefixes_to_extend.append(cex_prefix)
                self.S_dot_A.append(cex_prefix)
        return prefixes_to_extend

    def get_abstraction(self, out):
        """
        Get an abstraction for a concrete output. If such abstraction is not defined, return output.

        Args:

            out: output to be abstracted if possible

        Returns:

            abstracted output or output itself
        """
        return self.abstraction_mapping[out] if out in self.abstraction_mapping.keys() else out

    def cex_processing(self, cex: tuple, hypothesis: Onfsm):
        """
        Add counterexample to the observation table. If the counterexample leads to a state where an output of the
        same equivalence class already exists, the prefixes of the counterexample are added to S.A.
        Otherwise, the postfixes of counterexample are added to E.


        Args:

            cex: counterexample that should be added to the observation table
            hypothesis: onfsm that implements the counterexample
        """

        cex_len = len(cex[0])
        hypothesis.reset_to_initial()

        for step in range(0, cex_len - 1):
            hypothesis.step_to(cex[0][step], cex[1][step])

        possible_outputs = hypothesis.outputs_on_input(cex[0][cex_len - 1])

        equivalent_output = False

        for out in possible_outputs:
            abstracted_out = self.get_abstraction(out)
            abstracted_out_cex = self.get_abstraction(cex[1][cex_len - 1])
            if abstracted_out_cex == abstracted_out:
                equivalent_output = True
                break

        if equivalent_output:
            # add prefixes of cex to S_dot_A
            cex_prefixes = [(tuple(cex[0][0:i + 1]), tuple(cex[1][0:i + 1])) for i in range(0, len(cex[0]))]
            prefixes_to_extend = self.extend_S_dot_A(cex_prefixes)

            # CHANGED: REMOVED
            # self.observation_table.S_dot_A.extend(prefixes_to_extend)
            self.update_obs_table(s_set=prefixes_to_extend)
        else:
            # add distinguishing suffixes of cex to E
            # CHANGED CEX PROX
            # TODO: this will now not work as cex processing was changed
            # cex_suffixes = non_det_longest_prefix_cex_processing(self.observation_table, cex)
            cex_suffixes = all_suffixes(cex[0])

            added_suffixes = extend_set(self.observation_table.E, cex_suffixes)
            self.update_obs_table(e_set=added_suffixes)

    def clean_tables(self):

        self.observation_table.clean_obs_table()
        self.abstract_obs_table()

        update_S = self.S.copy()
        whole_S = self.S + self.S_dot_A

        update_S.sort()
        update_S.sort(key=lambda t: len(t[0]))

        s_rows = set()
        for s in update_S:
            hashed_s_row = self.row_to_hashable(s)
            if hashed_s_row not in s_rows:
                s_rows.add(hashed_s_row)
            else:
                size = len(s[0])
                for row in whole_S:
                    cmp_row = (row[0][:size], row[1][:size])
                    if s == cmp_row:
                        if row in self.S_dot_A:
                            self.S_dot_A.remove(row)
                        elif row in self.S:
                            self.S.remove(row)

                self.S_dot_A.append(s)
                self.S.remove(s)
#   AbstractedNonDetObservationTable( alphabet: list, sul: aalpy.learning_algs.non_deterministic.NonDeterministicSULWrapper.NonDeterministicSULWrapper, abstraction_mapping: dict, n_sampling=100 )
View Source
    def __init__(self, alphabet: list, sul: NonDeterministicSULWrapper, abstraction_mapping: dict, n_sampling=100):
        """
        Construction of the abstracted non-deterministic observation table.

        Args:

            alphabet: input alphabet
            sul: system under learning
            abstraction_mapping: map that translates outputs to abstracted outputs
            n_sampling: number of samples to be performed for each cell
        """

        assert alphabet is not None and sul is not None

        self.observation_table = NonDetObservationTable(alphabet, sul, n_sampling)

        self.S = list()
        self.S_dot_A = []
        self.E = []
        self.T = defaultdict(dict)
        self.A = [tuple([a]) for a in alphabet]

        self.abstraction_mapping = abstraction_mapping
        self.sul = sul

        empty_word = tuple()
        self.S.append((empty_word, empty_word))

Construction of the abstracted non-deterministic observation table.

Args:

alphabet: input alphabet
sul: system under learning
abstraction_mapping: map that translates outputs to abstracted outputs
n_sampling: number of samples to be performed for each cell
#   def update_obs_table(self, s_set=None, e_set: list = None):
View Source
    def update_obs_table(self, s_set=None, e_set: list = None):
        """
        Perform the membership queries and abstraction on observation table
        With  the  all-weather  assumption,  each  output  query  is  tried  a  number  of  times  on  the  system,
        and  the  driver  reports  the  set  of  all  possible  outputs.

        Args:

            s_set: Prefixes of S set on which to preform membership queries (Default value = None)
            e_set: Suffixes of E set on which to perform membership queries


        """

        self.observation_table.query_missing_observations(s_set, e_set)
        self.abstract_obs_table()
        self.clean_obs_table()

Perform the membership queries and abstraction on observation table With the all-weather assumption, each output query is tried a number of times on the system, and the driver reports the set of all possible outputs.

Args:

s_set: Prefixes of S set on which to preform membership queries (Default value = None)
e_set: Suffixes of E set on which to perform membership queries
#   def abstract_obs_table(self):
View Source
    def abstract_obs_table(self):
        """
        Creation of abstracted observation table. The provided abstraction mapping is used to
        replace outputs by abstracted outputs.
        """

        self.S = self.observation_table.S
        self.S_dot_A = list(set(self.observation_table.get_extended_S()).union(set(self.S_dot_A) - set(self.S)))
        self.E = self.observation_table.E

        update_S = self.S + self.S_dot_A
        update_E = self.E

        for s in update_S:
            for e in update_E:
                for o_tup in self.get_all_outputs(s, e):
                    abstracted_outputs = []
                    o_tup = tuple([o_tup])
                    for outputs in o_tup:
                        for o in outputs:
                            abstract_output = self.get_abstraction(o)
                            abstracted_outputs.append(abstract_output)
                    self.add_to_T(s, e, tuple(abstracted_outputs))

Creation of abstracted observation table. The provided abstraction mapping is used to replace outputs by abstracted outputs.

#   def add_to_T(self, s, e, value):
View Source
    def add_to_T(self, s, e, value):
        """
        Add values to the cell at T[s][e].

        Args:

            s: prefix
            e: element of S
            value: value to be added to the cell


        """
        if e not in self.T[s]:
            self.T[s][e] = set()
        self.T[s][e].add(value)

Add values to the cell at T[s][e].

Args:

s: prefix
e: element of S
value: value to be added to the cell
#   def get_all_outputs(self, s, e):
View Source
    def get_all_outputs(self, s, e):
        cell_outputs = set()
        cell_outputs.update(self.sul.cache.get_all_traces(s, e))
        return cell_outputs
#   def update_extended_S(self, row_prefix=None):
View Source
    def update_extended_S(self, row_prefix=None):
        """
        Helper generator function that returns extended S, or S.A set.
        For all values in the cell, create a new row where inputs is parent input plus element of alphabet, and
        output is parent output plus value in cell.

        Returns:

            New rows of extended S set.
        """
        return self.observation_table.get_extended_S(row_prefix=row_prefix)

Helper generator function that returns extended S, or S.A set. For all values in the cell, create a new row where inputs is parent input plus element of alphabet, and output is parent output plus value in cell.

Returns:

New rows of extended S set.
#   def get_row_to_close(self):
View Source
    def get_row_to_close(self):
        """
        Get row for that needs to be closed.

        Returns:

            row that will be moved to S set and closed
        """
        s_rows = set()
        for s in self.S:
            s_rows.add(self.row_to_hashable(s))

        for t in self.S_dot_A:
            row_t = self.row_to_hashable(t)

            if row_t not in s_rows:
                self.S.append(t)
                self.S_dot_A.remove(t)
                return t

        return None

Get row for that needs to be closed.

Returns:

row that will be moved to S set and closed
#   def get_row_to_complete(self):
View Source
    def get_row_to_complete(self):
        """
        Get row for that needs to be completed.

        Returns:

            row that will be added to S.A
        """

        s_rows = set()
        for s in self.S:
            s_rows.add(tuple((s, self.row_to_hashable(s))))

        for s_row in s_rows:
            similar_s_dot_a_rows = []
            for t in self.S_dot_A:
                row_t = self.row_to_hashable(t)
                if row_t == s_row[1]:
                    similar_s_dot_a_rows.append(t)
            similar_s_dot_a_rows.sort(key=lambda row: len(row[0]))
            for a in self.A:
                complete_outputs = self.get_all_outputs(s_row[0], a)
                for similar_s_dot_a_row in similar_s_dot_a_rows:
                    t_row_outputs = self.get_all_outputs(similar_s_dot_a_row, a)
                    output_difference = t_row_outputs.difference(complete_outputs)
                    if len(output_difference) > 0:
                        for o in output_difference:
                            extension = (similar_s_dot_a_row[0] + a, similar_s_dot_a_row[1] + tuple([o[0]]))
                            if extension not in self.S and extension not in self.S_dot_A:
                                return extension
                            else:
                                complete_outputs = complete_outputs.union(output_difference)

        return None

Get row for that needs to be completed.

Returns:

row that will be added to S.A
#   def get_row_to_make_consistent(self):
View Source
    def get_row_to_make_consistent(self):
        """
        Get row that violates consistency.
        """
        unified_S = self.S + self.S_dot_A
        s_rows = set()
        for s in self.S:
            s_rows.add(tuple((s, self.row_to_hashable(s))))

        for s_row in s_rows:
            similar_s_dot_a_rows = []
            for t in self.S_dot_A:
                row_t = self.row_to_hashable(t)
                if row_t == s_row[1]:
                    similar_s_dot_a_rows.append(t)

            similar_s_dot_a_rows.sort(key=lambda row: len(row[0]))

            for a in self.A:
                # CHANGED
                #                 outputs = self.observation_table.T[s_row[0]][a]
                outputs = self.get_all_outputs(s_row[0], a)
                for o in outputs:
                    extended_s_sequence = (s_row[0][0] + a, s_row[0][1] + tuple([o]))
                    if extended_s_sequence in unified_S:
                        extended_s_sequence_row = self.row_to_hashable(extended_s_sequence)
                        for similar_s_dot_a_row in similar_s_dot_a_rows:
                            extended_s_dot_a_sequence = (
                                similar_s_dot_a_row[0] + a, similar_s_dot_a_row[1] + tuple([o]))
                            if extended_s_dot_a_sequence in unified_S:
                                extended_s_dot_a_sequence_row = self.row_to_hashable(extended_s_dot_a_sequence)
                                if extended_s_sequence_row is not extended_s_dot_a_sequence_row:
                                    return self.get_distinctive_input_sequence(extended_s_sequence,
                                                                               extended_s_dot_a_sequence, a)

        return None

Get row that violates consistency.

#   def get_distinctive_input_sequence(self, first_row, second_row, inp):
View Source
    def get_distinctive_input_sequence(self, first_row, second_row, inp):
        """
        get input sequence that leads to a different output sequence for two given input/output sequences

        Args:

            first_row: row to be compared
            second_row: row to be compared
            inp: appended input to first_row and second_row that leads to different state 

        Returns:

            input sequence that leads to different outputs

        """
        for e in self.E:
            if len(self.T[first_row][e].difference(self.T[second_row][e])) > 0:
                return tuple([inp]) + e

        return None

get input sequence that leads to a different output sequence for two given input/output sequences

Args:

first_row: row to be compared
second_row: row to be compared
inp: appended input to first_row and second_row that leads to different state

Returns:

input sequence that leads to different outputs
#   def update_E(self, seq):
View Source
    def update_E(self, seq):
        if seq not in self.E:
            self.E.append(seq)
#   def clean_obs_table(self):
View Source
    def clean_obs_table(self):
        """
        Moves duplicates from S to S_dot_A. The entries in S_dot_A which are based on the moved row get deleted.
        The table will be smaller and more efficient.

        """
        # just for testing without cleaning
        # return False

        tmp_S = self.S.copy()
        tmp_both_S = self.S + self.S_dot_A
        hashed_rows_from_s = set()

        tmp_S.sort(key=lambda t: len(t[0]))

        for s in tmp_S:
            hashed_s_row = self.row_to_hashable(s)
            if hashed_s_row in hashed_rows_from_s:
                if s in self.S:
                    self.S.remove(s)
                    self.observation_table.S.remove(s)
                size = len(s[0])
                for row_prefix in tmp_both_S:
                    s_both_row = (row_prefix[0][:size], row_prefix[1][:size])
                    if s != row_prefix and s == s_both_row:
                        if row_prefix in self.S:
                            self.S.remove(row_prefix)
                            self.observation_table.S.remove(s)
            else:
                hashed_rows_from_s.add(hashed_s_row)

Moves duplicates from S to S_dot_A. The entries in S_dot_A which are based on the moved row get deleted. The table will be smaller and more efficient.

#   def row_to_hashable(self, row_prefix):
View Source
    def row_to_hashable(self, row_prefix):
        """
        Creates the hashable representation of the row. Frozenset is used as the order of element in each cell does not
        matter

        Args:

            row_prefix: prefix of the row in the observation table

        Returns:

            hashable representation of the row

        """
        row_repr = tuple()
        for e in self.E:
            # if e in self.T[row_prefix].keys():
            row_repr += (frozenset(self.T[row_prefix][e]),)
        return row_repr

Creates the hashable representation of the row. Frozenset is used as the order of element in each cell does not matter

Args:

row_prefix: prefix of the row in the observation table

Returns:

hashable representation of the row
#   def gen_hypothesis(self) -> aalpy.automata.Onfsm.Onfsm:
View Source
    def gen_hypothesis(self) -> Onfsm:
        """
        Generate automaton based on the values found in the abstracted observation table.

        Returns:

            Current abstracted hypothesis

        """
        state_distinguish = dict()
        states_dict = dict()
        initial = None

        unified_S = self.S + self.S_dot_A

        stateCounter = 0
        for prefix in self.S:
            state_id = f's{stateCounter}'
            states_dict[prefix] = OnfsmState(state_id)

            states_dict[prefix].prefix = prefix
            state_distinguish[self.row_to_hashable(prefix)] = states_dict[prefix]

            if prefix == self.S[0]:
                initial = states_dict[prefix]
            stateCounter += 1

        for prefix in self.S:
            similar_rows = []
            for row in unified_S:
                if self.row_to_hashable(row) == self.row_to_hashable(prefix):
                    similar_rows.append(row)
            for row in similar_rows:
                for a in self.A:
                    for t in self.get_all_outputs(row, a):
                        s_entry = (row[0] + a, row[1] + t)
                        if s_entry in unified_S:
                            state_in_S = state_distinguish[self.row_to_hashable(s_entry)]

                            if (t[0], state_in_S) not in states_dict[prefix].transitions[a[0]]:
                                states_dict[prefix].transitions[a[0]].append((t[0], state_in_S))

        assert initial
        automaton = Onfsm(initial, [s for s in states_dict.values()])
        automaton.characterization_set = self.E

        return automaton

Generate automaton based on the values found in the abstracted observation table.

Returns:

Current abstracted hypothesis
#   def extend_S_dot_A(self, cex_prefixes: list):
View Source
    def extend_S_dot_A(self, cex_prefixes: list):
        """
        Extends S.A based on counterexample prefixes.

        Args:

        cex_prefixes: input/output sequences that are added to S.A

        Returns:

        input/output sequences that have been added to the S.A
        """
        prefixes = self.S + self.S_dot_A
        prefixes_to_extend = []
        for cex_prefix in cex_prefixes:
            if cex_prefix not in prefixes:
                prefixes_to_extend.append(cex_prefix)
                self.S_dot_A.append(cex_prefix)
        return prefixes_to_extend

Extends S.A based on counterexample prefixes.

Args:

cex_prefixes: input/output sequences that are added to S.A

Returns:

input/output sequences that have been added to the S.A

#   def get_abstraction(self, out):
View Source
    def get_abstraction(self, out):
        """
        Get an abstraction for a concrete output. If such abstraction is not defined, return output.

        Args:

            out: output to be abstracted if possible

        Returns:

            abstracted output or output itself
        """
        return self.abstraction_mapping[out] if out in self.abstraction_mapping.keys() else out

Get an abstraction for a concrete output. If such abstraction is not defined, return output.

Args:

out: output to be abstracted if possible

Returns:

abstracted output or output itself
#   def cex_processing(self, cex: tuple, hypothesis: aalpy.automata.Onfsm.Onfsm):
View Source
    def cex_processing(self, cex: tuple, hypothesis: Onfsm):
        """
        Add counterexample to the observation table. If the counterexample leads to a state where an output of the
        same equivalence class already exists, the prefixes of the counterexample are added to S.A.
        Otherwise, the postfixes of counterexample are added to E.


        Args:

            cex: counterexample that should be added to the observation table
            hypothesis: onfsm that implements the counterexample
        """

        cex_len = len(cex[0])
        hypothesis.reset_to_initial()

        for step in range(0, cex_len - 1):
            hypothesis.step_to(cex[0][step], cex[1][step])

        possible_outputs = hypothesis.outputs_on_input(cex[0][cex_len - 1])

        equivalent_output = False

        for out in possible_outputs:
            abstracted_out = self.get_abstraction(out)
            abstracted_out_cex = self.get_abstraction(cex[1][cex_len - 1])
            if abstracted_out_cex == abstracted_out:
                equivalent_output = True
                break

        if equivalent_output:
            # add prefixes of cex to S_dot_A
            cex_prefixes = [(tuple(cex[0][0:i + 1]), tuple(cex[1][0:i + 1])) for i in range(0, len(cex[0]))]
            prefixes_to_extend = self.extend_S_dot_A(cex_prefixes)

            # CHANGED: REMOVED
            # self.observation_table.S_dot_A.extend(prefixes_to_extend)
            self.update_obs_table(s_set=prefixes_to_extend)
        else:
            # add distinguishing suffixes of cex to E
            # CHANGED CEX PROX
            # TODO: this will now not work as cex processing was changed
            # cex_suffixes = non_det_longest_prefix_cex_processing(self.observation_table, cex)
            cex_suffixes = all_suffixes(cex[0])

            added_suffixes = extend_set(self.observation_table.E, cex_suffixes)
            self.update_obs_table(e_set=added_suffixes)

Add counterexample to the observation table. If the counterexample leads to a state where an output of the same equivalence class already exists, the prefixes of the counterexample are added to S.A. Otherwise, the postfixes of counterexample are added to E.

Args:

cex: counterexample that should be added to the observation table
hypothesis: onfsm that implements the counterexample
#   def clean_tables(self):
View Source
    def clean_tables(self):

        self.observation_table.clean_obs_table()
        self.abstract_obs_table()

        update_S = self.S.copy()
        whole_S = self.S + self.S_dot_A

        update_S.sort()
        update_S.sort(key=lambda t: len(t[0]))

        s_rows = set()
        for s in update_S:
            hashed_s_row = self.row_to_hashable(s)
            if hashed_s_row not in s_rows:
                s_rows.add(hashed_s_row)
            else:
                size = len(s[0])
                for row in whole_S:
                    cmp_row = (row[0][:size], row[1][:size])
                    if s == cmp_row:
                        if row in self.S_dot_A:
                            self.S_dot_A.remove(row)
                        elif row in self.S:
                            self.S.remove(row)

                self.S_dot_A.append(s)
                self.S.remove(s)