Skip to content

Commit

Permalink
🐛 Fixing bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
amilworks committed Sep 13, 2023
1 parent b09d11c commit 9746c3d
Show file tree
Hide file tree
Showing 7 changed files with 231 additions and 56 deletions.
2 changes: 1 addition & 1 deletion cells/streamlit/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ RUN pip install --no-cache-dir poetry && \
poetry install --no-dev

RUN pip install git+https://github.com/geomstats/geomstats.git && \
pip install pacmap
pip install pacmap open3d

# Copy the entire project directory into the container
COPY cellgeometry .
Expand Down
2 changes: 1 addition & 1 deletion cells/streamlit/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ build:

# Runs the docker container
run:
@docker run -itp 8501:8501 -d --rm --name streamlit -v ./cellgeometry:/app amilworks/streamlit-geomstats
@docker run -itp 8501:8501 --network pacmap2_default -d --rm --name streamlit -v ./cellgeometry:/app amilworks/streamlit-geomstats

# Calls all the above targets
all: stop build run
Expand Down
3 changes: 3 additions & 0 deletions cells/streamlit/cellgeometry/pages/1-Load_Data.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,9 @@ def handle_uploaded_file(uploaded_file, destination_folder):
st.session_state["selected_dataset"] = new_upload_folder
handle_uploaded_file(uploaded_file, new_upload_folder)
else:
# st.session_state["selected_dataset"] = os.path.join(
# upload_folder, uploaded_file.name
# )
handle_uploaded_file(uploaded_file, upload_folder)

