From 7cc0cc7c8f9777e2a47c04391cb2cd6f4e8dde78 Mon Sep 17 00:00:00 2001 From: Gregory Lindsey Date: Wed, 3 Apr 2024 13:32:11 -0700 Subject: [PATCH] fix: pad audio arrays to same shape if sample rates differ --- docarray/typing/bytes/video_bytes.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/docarray/typing/bytes/video_bytes.py b/docarray/typing/bytes/video_bytes.py index a1003046720..f40f364e3a6 100644 --- a/docarray/typing/bytes/video_bytes.py +++ b/docarray/typing/bytes/video_bytes.py @@ -1,3 +1,4 @@ +import warnings from io import BytesIO from typing import TYPE_CHECKING, List, NamedTuple, TypeVar @@ -80,6 +81,11 @@ class MyDoc(BaseDoc): video_frames.append(frame.to_ndarray(format='rgb24')) + # Pad audio arrays to same shape if sample rates differ + if len({arr.shape for arr in audio_frames}) > 1: + warnings.warn('Audio frames have different sample rates') + audio_frames = self._pad_arrays_to_same_shape(audio_frames) + if len(audio_frames) == 0: audio = parse_obj_as(AudioNdArray, np.array(audio_frames)) else: @@ -89,3 +95,15 @@ class MyDoc(BaseDoc): indices = parse_obj_as(NdArray, keyframe_indices) return VideoLoadResult(video=video, audio=audio, key_frame_indices=indices) + + @staticmethod + def _pad_arrays_to_same_shape(arrays: List[np.ndarray]) -> List[np.ndarray]: + # Calculate the maximum number of samples in any array + max_samples = max(arr.shape[1] for arr in arrays) + + # Pad arrays with fewer samples + for i, arr in enumerate(arrays): + if arr.shape[1] < max_samples: + arrays[i] = np.pad(arr, ((0, 0), (0, max_samples - arr.shape[1]))) + + return arrays \ No newline at end of file