Logistic Regression#

%pip install pytensor pymc
Requirement already satisfied: pytensor in /usr/local/lib/python3.10/dist-packages (2.18.6)
Requirement already satisfied: pymc in /usr/local/lib/python3.10/dist-packages (5.10.4)
Requirement already satisfied: setuptools>=48.0.0 in /usr/local/lib/python3.10/dist-packages (from pytensor) (67.7.2)
Requirement already satisfied: scipy>=0.14 in /usr/local/lib/python3.10/dist-packages (from pytensor) (1.11.4)
Requirement already satisfied: numpy>=1.17.0 in /usr/local/lib/python3.10/dist-packages (from pytensor) (1.25.2)
Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from pytensor) (3.14.0)
Requirement already satisfied: etuples in /usr/local/lib/python3.10/dist-packages (from pytensor) (0.3.9)
Requirement already satisfied: logical-unification in /usr/local/lib/python3.10/dist-packages (from pytensor) (0.4.6)
Requirement already satisfied: miniKanren in /usr/local/lib/python3.10/dist-packages (from pytensor) (1.0.3)
Requirement already satisfied: cons in /usr/local/lib/python3.10/dist-packages (from pytensor) (0.4.6)
Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from pytensor) (4.12.1)
Requirement already satisfied: arviz>=0.13.0 in /usr/local/lib/python3.10/dist-packages (from pymc) (0.15.1)
Requirement already satisfied: cachetools>=4.2.1 in /usr/local/lib/python3.10/dist-packages (from pymc) (5.3.3)
Requirement already satisfied: cloudpickle in /usr/local/lib/python3.10/dist-packages (from pymc) (2.2.1)
Requirement already satisfied: fastprogress>=0.2.0 in /usr/local/lib/python3.10/dist-packages (from pymc) (1.0.3)
Requirement already satisfied: pandas>=0.24.0 in /usr/local/lib/python3.10/dist-packages (from pymc) (2.0.3)
Requirement already satisfied: matplotlib>=3.2 in /usr/local/lib/python3.10/dist-packages (from arviz>=0.13.0->pymc) (3.7.1)
Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from arviz>=0.13.0->pymc) (24.0)
Requirement already satisfied: xarray>=0.21.0 in /usr/local/lib/python3.10/dist-packages (from arviz>=0.13.0->pymc) (2023.7.0)
Requirement already satisfied: h5netcdf>=1.0.2 in /usr/local/lib/python3.10/dist-packages (from arviz>=0.13.0->pymc) (1.3.0)
Requirement already satisfied: xarray-einstats>=0.3 in /usr/local/lib/python3.10/dist-packages (from arviz>=0.13.0->pymc) (0.7.0)
Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas>=0.24.0->pymc) (2.8.2)
Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas>=0.24.0->pymc) (2023.4)
Requirement already satisfied: tzdata>=2022.1 in /usr/local/lib/python3.10/dist-packages (from pandas>=0.24.0->pymc) (2024.1)
Requirement already satisfied: toolz in /usr/local/lib/python3.10/dist-packages (from logical-unification->pytensor) (0.12.1)
Requirement already satisfied: multipledispatch in /usr/local/lib/python3.10/dist-packages (from logical-unification->pytensor) (1.0.0)
Requirement already satisfied: h5py in /usr/local/lib/python3.10/dist-packages (from h5netcdf>=1.0.2->arviz>=0.13.0->pymc) (3.9.0)
Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib>=3.2->arviz>=0.13.0->pymc) (1.2.1)
Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.10/dist-packages (from matplotlib>=3.2->arviz>=0.13.0->pymc) (0.12.1)
Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib>=3.2->arviz>=0.13.0->pymc) (4.53.0)
Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib>=3.2->arviz>=0.13.0->pymc) (1.4.5)
Requirement already satisfied: pillow>=6.2.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib>=3.2->arviz>=0.13.0->pymc) (9.4.0)
Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib>=3.2->arviz>=0.13.0->pymc) (3.1.2)
Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.2->pandas>=0.24.0->pymc) (1.16.0)
import pymc as pm
import numpy as np
import pandas as pd
import seaborn as sns
import scipy.stats as stats
from scipy.special import expit as logistic
import matplotlib.pyplot as plt
import arviz as az
import requests
import io
az.style.use('arviz-darkgrid')
z = np.linspace(-8, 8)
plt.plot(z, 1 / (1 + np.exp(-z)))
plt.xlabel('z')
plt.ylabel('logistic(z)')
Text(0, 0.5, 'logistic(z)')
_images/459ad8252063127194898d3bf950077bf969564478ef597fcd82a844d0c69004.png

