From e99cfb4d3ef394084747cfb5410a8342d41b271d Mon Sep 17 00:00:00 2001 From: lucidrains Date: Sat, 21 Sep 2024 15:15:17 -0700 Subject: [PATCH] gradio app should only be created after cli is invoked --- alphafold3_pytorch/app.py | 325 +++++++++++++++++++------------------- pyproject.toml | 2 +- 2 files changed, 164 insertions(+), 163 deletions(-) diff --git a/alphafold3_pytorch/app.py b/alphafold3_pytorch/app.py index e8a8f8ad..fdc280ec 100644 --- a/alphafold3_pytorch/app.py +++ b/alphafold3_pytorch/app.py @@ -70,178 +70,179 @@ def delete_cache(request: gr.Request): shutil.rmtree(str(user_dir)) -with gr.Blocks(delete_cache=(600, 3600)) as gradio_app: - entities = gr.State([]) - - with gr.Row(): - gr.Markdown("### AlphaFold3 PyTorch Web UI") - - with gr.Row(): - gr.Column(scale=8) - # upload_json_button = gr.Button("Upload JSON", scale=1, min_width=100) - clear_button = gr.Button("Clear", scale=1, min_width=100) - - with gr.Row(): - with gr.Column(scale=1, min_width=150): - mtype = gr.Dropdown( - value="Protein", - label="Molecule type", - choices=["Protein", "DNA", "RNA", "Ligand", "Ion"], - interactive=True, - ) - with gr.Column(scale=1, min_width=80): - c = gr.Number( - value=1, - label="Copies", - interactive=True, - ) - - with gr.Column(scale=8, min_width=200): - - @gr.render(inputs=mtype) - def render_sequence(mol_type): - if mol_type in ["Protein", "DNA", "RNA"]: - seq = gr.Textbox( - label="Paste sequence or fasta", - placeholder="Input", - interactive=True, - ) - elif mol_type == "Ligand": - seq = gr.Dropdown( - label="Select ligand", - choices=[ - "ADP - Adenosine disphosphate", - "ATP - Adenosine triphosphate", - "AMP - Adenosine monophosphate", - "GTP - Guanosine-5'-triphosphate", - "GDP - Guanosine-5'-diphosphate", - "FAD - Flavin adenine dinucleotide", - "NAD - Nicotinamide-adenine-dinucleotide", - "NAP - Nicotinamide-adenine-dinucleotide phosphate (NADP)", - "NDP - Dihydro-nicotinamide-adenine-dinucleotide-phosphate (NADPH)", - "HEM - Heme", - "HEC - Heme C", - "OLA - Oleic acid", - "MYR - Myristic acid", - "CIT - Citric acid", - "CLA - Chlorophyll A", - "CHL - Chlorophyll B", - "BCL - Bacteriochlorophyll A", - "BCB - Bacteriochlorophyll B", - ], - interactive=True, - ) - elif mol_type == "Ion": - seq = gr.Dropdown( - label="Select ion", - choices=[ - "Mg²⁺", - "Zn²⁺", - "Cl⁻", - "Ca²⁺", - "Na⁺", - "Mn²⁺", - "K⁺", - "Fe³⁺", - "Cu²⁺", - "Co²⁺", - ], - interactive=True, - ) - - add_button.click(add_entity, inputs=[entities, mtype, c, seq], outputs=[entities]) - clear_button.click(lambda: ("Protein", 1, None), None, outputs=[mtype, c, seq]) - - add_button = gr.Button("Add entity", scale=1, min_width=100) - - def add_entity(entities, mtype="Protein", c=1, seq=""): - if seq is None or len(seq) == 0: - gr.Info("Input required") - return entities - - seq_norm = seq.strip(" \t\n\r").upper() - - if mtype in ["Protein", "DNA", "RNA"]: - if mtype == "Protein" and any([x not in "ARDCQEGHILKMNFPSTWYV" for x in seq_norm]): - gr.Info("Invalid protein sequence. Allowed characters: A, R, D, C, Q, E, G, H, I, L, K, M, N, F, P, S, T, W, Y, V") - return entities - - if mtype == "DNA" and any([x not in "ACGT" for x in seq_norm]): - gr.Info("Invalid DNA sequence. Allowed characters: A, C, G, T") - return entities - - if mtype == "RNA" and any([x not in "ACGU" for x in seq_norm]): - gr.Info("Invalid RNA sequence. Allowed characters: A, C, G, U") - return entities - - if len(seq) < 4: - gr.Info("Minimum 4 characters required") - return entities - - elif mtype == "Ligand": - if seq is None or len(seq) == 0: - gr.Info("Select a ligand") - return entities - seq_norm = seq.split(" - ")[0] - elif mtype == "Ion": - if seq is None or len(seq) == 0: - gr.Info("Select an ion") - return entities - seq_norm = "".join([x for x in seq if x.isalpha()]) - - new_entity = {"mol_type": mtype, "num_copies": c, "sequence": seq_norm} - - return entities + [new_entity] - - @gr.render(inputs=entities) - def render_entities(entity_list): - for idx, entity in enumerate(entity_list): - with gr.Row(): - gr.Text( - value=entity["mol_type"], - label="Type", - scale=1, - min_width=90, - interactive=False, +def start_gradio_app(): + with gr.Blocks(delete_cache=(600, 3600)) as gradio_app: + entities = gr.State([]) + + with gr.Row(): + gr.Markdown("### AlphaFold3 PyTorch Web UI") + + with gr.Row(): + gr.Column(scale=8) + # upload_json_button = gr.Button("Upload JSON", scale=1, min_width=100) + clear_button = gr.Button("Clear", scale=1, min_width=100) + + with gr.Row(): + with gr.Column(scale=1, min_width=150): + mtype = gr.Dropdown( + value="Protein", + label="Molecule type", + choices=["Protein", "DNA", "RNA", "Ligand", "Ion"], + interactive=True, ) - gr.Text( - value=entity["num_copies"], + with gr.Column(scale=1, min_width=80): + c = gr.Number( + value=1, label="Copies", - scale=1, - min_width=80, - interactive=False, + interactive=True, ) - sequence = entity["sequence"] - if entity["mol_type"] not in ["Ligand", "Ion"]: - # Split every 10 characters, and add a \t after each split - sequence = "\t".join([sequence[i : i + 10] for i in range(0, len(sequence), 10)]) - - gr.Text( - value=sequence, - label="Sequence", - placeholder="Input", - scale=7, - min_width=200, - interactive=False, - ) + with gr.Column(scale=8, min_width=200): + + @gr.render(inputs=mtype) + def render_sequence(mol_type): + if mol_type in ["Protein", "DNA", "RNA"]: + seq = gr.Textbox( + label="Paste sequence or fasta", + placeholder="Input", + interactive=True, + ) + elif mol_type == "Ligand": + seq = gr.Dropdown( + label="Select ligand", + choices=[ + "ADP - Adenosine disphosphate", + "ATP - Adenosine triphosphate", + "AMP - Adenosine monophosphate", + "GTP - Guanosine-5'-triphosphate", + "GDP - Guanosine-5'-diphosphate", + "FAD - Flavin adenine dinucleotide", + "NAD - Nicotinamide-adenine-dinucleotide", + "NAP - Nicotinamide-adenine-dinucleotide phosphate (NADP)", + "NDP - Dihydro-nicotinamide-adenine-dinucleotide-phosphate (NADPH)", + "HEM - Heme", + "HEC - Heme C", + "OLA - Oleic acid", + "MYR - Myristic acid", + "CIT - Citric acid", + "CLA - Chlorophyll A", + "CHL - Chlorophyll B", + "BCL - Bacteriochlorophyll A", + "BCB - Bacteriochlorophyll B", + ], + interactive=True, + ) + elif mol_type == "Ion": + seq = gr.Dropdown( + label="Select ion", + choices=[ + "Mg²⁺", + "Zn²⁺", + "Cl⁻", + "Ca²⁺", + "Na⁺", + "Mn²⁺", + "K⁺", + "Fe³⁺", + "Cu²⁺", + "Co²⁺", + ], + interactive=True, + ) + + add_button.click(add_entity, inputs=[entities, mtype, c, seq], outputs=[entities]) + clear_button.click(lambda: ("Protein", 1, None), None, outputs=[mtype, c, seq]) + + add_button = gr.Button("Add entity", scale=1, min_width=100) + + def add_entity(entities, mtype="Protein", c=1, seq=""): + if seq is None or len(seq) == 0: + gr.Info("Input required") + return entities + + seq_norm = seq.strip(" \t\n\r").upper() + + if mtype in ["Protein", "DNA", "RNA"]: + if mtype == "Protein" and any([x not in "ARDCQEGHILKMNFPSTWYV" for x in seq_norm]): + gr.Info("Invalid protein sequence. Allowed characters: A, R, D, C, Q, E, G, H, I, L, K, M, N, F, P, S, T, W, Y, V") + return entities + + if mtype == "DNA" and any([x not in "ACGT" for x in seq_norm]): + gr.Info("Invalid DNA sequence. Allowed characters: A, C, G, T") + return entities + + if mtype == "RNA" and any([x not in "ACGU" for x in seq_norm]): + gr.Info("Invalid RNA sequence. Allowed characters: A, C, G, U") + return entities + + if len(seq) < 4: + gr.Info("Minimum 4 characters required") + return entities + + elif mtype == "Ligand": + if seq is None or len(seq) == 0: + gr.Info("Select a ligand") + return entities + seq_norm = seq.split(" - ")[0] + elif mtype == "Ion": + if seq is None or len(seq) == 0: + gr.Info("Select an ion") + return entities + seq_norm = "".join([x for x in seq if x.isalpha()]) + + new_entity = {"mol_type": mtype, "num_copies": c, "sequence": seq_norm} + + return entities + [new_entity] + + @gr.render(inputs=entities) + def render_entities(entity_list): + for idx, entity in enumerate(entity_list): + with gr.Row(): + gr.Text( + value=entity["mol_type"], + label="Type", + scale=1, + min_width=90, + interactive=False, + ) + gr.Text( + value=entity["num_copies"], + label="Copies", + scale=1, + min_width=80, + interactive=False, + ) - del_button = gr.Button("🗑️", scale=0, min_width=50) + sequence = entity["sequence"] + if entity["mol_type"] not in ["Ligand", "Ion"]: + # Split every 10 characters, and add a \t after each split + sequence = "\t".join([sequence[i : i + 10] for i in range(0, len(sequence), 10)]) + + gr.Text( + value=sequence, + label="Sequence", + placeholder="Input", + scale=7, + min_width=200, + interactive=False, + ) - def delete(entity_id=idx): - entity_list.pop(entity_id) - return entity_list + del_button = gr.Button("🗑️", scale=0, min_width=50) - del_button.click(delete, None, outputs=[entities]) + def delete(entity_id=idx): + entity_list.pop(entity_id) + return entity_list - pred_button = gr.Button("Predict", scale=1, min_width=100) - output_mol = Molecule3D(label="Output structure", config={"backgroundColor": "black"}) + del_button.click(delete, None, outputs=[entities]) - pred_button.click(fold, inputs=entities, outputs=output_mol) - clear_button.click(lambda: ([], None), None, outputs=[entities, output_mol]) + pred_button = gr.Button("Predict", scale=1, min_width=100) + output_mol = Molecule3D(label="Output structure", config={"backgroundColor": "black"}) - gradio_app.unload(delete_cache) + pred_button.click(fold, inputs=entities, outputs=output_mol) + clear_button.click(lambda: ([], None), None, outputs=[entities, output_mol]) + gradio_app.unload(delete_cache) + gradio_app.launch() # cli @click.command() @@ -271,4 +272,4 @@ def app(checkpoint: str, cache_dir: str, precision: str): # dtype = torch.float32 # model.to(device, dtype=dtype) - gradio_app.launch() + start_gradio_app() diff --git a/pyproject.toml b/pyproject.toml index fc00300b..aa065db1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "alphafold3-pytorch" -version = "0.5.31" +version = "0.5.32" description = "Alphafold 3 - Pytorch" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" },