Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 27 additions & 18 deletions ultraplot/axes/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,8 @@
Width of streamlines.
cmap, norm : optional
Colormap and normalization for array colors.
colorbar, colorbar_kw : optional
Add a colorbar for array-valued streamline colors.
arrowsize : float, optional
Arrow size scaling.
arrowstyle : str, optional
Expand Down Expand Up @@ -1918,6 +1920,8 @@ def curved_quiver(
grains: Optional[int] = None,
density: Optional[int] = None,
arrow_at_end: Optional[bool] = None,
colorbar: Optional[str] = None,
colorbar_kw: Optional[dict[str, Any]] = None,
):
"""
%(plot.curved_quiver)s
Expand All @@ -1935,6 +1939,7 @@ def curved_quiver(
zorder = _not_none(zorder, mlines.Line2D.zorder)
transform = _not_none(transform, self.transData)
color = _not_none(color, self._get_lines.get_next_color())
colorbar_kw = colorbar_kw or {}
linewidth = _not_none(linewidth, rc["lines.linewidth"])
scale = _not_none(scale, rc["curved_quiver.scale"])
grains = _not_none(grains, rc["curved_quiver.grains"])
Expand Down Expand Up @@ -1968,6 +1973,7 @@ def curved_quiver(
raise ValueError(
"If 'linewidth' is given, must have the shape of 'Grid(x,y)'"
)
linewidth = np.ma.masked_invalid(linewidth)
line_kw["linewidth"] = []
else:
line_kw["linewidth"] = linewidth
Expand All @@ -1990,7 +1996,6 @@ def curved_quiver(

integrate = solver.get_integrator(u, v, minlength, resolution, magnitude)
trajectories = []
edges = []

if start_points is None:
start_points = solver.gen_starting_points(x, y, grains)
Expand Down Expand Up @@ -2026,18 +2031,19 @@ def curved_quiver(

for xs, ys in sp2:
xg, yg = solver.domain_map.data2grid(xs, ys)
t = integrate(xg, yg)
if t is not None:
trajectories.append(t[0])
edges.append(t[1])
trajectory = integrate(xg, yg)
if trajectory is not None:
trajectories.append(trajectory)
streamlines = []
arrows = []
for t, edge in zip(trajectories, edges):
tgx = np.array(t[0])
tgy = np.array(t[1])
for trajectory in trajectories:
tgx = np.array(trajectory.x)
tgy = np.array(trajectory.y)

# Rescale from grid-coordinates to data-coordinates.
tx, ty = solver.domain_map.grid2data(*np.array(t))
tx, ty = solver.domain_map.grid2data(
*np.array([trajectory.x, trajectory.y])
)
tx += solver.grid.x_origin
ty += solver.grid.y_origin

Expand All @@ -2054,14 +2060,9 @@ def curved_quiver(
continue

arrow_tail = (tx[-1], ty[-1])

# Extrapolate to find arrow head
xg, yg = solver.domain_map.data2grid(
tx[-1] - solver.grid.x_origin, ty[-1] - solver.grid.y_origin
)

ui = solver.interpgrid(u, xg, yg)
vi = solver.interpgrid(v, xg, yg)
if trajectory.end_direction is None:
continue
ui, vi = trajectory.end_direction

norm_v = np.sqrt(ui**2 + vi**2)
if norm_v > 0:
Expand All @@ -2087,14 +2088,16 @@ def curved_quiver(
if isinstance(linewidth, np.ndarray):
line_widths = solver.interpgrid(linewidth, tgx, tgy)[:-1]
line_kw["linewidth"].extend(line_widths)
if np.ma.is_masked(line_widths[n]):
continue
arrow_kw["linewidth"] = line_widths[n]

if use_multicolor_lines:
color_values = solver.interpgrid(color, tgx, tgy)[:-1]
line_colors.append(color_values)
arrow_kw["color"] = cmap(norm(color_values[n]))

if not edge:
if not trajectory.hit_edge:
p = mpatches.FancyArrowPatch(
arrow_tail, arrow_head, transform=transform, **arrow_kw
)
Expand Down Expand Up @@ -2125,6 +2128,12 @@ def curved_quiver(
lc.set_array(np.ma.hstack(line_colors))
lc.set_cmap(cmap)
lc.set_norm(norm)
self._update_guide(
lc,
colorbar=colorbar,
colorbar_kw=colorbar_kw,
queue_colorbar=False,
)

self.add_collection(lc)
self.autoscale_view()
Expand Down
55 changes: 42 additions & 13 deletions ultraplot/axes/plot_types/curved_quiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,14 @@ class CurvedQuiverSet(StreamplotSet):
arrows: object


@dataclass
class _CurvedQuiverTrajectory:
x: list[float]
y: list[float]
hit_edge: bool
end_direction: tuple[float, float] | None


class _DomainMap(object):
"""Map representing different coordinate systems.

