Multi-armed bandit¶

  • Michal Kubišta
  • https://github.com/kubistmi/contextual_MAB

credits: Murder (The Office)¶

Multi-armed bandit¶

credits: Alex Slivkins - Microsoft Research Silicon Valley¶

Table of contents:¶

  1. Multi-armed bandit introduction
  2. Data review
  3. Code review
  4. Model results

What is MAB?¶

  • simplified reinforcement learning problem
  • set of actions that yield unknown rewards
  • maximise the total reward
  • exploration-exploitation tradeoff
    • based on policy

Contextual MAB¶

  • get some context before deciding on the action
  • now we can try to collect data and try to predict the reward
    • reward ~ context
  • each action gets its oracle (ML model)

Oracle and policy¶

And the data provider¶

Data review¶

In [3]:
# dataset from https://www.kaggle.com/blastchar/telco-customer-churn
churn = pd.read_csv("data/WA_Fn-UseC_-Telco-Customer-Churn.csv")
churn.rename(str.lower, axis = 'columns', inplace = True)

churn.head(5)
Out[3]:
customerid gender seniorcitizen partner dependents tenure phoneservice multiplelines internetservice onlinesecurity ... deviceprotection techsupport streamingtv streamingmovies contract paperlessbilling paymentmethod monthlycharges totalcharges churn
0 7590-VHVEG Female 0 Yes No 1 No No phone service DSL No ... No No No No Month-to-month Yes Electronic check 29.85 29.85 No
1 5575-GNVDE Male 0 No No 34 Yes No DSL Yes ... Yes No No No One year No Mailed check 56.95 1889.5 No
2 3668-QPYBK Male 0 No No 2 Yes No DSL Yes ... No No No No Month-to-month Yes Mailed check 53.85 108.15 Yes
3 7795-CFOCW Male 0 No No 45 No No phone service DSL Yes ... Yes Yes No No One year No Bank transfer (automatic) 42.30 1840.75 No
4 9237-HQITU Female 0 No No 2 Yes No Fiber optic No ... No No No No Month-to-month Yes Electronic check 70.70 151.65 Yes

5 rows × 21 columns

Numerical variables¶

In [5]:
churn.hist(figsize = (30,6), layout = (1,4))

churn.describe()
Out[5]:
tenure monthlycharges totalcharges
count 7043.000000 7043.000000 7032.000000
mean 32.371149 64.761692 2283.300441
std 24.559481 30.090047 2266.771362
min 0.000000 18.250000 18.800000
25% 9.000000 35.500000 401.450000
50% 29.000000 70.350000 1397.475000
75% 55.000000 89.850000 3794.737500
max 72.000000 118.750000 8684.800000

Categorical variables¶

In [7]:
display_dfs(cats)
gender
Male 3555
Female 3488
          
seniorcitizen
no 5901
yes 1142
          
partner
No 3641
Yes 3402
          
dependents
No 4933
Yes 2110
          
phoneservice
Yes 6361
No 682
          
multiplelines
No 3390
Yes 2971
No phone service 682
          
internetservice
Fiber optic 3096
DSL 2421
No 1526
          
onlinesecurity
No 3498
Yes 2019
No internet service 1526
          
onlinebackup
No 3088
Yes 2429
No internet service 1526
          
deviceprotection
No 3095
Yes 2422
No internet service 1526
          
techsupport
No 3473
Yes 2044
No internet service 1526
          
streamingtv
No 2810
Yes 2707
No internet service 1526
          
streamingmovies
No 2785
Yes 2732
No internet service 1526
          
contract
Month-to-month 3875
Two year 1695
One year 1473
          
paperlessbilling
Yes 4171
No 2872
          
paymentmethod
Electronic check 2365
Mailed check 1612
Bank transfer (automatic) 1544
Credit card (automatic) 1522
          
churn
No 5174
Yes 1869
          

NA handling¶

