diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 63d6facd16a..eda1318e7b1 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -171,7 +171,7 @@ jobs: uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} - - name: Prepare enviroment + - name: Prepare environment run: | python -m pip install --upgrade pip python -m pip install wheel diff --git a/docarray/document/mixins/mesh.py b/docarray/document/mixins/mesh.py index 3b26489f030..8f5de66918b 100644 --- a/docarray/document/mixins/mesh.py +++ b/docarray/document/mixins/mesh.py @@ -1,3 +1,4 @@ +from enum import Enum from typing import TYPE_CHECKING, Union import numpy as np @@ -7,7 +8,7 @@ import trimesh -class Mesh: +class MeshEnum(Enum): FILE_EXTENSIONS = [ 'glb', 'obj', @@ -17,6 +18,10 @@ class Mesh: FACES = 'faces' +class PointCloudEnum(Enum): + COLORS = 'point_cloud_colors' + + class MeshDataMixin: """Provide helper functions for :class:`Document` to support 3D mesh data and point cloud.""" @@ -80,8 +85,8 @@ def load_uri_to_vertices_and_faces(self: 'T') -> 'T': faces = mesh.faces.view(np.ndarray) self.chunks = [ - Document(name=Mesh.VERTICES, tensor=vertices), - Document(name=Mesh.FACES, tensor=faces), + Document(name=MeshEnum.VERTICES, tensor=vertices), + Document(name=MeshEnum.FACES, tensor=faces), ] return self @@ -96,9 +101,9 @@ def load_vertices_and_faces_to_point_cloud(self: 'T', samples: int) -> 'T': faces = None for chunk in self.chunks: - if chunk.tags['name'] == Mesh.VERTICES: + if chunk.tags['name'] == MeshEnum.VERTICES: vertices = chunk.tensor - if chunk.tags['name'] == Mesh.FACES: + if chunk.tags['name'] == MeshEnum.FACES: faces = chunk.tensor if vertices is not None and faces is not None: diff --git a/docarray/document/mixins/plot.py b/docarray/document/mixins/plot.py index af1779c2362..00f6570c6e0 100644 --- a/docarray/document/mixins/plot.py +++ b/docarray/document/mixins/plot.py @@ -3,7 +3,7 @@ import numpy as np -from docarray.document.mixins.mesh import Mesh +from docarray.document.mixins.mesh import MeshEnum, PointCloudEnum class PlotMixin: @@ -133,7 +133,7 @@ def _is_3d_vertices_and_faces(self): """ if self.chunks is not None: name_tags = [c.tags['name'] for c in self.chunks] - if Mesh.VERTICES in name_tags and Mesh.FACES in name_tags: + if MeshEnum.VERTICES in name_tags and MeshEnum.FACES in name_tags: return True else: return False @@ -173,9 +173,11 @@ def display_vertices_and_faces(self): import trimesh vertices = [ - c.tensor for c in self.chunks if c.tags['name'] == Mesh.VERTICES + c.tensor for c in self.chunks if c.tags['name'] == MeshEnum.VERTICES ][-1] - faces = [c.tensor for c in self.chunks if c.tags['name'] == Mesh.FACES][-1] + faces = [c.tensor for c in self.chunks if c.tags['name'] == MeshEnum.FACES][ + -1 + ] mesh = trimesh.Trimesh(vertices=vertices, faces=faces) display(mesh.show()) @@ -185,15 +187,24 @@ def display_point_cloud_tensor(self) -> None: from IPython.display import display from hubble.utils.notebook import is_notebook + colors = np.tile(np.array([0, 0, 0]), (len(self.tensor), 1)) + for chunk in self.chunks: + if ( + 'name' in chunk.tags.keys() + and chunk.tags['name'] == PointCloudEnum.COLORS + and chunk.tensor.shape[-1] in [3, 4] + ): + colors = chunk.tensor + + pc = trimesh.points.PointCloud( + vertices=self.tensor, + colors=colors, + ) + if is_notebook(): - pc = trimesh.points.PointCloud( - vertices=self.tensor, - colors=np.tile(np.array([0, 0, 0, 1]), (len(self.tensor), 1)), - ) s = trimesh.Scene(geometry=pc) display(s.show()) else: - pc = trimesh.points.PointCloud(vertices=self.tensor) display(pc.show()) def display_rgbd_tensor(self) -> None: diff --git a/docs/datatypes/mesh/index.md b/docs/datatypes/mesh/index.md index 0e3c68e75c9..d6fdbb413db 100644 --- a/docs/datatypes/mesh/index.md +++ b/docs/datatypes/mesh/index.md @@ -2576,6 +2576,15 @@ init(); " width="100%" height="500px" style="border:none;">