The Iris Dataset#

target_url = 'https://raw.githubusercontent.com/cfteach/brds/main/datasets/iris.csv'

download = requests.get(target_url).content
iris = pd.read_csv(io.StringIO(download.decode('utf-8')))

iris.head()
sepal_length sepal_width petal_length petal_width species
0 5.1 3.5 1.4 0.2 setosa
1 4.9 3.0 1.4 0.2 setosa
2 4.7 3.2 1.3 0.2 setosa
3 4.6 3.1 1.5 0.2 setosa
4 5.0 3.6 1.4 0.2 setosa
# species vs petal_width

from matplotlib import pyplot as plt
import seaborn as sns
figsize = (12, 1.2 * len(iris['species'].unique()))
plt.figure(figsize=figsize)
sns.violinplot(iris, x='petal_width', y='species', inner='stick', palette='Dark2')
sns.despine(top=True, right=True, bottom=True, left=True)
<ipython-input-9-caa5570effe5>:7: FutureWarning: 

Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `y` variable to `hue` and set `legend=False` for the same effect.

  sns.violinplot(iris, x='petal_width', y='species', inner='stick', palette='Dark2')
_images/a5f509cc638232d322f95361e09d4308850f2447623969a007614ce25b6d5651.png
#using stripplot function from seaborn

sns.stripplot(x="species", y="sepal_length", data=iris, jitter=True)
<Axes: xlabel='species', ylabel='sepal_length'>
_images/4871b71c87e53ae09425d02b1d1e3bebf704712602f640a4f815fb8bed5815ad.png
sns.pairplot(iris, hue='species', diag_kind='kde', height=1.5)
/usr/local/lib/python3.10/dist-packages/seaborn/axisgrid.py:123: UserWarning: The figure layout has changed to tight
  self._figure.tight_layout(*args, **kwargs)
/usr/local/lib/python3.10/dist-packages/seaborn/axisgrid.py:213: UserWarning: This figure was using a layout engine that is incompatible with subplots_adjust and/or tight_layout; not calling subplots_adjust.
  self._figure.subplots_adjust(right=right)
/usr/local/lib/python3.10/dist-packages/seaborn/axisgrid.py:123: UserWarning: The figure layout has changed to tight
  self._figure.tight_layout(*args, **kwargs)
<seaborn.axisgrid.PairGrid at 0x7a327cda46a0>
_images/5b2b02b6f1f6b5a4431bde153e919c76f33cae4940501b77b6083e723a03af06.png

The logistic model applied to the iris dataset#

df = iris.query("species == ('setosa', 'versicolor')")
df
sepal_length sepal_width petal_length petal_width species
0 5.1 3.5 1.4 0.2 setosa
1 4.9 3.0 1.4 0.2 setosa
2 4.7 3.2 1.3 0.2 setosa
3 4.6 3.1 1.5 0.2 setosa
4 5.0 3.6 1.4 0.2 setosa
... ... ... ... ... ...
95 5.7 3.0 4.2 1.2 versicolor
96 5.7 2.9 4.2 1.3 versicolor
97 6.2 2.9 4.3 1.3 versicolor
98 5.1 2.5 3.0 1.1 versicolor
99 5.7 2.8 4.1 1.3 versicolor

100 rows × 5 columns

# converting the 'species' column of a DataFrame df into categorical codes using Pandas
y_0 = pd.Categorical(df['species']).codes
y_0
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=int8)
# let's select one variate
x_n = 'sepal_length'
x_0 = df[x_n].values
#print(x_0)

# let's center our dataset, as we have done in other exercises
x_c = x_0 - x_0.mean()
print(x_c)
[-0.371 -0.571 -0.771 -0.871 -0.471 -0.071 -0.871 -0.471 -1.071 -0.571
 -0.071 -0.671 -0.671 -1.171  0.329  0.229 -0.071 -0.371  0.229 -0.371
 -0.071 -0.371 -0.871 -0.371 -0.671 -0.471 -0.471 -0.271 -0.271 -0.771
 -0.671 -0.071 -0.271  0.029 -0.571 -0.471  0.029 -0.571 -1.071 -0.371
 -0.471 -0.971 -1.071 -0.471 -0.371 -0.671 -0.371 -0.871 -0.171 -0.471
  1.529  0.929  1.429  0.029  1.029  0.229  0.829 -0.571  1.129 -0.271
 -0.471  0.429  0.529  0.629  0.129  1.229  0.129  0.329  0.729  0.129
  0.429  0.629  0.829  0.629  0.929  1.129  1.329  1.229  0.529  0.229
  0.029  0.029  0.329  0.529 -0.071  0.529  1.229  0.829  0.129  0.029
  0.029  0.629  0.329 -0.471  0.129  0.229  0.229  0.729 -0.371  0.229]
