@@ -869,6 +869,12 @@ def _replace_coerce(
869869
870870 # ---------------------------------------------------------------------
871871
872+ def _maybe_squeeze_arg (self , arg : np .ndarray ) -> np .ndarray :
873+ """
874+ For compatibility with 1D-only ExtensionArrays.
875+ """
876+ return arg
877+
872878 def setitem (self , indexer , value ):
873879 """
874880 Attempt self.values[indexer] = value, possibly creating a new array.
@@ -1314,6 +1320,46 @@ class EABackedBlock(Block):
13141320
13151321 values : ExtensionArray
13161322
1323+ def putmask (self , mask , new ) -> list [Block ]:
1324+ """
1325+ See Block.putmask.__doc__
1326+ """
1327+ mask = extract_bool_array (mask )
1328+
1329+ values = self .values
1330+
1331+ mask = self ._maybe_squeeze_arg (mask )
1332+
1333+ try :
1334+ # Caller is responsible for ensuring matching lengths
1335+ values ._putmask (mask , new )
1336+ except (TypeError , ValueError ) as err :
1337+ if isinstance (err , ValueError ) and "Timezones don't match" not in str (err ):
1338+ # TODO(2.0): remove catching ValueError at all since
1339+ # DTA raising here is deprecated
1340+ raise
1341+
1342+ if is_interval_dtype (self .dtype ):
1343+ # Discussion about what we want to support in the general
1344+ # case GH#39584
1345+ blk = self .coerce_to_target_dtype (new )
1346+ if blk .dtype == _dtype_obj :
1347+ # For now at least, only support casting e.g.
1348+ # Interval[int64]->Interval[float64],
1349+ raise
1350+ return blk .putmask (mask , new )
1351+
1352+ elif isinstance (self , NDArrayBackedExtensionBlock ):
1353+ # NB: not (yet) the same as
1354+ # isinstance(values, NDArrayBackedExtensionArray)
1355+ blk = self .coerce_to_target_dtype (new )
1356+ return blk .putmask (mask , new )
1357+
1358+ else :
1359+ raise
1360+
1361+ return [self ]
1362+
13171363 def delete (self , loc ) -> None :
13181364 """
13191365 Delete given loc(-s) from block in-place.
@@ -1410,36 +1456,16 @@ def set_inplace(self, locs, values) -> None:
14101456 # _cache not yet initialized
14111457 pass
14121458
1413- def putmask (self , mask , new ) -> list [ Block ] :
1459+ def _maybe_squeeze_arg (self , arg ) :
14141460 """
1415- See Block.putmask.__doc__
1461+ If necessary, squeeze a (N, 1) ndarray to (N,)
14161462 """
1417- mask = extract_bool_array (mask )
1418-
1419- new_values = self .values
1420-
1421- if mask .ndim == new_values .ndim + 1 :
1463+ # e.g. if we are passed a 2D mask for putmask
1464+ if isinstance (arg , np .ndarray ) and arg .ndim == self .values .ndim + 1 :
14221465 # TODO(EA2D): unnecessary with 2D EAs
1423- mask = mask .reshape (new_values .shape )
1424-
1425- try :
1426- # Caller is responsible for ensuring matching lengths
1427- new_values ._putmask (mask , new )
1428- except TypeError :
1429- if not is_interval_dtype (self .dtype ):
1430- # Discussion about what we want to support in the general
1431- # case GH#39584
1432- raise
1433-
1434- blk = self .coerce_to_target_dtype (new )
1435- if blk .dtype == _dtype_obj :
1436- # For now at least, only support casting e.g.
1437- # Interval[int64]->Interval[float64],
1438- raise
1439- return blk .putmask (mask , new )
1440-
1441- nb = type (self )(new_values , placement = self ._mgr_locs , ndim = self .ndim )
1442- return [nb ]
1466+ assert arg .shape [1 ] == 1
1467+ arg = arg [:, 0 ]
1468+ return arg
14431469
14441470 @property
14451471 def is_view (self ) -> bool :
@@ -1595,15 +1621,8 @@ def where(self, other, cond) -> list[Block]:
15951621 cond = extract_bool_array (cond )
15961622 assert not isinstance (other , (ABCIndex , ABCSeries , ABCDataFrame ))
15971623
1598- if isinstance (other , np .ndarray ) and other .ndim == 2 :
1599- # TODO(EA2D): unnecessary with 2D EAs
1600- assert other .shape [1 ] == 1
1601- other = other [:, 0 ]
1602-
1603- if isinstance (cond , np .ndarray ) and cond .ndim == 2 :
1604- # TODO(EA2D): unnecessary with 2D EAs
1605- assert cond .shape [1 ] == 1
1606- cond = cond [:, 0 ]
1624+ other = self ._maybe_squeeze_arg (other )
1625+ cond = self ._maybe_squeeze_arg (cond )
16071626
16081627 if lib .is_scalar (other ) and isna (other ):
16091628 # The default `other` for Series / Frame is np.nan
@@ -1698,16 +1717,6 @@ def setitem(self, indexer, value):
16981717 values [indexer ] = value
16991718 return self
17001719
1701- def putmask (self , mask , new ) -> list [Block ]:
1702- mask = extract_bool_array (mask )
1703-
1704- if not self ._can_hold_element (new ):
1705- return self .coerce_to_target_dtype (new ).putmask (mask , new )
1706-
1707- arr = self .values
1708- arr .T ._putmask (mask , new )
1709- return [self ]
1710-
17111720 def where (self , other , cond ) -> list [Block ]:
17121721 arr = self .values
17131722
0 commit comments