大家好,又见面了,我是你们的朋友全栈君。
本文介绍 Prophet 模型的简单调用。
(一)日志设置为不输出
import os
class SuppressStdoutStderr(object):
"""
A context manager for doing a "deep suppression" of stdout and stderr in
Python, i.e. will suppress all print, even if the print originates in a
compiled C/Fortran sub-function.
This will not suppress raised exceptions, since exceptions are printed
to stderr just before a script exits, and after the context manager has
exited (at least, I think that is why it lets exceptions through).
"""
def __init__(self):
# open a pair of null files
self.null_fds = [os.open(os.devnull, os.O_RDWR) for x in range(2)]
# save the actual stdout (1) and stderr (2) file descriptors
self.save_fds = (os.dup(1), os.dup(2))
def __enter__(self):
# assign the null pointers to stdout and stderr
os.dup2(self.null_fds[0], 1)
os.dup2(self.null_fds[1], 2)
def __exit__(self, *_):
# reassign the real stdout/stderr back to (1) and (2)
os.dup2(self.save_fds[0], 1)
os.dup2(self.save_fds[1], 2)
# close the null files
os.close(self.null_fds[0])
os.close(self.null_fds[1])
os.close(self.save_fds[0])
os.close(self.save_fds[1])
(二)Prophet 预测模型 class,与官方 Prophet 结构相似,但不继承
1. 初始化
from datetime import datetime, timedelta
from typing import Tuple
import numpy as np
import pandas as pd
from fbprophet import Prophet
from fbprophet.diagnostics import cross_validation
STD_D_STR = "%Y-%m-%d" # '%m': 月份,'%M': 分钟
class ProphetPredictor(object):
def __init__(self, x_train: pd.DataFrame,
trg_st_dt: datetime, tm_step: int, his_st_dt: datetime, his_en_dt: datetime,
cv_horizon: str, cv_period: str, cv_initial: str,
n_changepoints=None, changepoint_range=0.7,
yearly_seasonality=False, weekly_seasonality=True, daily_seasonality=False,
holidays=None, seasonality_mode='multiplicative',
seasonality_prior_scale=10, holidays_prior_scale=0, changepoint_prior_scale=0.05):
"""
initialisation
:param x_train: 数据集, ['ds', 'y']
:param trg_st_dt: 预测开始日期
:param tm_step: 预测时间间隔
:param his_st_dt: 训练集开始日期
:param his_en_dt: 训练集结束日期 todo: 允许使用的训练集结束日期必须严格小于预测开始日期
:param cv_horizon: 交叉验证 horizon 参数, '3 days' 格式
:param cv_period: 交叉验证 period 参数, '3 days' 格式
:param cv_initial: 交叉验证 initial 参数, '3 days' 格式
:param n_changepoints: Changepoint 最大数量
:param changepoint_range: Changepoint 在历史数据中出现的时间范围
:param yearly_seasonality: 年周期性
:param weekly_seasonality: 周周期性
:param daily_seasonality: 日周期性
:param holidays: 节假日或特殊日期
:param seasonality_mode: 季节模型方式, {'additive', 'multiplicative'}
:param seasonality_prior_scale: 改变周期性影响因素的强度
:param holidays_prior_scale: 改变假日模型的强度
:param changepoint_prior_scale: 设定自动突变点选择的灵活性,值越大越容易出现 Changepoint
"""
# Prophet 模型参数
self.params = {
"n_changepoints": n_changepoints,
"changepoint_range": changepoint_range,
"yearly_seasonality": yearly_seasonality,
"weekly_seasonality": weekly_seasonality,
"daily_seasonality": daily_seasonality,
"holidays": holidays,
"seasonality_mode": seasonality_mode,
"seasonality_prior_scale": seasonality_prior_scale,
"holidays_prior_scale": holidays_prior_scale,
"changepoint_prior_scale": changepoint_prior_scale
}
self.trg_st_dt = datetime.strptime(trg_st_dt, STD_D_STR) if isinstance(trg_st_dt, str) else trg_st_dt
self.tm_step = tm_step
self.trg_en_dt = self.trg_st_dt + timedelta(days=tm_step - 1)
self.his_st_dt = datetime.strptime(his_st_dt, STD_D_STR) if isinstance(his_st_dt, str) else his_st_dt
self.his_en_dt = datetime.strptime(his_en_dt, STD_D_STR) if isinstance(his_en_dt, str) else his_en_dt
# 提前期 = 预测开始日期 - 历史数据最晚日期 - 1 (认为次日预测提前期为0)
self.ahead = (self.trg_st_dt - self.his_en_dt).days - 1
self.x_train = x_train[['ds', 'y']].copy()
self.model = None
self.cv_horizon = cv_horizon
self.cv_period = cv_period
self.cv_initial = cv_initial
self.map_err = 100
2. 模型训练
def fit(self):
"""
模型训练
:return: 无
"""
self.x_train = self.x_train[
(datetime.strftime(self.his_st_dt, STD_D_STR) <= self.x_train['ds'])
& (self.x_train['ds'] <= datetime.strftime(self.his_en_dt, STD_D_STR))].reset_index(drop=True)
self.model = Prophet(**self.params)
with SuppressStdoutStderr():
self.model.fit(df=self.x_train)
3. 交叉验证
def cv(self, params=None) -> float:
"""
交叉验证
:param params: 模型参数,网格寻参时不为 None
:return: map_err: 平均绝对百分比误差(MAPE)
"""
params_ = params if params else self.params
self.model = Prophet(**params_)
with SuppressStdoutStderr():
self.model.fit(self.x_train)
cv_result = cross_validation(self.model,
horizon=self.cv_horizon, period=self.cv_period, initial=self.cv_initial)
# 平均绝对百分比误差(MAPE)
map_err = np.mean(np.abs(cv_result['yhat'] - cv_result['y']) / cv_result['y']) * 100
return map_err
4. 网格寻参
def grid_search(self) -> pd.DataFrame:
"""
网格寻参
:return: df_search: 寻参记录
"""
list_n_changepoints = [i for i in range(2, 7)]
list_changepoint_range = [i / 10 for i in range(5, 10)]
list_seasonality_mode = ["additive", "multiplicative"]
list_seasonality_prior_scale = [0.1, 0.5, 1, 5, 10]
list_changepoint_prior_scale = [0.1, 0.5, 1, 5, 10]
list_search = []
for nc in list_n_changepoints:
for cr in list_changepoint_range:
for sm in list_seasonality_mode:
for sps in list_seasonality_prior_scale:
for cps in list_changepoint_prior_scale:
params = {
"n_changepoints": nc,
"changepoint_range": cr,
"yearly_seasonality": False,
"weekly_seasonality": True,
"daily_seasonality": False,
"holidays": None,
"seasonality_mode": sm,
"seasonality_prior_scale": sps,
"holidays_prior_scale": 0,
"changepoint_prior_scale": cps
}
score = self.cv(params=params)
list_search.append([nc, cr, sm, sps, cps, score])
if score < self.map_err:
self.map_err, self.params = score, params
print("current best mse: {0}; current params: {1}".format(round(self.map_err, 4),
params))
df_search = pd.DataFrame(data=list_search,
columns=['n_changepoints', 'changepoint_range', 'seasonality_mode',
'seasonality_prior_scale', 'changepoint_prior_scale', 'mse'])
return df_search
5. 预测
def predict(self) -> Tuple[pd.DataFrame, pd.DataFrame]:
"""
预测
:return: df_predict_y: 预测数据
:return: df_history_y: 历史数据
"""
df_future = self.model.make_future_dataframe(periods=self.ahead + self.tm_step,
include_history=True).dropna().reset_index(drop=True)
df_predict = self.model.predict(df=df_future)
df_predict['ds'] = df_predict['ds'].apply(lambda x: str(x.date()))
df_predict_y = df_predict[['ds', 'yhat', 'yhat_lower', 'yhat_upper']].tail(self.tm_step)
df_history_y = df_predict[['ds', 'yhat']][: - (self.tm_step + self.ahead)]
return df_predict_y, df_history_y
(三)其他尝试
添加季节性组件:
self.model.add_seasonality(name='sin', period=2 * np.pi / 7, fourier_order=1)
(四)模型调用示例
1. 网格寻参
# 初始化
prophet_predictor = ProphetPredictor(x_train=df_input,
trg_st_dt=trg_st_dt, tm_step=3, his_st_dt=his_st_dt, his_en_dt=his_en_dt,
cv_horizon='3 days', cv_period='3 days', cv_initial='135 days')
# 网格寻参
dts_search = datetime.now()
df_search = prophet_predictor.grid_search()
print("df_search:\n", df_search, '\n')
dte_search = datetime.now()
tm_search = round((dte_search - dts_search).seconds + (dte_search - dts_search).microseconds / (10 ** 6), 3)
print("grid search time: {} s".format(tm_search), '\n')
# 训练
dts_train = datetime.now()
prophet_predictor.fit()
dte_train = datetime.now()
tm_train = round((dte_train - dts_train).seconds + (dte_train - dts_train).microseconds / (10 ** 6), 3)
print("train time: {} s".format(tm_train), '\n')
# 预测
df_predict_y, df_history_y = prophet_predictor.predict()
print("df_predict_y:\n", df_predict_y, '\n')
print("df_history_y:\n", df_history_y, '\n')
2. 给定参数
# 参数设定
params = {
"n_changepoints": 2,
"changepoint_range": 0.7,
"seasonality_mode": 'additive',
"seasonality_prior_scale": 0.5,
"changepoint_prior_scale": 10,
"yearly_seasonality": False,
"weekly_seasonality": True,
"daily_seasonality": False
}
# 初始化
prophet_predictor = ProphetPredictor(x_train=df_input,
trg_st_dt=trg_st_dt, tm_step=3, his_st_dt=his_st_dt, his_en_dt=his_en_dt,
cv_horizon='3 days', cv_period='3 days', cv_initial='135 days',
**params)
# 训练
dts_train = datetime.now()
prophet_predictor.fit()
dte_train = datetime.now()
tm_train = round((dte_train - dts_train).seconds + (dte_train - dts_train).microseconds / (10 ** 6), 3)
print("train time: {} s".format(tm_train), '\n')
# 预测
df_predict_y, df_history_y = prophet_predictor.predict()
print("df_predict_y:\n", df_predict_y, '\n')
print("df_history_y:\n", df_history_y, '\n')
参考资料:
https://www.cnblogs.com/fulu/p/13329656.html
https://www.cnblogs.com/zhazhaacmer/p/13786940.html
发布者:全栈程序员-用户IM,转载请注明出处:https://javaforall.cn/151111.html原文链接:https://javaforall.cn
【正版授权,激活自己账号】: Jetbrains全家桶Ide使用,1年售后保障,每天仅需1毛
【官方授权 正版激活】: 官方授权 正版激活 支持Jetbrains家族下所有IDE 使用个人JB账号...