A very basic example from the NFlows Library

A very basic example from the NFlows Library#

Ref. bayesiains/nflows

This is intended to get you started and have an idea of what the code will look like.

We will move to a hands on example afterwards where we will use data from a Nuclear Physics experiment.

!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
!pip install nflows
!pip install pkbar
!pip install FrEIA
!pip install scipy
Looking in indexes: https://download.pytorch.org/whl/cu121
Requirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (2.2.1+cu121)
Requirement already satisfied: torchvision in /usr/local/lib/python3.10/dist-packages (0.17.1+cu121)
Requirement already satisfied: torchaudio in /usr/local/lib/python3.10/dist-packages (2.2.1+cu121)
Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch) (3.14.0)
Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch) (4.11.0)
Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch) (1.12)
Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch) (3.3)
Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch) (3.1.4)
Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch) (2023.6.0)
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch)
  Downloading https://download.pytorch.org/whl/cu121/nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 23.7/23.7 MB 52.2 MB/s eta 0:00:00
?25hCollecting nvidia-cuda-runtime-cu12==12.1.105 (from torch)
  Downloading https://download.pytorch.org/whl/cu121/nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 823.6/823.6 kB 35.8 MB/s eta 0:00:00
?25hCollecting nvidia-cuda-cupti-cu12==12.1.105 (from torch)
  Downloading https://download.pytorch.org/whl/cu121/nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 14.1/14.1 MB 72.0 MB/s eta 0:00:00
?25hCollecting nvidia-cudnn-cu12==8.9.2.26 (from torch)
  Downloading https://download.pytorch.org/whl/cu121/nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl (731.7 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 731.7/731.7 MB 1.1 MB/s eta 0:00:00
?25hCollecting nvidia-cublas-cu12==12.1.3.1 (from torch)
  Downloading https://download.pytorch.org/whl/cu121/nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl (410.6 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 410.6/410.6 MB 4.0 MB/s eta 0:00:00
?25hCollecting nvidia-cufft-cu12==11.0.2.54 (from torch)
  Downloading https://download.pytorch.org/whl/cu121/nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl (121.6 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 121.6/121.6 MB 7.3 MB/s eta 0:00:00
?25hCollecting nvidia-curand-cu12==10.3.2.106 (from torch)
  Downloading https://download.pytorch.org/whl/cu121/nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl (56.5 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 56.5/56.5 MB 11.3 MB/s eta 0:00:00
?25hCollecting nvidia-cusolver-cu12==11.4.5.107 (from torch)
  Downloading https://download.pytorch.org/whl/cu121/nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl (124.2 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 124.2/124.2 MB 8.3 MB/s eta 0:00:00
?25hCollecting nvidia-cusparse-cu12==12.1.0.106 (from torch)
  Downloading https://download.pytorch.org/whl/cu121/nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl (196.0 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 196.0/196.0 MB 6.4 MB/s eta 0:00:00
?25hCollecting nvidia-nccl-cu12==2.19.3 (from torch)
  Downloading https://download.pytorch.org/whl/cu121/nvidia_nccl_cu12-2.19.3-py3-none-manylinux1_x86_64.whl (166.0 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 166.0/166.0 MB 2.6 MB/s eta 0:00:00
?25hCollecting nvidia-nvtx-cu12==12.1.105 (from torch)
  Downloading https://download.pytorch.org/whl/cu121/nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (99 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 99.1/99.1 kB 13.5 MB/s eta 0:00:00
?25hRequirement already satisfied: triton==2.2.0 in /usr/local/lib/python3.10/dist-packages (from torch) (2.2.0)
Collecting nvidia-nvjitlink-cu12 (from nvidia-cusolver-cu12==11.4.5.107->torch)
  Downloading https://download.pytorch.org/whl/cu121/nvidia_nvjitlink_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (19.8 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 19.8/19.8 MB 36.0 MB/s eta 0:00:00
