From 6425ed194b488864f23dbb3a51c9ca130acba36a Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Wed, 3 May 2023 09:05:48 -0700 Subject: [PATCH] Fix loading `2023-02-07-ppd-mp.pkl.gz` (#26) * fix fetch_process_wbm_dataset.py after pandas v2 breaking changes * drop pytest-markdown-docs from optional deps * fix double slash in PRED_FILES * bump deps * update docs, DataFiles and matbench_discovery/energy.py with updated (2023-02-07) MP elemental reference energies (closes #23) * update 2022-10-19-wbm-summary.csv formation energies with 2023-02-07 element reference energies compress data/mp/2023-02-07-mp-elemental-reference-entries.json.gz update data/figshare/1.0.0.json file links * pin pandas>=2.0.0 https://github.com/janosh/matbench-discovery/issues/22#issuecomment-1531464310 mark test_load_train_test_no_mock() for mp_computed_structure_entries as very_slow * load_train_test() support loading and caching pickle files (for mp_patched_phase_diagram) change signature from data_names (str | list[str], optional) = 'all' to data_key (str) * rename load_train_test() to load() --- .gitignore | 1 + data/figshare/1.0.0.json | 4 +- ...-07-mp-elemental-reference-entries.json.gz | Bin 0 -> 6350 bytes data/mp/build_phase_diagram.py | 8 +- data/wbm/fetch_process_wbm_dataset.py | 54 ++++--- matbench_discovery/data.py | 133 +++++++++--------- matbench_discovery/energy.py | 10 +- matbench_discovery/preds.py | 2 +- pyproject.toml | 11 +- site/package.json | 8 +- .../src/figs/mp-elemental-ref-energies.svelte | 2 +- site/src/routes/contribute/+page.md | 8 +- site/src/routes/si/+page.md | 2 +- tests/test_data.py | 93 ++++++------ 14 files changed, 165 insertions(+), 171 deletions(-) create mode 100644 data/mp/2023-02-07-mp-elemental-reference-entries.json.gz diff --git a/.gitignore b/.gitignore index a56b9175..7a59fcd7 100644 --- a/.gitignore +++ b/.gitignore @@ -14,6 +14,7 @@ data/**/raw data/**/tsne data/2022-* data/m3gnet-* +!data/mp/2023-02-07-mp-elemental-reference-entries.json.gz # slurm + Weights and Biases logs wandb/ diff --git a/data/figshare/1.0.0.json b/data/figshare/1.0.0.json index 2760c740..7908f52a 100644 --- a/data/figshare/1.0.0.json +++ b/data/figshare/1.0.0.json @@ -1,10 +1,10 @@ { "mp_computed_structure_entries": "https://figshare.com/ndownloader/files/40344436", - "mp_elemental_ref_entries": "https://figshare.com/ndownloader/files/40344445", + "mp_elemental_ref_entries": "https://figshare.com/ndownloader/files/40387775", "mp_energies": "https://figshare.com/ndownloader/files/40344448", "mp_patched_phase_diagram": "https://figshare.com/ndownloader/files/40344451", "wbm_computed_structure_entries": "https://figshare.com/ndownloader/files/40344463", "wbm_initial_structures": "https://figshare.com/ndownloader/files/40344466", "wbm_cses_plus_init_structs": "https://figshare.com/ndownloader/files/40344469", - "wbm_summary": "https://figshare.com/ndownloader/files/40344475" + "wbm_summary": "https://figshare.com/ndownloader/files/40407575" } diff --git a/data/mp/2023-02-07-mp-elemental-reference-entries.json.gz b/data/mp/2023-02-07-mp-elemental-reference-entries.json.gz new file mode 100644 index 0000000000000000000000000000000000000000..d5e66456ad763b01fca099ac4c79b1cc642d83f1 GIT binary patch literal 6350 zcmV;<7%}G`iwFoUicn+%|1vN#Gc7PREif@HZE!7RY-Md_ZggR6EplaMWpZV1V`VL6 zZgg^KWpgfSb8l_{++Asp97l5fD+7ME8j%r^SAR0e)rzttnh?2%WEe!A)Y!Z*hhbr{ z|J{)_^g`>+Xrl;J0W5+fq;Ae+Pv`51cyYY=%cJM*(fR5xkN)}HhkE~Z{P&}ezrIU% zZ`%7)dw+Lx-EL3i!@H07cdfqu^uVKkx_b1_|m@c1qtZm;jIKfGT~Tn;uT`0Mf|^0lw8>-c`}K7M74`Rbc*9zR;X>gJ}E zPhTY8?{DA#b*leye|z^%zI42dfBScN=3}}^@7i6vxm`~9_~EXko7cA=TN(f2FOTl7 z@7mjO#*4>)di~}zl`TCg5%a2`H-XhoiN(uJa<#jhYpYa zI*z-(ef`UQ&go{n?w;P>w(%_<9$8+&_|b0e-@m^5_2cr+j-P2fR?}TtUg3xTzOL^g#ZC@0=VU19dLFZz&*aX0{^mo#PNpz z25`*>?8RY3hm*s(o>>V=dfU~Os_xtVe z=1-1oluv=5;BB1CQ@B}Y5NIE~{r(EVg*gqVu@xdS-Au}1mKJp<923u@(dc_Lf{9l`Hz5mxTe*Bd-`@u3rCTS z+!jv&w;Vh-aRwmBSMCG3r*Gw*e+c5}?=a_Mju8y7D?$#MHO1B-@ee!3)bf{47QkH@ zceXTwIOjE=8ocMI_|#7qwHkEKx5J!sCl^h$afZ3)s*%H;4BYd%fy*)U3g~NXxj0{w z#Ld(@mXvb=J_6Xh{N#g~GK)tg4QzoxsOvm{Q;pklukdiMQ@nk*Ep4#Qil?ln-IzB$7j6t90UERTCGGo>oKd<%TAX z@H8x3w0=e%#>w(8rkL9aZ!hOWRm@tb*@YOZgwHY9m~(3!j7dUNL-^9+#$OR!p|Dgs zlQ|hRh$HrlN;WQj2GlmRZej;co+N%_JF{=g(N0PKdXTH#32<#5!#OXo0&ca%Zj3>X zu2D7Qq7SW)(&LwoIE+`uoXMSsQD7~GBMjauKK0?A)~Dk(%>TMm9LYEnXPA4cX5-Kr zXTx6ExOo&;AO+)m%~@iC-ZX<9*79z?wvFY|rdZ!sm2;N#_DCZ)w!LK_^kc%#SVyD~%!&5$sgBf019; zIo?7dHYEmepQe~n8l|3rOzmikbN&#eRB@x^&lXSKl3KgB5~3-dbX?SG$o=@srj=C3 z1b*5pvK%~zoZl&I`|%gmz%l0>osjB~TM&=Zed-uvA_v`Y%*UwR^I3kQ6)$8}z)^?_ zJ@xCOwikCB)=*F2jUl30+_oI7H2C69nUEKE^Tx~Mgq0>LF2z!iil>;p!|uekb>?^s zXNqTZXQUHDP@O3|BBf@J>vJJLZb%VKI9Uz>EX>-;my_pdH`>XU{&=t2HXqS;uAE%} z@2ZE?eW}5<)?-eg9)>79^oTgn@y4hS1zPV(U)1g3W>2GSFPJcED_;&)g14P4TU9I@ zMcXnZ@twV-Y_O+7KtdA?yY9>ZyeNz^7KmkYCu@!FMJ+3q4cB*D-fo%WJ+Wm^9Netk z@>aEL9Czxqyd8E8Vb=RLLbj4hGXY}DEprvDRc`>f*xV0b*Wyy77rSOo!Rw^dhN$#g za$H}a->@enoB{1gMvqkZwCj_T9l9P#6aiimw)K<%9|4*ysi!D8>KamX{EY5H@ zPG(dlPh;@76ZBpv{p7U}jm}qZ$fV}H4>m!msbpuW?T6?mhb7T#=@hi0f*BJGTB7&U zJ*_WcZ5YSJ6PbYIri9_K9>(o8?ELsvN%R6(a$GZ_O=!7UEIE;d-0GOyI>wNe6=HfG z{s_Hl>f|ZSuda8iKYZBGufj<_9nON8VQo2COT6vG8GiV%<);z}vBFQydZ@k3Tk_MTs|9%m zwdG*Njqb%5US@@#1}t(Z8L$^i%q`elO@!EEYt;NvW>3x1>=uJwFhwR!I|SCpY)>}C z8l21s2n^1~6~@6v0h@!Jreo$BZX?-2**_$mVu^f0Be`dMf zHy83B!e#AV>E1VUycO(5ylJ+*TB7KRW%Q106OJ8+Cpkp$HYy7v3E~{OM)yH zwVo}!-;yni`(c_|xaD9CbUS&}m+#dEOO6?Qw9Rnnk&K}PF1H{E9ODY+7!8&%dP0nq z*hkO1F={jQa!u=Vyl=LI8%OV4fN75RjhY9EA$r;gagP-fC&=na$Z4Za$r_@Nib*A! z)S?SH9>t~7V&Z~xIzRQOr$f+*lrfg3Vis|E6k+rOdSr53E*V~q|EI8o} zv+lR$WCd`0wZFZ*o_AVi%+QkSEjw(!)P@k95j#yebl>eU>R2??8wgMSf?$XNO})W< zM)TW`Th7z-DTD}C2*a$0T+fK?q`&oJQ;l1Uz+*zz)R0}Sz9n|XcAv$o+Wf^k(~*M1 zp`|)1r{;pRB{Nv9Y`R{3>*v#K%tlRI~rfColAUt7~2KoJk z#U6Xfk?U^MMn`Onr3i$r5XR7x^EMF{+Z-PwYFn!GqsR3^)D~Es{98{EzDb*EHW-5x z7-p7jIa!g1=Dq4zNlWq?3e-AgpSd(>##1!Km70a6c59BH*Bgh#8{-)KV>Rg(ar886 zBw>9a>)ZF+b|6`>#?0Wh987v43wDCsw`yG>2XAdit#d|5R5$^I-l+xet$}ciT0M?h z{g6O%TIrO*1j83fS+E{@JdPL%SlKS(8t`$$# zL=GOuU5;Nf;gFy8vc2Ca=lk<}rG`qjHS|4jpyVj$p{QDrTq35_=@=C)@vVFYI)w|< z_2OO?uIksPeE+y*3*6W|A|`H@g!`lJRCyDJoBc@!KA5?(_Z@buik3R!3P_m=NwZl z=mcW5CoCf|RT^64W2`+DxWO`c8bBqh(FCvRk<({L;6Hx5smvf->us1dUM>e~*|?oD zzmMNu0Z&!9dX0%)YQ+TGIaN%hLdlI&!=8?@;vIK9|;_Ag&e?%99%iW2inlQ(z1@90PBQq$T(h>H8LVEKF9S zxAmI=pR_F(D+i#b4|T1v zYheiXT6G?9)U}*~wy?g}@`o+ysk0{+12gfAE-Xi5oEPnPGH*Y;R|!K;7JG=MBXQ-J zGnZ^ziY}!VdJG4sHP|u3FsfTY0&UH_1;b<%a!TvX+p7(`-;ho&IQ9-^1E|Z%r0m+= zi7Z^}{1md7v=nl8p&MwvnZ_QC@CPGNaymvXgvU-m4kr_Ii6jKC?%G`sw=FBFX6!_l zUvB13my?wl)Sa@uuaXM4k#K6sXP=`|Y&v4Th6*;qENl>lMvbTF`bn^Jk2mmQ7od#I8jO(+;?o`k0K{YwWZvymV&TYN6n| z1)C}qCJHurZ>+%1n3BCcN)|-mwltEP3t13N%fWrd4B*EtlWsID87sA zE6e-*qtZ(bB@<+CM*g#K%|n>eLdFnM;Zl!*H@DQWqQuwwmeXK@rolC?X?>-B%f=9Y zf@m2G&Vsj}^~MlGJJqp%zEMqFz$9i(98K@Gg|XwbiiH#7Mh%SdaDcs&i-hWEyc`mP znu#M2%b`JzK10)T!_j(k5`v9GjC7WHdvk#-Sm*cZy8Pjd+E|`*>8=M;3UMNk_-*uv zR3~oL2hZ*h{M3v}){x{vl`#al59c+;u-w)9oXDmdv)G8S2vZ5TSNg`Rz2qET{h~0` zmVE4lDDQYru0!t0bIF*Db+z>#g#k7aW9nRy4_STEwHVn6pNKPI3J@iUBa5=kI`!>fE(MK zzkf}NiDO6+a?0Hxat@*Q z%!P)cYN4T@e2{)5#+C|ST4@hDYPhBJ=UQ*@z<8uYD3~ z)aF?!bvap|5;SHfTle*iYU@&RHUS6k$@m@$6&tD6mQqOO5L-WR8J@vqbK8Z1Est6+ z+^4Hr&mlf+i#zs+&dn+h%gLI(b9SfcnP(p~)XA$3LV+A7^ey<(eG{y4t@1Es?~g%R z2V*_MSXZmo9w^FfXx8r2Q>!rd)3!5$F|r|=g>*k%vUFlxcEa3GY9A&yL6>jdmuk72 z7_vlma$D#_mW!yz7)^C&6PKtCE=c=EwafhBs#c@!cfW3G=_Do~2Xzlf1!oh7%%<7yyo;5L{OQO(MocZo!*rVuf!~vj=2Ex0R(`PS zW4W*IH=U(H>j9?QQNPv&;vACQYemD?_j9~ecCG@uR6CmPB&KUEI(hdwwuCMmMZKJP zL~_oA3)wkM+xOv~);BI~nJy6y?wOeFES6hNo@*(&z0_7mrerSnU>kZKN~dIZash3wRfQxPP zDv3``!8(U7w~CQN66j@#z z3oqVnYux~ii#QXCd!hT%;D!k3PKZstt8?XIsN{|^CLaM2o*kHAoQ_)}O` z-214{r=qvoeCl#j>lM4&wliGL86Rkt0Dh{EL|f-+FV^t1U16Bj^gAVp9ea-wIfrIL z%bcSGG%XI*hhu0H!pQ#`e_EsowX6t4R`&?fXRwB!HWYo}L4X2sm?nVpMOhoSQ~UW( z`88?@U$dSMK6p$7+y6h_K=k&V@4$dFBg-E;zNk zDi^iBcPVWe(Rm)-Ry^Ym-|I@#xRYAk_emkEnTi;*YC`t~G6Q>)g7L%8W(@u)&A=l8 zEH~iPeGTM;QQ(H_rQ#lMIYQ`8{GjPMmF+4Q*aNz~)K?!DrMl%LWD~Ad5gOIzn(Nh5 zj26g8Q_W%aLC2v@$ohT|0Izk*J>1iJ?vS>WUj?u>#xN^5Jkg zOhdy3SnQn{O3N*nObLBGQoSLbEa~+I<5^kI`x&j7-*3CQ5}kKp)_eZ__2tIQy@po5 zzaEG~ny3?w(rHc^45vV&Nz+-b-b7bID?J^9HjM0V7(O*dP!R{}3i3~PwZ6o4vnA}n z1>t)%T_||5G1G(JsqXMnA6_lVC2_8?7=#2Vd$2;iDY)caZ}tEyLx%3J1afeOOM_jJ zEEv-JmE^A07rQoet`ttA+)HF_%i>LiR%1-Tn-Wbg zCACQG<>l4zg^S)i2*|msA;xNlKBiSWY3R^3Rx+|_!K%q?l_DVh-de22r63pz`C zjXylm<2s4uyu>t=dUjNMYQz+~e2}>34qv#+`ESPxg6C zJ#^PbVSlznHD3@T@zL3xQYAlT6>}-oVw|(+ZN}70YZYqgkxH~ZmSgm0j*eh6W*<;p zrAl5tYwDip<*rua?z>Ipy%2o#VQS~b!DEh68M%qD6Nmjyg&noPJ-AA~BTy0p2c|Z1 zsaxOp5F32|cht*9Ar#ua8sZS~!pO}hZR&MdliLk(XHU_|ui&SfRc|i_c5L?A>+tk; zLEfxc6QawOlCJ?u0SGR7mrG3+swLO#0rYU str: # %% -df_wbm["computed_structure_entry"] = pd.concat(dfs_wbm_cses.values()).to_numpy() +df_wbm["computed_structure_entry"] = np.concatenate([*dfs_wbm_cses.values()]).squeeze() for mat_id, cse in df_wbm.computed_structure_entry.items(): # needed to ensure MaterialsProjectCompatibility can process the entries @@ -319,9 +319,9 @@ def increment_wbm_material_id(wbm_id: str) -> str: ) -assert sum(df_summary.index == "None") == 6 +assert sum(no_id_mask := df_summary.index.isna()) == 6, f"{sum(no_id_mask)=}" # the 'None' materials have 0 volume, energy, n_sites, bandgap, etc. -assert all(df_summary[df_summary.index == "None"].drop(columns=["formula"]) == 0) +assert all(df_summary[no_id_mask].drop(columns=["formula"]) == 0) assert len(df_summary.query("volume > 0")) == len(df_wbm) + len(nan_init_structs_ids) # make sure dropping materials with 0 volume removes exactly 6 materials, the same ones # listed in bad_struct_ids above @@ -332,7 +332,7 @@ def increment_wbm_material_id(wbm_id: str) -> str: df_summary.index = df_summary.index.map(increment_wbm_material_id) # format IDs # drop materials with id='None' and missing initial structures -df_summary = df_summary.drop(index=[*nan_init_structs_ids, "None"]) +df_summary = df_summary.drop(index=[*nan_init_structs_ids, float("NaN")]) # the 8403 material IDs in step 3 with final number larger than any of the ones in # bad_struct_ids are now misaligned between df_summary and df_wbm @@ -340,6 +340,14 @@ def increment_wbm_material_id(wbm_id: str) -> str: # bad_struct_ids. we fix this with fix_bad_struct_index_mismatch() by mapping the IDs in # df_wbm to the ones in df_summary so that both indices become consecutive. assert sum(df_summary.index != df_wbm.index) == 8403 +assert {*df_summary.index} - {*df_wbm.index} == { + "wbm-3-70803", + "wbm-3-70804", + "wbm-3-70826", + "wbm-3-70827", + "wbm-3-70829", + "wbm-3-70830", +} def fix_bad_struct_index_mismatch(material_id: str) -> str: @@ -559,7 +567,6 @@ def fix_bad_struct_index_mismatch(material_id: str) -> str: assert sum(df_wbm.index != df_summary.index) == 0 e_form_col = "e_form_per_atom_uncorrected" -assert e_form_col not in df_summary for row in tqdm(df_wbm.itertuples(), total=len(df_wbm)): mat_id, cse, formula = row.Index, row.cse, row.formula_from_cse @@ -568,17 +575,21 @@ def fix_bad_struct_index_mismatch(material_id: str) -> str: entry_like = dict(composition=formula, energy=cse.uncorrected_energy) e_form = get_e_form_per_atom(entry_like) - e_form_ppd = ppd_mp.get_form_energy_per_atom(cse) + e_form_ppd = ppd_mp.get_form_energy_per_atom(cse) - cse.correction_per_atom - correction = cse.correction_per_atom # make sure the PPD.get_e_form_per_atom() and standalone get_e_form_per_atom() # method of calculating formation energy agree assert ( - abs(e_form - (e_form_ppd - correction)) < 1e-4 - ), f"{mat_id=}: {e_form=:.5} != {e_form_ppd - correction=:.5}" + abs(e_form - e_form_ppd) < 1e-4 + ), f"{mat_id}: {e_form=:.3} != {e_form_ppd=:.3} (diff={e_form - e_form_ppd:.3}))" df_summary.at[cse.entry_id, e_form_col] = e_form +df_summary[e_form_col.replace("uncorrected", "mp2020_corrected")] = ( + df_summary[e_form_col] + df_summary["e_correction_per_atom_mp2020"] +) + + # %% try: from aviary.wren.utils import get_aflow_label_from_spglib @@ -623,17 +634,16 @@ def fix_bad_struct_index_mismatch(material_id: str) -> str: df_summary.round(6).to_csv(f"{module_dir}/{today}-wbm-summary.csv") -# %% read summary data from disk -df_summary = pd.read_csv(f"{module_dir}/2022-10-19-wbm-summary.csv").set_index( - "material_id" -) - - -# %% read WBM initial structures and computed structure entries from disk -df_wbm = pd.read_json( - f"{module_dir}/2022-10-19-wbm-computed-structure-entries+init-structs.json.bz2" -).set_index("material_id") +# %% only here to load data quickly for later inspection +if False: + df_summary = pd.read_csv(f"{module_dir}/2022-10-19-wbm-summary.csv").set_index( + "material_id" + ) + df_wbm = pd.read_json( + f"{module_dir}/2022-10-19-wbm-computed-structure-entries+init-structs.json.bz2" + ).set_index("material_id") -df_wbm["cse"] = [ - ComputedStructureEntry.from_dict(x) for x in tqdm(df_wbm.computed_structure_entry) -] + df_wbm["cse"] = [ + ComputedStructureEntry.from_dict(x) + for x in tqdm(df_wbm.computed_structure_entry) + ] diff --git a/matbench_discovery/data.py b/matbench_discovery/data.py index afabe7c4..de2bf951 100644 --- a/matbench_discovery/data.py +++ b/matbench_discovery/data.py @@ -1,17 +1,19 @@ from __future__ import annotations +import gzip import json import os +import pickle import sys import urllib.error -from collections.abc import Sequence +import urllib.request from glob import glob from pathlib import Path from typing import Any, Callable import pandas as pd -from pymatgen.core import Structure -from pymatgen.entries.computed_entries import ComputedStructureEntry +from monty.json import MontyDecoder +from pymatgen.analysis.phase_diagram import PatchedPhaseDiagram from tqdm import tqdm from matbench_discovery import FIGSHARE @@ -40,13 +42,13 @@ def as_dict_handler(obj: Any) -> dict[str, Any] | None: # removes e.g. non-serializable AseAtoms from M3GNet relaxation trajectories -def load_train_test( - data_names: str | Sequence[str], +def load( + data_key: str, version: str = figshare_versions[-1], cache_dir: str | Path = default_cache_dir, hydrate: bool = False, **kwargs: Any, -) -> pd.DataFrame: +) -> pd.DataFrame | PatchedPhaseDiagram: """Download parts of or the full MP training data and WBM test data as pandas DataFrames. The full training and test sets are each about ~500 MB as compressed JSON which will be cached locally to cache_dir for faster re-loading unless @@ -56,8 +58,8 @@ def load_train_test( see https://janosh.github.io/matbench-discovery/contribute#--direct-download. Args: - data_names (str | list[str], optional): Which parts of the MP/WBM data to load. - Can be any subset of set(DATA_FILES) or 'all'. + data_key (str): Which parts of the MP/WBM data to load. Must be one of + list(DATA_FILES). version (str, optional): Which version of the dataset to load. Defaults to latest version of data files published to Figshare. Pass any invalid version to see valid options. @@ -71,77 +73,68 @@ def load_train_test( depending on which file is loaded. Raises: - ValueError: On bad version number or bad data names. + ValueError: On bad version number or bad data_key. Returns: pd.DataFrame: Single dataframe or dictionary of dfs if multiple data requested. """ if version not in figshare_versions: raise ValueError(f"Unexpected {version=}. Must be one of {figshare_versions}.") - if data_names == "all": - data_names = list(DATA_FILES) - elif isinstance(data_names, str): - data_names = [data_names] - if missing := set(data_names) - set(DATA_FILES): - raise ValueError(f"{missing} must be subset of {set(DATA_FILES)}") + if not isinstance(data_key, str) or data_key not in DATA_FILES: + raise ValueError(f"Unknown {data_key=}, must be one of {list(DATA_FILES)}.") with open(f"{FIGSHARE}/{version}.json") as json_file: file_urls = json.load(json_file) - dfs = {} - for key in data_names: - file = DataFiles.__dict__[key] - csv_ext = (".csv", ".csv.gz", ".csv.bz2") - reader = pd.read_csv if file.endswith(csv_ext) else pd.read_json - - cache_path = f"{cache_dir}/{file}" - if os.path.isfile(cache_path): # load from disk cache - print(f"Loading {key!r} from cached file at {cache_path!r}") - df = reader(cache_path, **kwargs) - else: # download from Figshare URL - # manually set compression since pandas can't infer from URL - if file.endswith(".gz"): - kwargs.setdefault("compression", "gzip") - elif file.endswith(".bz2"): - kwargs.setdefault("compression", "bz2") - url = file_urls[key] - print(f"Downloading {key!r} from {url}") + file = DataFiles.__dict__[data_key] + + cache_path = f"{cache_dir}/{file}" + if not os.path.isfile(cache_path): # download from Figshare URL + url = file_urls[data_key] + print(f"Downloading {data_key!r} from {url}") + try: + # ensure directory exists + os.makedirs(os.path.dirname(cache_path), exist_ok=True) + # download and save to disk + urllib.request.urlretrieve(url, cache_path) + print(f"Cached {data_key!r} to {cache_path!r}") + except urllib.error.HTTPError as exc: + raise ValueError(f"Bad {url=}") from exc + except Exception: + print(f"\n\nvariable dump:\n{file=},\n{url=}") + raise + + print(f"Loading {data_key!r} from cached file at {cache_path!r}") + if ".pkl" in file: # handle key='mp_patched_phase_diagram' separately + with gzip.open(cache_path, "rb") as zip_file: + return pickle.load(zip_file) + + csv_ext = (".csv", ".csv.gz", ".csv.bz2") + reader = pd.read_csv if file.endswith(csv_ext) else pd.read_json + try: + df = reader(cache_path, **kwargs) + except Exception: + print(f"\n\nvariable dump:\n{file=},\n{reader=}\n{kwargs=}") + raise + + if "material_id" in df: + df = df.set_index("material_id") + if hydrate: + for col in df: + if not isinstance(df[col].iloc[0], dict): + continue try: - df = reader(url, **kwargs) - except urllib.error.HTTPError as exc: - raise ValueError(f"Bad {url=}") from exc + # convert dicts to pymatgen Structures and ComputedStructureEntrys + df[col] = [ + MontyDecoder().process_decoded(dct) + for dct in tqdm(df[col], desc=col) + ] except Exception: - print(f"\n\nvariable dump:\n{file=},\n{url=},\n{reader=},\n{kwargs=}") + print(f"\n\nvariable dump:\n{col=},\n{df[col]=}") raise - if cache_dir and not os.path.isfile(cache_path): - os.makedirs(os.path.dirname(cache_path), exist_ok=True) - if ".csv" in file: - df.to_csv(cache_path, index=False) - elif ".json" in file: - df.to_json(cache_path, default_handler=as_dict_handler) - else: - raise ValueError(f"Unexpected file type {file}") - print(f"Cached {key!r} to {cache_path!r}") - if "material_id" in df: - df = df.set_index("material_id") - if hydrate: - for col in df: - if not isinstance(df[col].iloc[0], dict): - continue - try: - df[col] = [ - ComputedStructureEntry.from_dict(d) - for d in tqdm(df[col], desc=col) - ] - except Exception: - df[col] = [Structure.from_dict(d) for d in tqdm(df[col], desc=col)] - - dfs[key] = df - - if len(data_names) == 1: - return dfs[data_names[0]] - return dfs + + return df def glob_to_df( @@ -228,7 +221,7 @@ class DataFiles(Files): def _on_not_found(self, key: str, msg: str) -> None: # type: ignore[override] msg += ( " Would you like to download it now using matbench_discovery." - f"data.load_train_test({key!r}). This will cache the file for future use." + f"data.load({key!r}). This will cache the file for future use." ) # default to 'y' if not in interactive session, and user can't answer @@ -236,12 +229,12 @@ def _on_not_found(self, key: str, msg: str) -> None: # type: ignore[override] while answer not in ("y", "n"): answer = input(f"{msg} [y/n] ").lower().strip() if answer == "y": - load_train_test(key) # download and cache data file + load(key) # download and cache data file mp_computed_structure_entries = ( "mp/2023-02-07-mp-computed-structure-entries.json.gz" ) - mp_elemental_ref_entries = "mp/2022-09-19-mp-elemental-reference-entries.json" + mp_elemental_ref_entries = "mp/2023-02-07-mp-elemental-reference-entries.json.gz" mp_energies = "mp/2023-01-10-mp-energies.csv" mp_patched_phase_diagram = "mp/2023-02-07-ppd-mp.pkl.gz" wbm_computed_structure_entries = ( @@ -254,9 +247,9 @@ def _on_not_found(self, key: str, msg: str) -> None: # type: ignore[override] wbm_summary = "wbm/2022-10-19-wbm-summary.csv" -# data files can be downloaded and cached with matbench_discovery.data.load_train_test() +# data files can be downloaded and cached with matbench_discovery.data.load() DATA_FILES = DataFiles() -df_wbm = load_train_test("wbm_summary") +df_wbm = load("wbm_summary") df_wbm["material_id"] = df_wbm.index diff --git a/matbench_discovery/energy.py b/matbench_discovery/energy.py index c6e6b001..77e20086 100644 --- a/matbench_discovery/energy.py +++ b/matbench_discovery/energy.py @@ -70,15 +70,15 @@ def get_elemental_ref_entries( # tested to agree with TRI's MP reference energies # https://github.com/TRI-AMDD/CAMD/blob/1c965cba636531e542f4821a555b98b2d81ed034/camd/utils/data.py#L134 -# fmt: off mp_elemental_ref_energies = { - "Ne": -0.0259, "He": -0.0091, "Ar": -0.0688, "F": -1.9115, "O": -4.948, "Cl": -1.8485, "N": -8.3365, "Kr": -0.0567, "Br": -1.6369, "I": -1.524, "Xe": -0.0362, "S": -4.1364, "Se": -3.4959, "C": -9.2268, "Au": -3.2739, "W": -12.9581, "Pb": -3.7126, "Rh": -7.3643, "Pt": -6.0709, "Ru": -9.2744, "Pd": -5.1799, "Os": -11.2274, "Ir": -8.8384, "H": -3.3927, "P": -5.4133, "As": -4.6591, "Mo": -10.8456, "Te": -3.1433, "Sb": -4.129, "B": -6.6794, "Bi": -3.89, "Ge": -4.623, "Hg": -0.3037, "Sn": -4.0096, "Ag": -2.8326, "Ni": -5.7801, "Tc": -10.3606, "Si": -5.4253, "Re": -12.4445, "Cu": -4.0992, "Co": -7.1083, "Fe": -8.47, "Ga": -3.0281, "In": -2.7517, "Cd": -0.9229, "Cr": -9.653, "Zn": -1.2597, "V": -9.0839, "Tl": -2.3626, "Al": -3.7456, "Nb": -10.1013, "Be": -3.7394, "Mn": -9.162, "Ti": -7.8955, "Ta": -11.8578, "Pa": -9.5147, "U": -11.2914, "Sc": -6.3325, "Np": -12.9478, "Zr": -8.5477, "Mg": -1.6003, "Th": -7.4139, "Hf": -9.9572, "Pu": -14.2678, "Lu": -4.521, "Tm": -4.4758, "Er": -4.5677, "Ho": -4.5824, "Y": -6.4665, "Dy": -4.6068, "Gd": -14.0761, "Eu": -10.292, "Sm": -4.7186, "Nd": -4.7681, "Pr": -4.7809, "Pm": -4.7505, "Ce": -5.9331, "Yb": -1.5396, "Tb": -4.6344, "La": -4.936, "Ac": -4.1212, "Ca": -2.0056, "Li": -1.9089, "Sr": -1.6895, "Na": -1.3225, "Ba": -1.919, "Rb": -0.9805, "K": -1.1104, "Cs": -0.8954, # noqa: E501 + elem: round(entry.energy_per_atom, 4) + for elem, entry in mp_elem_reference_entries.items() } -# fmt: on def get_e_form_per_atom( - entry: EntryLike, elemental_ref_energies: dict[str, float] = None + entry: EntryLike, + elemental_ref_energies: dict[str, float] = mp_elemental_ref_energies, ) -> float: """Get the formation energy of a composition from a list of entries and a dict mapping elements to reference energies. @@ -96,8 +96,6 @@ def get_e_form_per_atom( Returns: float: formation energy in eV/atom. """ - elemental_ref_energies = elemental_ref_energies or mp_elemental_ref_energies - if isinstance(entry, dict): energy = entry["energy"] comp = Composition(entry["composition"]) # is idempotent if already Composition diff --git a/matbench_discovery/preds.py b/matbench_discovery/preds.py index cee814d9..7b9796d9 100644 --- a/matbench_discovery/preds.py +++ b/matbench_discovery/preds.py @@ -60,7 +60,7 @@ class PredFiles(Files): # model_labels remaps model keys to pretty plot labels (see Files) -PRED_FILES = PredFiles(root=f"{ROOT}/models/", key_map=model_labels) +PRED_FILES = PredFiles(root=f"{ROOT}/models", key_map=model_labels) def load_df_wbm_with_preds( diff --git a/pyproject.toml b/pyproject.toml index 020444f7..6a83a9c6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,7 +36,7 @@ dependencies = [ "numpy", # output_formatting needed to for pandas Stylers # see https://github.com/pandas-dev/pandas/blob/main/pyproject.toml#L78 - "pandas[output_formatting]", + "pandas[output_formatting]>=2.0.0", "plotly", "pymatgen", "pymatviz[export-figs]", @@ -52,7 +52,7 @@ Repo = "https://github.com/janosh/matbench-discovery" Package = "https://pypi.org/project/matbench-discovery" [project.optional-dependencies] -test = ["pytest", "pytest-cov", "pytest-markdown-docs"] +test = ["pytest", "pytest-cov"] # how to specify git deps: https://stackoverflow.com/a/73572379 running-models = [ # torch needs to install before aviary @@ -121,5 +121,8 @@ no_implicit_optional = false [tool.pytest.ini_options] testpaths = ["tests"] -addopts = "-p no:warnings -m 'not slow'" -markers = ["slow: deselect slow tests with -m 'not slow'"] +addopts = "-p no:warnings -m 'not slow and not very_slow'" +markers = [ + "slow: deselect slow tests with -m 'not slow'", + "very_slow: select with -m 'very_slow'", +] diff --git a/site/package.json b/site/package.json index e18bafdb..23a8db82 100644 --- a/site/package.json +++ b/site/package.json @@ -21,14 +21,14 @@ "@sveltejs/adapter-static": "^2.0.2", "@sveltejs/kit": "^1.15.9", "@sveltejs/vite-plugin-svelte": "^2.1.1", - "@typescript-eslint/eslint-plugin": "^5.59.1", - "@typescript-eslint/parser": "^5.59.1", + "@typescript-eslint/eslint-plugin": "^5.59.2", + "@typescript-eslint/parser": "^5.59.2", "elementari": "^0.1.6", "eslint": "^8.39.0", "eslint-plugin-svelte3": "^4.0.0", "hastscript": "^7.2.0", "js-yaml": "^4.1.0", - "katex": "^0.16.6", + "katex": "^0.16.7", "mdsvex": "^0.10.6", "prettier": "^2.8.8", "prettier-plugin-svelte": "^2.10.0", @@ -38,7 +38,7 @@ "remark-math": "3.0.0", "svelte": "^3.58.0", "svelte-check": "^3.2.0", - "svelte-multiselect": "^8.6.0", + "svelte-multiselect": "^8.6.1", "svelte-preprocess": "^5.0.3", "svelte-toc": "^0.5.5", "svelte-zoo": "^0.4.5", diff --git a/site/src/figs/mp-elemental-ref-energies.svelte b/site/src/figs/mp-elemental-ref-energies.svelte index 9e95886e..cc5c6fcf 100644 --- a/site/src/figs/mp-elemental-ref-energies.svelte +++ b/site/src/figs/mp-elemental-ref-energies.svelte @@ -1 +1 @@ -
+
diff --git a/site/src/routes/contribute/+page.md b/site/src/routes/contribute/+page.md index c5317fc4..6e475776 100644 --- a/site/src/routes/contribute/+page.md +++ b/site/src/routes/contribute/+page.md @@ -17,10 +17,10 @@ pip install matbench-discovery This example script downloads the training and test data for training a model: ```py -from matbench_discovery.data import load_train_test +from matbench_discovery.data import load from matbench_discovery.data import df_wbm, DATA_FILES -# any subset of these keys can be passed to load_train_test() +# any subset of these keys can be passed to load() assert sorted(DATA_FILES) == [ "mp-computed-structure-entries", "mp-elemental-ref-energies", @@ -31,7 +31,7 @@ assert sorted(DATA_FILES) == [ "wbm-summary", ] -df_wbm = load_train_test("wbm-summary", version="v1.0.0") +df_wbm = load("wbm-summary", version="v1.0.0") assert df_wbm.shape == (256963, 15) @@ -79,7 +79,7 @@ You can also download the data files directly from GitHub: 1. [`2023-01-10-mp-energies.json.gz`]({repo}/blob/-/data/mp/2023-01-10-mp-energies.json.gz): Materials Project formation energies and energies above convex hull 1. [`2023-02-07-mp-computed-structure-entries.json.gz`]({repo}/blob/-/data/mp/2023-02-07-mp-computed-structure-entries.json.gz): Materials Project computed structure entries 1. [`2023-02-07-ppd-mp.pkl.gz`]({repo}/blob/-/data/mp/2023-02-07-ppd-mp.pkl.gz): [PatchedPhaseDiagram](https://pymatgen.org/pymatgen.analysis.phase_diagram.html#pymatgen.analysis.phase_diagram.PatchedPhaseDiagram) constructed from all MP ComputedStructureEntries -1. [`2022-09-19-mp-elemental-reference-entries.json`]({repo}/blob/-/data/mp/2022-09-19-mp-elemental-reference-entries.json): Minimum energy PDEntries for each element present in the Materials Project +1. [`2023-02-07-mp-elemental-reference-entries.json.gz`]({repo}/blob/-/data/mp/2023-02-07-mp-elemental-reference-entries.json.gz): Minimum energy PDEntries for each element present in the Materials Project [wbm paper]: https://nature.com/articles/s41524-020-00481-6 diff --git a/site/src/routes/si/+page.md b/site/src/routes/si/+page.md index 2b6f8459..0e1dee7b 100644 --- a/site/src/routes/si/+page.md +++ b/site/src/routes/si/+page.md @@ -61,7 +61,7 @@ A further point of clarification: whenever we say distance to the convex hull we {/if} -> @label:fig:mp-elemental-reference-energies WBM formation energies were calculated w.r.t. these Materials Project elemental reference energies ([queried on 2022-09-19](https://github.com/janosh/matbench-discovery/blob/main/data/mp/2022-09-19-mp-elemental-reference-entries.json)). Marker size indicates the number of atoms in the reference structure. Hover points for details. +> @label:fig:mp-elemental-reference-energies WBM formation energies were calculated w.r.t. these Materials Project elemental reference energies ([queried on 2023-02-07](https://github.com/janosh/matbench-discovery/blob/main/data/mp/2023-02-07-mp-elemental-reference-entries.json.gz)). Marker size indicates the number of atoms in the reference structure. Hover points for details. ## Classification Histograms using Model-Predicted Energies diff --git a/tests/test_data.py b/tests/test_data.py index 410ce36b..f804141f 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -19,7 +19,7 @@ df_wbm, figshare_versions, glob_to_df, - load_train_test, + load, ) with open(f"{FIGSHARE}/{figshare_versions[-1]}.json") as file: @@ -33,35 +33,33 @@ @pytest.mark.parametrize( - "data_names, hydrate", + "data_key, hydrate", [ - (["wbm_summary"], True), - (["wbm_initial_structures"], True), - (["wbm_computed_structure_entries"], False), - (["wbm_summary", "wbm_initial_structures"], True), - (["mp_elemental_ref_entries"], True), - (["mp_energies"], True), + ("wbm_summary", True), + ("wbm_initial_structures", True), + ("wbm_computed_structure_entries", False), + ("mp_elemental_ref_entries", True), + ("mp_energies", True), ], ) -def test_load_train_test( - data_names: list[str], +def test_load( + data_key: str, hydrate: bool, + # df with Structures and ComputedStructureEntries as dicts dummy_df_serialized: pd.DataFrame, capsys: CaptureFixture[str], tmp_path: Path, ) -> None: - # intercept HTTP requests to GitHub raw user content and return dummy df instead - with patch("matbench_discovery.data.pd.read_csv") as read_csv, patch( - "matbench_discovery.data.pd.read_json" - ) as read_json: - # dummy df with Structures and ComputedStructureEntries - read_json.return_value = dummy_df_serialized + filepath = DATA_FILES[data_key] + # intercept HTTP requests and write dummy df to disk instead + with patch("urllib.request.urlretrieve") as urlretrieve: # dummy df with random floats and material_id column - read_csv.return_value = pd._testing.makeDataFrame().reset_index( - names="material_id" - ) - out = load_train_test( - data_names, + df_csv = pd._testing.makeDataFrame().reset_index(names="material_id") + + writer = dummy_df_serialized.to_json if ".json" in filepath else df_csv.to_csv + urlretrieve.side_effect = lambda url, path: writer(path) + out = load( + data_key, hydrate=hydrate, # test both str and Path for cache_dir cache_dir=str(tmp_path) if random() < 0.5 else tmp_path, @@ -69,41 +67,30 @@ def test_load_train_test( stdout, _stderr = capsys.readouterr() - expected_outs = [ - f"Downloading {key!r} from {figshare_urls[key]}" for key in data_names - ] - for expected_out in expected_outs: - assert expected_out in stdout + assert f"Downloading {data_key!r} from {figshare_urls[data_key]}" in stdout # check we called read_csv/read_json once for each data_name - assert read_json.call_count + read_csv.call_count == len(data_names) + assert urlretrieve.call_count == 1 - if len(data_names) > 1: - assert isinstance(out, dict) - assert list(out) == data_names - for key, df in out.items(): - assert isinstance(df, pd.DataFrame), f"{key} not a DataFrame but {type(df)}" - else: - assert isinstance(out, pd.DataFrame), f"{data_names[0]} not a DataFrame" + assert isinstance(out, pd.DataFrame), f"{data_key} not a DataFrame" # test that df loaded from cache is the same as initial df - from_cache = load_train_test(data_names, hydrate=hydrate, cache_dir=tmp_path) - if len(data_names) > 1: - for key, df in from_cache.items(): - pd.testing.assert_frame_equal(df, out[key]) - else: - pd.testing.assert_frame_equal(out, from_cache) + from_cache = load(data_key, hydrate=hydrate, cache_dir=tmp_path) + pd.testing.assert_frame_equal(out, from_cache) -def test_load_train_test_raises(tmp_path: Path) -> None: - # bad data name - with pytest.raises(ValueError, match=f"must be subset of {set(DATA_FILES)}"): - load_train_test(["bad-data-name"]) +def test_load_raises(tmp_path: Path) -> None: + data_key = "bad-key" + with pytest.raises(ValueError) as exc_info: + load(data_key) + + assert f"Unknown {data_key=}, must be one of {list(DATA_FILES)}" in str( + exc_info.value + ) - # bad_version version = "invalid-version" with pytest.raises(ValueError) as exc_info: - load_train_test("wbm_summary", version=version, cache_dir=tmp_path) + load("wbm_summary", version=version, cache_dir=tmp_path) assert ( str(exc_info.value) @@ -112,8 +99,8 @@ def test_load_train_test_raises(tmp_path: Path) -> None: assert os.listdir(tmp_path) == [], "cache_dir should be empty" -def test_load_train_test_doc_str() -> None: - doc_str = load_train_test.__doc__ +def test_load_doc_str() -> None: + doc_str = load.__doc__ assert isinstance(doc_str, str) # mypy type narrowing # check that we link to the right data description page @@ -144,7 +131,7 @@ def test_load_train_test_doc_str() -> None: @pytest.mark.parametrize( "file_key, version, expected_shape, expected_cols", [ - ("mp_elemental_ref_entries", figshare_versions[-1], (5, 89), set()), + ("mp_elemental_ref_entries", figshare_versions[-1], (9, 89), set()), pytest.param( "wbm_summary", figshare_versions[-1], @@ -158,11 +145,11 @@ def test_load_train_test_doc_str() -> None: figshare_versions[-1], (154718, 1), {"entry"}, - marks=pytest.mark.slow, + marks=pytest.mark.very_slow, ), ], ) -def test_load_train_test_no_mock( +def test_load_no_mock( file_key: str, version: str, expected_shape: tuple[int, int], @@ -173,7 +160,7 @@ def test_load_train_test_no_mock( assert os.listdir(tmp_path) == [], "cache_dir should be empty" # This function runs the download from Figshare for real hence takes some time and # requires being online - df = load_train_test(file_key, version=version, cache_dir=tmp_path) + df = load(file_key, version=version, cache_dir=tmp_path) assert len(os.listdir(tmp_path)) == 1, "cache_dir should have one file" assert df.shape == expected_shape assert ( @@ -191,7 +178,7 @@ def test_load_train_test_no_mock( # test that df loaded from cache is the same as initial df pd.testing.assert_frame_equal( - df, load_train_test(file_key, version=version, cache_dir=tmp_path) + df, load(file_key, version=version, cache_dir=tmp_path) ) stdout, stderr = capsys.readouterr()