@@ -1589,6 +1589,7 @@ def _clean_interp_method(method, **kwargs):
15891589
15901590
15911591def interpolate_1d (xvalues , yvalues , method = 'linear' , limit = None ,
1592+ limit_direction = 'forward' ,
15921593 fill_value = None , bounds_error = False , order = None , ** kwargs ):
15931594 """
15941595 Logic for the 1-d interpolation. The result should be 1-d, inputs
@@ -1602,9 +1603,15 @@ def interpolate_1d(xvalues, yvalues, method='linear', limit=None,
16021603 invalid = isnull (yvalues )
16031604 valid = ~ invalid
16041605
1605- valid_y = yvalues [valid ]
1606- valid_x = xvalues [valid ]
1607- new_x = xvalues [invalid ]
1606+ if not valid .any ():
1607+ # have to call np.asarray(xvalues) since xvalues could be an Index
1608+ # which cant be mutated
1609+ result = np .empty_like (np .asarray (xvalues ), dtype = np .float64 )
1610+ result .fill (np .nan )
1611+ return result
1612+
1613+ if valid .all ():
1614+ return yvalues
16081615
16091616 if method == 'time' :
16101617 if not getattr (xvalues , 'is_all_dates' , None ):
@@ -1614,66 +1621,82 @@ def interpolate_1d(xvalues, yvalues, method='linear', limit=None,
16141621 'DatetimeIndex' )
16151622 method = 'values'
16161623
1617- def _interp_limit (invalid , limit ):
1618- """mask off values that won't be filled since they exceed the limit"" "
1624+ def _interp_limit (invalid , fw_limit , bw_limit ):
1625+ "Get idx of values that won't be forward- filled b/c they exceed the limit. "
16191626 all_nans = np .where (invalid )[0 ]
16201627 if all_nans .size == 0 : # no nans anyway
16211628 return []
1622- violate = [invalid [x :x + limit + 1 ] for x in all_nans ]
1623- violate = np .array ([x .all () & (x .size > limit ) for x in violate ])
1624- return all_nans [violate ] + limit
1629+ violate = [invalid [max (0 , x - bw_limit ):x + fw_limit + 1 ] for x in all_nans ]
1630+ violate = np .array ([x .all () & (x .size > bw_limit + fw_limit ) for x in violate ])
1631+ return all_nans [violate ] + fw_limit - bw_limit
1632+
1633+ valid_limit_directions = ['forward' , 'backward' , 'both' ]
1634+ limit_direction = limit_direction .lower ()
1635+ if limit_direction not in valid_limit_directions :
1636+ msg = 'Invalid limit_direction: expecting one of %r, got %r.' % (
1637+ valid_limit_directions , limit_direction )
1638+ raise ValueError (msg )
16251639
1626- xvalues = getattr (xvalues , 'values' , xvalues )
1627- yvalues = getattr (yvalues , 'values' , yvalues )
1640+ from pandas import Series
1641+ ys = Series (yvalues )
1642+ start_nans = set (range (ys .first_valid_index ()))
1643+ end_nans = set (range (1 + ys .last_valid_index (), len (valid )))
1644+
1645+ # This is a list of the indexes in the series whose yvalue is currently NaN,
1646+ # but whose interpolated yvalue will be overwritten with NaN after computing
1647+ # the interpolation. For each index in this list, one of these conditions is
1648+ # true of the corresponding NaN in the yvalues:
1649+ #
1650+ # a) It is one of a chain of NaNs at the beginning of the series, and either
1651+ # limit is not specified or limit_direction is 'forward'.
1652+ # b) It is one of a chain of NaNs at the end of the series, and limit is
1653+ # specified and limit_direction is 'backward' or 'both'.
1654+ # c) Limit is nonzero and it is further than limit from the nearest non-NaN
1655+ # value (with respect to the limit_direction setting).
1656+ #
1657+ # The default behavior is to fill forward with no limit, ignoring NaNs at
1658+ # the beginning (see issues #9218 and #10420)
1659+ violate_limit = sorted (start_nans )
16281660
16291661 if limit :
1630- violate_limit = _interp_limit (invalid , limit )
1631- if valid .any ():
1632- firstIndex = valid .argmax ()
1633- valid = valid [firstIndex :]
1634- invalid = invalid [firstIndex :]
1635- result = yvalues .copy ()
1636- if valid .all ():
1637- return yvalues
1638- else :
1639- # have to call np.array(xvalues) since xvalues could be an Index
1640- # which cant be mutated
1641- result = np .empty_like (np .array (xvalues ), dtype = np .float64 )
1642- result .fill (np .nan )
1643- return result
1662+ if limit_direction == 'forward' :
1663+ violate_limit = sorted (start_nans | set (_interp_limit (invalid , limit , 0 )))
1664+ if limit_direction == 'backward' :
1665+ violate_limit = sorted (end_nans | set (_interp_limit (invalid , 0 , limit )))
1666+ if limit_direction == 'both' :
1667+ violate_limit = _interp_limit (invalid , limit , limit )
1668+
1669+ xvalues = getattr (xvalues , 'values' , xvalues )
1670+ yvalues = getattr (yvalues , 'values' , yvalues )
1671+ result = yvalues .copy ()
16441672
16451673 if method in ['linear' , 'time' , 'index' , 'values' ]:
16461674 if method in ('values' , 'index' ):
16471675 inds = np .asarray (xvalues )
16481676 # hack for DatetimeIndex, #1646
16491677 if issubclass (inds .dtype .type , np .datetime64 ):
16501678 inds = inds .view (np .int64 )
1651-
16521679 if inds .dtype == np .object_ :
16531680 inds = lib .maybe_convert_objects (inds )
16541681 else :
16551682 inds = xvalues
1656-
1657- inds = inds [firstIndex :]
1658-
1659- result [firstIndex :][invalid ] = np .interp (inds [invalid ], inds [valid ],
1660- yvalues [firstIndex :][valid ])
1661-
1662- if limit :
1663- result [violate_limit ] = np .nan
1683+ result [invalid ] = np .interp (inds [invalid ], inds [valid ], yvalues [valid ])
1684+ result [violate_limit ] = np .nan
16641685 return result
16651686
16661687 sp_methods = ['nearest' , 'zero' , 'slinear' , 'quadratic' , 'cubic' ,
16671688 'barycentric' , 'krogh' , 'spline' , 'polynomial' ,
16681689 'piecewise_polynomial' , 'pchip' ]
16691690 if method in sp_methods :
1670- new_x = new_x [firstIndex :]
1671-
1672- result [firstIndex :][invalid ] = _interpolate_scipy_wrapper (
1673- valid_x , valid_y , new_x , method = method , fill_value = fill_value ,
1691+ inds = np .asarray (xvalues )
1692+ # hack for DatetimeIndex, #1646
1693+ if issubclass (inds .dtype .type , np .datetime64 ):
1694+ inds = inds .view (np .int64 )
1695+ result [invalid ] = _interpolate_scipy_wrapper (
1696+ inds [valid ], yvalues [valid ], inds [invalid ], method = method ,
1697+ fill_value = fill_value ,
16741698 bounds_error = bounds_error , order = order , ** kwargs )
1675- if limit :
1676- result [violate_limit ] = np .nan
1699+ result [violate_limit ] = np .nan
16771700 return result
16781701
16791702
0 commit comments