?25hRequirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from torchvision) (1.25.2)
Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /usr/local/lib/python3.10/dist-packages (from torchvision) (9.4.0)
Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch) (2.1.5)
Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch) (1.3.0)
Installing collected packages: nvidia-nvtx-cu12, nvidia-nvjitlink-cu12, nvidia-nccl-cu12, nvidia-curand-cu12, nvidia-cufft-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, nvidia-cusparse-cu12, nvidia-cudnn-cu12, nvidia-cusolver-cu12
Successfully installed nvidia-cublas-cu12-12.1.3.1 nvidia-cuda-cupti-cu12-12.1.105 nvidia-cuda-nvrtc-cu12-12.1.105 nvidia-cuda-runtime-cu12-12.1.105 nvidia-cudnn-cu12-8.9.2.26 nvidia-cufft-cu12-11.0.2.54 nvidia-curand-cu12-10.3.2.106 nvidia-cusolver-cu12-11.4.5.107 nvidia-cusparse-cu12-12.1.0.106 nvidia-nccl-cu12-2.19.3 nvidia-nvjitlink-cu12-12.1.105 nvidia-nvtx-cu12-12.1.105
Collecting nflows
  Downloading nflows-0.14.tar.gz (45 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 45.8/45.8 kB 855.9 kB/s eta 0:00:00
?25h  Preparing metadata (setup.py) ... ?25l?25hdone
Requirement already satisfied: matplotlib in /usr/local/lib/python3.10/dist-packages (from nflows) (3.7.1)
Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from nflows) (1.25.2)
Requirement already satisfied: tensorboard in /usr/local/lib/python3.10/dist-packages (from nflows) (2.15.2)
Requirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (from nflows) (2.2.1+cu121)
Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from nflows) (4.66.4)
Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib->nflows) (1.2.1)
Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.10/dist-packages (from matplotlib->nflows) (0.12.1)
Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib->nflows) (4.51.0)
Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib->nflows) (1.4.5)
Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib->nflows) (24.0)
Requirement already satisfied: pillow>=6.2.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib->nflows) (9.4.0)
Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib->nflows) (3.1.2)
Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.10/dist-packages (from matplotlib->nflows) (2.8.2)
Requirement already satisfied: absl-py>=0.4 in /usr/local/lib/python3.10/dist-packages (from tensorboard->nflows) (1.4.0)
Requirement already satisfied: grpcio>=1.48.2 in /usr/local/lib/python3.10/dist-packages (from tensorboard->nflows) (1.63.0)
Requirement already satisfied: google-auth<3,>=1.6.3 in /usr/local/lib/python3.10/dist-packages (from tensorboard->nflows) (2.27.0)
Requirement already satisfied: google-auth-oauthlib<2,>=0.5 in /usr/local/lib/python3.10/dist-packages (from tensorboard->nflows) (1.2.0)
Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.10/dist-packages (from tensorboard->nflows) (3.6)
Requirement already satisfied: protobuf!=4.24.0,>=3.19.6 in /usr/local/lib/python3.10/dist-packages (from tensorboard->nflows) (3.20.3)
Requirement already satisfied: requests<3,>=2.21.0 in /usr/local/lib/python3.10/dist-packages (from tensorboard->nflows) (2.31.0)
Requirement already satisfied: setuptools>=41.0.0 in /usr/local/lib/python3.10/dist-packages (from tensorboard->nflows) (67.7.2)
Requirement already satisfied: six>1.9 in /usr/local/lib/python3.10/dist-packages (from tensorboard->nflows) (1.16.0)
Requirement already satisfied: tensorboard-data-server<0.8.0,>=0.7.0 in /usr/local/lib/python3.10/dist-packages (from tensorboard->nflows) (0.7.2)
Requirement already satisfied: werkzeug>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from tensorboard->nflows) (3.0.3)
Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch->nflows) (3.14.0)
Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch->nflows) (4.11.0)
Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch->nflows) (1.12)
Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch->nflows) (3.3)
Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch->nflows) (3.1.4)
Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch->nflows) (2023.6.0)
Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch->nflows) (12.1.105)
Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch->nflows) (12.1.105)
Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch->nflows) (12.1.105)
Requirement already satisfied: nvidia-cudnn-cu12==8.9.2.26 in /usr/local/lib/python3.10/dist-packages (from torch->nflows) (8.9.2.26)
Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /usr/local/lib/python3.10/dist-packages (from torch->nflows) (12.1.3.1)
Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /usr/local/lib/python3.10/dist-packages (from torch->nflows) (11.0.2.54)
Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /usr/local/lib/python3.10/dist-packages (from torch->nflows) (10.3.2.106)
Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /usr/local/lib/python3.10/dist-packages (from torch->nflows) (11.4.5.107)
Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /usr/local/lib/python3.10/dist-packages (from torch->nflows) (12.1.0.106)
Requirement already satisfied: nvidia-nccl-cu12==2.19.3 in /usr/local/lib/python3.10/dist-packages (from torch->nflows) (2.19.3)
Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch->nflows) (12.1.105)
Requirement already satisfied: triton==2.2.0 in /usr/local/lib/python3.10/dist-packages (from torch->nflows) (2.2.0)
Requirement already satisfied: nvidia-nvjitlink-cu12 in /usr/local/lib/python3.10/dist-packages (from nvidia-cusolver-cu12==11.4.5.107->torch->nflows) (12.1.105)
Requirement already satisfied: cachetools<6.0,>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from google-auth<3,>=1.6.3->tensorboard->nflows) (5.3.3)
Requirement already satisfied: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.10/dist-packages (from google-auth<3,>=1.6.3->tensorboard->nflows) (0.4.0)
Requirement already satisfied: rsa<5,>=3.1.4 in /usr/local/lib/python3.10/dist-packages (from google-auth<3,>=1.6.3->tensorboard->nflows) (4.9)
Requirement already satisfied: requests-oauthlib>=0.7.0 in /usr/local/lib/python3.10/dist-packages (from google-auth-oauthlib<2,>=0.5->tensorboard->nflows) (1.3.1)
Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.21.0->tensorboard->nflows) (3.3.2)
Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.21.0->tensorboard->nflows) (3.7)
Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.21.0->tensorboard->nflows) (2.0.7)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.21.0->tensorboard->nflows) (2024.2.2)
Requirement already satisfied: MarkupSafe>=2.1.1 in /usr/local/lib/python3.10/dist-packages (from werkzeug>=1.0.1->tensorboard->nflows) (2.1.5)
Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch->nflows) (1.3.0)
Requirement already satisfied: pyasn1<0.7.0,>=0.4.6 in /usr/local/lib/python3.10/dist-packages (from pyasn1-modules>=0.2.1->google-auth<3,>=1.6.3->tensorboard->nflows) (0.6.0)
Requirement already satisfied: oauthlib>=3.0.0 in /usr/local/lib/python3.10/dist-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<2,>=0.5->tensorboard->nflows) (3.2.2)
Building wheels for collected packages: nflows
  Building wheel for nflows (setup.py) ... ?25l?25hdone
  Created wheel for nflows: filename=nflows-0.14-py3-none-any.whl size=53654 sha256=1fd0c8ef1be31b169711a2a7f54bffd4fd4cd5d75b052ec672d46691313a7b34
  Stored in directory: /root/.cache/pip/wheels/ca/8f/ac/c324eb57b461632081812c33b13161878290d0e6fbb8f5a7e2
