Title: TensorFlow决策森林构建GBDT(Python) · Issue #50 · aialgorithm/Blog · GitHub
Open Graph Title: TensorFlow决策森林构建GBDT(Python) · Issue #50 · aialgorithm/Blog
X Title: TensorFlow决策森林构建GBDT(Python) · Issue #50 · aialgorithm/Blog
Description: 一、Deep Learning is Not All You Need 尽管神经网络在图像识别、自然语言等很多领域大放异彩,但回到表格数据的数据挖掘任务中,树模型才是低调王者,如论文《Tabular Data: Deep Learning is Not All You Need》提及的:深度学习可能不是解决所有机器学习问题的灵丹妙药,通过树模型在处理表格数据时性能与神经网络相当(甚至优于神经网络),而且树模型易于训练使用,有较好的可解释性。 二、树模型的使用 对于决策树...
Open Graph Description: 一、Deep Learning is Not All You Need 尽管神经网络在图像识别、自然语言等很多领域大放异彩,但回到表格数据的数据挖掘任务中,树模型才是低调王者,如论文《Tabular Data: Deep Learning is Not All You Need》提及的:深度学习可能不是解决所有机器学习问题的灵丹妙药,通过树模型在处理表格数据时性能与神经网络相当(甚至优于神经...
X Description: 一、Deep Learning is Not All You Need 尽管神经网络在图像识别、自然语言等很多领域大放异彩,但回到表格数据的数据挖掘任务中,树模型才是低调王者,如论文《Tabular Data: Deep Learning is Not All You Need》提及的:深度学习可能不是解决所有机器学习问题的灵丹妙药,通过树模型在处理表格数据时性能与神经网络相当(甚至优于神经...
Opengraph URL: https://github.com/aialgorithm/Blog/issues/50
X: @github
Domain: github.com
{"@context":"https://schema.org","@type":"DiscussionForumPosting","headline":"TensorFlow决策森林构建GBDT(Python)","articleBody":"### 一、Deep Learning is Not All You Need\r\n尽管神经网络在图像识别、自然语言等很多领域大放异彩,但回到表格数据的数据挖掘任务中,树模型才是低调王者,如论文《Tabular Data: Deep Learning is Not All You Need》提及的:深度学习可能不是解决所有机器学习问题的灵丹妙药,通过树模型在处理表格数据时性能与神经网络相当(甚至优于神经网络),而且树模型易于训练使用,有较好的可解释性。\r\n\r\n\r\n\r\n\r\n\r\n### 二、树模型的使用\r\n对于决策树等模型的使用,通常是要到scikit-learn、xgboost、lightgbm等机器学习库调用, 这和深度学习库是独立割裂的,不太方便树模型与神经网络的模型融合。\r\n\r\n一个好消息是,Google 开源了 TensorFlow 决策森林(TF-DF),为基于树的模型和神经网络提供统一的接口,可以直接用TensorFlow调用树模型。决策森林(TF-DF)简单来说就是用TensorFlow封装了常用的随机森林(RF)、梯度提升(GBDT)等算法,其底层算法是基于C++的 [Yggdrasil 决策森林 (YDF)](https://github.com/google/yggdrasil-decision-forests#:~:text=Yggdrasil%20Decision%20Forests%20(YDF)%20is,interpretation%20of%20Decision%20Forest%20models.)实现的。\r\n\r\n### 三、TensorFlow构建GBDT实践\r\nTF-DF安装很简单`pip install -U tensorflow_decision_forests`,有个遗憾是目前只支持Linux环境,如果本地用不了将代码复制到 Google Colab 试试~\r\n\r\n- 本例的数据集用的癌细胞分类的数据集,首先加载下常用的模块及数据集:\r\n```\r\nimport numpy as np \r\nimport pandas as pd\r\nimport matplotlib.pyplot as plt\r\nimport tensorflow as tf\r\ntf.random.set_seed(123)\r\n\r\nfrom sklearn import datasets\r\nfrom sklearn.model_selection import train_test_split\r\nfrom sklearn.metrics import precision_score, recall_score, f1_score,roc_curve\r\n\r\ndataset_cancer = datasets.load_breast_cancer() # 加载癌细胞数据集\r\n\r\n#print(dataset_cancer['DESCR'])\r\n\r\ndf = pd.DataFrame(dataset_cancer.data, columns=dataset_cancer.feature_names) \r\n\r\ndf['label'] = dataset_cancer.target\r\n\r\nprint(df.shape)\r\n\r\ndf.head()\r\n```\r\n\r\n\r\n- 划分数据集,并简单做下数据EDA分析:\r\n```\r\n# holdout验证法: 按3:7划分测试集 训练集\r\nx_train, x_test= train_test_split(df, test_size=0.3)\r\n\r\n# EDA分析:数据统计指标\r\nx_train.describe(include='all')\r\n```\r\n\r\n\r\n- 构建TensorFlow的GBDT模型:\r\nTD-DF 一个非常方便的地方是它不需要对数据进行任何预处理。它会自动处理数字和分类特征,以及缺失值,我们只需要将df转换为 TensorFlow 数据集,如下一些超参数设定:\r\n\r\n模型方面的树的一些常规超参数,类似于scikit-learn的GBDT\r\n\r\n\r\n\r\n此外,还有带有正则化(dropout、earlystop)、损失函数(focal-loss)、效率方面(goss基于梯度采样)等优化方法:\r\n\r\n\r\n\r\n构建模型、编译及训练,一步到位:\r\n```\r\n# 模型参数\r\nmodel_tf = tfdf.keras.GradientBoostedTreesModel(loss=\"BINARY_FOCAL_LOSS\")\r\n\r\n# 模型训练\r\nmodel_tf.compile()\r\nmodel_tf.fit(x=train_ds,validation_freq=0.1)\r\n```\r\n\r\n- 评估模型效果\r\n```\r\n## 模型评估\r\n可以看到test的准确率已经都接近1,可以再那个困难的数据任务试试~\r\nevaluation = model_tf.evaluate(test_ds,return_dict=True)\r\nprobs = model_tf.predict(test_ds)\r\nfpr, tpr, _ = roc_curve(x_test.label, probs)\r\nplt.plot(fpr, tpr)\r\nplt.title('ROC curve')\r\nplt.xlabel('false positive rate')\r\nplt.ylabel('true positive rate')\r\nplt.xlim(0,)\r\nplt.ylim(0,)\r\nplt.show()\r\nprint(evaluation)\r\n```\r\n\r\n\r\n- 模型解释性\r\nGBDT等树模型还有另外一个很大的优势是解释性,这里TF-DF也有实现。\r\n模型情况及特征重要性可以通过`print(model_tf.summary())`打印出来,\r\n\r\n特征重要性支持了几种不同的方法评估:\r\n\r\nMEAN_MIN_DEPTH指标。 平均最小深度越小,较低的值意味着大量样本是基于此特征进行分类的,变量越重要。\r\n\r\n\r\nNUM_NODES指标。它显示了给定特征被用作分割的次数,类似split。此外还有其他指标就不一一列举了。\r\n\r\n\r\n我们还可以打印出模型的具体决策的树结构,通过运行`tfdf.model_plotter.plot_model_in_colab(model_tf, tree_idx=0, max_depth=10)`,整个过程还是比较清晰的。\r\n\r\n\r\n#### 小结\r\n基于TensorFlow的TF-DF的树模型方法,我们可以方便训练树模型(特别对于熟练TensorFlow框架的同学),更进一步,也可以与TensorFlow的神经网络模型做效果对比、树模型与神经网络模型融合、利用异构模型先特征表示学习再输入模型(如GBDT+DNN、DNN embedding+GBDT),进一步了解可见如下参考文献。\r\n\r\n\u003e参考文献:\r\nhttps://www.tensorflow.org/decision_forests/\r\nhttps://keras.io/examples/structured_data/classification_with_tfdf/","author":{"url":"https://github.com/aialgorithm","@type":"Person","name":"aialgorithm"},"datePublished":"2022-05-05T10:26:46.000Z","interactionStatistic":{"@type":"InteractionCounter","interactionType":"https://schema.org/CommentAction","userInteractionCount":0},"url":"https://github.com/50/Blog/issues/50"}
| route-pattern | /_view_fragments/issues/show/:user_id/:repository/:id/issue_layout(.:format) |
| route-controller | voltron_issues_fragments |
| route-action | issue_layout |
| fetch-nonce | v2:dc085c4a-f906-33cb-cf4c-ae49192ac4af |
| current-catalog-service-hash | 81bb79d38c15960b92d99bca9288a9108c7a47b18f2423d0f6438c5b7bcd2114 |
| request-id | ABAC:BB9B6:BFDED9:1037603:6969E941 |
| html-safe-nonce | 0bb876b9deb2e1db2177aafed2ea49767e7e182bfd2fed54c1c199650f2aa9fa |
| visitor-payload | eyJyZWZlcnJlciI6IiIsInJlcXVlc3RfaWQiOiJBQkFDOkJCOUI2OkJGREVEOToxMDM3NjAzOjY5NjlFOTQxIiwidmlzaXRvcl9pZCI6IjQ4NjI3ODk1MzA4MDk5MTk4MDkiLCJyZWdpb25fZWRnZSI6ImlhZCIsInJlZ2lvbl9yZW5kZXIiOiJpYWQifQ== |
| visitor-hmac | e7b0c4c2327fdf9192a6dc7f0656535e6d5cbdec0318d6a74e9bd95d1a66ef53 |
| hovercard-subject-tag | issue:1226488227 |
| github-keyboard-shortcuts | repository,issues,copilot |
| google-site-verification | Apib7-x98H0j5cPqHWwSMm6dNU4GmODRoqxLiDzdx9I |
| octolytics-url | https://collector.github.com/github/collect |
| analytics-location | / |
| fb:app_id | 1401488693436528 |
| apple-itunes-app | app-id=1477376905, app-argument=https://github.com/_view_fragments/issues/show/aialgorithm/Blog/50/issue_layout |
| twitter:image | https://opengraph.githubassets.com/2ec452edfc67a94ececf346cae26d2e74574a305a73b118474d583794049745a/aialgorithm/Blog/issues/50 |
| twitter:card | summary_large_image |
| og:image | https://opengraph.githubassets.com/2ec452edfc67a94ececf346cae26d2e74574a305a73b118474d583794049745a/aialgorithm/Blog/issues/50 |
| og:image:alt | 一、Deep Learning is Not All You Need 尽管神经网络在图像识别、自然语言等很多领域大放异彩,但回到表格数据的数据挖掘任务中,树模型才是低调王者,如论文《Tabular Data: Deep Learning is Not All You Need》提及的:深度学习可能不是解决所有机器学习问题的灵丹妙药,通过树模型在处理表格数据时性能与神经网络相当(甚至优于神经... |
| og:image:width | 1200 |
| og:image:height | 600 |
| og:site_name | GitHub |
| og:type | object |
| og:author:username | aialgorithm |
| hostname | github.com |
| expected-hostname | github.com |
| None | 7b32f1c7c4549428ee399213e8345494fc55b5637195d3fc5f493657579235e8 |
| turbo-cache-control | no-preview |
| go-import | github.com/aialgorithm/Blog git https://github.com/aialgorithm/Blog.git |
| octolytics-dimension-user_id | 33707637 |
| octolytics-dimension-user_login | aialgorithm |
| octolytics-dimension-repository_id | 147093233 |
| octolytics-dimension-repository_nwo | aialgorithm/Blog |
| octolytics-dimension-repository_public | true |
| octolytics-dimension-repository_is_fork | false |
| octolytics-dimension-repository_network_root_id | 147093233 |
| octolytics-dimension-repository_network_root_nwo | aialgorithm/Blog |
| turbo-body-classes | logged-out env-production page-responsive |
| disable-turbo | false |
| browser-stats-url | https://api.github.com/_private/browser/stats |
| browser-errors-url | https://api.github.com/_private/browser/errors |
| release | bdde15ad1b403e23b08bbd89b53fbe6bdf688cad |
| ui-target | full |
| theme-color | #1e2327 |
| color-scheme | light dark |
Links:
Viewport: width=device-width