with pm.Model() as model_logreg:
    α = pm.Normal('α', mu=0, sigma=10)
    β = pm.Normal('β', mu=0, sigma=10)

    #μ = α + pm.math.dot(x_c, β)
    μ = α + β * x_c
    θ = pm.Deterministic('θ', pm.math.sigmoid(μ))
    bd = pm.Deterministic('bd', -α/β)

    yl = pm.Bernoulli('yl', p=θ, observed=y_0)

    idata_logreg = pm.sample(2000, tune = 2000, return_inferencedata=True)
100.00% [4000/4000 00:03<00:00 Sampling chain 0, 0 divergences]
100.00% [4000/4000 00:03<00:00 Sampling chain 1, 0 divergences]
varnames = ['α', 'β', 'bd']
res = az.summary(idata_logreg)
#print(res)
az.summary(idata_logreg)
/usr/local/lib/python3.10/dist-packages/arviz/utils.py:184: NumbaDeprecationWarning: The 'nopython' keyword argument was not supplied to the 'numba.jit' decorator. The implicit default value for this argument is currently False, but it will be changed to True in Numba 0.59.0. See https://numba.readthedocs.io/en/stable/reference/deprecation.html#deprecation-of-object-mode-fall-back-behaviour-when-using-jit for details.
  numba_fn = numba.jit(**self.kwargs)(self.function)
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
α 0.310 0.327 -0.295 0.919 0.006 0.005 3204.0 2411.0 1.0
β 5.383 1.022 3.613 7.437 0.019 0.014 3075.0 2658.0 1.0
θ[0] 0.165 0.058 0.068 0.281 0.001 0.001 3771.0 2816.0 1.0
θ[1] 0.068 0.036 0.011 0.135 0.001 0.000 3734.0 2777.0 1.0
θ[2] 0.027 0.020 0.001 0.064 0.000 0.000 3677.0 2759.0 1.0
... ... ... ... ... ... ... ... ... ...
θ[96] 0.815 0.063 0.691 0.924 0.001 0.001 2902.0 2330.0 1.0
θ[97] 0.980 0.017 0.950 1.000 0.000 0.000 2851.0 2658.0 1.0
θ[98] 0.165 0.058 0.068 0.281 0.001 0.001 3771.0 2816.0 1.0
θ[99] 0.815 0.063 0.691 0.924 0.001 0.001 2902.0 2330.0 1.0
bd -0.057 0.060 -0.171 0.057 0.001 0.001 3410.0 2825.0 1.0

103 rows × 9 columns

theta_post= idata_logreg.posterior['θ']
print(np.shape(theta_post))
(2, 2000, 100)
theta = idata_logreg.posterior['θ'].mean(axis=0).mean(axis=0)
idx = np.argsort(x_c)

np.random.seed(123)

# plotting the sigmoid (logistic) curve
plt.plot(x_c[idx], theta[idx], color='C2', lw=3)

# plotting the mean boundary decision
plt.vlines(idata_logreg.posterior['bd'].mean(), 0, 1, color='k')

# ax.hdi will computer a lower and higher value of bd
bd_hdp = az.hdi(idata_logreg.posterior['bd'], hdi_prob=0.94)


# Fill the area between two vertical curves.
plt.fill_betweenx([0, 1], bd_hdp.bd[0].values, bd_hdp.bd[1].values, color='k', alpha=0.25)


plt.scatter(x_c, np.random.normal(y_0, 0.02),
            marker='.', color=[f'C{x}' for x in y_0]) # 0.02 added for visualization purposes


az.plot_hdi(x_c, idata_logreg.posterior['θ'], color='C2')  #green band


plt.xlabel(x_n)
plt.ylabel('θ', rotation=0)
# use original scale for xticks
locs, _ = plt.xticks()
plt.xticks(locs, np.round(locs + x_0.mean(), 1))

