diff --git a/src/mridle/extras/datasets/pandas_styler.py b/src/mridle/extras/datasets/pandas_styler.py index 1b6b4f4a..2e6efdca 100644 --- a/src/mridle/extras/datasets/pandas_styler.py +++ b/src/mridle/extras/datasets/pandas_styler.py @@ -10,6 +10,7 @@ import fsspec import numpy as np +from matplotlib.pyplot import plot as plt from typing import Any, Dict @@ -41,9 +42,15 @@ def _save(self, data) -> None: # using get_filepath_str ensures that the protocol and path are appended correctly for different filesystems save_path = get_filepath_str(self._get_save_path(), self._protocol) df = data.data # Extract the DataFrame from the styler - df_html = df.to_html(index=False) # Convert DataFrame to HTML - with open(save_path, "w") as f: - f.write(df_html) + #df_html = df.to_html(index=False) # Convert DataFrame to HTML + #with open(save_path, "w") as f: + # f.write(df_html) + + df = data.data # Extract the DataFrame from the styler + plt.figure(figsize=(10, 6)) + plt.axis('off') + plt.table(cellText=df.values, colLabels=df.columns, cellLoc='center', loc='center') + plt.savefig(save_path, bbox_inches='tight') def _describe(self) -> Dict[str, Any]: """Returns a dict that describes the attributes of the dataset."""