Prophet 初学笔记[通俗易懂]

Prophet 初学笔记[通俗易懂]本文介绍Prophet模型的简单调用。

大家好,又见面了,我是你们的朋友全栈君。

本文介绍 Prophet 模型的简单调用。

 

(一)日志设置为不输出

import os

class SuppressStdoutStderr(object):
    """
    A context manager for doing a "deep suppression" of stdout and stderr in
    Python, i.e. will suppress all print, even if the print originates in a
    compiled C/Fortran sub-function.

    This will not suppress raised exceptions, since exceptions are printed
    to stderr just before a script exits, and after the context manager has
    exited (at least, I think that is why it lets exceptions through).
    """

    def __init__(self):
        # open a pair of null files
        self.null_fds = [os.open(os.devnull, os.O_RDWR) for x in range(2)]

        # save the actual stdout (1) and stderr (2) file descriptors
        self.save_fds = (os.dup(1), os.dup(2))

    def __enter__(self):
        # assign the null pointers to stdout and stderr
        os.dup2(self.null_fds[0], 1)
        os.dup2(self.null_fds[1], 2)

    def __exit__(self, *_):
        # reassign the real stdout/stderr back to (1) and (2)
        os.dup2(self.save_fds[0], 1)
        os.dup2(self.save_fds[1], 2)

        # close the null files
        os.close(self.null_fds[0])
        os.close(self.null_fds[1])
        os.close(self.save_fds[0])
        os.close(self.save_fds[1])

 

(二)Prophet 预测模型 class,与官方 Prophet 结构相似,但不继承

1. 初始化

from datetime import datetime, timedelta
from typing import Tuple
import numpy as np
import pandas as pd

from fbprophet import Prophet
from fbprophet.diagnostics import cross_validation

STD_D_STR = "%Y-%m-%d"  # '%m': 月份,'%M': 分钟

class ProphetPredictor(object):
    def __init__(self, x_train: pd.DataFrame,
                 trg_st_dt: datetime, tm_step: int, his_st_dt: datetime, his_en_dt: datetime,
                 cv_horizon: str, cv_period: str, cv_initial: str,
                 n_changepoints=None, changepoint_range=0.7,
                 yearly_seasonality=False, weekly_seasonality=True, daily_seasonality=False,
                 holidays=None, seasonality_mode='multiplicative',
                 seasonality_prior_scale=10, holidays_prior_scale=0, changepoint_prior_scale=0.05):
        """
        initialisation

        :param x_train:  数据集, ['ds', 'y']
        :param trg_st_dt:  预测开始日期
        :param tm_step:  预测时间间隔
        :param his_st_dt:  训练集开始日期
        :param his_en_dt:  训练集结束日期  todo: 允许使用的训练集结束日期必须严格小于预测开始日期
        :param cv_horizon:  交叉验证 horizon 参数, '3 days' 格式
        :param cv_period:  交叉验证 period 参数, '3 days' 格式
        :param cv_initial:  交叉验证 initial 参数, '3 days' 格式
        :param n_changepoints:  Changepoint 最大数量
        :param changepoint_range:  Changepoint 在历史数据中出现的时间范围
        :param yearly_seasonality:  年周期性
        :param weekly_seasonality:  周周期性
        :param daily_seasonality:  日周期性
        :param holidays:  节假日或特殊日期
        :param seasonality_mode:  季节模型方式, {'additive', 'multiplicative'}
        :param seasonality_prior_scale:  改变周期性影响因素的强度
        :param holidays_prior_scale:  改变假日模型的强度
        :param changepoint_prior_scale:  设定自动突变点选择的灵活性,值越大越容易出现 Changepoint
        """

        # Prophet 模型参数
        self.params = {
            "n_changepoints": n_changepoints,
            "changepoint_range": changepoint_range,
            "yearly_seasonality": yearly_seasonality,
            "weekly_seasonality": weekly_seasonality,
            "daily_seasonality": daily_seasonality,
            "holidays": holidays,
            "seasonality_mode": seasonality_mode,
            "seasonality_prior_scale": seasonality_prior_scale,
            "holidays_prior_scale": holidays_prior_scale,
            "changepoint_prior_scale": changepoint_prior_scale
        }

        self.trg_st_dt = datetime.strptime(trg_st_dt, STD_D_STR) if isinstance(trg_st_dt, str) else trg_st_dt
        self.tm_step = tm_step
        self.trg_en_dt = self.trg_st_dt + timedelta(days=tm_step - 1)
        self.his_st_dt = datetime.strptime(his_st_dt, STD_D_STR) if isinstance(his_st_dt, str) else his_st_dt
        self.his_en_dt = datetime.strptime(his_en_dt, STD_D_STR) if isinstance(his_en_dt, str) else his_en_dt

        # 提前期 = 预测开始日期 - 历史数据最晚日期 - 1 (认为次日预测提前期为0)
        self.ahead = (self.trg_st_dt - self.his_en_dt).days - 1

        self.x_train = x_train[['ds', 'y']].copy()
        self.model = None
        self.cv_horizon = cv_horizon
        self.cv_period = cv_period
        self.cv_initial = cv_initial
        self.map_err = 100

