Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import numpy as np | |
| import pandas as pd | |
| import plotly.express as px | |
| import plotly.subplots as sp | |
| from datasets import load_dataset | |
| import umap | |
| # Load dataset function | |
| def load_dataset_from_hub(dataset_name, split="test"): | |
| try: | |
| return load_dataset(dataset_name, split=split), None | |
| except Exception as e: | |
| return None, str(e) | |
| # Create visualization function | |
| def create_visualization(split, color_col, log): | |
| # Load the dataset | |
| dataset, error = load_dataset_from_hub("Smith42/galaxies_with_embeddings", split) | |
| if error: | |
| return None, f"Error loading dataset: {error}" | |
| try: | |
| embedding_cols = ["p16k00_pca", "p16k01_pca", "p16k10_pca"] | |
| # Extract embeddings and color values | |
| embeddings = dataset.select_columns(embedding_cols) | |
| colors = np.array(dataset[color_col], dtype=float) | |
| if log: colors = np.log(colors) | |
| fig = sp.make_subplots(cols=3, subplot_titles=["k = 0%", "k = 1%", "k = 10%"]) | |
| ii = 0 | |
| for col in range(1, 4): | |
| embedding_col = embedding_cols[ii] | |
| emb_ar = np.array(embeddings[embedding_col]) | |
| df = pd.DataFrame({ | |
| 'x': emb_ar[:, 0], | |
| 'y': emb_ar[:, 1], | |
| 'color': colors | |
| }).dropna() | |
| scatter = px.scatter(df, x='x', y='y', color='color') | |
| fig.add_trace(scatter.data[0], row=1, col=col) | |
| ii = ii + 1 | |
| return fig, None | |
| except Exception as e: | |
| return None, f"Error creating viz: {str(e)}" | |
| property_groups = { | |
| "Basic Identifiers": [ | |
| "dr8_id", "ra", "dec", "brickid", "objid", "file_name", "iauname" | |
| ], | |
| "Galaxy Morphology": [ | |
| "smooth-or-featured_smooth_fraction", "smooth-or-featured_featured-or-disk_fraction", | |
| "smooth-or-featured_artifact_fraction", "disk-edge-on_yes_fraction", "disk-edge-on_no_fraction", | |
| "has-spiral-arms_yes_fraction", "has-spiral-arms_no_fraction", | |
| "bar_strong_fraction", "bar_weak_fraction", "bar_no_fraction", | |
| "bulge-size_dominant_fraction", "bulge-size_large_fraction", "bulge-size_moderate_fraction", | |
| "bulge-size_small_fraction", "bulge-size_none_fraction", | |
| "how-rounded_round_fraction", "how-rounded_in-between_fraction", "how-rounded_cigar-shaped_fraction", | |
| "edge-on-bulge_boxy_fraction", "edge-on-bulge_none_fraction", "edge-on-bulge_rounded_fraction", | |
| "spiral-winding_tight_fraction", "spiral-winding_medium_fraction", "spiral-winding_loose_fraction", | |
| "spiral-arm-count_1_fraction", "spiral-arm-count_2_fraction", "spiral-arm-count_3_fraction", | |
| "spiral-arm-count_4_fraction", "spiral-arm-count_more-than-4_fraction", "spiral-arm-count_cant-tell_fraction", | |
| "merging_none_fraction", "merging_minor-disturbance_fraction", "merging_major-disturbance_fraction", | |
| "merging_merger_fraction" | |
| ], | |
| "Physical Size Parameters": [ | |
| "est_petro_th50", "est_petro_th50_kpc", "petro_theta", "petro_th50", "petro_th90", | |
| "petro_phi50", "petro_phi90", "petro_ba50", "petro_ba90", | |
| "elpetro_ba", "elpetro_phi", "elpetro_flux_r", "elpetro_theta_r" | |
| ], | |
| "Photometric Properties": [ | |
| "mag_r_desi", "mag_g_desi", "mag_z_desi", | |
| "mag_f", "mag_n", "mag_u", "mag_g", "mag_r", "mag_i", "mag_z", | |
| "u_minus_r", "sersic_n", "sersic_ba", "sersic_phi", | |
| "elpetro_absmag_f", "elpetro_absmag_n", "elpetro_absmag_u", | |
| "elpetro_absmag_g", "elpetro_absmag_r", "elpetro_absmag_i", "elpetro_absmag_z", | |
| "sersic_nmgy_f", "sersic_nmgy_n", "sersic_nmgy_u", "sersic_nmgy_g", | |
| "sersic_nmgy_r", "sersic_nmgy_i", "sersic_nmgy_z" | |
| ], | |
| "Mass and Redshift": [ | |
| "elpetro_mass", "elpetro_mass_log", "redshift", "redshift_nsa", | |
| "redshift_ossy", "photo_z", "photo_zerr", "spec_z" | |
| ], | |
| "Star Formation Properties": [ | |
| "fibre_sfr_avg", "fibre_sfr_entropy", "fibre_sfr_median", "fibre_sfr_mode", | |
| "fibre_sfr_p16", "fibre_sfr_p2p5", "fibre_sfr_p84", "fibre_sfr_p97p5", | |
| "fibre_ssfr_avg", "fibre_ssfr_entropy", "fibre_ssfr_median", "fibre_ssfr_mode", | |
| "fibre_ssfr_p16", "fibre_ssfr_p2p5", "fibre_ssfr_p84", "fibre_ssfr_p97p5", | |
| "total_ssfr_avg", "total_ssfr_entropy", "total_ssfr_flag", "total_ssfr_median", | |
| "total_ssfr_mode", "total_ssfr_p16", "total_ssfr_p2p5", "total_ssfr_p84", | |
| "total_ssfr_p97p5", "total_sfr_avg", "total_sfr_entropy", "total_sfr_flag", | |
| "total_sfr_median", "total_sfr_mode", "total_sfr_p16", "total_sfr_p2p5", | |
| "total_sfr_p84", "total_sfr_p97p5" | |
| ], | |
| "AGN Properties": [ | |
| "log_l_oiii", "fwhm", "e_fwhm", "equiv_width", "log_l_ha", | |
| "log_m_bh", "upper_e_log_m_bh", "lower_e_log_m_bh", "log_bolometric_l" | |
| ], | |
| "HI Properties": [ | |
| "W50", "sigW", "W20", "HIflux", "sigflux", "SNR", "RMS", | |
| "Dist", "sigDist", "logMH", "siglogMH" | |
| ], | |
| "PhotoZ Catalog": [ | |
| "photoz_id", "ra_photoz", "dec_photoz", "mag_abs_g_photoz", "mag_abs_r_photoz", | |
| "mag_abs_z_photoz", "mass_inf_photoz", "mass_med_photoz", "mass_sup_photoz", | |
| "sfr_inf_photoz", "sfr_sup_photoz", "ssfr_inf_photoz", "ssfr_med_photoz", | |
| "ssfr_sup_photoz", "sky_separation_arcsec_from_photoz" | |
| ] | |
| } | |
| # Define the Gradio interface | |
| with gr.Blocks(title="Galaxy embeddings") as demo: | |
| gr.Markdown("# Sparse galaxy embeddings") | |
| with gr.Row(): | |
| split_input = gr.Dropdown( | |
| label="Split", | |
| value="test", | |
| choices=["test", "validation"] | |
| ) | |
| group_dropdown = gr.Dropdown( | |
| label="Property category", | |
| choices=list(property_groups.keys()), | |
| value=list(property_groups.keys())[0] | |
| ) | |
| color_col = gr.Dropdown( | |
| label="Property", | |
| choices=property_groups[list(property_groups.keys())[0]] | |
| ) | |
| log = gr.Checkbox( | |
| label="Take log?", | |
| value=False | |
| ) | |
| visualize_btn = gr.Button("Let's go!") | |
| error_output = gr.Textbox(label="Errors", visible=False) | |
| def update_properties(group): | |
| return gr.update(choices=property_groups[group], value=property_groups[group][0]) | |
| group_dropdown.change( | |
| fn=update_properties, | |
| inputs=[group_dropdown], | |
| outputs=[color_col] | |
| ) | |
| with gr.Row(): | |
| plot_output = gr.Plot(label="Visualization") | |
| visualize_btn.click( | |
| fn=create_visualization, | |
| inputs=[split_input, color_col, log], | |
| outputs=[plot_output, error_output] | |
| ) | |
| demo.launch() | |