In [8]:
(
    churn
    .isna()
    .apply(sum)
    .sort_values(ascending = False)
    .head(2)
)
Out[8]:
totalcharges    11
customerid       0
dtype: int64
In [9]:
churn.loc[churn.totalcharges.isna()]
Out[9]:
customerid gender seniorcitizen partner dependents tenure phoneservice multiplelines internetservice onlinesecurity ... deviceprotection techsupport streamingtv streamingmovies contract paperlessbilling paymentmethod monthlycharges totalcharges churn
488 4472-LVYGI Female no Yes Yes 0 No No phone service DSL Yes ... Yes Yes Yes No Two year Yes Bank transfer (automatic) 52.55 NaN No
753 3115-CZMZD Male no No Yes 0 Yes No No No internet service ... No internet service No internet service No internet service No internet service Two year No Mailed check 20.25 NaN No
936 5709-LVOEQ Female no Yes Yes 0 Yes No DSL Yes ... Yes No Yes Yes Two year No Mailed check 80.85 NaN No
1082 4367-NUYAO Male no Yes Yes 0 Yes Yes No No internet service ... No internet service No internet service No internet service No internet service Two year No Mailed check 25.75 NaN No
1340 1371-DWPAZ Female no Yes Yes 0 No No phone service DSL Yes ... Yes Yes Yes No Two year No Credit card (automatic) 56.05 NaN No
3331 7644-OMVMY Male no Yes Yes 0 Yes No No No internet service ... No internet service No internet service No internet service No internet service Two year No Mailed check 19.85 NaN No
3826 3213-VVOLG Male no Yes Yes 0 Yes Yes No No internet service ... No internet service No internet service No internet service No internet service Two year No Mailed check 25.35 NaN No
4380 2520-SGTTA Female no Yes Yes 0 Yes No No No internet service ... No internet service No internet service No internet service No internet service Two year No Mailed check 20.00 NaN No
5218 2923-ARZLG Male no Yes Yes 0 Yes No No No internet service ... No internet service No internet service No internet service No internet service One year Yes Mailed check 19.70 NaN No
6670 4075-WKNIU Female no Yes Yes 0 Yes Yes DSL No ... Yes Yes Yes No Two year No Mailed check 73.35 NaN No
6754 2775-SEFEE Male no No Yes 0 Yes Yes DSL Yes ... No Yes No No Two year Yes Bank transfer (automatic) 61.90 NaN No

11 rows × 21 columns

Check the hypothesis¶

In [10]:
lm = smf.ols("totalcharges ~ monthlycharges:tenure", churn).fit()
summary = lm.summary()
summary.extra_txt = ''
summary
Out[10]:
OLS Regression Results
Dep. Variable: totalcharges R-squared: 0.999
Model: OLS Adj. R-squared: 0.999
Method: Least Squares F-statistic: 7.981e+06
Date: Wed, 30 Jun 2021 Prob (F-statistic): 0.00
Time: 08:06:55 Log-Likelihood: -39571.
No. Observations: 7032 AIC: 7.915e+04
Df Residuals: 7030 BIC: 7.916e+04
Df Model: 1
Covariance Type: nonrobust
coef std err t P>|t| [0.025 0.975]
Intercept -0.9259 1.139 -0.813 0.416 -3.158 1.307
monthlycharges:tenure 1.0005 0.000 2825.026 0.000 1.000 1.001
Omnibus: 536.457 Durbin-Watson: 2.055
Prob(Omnibus): 0.000 Jarque-Bera (JB): 3044.087
Skew: -0.034 Prob(JB): 0.00
Kurtosis: 6.223 Cond. No. 4.57e+03


In [11]:
churn_wona = churn.assign(tenmon = churn.monthlycharges * churn.tenure).dropna()
churn_wona.plot("tenmon", "totalcharges", kind = "scatter", figsize = (10,6))
plt.plot(churn_wona.tenmon, lm.fittedvalues, color = 'orange')
plt.show()
del(churn_wona)
In [12]:
churn.fillna(0, inplace = True)
churn.to_csv('data/cleaned.csv', index = False)

