From 25595323edef39d6eba4accec94411b75c73cdd0 Mon Sep 17 00:00:00 2001 From: zero323 Date: Sun, 15 Nov 2020 17:29:32 +0100 Subject: [PATCH 01/26] Update mypy.ini --- python/mypy.ini | 87 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 87 insertions(+) diff --git a/python/mypy.ini b/python/mypy.ini index 4a5368a519097..5103452a053be 100644 --- a/python/mypy.ini +++ b/python/mypy.ini @@ -16,10 +16,97 @@ ; [mypy] +strict_optional = True +no_implicit_optional = True +disallow_untyped_defs = True + +; Allow untyped def in internal modules and tests + +[mypy-pyspark.daemon] +disallow_untyped_defs = False + +[mypy-pyspark.find_spark_home] +disallow_untyped_defs = False + +[mypy-pyspark._globals] +disallow_untyped_defs = False + +[mypy-pyspark.install] +disallow_untyped_defs = False + +[mypy-pyspark.java_gateway] +disallow_untyped_defs = False + +[mypy-pyspark.join] +disallow_untyped_defs = False + +[mypy-pyspark.ml.tests.*] +disallow_untyped_defs = False + +[mypy-pyspark.mllib.tests.*] +disallow_untyped_defs = False + +[mypy-pyspark.rddsampler] +disallow_untyped_defs = False + +[mypy-pyspark.resource.tests.*] +disallow_untyped_defs = False + +[mypy-pyspark.serializers] +disallow_untyped_defs = False + +[mypy-pyspark.shuffle] +disallow_untyped_defs = False + +[mypy-pyspark.streaming.tests.*] +disallow_untyped_defs = False + +[mypy-pyspark.streaming.util] +disallow_untyped_defs = False + +[mypy-pyspark.sql.tests.*] +disallow_untyped_defs = False + +[mypy-pyspark.sql.pandas.serializers] +disallow_untyped_defs = False + +[mypy-pyspark.sql.pandas.types] +disallow_untyped_defs = False + +[mypy-pyspark.sql.pandas.typehints] +disallow_untyped_defs = False + +[mypy-pyspark.sql.pandas.utils] +disallow_untyped_defs = False + +[mypy-pyspark.sql.pandas._typing.protocols.*] +disallow_untyped_defs = False + +[mypy-pyspark.sql.utils] +disallow_untyped_defs = False + +[mypy-pyspark.tests.*] +disallow_untyped_defs = False + +[mypy-pyspark.testing.*] +disallow_untyped_defs = False + +[mypy-pyspark.traceback_utils] +disallow_untyped_defs = False + +[mypy-pyspark.util] +disallow_untyped_defs = False + +[mypy-pyspark.worker] +disallow_untyped_defs = False + +; Ignore errors in embedded third party code [mypy-pyspark.cloudpickle.*] ignore_errors = True +; Ignore missing imports for external untyped packages + [mypy-py4j.*] ignore_missing_imports = True From 41e62f829f981f64e8405606cc1b7ac1d1bf650b Mon Sep 17 00:00:00 2001 From: zero323 Date: Sun, 15 Nov 2020 20:52:02 +0100 Subject: [PATCH 02/26] Adjust pyspark.sql.types --- python/pyspark/sql/types.pyi | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/python/pyspark/sql/types.pyi b/python/pyspark/sql/types.pyi index 31765e94884d7..3adf823d99a82 100644 --- a/python/pyspark/sql/types.pyi +++ b/python/pyspark/sql/types.pyi @@ -17,7 +17,8 @@ # under the License. from typing import overload -from typing import Any, Callable, Dict, Iterator, List, Optional, Union, Tuple, TypeVar +from typing import Any, Callable, Dict, Iterator, List, Optional, Union, Tuple, Type, TypeVar +from py4j.java_gateway import JavaGateway, JavaObject import datetime T = TypeVar("T") @@ -37,7 +38,7 @@ class DataType: def fromInternal(self, obj: Any) -> Any: ... class DataTypeSingleton(type): - def __call__(cls): ... + def __call__(cls: Type[T]) -> T: ... # type: ignore class NullType(DataType, metaclass=DataTypeSingleton): ... class AtomicType(DataType): ... @@ -85,8 +86,8 @@ class ShortType(IntegralType): class ArrayType(DataType): elementType: DataType containsNull: bool - def __init__(self, elementType=DataType, containsNull: bool = ...) -> None: ... - def simpleString(self): ... + def __init__(self, elementType: DataType, containsNull: bool = ...) -> None: ... + def simpleString(self) -> str: ... def jsonValue(self) -> Dict[str, Any]: ... @classmethod def fromJson(cls, json: Dict[str, Any]) -> ArrayType: ... @@ -197,8 +198,8 @@ class Row(tuple): class DateConverter: def can_convert(self, obj: Any) -> bool: ... - def convert(self, obj, gateway_client) -> Any: ... + def convert(self, obj: datetime.date, gateway_client: JavaGateway) -> JavaObject: ... class DatetimeConverter: - def can_convert(self, obj) -> bool: ... - def convert(self, obj, gateway_client) -> Any: ... + def can_convert(self, obj: Any) -> bool: ... + def convert(self, obj: datetime.datetime, gateway_client: JavaGateway) -> JavaObject: ... From 091def9f40c8cf748d38d1f97feb4bd32e458771 Mon Sep 17 00:00:00 2001 From: zero323 Date: Sun, 15 Nov 2020 21:19:50 +0100 Subject: [PATCH 03/26] Adjust pyspark.sql.column --- python/pyspark/sql/column.pyi | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/column.pyi b/python/pyspark/sql/column.pyi index 0fbb10053fdbf..1f63e65b3de81 100644 --- a/python/pyspark/sql/column.pyi +++ b/python/pyspark/sql/column.pyi @@ -32,7 +32,7 @@ from pyspark.sql.window import WindowSpec from py4j.java_gateway import JavaObject # type: ignore[import] class Column: - def __init__(self, JavaObject) -> None: ... + def __init__(self, jc: JavaObject) -> None: ... def __neg__(self) -> Column: ... def __add__(self, other: Union[Column, LiteralType, DecimalLiteral]) -> Column: ... def __sub__(self, other: Union[Column, LiteralType, DecimalLiteral]) -> Column: ... @@ -105,7 +105,11 @@ class Column: def name(self, *alias: str) -> Column: ... def cast(self, dataType: Union[DataType, str]) -> Column: ... def astype(self, dataType: Union[DataType, str]) -> Column: ... - def between(self, lowerBound, upperBound) -> Column: ... + def between( + self, + lowerBound: Union[Column, LiteralType, DateTimeLiteral, DecimalLiteral], + upperBound: Union[Column, LiteralType, DateTimeLiteral, DecimalLiteral], + ) -> Column: ... def when(self, condition: Column, value: Any) -> Column: ... def otherwise(self, value: Any) -> Column: ... def over(self, window: WindowSpec) -> Column: ... From 1967c70a72b47219657e1997bb38974ab560275f Mon Sep 17 00:00:00 2001 From: zero323 Date: Sun, 15 Nov 2020 21:27:57 +0100 Subject: [PATCH 04/26] Adjust pyspark.sql.context --- python/pyspark/sql/context.pyi | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/context.pyi b/python/pyspark/sql/context.pyi index 64927b37ac2a9..915a0fe1f6709 100644 --- a/python/pyspark/sql/context.pyi +++ b/python/pyspark/sql/context.pyi @@ -43,14 +43,14 @@ class SQLContext: sparkSession: SparkSession def __init__( self, - sparkContext, + sparkContext: SparkContext, sparkSession: Optional[SparkSession] = ..., jsqlContext: Optional[JavaObject] = ..., ) -> None: ... @classmethod def getOrCreate(cls: type, sc: SparkContext) -> SQLContext: ... def newSession(self) -> SQLContext: ... - def setConf(self, key: str, value) -> None: ... + def setConf(self, key: str, value: Union[bool, int, str]) -> None: ... def getConf(self, key: str, defaultValue: Optional[str] = ...) -> str: ... @property def udf(self) -> UDFRegistration: ... @@ -116,7 +116,7 @@ class SQLContext: path: Optional[str] = ..., source: Optional[str] = ..., schema: Optional[StructType] = ..., - **options + **options: str ) -> DataFrame: ... def sql(self, sqlQuery: str) -> DataFrame: ... def table(self, tableName: str) -> DataFrame: ... From 589a117204a48c0dfd6f4a8470bb957da0eb7d0f Mon Sep 17 00:00:00 2001 From: zero323 Date: Sun, 15 Nov 2020 21:38:45 +0100 Subject: [PATCH 05/26] Adjust pyspark.sql.session --- python/pyspark/sql/session.pyi | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/session.pyi b/python/pyspark/sql/session.pyi index 17ba8894c1731..6cd2d3bed2b2f 100644 --- a/python/pyspark/sql/session.pyi +++ b/python/pyspark/sql/session.pyi @@ -17,7 +17,8 @@ # under the License. from typing import overload -from typing import Any, Iterable, List, Optional, Tuple, TypeVar, Union +from typing import Any, Iterable, List, Optional, Tuple, Type, TypeVar, Union +from types import TracebackType from py4j.java_gateway import JavaObject # type: ignore[import] @@ -122,4 +123,9 @@ class SparkSession(SparkConversionMixin): def streams(self) -> StreamingQueryManager: ... def stop(self) -> None: ... def __enter__(self) -> SparkSession: ... - def __exit__(self, exc_type, exc_val, exc_tb) -> None: ... + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: ... From ba2f24a93257afca9838b493fc2a79b78376afd1 Mon Sep 17 00:00:00 2001 From: zero323 Date: Sun, 15 Nov 2020 21:51:50 +0100 Subject: [PATCH 06/26] Adjust pyspark.sql.udf --- python/pyspark/sql/udf.pyi | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/udf.pyi b/python/pyspark/sql/udf.pyi index 87c3672780037..ea61397a67ba1 100644 --- a/python/pyspark/sql/udf.pyi +++ b/python/pyspark/sql/udf.pyi @@ -18,8 +18,9 @@ from typing import Any, Callable, Optional -from pyspark.sql._typing import ColumnOrName, DataTypeOrString +from pyspark.sql._typing import ColumnOrName, DataTypeOrString, UserDefinedFunctionLike from pyspark.sql.column import Column +from pyspark.sql.types import DataType import pyspark.sql.session class UserDefinedFunction: @@ -35,7 +36,7 @@ class UserDefinedFunction: deterministic: bool = ..., ) -> None: ... @property - def returnType(self): ... + def returnType(self) -> DataType: ... def __call__(self, *cols: ColumnOrName) -> Column: ... def asNondeterministic(self) -> UserDefinedFunction: ... @@ -47,7 +48,7 @@ class UDFRegistration: name: str, f: Callable[..., Any], returnType: Optional[DataTypeOrString] = ..., - ): ... + ) -> UserDefinedFunctionLike: ... def registerJavaFunction( self, name: str, From d5a2995b1c5e7373b214bc4d900c1181408783f2 Mon Sep 17 00:00:00 2001 From: zero323 Date: Sun, 15 Nov 2020 21:58:18 +0100 Subject: [PATCH 07/26] Adjust pyspark.sql.functions --- python/pyspark/sql/functions.pyi | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/functions.pyi b/python/pyspark/sql/functions.pyi index 281c1d75436c6..252f883b5fb09 100644 --- a/python/pyspark/sql/functions.pyi +++ b/python/pyspark/sql/functions.pyi @@ -65,13 +65,13 @@ def round(col: ColumnOrName, scale: int = ...) -> Column: ... def bround(col: ColumnOrName, scale: int = ...) -> Column: ... def shiftLeft(col: ColumnOrName, numBits: int) -> Column: ... def shiftRight(col: ColumnOrName, numBits: int) -> Column: ... -def shiftRightUnsigned(col, numBits) -> Column: ... +def shiftRightUnsigned(col: ColumnOrName, numBits: int) -> Column: ... def spark_partition_id() -> Column: ... def expr(str: str) -> Column: ... def struct(*cols: ColumnOrName) -> Column: ... def greatest(*cols: ColumnOrName) -> Column: ... def least(*cols: Column) -> Column: ... -def when(condition: Column, value) -> Column: ... +def when(condition: Column, value: Any) -> Column: ... @overload def log(arg1: ColumnOrName) -> Column: ... @overload @@ -174,7 +174,9 @@ def create_map(*cols: ColumnOrName) -> Column: ... def array(*cols: ColumnOrName) -> Column: ... def array_contains(col: ColumnOrName, value: Any) -> Column: ... def arrays_overlap(a1: ColumnOrName, a2: ColumnOrName) -> Column: ... -def slice(x: ColumnOrName, start: Union[Column, int], length: Union[Column, int]) -> Column: ... +def slice( + x: ColumnOrName, start: Union[Column, int], length: Union[Column, int] +) -> Column: ... def array_join( col: ColumnOrName, delimiter: str, null_replacement: Optional[str] = ... ) -> Column: ... From c7089dab5baa28c6d113094737acc6332ce3c8ea Mon Sep 17 00:00:00 2001 From: zero323 Date: Sun, 15 Nov 2020 22:08:23 +0100 Subject: [PATCH 08/26] Adjust pyspark.broadcast --- python/pyspark/broadcast.pyi | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/python/pyspark/broadcast.pyi b/python/pyspark/broadcast.pyi index 4b019a509a003..944cb06d4178c 100644 --- a/python/pyspark/broadcast.pyi +++ b/python/pyspark/broadcast.pyi @@ -17,7 +17,7 @@ # under the License. import threading -from typing import Any, Dict, Generic, Optional, TypeVar +from typing import Any, Callable, Dict, Generic, Optional, Tuple, TypeVar T = TypeVar("T") @@ -32,14 +32,14 @@ class Broadcast(Generic[T]): path: Optional[Any] = ..., sock_file: Optional[Any] = ..., ) -> None: ... - def dump(self, value: Any, f: Any) -> None: ... - def load_from_path(self, path: Any): ... - def load(self, file: Any): ... + def dump(self, value: T, f: Any) -> None: ... + def load_from_path(self, path: Any) -> T: ... + def load(self, file: Any) -> T: ... @property def value(self) -> T: ... def unpersist(self, blocking: bool = ...) -> None: ... def destroy(self, blocking: bool = ...) -> None: ... - def __reduce__(self): ... + def __reduce__(self) -> Tuple[Callable[[int], T], Tuple[int]]: ... class BroadcastPickleRegistry(threading.local): def __init__(self) -> None: ... From 51ded72ca19d9d0216806222e19728c314c67ce2 Mon Sep 17 00:00:00 2001 From: zero323 Date: Sun, 15 Nov 2020 22:15:41 +0100 Subject: [PATCH 09/26] Adjust pyspark.context --- python/pyspark/context.pyi | 25 +++++++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/python/pyspark/context.pyi b/python/pyspark/context.pyi index 2789a38b3be9f..640a69cad08ab 100644 --- a/python/pyspark/context.pyi +++ b/python/pyspark/context.pyi @@ -16,7 +16,19 @@ # specific language governing permissions and limitations # under the License. -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, TypeVar +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + NoReturn, + Optional, + Tuple, + Type, + TypeVar, +) +from types import TracebackType from py4j.java_gateway import JavaGateway, JavaObject # type: ignore[import] @@ -51,9 +63,14 @@ class SparkContext: jsc: Optional[JavaObject] = ..., profiler_cls: type = ..., ) -> None: ... - def __getnewargs__(self): ... - def __enter__(self): ... - def __exit__(self, type, value, trace): ... + def __getnewargs__(self) -> NoReturn: ... + def __enter__(self) -> SparkContext: ... + def __exit__( + self, + type: Optional[Type[BaseException]], + value: Optional[BaseException], + trace: Optional[TracebackType], + ) -> None: ... @classmethod def getOrCreate(cls, conf: Optional[SparkConf] = ...) -> SparkContext: ... def setLogLevel(self, logLevel: str) -> None: ... From 0dde1116c5f3a9d9746c684558e2b394938cdee3 Mon Sep 17 00:00:00 2001 From: zero323 Date: Sun, 15 Nov 2020 22:38:48 +0100 Subject: [PATCH 10/26] Adjust pyspark.ml.linalg --- python/pyspark/ml/linalg/__init__.pyi | 36 ++++++++++++--------------- 1 file changed, 16 insertions(+), 20 deletions(-) diff --git a/python/pyspark/ml/linalg/__init__.pyi b/python/pyspark/ml/linalg/__init__.pyi index a576b30aec308..b4fba8823b678 100644 --- a/python/pyspark/ml/linalg/__init__.pyi +++ b/python/pyspark/ml/linalg/__init__.pyi @@ -17,7 +17,7 @@ # under the License. from typing import overload -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Dict, Iterable, List, NoReturn, Optional, Tuple, Type, Union from pyspark.ml import linalg as newlinalg # noqa: F401 from pyspark.sql.types import StructType, UserDefinedType @@ -45,7 +45,7 @@ class MatrixUDT(UserDefinedType): @classmethod def scalaUDT(cls) -> str: ... def serialize( - self, obj + self, obj: Matrix ) -> Tuple[ int, int, int, Optional[List[int]], Optional[List[int]], List[float], bool ]: ... @@ -64,9 +64,7 @@ class DenseVector(Vector): def __init__(self, __arr: bytes) -> None: ... @overload def __init__(self, __arr: Iterable[float]) -> None: ... - @staticmethod - def parse(s) -> DenseVector: ... - def __reduce__(self) -> Tuple[type, bytes]: ... + def __reduce__(self) -> Tuple[Type[DenseVector], bytes]: ... def numNonzeros(self) -> int: ... def norm(self, p: Union[float, str]) -> float64: ... def dot(self, other: Iterable[float]) -> float64: ... @@ -112,16 +110,14 @@ class SparseVector(Vector): def __init__(self, size: int, __map: Dict[int, float]) -> None: ... def numNonzeros(self) -> int: ... def norm(self, p: Union[float, str]) -> float64: ... - def __reduce__(self): ... - @staticmethod - def parse(s: str) -> SparseVector: ... + def __reduce__(self) -> Tuple[Type[SparseVector], Tuple[int, bytes, bytes]]: ... def dot(self, other: Iterable[float]) -> float64: ... def squared_distance(self, other: Iterable[float]) -> float64: ... def toArray(self) -> ndarray: ... def __len__(self) -> int: ... - def __eq__(self, other) -> bool: ... + def __eq__(self, other: Any) -> bool: ... def __getitem__(self, index: int) -> float64: ... - def __ne__(self, other) -> bool: ... + def __ne__(self, other: Any) -> bool: ... def __hash__(self) -> int: ... class Vectors: @@ -144,13 +140,13 @@ class Vectors: def sparse(size: int, __map: Dict[int, float]) -> SparseVector: ... @overload @staticmethod - def dense(self, *elements: float) -> DenseVector: ... + def dense(*elements: float) -> DenseVector: ... @overload @staticmethod - def dense(self, __arr: bytes) -> DenseVector: ... + def dense(__arr: bytes) -> DenseVector: ... @overload @staticmethod - def dense(self, __arr: Iterable[float]) -> DenseVector: ... + def dense(__arr: Iterable[float]) -> DenseVector: ... @staticmethod def stringify(vector: Vector) -> str: ... @staticmethod @@ -158,8 +154,6 @@ class Vectors: @staticmethod def norm(vector: Vector, p: Union[float, str]) -> float64: ... @staticmethod - def parse(s: str) -> Vector: ... - @staticmethod def zeros(size: int) -> DenseVector: ... class Matrix: @@ -170,7 +164,7 @@ class Matrix: def __init__( self, numRows: int, numCols: int, isTransposed: bool = ... ) -> None: ... - def toArray(self): ... + def toArray(self) -> NoReturn: ... class DenseMatrix(Matrix): values: Any @@ -186,11 +180,11 @@ class DenseMatrix(Matrix): values: Iterable[float], isTransposed: bool = ..., ) -> None: ... - def __reduce__(self) -> Tuple[type, Tuple[int, int, bytes, int]]: ... + def __reduce__(self) -> Tuple[Type[DenseMatrix], Tuple[int, int, bytes, int]]: ... def toArray(self) -> ndarray: ... def toSparse(self) -> SparseMatrix: ... def __getitem__(self, indices: Tuple[int, int]) -> float64: ... - def __eq__(self, other) -> bool: ... + def __eq__(self, other: Any) -> bool: ... class SparseMatrix(Matrix): colPtrs: ndarray @@ -216,11 +210,13 @@ class SparseMatrix(Matrix): values: Iterable[float], isTransposed: bool = ..., ) -> None: ... - def __reduce__(self) -> Tuple[type, Tuple[int, int, bytes, bytes, bytes, int]]: ... + def __reduce__( + self, + ) -> Tuple[Type[SparseMatrix], Tuple[int, int, bytes, bytes, bytes, int]]: ... def __getitem__(self, indices: Tuple[int, int]) -> float64: ... def toArray(self) -> ndarray: ... def toDense(self) -> DenseMatrix: ... - def __eq__(self, other) -> bool: ... + def __eq__(self, other: Any) -> bool: ... class Matrices: @overload From 553d6a8697e712221e5eab9762cc61f49c63531b Mon Sep 17 00:00:00 2001 From: zero323 Date: Sun, 15 Nov 2020 22:53:04 +0100 Subject: [PATCH 11/26] Adjust pyspark.mllib.linalg --- python/pyspark/mllib/linalg/__init__.pyi | 45 +++++++++++++++--------- 1 file changed, 29 insertions(+), 16 deletions(-) diff --git a/python/pyspark/mllib/linalg/__init__.pyi b/python/pyspark/mllib/linalg/__init__.pyi index c0719c535c8f4..60d16b26f3590 100644 --- a/python/pyspark/mllib/linalg/__init__.pyi +++ b/python/pyspark/mllib/linalg/__init__.pyi @@ -17,7 +17,18 @@ # under the License. from typing import overload -from typing import Any, Dict, Generic, Iterable, List, Optional, Tuple, TypeVar, Union +from typing import ( + Any, + Dict, + Generic, + Iterable, + List, + Optional, + Tuple, + Type, + TypeVar, + Union, +) from pyspark.ml import linalg as newlinalg from pyspark.sql.types import StructType, UserDefinedType from numpy import float64, ndarray # type: ignore[import] @@ -46,7 +57,7 @@ class MatrixUDT(UserDefinedType): @classmethod def scalaUDT(cls) -> str: ... def serialize( - self, obj + self, obj: Matrix ) -> Tuple[ int, int, int, Optional[List[int]], Optional[List[int]], List[float], bool ]: ... @@ -67,8 +78,8 @@ class DenseVector(Vector): @overload def __init__(self, __arr: Iterable[float]) -> None: ... @staticmethod - def parse(s) -> DenseVector: ... - def __reduce__(self) -> Tuple[type, bytes]: ... + def parse(s: str) -> DenseVector: ... + def __reduce__(self) -> Tuple[Type[DenseVector], bytes]: ... def numNonzeros(self) -> int: ... def norm(self, p: Union[float, str]) -> float64: ... def dot(self, other: Iterable[float]) -> float64: ... @@ -115,7 +126,7 @@ class SparseVector(Vector): def __init__(self, size: int, __map: Dict[int, float]) -> None: ... def numNonzeros(self) -> int: ... def norm(self, p: Union[float, str]) -> float64: ... - def __reduce__(self): ... + def __reduce__(self) -> Tuple[Type[SparseVector], Tuple[int, bytes, bytes]]: ... @staticmethod def parse(s: str) -> SparseVector: ... def dot(self, other: Iterable[float]) -> float64: ... @@ -123,9 +134,9 @@ class SparseVector(Vector): def toArray(self) -> ndarray: ... def asML(self) -> newlinalg.SparseVector: ... def __len__(self) -> int: ... - def __eq__(self, other) -> bool: ... + def __eq__(self, other: Any) -> bool: ... def __getitem__(self, index: int) -> float64: ... - def __ne__(self, other) -> bool: ... + def __ne__(self, other: Any) -> bool: ... def __hash__(self) -> int: ... class Vectors: @@ -148,13 +159,13 @@ class Vectors: def sparse(size: int, __map: Dict[int, float]) -> SparseVector: ... @overload @staticmethod - def dense(self, *elements: float) -> DenseVector: ... + def dense(*elements: float) -> DenseVector: ... @overload @staticmethod - def dense(self, __arr: bytes) -> DenseVector: ... + def dense(__arr: bytes) -> DenseVector: ... @overload @staticmethod - def dense(self, __arr: Iterable[float]) -> DenseVector: ... + def dense(__arr: Iterable[float]) -> DenseVector: ... @staticmethod def fromML(vec: newlinalg.DenseVector) -> DenseVector: ... @staticmethod @@ -176,8 +187,8 @@ class Matrix: def __init__( self, numRows: int, numCols: int, isTransposed: bool = ... ) -> None: ... - def toArray(self): ... - def asML(self): ... + def toArray(self) -> ndarray: ... + def asML(self) -> newlinalg.Matrix: ... class DenseMatrix(Matrix): values: Any @@ -193,12 +204,12 @@ class DenseMatrix(Matrix): values: Iterable[float], isTransposed: bool = ..., ) -> None: ... - def __reduce__(self) -> Tuple[type, Tuple[int, int, bytes, int]]: ... + def __reduce__(self) -> Tuple[Type[DenseMatrix], Tuple[int, int, bytes, int]]: ... def toArray(self) -> ndarray: ... def toSparse(self) -> SparseMatrix: ... def asML(self) -> newlinalg.DenseMatrix: ... def __getitem__(self, indices: Tuple[int, int]) -> float64: ... - def __eq__(self, other) -> bool: ... + def __eq__(self, other: Any) -> bool: ... class SparseMatrix(Matrix): colPtrs: ndarray @@ -224,12 +235,14 @@ class SparseMatrix(Matrix): values: Iterable[float], isTransposed: bool = ..., ) -> None: ... - def __reduce__(self) -> Tuple[type, Tuple[int, int, bytes, bytes, bytes, int]]: ... + def __reduce__( + self, + ) -> Tuple[Type[SparseMatrix], Tuple[int, int, bytes, bytes, bytes, int]]: ... def __getitem__(self, indices: Tuple[int, int]) -> float64: ... def toArray(self) -> ndarray: ... def toDense(self) -> DenseMatrix: ... def asML(self) -> newlinalg.SparseMatrix: ... - def __eq__(self, other) -> bool: ... + def __eq__(self, other: Any) -> bool: ... class Matrices: @overload From 7b83aaab78042c92cc5bc938215161d19e585208 Mon Sep 17 00:00:00 2001 From: zero323 Date: Sun, 15 Nov 2020 22:56:16 +0100 Subject: [PATCH 12/26] Adjust pyspark.rdd --- python/pyspark/rdd.pyi | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/python/pyspark/rdd.pyi b/python/pyspark/rdd.pyi index 35c49e952b0cd..a277cd9f7edae 100644 --- a/python/pyspark/rdd.pyi +++ b/python/pyspark/rdd.pyi @@ -85,12 +85,16 @@ class PythonEvalType: SQL_COGROUPED_MAP_PANDAS_UDF: PandasCogroupedMapUDFType class BoundedFloat(float): - def __new__(cls, mean: float, confidence: float, low: float, high: float): ... + def __new__( + cls, mean: float, confidence: float, low: float, high: float + ) -> BoundedFloat: ... class Partitioner: numPartitions: int partitionFunc: Callable[[Any], int] - def __init__(self, numPartitions, partitionFunc) -> None: ... + def __init__( + self, numPartitions: int, partitionFunc: Callable[[Any], int] + ) -> None: ... def __eq__(self, other: Any) -> bool: ... def __call__(self, k: Any) -> int: ... From dcf8e3d1c022e58635df2d8d73412d0b6849d58d Mon Sep 17 00:00:00 2001 From: zero323 Date: Sun, 15 Nov 2020 22:59:02 +0100 Subject: [PATCH 13/26] Adjust pyspark.ml.feature --- python/pyspark/ml/feature.pyi | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/python/pyspark/ml/feature.pyi b/python/pyspark/ml/feature.pyi index f5b12a5b2ffc6..c1d3669b2479d 100644 --- a/python/pyspark/ml/feature.pyi +++ b/python/pyspark/ml/feature.pyi @@ -100,9 +100,9 @@ class _LSHParams(HasInputCol, HasOutputCol): def getNumHashTables(self) -> int: ... class _LSH(Generic[JM], JavaEstimator[JM], _LSHParams, JavaMLReadable, JavaMLWritable): - def setNumHashTables(self: P, value) -> P: ... - def setInputCol(self: P, value) -> P: ... - def setOutputCol(self: P, value) -> P: ... + def setNumHashTables(self: P, value: int) -> P: ... + def setInputCol(self: P, value: str) -> P: ... + def setOutputCol(self: P, value: str) -> P: ... class _LSHModel(JavaModel, _LSHParams): def setInputCol(self: P, value: str) -> P: ... @@ -1518,7 +1518,7 @@ class ChiSqSelector( fpr: float = ..., fdr: float = ..., fwe: float = ... - ): ... + ) -> ChiSqSelector: ... def setSelectorType(self, value: str) -> ChiSqSelector: ... def setNumTopFeatures(self, value: int) -> ChiSqSelector: ... def setPercentile(self, value: float) -> ChiSqSelector: ... @@ -1615,7 +1615,7 @@ class VarianceThresholdSelector( featuresCol: str = ..., outputCol: Optional[str] = ..., varianceThreshold: float = ..., - ): ... + ) -> VarianceThresholdSelector: ... def setVarianceThreshold(self, value: float) -> VarianceThresholdSelector: ... def setFeaturesCol(self, value: str) -> VarianceThresholdSelector: ... def setOutputCol(self, value: str) -> VarianceThresholdSelector: ... From 4efeeae6e47f032c689a62b02376612217d0dc82 Mon Sep 17 00:00:00 2001 From: zero323 Date: Sun, 15 Nov 2020 23:05:08 +0100 Subject: [PATCH 14/26] Adjust pyspark.ml.evaluation --- python/pyspark/ml/evaluation.pyi | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/python/pyspark/ml/evaluation.pyi b/python/pyspark/ml/evaluation.pyi index ea0a9f045cd6a..55a3ae2774115 100644 --- a/python/pyspark/ml/evaluation.pyi +++ b/python/pyspark/ml/evaluation.pyi @@ -39,9 +39,12 @@ from pyspark.ml.param.shared import ( HasWeightCol, ) from pyspark.ml.util import JavaMLReadable, JavaMLWritable +from pyspark.sql.dataframe import DataFrame class Evaluator(Params, metaclass=abc.ABCMeta): - def evaluate(self, dataset, params: Optional[ParamMap] = ...) -> float: ... + def evaluate( + self, dataset: DataFrame, params: Optional[ParamMap] = ... + ) -> float: ... def isLargerBetter(self) -> bool: ... class JavaEvaluator(JavaParams, Evaluator, metaclass=abc.ABCMeta): @@ -75,16 +78,15 @@ class BinaryClassificationEvaluator( def setLabelCol(self, value: str) -> BinaryClassificationEvaluator: ... def setRawPredictionCol(self, value: str) -> BinaryClassificationEvaluator: ... def setWeightCol(self, value: str) -> BinaryClassificationEvaluator: ... - -def setParams( - self, - *, - rawPredictionCol: str = ..., - labelCol: str = ..., - metricName: BinaryClassificationEvaluatorMetricType = ..., - weightCol: Optional[str] = ..., - numBins: int = ... -) -> BinaryClassificationEvaluator: ... + def setParams( + self, + *, + rawPredictionCol: str = ..., + labelCol: str = ..., + metricName: BinaryClassificationEvaluatorMetricType = ..., + weightCol: Optional[str] = ..., + numBins: int = ... + ) -> BinaryClassificationEvaluator: ... class RegressionEvaluator( JavaEvaluator, From a089d5cf19e66c038824058746c1855961e4f511 Mon Sep 17 00:00:00 2001 From: zero323 Date: Sun, 15 Nov 2020 23:07:02 +0100 Subject: [PATCH 15/26] Adjust pyspark.ml.regression --- python/pyspark/ml/regression.pyi | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/pyspark/ml/regression.pyi b/python/pyspark/ml/regression.pyi index 5cb0e7a5092f7..c9f9b8c32d93d 100644 --- a/python/pyspark/ml/regression.pyi +++ b/python/pyspark/ml/regression.pyi @@ -749,10 +749,10 @@ class _FactorizationMachinesParams( initStd: Param[float] solver: Param[str] def __init__(self, *args: Any): ... - def getFactorSize(self): ... - def getFitLinear(self): ... - def getMiniBatchFraction(self): ... - def getInitStd(self): ... + def getFactorSize(self) -> int: ... + def getFitLinear(self) -> bool: ... + def getMiniBatchFraction(self) -> float: ... + def getInitStd(self) -> float: ... class FMRegressor( _JavaRegressor[FMRegressionModel], From ee8a6549942859fe2f1372105a30ca443751ac87 Mon Sep 17 00:00:00 2001 From: zero323 Date: Sun, 15 Nov 2020 23:11:48 +0100 Subject: [PATCH 16/26] Adjust pyspark.ml.classification --- python/pyspark/ml/classification.pyi | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/pyspark/ml/classification.pyi b/python/pyspark/ml/classification.pyi index 4bde851bb1e0d..c44176a13a69b 100644 --- a/python/pyspark/ml/classification.pyi +++ b/python/pyspark/ml/classification.pyi @@ -107,7 +107,7 @@ class _JavaProbabilisticClassifier( class _JavaProbabilisticClassificationModel( ProbabilisticClassificationModel, _JavaClassificationModel[T] ): - def predictProbability(self, value: Any): ... + def predictProbability(self, value: Vector) -> Vector: ... class _ClassificationSummary(JavaWrapper): @property @@ -543,7 +543,7 @@ class RandomForestClassificationModel( @property def trees(self) -> List[DecisionTreeClassificationModel]: ... def summary(self) -> RandomForestClassificationTrainingSummary: ... - def evaluate(self, dataset) -> RandomForestClassificationSummary: ... + def evaluate(self, dataset: DataFrame) -> RandomForestClassificationSummary: ... class RandomForestClassificationSummary(_ClassificationSummary): ... class RandomForestClassificationTrainingSummary( @@ -891,7 +891,7 @@ class FMClassifier( solver: str = ..., thresholds: Optional[Any] = ..., seed: Optional[Any] = ..., - ): ... + ) -> FMClassifier: ... def setFactorSize(self, value: int) -> FMClassifier: ... def setFitLinear(self, value: bool) -> FMClassifier: ... def setMiniBatchFraction(self, value: float) -> FMClassifier: ... From 6788a4b79d3a9c96025c58c20660825eebfbac09 Mon Sep 17 00:00:00 2001 From: zero323 Date: Sun, 15 Nov 2020 23:14:01 +0100 Subject: [PATCH 17/26] Adjust pyspark.resource.profile --- python/pyspark/resource/profile.pyi | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/resource/profile.pyi b/python/pyspark/resource/profile.pyi index 6763baf6590a3..04838692436df 100644 --- a/python/pyspark/resource/profile.pyi +++ b/python/pyspark/resource/profile.pyi @@ -49,7 +49,7 @@ class ResourceProfileBuilder: def __init__(self) -> None: ... def require( self, resourceRequest: Union[ExecutorResourceRequest, TaskResourceRequests] - ): ... + ) -> ResourceProfileBuilder: ... def clearExecutorResourceRequests(self) -> None: ... def clearTaskResourceRequests(self) -> None: ... @property From f2a70b5643ea3507f29a9155e5e11bc86f929bb4 Mon Sep 17 00:00:00 2001 From: zero323 Date: Sun, 15 Nov 2020 23:37:43 +0100 Subject: [PATCH 18/26] Adjust pyspark.{ml, mllib}.common --- python/pyspark/ml/common.pyi | 10 ++++++++-- python/pyspark/mllib/common.pyi | 20 ++++++++++++++------ 2 files changed, 22 insertions(+), 8 deletions(-) diff --git a/python/pyspark/ml/common.pyi b/python/pyspark/ml/common.pyi index 7bf0ed6183d8a..a38fc5734f466 100644 --- a/python/pyspark/ml/common.pyi +++ b/python/pyspark/ml/common.pyi @@ -16,5 +16,11 @@ # specific language governing permissions and limitations # under the License. -def callJavaFunc(sc, func, *args): ... -def inherit_doc(cls): ... +from typing import Any, TypeVar + +import pyspark.context + +C = TypeVar("C", bound=type) + +def callJavaFunc(sc: pyspark.context.SparkContext, func: Any, *args: Any) -> Any: ... +def inherit_doc(cls: C) -> C: ... diff --git a/python/pyspark/mllib/common.pyi b/python/pyspark/mllib/common.pyi index 1df308b91b5a1..daba212d93633 100644 --- a/python/pyspark/mllib/common.pyi +++ b/python/pyspark/mllib/common.pyi @@ -16,12 +16,20 @@ # specific language governing permissions and limitations # under the License. -def callJavaFunc(sc, func, *args): ... -def callMLlibFunc(name, *args): ... +from typing import Any, TypeVar + +import pyspark.context + +from py4j.java_gateway import JavaObject + +C = TypeVar("C", bound=type) + +def callJavaFunc(sc: pyspark.context.SparkContext, func: Any, *args: Any) -> Any: ... +def callMLlibFunc(name: str, *args: Any) -> Any: ... class JavaModelWrapper: - def __init__(self, java_model) -> None: ... - def __del__(self): ... - def call(self, name, *a): ... + def __init__(self, java_model: JavaObject) -> None: ... + def __del__(self) -> None: ... + def call(self, name: str, *a: Any) -> Any: ... -def inherit_doc(cls): ... +def inherit_doc(cls: C) -> C: ... From 461a0a241e074cee49f02bdc1cb4c6f6b01f19d0 Mon Sep 17 00:00:00 2001 From: zero323 Date: Sun, 15 Nov 2020 23:44:00 +0100 Subject: [PATCH 19/26] Adjust pyspark.streaming.dstream --- python/pyspark/streaming/dstream.pyi | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/python/pyspark/streaming/dstream.pyi b/python/pyspark/streaming/dstream.pyi index 7b76ce4c65233..1521d838fc2b5 100644 --- a/python/pyspark/streaming/dstream.pyi +++ b/python/pyspark/streaming/dstream.pyi @@ -30,9 +30,12 @@ from typing import ( ) import datetime from pyspark.rdd import RDD +import pyspark.serializers from pyspark.storagelevel import StorageLevel import pyspark.streaming.context +from py4j.java_gateway import JavaObject + S = TypeVar("S") T = TypeVar("T") U = TypeVar("U") @@ -42,7 +45,12 @@ V = TypeVar("V") class DStream(Generic[T]): is_cached: bool is_checkpointed: bool - def __init__(self, jdstream, ssc, jrdd_deserializer) -> None: ... + def __init__( + self, + jdstream: JavaObject, + ssc: pyspark.streaming.context.StreamingContext, + jrdd_deserializer: pyspark.serializers.Serializer, + ) -> None: ... def context(self) -> pyspark.streaming.context.StreamingContext: ... def count(self) -> DStream[int]: ... def filter(self, f: Callable[[T], bool]) -> DStream[T]: ... From 3bc2b3d73581a9d674adc3e7e8da455a8a0aa734 Mon Sep 17 00:00:00 2001 From: zero323 Date: Sun, 15 Nov 2020 23:47:29 +0100 Subject: [PATCH 20/26] Adjust pyspark.mllib.random --- python/pyspark/mllib/random.pyi | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/mllib/random.pyi b/python/pyspark/mllib/random.pyi index dc5f4701614da..ec83170625c74 100644 --- a/python/pyspark/mllib/random.pyi +++ b/python/pyspark/mllib/random.pyi @@ -90,7 +90,7 @@ class RandomRDDs: def logNormalVectorRDD( sc: SparkContext, mean: float, - std, + std: float, numRows: int, numCols: int, numPartitions: Optional[int] = ..., From d150ce00ebb2ba69eef719724bc0d3ccbcda3010 Mon Sep 17 00:00:00 2001 From: zero323 Date: Sun, 15 Nov 2020 23:50:06 +0100 Subject: [PATCH 21/26] Adjust pyspark.mllib.recommendation --- python/pyspark/mllib/recommendation.pyi | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/pyspark/mllib/recommendation.pyi b/python/pyspark/mllib/recommendation.pyi index e2f15494209e9..4fea0acf3c1f9 100644 --- a/python/pyspark/mllib/recommendation.pyi +++ b/python/pyspark/mllib/recommendation.pyi @@ -16,7 +16,7 @@ # specific language governing permissions and limitations # under the License. -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Type, Union import array from collections import namedtuple @@ -27,7 +27,7 @@ from pyspark.mllib.common import JavaModelWrapper from pyspark.mllib.util import JavaLoader, JavaSaveable class Rating(namedtuple("Rating", ["user", "product", "rating"])): - def __reduce__(self): ... + def __reduce__(self) -> Tuple[Type[Rating], Tuple[int, int, float]]: ... class MatrixFactorizationModel( JavaModelWrapper, JavaSaveable, JavaLoader[MatrixFactorizationModel] From 0dcf608fea4058dedbfc0732ab510ef411e15a52 Mon Sep 17 00:00:00 2001 From: zero323 Date: Sun, 15 Nov 2020 23:56:34 +0100 Subject: [PATCH 22/26] Adjust pyspark.mllib.clustering --- python/pyspark/mllib/clustering.pyi | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/pyspark/mllib/clustering.pyi b/python/pyspark/mllib/clustering.pyi index 1c3eba17e201c..b4f349612f0fe 100644 --- a/python/pyspark/mllib/clustering.pyi +++ b/python/pyspark/mllib/clustering.pyi @@ -63,7 +63,7 @@ class BisectingKMeans: class KMeansModel(Saveable, Loader[KMeansModel]): centers: List[ndarray] - def __init__(self, centers: List[ndarray]) -> None: ... + def __init__(self, centers: List[VectorLike]) -> None: ... @property def clusterCenters(self) -> List[ndarray]: ... @property @@ -144,7 +144,9 @@ class PowerIterationClustering: class Assignment(NamedTuple("Assignment", [("id", int), ("cluster", int)])): ... class StreamingKMeansModel(KMeansModel): - def __init__(self, clusterCenters, clusterWeights) -> None: ... + def __init__( + self, clusterCenters: List[VectorLike], clusterWeights: VectorLike + ) -> None: ... @property def clusterWeights(self) -> List[float64]: ... centers: ndarray From 9cbdf0e355f3d9187aa3de7e13fa588a5d19d32d Mon Sep 17 00:00:00 2001 From: zero323 Date: Sun, 15 Nov 2020 23:59:29 +0100 Subject: [PATCH 23/26] Adjust pyspark.mllib.classification --- python/pyspark/mllib/classification.pyi | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/mllib/classification.pyi b/python/pyspark/mllib/classification.pyi index c51882c87bfc2..967b0a9f289dd 100644 --- a/python/pyspark/mllib/classification.pyi +++ b/python/pyspark/mllib/classification.pyi @@ -118,7 +118,7 @@ class NaiveBayesModel(Saveable, Loader[NaiveBayesModel]): labels: ndarray pi: ndarray theta: ndarray - def __init__(self, labels, pi, theta) -> None: ... + def __init__(self, labels: ndarray, pi: ndarray, theta: ndarray) -> None: ... @overload def predict(self, x: VectorLike) -> float64: ... @overload From 1649576c6f16ab041ae252fa6198a928c926de21 Mon Sep 17 00:00:00 2001 From: zero323 Date: Mon, 16 Nov 2020 00:01:50 +0100 Subject: [PATCH 24/26] Adjust pyspark.mllib.stat._statistics --- python/pyspark/mllib/stat/_statistics.pyi | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/mllib/stat/_statistics.pyi b/python/pyspark/mllib/stat/_statistics.pyi index 4d2701d486881..3834d51639eb2 100644 --- a/python/pyspark/mllib/stat/_statistics.pyi +++ b/python/pyspark/mllib/stat/_statistics.pyi @@ -65,5 +65,5 @@ class Statistics: def chiSqTest(observed: RDD[LabeledPoint]) -> List[ChiSqTestResult]: ... @staticmethod def kolmogorovSmirnovTest( - data, distName: Literal["norm"] = ..., *params: float + data: RDD[float], distName: Literal["norm"] = ..., *params: float ) -> KolmogorovSmirnovTestResult: ... From 426c7bbc31f358a0faaaeb0137d8ddcea5997388 Mon Sep 17 00:00:00 2001 From: zero323 Date: Mon, 16 Nov 2020 03:43:49 +0100 Subject: [PATCH 25/26] Drop unused imports --- python/pyspark/streaming/context.pyi | 2 +- python/pyspark/streaming/kinesis.pyi | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/pyspark/streaming/context.pyi b/python/pyspark/streaming/context.pyi index 026163fc9a1db..117a6742e6b6b 100644 --- a/python/pyspark/streaming/context.pyi +++ b/python/pyspark/streaming/context.pyi @@ -16,7 +16,7 @@ # specific language governing permissions and limitations # under the License. -from typing import Any, Callable, List, Optional, TypeVar, Union +from typing import Any, Callable, List, Optional, TypeVar from py4j.java_gateway import JavaObject # type: ignore[import] diff --git a/python/pyspark/streaming/kinesis.pyi b/python/pyspark/streaming/kinesis.pyi index af7cd6f6ec13c..399c37f869620 100644 --- a/python/pyspark/streaming/kinesis.pyi +++ b/python/pyspark/streaming/kinesis.pyi @@ -16,7 +16,7 @@ # specific language governing permissions and limitations # under the License. -from typing import Any, Callable, Optional, TypeVar +from typing import Callable, Optional, TypeVar from pyspark.storagelevel import StorageLevel from pyspark.streaming.context import StreamingContext from pyspark.streaming.dstream import DStream From 9182a9863c057c3b1974e613c1152c6e434a7b52 Mon Sep 17 00:00:00 2001 From: zero323 Date: Mon, 23 Nov 2020 13:01:18 +0100 Subject: [PATCH 26/26] Fill generics, when useful --- python/pyspark/ml/feature.pyi | 10 ++++++++-- python/pyspark/ml/pipeline.pyi | 4 ++-- python/pyspark/ml/regression.pyi | 2 +- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/python/pyspark/ml/feature.pyi b/python/pyspark/ml/feature.pyi index c1d3669b2479d..4999defdf8a70 100644 --- a/python/pyspark/ml/feature.pyi +++ b/python/pyspark/ml/feature.pyi @@ -1602,7 +1602,10 @@ class _VarianceThresholdSelectorParams(HasFeaturesCol, HasOutputCol): def getVarianceThreshold(self) -> float: ... class VarianceThresholdSelector( - JavaEstimator, _VarianceThresholdSelectorParams, JavaMLReadable, JavaMLWritable + JavaEstimator[VarianceThresholdSelectorModel], + _VarianceThresholdSelectorParams, + JavaMLReadable[VarianceThresholdSelector], + JavaMLWritable, ): def __init__( self, @@ -1621,7 +1624,10 @@ class VarianceThresholdSelector( def setOutputCol(self, value: str) -> VarianceThresholdSelector: ... class VarianceThresholdSelectorModel( - JavaModel, _VarianceThresholdSelectorParams, JavaMLReadable, JavaMLWritable + JavaModel, + _VarianceThresholdSelectorParams, + JavaMLReadable[VarianceThresholdSelectorModel], + JavaMLWritable, ): def setFeaturesCol(self, value: str) -> VarianceThresholdSelectorModel: ... def setOutputCol(self, value: str) -> VarianceThresholdSelectorModel: ... diff --git a/python/pyspark/ml/pipeline.pyi b/python/pyspark/ml/pipeline.pyi index 44680586d70d1..f47e9e012ae14 100644 --- a/python/pyspark/ml/pipeline.pyi +++ b/python/pyspark/ml/pipeline.pyi @@ -51,7 +51,7 @@ class PipelineWriter(MLWriter): def __init__(self, instance: Pipeline) -> None: ... def saveImpl(self, path: str) -> None: ... -class PipelineReader(MLReader): +class PipelineReader(MLReader[Pipeline]): cls: Type[Pipeline] def __init__(self, cls: Type[Pipeline]) -> None: ... def load(self, path: str) -> Pipeline: ... @@ -61,7 +61,7 @@ class PipelineModelWriter(MLWriter): def __init__(self, instance: PipelineModel) -> None: ... def saveImpl(self, path: str) -> None: ... -class PipelineModelReader(MLReader): +class PipelineModelReader(MLReader[PipelineModel]): cls: Type[PipelineModel] def __init__(self, cls: Type[PipelineModel]) -> None: ... def load(self, path: str) -> PipelineModel: ... diff --git a/python/pyspark/ml/regression.pyi b/python/pyspark/ml/regression.pyi index c9f9b8c32d93d..b8f1e61859c72 100644 --- a/python/pyspark/ml/regression.pyi +++ b/python/pyspark/ml/regression.pyi @@ -414,7 +414,7 @@ class RandomForestRegressionModel( _TreeEnsembleModel, _RandomForestRegressorParams, JavaMLWritable, - JavaMLReadable, + JavaMLReadable[RandomForestRegressionModel], ): @property def trees(self) -> List[DecisionTreeRegressionModel]: ...