diff --git a/docarray/typing/url/any_url.py b/docarray/typing/url/any_url.py index 6d930aa53f3..4be64b9e057 100644 --- a/docarray/typing/url/any_url.py +++ b/docarray/typing/url/any_url.py @@ -1,8 +1,9 @@ +import mimetypes import os import urllib import urllib.parse import urllib.request -from typing import TYPE_CHECKING, Any, Optional, Type, TypeVar, Union +from typing import TYPE_CHECKING, Any, List, Optional, Type, TypeVar, Union import numpy as np from pydantic import AnyUrl as BaseAnyUrl @@ -27,6 +28,17 @@ class AnyUrl(BaseAnyUrl, AbstractType): False # turn off host requirement to allow passing of local paths as URL ) + @classmethod + def mime_type(cls) -> str: + """Returns the mime type this class deals with.""" + raise NotImplementedError + + @classmethod + def extra_extensions(cls) -> List[str]: + """Returns a list of allowed file extensions for this class which + falls outside the scope of mimetypes library.""" + raise NotImplementedError + def _to_node_protobuf(self) -> 'NodeProto': """Convert Document into a NodeProto protobuf message. This function should be called when the Document is nested into another Document that need to @@ -38,6 +50,37 @@ def _to_node_protobuf(self) -> 'NodeProto': return NodeProto(text=str(self), type=self._proto_type_name) + @classmethod + def is_extension_allowed(cls, value: Any) -> bool: + """ + Check if the file extension of the url is allowed for that class. + First read the mime type of the file, if it fails, then check the file extension. + + :param value: url to the file + :return: True if the extension is allowed, False otherwise + """ + if cls == AnyUrl: # no check for AnyUrl class + return True + mimetype, _ = mimetypes.guess_type(value.split("?")[0]) + if mimetype: + return mimetype.startswith(cls.mime_type()) + else: + # check if the extension is among the extra extensions of that class + return any( + value.endswith(ext) or value.split("?")[0].endswith(ext) + for ext in cls.extra_extensions() + ) + + @classmethod + def is_special_case(cls, value: Any) -> bool: + """ + Check if the url is a special case. + + :param value: url to the file + :return: True if the url is a special case, False otherwise + """ + return False + @classmethod def validate( cls: Type[T], @@ -61,10 +104,14 @@ def validate( url = super().validate(abs_path, field, config) # basic url validation - if input_is_relative_path: - return cls(str(value), scheme=None) - else: - return cls(str(url), scheme=None) + # perform check only for subclasses of AnyUrl + if not cls.is_extension_allowed(value): + if not cls.is_special_case(value): # check for special cases + raise ValueError( + f'file {value} is not a valid file format for class {cls}' + ) + + return cls(str(value if input_is_relative_path else url), scheme=None) @classmethod def validate_parts(cls, parts: 'Parts', validate_port: bool = True) -> 'Parts': diff --git a/docarray/typing/url/audio_url.py b/docarray/typing/url/audio_url.py index a84a68754ee..8a3b83c6648 100644 --- a/docarray/typing/url/audio_url.py +++ b/docarray/typing/url/audio_url.py @@ -1,5 +1,5 @@ import warnings -from typing import Optional, Tuple, TypeVar +from typing import List, Optional, Tuple, TypeVar from docarray.typing import AudioNdArray from docarray.typing.bytes.audio_bytes import AudioBytes @@ -17,6 +17,15 @@ class AudioUrl(AnyUrl): Can be remote (web) URL, or a local file path. """ + @classmethod + def mime_type(cls) -> str: + return 'audio' + + @classmethod + def extra_extensions(cls) -> List[str]: + # add only those extensions that can not be identified by the mimetypes library but are valid + return [] + def load(self: T) -> Tuple[AudioNdArray, int]: """ Load the data from the url into an [`AudioNdArray`][docarray.typing.AudioNdArray] diff --git a/docarray/typing/url/image_url.py b/docarray/typing/url/image_url.py index 43758cf7436..acde2e45e57 100644 --- a/docarray/typing/url/image_url.py +++ b/docarray/typing/url/image_url.py @@ -1,5 +1,5 @@ import warnings -from typing import TYPE_CHECKING, Optional, Tuple, TypeVar +from typing import TYPE_CHECKING, List, Optional, Tuple, TypeVar from docarray.typing import ImageBytes from docarray.typing.proto_register import _register_proto @@ -20,6 +20,15 @@ class ImageUrl(AnyUrl): Can be remote (web) URL, or a local file path. """ + @classmethod + def mime_type(cls) -> str: + return 'image' + + @classmethod + def extra_extensions(cls) -> List[str]: + # add only those extensions that can not be identified by the mimetypes library but are valid + return [] + def load_pil(self, timeout: Optional[float] = None) -> 'PILImage.Image': """ Load the image from the bytes into a `PIL.Image.Image` instance diff --git a/docarray/typing/url/text_url.py b/docarray/typing/url/text_url.py index 86da87790e6..817c6aa99d9 100644 --- a/docarray/typing/url/text_url.py +++ b/docarray/typing/url/text_url.py @@ -1,4 +1,4 @@ -from typing import Optional, TypeVar +from typing import List, Optional, TypeVar from docarray.typing.proto_register import _register_proto from docarray.typing.url.any_url import AnyUrl @@ -13,6 +13,32 @@ class TextUrl(AnyUrl): Can be remote (web) URL, or a local file path. """ + @classmethod + def mime_type(cls) -> str: + return 'text' + + @classmethod + def extra_extensions(cls) -> List[str]: + """ + List of extra file extensions for this type of URL (outside the scope of mimetype library). + """ + return ['.md'] + + @classmethod + def is_special_case(cls, value: 'AnyUrl') -> bool: + """ + Check if the url is a special case that needs to be handled differently. + + :param value: url to the file + :return: True if the url is a special case, False otherwise + """ + if value.startswith('http') or value.startswith('https'): + if len(value.split('/')[-1].split('.')) == 1: + # This handles the case where the value is a URL without a file extension + # for e.g. https://de.wikipedia.org/wiki/Brixen + return True + return False + def load(self, charset: str = 'utf-8', timeout: Optional[float] = None) -> str: """ Load the text file into a string. diff --git a/docarray/typing/url/url_3d/mesh_url.py b/docarray/typing/url/url_3d/mesh_url.py index 70f32eb5581..5b558e38a94 100644 --- a/docarray/typing/url/url_3d/mesh_url.py +++ b/docarray/typing/url/url_3d/mesh_url.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any, Dict, Optional, TypeVar +from typing import TYPE_CHECKING, Any, Dict, List, Optional, TypeVar import numpy as np from pydantic import parse_obj_as @@ -20,6 +20,19 @@ class Mesh3DUrl(Url3D): Can be remote (web) URL, or a local file path. """ + @classmethod + def extra_extensions(cls) -> List[str]: + # return list of allowed extensions to be used for mesh if mimetypes fail to detect + # generated with the help of chatGPT and definitely this list is not exhaustive + # bit hacky because of black formatting, making it a long vertical list + list_a = ['3ds', '3mf', 'ac', 'ac3d', 'amf', 'assimp', 'bvh', 'cob', 'collada'] + list_b = ['ctm', 'dxf', 'e57', 'fbx', 'gltf', 'glb', 'ifc', 'lwo', 'lws', 'lxo'] + list_c = ['md2', 'md3', 'md5', 'mdc', 'm3d', 'mdl', 'ms3d', 'nff', 'obj', 'off'] + list_d = ['pcd', 'pod', 'pmd', 'pmx', 'ply', 'q3o', 'q3s', 'raw', 'sib', 'smd'] + list_e = ['stl', 'ter' 'terragen', 'vtk', 'vrml', 'x3d', 'xaml', 'xgl', 'xml'] + list_f = ['xyz', 'zgl', 'vta'] + return list_a + list_b + list_c + list_d + list_e + list_f + def load( self: T, skip_materials: bool = True, diff --git a/docarray/typing/url/url_3d/point_cloud_url.py b/docarray/typing/url/url_3d/point_cloud_url.py index efe6ce6ae0e..a726aee9b5f 100644 --- a/docarray/typing/url/url_3d/point_cloud_url.py +++ b/docarray/typing/url/url_3d/point_cloud_url.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any, Dict, Optional, TypeVar +from typing import TYPE_CHECKING, Any, Dict, List, Optional, TypeVar import numpy as np from pydantic import parse_obj_as @@ -21,6 +21,17 @@ class PointCloud3DUrl(Url3D): Can be remote (web) URL, or a local file path. """ + @classmethod + def extra_extensions(cls) -> List[str]: + # return list of file format for point cloud if mimetypes fail to detect + # generated with the help of chatGPT and definitely this list is not exhaustive + # bit hacky because of black formatting, making it a long vertical list + list_a = ['ascii', 'bin', 'b3dm', 'bpf', 'dp', 'dxf', 'e57', 'fls', 'fls'] + list_b = ['glb', 'ply', 'gpf', 'las', 'obj', 'osgb', 'pcap', 'pcd', 'pdal'] + list_c = ['pfm', 'ply', 'ply2', 'pod', 'pods', 'pnts', 'ptg', 'ptx', 'pts'] + list_d = ['rcp', 'xyz', 'zfs'] + return list_a + list_b + list_c + list_d + def load( self: T, samples: int, diff --git a/docarray/typing/url/url_3d/url_3d.py b/docarray/typing/url/url_3d/url_3d.py index c55c0f954e7..3514a88bca1 100644 --- a/docarray/typing/url/url_3d/url_3d.py +++ b/docarray/typing/url/url_3d/url_3d.py @@ -18,6 +18,10 @@ class Url3D(AnyUrl, ABC): Can be remote (web) URL, or a local file path. """ + @classmethod + def mime_type(cls) -> str: + return 'application' + def _load_trimesh_instance( self: T, force: Optional[str] = None, diff --git a/docarray/typing/url/video_url.py b/docarray/typing/url/video_url.py index 5bd7b1be0b9..a86ee630d3a 100644 --- a/docarray/typing/url/video_url.py +++ b/docarray/typing/url/video_url.py @@ -1,5 +1,5 @@ import warnings -from typing import Optional, TypeVar +from typing import List, Optional, TypeVar from docarray.typing.bytes.video_bytes import VideoBytes, VideoLoadResult from docarray.typing.proto_register import _register_proto @@ -16,6 +16,15 @@ class VideoUrl(AnyUrl): Can be remote (web) URL, or a local file path. """ + @classmethod + def mime_type(cls) -> str: + return 'video' + + @classmethod + def extra_extensions(cls) -> List[str]: + # add only those extensions that can not be identified by the mimetypes library but are valid + return [] + def load(self: T, **kwargs) -> VideoLoadResult: """ Load the data from the url into a `NamedTuple` of diff --git a/tests/index/weaviate/test_index_get_del_weaviate.py b/tests/index/weaviate/test_index_get_del_weaviate.py index 8c1bd15636e..569a7f9c112 100644 --- a/tests/index/weaviate/test_index_get_del_weaviate.py +++ b/tests/index/weaviate/test_index_get_del_weaviate.py @@ -403,7 +403,7 @@ class MyMultiModalDoc(BaseDoc): def test_index_document_with_bytes(weaviate_client): - doc = ImageDoc(id="1", url="www.foo.com", bytes_=b"foo") + doc = ImageDoc(id="1", url="www.foo.com/test.png", bytes_=b"foo") index = WeaviateDocumentIndex[ImageDoc]() index.index([doc]) diff --git a/tests/integrations/predefined_document/test_audio.py b/tests/integrations/predefined_document/test_audio.py index 2ba207245f7..a69852d9ba1 100644 --- a/tests/integrations/predefined_document/test_audio.py +++ b/tests/integrations/predefined_document/test_audio.py @@ -29,7 +29,6 @@ str(TOYDATA_DIR / 'hello.ogg'), str(TOYDATA_DIR / 'hello.wma'), str(TOYDATA_DIR / 'hello.aac'), - str(TOYDATA_DIR / 'hello'), ] LOCAL_AUDIO_FILES_AND_FORMAT = [ @@ -40,7 +39,6 @@ (str(TOYDATA_DIR / 'hello.ogg'), 'ogg'), (str(TOYDATA_DIR / 'hello.wma'), 'asf'), (str(TOYDATA_DIR / 'hello.aac'), 'adts'), - (str(TOYDATA_DIR / 'hello'), 'wav'), ] NON_AUDIO_FILES = [ diff --git a/tests/units/typing/url/test_audio_url.py b/tests/units/typing/url/test_audio_url.py index 2e6b46bcabf..70771450ac3 100644 --- a/tests/units/typing/url/test_audio_url.py +++ b/tests/units/typing/url/test_audio_url.py @@ -1,3 +1,4 @@ +import os from typing import Optional import numpy as np @@ -123,3 +124,25 @@ def test_load_bytes(): assert isinstance(audio_bytes, bytes) assert isinstance(audio_bytes, AudioBytes) assert len(audio_bytes) > 0 + + +@pytest.mark.parametrize( + 'file_type, file_source', + [ + ('audio', AUDIO_FILES[0]), + ('audio', AUDIO_FILES[1]), + ('audio', REMOTE_AUDIO_FILE), + ('image', os.path.join(TOYDATA_DIR, 'test.png')), + ('video', os.path.join(TOYDATA_DIR, 'mov_bbb.mp4')), + ('text', os.path.join(TOYDATA_DIR, 'test' 'test.html')), + ('text', os.path.join(TOYDATA_DIR, 'test' 'test.md')), + ('text', os.path.join(TOYDATA_DIR, 'penal_colony.txt')), + ('application', os.path.join(TOYDATA_DIR, 'test.glb')), + ], +) +def test_file_validation(file_type, file_source): + if file_type != AudioUrl.mime_type(): + with pytest.raises(ValueError): + parse_obj_as(AudioUrl, file_source) + else: + parse_obj_as(AudioUrl, file_source) diff --git a/tests/units/typing/url/test_image_url.py b/tests/units/typing/url/test_image_url.py index 4054c997c80..86ea04ab11f 100644 --- a/tests/units/typing/url/test_image_url.py +++ b/tests/units/typing/url/test_image_url.py @@ -9,6 +9,7 @@ from docarray.base_doc.io.json import orjson_dumps from docarray.typing import ImageUrl +from tests import TOYDATA_DIR CUR_DIR = os.path.dirname(os.path.abspath(__file__)) PATH_TO_IMAGE_DATA = os.path.join(CUR_DIR, '..', '..', '..', 'toydata', 'image-data') @@ -174,3 +175,27 @@ def test_validation(path_to_img): url = parse_obj_as(ImageUrl, path_to_img) assert isinstance(url, ImageUrl) assert isinstance(url, str) + + +@pytest.mark.parametrize( + 'file_type, file_source', + [ + ('image', IMAGE_PATHS['png']), + ('image', IMAGE_PATHS['jpg']), + ('image', IMAGE_PATHS['jpeg']), + ('image', REMOTE_JPG), + ('audio', os.path.join(TOYDATA_DIR, 'hello.mp3')), + ('audio', os.path.join(TOYDATA_DIR, 'hello.wav')), + ('video', os.path.join(TOYDATA_DIR, 'mov_bbb.mp4')), + ('text', os.path.join(TOYDATA_DIR, 'test' 'test.html')), + ('text', os.path.join(TOYDATA_DIR, 'test' 'test.md')), + ('text', os.path.join(TOYDATA_DIR, 'penal_colony.txt')), + ('application', os.path.join(TOYDATA_DIR, 'test.glb')), + ], +) +def test_file_validation(file_type, file_source): + if file_type != ImageUrl.mime_type(): + with pytest.raises(ValueError): + parse_obj_as(ImageUrl, file_source) + else: + parse_obj_as(ImageUrl, file_source) diff --git a/tests/units/typing/url/test_mesh_url.py b/tests/units/typing/url/test_mesh_url.py index fb83a3362a2..14529394b54 100644 --- a/tests/units/typing/url/test_mesh_url.py +++ b/tests/units/typing/url/test_mesh_url.py @@ -1,3 +1,5 @@ +import os + import numpy as np import pytest from pydantic.tools import parse_obj_as, schema_json_of @@ -75,3 +77,28 @@ def test_validation(path_to_file): def test_proto_mesh_url(): uri = parse_obj_as(Mesh3DUrl, REMOTE_OBJ_FILE) uri._to_node_protobuf() + + +@pytest.mark.parametrize( + 'file_type, file_source', + [ + ('application', MESH_FILES['obj']), + ('application', MESH_FILES['glb']), + ('application', MESH_FILES['ply']), + ('application', REMOTE_OBJ_FILE), + ('audio', os.path.join(TOYDATA_DIR, 'hello.aac')), + ('audio', os.path.join(TOYDATA_DIR, 'hello.mp3')), + ('audio', os.path.join(TOYDATA_DIR, 'hello.ogg')), + ('video', os.path.join(TOYDATA_DIR, 'mov_bbb.mp4')), + ('image', os.path.join(TOYDATA_DIR, 'test.png')), + ('text', os.path.join(TOYDATA_DIR, 'test' 'test.html')), + ('text', os.path.join(TOYDATA_DIR, 'test' 'test.md')), + ('text', os.path.join(TOYDATA_DIR, 'penal_colony.txt')), + ], +) +def test_file_validation(file_type, file_source): + if file_type != Mesh3DUrl.mime_type(): + with pytest.raises(ValueError): + parse_obj_as(Mesh3DUrl, file_source) + else: + parse_obj_as(Mesh3DUrl, file_source) diff --git a/tests/units/typing/url/test_point_cloud_url.py b/tests/units/typing/url/test_point_cloud_url.py index e48404fe9ce..8ddd5fa5182 100644 --- a/tests/units/typing/url/test_point_cloud_url.py +++ b/tests/units/typing/url/test_point_cloud_url.py @@ -1,3 +1,5 @@ +import os + import numpy as np import pytest from pydantic.tools import parse_obj_as, schema_json_of @@ -79,3 +81,28 @@ def test_validation(path_to_file): def test_proto_point_cloud_url(): uri = parse_obj_as(PointCloud3DUrl, REMOTE_OBJ_FILE) uri._to_node_protobuf() + + +@pytest.mark.parametrize( + 'file_type, file_source', + [ + ('application', MESH_FILES['obj']), + ('application', MESH_FILES['glb']), + ('application', MESH_FILES['ply']), + ('application', REMOTE_OBJ_FILE), + ('audio', os.path.join(TOYDATA_DIR, 'hello.aac')), + ('audio', os.path.join(TOYDATA_DIR, 'hello.mp3')), + ('audio', os.path.join(TOYDATA_DIR, 'hello.ogg')), + ('video', os.path.join(TOYDATA_DIR, 'mov_bbb.mp4')), + ('image', os.path.join(TOYDATA_DIR, 'test.png')), + ('text', os.path.join(TOYDATA_DIR, 'test' 'test.html')), + ('text', os.path.join(TOYDATA_DIR, 'test' 'test.md')), + ('text', os.path.join(TOYDATA_DIR, 'penal_colony.txt')), + ], +) +def test_file_validation(file_type, file_source): + if file_type != PointCloud3DUrl.mime_type(): + with pytest.raises(ValueError): + parse_obj_as(PointCloud3DUrl, file_source) + else: + parse_obj_as(PointCloud3DUrl, file_source) diff --git a/tests/units/typing/url/test_text_url.py b/tests/units/typing/url/test_text_url.py index ebee337ab65..05d19eaa6bb 100644 --- a/tests/units/typing/url/test_text_url.py +++ b/tests/units/typing/url/test_text_url.py @@ -54,7 +54,7 @@ def test_load_to_bytes(url): @pytest.mark.proto @pytest.mark.slow @pytest.mark.internet -@pytest.mark.parametrize('url', [REMOTE_TEXT_FILE, *LOCAL_TEXT_FILES]) +@pytest.mark.parametrize('url', [REMOTE_TEXT_FILE]) def test_proto_text_url(url): uri = parse_obj_as(TextUrl, url) @@ -89,3 +89,24 @@ def test_validation(path_to_file): url = parse_obj_as(TextUrl, path_to_file) assert isinstance(url, TextUrl) assert isinstance(url, str) + + +@pytest.mark.parametrize( + 'file_type, file_source', + [ + *[('text', file) for file in LOCAL_TEXT_FILES], + ('text', REMOTE_TEXT_FILE), + ('audio', os.path.join(TOYDATA_DIR, 'hello.aac')), + ('audio', os.path.join(TOYDATA_DIR, 'hello.mp3')), + ('audio', os.path.join(TOYDATA_DIR, 'hello.ogg')), + ('image', os.path.join(TOYDATA_DIR, 'test.png')), + ('video', os.path.join(TOYDATA_DIR, 'mov_bbb.mp4')), + ('application', os.path.join(TOYDATA_DIR, 'test.glb')), + ], +) +def test_file_validation(file_type, file_source): + if file_type != TextUrl.mime_type(): + with pytest.raises(ValueError): + parse_obj_as(TextUrl, file_source) + else: + parse_obj_as(TextUrl, file_source) diff --git a/tests/units/typing/url/test_video_url.py b/tests/units/typing/url/test_video_url.py index 726e66a0cb6..a76374aec56 100644 --- a/tests/units/typing/url/test_video_url.py +++ b/tests/units/typing/url/test_video_url.py @@ -1,3 +1,4 @@ +import os from typing import Optional import numpy as np @@ -146,3 +147,26 @@ def test_load_bytes(): assert isinstance(video_bytes, bytes) assert isinstance(video_bytes, VideoBytes) assert len(video_bytes) > 0 + + +@pytest.mark.parametrize( + 'file_type, file_source', + [ + ('video', LOCAL_VIDEO_FILE), + ('video', REMOTE_VIDEO_FILE), + ('audio', os.path.join(TOYDATA_DIR, 'hello.aac')), + ('audio', os.path.join(TOYDATA_DIR, 'hello.mp3')), + ('audio', os.path.join(TOYDATA_DIR, 'hello.ogg')), + ('image', os.path.join(TOYDATA_DIR, 'test.png')), + ('text', os.path.join(TOYDATA_DIR, 'test' 'test.html')), + ('text', os.path.join(TOYDATA_DIR, 'test' 'test.md')), + ('text', os.path.join(TOYDATA_DIR, 'penal_colony.txt')), + ('application', os.path.join(TOYDATA_DIR, 'test.glb')), + ], +) +def test_file_validation(file_type, file_source): + if file_type != VideoUrl.mime_type(): + with pytest.raises(ValueError): + parse_obj_as(VideoUrl, file_source) + else: + parse_obj_as(VideoUrl, file_source)