Code review¶

Agent¶

  • part of the domain piece of the architecture
  • serves as a router for other modules
In [14]:
# Agent is wrapper around DataProvider, Oracle and Policy
class Agent(ABC):
    #    
    def __init__(self, provider: DataProvider, oracle: Oracle, policy: Policy):
        self.provider = provider
        self.oracle = oracle
        self.policy = policy
    #
    def act(self, X: pd.DataFrame, time: int) -> int:
        pred = self.oracle.predict(X)
        return(self.policy.decide(pred, time))
    #
    def save_iter(self, context: pd.DataFrame, action:int, reward: int) -> None:
        self.provider.collect(context, action, reward)
    #
    def update(self) -> None:
        history = self.provider.provide()
        self.oracle.fit(history)
    #
    def replay(self) -> None:
        history = self.provider.provide(self.provider.size())
        self.oracle.fit(history)

Environment¶

  • handles the provision of the context
  • evaluates the actions (with reward)
In [15]:
# Environment handles the provision of context and rewards based on specified data
class Environment(ABC):
    #
    @abstractmethod
    def get_context(self) -> pd.DataFrame:
        pass
    #
    @abstractmethod
    def evaluate(self, action: int) -> float:
        pass

Learn¶

  • part of the domain piece of the architecture
  • defines the iterative process
In [16]:
# Learn function handles the iterative learning of the Agent using the Environment
def learn(agent: Agent, env: Environment, iters: int, update_freq: int, replay_freq : int = None) -> Agent:
    replay = True
    if replay_freq is None:
        replay = False
    #
    for i in range(iters):
        if i > 0:
            if replay and (i % replay_freq == 0):
                agent.replay()
            elif i % update_freq == 0:
                agent.update()
        cx = env.get_context()
        act = agent.act(cx, i)
        rew = env.evaluate(act)
        agent.save_iter(cx, act, rew)
    return(agent)

Oracle¶

  • wrapper for set of (ML) models
In [17]:
# Oracle is a wrapper around set of models used to predict reward based on context
class Oracle(ABC):
    #
    def __init__(self, actions: List[int], min_reward: int):
        self.minrew = min_reward
        self.actions = actions
        self.oracles = {a : LinearRegression() for a in actions}
    #
    def fit(self, X: pd.DataFrame) -> None:
        for i in self.actions:
            self.__fit_oracle__(i, X.query("action == @i"))
    #
    def predict(self, X: pd.DataFrame) -> pd.Series:
        out = {
            i: self.__predict_oracle__(i, X)
            for i in self.actions
        }
        return(pd.Series(out))
    #
    @abstractmethod
    def __fit_oracle__(self, oracle: int, X: pd.DataFrame):
        pass
    #
    @abstractmethod
    def __predict_oracle__(self, oracle:int, X: pd.DataFrame):
        pass

Policy¶

In [18]:
# Policy defines the tactic used to handle the exploration-exploitation tradeoff
class Policy(ABC):
    #
    @abstractmethod
    def decide(self, rewards: pd.Series, time: int) -> int:
        pass

Provider¶

In [19]:
class DataProvider(ABC):
    #
    def __init__(self, batchsize: int):
        self.contexts = []
        self.actions = []
        self.rewards = []
        self.defsize = batchsize
    #
    def size(self) -> int:
        return(len(self.contexts))
    #
    def collect(self, context: pd.DataFrame, action: int, reward: float) -> None:
        self.contexts.append(context)
        self.actions.append(action)
        self.rewards.append(reward)
    #
    @abstractmethod
    def provide(self, size: int) -> pd.DataFrame:
        pass

Implementations¶

Provider

  • Batch
  • Sample

Oracle

  • LinReg
  • OnReg
  • RegTree
  • Neural

Policy

  • AdaGreed
  • EpsGreed

Reward specification¶

In [21]:
rewards
Out[21]:
act 0 1 mean
pred
0 1.0 -2.0 -0.5
1 -1.5 0.5 -0.5

Model results¶