Successfully built nflows
Installing collected packages: nflows
Successfully installed nflows-0.14
Collecting pkbar
  Downloading pkbar-0.5-py3-none-any.whl (9.2 kB)
Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from pkbar) (1.25.2)
Installing collected packages: pkbar
Successfully installed pkbar-0.5
Collecting FrEIA
  Downloading FrEIA-0.2.tar.gz (34 kB)
  Preparing metadata (setup.py) ... ?25l?25hdone
Requirement already satisfied: numpy>=1.15.0 in /usr/local/lib/python3.10/dist-packages (from FrEIA) (1.25.2)
Requirement already satisfied: scipy>=1.5 in /usr/local/lib/python3.10/dist-packages (from FrEIA) (1.11.4)
Requirement already satisfied: torch>=1.0.0 in /usr/local/lib/python3.10/dist-packages (from FrEIA) (2.2.1+cu121)
Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch>=1.0.0->FrEIA) (3.14.0)
Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.0.0->FrEIA) (4.11.0)
Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=1.0.0->FrEIA) (1.12)
Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.0.0->FrEIA) (3.3)
Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.0.0->FrEIA) (3.1.4)
Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch>=1.0.0->FrEIA) (2023.6.0)
Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=1.0.0->FrEIA) (12.1.105)
Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=1.0.0->FrEIA) (12.1.105)
Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=1.0.0->FrEIA) (12.1.105)
Requirement already satisfied: nvidia-cudnn-cu12==8.9.2.26 in /usr/local/lib/python3.10/dist-packages (from torch>=1.0.0->FrEIA) (8.9.2.26)
Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /usr/local/lib/python3.10/dist-packages (from torch>=1.0.0->FrEIA) (12.1.3.1)
Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /usr/local/lib/python3.10/dist-packages (from torch>=1.0.0->FrEIA) (11.0.2.54)
Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /usr/local/lib/python3.10/dist-packages (from torch>=1.0.0->FrEIA) (10.3.2.106)
Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /usr/local/lib/python3.10/dist-packages (from torch>=1.0.0->FrEIA) (11.4.5.107)
Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /usr/local/lib/python3.10/dist-packages (from torch>=1.0.0->FrEIA) (12.1.0.106)
Requirement already satisfied: nvidia-nccl-cu12==2.19.3 in /usr/local/lib/python3.10/dist-packages (from torch>=1.0.0->FrEIA) (2.19.3)
Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=1.0.0->FrEIA) (12.1.105)
Requirement already satisfied: triton==2.2.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.0.0->FrEIA) (2.2.0)
Requirement already satisfied: nvidia-nvjitlink-cu12 in /usr/local/lib/python3.10/dist-packages (from nvidia-cusolver-cu12==11.4.5.107->torch>=1.0.0->FrEIA) (12.1.105)
Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.0.0->FrEIA) (2.1.5)
Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=1.0.0->FrEIA) (1.3.0)
Building wheels for collected packages: FrEIA
  Building wheel for FrEIA (setup.py) ... ?25l?25hdone
  Created wheel for FrEIA: filename=FrEIA-0.2-py3-none-any.whl size=42758 sha256=ef1b2bded4cc034ac4912bbb97544c30a8e839cc9465755c77bdc00cf829e517
  Stored in directory: /root/.cache/pip/wheels/81/a8/e2/d532a76f72108ac4a340cbe3f86b4f591abfdbd75209a5badb