2. 模型训练

    def fit(self):
        """
        模型训练
        :return: 无
        """

        self.x_train = self.x_train[
            (datetime.strftime(self.his_st_dt, STD_D_STR) <= self.x_train['ds'])
            & (self.x_train['ds'] <= datetime.strftime(self.his_en_dt, STD_D_STR))].reset_index(drop=True)

        self.model = Prophet(**self.params)
        with SuppressStdoutStderr():
            self.model.fit(df=self.x_train)

3. 交叉验证

    def cv(self, params=None) -> float:
        """
        交叉验证
        :param params:  模型参数,网格寻参时不为 None
        :return: map_err:  平均绝对百分比误差(MAPE)
        """

        params_ = params if params else self.params
        self.model = Prophet(**params_)
        with SuppressStdoutStderr():
            self.model.fit(self.x_train)
            cv_result = cross_validation(self.model,
                                         horizon=self.cv_horizon, period=self.cv_period, initial=self.cv_initial)

        # 平均绝对百分比误差(MAPE)
        map_err = np.mean(np.abs(cv_result['yhat'] - cv_result['y']) / cv_result['y']) * 100

        return map_err

4. 网格寻参

    def grid_search(self) -> pd.DataFrame:
        """
        网格寻参
        :return: df_search:  寻参记录
        """

        list_n_changepoints = [i for i in range(2, 7)]
        list_changepoint_range = [i / 10 for i in range(5, 10)]
        list_seasonality_mode = ["additive", "multiplicative"]
        list_seasonality_prior_scale = [0.1, 0.5, 1, 5, 10]
        list_changepoint_prior_scale = [0.1, 0.5, 1, 5, 10]

        list_search = []
        for nc in list_n_changepoints:
            for cr in list_changepoint_range:
                for sm in list_seasonality_mode:
                    for sps in list_seasonality_prior_scale:
                        for cps in list_changepoint_prior_scale:
                            params = {
                                "n_changepoints": nc,
                                "changepoint_range": cr,
                                "yearly_seasonality": False,
                                "weekly_seasonality": True,
                                "daily_seasonality": False,
                                "holidays": None,
                                "seasonality_mode": sm,
                                "seasonality_prior_scale": sps,
                                "holidays_prior_scale": 0,
                                "changepoint_prior_scale": cps
                            }
                            score = self.cv(params=params)
                            list_search.append([nc, cr, sm, sps, cps, score])

                            if score < self.map_err:
                                self.map_err, self.params = score, params
                                print("current best mse:  {0};  current params:  {1}".format(round(self.map_err, 4),
                                                                                             params))

        df_search = pd.DataFrame(data=list_search,
                                 columns=['n_changepoints', 'changepoint_range', 'seasonality_mode',
                                          'seasonality_prior_scale', 'changepoint_prior_scale', 'mse'])

        return df_search

5. 预测

    def predict(self) -> Tuple[pd.DataFrame, pd.DataFrame]:
        """
        预测
        :return: df_predict_y:  预测数据
        :return: df_history_y:  历史数据
        """

        df_future = self.model.make_future_dataframe(periods=self.ahead + self.tm_step,
                                                     include_history=True).dropna().reset_index(drop=True)
        df_predict = self.model.predict(df=df_future)
        df_predict['ds'] = df_predict['ds'].apply(lambda x: str(x.date()))

        df_predict_y = df_predict[['ds', 'yhat', 'yhat_lower', 'yhat_upper']].tail(self.tm_step)
        df_history_y = df_predict[['ds', 'yhat']][: - (self.tm_step + self.ahead)]

        return df_predict_y, df_history_y

 

(三)其他尝试

添加季节性组件:

self.model.add_seasonality(name='sin', period=2 * np.pi / 7, fourier_order=1)

 

(四)模型调用示例

1. 网格寻参

# 初始化
prophet_predictor = ProphetPredictor(x_train=df_input,
                                     trg_st_dt=trg_st_dt, tm_step=3, his_st_dt=his_st_dt, his_en_dt=his_en_dt,
                                     cv_horizon='3 days', cv_period='3 days', cv_initial='135 days')

# 网格寻参
dts_search = datetime.now()
df_search = prophet_predictor.grid_search()
print("df_search:\n", df_search, '\n')
dte_search = datetime.now()
tm_search = round((dte_search - dts_search).seconds + (dte_search - dts_search).microseconds / (10 ** 6), 3)
print("grid search time:  {} s".format(tm_search), '\n')

# 训练
dts_train = datetime.now()
prophet_predictor.fit()
dte_train = datetime.now()
tm_train = round((dte_train - dts_train).seconds + (dte_train - dts_train).microseconds / (10 ** 6), 3)
print("train time:  {} s".format(tm_train), '\n')

