Skip to content

Commit

Permalink
🎨 Drawing 3D plot keeping the original colours (#10)
Browse files Browse the repository at this point in the history
  • Loading branch information
Starlitnightly authored Jun 8, 2024
1 parent 8d49e76 commit e2c3a63
Showing 1 changed file with 21 additions and 4 deletions.
25 changes: 21 additions & 4 deletions scSLAT/viz/multi_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,14 +82,16 @@ def __init__(self,adatas:List[AnnData],
self.loc_list = []
self.anno_list = []
for adata in adatas:
loc = adata.obsm[spatial_key]
loc = adata.obsm[spatial_key].copy()
if scale_coordinate:
for i in range(2):
loc[:,i] = (loc[:,i]-np.min(loc[:,i]))/(np.max(loc[:,i])-np.min(loc[:,i]))
anno = adata.obs[anno_key]
self.loc_list.append(loc)
self.anno_list.append(anno)

self.adatas = adatas
self.anno_key=anno_key
self.celltypes = set(pd.concat(self.anno_list))
self.subsample_size = subsample_size

Expand Down Expand Up @@ -130,17 +132,31 @@ def draw_3D(self,
ax = fig.add_subplot(111, projection='3d')
ax.set_box_aspect([1, 1, height_scale * len(self.mappings)])
# color by different cell types



color = get_color(len(self.celltypes))
c_map = {}
for i, celltype in enumerate(self.celltypes):
c_map[celltype] = color[i]
for j, mapping in enumerate(self.mappings):
print(f"Mapping {j}th layer ")
# plot cells
for i, (layer, anno) in enumerate(zip(self.loc_list[j:j+2], self.anno_list[j:j+2])):
for i, (layer, anno,ad) in enumerate(zip(self.loc_list[j:j+2], self.anno_list[j:j+2],self.adatas[j:j+2])):
if i==0 and 0<j<len(self.mappings)-1:
continue
for cell_type in self.celltypes:

ad.obs[self.anno_key]=ad.obs[self.anno_key].astype('category')
if '{}_colors'.format(self.anno_key) in ad.uns.keys():
c_map=dict(zip(ad.obs[self.anno_key].cat.categories.tolist(),
ad.uns['{}_colors'.format(self.anno_key)]))
else:
if len(ad.obs[self.anno_key].cat.categories)>28:
c_map=dict(zip(ad.obs[self.anno_key].cat.categories,sc.pl.palettes.default_102))
else:
c_map=dict(zip(ad.obs[self.anno_key].cat.categories,sc.pl.palettes.zeileis_28))

for cell_type in ad.obs[self.anno_key].cat.categories:
slice = layer[anno == cell_type,:]
xs = slice[:,0]
ys = slice[:,1]
Expand All @@ -159,7 +175,8 @@ def draw_3D(self,

if hide_axis:
plt.axis('off')
plt.show()
return ax
#plt.show()


class match_3D_multi():
Expand Down

0 comments on commit e2c3a63

Please sign in to comment.