Successfully built FrEIA
Installing collected packages: FrEIA
Successfully installed FrEIA-0.2
Requirement already satisfied: scipy in /usr/local/lib/python3.10/dist-packages (1.11.4)
Requirement already satisfied: numpy<1.28.0,>=1.21.6 in /usr/local/lib/python3.10/dist-packages (from scipy) (1.25.2)
# Basic imports

import matplotlib.pyplot as plt
import sklearn.datasets as datasets

import torch
from torch import nn
from torch import optim
import numpy as np
from nflows.flows.base import Flow
from nflows.distributions.normal import ConditionalDiagonalNormal
from nflows.transforms.base import CompositeTransform
from nflows.transforms.autoregressive import MaskedAffineAutoregressiveTransform
from nflows.transforms.permutations import ReversePermutation
from nflows.nn.nets import ResidualNet
import pkbar

x, y = datasets.make_moons(128, noise=.1)
plt.scatter(x[:, 0], x[:, 1], c=y);
_images/ec280fc3a210031c26687757e6ddb0ad61f023df2003566b3b6c284363eda696.png
# Check if we have a GPU
!nvidia-smi
Fri May 17 19:10:14 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  Tesla T4                       Off | 00000000:00:04.0 Off |                    0 |
| N/A   50C    P8              10W /  70W |      0MiB / 15360MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                                         
+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|  No running processes found                                                           |
+---------------------------------------------------------------------------------------+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

num_layers = 5
base_dist = ConditionalDiagonalNormal(shape=[2],
                                      context_encoder=nn.Linear(1, 4))

transforms = []
for _ in range(num_layers):
    # Permutations
    transforms.append(ReversePermutation(features=2))
    # Affine transformation
    # Notice this is affine, but also autoregressive
    # Autoregressive refers to the way the NN(x_2) functions - not important at this time
    # See slides for reference on affine functions.
    transforms.append(MaskedAffineAutoregressiveTransform(features=2,
                                                          hidden_features=4,
                                                          context_features=1))
transform = CompositeTransform(transforms)

flow = Flow(transform, base_dist).to(device)
optimizer = optim.Adam(flow.parameters())

Training and plotting#

We are training the model through likelihood maximization. What we expect is that as the model progresses, the likelihood distribution should resemble a density over the data we inject.

import torch
import torch.optim as optim
import matplotlib.pyplot as plt
from sklearn import datasets