([<matplotlib.axis.XTick at 0x7a3276864550>,
  <matplotlib.axis.XTick at 0x7a3276864820>,
  <matplotlib.axis.XTick at 0x7a32769dde10>,
  <matplotlib.axis.XTick at 0x7a3276b1ae00>,
  <matplotlib.axis.XTick at 0x7a3276b189d0>,
  <matplotlib.axis.XTick at 0x7a3276b19450>,
  <matplotlib.axis.XTick at 0x7a3276b19f90>,
  <matplotlib.axis.XTick at 0x7a3276b18c10>],
 [Text(-1.5, 0, '4.0'),
  Text(-1.0, 0, '4.5'),
  Text(-0.5, 0, '5.0'),
  Text(0.0, 0, '5.5'),
  Text(0.5, 0, '6.0'),
  Text(1.0, 0, '6.5'),
  Text(1.5, 0, '7.0'),
  Text(2.0, 0, '7.5')])
_images/0e44f0876692903f4ec29ba2ac35ea43a463c36408e32dd9b77c0e7319645f9b.png

What happens if you change your priors?

Multidimensional (feature space) Logistic Regression#

df = iris.query("species == ('setosa', 'versicolor')")
y_1 = pd.Categorical(df['species']).codes
x_n = ['sepal_length', 'sepal_width']
x_1 = df[x_n].values
print(np.shape(x_1), type(x_1))
print(np.shape(y_1), type(y_1))
(100, 2) <class 'numpy.ndarray'>
(100,) <class 'numpy.ndarray'>
np.shape(x_1[:,0])
(100,)
with pm.Model() as model_1:

    α = pm.Normal('α', mu=0, sigma=10)
    β = pm.Normal('β', mu=0, sigma=2, shape=len(x_n))

    #Data container that registers a data variable with the model.
    #x__1 = pm.Data('x', x_1, mutable=True)
    #y__1 = pm.Data('y', y_1, mutable=True)
    # advantages: Model Reusability; Update Beliefs (reusing model); Efficeincy in sampling (same computational graph, even when data changes)
    #Depending on the mutable setting (default: True), the variable is registered as a SharedVariable, enabling it to be altered in value and shape, but NOT in dimensionality using pymc.set_data().

    μ = α + pm.math.dot(x_1, β)


    θ = pm.Deterministic('θ', 1 / (1 + pm.math.exp(-μ)))
    bd = pm.Deterministic('bd', -α/β[1] - β[0]/β[1] * x_1[:,0])



    yl = pm.Bernoulli('yl', p=θ, observed=y_1)

    trace_1 = pm.sample(1000, tune=2000, return_inferencedata=True, target_accept=0.9)
100.00% [3000/3000 00:29<00:00 Sampling chain 0, 0 divergences]
100.00% [3000/3000 00:31<00:00 Sampling chain 1, 0 divergences]
varnames = ['α', 'β']
az.plot_forest(trace_1, var_names=varnames);
_images/dbc2587fcb00627663e265e6882acc1dbbc25cccad65bfc6cdbfd69118f528f1.png
idx = np.argsort(x_1[:,0])


bd_mean = trace_1.posterior['bd'].mean(axis=0).mean(axis=0)


plt.scatter(x_1[:,0], x_1[:,1], c=[f'C{x}' for x in y_0])

bd = bd_mean[idx]

plt.plot(x_1[:,0][idx], bd, color='k');


az.plot_hdi(x_1[:,0], trace_1.posterior['bd'], color='k')


plt.xlabel(x_n[0])
plt.ylabel(x_n[1])
Text(0, 0.5, 'sepal_width')
_images/4d74922909d73fe0f0c4af71325dd46f34ad118cb82245ddd9cc5638da025bc4.png
  • What happens if you change your priors?

  • What happens if your decision boundary is not linear? (See below)

Prediction on unseen data#

print(np.shape(x_1), type(x_1))

udata = np.array(((5.2,2.0),(7.0,2.0),(4.5,5.0),(5.5,3.2)))

print(np.shape(udata), type(udata))
(100, 2) <class 'numpy.ndarray'>
(4, 2) <class 'numpy.ndarray'>
print(udata)
[[5.2 2. ]
 [7.  2. ]
 [4.5 5. ]
 [5.5 3.2]]
alpha_chain = trace_1.posterior['α'].mean(axis=0).values
beta_chain  = trace_1.posterior['β'].mean(axis=0).values

print(np.shape(alpha_chain), np.shape(beta_chain))
(1000,) (1000, 2)
logit = np.dot(udata, beta_chain.T) + alpha_chain
print(np.shape(logit))
(4, 1000)
probabilities = 1 / (1 + np.exp(-logit))
print(np.shape(probabilities))
(4, 1000)
# Average probabilities for prediction
mean_probabilities = np.mean(probabilities, axis=1)

