Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pre v002 release #16

Merged
merged 30 commits into from
Aug 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
0f0a425
Updated setup and requirements
james-alvey-42 Mar 25, 2024
fe39389
First upload of fake LISA simulator for inference tests
james-alvey-42 Mar 25, 2024
bf0c2ff
Upload of initial saqqara source code: configs, priors, simulator
james-alvey-42 Mar 25, 2024
9ae2533
Upload of noise variation project example
james-alvey-42 Mar 25, 2024
521546f
saqqara testing framework
james-alvey-42 Mar 25, 2024
4ca1f63
Updated fake lisa example for 3 channels
james-alvey-42 Mar 28, 2024
a2ca323
Updated fake lisa simulator to include coarse graining
james-alvey-42 Apr 4, 2024
c011eda
Updated niose variation simulator to include coarse-graining
james-alvey-42 Apr 4, 2024
ec30775
Fixed fast coarse graining strategy
james-alvey-42 Apr 8, 2024
b603f15
Added script for generating data to .npy files
james-alvey-42 Apr 8, 2024
a5691ac
Updated array init
james-alvey-42 Apr 8, 2024
d421585
Updated .gitignore
james-alvey-42 Apr 8, 2024
4a7862f
Added dataset and dataloaders + new training notebook for NV
james-alvey-42 Apr 11, 2024
efd74f0
6 fr improved modular simulator undark path (#11)
james-alvey-42 Jun 6, 2024
9fdfd98
Debugged new scripted version and logging (#12)
james-alvey-42 Jun 6, 2024
dc2b9a0
Updated training script
james-alvey-42 Jun 11, 2024
150b9ef
Small run updates
james-alvey-42 Jun 11, 2024
47ecc2e
Dataset resampling notebook added
james-alvey-42 Jun 19, 2024
48e26fc
Fully implemented resampling training
james-alvey-42 Jun 20, 2024
cea4899
Big update with noise cross terms added
james-alvey-42 Jul 11, 2024
aeea11c
Implemeted linear signal cross terms
james-alvey-42 Jul 12, 2024
17d3ecc
Full cross terms dataloaders
james-alvey-42 Jul 15, 2024
003c013
Corrected linear signal implementation (#14) (#15)
james-alvey-42 Jul 15, 2024
b53345a
Pre-arxiv update
james-alvey-42 Aug 1, 2024
5eb7a07
v0.0.2 clean up
james-alvey-42 Aug 1, 2024
cdc0750
Update .gitignore
james-alvey-42 Aug 1, 2024
f69a425
Update .gitignore
james-alvey-42 Aug 1, 2024
0c33b93
Clean up notebooks
james-alvey-42 Aug 1, 2024
f9cd38d
Merge branch 'main' into pre-v002-release
james-alvey-42 Aug 1, 2024
986a38e
Update requirements.txt
james-alvey-42 Aug 1, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -128,4 +128,10 @@ dmypy.json

# Pyre type checker
.pyre/
peregrine/tmnre_store/*
**/simulations*/*
**/lightning_logs/*
**/*.ckpt
**/wandb/**
*.npy
**/examples/vary_noise/data/**
**/examples/vary_noise/training_dir/**
45 changes: 45 additions & 0 deletions examples/vary_noise/configs/arxiv_submission.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
model:
name: noise_variation
fmin: 3.e-5
fmax: 5.e-1
deltaf: 1.e-6
ngrid: 2000
noise_approx: False
priors:
amp: [-11.12569522857666, -10.895430564880371]
tilt: [-0.434512734413147, 0.6660090684890747]
TM: [0., 6.]
OMS: [0., 30.]
simulate:
store_name: data/resampling_data
store_size: 50_000
chunk_size: 128
train:
trainer_dir: training_dir
type: resampling
store_name: data/resampling_data
channels: AET
total_size: 50_000
train_fraction: 0.7
train_batch_size: 8192
val_batch_size: 8192
num_workers: 8
device: gpu
n_devices: 1
min_epochs: 1
max_epochs: 400
early_stopping_patience: 400
learning_rate: 5.e-5
lr_scheduler:
type: CosineWithWarmUp
T_max: 380
eta_min: 1.e-7
total_warmup_steps: 20
logger:
type: wandb
name: resampling
project: saqqara
entity: j-b-g-alvey
offline: False
log_model: all
save_dir: training_dir
46 changes: 46 additions & 0 deletions examples/vary_noise/configs/bounded_mcmc.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
model:
name: noise_variation
fmin: 3.e-5
fmax: 5.e-1
deltaf: 1.e-6
ngrid: 1000
noise_approx: False
priors:
amp: [-12.0, -10.0]
tilt: [-1.5, 3.0] # [-0.7, 0.7]
TM: [0., 6.]
OMS: [0., 30.]
simulate:
store_name: data/bounded_simulations
store_size: 250_000
chunk_size: 128
train:
trainer_dir: training_dir
type: resampling
signal_dir: data/signal_store
tm_dir: data/tm_store
oms_dir: data/oms_store
total_size: 50_000
train_fraction: 0.5
train_batch_size: 8192
val_batch_size: 8192
num_workers: 8
device: gpu
n_devices: 1
min_epochs: 1
max_epochs: 300
early_stopping_patience: 300
learning_rate: 8.e-5
lr_scheduler:
type: CosineWithWarmUp
T_max: 250
eta_min: 1.e-5
total_warmup_steps: 20
logger:
type: wandb
name: resampling
project: saqqara
entity: j-b-g-alvey
offline: False
log_model: all
save_dir: training_dir
193 changes: 193 additions & 0 deletions examples/vary_noise/explore/coverage_test_results.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import saqqara\n",
"import sys\n",
"sys.path.insert(0, '../inference/')\n",
"sys.path.insert(0, '../simulator/')\n",
"from networks import SignalAET\n",
"from simulator import LISA_AET\n",
"from dataloader import get_datasets, setup_dataloaders, get_data_npy_dataset\n",
"import numpy as np\n",
"import torch\n",
"import matplotlib.pyplot as plt\n",
"import glob\n",
"import pickle\n",
"import swyft\n",
"import logging\n",
"import tqdm\n",
"log = logging.getLogger(\"pytorch_lightning\")\n",
"log.propagate = False\n",
"log.setLevel(logging.ERROR)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from scipy.integrate import simps\n",
"from scipy.interpolate import interp1d\n",
"def get_sigmas(logratios):\n",
" lrs = np.array(logratios.logratios[:, 0].reshape(int(np.sqrt(logratios.logratios.shape[0])), int(np.sqrt(logratios.logratios.shape[0]))))\n",
" params_alpha = np.array(logratios.params[:, 0, 0].reshape(int(np.sqrt(logratios.params.shape[0])), int(np.sqrt(logratios.params.shape[0]))))\n",
" params_gamma = np.array(logratios.params[:, 0, 1].reshape(int(np.sqrt(logratios.params.shape[0])), int(np.sqrt(logratios.params.shape[0]))))\n",
" posterior = np.exp(lrs - np.max(lrs)) / np.sum(np.exp(lrs - np.max(lrs))) / (params_alpha[1, 0] - params_alpha[0, 0]) * (params_gamma[0, 1] - params_gamma[0, 0])\n",
" alpha_marginal = simps(posterior, params_gamma, axis=1)\n",
" gamma_marginal = simps(posterior, params_alpha, axis=0)\n",
" alpha_ps = params_alpha[:, 0]\n",
" gamma_ps = params_gamma[0, :]\n",
" norm_alpha_marginal = alpha_marginal / simps(alpha_marginal, alpha_ps) \n",
" norm_gamma_marginal = gamma_marginal / simps(gamma_marginal, gamma_ps)\n",
" alpha_cumulant = np.cumsum(norm_alpha_marginal * (alpha_ps[1] - alpha_ps[0]))\n",
" gamma_cumulant = np.cumsum(norm_gamma_marginal * (gamma_ps[1] - gamma_ps[0]))\n",
" alpha_interp = interp1d(alpha_cumulant, alpha_ps)\n",
" gamma_interp = interp1d(gamma_cumulant, gamma_ps)\n",
" alpha_sigma = 0.5 * (alpha_interp(0.5 + 0.34) - alpha_interp(0.5 - 0.34))\n",
" gamma_sigma = 0.5 * (gamma_interp(0.5 + 0.34) - gamma_interp(0.5 - 0.34))\n",
" return alpha_sigma, gamma_sigma\n",
"\n",
"def get_resampling_dataset(sim, settings, path_to_data=None):\n",
" training_settings = settings.get(\"train\", {})\n",
" if training_settings[\"type\"] != \"resampling\":\n",
" raise ValueError(\"Training type must be resampling\")\n",
" data_dir = training_settings.get(\"store_name\") if path_to_data is None else path_to_data + training_settings.get(\"store_name\")\n",
" store_dataset = get_data_npy_dataset(data_dir)\n",
" resampling_dataset = saqqara.RandomSamplingDataset(\n",
" store_dataset,\n",
" shuffle=training_settings.get(\"shuffle\", True),\n",
" )\n",
" dataset = saqqara.ResamplingTraining(sim, resampling_dataset)\n",
" return dataset\n",
"\n",
"def get_grid(N=1000, a_low=-11.42961597442627, a_high=-10.696080207824707, g_low=-0.7066106200218201, g_high=1.0477334260940552):\n",
" a_samples = np.linspace(a_low, a_high, N)\n",
" g_samples = np.linspace(g_low, g_high, N)\n",
" ag_samples = np.array(np.meshgrid(a_samples, g_samples)).T.reshape(-1, 2)\n",
" A_samples = np.ones(N)\n",
" P_samples = np.ones(N)\n",
" AP_samples = np.array(np.meshgrid(A_samples, P_samples)).T.reshape(-1, 2)\n",
" return swyft.Samples(z=np.float32(np.concatenate((ag_samples, AP_samples), axis=1)))\n",
"\n",
"def get_network(id, sim):\n",
" config = glob.glob(f\"../training_dir/training_config_id={id}.yaml\")[0]\n",
" ckpt = glob.glob(f\"../training_dir/saqqara-*_id={id}.ckpt\")[0]\n",
" settings = saqqara.load_settings(config_path=config)\n",
" network = SignalAET(settings=settings, sim=sim)\n",
" network = saqqara.load_state(network=network, ckpt=ckpt)\n",
" return network"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"id = \"G7RG\"\n",
"config = glob.glob(f\"../training_dir/training_config_id={id}.yaml\")[0]\n",
"ckpt = glob.glob(f\"../training_dir/saqqara-*_id={id}.ckpt\")[0]\n",
"settings = saqqara.load_settings(config_path=config)\n",
"sim = LISA_AET(settings)\n",
"network = get_network(id, sim)\n",
"trainer = saqqara.setup_trainer(settings, logger=None)\n",
"dataset = get_resampling_dataset(sim, settings, path_to_data='../')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def get_cumulative_dist(sample):\n",
" dist = []\n",
" for pt in np.linspace(0, 1, 1000):\n",
" dist.append(len(sample[sample < pt]) / len(sample))\n",
" return np.array(dist)\n",
"dists = np.vstack([get_cumulative_dist(np.random.uniform(0, 1, 200)) for _ in tqdm.tqdm(range(4000))])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"prior_samples = swyft.Samples(z=torch.tensor(sim.prior.sample(10_000)).float())\n",
"coverage_data = dataset.sample(z=sim.prior.sample(4000))\n",
"dm = swyft.Samples(z=torch.tensor(coverage_data['z']).float(), data=torch.tensor(coverage_data['data']).float())\n",
"coverage_samples = trainer.test_coverage(get_network(id=\"G7RG\", sim=sim), dm, prior_samples, batch_size=100000)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"fig, axes = plt.subplots(1, 1, figsize = (10, 10))\n",
"for i in range(2):\n",
" ax = axes\n",
" ax.plot([0, 0], [1, 1], c='r')\n",
" swyft.plot_pp(coverage_samples, ['z[0]', 'z[1]'], ax = axes)\n",
" axes.plot([0.0, 1.0], [0.0, 1.0])\n",
" axes.fill_between(np.linspace(0, 1, 1000), y1=np.quantile(dists, q=[0.025], axis=0)[0], y2=np.quantile(dists, q=[0.975], axis=0)[0], color='C1', alpha=0.8, zorder=-10)\n",
"plt.tight_layout()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"pickle.dump(coverage_samples, open(\"../results/full_inference/coverage_samples.pkl\", \"wb\"))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "default",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.0"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Loading
Loading