num_iter = 5000
for i in range(num_iter):
    x, y = datasets.make_moons(128, noise=.1)
    x = torch.tensor(x, dtype=torch.float32, device=device)
    y = torch.tensor(y, dtype=torch.float32, device=device).reshape(-1, 1)
    optimizer.zero_grad()
    loss = -flow.log_prob(inputs=x, context=y).mean()
    loss.backward()
    optimizer.step()

    if (i + 1) % 500 == 0:
        fig, ax = plt.subplots(1, 2)
        xline = torch.linspace(-1.5, 2.5, steps=100)
        yline = torch.linspace(-.75, 1.25, steps=100)
        xgrid, ygrid = torch.meshgrid(xline, yline, indexing='ij')
        xgrid, ygrid = xgrid.to(device), ygrid.to(device)
        xyinput = torch.cat([xgrid.reshape(-1, 1), ygrid.reshape(-1, 1)], dim=1)

        with torch.no_grad():
            zgrid0 = flow.log_prob(xyinput, torch.zeros(10000, 1).to(device)).exp().reshape(100, 100).detach().cpu()
            zgrid1 = flow.log_prob(xyinput, torch.ones(10000, 1).to(device)).exp().reshape(100, 100).detach().cpu()

        ax[0].contourf(xgrid.cpu().numpy(), ygrid.cpu().numpy(), zgrid0.numpy())
        ax[1].contourf(xgrid.cpu().numpy(), ygrid.cpu().numpy(), zgrid1.numpy())
        plt.suptitle('Iteration {}'.format(i + 1))
        plt.show()
_images/19aab8e0ac192247ff04be2a72876f2922bee31be8ae1e17bb8a16863e50dccb.png _images/3c223bae73a8943250952181f63ced02c0ac4a74b9d5792f402b9d0cc87c69b5.png _images/6d34479c22327a56307aa8a2f54a847b61171383d57b9e7f617f271ea001d6ef.png _images/3deb17eefc212fcec95e73734d959c2ab38f3cf64b88cbc76822ada8672a7086.png _images/946570e05ba560f1ae36858ea2ef568c07cb075d3371f3ec4bb7e6fd265623c9.png _images/52e0b6f860a1b9170c4cb3d754f967bacf162de42fa1f5b6b1ea966a4618c0fe.png _images/a50a515cf72b1fbb5a2c155123c027f2f00d1443043b041d8039ab13833a4f09.png _images/0ab8751389f41dab7b353517cf2fc2e5156121d95d02c31ba651f3fabe2c9925.png _images/e6b45b5940ce63604b201f6f2859d1985a3c0b08efaaa6594bf8a5779d0f00a6.png _images/3d30a16344bf6b6cecb4f4bfd4402c2318dd0cf8dcbcc538219f9d1f7dd91047.png

What are we doing here?#

  1. We are training the model to learn the transformation \(z = f(x)\)

  2. We are plotting the learned density over the distribution, \(\textit{i.e.}\), the Probability Density Function over the two distinct classes.

Lets take it a step further. Now we want to sample our Normalizing Flow’s base distribution, and perform the inverse transformation \(x = f^{-1}(x)\) to generate artificial data.

# First lets sample some points from the true distributions to compare.

x, y = datasets.make_moons(500, noise=.1)

# Now lets create a representation of z from a Gaussian Distribution

z_y = torch.tensor(y).to(device).float().reshape(-1, 1)

# We would usually need to create a sample from a Gaussian distribution as follows:
# z = torch.tensor(np.random.normal(loc=0.0,scale=1.0,size=(500,2))).to(device).float()
# NFlows does this under the hood with their distributions! It is much cleaner and more reliable.

x_generated = flow.sample(num_samples=1,context=z_y).reshape(-1,2)
print(x_generated.shape)
x_generated = x_generated.detach().cpu().numpy()

plt.scatter(x[:, 0], x[:, 1], c=y)
plt.title('True Data',fontsize=24)
plt.show()


plt.scatter(x_generated[:,0],x_generated[:,1],c=y)
plt.title('Generated Data',fontsize=24)
plt.show()
torch.Size([500, 2])
_images/f6105edd7817296b8970411b8814894edf5e031b95d30ee074129327c9b921e4.png _images/b14d25c008d78a5c7955fcf30ff88808d5a733e26b3234f79397d3dd34acfee8.png

Can you use this to perform classification?#

# First lets sample some points from the true distribution.
# We have trained with some width to the distributions
# Lets see classification results if we have no noise
sigma = 0
x, y = datasets.make_moons(1000, noise=sigma)


# Lets create two hypothesis and assume we do not know the ground truth labels
# Hypothesis of class 0
x_ = torch.tensor(x,device=device,dtype=torch.float32)
hyp_0 = torch.tensor(np.zeros_like(y),device=device,dtype=torch.float32).reshape(-1,1)
# Hypothesis of class 1
hyp_1 = torch.tensor(np.ones_like(y),device=device,dtype=torch.float32).reshape(-1,1)

