From 5243e2364c0c8171a9ee479dd31dfa957904a2dd Mon Sep 17 00:00:00 2001 From: Mohammad Kalim Akram Date: Wed, 31 May 2023 18:42:53 +0200 Subject: [PATCH 1/3] fix: validate file formats in url Signed-off-by: Mohammad Kalim Akram --- docarray/typing/url/any_url.py | 38 ++++++++++++++++++- docarray/typing/url/audio_url.py | 11 +++++- docarray/typing/url/image_url.py | 11 +++++- docarray/typing/url/text_url.py | 11 +++++- docarray/typing/url/url_3d/mesh_url.py | 15 +++++++- docarray/typing/url/url_3d/point_cloud_url.py | 13 ++++++- docarray/typing/url/url_3d/url_3d.py | 4 ++ docarray/typing/url/video_url.py | 11 +++++- tests/units/typing/url/test_audio_url.py | 23 +++++++++++ tests/units/typing/url/test_image_url.py | 25 ++++++++++++ tests/units/typing/url/test_mesh_url.py | 27 +++++++++++++ .../units/typing/url/test_point_cloud_url.py | 27 +++++++++++++ tests/units/typing/url/test_text_url.py | 21 ++++++++++ tests/units/typing/url/test_video_url.py | 24 ++++++++++++ 14 files changed, 254 insertions(+), 7 deletions(-) diff --git a/docarray/typing/url/any_url.py b/docarray/typing/url/any_url.py index 6d930aa53f3..838f2aef09f 100644 --- a/docarray/typing/url/any_url.py +++ b/docarray/typing/url/any_url.py @@ -1,8 +1,9 @@ +import mimetypes import os import urllib import urllib.parse import urllib.request -from typing import TYPE_CHECKING, Any, Optional, Type, TypeVar, Union +from typing import TYPE_CHECKING, Any, List, Optional, Type, TypeVar, Union import numpy as np from pydantic import AnyUrl as BaseAnyUrl @@ -27,6 +28,14 @@ class AnyUrl(BaseAnyUrl, AbstractType): False # turn off host requirement to allow passing of local paths as URL ) + @classmethod + def mime_type(cls) -> str: + raise NotImplementedError + + @classmethod + def allowed_extensions(cls) -> List[str]: + raise NotImplementedError + def _to_node_protobuf(self) -> 'NodeProto': """Convert Document into a NodeProto protobuf message. This function should be called when the Document is nested into another Document that need to @@ -61,6 +70,33 @@ def validate( url = super().validate(abs_path, field, config) # basic url validation + # Use mimetypes to validate file formats + mimetype, encoding = mimetypes.guess_type(value.split("?")[0]) + if not mimetype: + # try reading from the request headers if mimetypes failed - could be slow + try: + r = urllib.request.urlopen(value) + except Exception: # noqa + pass # should we raise an error/warning here, since url is not reachable(invalid)? + else: + mimetype = r.headers.get_content_maintype() + + skip_check = False + if not mimetype: # not able to automatically detect mimetype + # check if the file extension is among one of the allowed extensions + if not any( + value.endswith(ext) or value.split("?")[0].endswith(ext) + for ext in cls.allowed_extensions() + ): + raise ValueError( + f'file {value} is not a valid file format for class {cls}' + ) + else: + skip_check = True # one of the allowed extensions, skip the check + + if not skip_check and not mimetype.startswith(cls.mime_type()): + raise ValueError(f'file {value} is not a {cls.mime_type()} file format') + if input_is_relative_path: return cls(str(value), scheme=None) else: diff --git a/docarray/typing/url/audio_url.py b/docarray/typing/url/audio_url.py index a84a68754ee..7128e477823 100644 --- a/docarray/typing/url/audio_url.py +++ b/docarray/typing/url/audio_url.py @@ -1,5 +1,5 @@ import warnings -from typing import Optional, Tuple, TypeVar +from typing import List, Optional, Tuple, TypeVar from docarray.typing import AudioNdArray from docarray.typing.bytes.audio_bytes import AudioBytes @@ -17,6 +17,15 @@ class AudioUrl(AnyUrl): Can be remote (web) URL, or a local file path. """ + @classmethod + def mime_type(cls) -> str: + return 'audio' + + @classmethod + def allowed_extensions(cls) -> List[str]: + # add only those extensions that can not be identified by the mimetypes library but are valid + return [] + def load(self: T) -> Tuple[AudioNdArray, int]: """ Load the data from the url into an [`AudioNdArray`][docarray.typing.AudioNdArray] diff --git a/docarray/typing/url/image_url.py b/docarray/typing/url/image_url.py index 43758cf7436..ff3cab4136d 100644 --- a/docarray/typing/url/image_url.py +++ b/docarray/typing/url/image_url.py @@ -1,5 +1,5 @@ import warnings -from typing import TYPE_CHECKING, Optional, Tuple, TypeVar +from typing import TYPE_CHECKING, List, Optional, Tuple, TypeVar from docarray.typing import ImageBytes from docarray.typing.proto_register import _register_proto @@ -20,6 +20,15 @@ class ImageUrl(AnyUrl): Can be remote (web) URL, or a local file path. """ + @classmethod + def mime_type(cls) -> str: + return 'image' + + @classmethod + def allowed_extensions(cls) -> List[str]: + # add only those extensions that can not be identified by the mimetypes library but are valid + return [] + def load_pil(self, timeout: Optional[float] = None) -> 'PILImage.Image': """ Load the image from the bytes into a `PIL.Image.Image` instance diff --git a/docarray/typing/url/text_url.py b/docarray/typing/url/text_url.py index 86da87790e6..aa07c03b67f 100644 --- a/docarray/typing/url/text_url.py +++ b/docarray/typing/url/text_url.py @@ -1,4 +1,4 @@ -from typing import Optional, TypeVar +from typing import List, Optional, TypeVar from docarray.typing.proto_register import _register_proto from docarray.typing.url.any_url import AnyUrl @@ -13,6 +13,15 @@ class TextUrl(AnyUrl): Can be remote (web) URL, or a local file path. """ + @classmethod + def mime_type(cls) -> str: + return 'text' + + @classmethod + def allowed_extensions(cls) -> List[str]: + # add only those extensions that can not be identified by the mimetypes library but are valid + return ['.md'] + def load(self, charset: str = 'utf-8', timeout: Optional[float] = None) -> str: """ Load the text file into a string. diff --git a/docarray/typing/url/url_3d/mesh_url.py b/docarray/typing/url/url_3d/mesh_url.py index 70f32eb5581..709c2812c3d 100644 --- a/docarray/typing/url/url_3d/mesh_url.py +++ b/docarray/typing/url/url_3d/mesh_url.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any, Dict, Optional, TypeVar +from typing import TYPE_CHECKING, Any, Dict, List, Optional, TypeVar import numpy as np from pydantic import parse_obj_as @@ -20,6 +20,19 @@ class Mesh3DUrl(Url3D): Can be remote (web) URL, or a local file path. """ + @classmethod + def allowed_extensions(cls) -> List[str]: + # return list of allowed extensions to be used for mesh if mimetypes fail to detect + # generated with the help of chatGPT and definitely this list is not exhaustive + # bit hacky because of black formatting, making it a long vertical list + list_a = ['3ds', '3mf', 'ac', 'ac3d', 'amf', 'assimp', 'bvh', 'cob', 'collada'] + list_b = ['ctm', 'dxf', 'e57', 'fbx', 'gltf', 'glb', 'ifc', 'lwo', 'lws', 'lxo'] + list_c = ['md2', 'md3', 'md5', 'mdc', 'm3d', 'mdl', 'ms3d', 'nff', 'obj', 'off'] + list_d = ['pcd', 'pod', 'pmd', 'pmx', 'ply', 'q3o', 'q3s', 'raw', 'sib', 'smd'] + list_e = ['stl', 'ter' 'terragen', 'vtk', 'vrml', 'x3d', 'xaml', 'xgl', 'xml'] + list_f = ['xyz', 'zgl', 'vta'] + return list_a + list_b + list_c + list_d + list_e + list_f + def load( self: T, skip_materials: bool = True, diff --git a/docarray/typing/url/url_3d/point_cloud_url.py b/docarray/typing/url/url_3d/point_cloud_url.py index efe6ce6ae0e..3c566950363 100644 --- a/docarray/typing/url/url_3d/point_cloud_url.py +++ b/docarray/typing/url/url_3d/point_cloud_url.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any, Dict, Optional, TypeVar +from typing import TYPE_CHECKING, Any, Dict, List, Optional, TypeVar import numpy as np from pydantic import parse_obj_as @@ -21,6 +21,17 @@ class PointCloud3DUrl(Url3D): Can be remote (web) URL, or a local file path. """ + @classmethod + def allowed_extensions(cls) -> List[str]: + # return list of file format for point cloud if mimetypes fail to detect + # generated with the help of chatGPT and definitely this list is not exhaustive + # bit hacky because of black formatting, making it a long vertical list + list_a = ['ascii', 'bin', 'b3dm', 'bpf', 'dp', 'dxf', 'e57', 'fls', 'fls'] + list_b = ['glb', 'ply', 'gpf', 'las', 'obj', 'osgb', 'pcap', 'pcd', 'pdal'] + list_c = ['pfm', 'ply', 'ply2', 'pod', 'pods', 'pnts', 'ptg', 'ptx', 'pts'] + list_d = ['rcp', 'xyz', 'zfs'] + return list_a + list_b + list_c + list_d + def load( self: T, samples: int, diff --git a/docarray/typing/url/url_3d/url_3d.py b/docarray/typing/url/url_3d/url_3d.py index c55c0f954e7..3514a88bca1 100644 --- a/docarray/typing/url/url_3d/url_3d.py +++ b/docarray/typing/url/url_3d/url_3d.py @@ -18,6 +18,10 @@ class Url3D(AnyUrl, ABC): Can be remote (web) URL, or a local file path. """ + @classmethod + def mime_type(cls) -> str: + return 'application' + def _load_trimesh_instance( self: T, force: Optional[str] = None, diff --git a/docarray/typing/url/video_url.py b/docarray/typing/url/video_url.py index 5bd7b1be0b9..16a69cf33c0 100644 --- a/docarray/typing/url/video_url.py +++ b/docarray/typing/url/video_url.py @@ -1,5 +1,5 @@ import warnings -from typing import Optional, TypeVar +from typing import List, Optional, TypeVar from docarray.typing.bytes.video_bytes import VideoBytes, VideoLoadResult from docarray.typing.proto_register import _register_proto @@ -16,6 +16,15 @@ class VideoUrl(AnyUrl): Can be remote (web) URL, or a local file path. """ + @classmethod + def mime_type(cls) -> str: + return 'video' + + @classmethod + def allowed_extensions(cls) -> List[str]: + # add only those extensions that can not be identified by the mimetypes library but are valid + return [] + def load(self: T, **kwargs) -> VideoLoadResult: """ Load the data from the url into a `NamedTuple` of diff --git a/tests/units/typing/url/test_audio_url.py b/tests/units/typing/url/test_audio_url.py index 2e6b46bcabf..70771450ac3 100644 --- a/tests/units/typing/url/test_audio_url.py +++ b/tests/units/typing/url/test_audio_url.py @@ -1,3 +1,4 @@ +import os from typing import Optional import numpy as np @@ -123,3 +124,25 @@ def test_load_bytes(): assert isinstance(audio_bytes, bytes) assert isinstance(audio_bytes, AudioBytes) assert len(audio_bytes) > 0 + + +@pytest.mark.parametrize( + 'file_type, file_source', + [ + ('audio', AUDIO_FILES[0]), + ('audio', AUDIO_FILES[1]), + ('audio', REMOTE_AUDIO_FILE), + ('image', os.path.join(TOYDATA_DIR, 'test.png')), + ('video', os.path.join(TOYDATA_DIR, 'mov_bbb.mp4')), + ('text', os.path.join(TOYDATA_DIR, 'test' 'test.html')), + ('text', os.path.join(TOYDATA_DIR, 'test' 'test.md')), + ('text', os.path.join(TOYDATA_DIR, 'penal_colony.txt')), + ('application', os.path.join(TOYDATA_DIR, 'test.glb')), + ], +) +def test_file_validation(file_type, file_source): + if file_type != AudioUrl.mime_type(): + with pytest.raises(ValueError): + parse_obj_as(AudioUrl, file_source) + else: + parse_obj_as(AudioUrl, file_source) diff --git a/tests/units/typing/url/test_image_url.py b/tests/units/typing/url/test_image_url.py index 4054c997c80..86ea04ab11f 100644 --- a/tests/units/typing/url/test_image_url.py +++ b/tests/units/typing/url/test_image_url.py @@ -9,6 +9,7 @@ from docarray.base_doc.io.json import orjson_dumps from docarray.typing import ImageUrl +from tests import TOYDATA_DIR CUR_DIR = os.path.dirname(os.path.abspath(__file__)) PATH_TO_IMAGE_DATA = os.path.join(CUR_DIR, '..', '..', '..', 'toydata', 'image-data') @@ -174,3 +175,27 @@ def test_validation(path_to_img): url = parse_obj_as(ImageUrl, path_to_img) assert isinstance(url, ImageUrl) assert isinstance(url, str) + + +@pytest.mark.parametrize( + 'file_type, file_source', + [ + ('image', IMAGE_PATHS['png']), + ('image', IMAGE_PATHS['jpg']), + ('image', IMAGE_PATHS['jpeg']), + ('image', REMOTE_JPG), + ('audio', os.path.join(TOYDATA_DIR, 'hello.mp3')), + ('audio', os.path.join(TOYDATA_DIR, 'hello.wav')), + ('video', os.path.join(TOYDATA_DIR, 'mov_bbb.mp4')), + ('text', os.path.join(TOYDATA_DIR, 'test' 'test.html')), + ('text', os.path.join(TOYDATA_DIR, 'test' 'test.md')), + ('text', os.path.join(TOYDATA_DIR, 'penal_colony.txt')), + ('application', os.path.join(TOYDATA_DIR, 'test.glb')), + ], +) +def test_file_validation(file_type, file_source): + if file_type != ImageUrl.mime_type(): + with pytest.raises(ValueError): + parse_obj_as(ImageUrl, file_source) + else: + parse_obj_as(ImageUrl, file_source) diff --git a/tests/units/typing/url/test_mesh_url.py b/tests/units/typing/url/test_mesh_url.py index fb83a3362a2..14529394b54 100644 --- a/tests/units/typing/url/test_mesh_url.py +++ b/tests/units/typing/url/test_mesh_url.py @@ -1,3 +1,5 @@ +import os + import numpy as np import pytest from pydantic.tools import parse_obj_as, schema_json_of @@ -75,3 +77,28 @@ def test_validation(path_to_file): def test_proto_mesh_url(): uri = parse_obj_as(Mesh3DUrl, REMOTE_OBJ_FILE) uri._to_node_protobuf() + + +@pytest.mark.parametrize( + 'file_type, file_source', + [ + ('application', MESH_FILES['obj']), + ('application', MESH_FILES['glb']), + ('application', MESH_FILES['ply']), + ('application', REMOTE_OBJ_FILE), + ('audio', os.path.join(TOYDATA_DIR, 'hello.aac')), + ('audio', os.path.join(TOYDATA_DIR, 'hello.mp3')), + ('audio', os.path.join(TOYDATA_DIR, 'hello.ogg')), + ('video', os.path.join(TOYDATA_DIR, 'mov_bbb.mp4')), + ('image', os.path.join(TOYDATA_DIR, 'test.png')), + ('text', os.path.join(TOYDATA_DIR, 'test' 'test.html')), + ('text', os.path.join(TOYDATA_DIR, 'test' 'test.md')), + ('text', os.path.join(TOYDATA_DIR, 'penal_colony.txt')), + ], +) +def test_file_validation(file_type, file_source): + if file_type != Mesh3DUrl.mime_type(): + with pytest.raises(ValueError): + parse_obj_as(Mesh3DUrl, file_source) + else: + parse_obj_as(Mesh3DUrl, file_source) diff --git a/tests/units/typing/url/test_point_cloud_url.py b/tests/units/typing/url/test_point_cloud_url.py index e48404fe9ce..8ddd5fa5182 100644 --- a/tests/units/typing/url/test_point_cloud_url.py +++ b/tests/units/typing/url/test_point_cloud_url.py @@ -1,3 +1,5 @@ +import os + import numpy as np import pytest from pydantic.tools import parse_obj_as, schema_json_of @@ -79,3 +81,28 @@ def test_validation(path_to_file): def test_proto_point_cloud_url(): uri = parse_obj_as(PointCloud3DUrl, REMOTE_OBJ_FILE) uri._to_node_protobuf() + + +@pytest.mark.parametrize( + 'file_type, file_source', + [ + ('application', MESH_FILES['obj']), + ('application', MESH_FILES['glb']), + ('application', MESH_FILES['ply']), + ('application', REMOTE_OBJ_FILE), + ('audio', os.path.join(TOYDATA_DIR, 'hello.aac')), + ('audio', os.path.join(TOYDATA_DIR, 'hello.mp3')), + ('audio', os.path.join(TOYDATA_DIR, 'hello.ogg')), + ('video', os.path.join(TOYDATA_DIR, 'mov_bbb.mp4')), + ('image', os.path.join(TOYDATA_DIR, 'test.png')), + ('text', os.path.join(TOYDATA_DIR, 'test' 'test.html')), + ('text', os.path.join(TOYDATA_DIR, 'test' 'test.md')), + ('text', os.path.join(TOYDATA_DIR, 'penal_colony.txt')), + ], +) +def test_file_validation(file_type, file_source): + if file_type != PointCloud3DUrl.mime_type(): + with pytest.raises(ValueError): + parse_obj_as(PointCloud3DUrl, file_source) + else: + parse_obj_as(PointCloud3DUrl, file_source) diff --git a/tests/units/typing/url/test_text_url.py b/tests/units/typing/url/test_text_url.py index ebee337ab65..fe2016dfbce 100644 --- a/tests/units/typing/url/test_text_url.py +++ b/tests/units/typing/url/test_text_url.py @@ -89,3 +89,24 @@ def test_validation(path_to_file): url = parse_obj_as(TextUrl, path_to_file) assert isinstance(url, TextUrl) assert isinstance(url, str) + + +@pytest.mark.parametrize( + 'file_type, file_source', + [ + *[('text', file) for file in LOCAL_TEXT_FILES], + ('text', REMOTE_TEXT_FILE), + ('audio', os.path.join(TOYDATA_DIR, 'hello.aac')), + ('audio', os.path.join(TOYDATA_DIR, 'hello.mp3')), + ('audio', os.path.join(TOYDATA_DIR, 'hello.ogg')), + ('image', os.path.join(TOYDATA_DIR, 'test.png')), + ('video', os.path.join(TOYDATA_DIR, 'mov_bbb.mp4')), + ('application', os.path.join(TOYDATA_DIR, 'test.glb')), + ], +) +def test_file_validation(file_type, file_source): + if file_type != TextUrl.mime_type(): + with pytest.raises(ValueError): + parse_obj_as(TextUrl, file_source) + else: + parse_obj_as(TextUrl, file_source) diff --git a/tests/units/typing/url/test_video_url.py b/tests/units/typing/url/test_video_url.py index 726e66a0cb6..a76374aec56 100644 --- a/tests/units/typing/url/test_video_url.py +++ b/tests/units/typing/url/test_video_url.py @@ -1,3 +1,4 @@ +import os from typing import Optional import numpy as np @@ -146,3 +147,26 @@ def test_load_bytes(): assert isinstance(video_bytes, bytes) assert isinstance(video_bytes, VideoBytes) assert len(video_bytes) > 0 + + +@pytest.mark.parametrize( + 'file_type, file_source', + [ + ('video', LOCAL_VIDEO_FILE), + ('video', REMOTE_VIDEO_FILE), + ('audio', os.path.join(TOYDATA_DIR, 'hello.aac')), + ('audio', os.path.join(TOYDATA_DIR, 'hello.mp3')), + ('audio', os.path.join(TOYDATA_DIR, 'hello.ogg')), + ('image', os.path.join(TOYDATA_DIR, 'test.png')), + ('text', os.path.join(TOYDATA_DIR, 'test' 'test.html')), + ('text', os.path.join(TOYDATA_DIR, 'test' 'test.md')), + ('text', os.path.join(TOYDATA_DIR, 'penal_colony.txt')), + ('application', os.path.join(TOYDATA_DIR, 'test.glb')), + ], +) +def test_file_validation(file_type, file_source): + if file_type != VideoUrl.mime_type(): + with pytest.raises(ValueError): + parse_obj_as(VideoUrl, file_source) + else: + parse_obj_as(VideoUrl, file_source) From aa073128bad51eeb66da000843e2e5b3068b9bbf Mon Sep 17 00:00:00 2001 From: Mohammad Kalim Akram Date: Fri, 23 Jun 2023 10:41:37 +0530 Subject: [PATCH 2/3] fix: apply review suggestions Signed-off-by: Mohammad Kalim Akram --- docarray/typing/url/any_url.py | 67 +++++++++++-------- docarray/typing/url/audio_url.py | 2 +- docarray/typing/url/image_url.py | 2 +- docarray/typing/url/text_url.py | 20 +++++- docarray/typing/url/url_3d/mesh_url.py | 2 +- docarray/typing/url/url_3d/point_cloud_url.py | 2 +- docarray/typing/url/video_url.py | 2 +- tests/units/typing/url/test_text_url.py | 2 +- 8 files changed, 63 insertions(+), 36 deletions(-) diff --git a/docarray/typing/url/any_url.py b/docarray/typing/url/any_url.py index 838f2aef09f..638e69eba1f 100644 --- a/docarray/typing/url/any_url.py +++ b/docarray/typing/url/any_url.py @@ -30,10 +30,13 @@ class AnyUrl(BaseAnyUrl, AbstractType): @classmethod def mime_type(cls) -> str: + """Returns the mime type this class deals with.""" raise NotImplementedError @classmethod - def allowed_extensions(cls) -> List[str]: + def extra_extensions(cls) -> List[str]: + """Returns a list of allowed file extensions for this class which + falls outside the scope of mimetypes library.""" raise NotImplementedError def _to_node_protobuf(self) -> 'NodeProto': @@ -47,6 +50,37 @@ def _to_node_protobuf(self) -> 'NodeProto': return NodeProto(text=str(self), type=self._proto_type_name) + @classmethod + def is_extension_allowed(cls, value: 'AnyUrl') -> bool: + """ + Check if the file extension of the url is allowed for that class. + First read the mime type of the file, if it fails, then check the file extension. + + :param value: url to the file + :return: True if the extension is allowed, False otherwise + """ + if not issubclass(cls, AnyUrl): # no check for AnyUrl class + return True + mimetype, _ = mimetypes.guess_type(value.split("?")[0]) + if mimetype: + return mimetype.startswith(cls.mime_type()) + else: + # check if the extension is among the extra extensions of that class + return any( + value.endswith(ext) or value.split("?")[0].endswith(ext) + for ext in cls.extra_extensions() + ) + + @classmethod + def is_special_case(cls, value: 'AnyUrl') -> bool: + """ + Check if the url is a special case. + + :param value: url to the file + :return: True if the url is a special case, False otherwise + """ + return False + @classmethod def validate( cls: Type[T], @@ -70,37 +104,14 @@ def validate( url = super().validate(abs_path, field, config) # basic url validation - # Use mimetypes to validate file formats - mimetype, encoding = mimetypes.guess_type(value.split("?")[0]) - if not mimetype: - # try reading from the request headers if mimetypes failed - could be slow - try: - r = urllib.request.urlopen(value) - except Exception: # noqa - pass # should we raise an error/warning here, since url is not reachable(invalid)? - else: - mimetype = r.headers.get_content_maintype() - - skip_check = False - if not mimetype: # not able to automatically detect mimetype - # check if the file extension is among one of the allowed extensions - if not any( - value.endswith(ext) or value.split("?")[0].endswith(ext) - for ext in cls.allowed_extensions() - ): + # perform check only for subclasses of AnyUrl + if not cls.is_extension_allowed(url): + if not cls.is_special_case(url): # check for special cases raise ValueError( f'file {value} is not a valid file format for class {cls}' ) - else: - skip_check = True # one of the allowed extensions, skip the check - - if not skip_check and not mimetype.startswith(cls.mime_type()): - raise ValueError(f'file {value} is not a {cls.mime_type()} file format') - if input_is_relative_path: - return cls(str(value), scheme=None) - else: - return cls(str(url), scheme=None) + return cls(str(value if input_is_relative_path else url), scheme=None) @classmethod def validate_parts(cls, parts: 'Parts', validate_port: bool = True) -> 'Parts': diff --git a/docarray/typing/url/audio_url.py b/docarray/typing/url/audio_url.py index 7128e477823..8a3b83c6648 100644 --- a/docarray/typing/url/audio_url.py +++ b/docarray/typing/url/audio_url.py @@ -22,7 +22,7 @@ def mime_type(cls) -> str: return 'audio' @classmethod - def allowed_extensions(cls) -> List[str]: + def extra_extensions(cls) -> List[str]: # add only those extensions that can not be identified by the mimetypes library but are valid return [] diff --git a/docarray/typing/url/image_url.py b/docarray/typing/url/image_url.py index ff3cab4136d..acde2e45e57 100644 --- a/docarray/typing/url/image_url.py +++ b/docarray/typing/url/image_url.py @@ -25,7 +25,7 @@ def mime_type(cls) -> str: return 'image' @classmethod - def allowed_extensions(cls) -> List[str]: + def extra_extensions(cls) -> List[str]: # add only those extensions that can not be identified by the mimetypes library but are valid return [] diff --git a/docarray/typing/url/text_url.py b/docarray/typing/url/text_url.py index aa07c03b67f..9b39eb10877 100644 --- a/docarray/typing/url/text_url.py +++ b/docarray/typing/url/text_url.py @@ -18,10 +18,26 @@ def mime_type(cls) -> str: return 'text' @classmethod - def allowed_extensions(cls) -> List[str]: - # add only those extensions that can not be identified by the mimetypes library but are valid + def extra_extensions(cls) -> List[str]: + """ + List of extra file extensions for this type of URL (outside the scope of mimetype library). + """ return ['.md'] + @classmethod + def is_special_case(cls, value: 'AnyUrl') -> bool: + """ + Check if the url is a special case that needs to be handled differently. + + :param value: url to the file + :return: True if the url is a special case, False otherwise + """ + if value.startswith('http') or value.startswith('https'): + if len(value.split('/')[-1].split('.')) == 1: + # This handles the case where the value is a URL without a file extension + # for e.g. https://de.wikipedia.org/wiki/Brixen + return True + def load(self, charset: str = 'utf-8', timeout: Optional[float] = None) -> str: """ Load the text file into a string. diff --git a/docarray/typing/url/url_3d/mesh_url.py b/docarray/typing/url/url_3d/mesh_url.py index 709c2812c3d..5b558e38a94 100644 --- a/docarray/typing/url/url_3d/mesh_url.py +++ b/docarray/typing/url/url_3d/mesh_url.py @@ -21,7 +21,7 @@ class Mesh3DUrl(Url3D): """ @classmethod - def allowed_extensions(cls) -> List[str]: + def extra_extensions(cls) -> List[str]: # return list of allowed extensions to be used for mesh if mimetypes fail to detect # generated with the help of chatGPT and definitely this list is not exhaustive # bit hacky because of black formatting, making it a long vertical list diff --git a/docarray/typing/url/url_3d/point_cloud_url.py b/docarray/typing/url/url_3d/point_cloud_url.py index 3c566950363..a726aee9b5f 100644 --- a/docarray/typing/url/url_3d/point_cloud_url.py +++ b/docarray/typing/url/url_3d/point_cloud_url.py @@ -22,7 +22,7 @@ class PointCloud3DUrl(Url3D): """ @classmethod - def allowed_extensions(cls) -> List[str]: + def extra_extensions(cls) -> List[str]: # return list of file format for point cloud if mimetypes fail to detect # generated with the help of chatGPT and definitely this list is not exhaustive # bit hacky because of black formatting, making it a long vertical list diff --git a/docarray/typing/url/video_url.py b/docarray/typing/url/video_url.py index 16a69cf33c0..a86ee630d3a 100644 --- a/docarray/typing/url/video_url.py +++ b/docarray/typing/url/video_url.py @@ -21,7 +21,7 @@ def mime_type(cls) -> str: return 'video' @classmethod - def allowed_extensions(cls) -> List[str]: + def extra_extensions(cls) -> List[str]: # add only those extensions that can not be identified by the mimetypes library but are valid return [] diff --git a/tests/units/typing/url/test_text_url.py b/tests/units/typing/url/test_text_url.py index fe2016dfbce..05d19eaa6bb 100644 --- a/tests/units/typing/url/test_text_url.py +++ b/tests/units/typing/url/test_text_url.py @@ -54,7 +54,7 @@ def test_load_to_bytes(url): @pytest.mark.proto @pytest.mark.slow @pytest.mark.internet -@pytest.mark.parametrize('url', [REMOTE_TEXT_FILE, *LOCAL_TEXT_FILES]) +@pytest.mark.parametrize('url', [REMOTE_TEXT_FILE]) def test_proto_text_url(url): uri = parse_obj_as(TextUrl, url) From a7f515d1658d389546f831c9632bf9d742b22601 Mon Sep 17 00:00:00 2001 From: Mohammad Kalim Akram Date: Fri, 23 Jun 2023 11:50:09 +0530 Subject: [PATCH 3/3] fix: mypy and unit tests Signed-off-by: Mohammad Kalim Akram --- docarray/typing/url/any_url.py | 10 +++++----- docarray/typing/url/text_url.py | 1 + tests/index/weaviate/test_index_get_del_weaviate.py | 2 +- tests/integrations/predefined_document/test_audio.py | 2 -- 4 files changed, 7 insertions(+), 8 deletions(-) diff --git a/docarray/typing/url/any_url.py b/docarray/typing/url/any_url.py index 638e69eba1f..4be64b9e057 100644 --- a/docarray/typing/url/any_url.py +++ b/docarray/typing/url/any_url.py @@ -51,7 +51,7 @@ def _to_node_protobuf(self) -> 'NodeProto': return NodeProto(text=str(self), type=self._proto_type_name) @classmethod - def is_extension_allowed(cls, value: 'AnyUrl') -> bool: + def is_extension_allowed(cls, value: Any) -> bool: """ Check if the file extension of the url is allowed for that class. First read the mime type of the file, if it fails, then check the file extension. @@ -59,7 +59,7 @@ def is_extension_allowed(cls, value: 'AnyUrl') -> bool: :param value: url to the file :return: True if the extension is allowed, False otherwise """ - if not issubclass(cls, AnyUrl): # no check for AnyUrl class + if cls == AnyUrl: # no check for AnyUrl class return True mimetype, _ = mimetypes.guess_type(value.split("?")[0]) if mimetype: @@ -72,7 +72,7 @@ def is_extension_allowed(cls, value: 'AnyUrl') -> bool: ) @classmethod - def is_special_case(cls, value: 'AnyUrl') -> bool: + def is_special_case(cls, value: Any) -> bool: """ Check if the url is a special case. @@ -105,8 +105,8 @@ def validate( url = super().validate(abs_path, field, config) # basic url validation # perform check only for subclasses of AnyUrl - if not cls.is_extension_allowed(url): - if not cls.is_special_case(url): # check for special cases + if not cls.is_extension_allowed(value): + if not cls.is_special_case(value): # check for special cases raise ValueError( f'file {value} is not a valid file format for class {cls}' ) diff --git a/docarray/typing/url/text_url.py b/docarray/typing/url/text_url.py index 9b39eb10877..817c6aa99d9 100644 --- a/docarray/typing/url/text_url.py +++ b/docarray/typing/url/text_url.py @@ -37,6 +37,7 @@ def is_special_case(cls, value: 'AnyUrl') -> bool: # This handles the case where the value is a URL without a file extension # for e.g. https://de.wikipedia.org/wiki/Brixen return True + return False def load(self, charset: str = 'utf-8', timeout: Optional[float] = None) -> str: """ diff --git a/tests/index/weaviate/test_index_get_del_weaviate.py b/tests/index/weaviate/test_index_get_del_weaviate.py index 8c1bd15636e..569a7f9c112 100644 --- a/tests/index/weaviate/test_index_get_del_weaviate.py +++ b/tests/index/weaviate/test_index_get_del_weaviate.py @@ -403,7 +403,7 @@ class MyMultiModalDoc(BaseDoc): def test_index_document_with_bytes(weaviate_client): - doc = ImageDoc(id="1", url="www.foo.com", bytes_=b"foo") + doc = ImageDoc(id="1", url="www.foo.com/test.png", bytes_=b"foo") index = WeaviateDocumentIndex[ImageDoc]() index.index([doc]) diff --git a/tests/integrations/predefined_document/test_audio.py b/tests/integrations/predefined_document/test_audio.py index 2ba207245f7..a69852d9ba1 100644 --- a/tests/integrations/predefined_document/test_audio.py +++ b/tests/integrations/predefined_document/test_audio.py @@ -29,7 +29,6 @@ str(TOYDATA_DIR / 'hello.ogg'), str(TOYDATA_DIR / 'hello.wma'), str(TOYDATA_DIR / 'hello.aac'), - str(TOYDATA_DIR / 'hello'), ] LOCAL_AUDIO_FILES_AND_FORMAT = [ @@ -40,7 +39,6 @@ (str(TOYDATA_DIR / 'hello.ogg'), 'ogg'), (str(TOYDATA_DIR / 'hello.wma'), 'asf'), (str(TOYDATA_DIR / 'hello.aac'), 'adts'), - (str(TOYDATA_DIR / 'hello'), 'wav'), ] NON_AUDIO_FILES = [