Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions docarray/array/mixins/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import warnings
from collections import Counter
from math import sqrt, ceil, floor
from typing import Optional, Tuple
from typing import Optional, Tuple, List

import numpy as np

Expand Down Expand Up @@ -126,8 +126,9 @@ def plot_embeddings(
host: str = '127.0.0.1',
port: Optional[int] = None,
image_source: str = 'tensor',
exclude_fields_metas: Optional[List[str]] = None,
) -> str:
"""Interactively visualize :attr:`.embeddings` using the Embedding Projector.
"""Interactively visualize :attr:`.embeddings` using the Embedding Projector and store the visualization informations.

:param title: the title of this visualization. If you want to compare multiple embeddings at the same time,
make sure to give different names each time and set ``path`` to the same value.
Expand All @@ -140,6 +141,7 @@ def plot_embeddings(
:param channel_axis: only used when `image_sprites=True`. the axis id of the color channel, ``-1`` indicates the color channel info at the last axis
:param start_server: if set, start a HTTP server and open the frontend directly. Otherwise, you need to rely on ``return`` path and serve by yourself.
:param image_source: specify where the image comes from, can be ``uri`` or ``tensor``. empty tensor will fallback to uri
:param exclude_fields_metas: specify the fields that you want to exclude from metadata tsv file
:return: the path to the embeddings visualization info.
"""
from docarray.helper import random_port, __resources_path__
Expand Down Expand Up @@ -178,14 +180,20 @@ def plot_embeddings(

self.save_embeddings_csv(os.path.join(path, emb_fn), delimiter='\t')

_exclude_fields = ('embedding', 'tensor', 'scores')
_exclude_fields_metas = ['embedding', 'tensor', 'scores'] + (
exclude_fields_metas or []
)

with_header = True
if len(set(self[0].non_empty_fields).difference(set(_exclude_fields))) <= 1:
if (
len(set(self[0].non_empty_fields).difference(set(_exclude_fields_metas)))
<= 1
):
with_header = False

self.save_csv(
os.path.join(path, meta_fn),
exclude_fields=_exclude_fields,
exclude_fields=_exclude_fields_metas,
dialect='excel-tab',
with_header=with_header,
)
Expand Down
7 changes: 6 additions & 1 deletion docs/fundamentals/documentarray/visualization.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,4 +90,9 @@ da.plot_embeddings(image_sprites=True)

```{figure} images/embedding-projector.gif
:align: center
```
```

````{admonition} Note
:class: note
If you have a lot of metadata, plotting may be slow since that metadata is stored in a corresponding TSV file. You can speed up plotting with the `exclude_fields_metas` parameter, preventing fields (like `chunks` or `matches`) from being written to the TSV.
````