Skip to content

Commit

Permalink
Merge pull request #87 from WMD-group/publication
Browse files Browse the repository at this point in the history
Publication updates
  • Loading branch information
AntObi authored Aug 23, 2023
2 parents 8f0f6f3 + 9dfbd2e commit 3bcd214
Show file tree
Hide file tree
Showing 34 changed files with 2,710 additions and 104 deletions.
7 changes: 7 additions & 0 deletions Publication/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Publications

## Element similarity in high-dimensional materials representations

The folder `element_similarity` contains the scripts used to produce the results in the Publication: "Element similarity in high-dimensional materials representations"

Installing `SMACT==2.5.3` and `ElementEmbeddings==0.4` via pip should install all the necessary packages required to reproduce the results. We do provide all the packages which were installed at the time of running the notebooks for full transparency in the file `requirements_publication.txt`.
6 changes: 6 additions & 0 deletions Publication/element_similarity/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Element similarity in high-dimensional materials representations

* The notebook `similarity_periodic_trends/similarity_measures.ipynb` is used to reproduce the dis/similarity heatmaps and 2D projections
* The notebook `csp/data_radius_ratio_rules.ipynb` is used to get the dataset and calculate the radius ratio rules.
* The notebook `csp/cosine_embeddings.ipynb` is used to generate the pairwise cosine similarity values
* The notebook `csp/smact_sp` is used to perform the structure prediction using SMACT
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@
"import os\n",
"import seaborn as sns\n",
"\n",
"sns.set_context(\"paper\", font_scale=1.5)"
"sns.set_context(\"paper\", font_scale=1.5)\n",
"random_state = 42\n",
"reducer_params = {\"random_state\": random_state}\n",
"scatter_params = {\"s\": 80}"
]
},
{
Expand All @@ -31,17 +34,24 @@
"metadata": {},
"outputs": [],
"source": [
"# Set up the 7 embeddings\n",
"# Load the embeddings\n",
"cbfvs = [\n",
" \"magpie_sc\",\n",
" \"magpie\",\n",
" \"matscholar\",\n",
" \"mat2vec\",\n",
" \"megnet16\",\n",
" \"oliynyk\",\n",
" \"random_200\",\n",
" \"matscholar\",\n",
" \"oliynyk_sc\",\n",
" \"skipatom\",\n",
"]\n",
"element_embedddings = {cbfv: Embedding.load_data(cbfv) for cbfv in cbfvs}"
"element_embedddings = {cbfv: Embedding.load_data(cbfv) for cbfv in cbfvs}\n",
"\n",
"# Standardise\n",
"for embedding in element_embedddings.values():\n",
" print(f\"Attempting to standardise {embedding.embedding_name}...\")\n",
" print(f\" Already standardised: {embedding.is_standardised}\")\n",
" embedding.standardise(inplace=True)\n",
" print(f\"Now standardised: {embedding.is_standardised}\")"
]
},
{
Expand Down Expand Up @@ -79,6 +89,18 @@
"# del element_embedddings[\"skipatom\"].embeddings[\"Kr\"]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Which elements are missing for skipatom\n",
"set(element_embedddings[\"magpie\"].element_list) - set(\n",
" element_embedddings[\"skipatom\"].element_list\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -109,6 +131,7 @@
" show_axislabels=False,\n",
" ax=ax,\n",
" )\n",
" print(cbfv.embedding_name)\n",
" # plt.subplots_adjust(wspace=0.001)\n",
"axes[-1][-1].remove()\n",
"\n",
Expand Down Expand Up @@ -145,6 +168,70 @@
"fig.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Kr and Xe contribute to the distorted images for Skipatom."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"element_embedddings[\"skipatom_no_nobles\"] = Embedding.load_data(\"skipatom\")\n",
"\n",
"for el in [\"Xe\", \"Kr\"]:\n",
" del element_embedddings[\"skipatom_no_nobles\"].embeddings[el]\n",
"element_embedddings[\"skipatom_no_nobles\"].standardise(inplace=True)\n",
"element_embedddings[\"skipatom_no_nobles\"].embedding_name = \"skipatom (Xe,Kr removed)\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"fig, ax = plt.subplots(figsize=(20, 20))\n",
"heatmap_plotter(\n",
" embedding=element_embedddings[\"skipatom_no_nobles\"],\n",
" metric=\"euclidean\",\n",
" sortaxisby=\"atomic_number\",\n",
" show_axislabels=True,\n",
" ax=ax,\n",
")\n",
"\n",
"fig.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"fig, (axes) = plt.subplots(4, 2, figsize=(20, 20))\n",
"\n",
"for ax, cbfv in zip(axes.flatten(), element_embedddings.values()):\n",
" heatmap_plotter(\n",
" embedding=cbfv,\n",
" metric=\"euclidean\",\n",
" sortaxisby=\"atomic_number\",\n",
" show_axislabels=False,\n",
" ax=ax,\n",
" )\n",
" print(cbfv.embedding_name)\n",
" # plt.subplots_adjust(wspace=0.001)\n",
"\n",
"\n",
"fig.tight_layout()\n",
"fig.savefig(\"SI_euclidean.pdf\", bbox_inches=\"tight\")\n",
"fig.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand All @@ -169,7 +256,6 @@
" ax=ax,\n",
" )\n",
" # plt.subplots_adjust(wspace=0.001)\n",
"axes[-1][-1].remove()\n",
"\n",
"\n",
"fig.tight_layout()\n",
Expand Down Expand Up @@ -201,7 +287,6 @@
" ax=ax,\n",
" )\n",
" # plt.subplots_adjust(wspace=0.001)\n",
"axes[-1][-1].remove()\n",
"\n",
"\n",
"fig.tight_layout()\n",
Expand Down Expand Up @@ -233,7 +318,6 @@
" ax=ax,\n",
" )\n",
" # plt.subplots_adjust(wspace=0.001)\n",
"axes[-1][-1].remove()\n",
"\n",
"\n",
"fig.tight_layout()\n",
Expand Down Expand Up @@ -265,7 +349,6 @@
" ax=ax,\n",
" )\n",
" # plt.subplots_adjust(wspace=0.001)\n",
"axes[-1][-1].remove()\n",
"\n",
"\n",
"fig.tight_layout()\n",
Expand Down Expand Up @@ -299,7 +382,6 @@
" **heatmap_params\n",
" )\n",
" # plt.subplots_adjust(wspace=0.001)\n",
"axes[-1][-1].remove()\n",
"\n",
"\n",
"fig.tight_layout()\n",
Expand Down Expand Up @@ -333,7 +415,6 @@
" **heatmap_params\n",
" )\n",
" # plt.subplots_adjust(wspace=0.001)\n",
"axes[-1][-1].remove()\n",
"\n",
"\n",
"fig.tight_layout()\n",
Expand Down Expand Up @@ -367,7 +448,6 @@
" **heatmap_params\n",
" )\n",
" # plt.subplots_adjust(wspace=0.001)\n",
"axes[-1][-1].remove()\n",
"\n",
"\n",
"fig.tight_layout()\n",
Expand Down Expand Up @@ -408,9 +488,11 @@
" n_components=2,\n",
" ax=ax,\n",
" adjusttext=True,\n",
" reducer_params=reducer_params,\n",
" scatter_params=scatter_params,\n",
" )\n",
" ax.legend().remove()\n",
"axes[-1][-1].remove()\n",
"\n",
"\n",
"handles, labels = ax.get_legend_handles_labels()\n",
"fig.legend(handles, labels, bbox_to_anchor=(0.54, 1.06), loc=\"upper center\", ncol=5)\n",
Expand Down Expand Up @@ -445,9 +527,11 @@
" n_components=2,\n",
" ax=ax,\n",
" # adjusttext=True,\n",
" reducer_params=reducer_params,\n",
" scatter_params=scatter_params,\n",
" )\n",
" ax.legend().remove()\n",
"axes[-1][-1].remove()\n",
"\n",
"\n",
"handles, labels = ax.get_legend_handles_labels()\n",
"fig.legend(handles, labels, bbox_to_anchor=(0.54, 1.06), loc=\"upper center\", ncol=5)\n",
Expand Down Expand Up @@ -482,9 +566,11 @@
" n_components=2,\n",
" ax=ax,\n",
" adjusttext=True,\n",
" reducer_params=reducer_params,\n",
" scatter_params=scatter_params,\n",
" )\n",
" ax.legend().remove()\n",
"axes[-1][-1].remove()\n",
"\n",
"\n",
"handles, labels = ax.get_legend_handles_labels()\n",
"fig.legend(handles, labels, bbox_to_anchor=(0.54, 1.06), loc=\"upper center\", ncol=5)\n",
Expand Down Expand Up @@ -538,7 +624,7 @@
" ax.set_xlabel(\"Pearson correlation\")\n",
" ax.set_ylabel(\"Count\")\n",
"\n",
"axes[-1][-1].remove()\n",
"\n",
"plt.tight_layout()\n",
"plt.savefig(\"SI_pearson_distribution.pdf\")"
]
Expand Down Expand Up @@ -566,7 +652,7 @@
" ax.set_xlabel(\"Cosine similarity\")\n",
" ax.set_ylabel(\"Count\")\n",
"\n",
"axes[-1][-1].remove()\n",
"\n",
"plt.tight_layout()\n",
"plt.savefig(\"SI_cosine_similarity_distribution.pdf\")"
]
Expand Down
Loading

0 comments on commit 3bcd214

Please sign in to comment.