build_and_load_data(st.session_state["selected_dataset"])
Expand Down
4 changes: 2 additions & 2 deletions cells/streamlit/cellgeometry/pages/2-Mean_Shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,8 +284,8 @@ def find_indices_with_selected_string(string_list, selected_string):
for cell in cell_shapes:
fig.add_trace(
go.Scatter(
x=cell[:, 0],
y=cell[:, 1],
x=cell[:250, 0],
y=cell[:250, 1],
mode="lines",
line=dict(color="lightgrey", width=1),
)
Expand Down
197 changes: 159 additions & 38 deletions cells/streamlit/cellgeometry/pages/3-PACMAP.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
from sklearn.decomposition import PCA
import pacmap
import plotly.express as px

import requests
import numpy as np
import open3d as o3d
import plotly.graph_objects as go
from plotly.subplots import make_subplots

Expand All @@ -23,22 +25,38 @@
"""
)

if "cell_shapes" not in st.session_state:
st.warning(
"👈 Have you uploaded a zipped file of ROIs under Load Data? Afterwards, go the the Mean Shape page and run the analysis there."
)
st.stop()
# cells = st.session_state["cells"]
cell_shapes = st.session_state["cell_shapes"]
# st.help(pacmap.PaCMAP)


# # Send POST request
# response = requests.post(url, json=data)
# st.write("Sent request to server")
# # Check if the request was successful
# if response.status_code == 200:
# transformed_data = response.json().get("transformed_data")
# print("Transformed Data:", transformed_data)
# else:
# print(f"Error {response.status_code}: {response.text}")


# if "cell_shapes" not in st.session_state:
# st.warning(
# "👈 Have you uploaded a zipped file of ROIs under Load Data? Afterwards, go the the Mean Shape page and run the analysis there."
# )
# st.stop()
# # cells = st.session_state["cells"]
# cell_shapes = st.session_state["cell_shapes"]


if st.session_state["cell_lines"] is not None:
cell_lines = st.session_state["cell_lines"]
if st.session_state["treatment"] is not None:
treatment = st.session_state["treatment"]
# if st.session_state["cell_lines"] is not None:
# cell_lines = st.session_state["cell_lines"]
# if st.session_state["treatment"] is not None:
# treatment = st.session_state["treatment"]


# cells_flat = gs.reshape(cell_shapes, (len(cell_shapes), -1))


cells_flat = gs.reshape(cell_shapes, (len(cell_shapes), -1))
# st.write("Cells flat", cells_flat.shape)

# R1 = Euclidean(dim=1)
Expand All @@ -48,45 +66,25 @@

# n_components = st.slider("Select the Number of Sampling Points", 0, len(cells_flat), 10)


# st.write(treatment.shape)
# Perform PacMap dimensionality reduction
model = pacmap.PaCMAP()

# st.help(pacmap.PaCMAP)

runPacmap = st.toggle("Run PACMAP Analysis", False)

if runPacmap:

embedding = model.fit_transform(cells_flat)
# st.write(embedding.shape)
# st.write(cell_lines.shape)

# Visualize the embedding using Plotly Express
# Create a scatter plot with coloring based on 'cell_lines' and symbols based on 'treatments'
fig = px.scatter(
x=embedding[:, 0],
y=embedding[:, 1],
color=gs.squeeze(cell_lines), # differentiate by color based on cell_lines
symbol=gs.squeeze(treatment), # differentiate by symbol based on treatments
title="PacMap Embedding",
labels={"x": "Dimension 1", "y": "Dimension 2"},
color_discrete_sequence=px.colors.qualitative.Set1, # use a color palette
)

# Update layout for better clarity, if needed
fig.update_layout(legend_title_text="Cell Lines", legend_itemsizing="constant")

# Display the Plotly figure in Streamlit
st.plotly_chart(fig)

col1, col2, col3 = st.columns(3)

with col1:
st.markdown("##### Number of Components")
n_components = st.number_input(
"Default = 2",
min_value=2,
"Default = 3",
min_value=3,
max_value=None,
key="n_components",
)
st.write("Input dimensions of the embedded space.", n_components)

Expand All @@ -112,6 +110,78 @@
)

st.write("Select distance metric.")
model = pacmap.PaCMAP(n_components=st.session_state["n_components"])
embedding = model.fit_transform([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]])
st.write(embedding.shape)
# st.write(cell_lines.shape)

# Visualize the embedding using Plotly Express
# Create a scatter plot with coloring based on 'cell_lines' and symbols based on 'treatments'
fig = px.scatter_3d(
x=embedding[:, 0],
y=embedding[:, 1],
z=embedding[:, 2],
color=gs.squeeze(cell_lines), # differentiate by color based on cell_lines
symbol=gs.squeeze(treatment), # differentiate by symbol based on treatments
title="PacMap Embedding",
labels={"x": "Dimension 1", "y": "Dimension 2", "z": "Dimension 3"},
color_discrete_sequence=px.colors.qualitative.Light24, # use a color palette
)

# Update layout for better clarity, if needed
fig.update_layout(legend_title_text="Cell Lines", legend_itemsizing="constant")

# Display the Plotly figure in Streamlit
st.plotly_chart(fig)

st.stop()
# Create an Open3D PointCloud object
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(embedding[:, :3])
pcd.estimate_normals()
# Compute a mesh from the point cloud using the Ball-Pivoting Algorithm
radii = [0.005, 0.01, 0.02, 0.04]
mesh = o3d.geometry.TriangleMesh.create_from_point_cloud_ball_pivoting(
pcd, o3d.utility.DoubleVector(radii)
)
# Compute normals for the point cloud
# pcd.estimate_normals()

# Access the vertices and faces
vertices = np.asarray(mesh.vertices)
# triangles = np.asarray(mesh.triangles)
from scipy.spatial import Delaunay

tri = Delaunay(vertices)
triangles = tri.simplices
st.write(
"""
### 3D Mesh Reconstruction from Reduced Dimension Embeddings
#### Description
The plot visualizes a 3D mesh that's been reconstructed from reduced dimension embeddings. Initially, a point cloud is formed from these embeddings. From this point cloud, a mesh is generated using the Ball-Pivoting Algorithm, which identifies possible triangles by virtually rolling a ball over the point cloud. The final visualization provides a geometric representation of the relationships and structures within the embeddings, showcasing patterns and clusters that might not be evident in a simple scatter plot.
"""
)
# fig = go.Figure(data=[go.Scatter3d(x=vertices[:, 0], y=vertices[:, 1], z=vertices[:, 2], mode='markers')])
# Create a Mesh3D plot
fig = go.Figure(
data=[
go.Mesh3d(
x=vertices[:, 0],
y=vertices[:, 1],
z=vertices[:, 2],
i=triangles[:, 0], # i, j, k define the vertices of the triangles
j=triangles[:, 1],
k=triangles[:, 2],
opacity=0.9,
)
]
)
# Set plot layout
# fig.update_layout(scene=dict(aspectmode='data'))

# Show the plot
st.plotly_chart(fig)

st.markdown(
""" ### Background on PACMAP
Expand All @@ -123,6 +193,57 @@
"""
)
# pcas = {}
st.header("Random Forest Classifier")

from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import pandas as pd

selected_label = st.radio(
"Select the labels to train the classifier", ("cell_lines", "treatment")
)
# Convert data to DataFrame
df = pd.DataFrame(embedding)
X = df.values # Features (PCA components)
y = st.session_state[selected_label] # Labels

# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.3, random_state=42
)

# Initialize and train the Random Forest classifier
clf = RandomForestClassifier(n_estimators=100, random_state=42)
clf.fit(X_train, y_train)

# Predict on the test set
y_pred = clf.predict(X_test)

# Calculate accuracy
accuracy = accuracy_score(y_test, y_pred)
st.write(f"Accuracy: {accuracy * 100:.2f}%")

# Extract feature importances
feature_importances = clf.feature_importances_
df_importance = pd.DataFrame(
{
"Features": list(range(n_components)),
"Importance": feature_importances,
# 'Treatment': y # Assuming all treatments in the sample data are the same; adjust as needed
}
)

# Create the plot
fig = px.bar(
df_importance,
x="Features",
y="Importance",
title="Feature Importances of PCA Components by Treatment",
labels={"Importance": "Importance Value", "Features": "PCA Components"},
)

st.plotly_chart(fig)

# st.write(mean.estimate_)
# logs = CURVES_SPACE_SRV.metric.log(cells_flat, base_point=mean.estimate_)
Expand Down
42 changes: 42 additions & 0 deletions cells/streamlit/cellgeometry/pages/4-3D_Cell_Segmentation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import requests
import json
import streamlit as st
import numpy as np

# Define the API endpoint
url = "http://fastapi-app:8000/pacmap"

data = np.random.rand(590, 20)
data = data.tolist()
st.write(f"Data: {data}")
# Define the data you wish to send
data = {
"data": data,
}

# Send a POST request to the API endpoint
response = requests.post(url, json=data)

# Check the response
if response.status_code == 200:
reduced_data = response.json().get("reduced_data")
st.write(f"Reduced Data: {reduced_data}")
else:
st.write(f"Error {response.status_code}: {response.text}")


############################################################################################################
# Yes! you have the discrete_surfaces.py that is the generalization of discrete_curves.py
# https://github.com/geomstats/geomstats/blob/master/geomstats/geometry/discrete_surfaces.py
# 1:12
# You can create the manifold of surfaces, and it also has an elastic metric.
# Basically, substituting:
# manifold = DiscreteCurves(…)
# by
# manifold = DiscreteSurfaces(…)
# Could be the only modification you need to make to your current code.
# Beware that the surface code is currently quite slow

# Here is a simple example of discrete surface (a cube) that you can load an play with:
# https://github.com/geomstats/geomstats/blob/2eeee177044e38080cc1004ae4a0bf8dd9ceb601/geomstats/datasets/utils.py#L452
############################################################################################################
37 changes: 23 additions & 14 deletions cells/streamlit/cellgeometry/utils/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,20 +212,29 @@ def parse_coordinates(file_path):
line = line.strip()

if line:
try:
x, y = map(int, line.split())
cell_data.append([x, y])
except ValueError:
print(f"Skipping invalid line: {line}")
else:
if cell_data:
coordinates[cell_id] = np.array(cell_data)
cell_id += 1
cell_data = []

# Handle the last cell if it doesn't have a line break after it
if cell_data:
coordinates[cell_id] = np.array(cell_data)
# Check if the line contains a comma; if yes, split by comma, else split by space
if "," in line:
try:
x, y = map(float, line.split(","))
cell_data.append([x, y])
except ValueError:
print(f"Skipping invalid line: {line}")
else:
try:
x, y = map(float, line.split())
cell_data.append([x, y])
except ValueError:
print(f"Skipping invalid line: {line}")
elif (
cell_data
): # if line is empty and cell_data has data, add to dictionary
coordinates[cell_id] = np.array(cell_data)
cell_id += 1
cell_data = []

# Handle the last cell if it doesn't have a line break after it
if cell_data:
coordinates[cell_id] = np.array(cell_data)

return coordinates

Expand Down

0 comments on commit 9746c3d

Please sign in to comment.