Expand Down Expand Up @@ -197,7 +205,7 @@ def get_integrator(
minlength: float,
resolution: float,
magnitude: np.ndarray,
) -> Callable[[float, float], tuple[tuple[list[float], list[float]], bool] | None]:
) -> Callable[[float, float], _CurvedQuiverTrajectory | None]:
# rescale velocity onto grid-coordinates for integrations.
u, v = self.domain_map.data2grid(u, v)

Expand All @@ -215,9 +223,7 @@ def forward_time(xi: float, yi: float) -> tuple[float, float]:
vi = self.interpgrid(v, xi, yi)
return ui * dt_ds, vi * dt_ds

def integrate(
x0: float, y0: float
) -> tuple[tuple[list[float], list[float], bool]] | None:
def integrate(x0: float, y0: float) -> _CurvedQuiverTrajectory | None:
"""Return x, y grid-coordinates of trajectory based on starting point.

Integrate both forward and backward in time from starting point
Expand All @@ -226,15 +232,26 @@ def integrate(
occupied cell in the StreamMask. The resulting trajectory is
None if it is shorter than `minlength`.
"""
stotal, x_traj, y_traj = 0.0, [], []
self.domain_map.start_trajectory(x0, y0)
self.domain_map.reset_start_point(x0, y0)
stotal, x_traj, y_traj, m_total, hit_edge = self.integrate_rk12(
x_traj, y_traj, hit_edge = self.integrate_rk12(
x0, y0, forward_time, resolution, magnitude
)

if len(x_traj) > 1:
return (x_traj, y_traj), hit_edge
end_dx = x_traj[-1] - x_traj[-2]
end_dy = y_traj[-1] - y_traj[-2]
end_direction = (
None
if end_dx == 0 and end_dy == 0
else self.domain_map.grid2data(end_dx, end_dy)
)
return _CurvedQuiverTrajectory(
x=x_traj,
y=y_traj,
hit_edge=hit_edge,
end_direction=end_direction,
)
else:
# reject short trajectories
self.domain_map.undo_trajectory()
Expand All @@ -249,7 +266,7 @@ def integrate_rk12(
f: Callable[[float, float], tuple[float, float]],
resolution: float,
magnitude: np.ndarray,
) -> tuple[float, list[float], list[float], list[float], bool]:
) -> tuple[list[float], list[float], bool]:
"""2nd-order Runge-Kutta algorithm with adaptive step size.

This method is also referred to as the improved Euler's method, or
Expand Down Expand Up @@ -296,9 +313,14 @@ def integrate_rk12(
hit_edge = False

while self.domain_map.grid.within_grid(xi, yi):
try:
current_magnitude = self.interpgrid(magnitude, xi, yi)
except _CurvedQuiverTerminateTrajectory:
break

xf_traj.append(xi)
yf_traj.append(yi)
m_total.append(self.interpgrid(magnitude, xi, yi))
m_total.append(current_magnitude)

try:
k1x, k1y = f(xi, yi)
Expand All @@ -324,8 +346,15 @@ def integrate_rk12(

# Only save step if within error tolerance
if error < maxerror:
xi += dx2
yi += dy2
next_xi = xi + dx2
next_yi = yi + dy2
if self.domain_map.grid.within_grid(next_xi, next_yi):
try:
self.interpgrid(magnitude, next_xi, next_yi)
except _CurvedQuiverTerminateTrajectory:
break
xi = next_xi
yi = next_yi
self.domain_map.update_trajectory(xi, yi)
if not self.domain_map.grid.within_grid(xi, yi):
hit_edge = True
Expand All @@ -339,7 +368,7 @@ def integrate_rk12(
else:
ds = min(maxds, 0.85 * ds * (maxerror / error) ** 0.5)

return stotal, xf_traj, yf_traj, m_total, hit_edge
return xf_traj, yf_traj, hit_edge

def euler_step(self, xf_traj, yf_traj, f):
"""Simple Euler integration step that extends streamline to boundary."""
Expand Down Expand Up @@ -400,7 +429,7 @@ def interpgrid(self, a, xi, yi):

if not isinstance(xi, np.ndarray):
if np.ma.is_masked(ai):
raise _CurvedQuiverTerminateTrajectory
raise _CurvedQuiverTerminateTrajectory()
return ai

def gen_starting_points(self, x, y, grains):
Expand Down
97 changes: 95 additions & 2 deletions ultraplot/tests/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,6 +671,32 @@ def test_curved_quiver(rng):
return fig


def test_curved_quiver_integrator_skips_nan_seed():
"""
Test that masked seed points terminate cleanly instead of escaping the solver.
"""
from ultraplot.axes.plot_types.curved_quiver import CurvedQuiverSolver

x = np.linspace(0, 1, 5)
y = np.linspace(0, 1, 5)
u = np.ones((5, 5))
v = np.ones((5, 5))
u[2, 2] = np.nan
v[2, 2] = np.nan
u = np.ma.masked_invalid(u)
v = np.ma.masked_invalid(v)
magnitude = np.sqrt(u**2 + v**2)
magnitude /= np.max(magnitude)

solver = CurvedQuiverSolver(x, y, density=5)
integrator = solver.get_integrator(
u, v, minlength=0.1, resolution=1.0, magnitude=magnitude
)

assert integrator(2.0, 2.0) is None
assert not solver.mask._mask.any()


def test_validate_vector_shapes_pass():
"""
Test that vector shapes match the grid shape using CurvedQuiverSolver.
Expand Down Expand Up @@ -738,8 +764,8 @@ def test_generate_start_points():

