Custom Accessors

Definition

Custom accessors allow you to extend Pandas with your own methods that can be called using the familiar .accessor_name.method() syntax, similar to .dt, .str, or .cat. This creates reusable, clean APIs for domain-specific operations.

Key Concepts

  • @pd.api.extensions.register_dataframe_accessor: Decorator for DataFrame accessors
  • @pd.api.extensions.register_series_accessor: Decorator for Series accessors
  • Encapsulation: Bundle related functionality together
  • Reusability: Create once, use across projects
  • Clean API: Intuitive, method-chaining friendly interface

Example

python

import pandas as pd
import numpy as np

# ========== BASIC SERIES ACCESSOR ==========
print("=" * 60)
print("BASIC SERIES ACCESSOR")
print("=" * 60)

@pd.api.extensions.register_series_accessor("text_tools")
class TextToolsAccessor:
    """Custom accessor for text operations"""
    
    def __init__(self, pandas_obj):
        self._validate(pandas_obj)
        self._obj = pandas_obj
    
    @staticmethod
    def _validate(obj):
        """Verify we have string data"""
        if not pd.api.types.is_string_dtype(obj):
            raise AttributeError("Can only use .text_tools with string data")
    
    def word_count(self):
        """Count words in each string"""
        return self._obj.str.split().str.len()
    
    def char_count(self):
        """Count characters (excluding spaces)"""
        return self._obj.str.replace(' ', '').str.len()
    
    def title_case_smart(self):
        """Smart title case (keeps acronyms uppercase)"""
        def smart_title(text):
            words = text.split()
            result = []
            for word in words:
                if word.isupper() and len(word) > 1:
                    result.append(word)  # Keep acronyms
                else:
                    result.append(word.capitalize())
            return ' '.join(result)
        
        return self._obj.apply(smart_title)
    
    def extract_hashtags(self):
        """Extract all hashtags from text"""
        return self._obj.str.findall(r'#\w+')

# Test the accessor
text_data = pd.Series([
    'Hello World from NASA',
    'Python is great for AI and ML',
    'Check out #pandas #python',
    'The USA and UK are allies'
])

print("Original text:")
print(text_data)
print("\n")

print("Word count:")
print(text_data.text_tools.word_count())
print("\n")

print("Character count:")
print(text_data.text_tools.char_count())
print("\n")

print("Smart title case:")
print(text_data.text_tools.title_case_smart())
print("\n")

print("Extract hashtags:")
print(text_data.text_tools.extract_hashtags())
print("\n")

# ========== DATAFRAME ACCESSOR ==========
print("=" * 60)
print("DATAFRAME ACCESSOR")
print("=" * 60)

@pd.api.extensions.register_dataframe_accessor("financial")
class FinancialAccessor:
    """Custom accessor for financial calculations"""
    
    def __init__(self, pandas_obj):
        self._obj = pandas_obj
    
    def profit_margin(self, revenue_col, cost_col):
        """Calculate profit margin percentage"""
        revenue = self._obj[revenue_col]
        cost = self._obj[cost_col]
        return ((revenue - cost) / revenue * 100).round(2)
    
    def roi(self, revenue_col, investment_col):
        """Calculate Return on Investment"""
        revenue = self._obj[revenue_col]
        investment = self._obj[investment_col]
        return ((revenue - investment) / investment * 100).round(2)
    
    def compound_growth_rate(self, value_col, periods):
        """Calculate compound annual growth rate"""
        start_value = self._obj[value_col].iloc[0]
        end_value = self._obj[value_col].iloc[-1]
        return (((end_value / start_value) ** (1/periods)) - 1) * 100
    
    def summary_stats(self, amount_col):
        """Generate financial summary statistics"""
        data = self._obj[amount_col]
        return pd.Series({
            'Total': data.sum(),
            'Average': data.mean(),
            'Median': data.median(),
            'Min': data.min(),
            'Max': data.max(),
            'Std Dev': data.std(),
            'Range': data.max() - data.min()
        })

# Test the accessor
financial_data = pd.DataFrame({
    'Product': ['A', 'B', 'C', 'D'],
    'Revenue': [10000, 15000, 8000, 20000],
    'Cost': [7000, 10000, 6000, 14000],
    'Investment': [5000, 8000, 4000, 10000]
})

print("Financial data:")
print(financial_data)
print("\n")

print("Profit margins:")
print(financial_data.financial.profit_margin('Revenue', 'Cost'))
print("\n")

print("ROI:")
print(financial_data.financial.roi('Revenue', 'Investment'))
print("\n")

print("Summary statistics for Revenue:")
print(financial_data.financial.summary_stats('Revenue'))
print("\n")

# ========== ADVANCED: DATETIME ACCESSOR ==========
print("=" * 60)
print("ADVANCED: CUSTOM DATETIME ACCESSOR")
print("=" * 60)