# 预测
df_predict_y, df_history_y = prophet_predictor.predict()
print("df_predict_y:\n", df_predict_y, '\n')
print("df_history_y:\n", df_history_y, '\n')

2. 给定参数

# 参数设定
params = {
    "n_changepoints": 2,
    "changepoint_range": 0.7,
    "seasonality_mode": 'additive',
    "seasonality_prior_scale": 0.5,
    "changepoint_prior_scale": 10,
    "yearly_seasonality": False,
    "weekly_seasonality": True,
    "daily_seasonality": False
}

# 初始化
prophet_predictor = ProphetPredictor(x_train=df_input,
                                     trg_st_dt=trg_st_dt, tm_step=3, his_st_dt=his_st_dt, his_en_dt=his_en_dt,
                                     cv_horizon='3 days', cv_period='3 days', cv_initial='135 days',
                                     **params)

# 训练
dts_train = datetime.now()
prophet_predictor.fit()
dte_train = datetime.now()
tm_train = round((dte_train - dts_train).seconds + (dte_train - dts_train).microseconds / (10 ** 6), 3)
print("train time:  {} s".format(tm_train), '\n')

# 预测
df_predict_y, df_history_y = prophet_predictor.predict()
print("df_predict_y:\n", df_predict_y, '\n')
print("df_history_y:\n", df_history_y, '\n')

 

 

参考资料:

https://www.cnblogs.com/fulu/p/13329656.html

https://www.cnblogs.com/zhazhaacmer/p/13786940.html

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。

发布者:全栈程序员-用户IM,转载请注明出处:https://javaforall.cn/151111.html原文链接:https://javaforall.cn

【正版授权,激活自己账号】: Jetbrains全家桶Ide使用,1年售后保障,每天仅需1毛

【官方授权 正版激活】: 官方授权 正版激活 支持Jetbrains家族下所有IDE 使用个人JB账号...

(0)


相关推荐

  • 字符串的匹配算法_多字符串匹配

    字符串的匹配算法_多字符串匹配目录需求基础知识逻辑解析源码实现需求先简单描述溪源曾经遇到的需求:需求一:项目结果文件中实验结论可能会存在未知类型、转换错误、空指针、超过索引长度等等。这里是类比需求,用日常开发中常出现的错误类型作为需求,如果要以上结论则判断这个项目检测失败;解决方案一:大家常用的方式可能是if(){continue;}esleif(){continue;}…或者switch-case等;方案二:可能会使用集合contain()方法;方案三:依次匹配字符串中字符(暴力匹配);以上两种方案都能解决;然

  • PHP中生成json信息的方法

    PHP中生成json信息的方法

  • Java工程师自我评价(软件工程师自我评价)

    JAVA工程师简历自我评价无论在学习、工作或是生活中,我们都可能会使用到自我评价,自我评价具有重要的社会功能,它极大地影响人与人之间的交往方式。那要怎么写好自我评价呢?以下是小编收集整理的JAVA工程师简历自我评价,仅供参考,希望能够帮助到大家。JAVA工程师简历自我评价11、有较强的分析、解决问题的能力,具有较强逻辑思维能力和表达能力。2、性格开朗,积极乐观,能以极大的’热情投入工作。3、具…

  • Python中时间与时间戳之间的转换

    对于时间数据,如2016-05-0520:28:54,有时需要与时间戳进行相互的运算,此时就需要对两种形式进行转换,在Python中,转换时需要用到time模块,具体的操作有如下的几种:将时间转换为时间戳重新格式化时间时间戳转换为时间获取当前时间及将其转换成时间戳1、将时间转换成时间戳将如上的时间2016-05-0520:28:54转换成时间戳,具体的操作过程为:利用strptime

  • maven 快照 更新策略与发布到私服仓库方法

    maven 快照 更新策略与发布到私服仓库方法1、为什么会有快照?没有快照之前:A项目依赖于项目B,B每次改动就赋予一个新版本号,然后告诉A我改版本好了啊,每次改动都得告诉,有时忘了就麻烦了。可以看出没有快照会带来“浪费版本号”、沟通成大加大的问题。有了快照之后:A项目依赖于项目B,B每次改动都会打上时间戳,A编译时会检查B的时间戳,如果晚于本地仓库…

  • 前端工程配置Nginx反向代理[通俗易懂]

    前端工程配置Nginx反向代理HTTP配置HTTPS配置配置两个反向代理,一个代理http页面,一个代理https页面,前者监听80端口,后者监听443端口。配置后整个文件如下,其中有不少冗余,挑有用的看即可。#user nobody;worker_processes 1;#error_log logs/error.log;#error_log logs/error.log notice;#error_log logs/error.log info;#pid

发表回复

您的电子邮箱地址不会被公开。

关注全栈程序员社区公众号