# Class assignment (you might adjust the threshold if needed, default is 0.5)
class_assignments = (mean_probabilities > 0.5).astype(int)

# Uncertainty estimation
lower_bound = np.percentile(probabilities, 2.5, axis=1)
upper_bound = np.percentile(probabilities, 97.5, axis=1)

print("\n=======================================")
print("data: \n", udata.T)
print("class, probabilities, ranges(94%HDI): ")
for h,i,j,k in zip(class_assignments, mean_probabilities, lower_bound,upper_bound):
  print(f"class: {h}, mean prob. {i:.4f}, 94% HDI: [{j:.4f},{k:.4f}]")

print("=======================================\n")
=======================================
data: 
 [[5.2 7.  4.5 5.5]
 [2.  2.  5.  3.2]]
class, probabilities, ranges(94%HDI): 
class: 1, mean prob. 0.9898, 94% HDI: [0.9666,0.9986]
class: 1, mean prob. 1.0000, 94% HDI: [1.0000,1.0000]
class: 0, mean prob. 0.0000, 94% HDI: [0.0000,0.0000]
class: 0, mean prob. 0.4874, 94% HDI: [0.3275,0.6525]
=======================================
bd_mean = trace_1.posterior['bd'].mean(axis=0).mean(axis=0)


plt.scatter(udata[:,0], udata[:,1], c=[f'C{x}' for x in class_assignments])

bd = bd_mean[idx]

plt.plot(x_1[:,0][idx], bd, color='k');

az.plot_hdi(x_1[:,0], trace_1.posterior['bd'], color='k')

plt.xlabel(x_n[0])
plt.ylabel(x_n[1])
Text(0, 0.5, 'sepal_width')
_images/3f56c6b3aa37767d2a5fa68a2d77909a9473a3b359d59c44593f295e413a73f0.png
np.shape(x_1**2)
(100, 2)
x_1sq = x_1**2
print(np.shape(x_1sq), np.shape(x_1))
(100, 2) (100, 2)

Can I use a non-linear decision boundary?#

Let’s add, for example, a term \(\gamma \cdot x_{sepal \ length}^2\) in the equation for \(\mu\)

with pm.Model() as model_2:
    α = pm.Normal('α', mu=0, sigma=10)
    β = pm.Normal('β', mu=0, sigma=2, shape=len(x_n))
    𝛾 = pm.Normal('𝛾', mu=0, sigma=2)

    μ = α + pm.math.dot(x_1, β) + pm.math.dot(x_1[:,0]**2,𝛾)

    θ = pm.Deterministic('θ', 1 / (1 + pm.math.exp(-μ)))
    bd = pm.Deterministic('bd', -α/β[1] - β[0]/β[1] * x_1[:,0] - 𝛾/β[1] * x_1[:,0]**2)

    yl = pm.Bernoulli('yl', p=θ, observed=y_1)

    trace_2 = pm.sample(1000, tune=2000, return_inferencedata=True, target_accept=0.95)
