Definition
Custom accessors allow you to extend Pandas with your own methods that can be called using the familiar .accessor_name.method() syntax, similar to .dt, .str, or .cat. This creates reusable, clean APIs for domain-specific operations.
Key Concepts
- @pd.api.extensions.register_dataframe_accessor: Decorator for DataFrame accessors
- @pd.api.extensions.register_series_accessor: Decorator for Series accessors
- Encapsulation: Bundle related functionality together
- Reusability: Create once, use across projects
- Clean API: Intuitive, method-chaining friendly interface
Example
python
import pandas as pd
import numpy as np
# ========== BASIC SERIES ACCESSOR ==========
print("=" * 60)
print("BASIC SERIES ACCESSOR")
print("=" * 60)
@pd.api.extensions.register_series_accessor("text_tools")
class TextToolsAccessor:
"""Custom accessor for text operations"""
def __init__(self, pandas_obj):
self._validate(pandas_obj)
self._obj = pandas_obj
@staticmethod
def _validate(obj):
"""Verify we have string data"""
if not pd.api.types.is_string_dtype(obj):
raise AttributeError("Can only use .text_tools with string data")
def word_count(self):
"""Count words in each string"""
return self._obj.str.split().str.len()
def char_count(self):
"""Count characters (excluding spaces)"""
return self._obj.str.replace(' ', '').str.len()
def title_case_smart(self):
"""Smart title case (keeps acronyms uppercase)"""
def smart_title(text):
words = text.split()
result = []
for word in words:
if word.isupper() and len(word) > 1:
result.append(word) # Keep acronyms
else:
result.append(word.capitalize())
return ' '.join(result)
return self._obj.apply(smart_title)
def extract_hashtags(self):
"""Extract all hashtags from text"""
return self._obj.str.findall(r'#\w+')
# Test the accessor
text_data = pd.Series([
'Hello World from NASA',
'Python is great for AI and ML',
'Check out #pandas #python',
'The USA and UK are allies'
])
print("Original text:")
print(text_data)
print("\n")
print("Word count:")
print(text_data.text_tools.word_count())
print("\n")
print("Character count:")
print(text_data.text_tools.char_count())
print("\n")
print("Smart title case:")
print(text_data.text_tools.title_case_smart())
print("\n")
print("Extract hashtags:")
print(text_data.text_tools.extract_hashtags())
print("\n")
# ========== DATAFRAME ACCESSOR ==========
print("=" * 60)
print("DATAFRAME ACCESSOR")
print("=" * 60)
@pd.api.extensions.register_dataframe_accessor("financial")
class FinancialAccessor:
"""Custom accessor for financial calculations"""
def __init__(self, pandas_obj):
self._obj = pandas_obj
def profit_margin(self, revenue_col, cost_col):
"""Calculate profit margin percentage"""
revenue = self._obj[revenue_col]
cost = self._obj[cost_col]
return ((revenue - cost) / revenue * 100).round(2)
def roi(self, revenue_col, investment_col):
"""Calculate Return on Investment"""
revenue = self._obj[revenue_col]
investment = self._obj[investment_col]
return ((revenue - investment) / investment * 100).round(2)
def compound_growth_rate(self, value_col, periods):
"""Calculate compound annual growth rate"""
start_value = self._obj[value_col].iloc[0]
end_value = self._obj[value_col].iloc[-1]
return (((end_value / start_value) ** (1/periods)) - 1) * 100
def summary_stats(self, amount_col):
"""Generate financial summary statistics"""
data = self._obj[amount_col]
return pd.Series({
'Total': data.sum(),
'Average': data.mean(),
'Median': data.median(),
'Min': data.min(),
'Max': data.max(),
'Std Dev': data.std(),
'Range': data.max() - data.min()
})
# Test the accessor
financial_data = pd.DataFrame({
'Product': ['A', 'B', 'C', 'D'],
'Revenue': [10000, 15000, 8000, 20000],
'Cost': [7000, 10000, 6000, 14000],
'Investment': [5000, 8000, 4000, 10000]
})
print("Financial data:")
print(financial_data)
print("\n")
print("Profit margins:")
print(financial_data.financial.profit_margin('Revenue', 'Cost'))
print("\n")
print("ROI:")
print(financial_data.financial.roi('Revenue', 'Investment'))
print("\n")
print("Summary statistics for Revenue:")
print(financial_data.financial.summary_stats('Revenue'))
print("\n")
# ========== ADVANCED: DATETIME ACCESSOR ==========
print("=" * 60)
print("ADVANCED: CUSTOM DATETIME ACCESSOR")
print("=" * 60)
@pd.api.extensions.register_series_accessor("business_time")
class BusinessTimeAccessor:
"""Custom accessor for business time calculations"""
def __init__(self, pandas_obj):
self._validate(pandas_obj)
self._obj = pandas_obj
@staticmethod
def _validate(obj):
"""Verify we have datetime data"""
if not pd.api.types.is_datetime64_any_dtype(obj):
raise AttributeError("Can only use .business_time with datetime data")
def is_business_day(self):
"""Check if date is a business day (Mon-Fri)"""
return self._obj.dt.dayofweek < 5
def is_business_hours(self, start_hour=9, end_hour=17):
"""Check if time is during business hours"""
hour = self._obj.dt.hour
is_weekday = self.is_business_day()
is_business_hour = (hour >= start_hour) & (hour < end_hour)
return is_weekday & is_business_hour
def next_business_day(self):
"""Get next business day"""
from pandas.tseries.offsets import BDay
return self._obj + BDay(1)
def business_days_until(self, target_date):
"""Calculate business days until target"""
from pandas.tseries.offsets import BDay
bdays = pd.bdate_range(start=self._obj.min(), end=target_date)
result = []
for date in self._obj:
count = len(pd.bdate_range(start=date, end=target_date))
result.append(count - 1 if count > 0 else 0)
return pd.Series(result, index=self._obj.index)
def quarter_name(self):
"""Get quarter name (Q1, Q2, Q3, Q4)"""
return 'Q' + self._obj.dt.quarter.astype(str)
def fiscal_year(self, fiscal_start_month=4):
"""Calculate fiscal year (e.g., April start)"""
year = self._obj.dt.year
month = self._obj.dt.month
return year + (month >= fiscal_start_month).astype(int)
# Test business time accessor
dates = pd.Series(pd.date_range('2024-01-01', periods=10, freq='D'))
print("Dates:")
print(dates)
print("\n")
print("Is business day:")
print(dates.business_time.is_business_day())
print("\n")
print("Next business day:")
print(dates.business_time.next_business_day())
print("\n")
print("Quarter name:")
print(dates.business_time.quarter_name())
print("\n")
print("Fiscal year (April start):")
print(dates.business_time.fiscal_year(fiscal_start_month=4))
print("\n")
# ========== ACCESSOR WITH PARAMETERS ==========
print("=" * 60)
print("ACCESSOR WITH CONFIGURATION")
print("=" * 60)
@pd.api.extensions.register_dataframe_accessor("analytics")
class AnalyticsAccessor:
"""Advanced analytics accessor with configuration"""
def __init__(self, pandas_obj):
self._obj = pandas_obj
self._cache = {}
def outliers(self, column, method='iqr', threshold=1.5):
"""
Detect outliers using different methods
Parameters:
-----------
column : str
Column name to check for outliers
method : str
'iqr' - Interquartile Range
'zscore' - Z-score method
threshold : float
Threshold for outlier detection
"""
data = self._obj[column]
if method == 'iqr':
Q1 = data.quantile(0.25)
Q3 = data.quantile(0.75)
IQR = Q3 - Q1
lower_bound = Q1 - threshold * IQR
upper_bound = Q3 + threshold * IQR
return (data < lower_bound) | (data > upper_bound)
elif method == 'zscore':
z_scores = np.abs((data - data.mean()) / data.std())
return z_scores > threshold
else:
raise ValueError(f"Unknown method: {method}")
def normalize(self, column, method='minmax'):
"""
Normalize data using different methods
Parameters:
-----------
column : str
Column to normalize
method : str
'minmax' - Min-Max normalization (0-1)
'zscore' - Z-score normalization
"""
data = self._obj[column]
if method == 'minmax':
return (data - data.min()) / (data.max() - data.min())
elif method == 'zscore':
return (data - data.mean()) / data.std()
else:
raise ValueError(f"Unknown method: {method}")
def correlation_with(self, target_col, top_n=5):
"""Find columns most correlated with target"""
numeric_cols = self._obj.select_dtypes(include=[np.number]).columns
correlations = self._obj[numeric_cols].corr()[target_col].abs()
correlations = correlations.drop(target_col)
return correlations.nlargest(top_n)
def summary(self):
"""Generate comprehensive summary"""
numeric_cols = self._obj.select_dtypes(include=[np.number]).columns
summary_data = {
'shape': self._obj.shape,
'columns': len(self._obj.columns),
'numeric_columns': len(numeric_cols),
'missing_values': self._obj.isna().sum().sum(),
'memory_usage_mb': self._obj.memory_usage(deep=True).sum() / 1024**2
}
return pd.Series(summary_data)
# Test analytics accessor
analytics_data = pd.DataFrame({
'A': [1, 2, 3, 4, 5, 6, 100], # 100 is outlier
'B': [10, 20, 30, 40, 50, 60, 70],
'C': [5, 15, 25, 35, 45, 55, 65],
'D': [2, 4, 6, 8, 10, 12, 14]
})
print("Analytics data:")
print(analytics_data)
print("\n")
print("Outliers in column A (IQR method):")
print(analytics_data.analytics.outliers('A', method='iqr'))
print("\n")
print("Normalized column B (min-max):")
print(analytics_data.analytics.normalize('B', method='minmax'))
print("\n")
print("Top 3 correlations with column A:")
print(analytics_data.analytics.correlation_with('A', top_n=3))
print("\n")
print("DataFrame summary:")
print(analytics_data.analytics.summary())
print("\n")
# ========== CHAINING WITH ACCESSORS ==========
print("=" * 60)
print("METHOD CHAINING WITH ACCESSORS")
print("=" * 60)
# Accessors work great with method chaining
result = (
financial_data
.assign(
profit_margin=lambda x: x.financial.profit_margin('Revenue', 'Cost'),
roi=lambda x: x.financial.roi('Revenue', 'Investment')
)
.query('profit_margin > 25')
.sort_values('roi', ascending=False)
)
print("Method chaining with custom accessor:")
print(result)
