_linalg.pyi 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482
  1. from collections.abc import Iterable
  2. from typing import (
  3. Any,
  4. NamedTuple,
  5. Never,
  6. SupportsIndex,
  7. SupportsInt,
  8. TypeAlias,
  9. TypeVar,
  10. overload,
  11. )
  12. from typing import Literal as L
  13. import numpy as np
  14. from numpy import (
  15. complex128,
  16. complexfloating,
  17. float64,
  18. # other
  19. floating,
  20. int32,
  21. object_,
  22. signedinteger,
  23. timedelta64,
  24. unsignedinteger,
  25. # re-exports
  26. vecdot,
  27. )
  28. from numpy._core.fromnumeric import matrix_transpose
  29. from numpy._core.numeric import tensordot
  30. from numpy._typing import (
  31. ArrayLike,
  32. DTypeLike,
  33. NDArray,
  34. _ArrayLike,
  35. _ArrayLikeBool_co,
  36. _ArrayLikeComplex_co,
  37. _ArrayLikeFloat_co,
  38. _ArrayLikeInt_co,
  39. _ArrayLikeObject_co,
  40. _ArrayLikeTD64_co,
  41. _ArrayLikeUInt_co,
  42. )
  43. from numpy.linalg import LinAlgError
  44. __all__ = [
  45. "matrix_power",
  46. "solve",
  47. "tensorsolve",
  48. "tensorinv",
  49. "inv",
  50. "cholesky",
  51. "eigvals",
  52. "eigvalsh",
  53. "pinv",
  54. "slogdet",
  55. "det",
  56. "svd",
  57. "svdvals",
  58. "eig",
  59. "eigh",
  60. "lstsq",
  61. "norm",
  62. "qr",
  63. "cond",
  64. "matrix_rank",
  65. "LinAlgError",
  66. "multi_dot",
  67. "trace",
  68. "diagonal",
  69. "cross",
  70. "outer",
  71. "tensordot",
  72. "matmul",
  73. "matrix_transpose",
  74. "matrix_norm",
  75. "vector_norm",
  76. "vecdot",
  77. ]
  78. _ArrayT = TypeVar("_ArrayT", bound=NDArray[Any])
  79. _ModeKind: TypeAlias = L["reduced", "complete", "r", "raw"]
  80. ###
  81. fortran_int = np.intc
  82. class EigResult(NamedTuple):
  83. eigenvalues: NDArray[Any]
  84. eigenvectors: NDArray[Any]
  85. class EighResult(NamedTuple):
  86. eigenvalues: NDArray[Any]
  87. eigenvectors: NDArray[Any]
  88. class QRResult(NamedTuple):
  89. Q: NDArray[Any]
  90. R: NDArray[Any]
  91. class SlogdetResult(NamedTuple):
  92. # TODO: `sign` and `logabsdet` are scalars for input 2D arrays and
  93. # a `(x.ndim - 2)`` dimensionl arrays otherwise
  94. sign: Any
  95. logabsdet: Any
  96. class SVDResult(NamedTuple):
  97. U: NDArray[Any]
  98. S: NDArray[Any]
  99. Vh: NDArray[Any]
  100. @overload
  101. def tensorsolve(
  102. a: _ArrayLikeInt_co,
  103. b: _ArrayLikeInt_co,
  104. axes: Iterable[int] | None = ...,
  105. ) -> NDArray[float64]: ...
  106. @overload
  107. def tensorsolve(
  108. a: _ArrayLikeFloat_co,
  109. b: _ArrayLikeFloat_co,
  110. axes: Iterable[int] | None = ...,
  111. ) -> NDArray[floating]: ...
  112. @overload
  113. def tensorsolve(
  114. a: _ArrayLikeComplex_co,
  115. b: _ArrayLikeComplex_co,
  116. axes: Iterable[int] | None = ...,
  117. ) -> NDArray[complexfloating]: ...
  118. @overload
  119. def solve(
  120. a: _ArrayLikeInt_co,
  121. b: _ArrayLikeInt_co,
  122. ) -> NDArray[float64]: ...
  123. @overload
  124. def solve(
  125. a: _ArrayLikeFloat_co,
  126. b: _ArrayLikeFloat_co,
  127. ) -> NDArray[floating]: ...
  128. @overload
  129. def solve(
  130. a: _ArrayLikeComplex_co,
  131. b: _ArrayLikeComplex_co,
  132. ) -> NDArray[complexfloating]: ...
  133. @overload
  134. def tensorinv(
  135. a: _ArrayLikeInt_co,
  136. ind: int = ...,
  137. ) -> NDArray[float64]: ...
  138. @overload
  139. def tensorinv(
  140. a: _ArrayLikeFloat_co,
  141. ind: int = ...,
  142. ) -> NDArray[floating]: ...
  143. @overload
  144. def tensorinv(
  145. a: _ArrayLikeComplex_co,
  146. ind: int = ...,
  147. ) -> NDArray[complexfloating]: ...
  148. @overload
  149. def inv(a: _ArrayLikeInt_co) -> NDArray[float64]: ...
  150. @overload
  151. def inv(a: _ArrayLikeFloat_co) -> NDArray[floating]: ...
  152. @overload
  153. def inv(a: _ArrayLikeComplex_co) -> NDArray[complexfloating]: ...
  154. # TODO: The supported input and output dtypes are dependent on the value of `n`.
  155. # For example: `n < 0` always casts integer types to float64
  156. def matrix_power(
  157. a: _ArrayLikeComplex_co | _ArrayLikeObject_co,
  158. n: SupportsIndex,
  159. ) -> NDArray[Any]: ...
  160. @overload
  161. def cholesky(a: _ArrayLikeInt_co, /, *, upper: bool = False) -> NDArray[float64]: ...
  162. @overload
  163. def cholesky(a: _ArrayLikeFloat_co, /, *, upper: bool = False) -> NDArray[floating]: ...
  164. @overload
  165. def cholesky(a: _ArrayLikeComplex_co, /, *, upper: bool = False) -> NDArray[complexfloating]: ...
  166. @overload
  167. def outer(x1: _ArrayLike[Never], x2: _ArrayLike[Never]) -> NDArray[Any]: ...
  168. @overload
  169. def outer(x1: _ArrayLikeBool_co, x2: _ArrayLikeBool_co) -> NDArray[np.bool]: ...
  170. @overload
  171. def outer(x1: _ArrayLikeUInt_co, x2: _ArrayLikeUInt_co) -> NDArray[unsignedinteger]: ...
  172. @overload
  173. def outer(x1: _ArrayLikeInt_co, x2: _ArrayLikeInt_co) -> NDArray[signedinteger]: ...
  174. @overload
  175. def outer(x1: _ArrayLikeFloat_co, x2: _ArrayLikeFloat_co) -> NDArray[floating]: ...
  176. @overload
  177. def outer(
  178. x1: _ArrayLikeComplex_co,
  179. x2: _ArrayLikeComplex_co,
  180. ) -> NDArray[complexfloating]: ...
  181. @overload
  182. def outer(
  183. x1: _ArrayLikeTD64_co,
  184. x2: _ArrayLikeTD64_co,
  185. out: None = ...,
  186. ) -> NDArray[timedelta64]: ...
  187. @overload
  188. def outer(x1: _ArrayLikeObject_co, x2: _ArrayLikeObject_co) -> NDArray[object_]: ...
  189. @overload
  190. def outer(
  191. x1: _ArrayLikeComplex_co | _ArrayLikeTD64_co | _ArrayLikeObject_co,
  192. x2: _ArrayLikeComplex_co | _ArrayLikeTD64_co | _ArrayLikeObject_co,
  193. ) -> _ArrayT: ...
  194. @overload
  195. def qr(a: _ArrayLikeInt_co, mode: _ModeKind = ...) -> QRResult: ...
  196. @overload
  197. def qr(a: _ArrayLikeFloat_co, mode: _ModeKind = ...) -> QRResult: ...
  198. @overload
  199. def qr(a: _ArrayLikeComplex_co, mode: _ModeKind = ...) -> QRResult: ...
  200. @overload
  201. def eigvals(a: _ArrayLikeInt_co) -> NDArray[float64] | NDArray[complex128]: ...
  202. @overload
  203. def eigvals(a: _ArrayLikeFloat_co) -> NDArray[floating] | NDArray[complexfloating]: ...
  204. @overload
  205. def eigvals(a: _ArrayLikeComplex_co) -> NDArray[complexfloating]: ...
  206. @overload
  207. def eigvalsh(a: _ArrayLikeInt_co, UPLO: L["L", "U", "l", "u"] = ...) -> NDArray[float64]: ...
  208. @overload
  209. def eigvalsh(a: _ArrayLikeComplex_co, UPLO: L["L", "U", "l", "u"] = ...) -> NDArray[floating]: ...
  210. @overload
  211. def eig(a: _ArrayLikeInt_co) -> EigResult: ...
  212. @overload
  213. def eig(a: _ArrayLikeFloat_co) -> EigResult: ...
  214. @overload
  215. def eig(a: _ArrayLikeComplex_co) -> EigResult: ...
  216. @overload
  217. def eigh(
  218. a: _ArrayLikeInt_co,
  219. UPLO: L["L", "U", "l", "u"] = ...,
  220. ) -> EighResult: ...
  221. @overload
  222. def eigh(
  223. a: _ArrayLikeFloat_co,
  224. UPLO: L["L", "U", "l", "u"] = ...,
  225. ) -> EighResult: ...
  226. @overload
  227. def eigh(
  228. a: _ArrayLikeComplex_co,
  229. UPLO: L["L", "U", "l", "u"] = ...,
  230. ) -> EighResult: ...
  231. @overload
  232. def svd(
  233. a: _ArrayLikeInt_co,
  234. full_matrices: bool = ...,
  235. compute_uv: L[True] = ...,
  236. hermitian: bool = ...,
  237. ) -> SVDResult: ...
  238. @overload
  239. def svd(
  240. a: _ArrayLikeFloat_co,
  241. full_matrices: bool = ...,
  242. compute_uv: L[True] = ...,
  243. hermitian: bool = ...,
  244. ) -> SVDResult: ...
  245. @overload
  246. def svd(
  247. a: _ArrayLikeComplex_co,
  248. full_matrices: bool = ...,
  249. compute_uv: L[True] = ...,
  250. hermitian: bool = ...,
  251. ) -> SVDResult: ...
  252. @overload
  253. def svd(
  254. a: _ArrayLikeInt_co,
  255. full_matrices: bool = ...,
  256. compute_uv: L[False] = ...,
  257. hermitian: bool = ...,
  258. ) -> NDArray[float64]: ...
  259. @overload
  260. def svd(
  261. a: _ArrayLikeComplex_co,
  262. full_matrices: bool = ...,
  263. compute_uv: L[False] = ...,
  264. hermitian: bool = ...,
  265. ) -> NDArray[floating]: ...
  266. def svdvals(
  267. x: _ArrayLikeInt_co | _ArrayLikeFloat_co | _ArrayLikeComplex_co
  268. ) -> NDArray[floating]: ...
  269. # TODO: Returns a scalar for 2D arrays and
  270. # a `(x.ndim - 2)`` dimensionl array otherwise
  271. def cond(x: _ArrayLikeComplex_co, p: float | L["fro", "nuc"] | None = ...) -> Any: ...
  272. # TODO: Returns `int` for <2D arrays and `intp` otherwise
  273. def matrix_rank(
  274. A: _ArrayLikeComplex_co,
  275. tol: _ArrayLikeFloat_co | None = ...,
  276. hermitian: bool = ...,
  277. *,
  278. rtol: _ArrayLikeFloat_co | None = ...,
  279. ) -> Any: ...
  280. @overload
  281. def pinv(
  282. a: _ArrayLikeInt_co,
  283. rcond: _ArrayLikeFloat_co = ...,
  284. hermitian: bool = ...,
  285. ) -> NDArray[float64]: ...
  286. @overload
  287. def pinv(
  288. a: _ArrayLikeFloat_co,
  289. rcond: _ArrayLikeFloat_co = ...,
  290. hermitian: bool = ...,
  291. ) -> NDArray[floating]: ...
  292. @overload
  293. def pinv(
  294. a: _ArrayLikeComplex_co,
  295. rcond: _ArrayLikeFloat_co = ...,
  296. hermitian: bool = ...,
  297. ) -> NDArray[complexfloating]: ...
  298. # TODO: Returns a 2-tuple of scalars for 2D arrays and
  299. # a 2-tuple of `(a.ndim - 2)`` dimensionl arrays otherwise
  300. def slogdet(a: _ArrayLikeComplex_co) -> SlogdetResult: ...
  301. # TODO: Returns a 2-tuple of scalars for 2D arrays and
  302. # a 2-tuple of `(a.ndim - 2)`` dimensionl arrays otherwise
  303. def det(a: _ArrayLikeComplex_co) -> Any: ...
  304. @overload
  305. def lstsq(a: _ArrayLikeInt_co, b: _ArrayLikeInt_co, rcond: float | None = ...) -> tuple[
  306. NDArray[float64],
  307. NDArray[float64],
  308. int32,
  309. NDArray[float64],
  310. ]: ...
  311. @overload
  312. def lstsq(a: _ArrayLikeFloat_co, b: _ArrayLikeFloat_co, rcond: float | None = ...) -> tuple[
  313. NDArray[floating],
  314. NDArray[floating],
  315. int32,
  316. NDArray[floating],
  317. ]: ...
  318. @overload
  319. def lstsq(a: _ArrayLikeComplex_co, b: _ArrayLikeComplex_co, rcond: float | None = ...) -> tuple[
  320. NDArray[complexfloating],
  321. NDArray[floating],
  322. int32,
  323. NDArray[floating],
  324. ]: ...
  325. @overload
  326. def norm(
  327. x: ArrayLike,
  328. ord: float | L["fro", "nuc"] | None = ...,
  329. axis: None = ...,
  330. keepdims: bool = ...,
  331. ) -> floating: ...
  332. @overload
  333. def norm(
  334. x: ArrayLike,
  335. ord: float | L["fro", "nuc"] | None = ...,
  336. axis: SupportsInt | SupportsIndex | tuple[int, ...] = ...,
  337. keepdims: bool = ...,
  338. ) -> Any: ...
  339. @overload
  340. def matrix_norm(
  341. x: ArrayLike,
  342. /,
  343. *,
  344. ord: float | L["fro", "nuc"] | None = ...,
  345. keepdims: bool = ...,
  346. ) -> floating: ...
  347. @overload
  348. def matrix_norm(
  349. x: ArrayLike,
  350. /,
  351. *,
  352. ord: float | L["fro", "nuc"] | None = ...,
  353. keepdims: bool = ...,
  354. ) -> Any: ...
  355. @overload
  356. def vector_norm(
  357. x: ArrayLike,
  358. /,
  359. *,
  360. axis: None = ...,
  361. ord: float | None = ...,
  362. keepdims: bool = ...,
  363. ) -> floating: ...
  364. @overload
  365. def vector_norm(
  366. x: ArrayLike,
  367. /,
  368. *,
  369. axis: SupportsInt | SupportsIndex | tuple[int, ...] = ...,
  370. ord: float | None = ...,
  371. keepdims: bool = ...,
  372. ) -> Any: ...
  373. # TODO: Returns a scalar or array
  374. def multi_dot(
  375. arrays: Iterable[_ArrayLikeComplex_co | _ArrayLikeObject_co | _ArrayLikeTD64_co],
  376. *,
  377. out: NDArray[Any] | None = ...,
  378. ) -> Any: ...
  379. def diagonal(
  380. x: ArrayLike, # >= 2D array
  381. /,
  382. *,
  383. offset: SupportsIndex = ...,
  384. ) -> NDArray[Any]: ...
  385. def trace(
  386. x: ArrayLike, # >= 2D array
  387. /,
  388. *,
  389. offset: SupportsIndex = ...,
  390. dtype: DTypeLike = ...,
  391. ) -> Any: ...
  392. @overload
  393. def cross(
  394. x1: _ArrayLikeUInt_co,
  395. x2: _ArrayLikeUInt_co,
  396. /,
  397. *,
  398. axis: int = ...,
  399. ) -> NDArray[unsignedinteger]: ...
  400. @overload
  401. def cross(
  402. x1: _ArrayLikeInt_co,
  403. x2: _ArrayLikeInt_co,
  404. /,
  405. *,
  406. axis: int = ...,
  407. ) -> NDArray[signedinteger]: ...
  408. @overload
  409. def cross(
  410. x1: _ArrayLikeFloat_co,
  411. x2: _ArrayLikeFloat_co,
  412. /,
  413. *,
  414. axis: int = ...,
  415. ) -> NDArray[floating]: ...
  416. @overload
  417. def cross(
  418. x1: _ArrayLikeComplex_co,
  419. x2: _ArrayLikeComplex_co,
  420. /,
  421. *,
  422. axis: int = ...,
  423. ) -> NDArray[complexfloating]: ...
  424. @overload
  425. def matmul(
  426. x1: _ArrayLikeInt_co,
  427. x2: _ArrayLikeInt_co,
  428. ) -> NDArray[signedinteger]: ...
  429. @overload
  430. def matmul(
  431. x1: _ArrayLikeUInt_co,
  432. x2: _ArrayLikeUInt_co,
  433. ) -> NDArray[unsignedinteger]: ...
  434. @overload
  435. def matmul(
  436. x1: _ArrayLikeFloat_co,
  437. x2: _ArrayLikeFloat_co,
  438. ) -> NDArray[floating]: ...
  439. @overload
  440. def matmul(
  441. x1: _ArrayLikeComplex_co,
  442. x2: _ArrayLikeComplex_co,
  443. ) -> NDArray[complexfloating]: ...