100.00% [3000/3000 01:00<00:00 Sampling chain 0, 0 divergences]
100.00% [3000/3000 00:58<00:00 Sampling chain 1, 0 divergences]
trace_2
arviz.InferenceData
    • <xarray.Dataset>
      Dimensions:   (chain: 2, draw: 1000, β_dim_0: 2, θ_dim_0: 100, bd_dim_0: 100)
      Coordinates:
        * chain     (chain) int64 0 1
        * draw      (draw) int64 0 1 2 3 4 5 6 7 8 ... 992 993 994 995 996 997 998 999
        * β_dim_0   (β_dim_0) int64 0 1
        * θ_dim_0   (θ_dim_0) int64 0 1 2 3 4 5 6 7 8 9 ... 91 92 93 94 95 96 97 98 99
        * bd_dim_0  (bd_dim_0) int64 0 1 2 3 4 5 6 7 8 ... 91 92 93 94 95 96 97 98 99
      Data variables:
          α         (chain, draw) float64 1.62 -7.419 0.1927 ... 11.85 13.82 13.3
          β         (chain, draw, β_dim_0) float64 -2.621 -7.707 ... -1.832 -5.801
          𝛾         (chain, draw) float64 1.246 1.305 0.4872 ... 0.523 0.4905 0.4845
          θ         (chain, draw, θ_dim_0) float64 0.001814 0.01183 ... 0.8868 0.9136
          bd        (chain, draw, bd_dim_0) float64 2.681 2.426 2.183 ... 2.855 3.206
      Attributes:
          created_at:                 2024-02-22T19:31:16.798921
          arviz_version:              0.15.1
          inference_library:          pymc
          inference_library_version:  5.7.2
          sampling_time:              119.82878065109253
          tuning_steps:               2000

    • <xarray.Dataset>
      Dimensions:                (chain: 2, draw: 1000)
      Coordinates:
        * chain                  (chain) int64 0 1
        * draw                   (draw) int64 0 1 2 3 4 5 ... 994 995 996 997 998 999
      Data variables: (12/17)
          max_energy_error       (chain, draw) float64 -0.2017 -0.03366 ... -0.3917
          energy_error           (chain, draw) float64 -0.02009 -0.008556 ... -0.2159
          step_size              (chain, draw) float64 0.04268 0.04268 ... 0.02932
          perf_counter_start     (chain, draw) float64 4.682e+03 ... 4.761e+03
          acceptance_rate        (chain, draw) float64 0.9935 0.9999 ... 0.8849 0.9788
          energy                 (chain, draw) float64 20.74 21.26 ... 27.17 25.83
          ...                     ...
          step_size_bar          (chain, draw) float64 0.0381 0.0381 ... 0.03881
          tree_depth             (chain, draw) int64 7 6 7 7 7 7 7 5 ... 7 7 5 6 4 5 5
          largest_eigval         (chain, draw) float64 nan nan nan nan ... nan nan nan
          perf_counter_diff      (chain, draw) float64 0.02136 0.01663 ... 0.005639
          smallest_eigval        (chain, draw) float64 nan nan nan nan ... nan nan nan
          reached_max_treedepth  (chain, draw) bool False False False ... False False
      Attributes:
          created_at:                 2024-02-22T19:31:16.812696
          arviz_version:              0.15.1
          inference_library:          pymc
          inference_library_version:  5.7.2
          sampling_time:              119.82878065109253
          tuning_steps:               2000

    • <xarray.Dataset>
      Dimensions:   (yl_dim_0: 100)
      Coordinates:
        * yl_dim_0  (yl_dim_0) int64 0 1 2 3 4 5 6 7 8 ... 91 92 93 94 95 96 97 98 99
      Data variables:
          yl        (yl_dim_0) int64 0 0 0 0 0 0 0 0 0 0 0 0 ... 1 1 1 1 1 1 1 1 1 1 1
      Attributes:
          created_at:                 2024-02-22T19:31:16.821619
          arviz_version:              0.15.1
          inference_library:          pymc
          inference_library_version:  5.7.2

az.summary(trace_2)
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
α -2.588 5.816 -14.190 7.282 0.202 0.143 830.0 912.0 1.0
β[0] -0.215 1.781 -3.216 3.336 0.067 0.051 704.0 624.0 1.0
β[1] -5.914 1.136 -8.021 -3.783 0.040 0.030 872.0 836.0 1.0
𝛾 0.765 0.261 0.259 1.238 0.009 0.007 786.0 930.0 1.0
θ[0] 0.016 0.014 0.000 0.043 0.000 0.000 925.0 985.0 1.0
... ... ... ... ... ... ... ... ... ...
bd[95] 3.564 0.157 3.255 3.842 0.004 0.003 1613.0 1583.0 1.0
bd[96] 3.564 0.157 3.255 3.842 0.004 0.003 1613.0 1583.0 1.0
bd[97] 4.326 0.301 3.756 4.865 0.008 0.006 1452.0 1530.0 1.0
bd[98] 2.737 0.117 2.510 2.941 0.003 0.002 1662.0 1525.0 1.0
bd[99] 3.564 0.157 3.255 3.842 0.004 0.003 1613.0 1583.0 1.0

204 rows × 9 columns

idx = np.argsort(x_1[:,0])


bd_mean2 = trace_2.posterior['bd'].mean(axis=0).mean(axis=0)


plt.scatter(x_1[:,0], x_1[:,1], c=[f'C{x}' for x in y_0])

bd = bd_mean2[idx]

plt.plot(x_1[:,0][idx], bd, color='k');


az.plot_hdi(x_1[:,0], trace_2.posterior['bd'], color='k')


plt.xlabel(x_n[0])
plt.ylabel(x_n[1])

print(np.shape(bd_mean2))
(100,)
_images/702a775c910eb1c75c5cc33bb5ab298faf9e8eca4e1ae935655cd503d3c91af7.png