@pd.api.extensions.register_series_accessor("business_time")
class BusinessTimeAccessor:
    """Custom accessor for business time calculations"""
    
    def __init__(self, pandas_obj):
        self._validate(pandas_obj)
        self._obj = pandas_obj
    
    @staticmethod
    def _validate(obj):
        """Verify we have datetime data"""
        if not pd.api.types.is_datetime64_any_dtype(obj):
            raise AttributeError("Can only use .business_time with datetime data")
    
    def is_business_day(self):
        """Check if date is a business day (Mon-Fri)"""
        return self._obj.dt.dayofweek < 5
    
    def is_business_hours(self, start_hour=9, end_hour=17):
        """Check if time is during business hours"""
        hour = self._obj.dt.hour
        is_weekday = self.is_business_day()
        is_business_hour = (hour >= start_hour) & (hour < end_hour)
        return is_weekday & is_business_hour
    
    def next_business_day(self):
        """Get next business day"""
        from pandas.tseries.offsets import BDay
        return self._obj + BDay(1)
    
    def business_days_until(self, target_date):
        """Calculate business days until target"""
        from pandas.tseries.offsets import BDay
        bdays = pd.bdate_range(start=self._obj.min(), end=target_date)
        result = []
        for date in self._obj:
            count = len(pd.bdate_range(start=date, end=target_date))
            result.append(count - 1 if count > 0 else 0)
        return pd.Series(result, index=self._obj.index)
    
    def quarter_name(self):
        """Get quarter name (Q1, Q2, Q3, Q4)"""
        return 'Q' + self._obj.dt.quarter.astype(str)
    
    def fiscal_year(self, fiscal_start_month=4):
        """Calculate fiscal year (e.g., April start)"""
        year = self._obj.dt.year
        month = self._obj.dt.month
        return year + (month >= fiscal_start_month).astype(int)

# Test business time accessor
dates = pd.Series(pd.date_range('2024-01-01', periods=10, freq='D'))

print("Dates:")
print(dates)
print("\n")

print("Is business day:")
print(dates.business_time.is_business_day())
print("\n")

print("Next business day:")
print(dates.business_time.next_business_day())
print("\n")

print("Quarter name:")
print(dates.business_time.quarter_name())
print("\n")

print("Fiscal year (April start):")
print(dates.business_time.fiscal_year(fiscal_start_month=4))
print("\n")

# ========== ACCESSOR WITH PARAMETERS ==========
print("=" * 60)
print("ACCESSOR WITH CONFIGURATION")
print("=" * 60)

@pd.api.extensions.register_dataframe_accessor("analytics")
class AnalyticsAccessor:
    """Advanced analytics accessor with configuration"""
    
    def __init__(self, pandas_obj):
        self._obj = pandas_obj
        self._cache = {}
    
    def outliers(self, column, method='iqr', threshold=1.5):
        """
        Detect outliers using different methods
        
        Parameters:
        -----------
        column : str
            Column name to check for outliers
        method : str
            'iqr' - Interquartile Range
            'zscore' - Z-score method
        threshold : float
            Threshold for outlier detection
        """
        data = self._obj[column]
        
        if method == 'iqr':
            Q1 = data.quantile(0.25)
            Q3 = data.quantile(0.75)
            IQR = Q3 - Q1
            lower_bound = Q1 - threshold * IQR
            upper_bound = Q3 + threshold * IQR
            return (data < lower_bound) | (data > upper_bound)
        
        elif method == 'zscore':
            z_scores = np.abs((data - data.mean()) / data.std())
            return z_scores > threshold
        
        else:
            raise ValueError(f"Unknown method: {method}")
    
    def normalize(self, column, method='minmax'):
        """
        Normalize data using different methods
        
        Parameters:
        -----------
        column : str
            Column to normalize
        method : str
            'minmax' - Min-Max normalization (0-1)
            'zscore' - Z-score normalization
        """
        data = self._obj[column]
        
        if method == 'minmax':
            return (data - data.min()) / (data.max() - data.min())
        elif method == 'zscore':
            return (data - data.mean()) / data.std()
        else:
            raise ValueError(f"Unknown method: {method}")
    
    def correlation_with(self, target_col, top_n=5):
        """Find columns most correlated with target"""
        numeric_cols = self._obj.select_dtypes(include=[np.number]).columns
        correlations = self._obj[numeric_cols].corr()[target_col].abs()
        correlations = correlations.drop(target_col)
        return correlations.nlargest(top_n)
    
    def summary(self):
        """Generate comprehensive summary"""
        numeric_cols = self._obj.select_dtypes(include=[np.number]).columns
        
        summary_data = {
            'shape': self._obj.shape,
            'columns': len(self._obj.columns),
            'numeric_columns': len(numeric_cols),
            'missing_values': self._obj.isna().sum().sum(),
            'memory_usage_mb': self._obj.memory_usage(deep=True).sum() / 1024**2
        }
        
        return pd.Series(summary_data)

# Test analytics accessor
analytics_data = pd.DataFrame({
    'A': [1, 2, 3, 4, 5, 6, 100],  # 100 is outlier
    'B': [10, 20, 30, 40, 50, 60, 70],
    'C': [5, 15, 25, 35, 45, 55, 65],
    'D': [2, 4, 6, 8, 10, 12, 14]
})

print("Analytics data:")
print(analytics_data)
print("\n")

print("Outliers in column A (IQR method):")
print(analytics_data.analytics.outliers('A', method='iqr'))
print("\n")

print("Normalized column B (min-max):")
print(analytics_data.analytics.normalize('B', method='minmax'))
print("\n")

print("Top 3 correlations with column A:")
print(analytics_data.analytics.correlation_with('A', top_n=3))
print("\n")

print("DataFrame summary:")
print(analytics_data.analytics.summary())
print("\n")

# ========== CHAINING WITH ACCESSORS ==========
print("=" * 60)
print("METHOD CHAINING WITH ACCESSORS")
print("=" * 60)

# Accessors work great with method chaining
result = (
    financial_data
    .assign(
        profit_margin=lambda x: x.financial.profit_margin('Revenue', 'Cost'),
        roi=lambda x: x.financial.roi('Revenue', 'Investment')
    )
    .query('profit_margin > 25')
    .sort_values('roi', ascending=False)
)

print("Method chaining with custom accessor:")
print(result)