diff --git a/git/refs/head.py b/git/refs/head.py index 683634451..b3bbd7e7f 100644 --- a/git/refs/head.py +++ b/git/refs/head.py @@ -17,7 +17,7 @@ # typing --------------------------------------------------- -from typing import Any, Sequence, TYPE_CHECKING, Union +from typing import Any, Sequence, TYPE_CHECKING, TypeVar, Union from git.types import Commit_ish, PathLike @@ -26,6 +26,8 @@ from git.refs import RemoteReference from git.repo import Repo +T_Heads = TypeVar("T_Heads", bound="Head") + # ------------------------------------------------------------------- @@ -124,6 +126,13 @@ def reset( return self + @property + def reference(self) -> "Head": + """Wrap the parent reference method to change the type hint.""" + return super().reference + + ref = reference + class Head(Reference): """A Head is a named reference to a :class:`~git.objects.commit.Commit`. Every Head @@ -301,4 +310,11 @@ def config_writer(self) -> SectionConstraint[GitConfigParser]: """ return self._config_parser(read_only=False) + @property + def reference(self) -> T_Heads: + """Wrap the parent reference method to change the type hint.""" + return super().reference + + ref = reference + # } END configuration diff --git a/git/refs/reference.py b/git/refs/reference.py index e5d473779..aad4ae8a0 100644 --- a/git/refs/reference.py +++ b/git/refs/reference.py @@ -9,13 +9,16 @@ # typing ------------------------------------------------------------------ -from typing import Any, Callable, Iterator, TYPE_CHECKING, Type, Union +from typing import Any, Callable, Iterator, TYPE_CHECKING, Type, TypeVar, Union from git.types import AnyGitObject, PathLike, _T if TYPE_CHECKING: from git.repo import Repo +# named this way to avoid collision with symbolic.T_References +T_NonsymbolicReferences = TypeVar("T_NonsymbolicReferences", bound="Reference") + # ------------------------------------------------------------------------------ # { Utilities @@ -173,4 +176,11 @@ def remote_head(self) -> str: tokens = self.path.split("/") return "/".join(tokens[3:]) + @property + def reference(self) -> T_NonsymbolicReferences: + """Wrap the parent reference method to change the type hint.""" + return super().reference + + ref = reference + # } END remote interface diff --git a/git/refs/remote.py b/git/refs/remote.py index b4f4f7b36..ced0894a0 100644 --- a/git/refs/remote.py +++ b/git/refs/remote.py @@ -77,3 +77,10 @@ def delete(cls, repo: "Repo", *refs: "RemoteReference", **kwargs: Any) -> None: def create(cls, *args: Any, **kwargs: Any) -> NoReturn: """Raise :exc:`TypeError`. Defined so the ``create`` method is disabled.""" raise TypeError("Cannot explicitly create remote references") + + @property + def reference(self) -> "RemoteReference": + """Wrap the parent reference method to change the type hint.""" + return super().reference + + ref = reference diff --git a/git/refs/symbolic.py b/git/refs/symbolic.py index 1b90a3115..e83cc9c69 100644 --- a/git/refs/symbolic.py +++ b/git/refs/symbolic.py @@ -404,7 +404,7 @@ def object(self) -> AnyGitObject: def object(self, object: Union[AnyGitObject, "SymbolicReference", str]) -> "SymbolicReference": return self.set_object(object) - def _get_reference(self) -> "SymbolicReference": + def _get_reference(self) -> T_References: """ :return: :class:`~git.refs.reference.Reference` object we point to @@ -502,7 +502,7 @@ def set_reference( # Aliased reference @property - def reference(self) -> "SymbolicReference": + def reference(self) -> T_References: return self._get_reference() @reference.setter