def test_calculate_trajectories():
"""
Test that CurvedQuiverSolver.get_integrator returns callable for each seed point
and returns lists of trajectories and edges of correct length.
Test that CurvedQuiverSolver.get_integrator returns trajectory objects for each
seed point with the expected rendering metadata.
"""
from ultraplot.axes.plot_types.curved_quiver import CurvedQuiverSolver

Expand All @@ -755,6 +781,17 @@ def test_calculate_trajectories():
seeds = solver.gen_starting_points(x, y, grains=2)
results = [integrator(pt[0], pt[1]) for pt in seeds]
assert len(results) == seeds.shape[0]
trajectories = [result for result in results if result is not None]
assert trajectories
for trajectory in trajectories:
assert len(trajectory.x) == len(trajectory.y)
assert isinstance(trajectory.hit_edge, bool)
if trajectory.end_direction is not None:
expected = solver.domain_map.grid2data(
trajectory.x[-1] - trajectory.x[-2],
trajectory.y[-1] - trajectory.y[-2],
)
assert np.allclose(trajectory.end_direction, expected)


@pytest.mark.mpl_image_compare
Expand All @@ -779,6 +816,62 @@ def test_curved_quiver_multicolor_lines():
return fig


def test_curved_quiver_nan_vectors():
"""
Test that curved_quiver skips NaN vector regions without failing.
"""
x = np.linspace(-1, 1, 21)
y = np.linspace(-1, 1, 21)
X, Y = np.meshgrid(x, y)
U = -Y.copy()
V = X.copy()
speed = np.sqrt(U**2 + V**2)
invalid = (np.abs(X) < 0.2) & (np.abs(Y) < 0.2)
U[invalid] = np.nan
V[invalid] = np.nan
speed[invalid] = np.nan

fig, ax = uplt.subplots()
m = ax.curved_quiver(
X, Y, U, V, color=speed, arrow_at_end=True, scale=2.0, grains=10
)

segments = m.lines.get_segments()
assert segments
assert all(np.isfinite(segment).all() for segment in segments)
assert len(ax.patches) > 0
uplt.close(fig)


def test_curved_quiver_colorbar_argument():
"""
Test that curved_quiver forwards array colors to the shared colorbar guide path.
"""
x = np.linspace(-1, 1, 11)
y = np.linspace(-1, 1, 11)
X, Y = np.meshgrid(x, y)
U = -Y
V = X
speed = np.sqrt(U**2 + V**2)

fig, ax = uplt.subplots()
m = ax.curved_quiver(
X,
Y,
U,
V,
color=speed,
colorbar="r",
colorbar_kw={"label": "speed"},
)

assert ("right", "center") in ax[0]._colorbar_dict
cbar = ax[0]._colorbar_dict[("right", "center")]
assert cbar.mappable is m.lines
assert cbar.ax.get_ylabel() == "speed"
uplt.close(fig)


@pytest.mark.mpl_image_compare
@pytest.mark.parametrize(
"cmap",
Expand Down