LL_class_0 = flow.log_prob(x_,context=hyp_0).detach().cpu().numpy()
LL_class_1 = flow.log_prob(x_,context=hyp_1).detach().cpu().numpy()
# We can make predictions based off which likelihood is larger
y_pred = np.zeros_like(y)

y_pred[LL_class_1 > LL_class_0] = 1.0

plt.scatter(x[:, 0], x[:, 1], c=y_pred)
plt.title('Data - Noise: {0}'.format(sigma),fontsize=24)
plt.show()

print("Accuracy: ",(y_pred == y).sum() * 100 / len(y),"%")

# Lets see classification results if we have double the noise from training
sigma = 0.2
x, y = datasets.make_moons(1000, noise=sigma)


# Lets create two hypothesis and assume we do not know the ground truth labels
# Hypothesis of class 0
x_ = torch.tensor(x,device=device,dtype=torch.float32)
hyp_0 = torch.tensor(np.zeros_like(y),device=device,dtype=torch.float32).reshape(-1,1)
# Hypothesis of class 1
hyp_1 = torch.tensor(np.ones_like(y),device=device,dtype=torch.float32).reshape(-1,1)

LL_class_0 = flow.log_prob(x_,context=hyp_0).detach().cpu().numpy()
LL_class_1 = flow.log_prob(x_,context=hyp_1).detach().cpu().numpy()
# We can make predictions based off which likelihood is larger
y_pred = np.zeros_like(y)

y_pred[LL_class_1 > LL_class_0] = 1.0

plt.scatter(x[:, 0], x[:, 1], c=y_pred)
plt.title('Data - Noise: {0}'.format(sigma),fontsize=24)
plt.show()

print("Accuracy: ",(y_pred == y).sum() * 100 / len(y),"%")
_images/210a58c625cbe3d3b3d3e9213e1ac3316cc974e07a7e12e6532c1f95d7d75851.png
Accuracy:  100.0 %
_images/1344ea535f565b864e05cca86b38973a836687cbf772806ab3457ee4ce6d6df6.png
Accuracy:  96.7 %

Lets write a function to see how our performance changes as a function of noise.

def performance_func_noise(sigmas,flow):
  results = []
  for sigma in sigmas:
    x, y = datasets.make_moons(1000, noise=sigma)
    x_ = torch.tensor(x,device=device,dtype=torch.float32)
    hyp_0 = torch.tensor(np.zeros_like(y),device=device,dtype=torch.float32).reshape(-1,1)
    hyp_1 = torch.tensor(np.ones_like(y),device=device,dtype=torch.float32).reshape(-1,1)

    LL_class_0 = flow.log_prob(x_,context=hyp_0).detach().cpu().numpy()
    LL_class_1 = flow.log_prob(x_,context=hyp_1).detach().cpu().numpy()
    # We can make predictions based off which likelihood is larger
    y_pred = np.zeros_like(y)

    y_pred[LL_class_1 > LL_class_0] = 1.0
    print("Sigma = ",sigma," Accuracy: ",(y_pred == y).sum() * 100 / len(y),"%")
    results.append((y_pred == y).sum() * 100 / len(y))

  return results

sigmas = [0.0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0]
results = performance_func_noise(sigmas,flow)

plt.plot(sigmas,results,color='red',lw='2',linestyle='-')
plt.xlabel(r"$\sigma_{noise.}$",fontsize=20)
plt.ylabel("Accuracy (%)",fontsize=20)
plt.title(r"Accuracy as a function of $\sigma_{noise.}$",fontsize=20)
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)
plt.show()
Sigma =  0.0  Accuracy:  100.0 %
Sigma =  0.1  Accuracy:  99.8 %
Sigma =  0.2  Accuracy:  96.2 %
Sigma =  0.3  Accuracy:  87.0 %
Sigma =  0.4  Accuracy:  84.4 %
Sigma =  0.5  Accuracy:  75.8 %
Sigma =  0.6  Accuracy:  73.5 %
Sigma =  0.7  Accuracy:  69.4 %
Sigma =  0.8  Accuracy:  69.0 %
Sigma =  0.9  Accuracy:  66.7 %
Sigma =  1.0  Accuracy:  64.1 %
_images/588a901a3704ba0567fa0faa311e418b5c2b0cb7b7de8d9ae41ef029c2dfbf47.png