Logistic Regression, LDA QDA and KNN

在本文档中,我将向您展示如何在股票市场数据集(Smarket)上运行LR,LDA,QDA和KNN的分类方法。 对于每种方法,我将单独读取数据,因为其中一些方法要求我们以不同的方式输入数据(尤其是KNN函数)。

广义线性模型(GLM)

Generalized Linear Models
广义线性模型(GLM)

使用glm()函数拟合广义线性模型。glm函数的形式是

1
glm(formula, family=familytype(link=linkfunction), data=)

family Default Link Function
binomial (link = “logit”)
gaussian (link = “identity”)
Gamma (link = “inverse”)
inverse.gaussian (link = “1/mu^2”)
poisson (link = “log”)
quasi (link = “identity”, variance = “constant”)
quasibinomial (link = “logit”)
quasipoisson (link = “log”)

有关其他建模选项,请参阅help(glm)。有关每个系列的其他允许link功能,请参阅help(family)。

这里将介绍三种广义线性模型的子类型:逻辑回归,泊松回归和生存分析。

Logistic回归

当您从一组连续预测变量预测二元结果时,逻辑回归很有用。由于其限制性较低的假设,它通常优于判别函数分析。

1
2
3
4
5
6
7
8
9
10
# Logistic Regression
# where F is a binary factor and
# x1-x3 are continuous predictors
fit <- glm(F~x1+x2+x3,data=mydata,family=binomial())
summary(fit) # display results
confint(fit) # 95% CI for the coefficients
exp(coef(fit)) # exponentiated coefficients
exp(confint(fit)) # 95% CI for exponentiated coefficients
predict(fit, type="response") # predicted values
residuals(fit, type="deviance") # residuals

可以使用anova(fit1 ,fit2 ,test =“Chisq”)来比较嵌套模型。另外,cdplot(F ~ x ,data = mydata )将在连续x变量上显示二元结果F的条件密度图。

Logistic Regression(LR)

Smarket数据集是ISLR包的一部分。我们首先阅读加载ISLR包,然后将Smarket数据集拆分为训练和测试数据。在这个问题中,我们将使用子集技术(向量子集)。训练数据集将包含2005年之前的所有观测资料,测试数据集将包含2005年的所有观测资料。

以下命令将Smarket的数据框附加到R的工作目录(内存)。这将使我们能够直接访问数据集中的变量而无需指定数据框的名称,因此我们可以直接键入Smarket数据集中Year的变量,而不是键入Smarket $ Year。

1
2
3
4
5
library(ISLR) attach(Smarket)
## find the indecies for the observations in Smarket that consitutes the training
## and testing data
train = Year < 2005
test = !train

上面的命令创建了两个布尔向量。 布尔向量的值为TRUE或FALSE。 对于训练矢量,将为具有与年份<2005的Smarket中的观测具有相同index的单元分配TRUE。 “!”否定了训练中的内容。
选择Smarket中将进入培训和测试数据集的值。请注意,我们摆脱了第8个变量today因为它与方向类似。实际上,这就是Direction的计算方法:

1
2
training_data = Smarket[train, -8]
testing_data = Smarket[test, -8]

对于模型评估目的,我们将创建一个包含测试数据集中所有y值的向量。 在我们使用训练数据创建模型之后,模型评估将在稍后进行。

1
testing_y = Direction[test]

我们使用Direction因为这是我们想要预测的(我们的y变量)。 请注意,当我们索引Direction时,我们没有输入逗号,因为它是一个向量(一列)而不是Dataframe!现在是时候使用训练数据集训练我们的模型了:

1
logistic_model = glm(Direction ~ .,data = training_data,family = "binomial")

在上面的逻辑回归模型中,我们使用glm()函数,它是一般的线性模型。 第一个参数是我们的回归方程,它指定我们使用我们的data集中的所有预测变量来预测Direction(.意味着使用所有变量)。如果你想使用特定的变量,比如说Lag1和Lag2,那么公式就是Direction~Lag1 + Lag2。 我们使用训练数据集来训练我们的模型,并且我们将模型的family指定为二项式,因为我们正在运行逻辑回归。 如果我们不指定线性模型的family,则模型将是常规线性回归。

接下来,我们要评估我们的模型logistic_model。 为此,我们将预测测试数据集的y值,然后将预测的y与我们之前在名称testing_y下保存的实际值进行比较。 当在逻辑回归中使用predict()函数时,它会计算在一个类别或另一个类别中的预测概率(在我们的例子中为Down或Up)。

1
2
3
4
 logistic_probs = predict(logistic_model, testing_data, type = "response") head(logistic_probs)

## 999 1000 1001 1002 1003 1004
## 0.6385 0.6017 0.6038 0.5962 0.5875 0.5928

由于predict()计算概率,因此我们必须将它们转换为实际的类别(Up或Down)。 不幸的是,在逻辑回归中,predict()函数不会产生类别。 那么,让我们转换这些可能性。 我们首先创建一个向量来保存这些类。 这个数组将具有相同的testing_y长度(本例中为252),我们将初始化它以使其所有单元格标记为Down,然后我们将更新相应的预测可能性为大于0.5(此阈值可能会根据应用程序而变化)的单元格为Up。

1
2
3
4
5
6
7
8
9
logistic_pred_y = rep("Down", length(testing_y)) 
## the function rep(), repeats "Down" 252 times

logistic_pred_y[logistic_probs > 0.5] = "Up"

## R will first evaluate "logistic_probs >0.5", and it will be a vector of TRUE and FALSE.
## TRUE when the value of the cell in logistic_prob > 0.5, otherwise FALSE.
## R will replace all the "Down" values in "logistic_pred_y" vector with "Up"
## when "logistic_probs >0.5" is TRUE.

评估的最后几个步骤包括找到模型的混淆矩阵。

1
2
3
4
5
6
7
## the following command creates the confusion matrix 
table(logistic_pred_y, testing_y)

## testing_y
## logistic_pred_y Down Up
## Down 2 1
## Up 109 140

现在,让我们计算错误分类错误率:

1
2
3
mean(logistic_pred_y != testing_y)

## [1] 0.4365

我们在上面创建的逻辑回归模型的错误分类错误率是43.65%,这被认为是高错误分类错误率。

代码脚本:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
--LR
library(ISLR)
attach(Smarket)
train = Year < 2005
test = !train

names(Smarket)
training_data = Smarket[train, -8]
testing_data = Smarket[test, -8]

test_y = Direction[test]

logistic_model = glm(Direction ~ .,
data = training_data,
family = "binomial")
summary(logistic_model)

logistic_probs = predict(logistic_model, testing_data, type = "response")
head(logistic_probs)

logistic_pred_y = rep("Down", length(test_y))
## the function rep(), repeats "Down" 252 times
logistic_pred_y[logistic_probs > 0.5] = "Up"

## the following command creates the confusion matrix
table(logistic_pred_y, test_y)
mean(logistic_pred_y != test_y)

泊松回归

泊松回归模型常常应用于因变量是计数变量(count variable)的情形。

1
2
3
4
5
# Poisson Regression
# where count is a count and
# x1-x3 are continuous predictors
fit <- glm(count ~ x1+x2+x3, data=mydata, family=poisson())
summary(fit) display results

如果你有过度离散(看看剩余偏差是否远大于自由度),你可能想要使用quasipoisson()而不是poisson()。

生存分析

生存分析(也称为事件历史分析或可靠性分析)涵盖了一组用于对事件时间建模的技术。数据可能是正确的审查 - 事件可能不会在研究结束时发生,或者我们可能有关于观察的不完整信息,但知道事件没有发生到某一时间(例如参与者在10周内退出研究)但当时还活着)。

虽然通常使用glm()函数分析广义线性模型,但通常使用来自生存包。生存包可以处理一个和两个样本问题,参数加速失效模型和Cox比例风险模型。

通常以格式化开始时间,停止时间和状态输入数据(1 =发生事件,0 =未发生事件)。或者,数据可以是事件和状态的时间格式(1 =发生事件,0 =未发生事件)。状态= 0表示观察是正确的。在进一步分析之前,数据通过Surv()函数捆绑到Surv对象中。

1
2
3
4
5
6
survfit()用于估计一个或多个组的生存分布。创建一个生存对象。
survdiff()测试两组或更多组之间生存分布的差异。使用公式或已构建的Cox模型拟合生存曲线。
coxph()对一组预测变量的危险函数进行建模。拟合Cox比例风险回归模型。

cox.zph():检验一个Cox回归模型的比例风险假设。
survdiff():用log-rank/Mantel-Haenszel检验检验生存差异。

生存分析备查表
背景介绍

1
2
3
# Mayo Clinic Lung Cancer Data
library(survival)
?lung

inst: Institution code

time: Survival time in days

status: censoring status 1=censored, 2=dead

age: Age in years

sex: Male=1 Female=2

ph.ecog: ECOG performance score (0=good 5=dead)

ph.karno: Karnofsky performance score as rated by physician

pat.karno: Karnofsky performance score as rated by patient

meal.cal: Calories consumed at meals

wt.loss: Weight loss in last six months

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
lung <- as_tibble(lung)
lung

# A tibble: 228 x 10
inst time status age sex ph.ecog ph.karno pat.karno meal.cal wt.loss
* <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
1 3 306 2 74 1 1 90 100 1175 NA
2 3 455 2 68 1 0 90 90 1225 15
3 3 1010 1 56 1 0 90 90 NA 15
4 5 210 2 57 1 1 90 60 1150 11
5 1 883 2 60 1 0 100 90 NA 0
6 12 1022 1 74 1 1 50 80 513 0
7 7 310 2 68 2 2 70 60 384 10
8 11 361 2 71 2 2 60 80 538 1
9 1 218 2 53 1 1 70 80 825 16
10 7 166 2 61 1 2 70 70 271 34
# ... with 218 more rows

生存曲线

image

构建生存对象:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
survobj <- with(lung, Surv(time,status))
survobj
[1] 306 455 1010+ 210 883 1022+ 310 361 218 166 170 654 728 71 567 144
[17] 613 707 61 88 301 81 624 371 394 520 574 118 390 12 473 26
[33] 533 107 53 122 814 965+ 93 731 460 153 433 145 583 95 303 519
[49] 643 765 735 189 53 246 689 65 5 132 687 345 444 223 175 60
[65] 163 65 208 821+ 428 230 840+ 305 11 132 226 426 705 363 11 176
[81] 791 95 196+ 167 806+ 284 641 147 740+ 163 655 239 88 245 588+ 30
[97] 179 310 477 166 559+ 450 364 107 177 156 529+ 11 429 351 15 181
[113] 283 201 524 13 212 524 288 363 442 199 550 54 558 207 92 60
[129] 551+ 543+ 293 202 353 511+ 267 511+ 371 387 457 337 201 404+ 222 62
[145] 458+ 356+ 353 163 31 340 229 444+ 315+ 182 156 329 364+ 291 179 376+
[161] 384+ 268 292+ 142 413+ 266+ 194 320 181 285 301+ 348 197 382+ 303+ 296+
[177] 180 186 145 269+ 300+ 284+ 350 272+ 292+ 332+ 285 259+ 110 286 270 81
[193] 131 225+ 269 225+ 243+ 279+ 276+ 135 79 59 240+ 202+ 235+ 105 224+ 239
[209] 237+ 173+ 252+ 221+ 185+ 92+ 13 222+ 192+ 183 211+ 175+ 197+ 203+ 116 188+
[225] 191+ 105+ 174+ 177+

使用survfit()函数拟合一条生存曲线。这里让我们先创建一条不考虑任何比较的生存曲线,所以我们只需要指定survfit()在公式里期望的截距(比如~1)

1
2
3
4
5
6
fit0 <- survfit(survobj~1, data=lung)

Call: survfit(formula = survobj ~ 1, data = lung)

n events median 0.95LCL 0.95UCL
228 165 310 285 363

但模型对象本身不会给出太多的价值信息,我们需要使用summary函数查看模型汇总结果:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
summary(fit0)

Call: survfit(formula = survobj ~ 1, data = lung)

time n.risk n.event survival std.err lower 95% CI upper 95% CI
5 228 1 0.9956 0.00438 0.9871 1.000
11 227 3 0.9825 0.00869 0.9656 1.000
12 224 1 0.9781 0.00970 0.9592 0.997
13 223 2 0.9693 0.01142 0.9472 0.992
15 221 1 0.9649 0.01219 0.9413 0.989
26 220 1 0.9605 0.01290 0.9356 0.986
30 219 1 0.9561 0.01356 0.9299 0.983
31 218 1 0.9518 0.01419 0.9243 0.980
53 217 2 0.9430 0.01536 0.9134 0.974
54 215 1 0.9386 0.01590 0.9079 0.970
59 214 1 0.9342 0.01642 0.9026 0.967
60 213 2 0.9254 0.01740 0.8920 0.960
61 211 1 0.9211 0.01786 0.8867 0.957
62 210 1 0.9167 0.01830 0.8815 0.953
65 209 2 0.9079 0.01915 0.8711 0.946
71 207 1 0.9035 0.01955 0.8660 0.943
79 206 1 0.8991 0.01995 0.8609 0.939
81 205 2 0.8904 0.02069 0.8507 0.932
88 203 2 0.8816 0.02140 0.8406 0.925
92 201 1 0.8772 0.02174 0.8356 0.921
93 199 1 0.8728 0.02207 0.8306 0.917
95 198 2 0.8640 0.02271 0.8206 0.910
105 196 1 0.8596 0.02302 0.8156 0.906
107 194 2 0.8507 0.02362 0.8056 0.898
110 192 1 0.8463 0.02391 0.8007 0.894
116 191 1 0.8418 0.02419 0.7957 0.891
118 190 1 0.8374 0.02446 0.7908 0.887
122 189 1 0.8330 0.02473 0.7859 0.883
131 188 1 0.8285 0.02500 0.7810 0.879
132 187 2 0.8197 0.02550 0.7712 0.871
135 185 1 0.8153 0.02575 0.7663 0.867
142 184 1 0.8108 0.02598 0.7615 0.863
144 183 1 0.8064 0.02622 0.7566 0.859
145 182 2 0.7975 0.02667 0.7469 0.852
147 180 1 0.7931 0.02688 0.7421 0.848
153 179 1 0.7887 0.02710 0.7373 0.844
156 178 2 0.7798 0.02751 0.7277 0.836
163 176 3 0.7665 0.02809 0.7134 0.824
166 173 2 0.7577 0.02845 0.7039 0.816
167 171 1 0.7532 0.02863 0.6991 0.811
170 170 1 0.7488 0.02880 0.6944 0.807
175 167 1 0.7443 0.02898 0.6896 0.803
176 165 1 0.7398 0.02915 0.6848 0.799
177 164 1 0.7353 0.02932 0.6800 0.795
179 162 2 0.7262 0.02965 0.6704 0.787
180 160 1 0.7217 0.02981 0.6655 0.783
181 159 2 0.7126 0.03012 0.6559 0.774
182 157 1 0.7081 0.03027 0.6511 0.770
183 156 1 0.7035 0.03041 0.6464 0.766
186 154 1 0.6989 0.03056 0.6416 0.761
189 152 1 0.6943 0.03070 0.6367 0.757
194 149 1 0.6897 0.03085 0.6318 0.753
197 147 1 0.6850 0.03099 0.6269 0.749
199 145 1 0.6803 0.03113 0.6219 0.744
201 144 2 0.6708 0.03141 0.6120 0.735
202 142 1 0.6661 0.03154 0.6071 0.731
207 139 1 0.6613 0.03168 0.6020 0.726
208 138 1 0.6565 0.03181 0.5970 0.722
210 137 1 0.6517 0.03194 0.5920 0.717
212 135 1 0.6469 0.03206 0.5870 0.713
218 134 1 0.6421 0.03218 0.5820 0.708
222 132 1 0.6372 0.03231 0.5769 0.704
223 130 1 0.6323 0.03243 0.5718 0.699
226 126 1 0.6273 0.03256 0.5666 0.694
229 125 1 0.6223 0.03268 0.5614 0.690
230 124 1 0.6172 0.03280 0.5562 0.685
239 121 2 0.6070 0.03304 0.5456 0.675
245 117 1 0.6019 0.03316 0.5402 0.670
246 116 1 0.5967 0.03328 0.5349 0.666
267 112 1 0.5913 0.03341 0.5294 0.661
268 111 1 0.5860 0.03353 0.5239 0.656
269 110 1 0.5807 0.03364 0.5184 0.651
270 108 1 0.5753 0.03376 0.5128 0.645
283 104 1 0.5698 0.03388 0.5071 0.640
284 103 1 0.5642 0.03400 0.5014 0.635
285 101 2 0.5531 0.03424 0.4899 0.624
286 99 1 0.5475 0.03434 0.4841 0.619
288 98 1 0.5419 0.03444 0.4784 0.614
291 97 1 0.5363 0.03454 0.4727 0.608
293 94 1 0.5306 0.03464 0.4669 0.603
301 91 1 0.5248 0.03475 0.4609 0.597
303 89 1 0.5189 0.03485 0.4549 0.592
305 87 1 0.5129 0.03496 0.4488 0.586
306 86 1 0.5070 0.03506 0.4427 0.581
310 85 2 0.4950 0.03523 0.4306 0.569
320 82 1 0.4890 0.03532 0.4244 0.563
329 81 1 0.4830 0.03539 0.4183 0.558
337 79 1 0.4768 0.03547 0.4121 0.552
340 78 1 0.4707 0.03554 0.4060 0.546
345 77 1 0.4646 0.03560 0.3998 0.540
348 76 1 0.4585 0.03565 0.3937 0.534
350 75 1 0.4524 0.03569 0.3876 0.528
351 74 1 0.4463 0.03573 0.3815 0.522
353 73 2 0.4340 0.03578 0.3693 0.510
361 70 1 0.4278 0.03581 0.3631 0.504
363 69 2 0.4154 0.03583 0.3508 0.492
364 67 1 0.4092 0.03582 0.3447 0.486
371 65 2 0.3966 0.03581 0.3323 0.473
387 60 1 0.3900 0.03582 0.3258 0.467
390 59 1 0.3834 0.03582 0.3193 0.460
394 58 1 0.3768 0.03580 0.3128 0.454
426 55 1 0.3700 0.03580 0.3060 0.447
428 54 1 0.3631 0.03579 0.2993 0.440
429 53 1 0.3563 0.03576 0.2926 0.434
433 52 1 0.3494 0.03573 0.2860 0.427
442 51 1 0.3426 0.03568 0.2793 0.420
444 50 1 0.3357 0.03561 0.2727 0.413
450 48 1 0.3287 0.03555 0.2659 0.406
455 47 1 0.3217 0.03548 0.2592 0.399
457 46 1 0.3147 0.03539 0.2525 0.392
460 44 1 0.3076 0.03530 0.2456 0.385
473 43 1 0.3004 0.03520 0.2388 0.378
477 42 1 0.2933 0.03508 0.2320 0.371
519 39 1 0.2857 0.03498 0.2248 0.363
520 38 1 0.2782 0.03485 0.2177 0.356
524 37 2 0.2632 0.03455 0.2035 0.340
533 34 1 0.2554 0.03439 0.1962 0.333
550 32 1 0.2475 0.03423 0.1887 0.325
558 30 1 0.2392 0.03407 0.1810 0.316
567 28 1 0.2307 0.03391 0.1729 0.308
574 27 1 0.2221 0.03371 0.1650 0.299
583 26 1 0.2136 0.03348 0.1571 0.290
613 24 1 0.2047 0.03325 0.1489 0.281
624 23 1 0.1958 0.03297 0.1407 0.272
641 22 1 0.1869 0.03265 0.1327 0.263
643 21 1 0.1780 0.03229 0.1247 0.254
654 20 1 0.1691 0.03188 0.1169 0.245
655 19 1 0.1602 0.03142 0.1091 0.235
687 18 1 0.1513 0.03090 0.1014 0.226
689 17 1 0.1424 0.03034 0.0938 0.216
705 16 1 0.1335 0.02972 0.0863 0.207
707 15 1 0.1246 0.02904 0.0789 0.197
728 14 1 0.1157 0.02830 0.0716 0.187
731 13 1 0.1068 0.02749 0.0645 0.177
735 12 1 0.0979 0.02660 0.0575 0.167
765 10 1 0.0881 0.02568 0.0498 0.156
791 9 1 0.0783 0.02462 0.0423 0.145
814 7 1 0.0671 0.02351 0.0338 0.133
883 4 1 0.0503 0.02285 0.0207 0.123

这个表格每一行显示了一个(多个)事件或截尾发生了,在风险中的样本数(就是还没死的),以及及时的累积生存率等。

1
2
3
plot(fit0, xlab="Survival Time in Days", 
ylab="% Surviving", yscale=100,
main="Survival Distribution (Overall)")

Kaplan-Meier 生存曲线
image
比较男性和女性的生存分布

1
2
3
4
5
6
7
8
fit1 <- survfit(survobj~sex,data=lung)
fit1

Call: survfit(formula = survobj ~ sex, data = lung)

n events median 0.95LCL 0.95UCL
sex=1 138 112 270 212 310
sex=2 90 53 426 348 550

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
summary(fit1)

Call: survfit(formula = survobj ~ sex, data = lung)

sex=1
time n.risk n.event survival std.err lower 95% CI upper 95% CI
11 138 3 0.9783 0.0124 0.9542 1.000
12 135 1 0.9710 0.0143 0.9434 0.999
13 134 2 0.9565 0.0174 0.9231 0.991
15 132 1 0.9493 0.0187 0.9134 0.987
26 131 1 0.9420 0.0199 0.9038 0.982
30 130 1 0.9348 0.0210 0.8945 0.977
31 129 1 0.9275 0.0221 0.8853 0.972
53 128 2 0.9130 0.0240 0.8672 0.961
54 126 1 0.9058 0.0249 0.8583 0.956
59 125 1 0.8986 0.0257 0.8496 0.950
60 124 1 0.8913 0.0265 0.8409 0.945
65 123 2 0.8768 0.0280 0.8237 0.933
71 121 1 0.8696 0.0287 0.8152 0.928
81 120 1 0.8623 0.0293 0.8067 0.922
88 119 2 0.8478 0.0306 0.7900 0.910
92 117 1 0.8406 0.0312 0.7817 0.904
93 116 1 0.8333 0.0317 0.7734 0.898
95 115 1 0.8261 0.0323 0.7652 0.892
105 114 1 0.8188 0.0328 0.7570 0.886
107 113 1 0.8116 0.0333 0.7489 0.880
110 112 1 0.8043 0.0338 0.7408 0.873
116 111 1 0.7971 0.0342 0.7328 0.867
118 110 1 0.7899 0.0347 0.7247 0.861
131 109 1 0.7826 0.0351 0.7167 0.855
132 108 2 0.7681 0.0359 0.7008 0.842
135 106 1 0.7609 0.0363 0.6929 0.835
142 105 1 0.7536 0.0367 0.6851 0.829
144 104 1 0.7464 0.0370 0.6772 0.823
147 103 1 0.7391 0.0374 0.6694 0.816
156 102 2 0.7246 0.0380 0.6538 0.803
163 100 3 0.7029 0.0389 0.6306 0.783
166 97 1 0.6957 0.0392 0.6230 0.777
170 96 1 0.6884 0.0394 0.6153 0.770
175 94 1 0.6811 0.0397 0.6076 0.763
176 93 1 0.6738 0.0399 0.5999 0.757
177 92 1 0.6664 0.0402 0.5922 0.750
179 91 2 0.6518 0.0406 0.5769 0.736
180 89 1 0.6445 0.0408 0.5693 0.730
181 88 2 0.6298 0.0412 0.5541 0.716
183 86 1 0.6225 0.0413 0.5466 0.709
189 83 1 0.6150 0.0415 0.5388 0.702
197 80 1 0.6073 0.0417 0.5309 0.695
202 78 1 0.5995 0.0419 0.5228 0.687
207 77 1 0.5917 0.0420 0.5148 0.680
210 76 1 0.5839 0.0422 0.5068 0.673
212 75 1 0.5762 0.0424 0.4988 0.665
218 74 1 0.5684 0.0425 0.4909 0.658
222 72 1 0.5605 0.0426 0.4829 0.651
223 70 1 0.5525 0.0428 0.4747 0.643
229 67 1 0.5442 0.0429 0.4663 0.635
230 66 1 0.5360 0.0431 0.4579 0.627
239 64 1 0.5276 0.0432 0.4494 0.619
246 63 1 0.5192 0.0433 0.4409 0.611
267 61 1 0.5107 0.0434 0.4323 0.603
269 60 1 0.5022 0.0435 0.4238 0.595
270 59 1 0.4937 0.0436 0.4152 0.587
283 57 1 0.4850 0.0437 0.4065 0.579
284 56 1 0.4764 0.0438 0.3979 0.570
285 54 1 0.4676 0.0438 0.3891 0.562
286 53 1 0.4587 0.0439 0.3803 0.553
288 52 1 0.4499 0.0439 0.3716 0.545
291 51 1 0.4411 0.0439 0.3629 0.536
301 48 1 0.4319 0.0440 0.3538 0.527
303 46 1 0.4225 0.0440 0.3445 0.518
306 44 1 0.4129 0.0440 0.3350 0.509
310 43 1 0.4033 0.0441 0.3256 0.500
320 42 1 0.3937 0.0440 0.3162 0.490
329 41 1 0.3841 0.0440 0.3069 0.481
337 40 1 0.3745 0.0439 0.2976 0.471
353 39 2 0.3553 0.0437 0.2791 0.452
363 37 1 0.3457 0.0436 0.2700 0.443
364 36 1 0.3361 0.0434 0.2609 0.433
371 35 1 0.3265 0.0432 0.2519 0.423
387 34 1 0.3169 0.0430 0.2429 0.413
390 33 1 0.3073 0.0428 0.2339 0.404
394 32 1 0.2977 0.0425 0.2250 0.394
428 29 1 0.2874 0.0423 0.2155 0.383
429 28 1 0.2771 0.0420 0.2060 0.373
442 27 1 0.2669 0.0417 0.1965 0.362
455 25 1 0.2562 0.0413 0.1868 0.351
457 24 1 0.2455 0.0410 0.1770 0.341
460 22 1 0.2344 0.0406 0.1669 0.329
477 21 1 0.2232 0.0402 0.1569 0.318
519 20 1 0.2121 0.0397 0.1469 0.306
524 19 1 0.2009 0.0391 0.1371 0.294
533 18 1 0.1897 0.0385 0.1275 0.282
558 17 1 0.1786 0.0378 0.1179 0.270
567 16 1 0.1674 0.0371 0.1085 0.258
574 15 1 0.1562 0.0362 0.0992 0.246
583 14 1 0.1451 0.0353 0.0900 0.234
613 13 1 0.1339 0.0343 0.0810 0.221
624 12 1 0.1228 0.0332 0.0722 0.209
643 11 1 0.1116 0.0320 0.0636 0.196
655 10 1 0.1004 0.0307 0.0552 0.183
689 9 1 0.0893 0.0293 0.0470 0.170
707 8 1 0.0781 0.0276 0.0390 0.156
791 7 1 0.0670 0.0259 0.0314 0.143
814 5 1 0.0536 0.0239 0.0223 0.128
883 3 1 0.0357 0.0216 0.0109 0.117

sex=2
time n.risk n.event survival std.err lower 95% CI upper 95% CI
5 90 1 0.9889 0.0110 0.9675 1.000
60 89 1 0.9778 0.0155 0.9478 1.000
61 88 1 0.9667 0.0189 0.9303 1.000
62 87 1 0.9556 0.0217 0.9139 0.999
79 86 1 0.9444 0.0241 0.8983 0.993
81 85 1 0.9333 0.0263 0.8832 0.986
95 83 1 0.9221 0.0283 0.8683 0.979
107 81 1 0.9107 0.0301 0.8535 0.972
122 80 1 0.8993 0.0318 0.8390 0.964
145 79 2 0.8766 0.0349 0.8108 0.948
153 77 1 0.8652 0.0362 0.7970 0.939
166 76 1 0.8538 0.0375 0.7834 0.931
167 75 1 0.8424 0.0387 0.7699 0.922
182 71 1 0.8305 0.0399 0.7559 0.913
186 70 1 0.8187 0.0411 0.7420 0.903
194 68 1 0.8066 0.0422 0.7280 0.894
199 67 1 0.7946 0.0432 0.7142 0.884
201 66 2 0.7705 0.0452 0.6869 0.864
208 62 1 0.7581 0.0461 0.6729 0.854
226 59 1 0.7452 0.0471 0.6584 0.843
239 57 1 0.7322 0.0480 0.6438 0.833
245 54 1 0.7186 0.0490 0.6287 0.821
268 51 1 0.7045 0.0501 0.6129 0.810
285 47 1 0.6895 0.0512 0.5962 0.798
293 45 1 0.6742 0.0523 0.5791 0.785
305 43 1 0.6585 0.0534 0.5618 0.772
310 42 1 0.6428 0.0544 0.5447 0.759
340 39 1 0.6264 0.0554 0.5267 0.745
345 38 1 0.6099 0.0563 0.5089 0.731
348 37 1 0.5934 0.0572 0.4913 0.717
350 36 1 0.5769 0.0579 0.4739 0.702
351 35 1 0.5604 0.0586 0.4566 0.688
361 33 1 0.5434 0.0592 0.4390 0.673
363 32 1 0.5265 0.0597 0.4215 0.658
371 30 1 0.5089 0.0603 0.4035 0.642
426 26 1 0.4893 0.0610 0.3832 0.625
433 25 1 0.4698 0.0617 0.3632 0.608
444 24 1 0.4502 0.0621 0.3435 0.590
450 23 1 0.4306 0.0624 0.3241 0.572
473 22 1 0.4110 0.0626 0.3050 0.554
520 19 1 0.3894 0.0629 0.2837 0.534
524 18 1 0.3678 0.0630 0.2628 0.515
550 15 1 0.3433 0.0634 0.2390 0.493
641 11 1 0.3121 0.0649 0.2076 0.469
654 10 1 0.2808 0.0655 0.1778 0.443
687 9 1 0.2496 0.0652 0.1496 0.417
705 8 1 0.2184 0.0641 0.1229 0.388
728 7 1 0.1872 0.0621 0.0978 0.359
731 6 1 0.1560 0.0590 0.0743 0.328
735 5 1 0.1248 0.0549 0.0527 0.295
765 3 1 0.0832 0.0499 0.0257 0.270

summary()函数中可以设定时间参数用来选定一个时间区间,可以以此比对男生是不是比女生有更高的风险:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
summary(fit1, times=seq(0, 1000, 100))

Call: survfit(formula = survobj ~ sex, data = lung)

sex=1
time n.risk n.event survival std.err lower 95% CI upper 95% CI
0 138 0 1.0000 0.0000 1.0000 1.000
100 114 24 0.8261 0.0323 0.7652 0.892
200 78 30 0.6073 0.0417 0.5309 0.695
300 49 20 0.4411 0.0439 0.3629 0.536
400 31 15 0.2977 0.0425 0.2250 0.394
500 20 7 0.2232 0.0402 0.1569 0.318
600 13 7 0.1451 0.0353 0.0900 0.234
700 8 5 0.0893 0.0293 0.0470 0.170
800 6 2 0.0670 0.0259 0.0314 0.143
900 2 2 0.0357 0.0216 0.0109 0.117
1000 2 0 0.0357 0.0216 0.0109 0.117

sex=2
time n.risk n.event survival std.err lower 95% CI upper 95% CI
0 90 0 1.0000 0.0000 1.0000 1.000
100 82 7 0.9221 0.0283 0.8683 0.979
200 66 11 0.7946 0.0432 0.7142 0.884
300 43 9 0.6742 0.0523 0.5791 0.785
400 26 10 0.5089 0.0603 0.4035 0.642
500 21 5 0.4110 0.0626 0.3050 0.554
600 11 3 0.3433 0.0634 0.2390 0.493
700 8 3 0.2496 0.0652 0.1496 0.417
800 2 5 0.0832 0.0499 0.0257 0.270
900 1 0 0.0832 0.0499 0.0257 0.270

可视化

1
2
library(survminer)
ggsurvplot(fit1)

survminer的包提供的一个叫ggsurvplot()的函数可以帮助我们更简单地做出可以发表的生存曲线,代码脚本里是对ggplot2语法很熟悉的话还能更简单地进行修改。
image

添加曲线的置信区间,并增加long-rank检验的结果p值以及风险表格:

1
2
3
4
5
ggsurvplot(fit1, conf.int=TRUE, pval=TRUE, risk.table=TRUE, 
legend.labs=c("Male", "Female"), legend.title="Sex",
palette=c("dodgerblue2", "orchid2"),
title="Kaplan-Meier Curve for Lung Cancer Survival",
risk.table.height=.15)

image

Cox回归模型

Kaplan-Meier曲线用来对两个分类变量差异的可视化非常合适,但分类要是多,那就糟透了:

1
ggsurvplot(survfit(Surv(time, status)~nodes, data=survival::colon))

image

而且生存曲线另外不能可视化的是连续型变量的风险。

Cox PH回归模型正好是处理这类问题的一把好手,它同样内置于survival包中,语法与lm()和glm()一致。
比例风险回归也称为Cox回归,是评估不同变量对生存率影响的最常用方法。

image
再来用肺癌数据集看看不同性别的风险,这次使用Cox模型。

1
2
3
4
5
6
7
8
9
10
11
fit <- coxph(Surv(time, status)~sex, data=lung)
fit

Call:
coxph(formula = Surv(time, status) ~ sex, data = lung)

coef exp(coef) se(coef) z p
sex -0.5310 0.5880 0.1672 -3.176 0.00149

Likelihood ratio test=10.63 on 1 df, p=0.001111
n= 228, number of events= 165

结果中的exp(coef)列包含eβ1。它就是风险比率——该变量对风险率的乘数效应(对于该变量每个单位增加的)。因此,对于像性别这样的分类变量,从男性(基线)到女性的结果大约减少约40%的危险。你也可以翻转coef列上的符号,并采用exp(0.531),你可以将其解释为男性导致危险增加1.7倍,或者单位时间男性的死亡率约为女性1.7倍(女性死亡率为男性的0.588倍)。

1
2
3
HR=1: 没有效应
HR>1: 风险增加
HR<1: 风险减少 (保护变量)

“性别”有一个对应的p值,整个模型中也有一个p值。0.00111的p值非常接近我们在Kaplan-Meier图上看到的p=0.00131的p值。
这是因为KM曲线显示的是对数秩检验的p值。你可以通过调用summary(fit)来获得Cox模型结果。
你也可以使用survdiff()直接计算log-rank测试p值。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
summary(fit)

Call:
coxph(formula = Surv(time, status) ~ sex, data = lung)

n= 228, number of events= 165

coef exp(coef) se(coef) z Pr(>|z|)
sex -0.5310 0.5880 0.1672 -3.176 0.00149 **
---
Signif. codes: 0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1

exp(coef) exp(-coef) lower .95 upper .95
sex 0.588 1.701 0.4237 0.816

Concordance= 0.579 (se = 0.021 )
Rsquare= 0.046 (max possible= 0.999 )
Likelihood ratio test= 10.63 on 1 df, p=0.001
Wald test = 10.09 on 1 df, p=0.001
Score (logrank) test = 10.33 on 1 df, p=0.001

survdiff(Surv(time, status)~sex, data=lung)

Call:
survdiff(formula = Surv(time, status) ~ sex, data = lung)

N Observed Expected (O-E)^2/E (O-E)^2/V
sex=1 138 112 91.6 4.55 10.3
sex=2 90 53 73.4 5.68 10.3

Chisq= 10.3 on 1 degrees of freedom, p= 0.001

回到肺部数据并查看年龄的Cox模型。看起来年龄在模拟为连续变量时似乎有一点重要。

我们的的回归分析显示年龄有重要意义,让我们制作Kaplan-Meier图。但是,正如我们之前所看到的,我们不能这样做,因为我们会为每个独特的年龄值获得单独的曲线!

1
ggsurvplot(survfit(Surv(time, status)~age, data=lung))

image

试图将一个连续变量分成不同的组 - 三分位数,上四分位数与下四分位数,中位数分数等 - 这样你就可以生成生存曲线图。但是,你如何进行分组是有意义的!检查cut的帮助。cut()接受一个连续变量和一些断点,并从中创建一个分类变量。 我们来得到数据集的平均年龄,并绘制一个显示年龄分布的直方图。

1
2
3
mean(lung$age)
hist(lung$age)
ggplot(lung, aes(age)) + geom_histogram(bins=20)

image
image
现在,让我们尝试通过lung$age创建一个分类变量,其中0,62(平均值)和正无穷大。我们可以在这里继续添加labels =选项来标记我们创建的分组,例如,“年轻”和“老”。最后,我们可以将这个结果分配给肺数据集中的一个新对象。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
cut(lung$age, breaks=c(0, 62, Inf))

[1] (62,Inf] (62,Inf] (0,62] (0,62] (0,62] (62,Inf] (62,Inf] (62,Inf] (0,62]
[10] (0,62] (0,62] (62,Inf] (62,Inf] (0,62] (0,62] (62,Inf] (62,Inf] (62,Inf]
[19] (0,62] (0,62] (62,Inf] (0,62] (0,62] (0,62] (62,Inf] (62,Inf] (0,62]
[28] (62,Inf] (0,62] (62,Inf] (62,Inf] (62,Inf] (0,62] (0,62] (0,62] (0,62]
[37] (62,Inf] (62,Inf] (62,Inf] (62,Inf] (62,Inf] (62,Inf] (0,62] (0,62] (62,Inf]
[46] (62,Inf] (62,Inf] (62,Inf] (62,Inf] (0,62] (62,Inf] (62,Inf] (62,Inf] (0,62]
[55] (0,62] (0,62] (62,Inf] (0,62] (0,62] (62,Inf] (62,Inf] (0,62] (62,Inf]
[64] (62,Inf] (62,Inf] (62,Inf] (62,Inf] (62,Inf] (62,Inf] (62,Inf] (62,Inf] (0,62]
[73] (62,Inf] (0,62] (0,62] (62,Inf] (0,62] (0,62] (62,Inf] (62,Inf] (0,62]
[82] (0,62] (0,62] (0,62] (0,62] (62,Inf] (0,62] (0,62] (0,62] (62,Inf]
[91] (62,Inf] (62,Inf] (62,Inf] (0,62] (62,Inf] (62,Inf] (62,Inf] (62,Inf] (62,Inf]
[100] (62,Inf] (0,62] (62,Inf] (0,62] (62,Inf] (0,62] (62,Inf] (0,62] (62,Inf]
[109] (0,62] (62,Inf] (62,Inf] (0,62] (62,Inf] (62,Inf] (0,62] (62,Inf] (0,62]
[118] (62,Inf] (62,Inf] (62,Inf] (62,Inf] (0,62] (62,Inf] (62,Inf] (62,Inf] (62,Inf]
[127] (0,62] (62,Inf] (62,Inf] (0,62] (0,62] (0,62] (0,62] (0,62] (62,Inf]
[136] (62,Inf] (0,62] (0,62] (0,62] (0,62] (62,Inf] (62,Inf] (62,Inf] (62,Inf]
[145] (0,62] (0,62] (62,Inf] (0,62] (62,Inf] (0,62] (62,Inf] (0,62] (0,62]
[154] (0,62] (0,62] (62,Inf] (62,Inf] (0,62] (62,Inf] (0,62] (0,62] (0,62]
[163] (62,Inf] (62,Inf] (62,Inf] (0,62] (0,62] (0,62] (0,62] (62,Inf] (0,62]
[172] (0,62] (0,62] (0,62] (0,62] (0,62] (0,62] (0,62] (0,62] (62,Inf]
[181] (0,62] (0,62] (62,Inf] (62,Inf] (0,62] (0,62] (62,Inf] (0,62] (62,Inf]
[190] (0,62] (62,Inf] (0,62] (0,62] (62,Inf] (62,Inf] (62,Inf] (62,Inf] (62,Inf]
[199] (0,62] (0,62] (62,Inf] (62,Inf] (62,Inf] (0,62] (62,Inf] (0,62] (0,62]
[208] (0,62] (62,Inf] (0,62] (0,62] (62,Inf] (62,Inf] (62,Inf] (62,Inf] (62,Inf]
[217] (0,62] (62,Inf] (62,Inf] (0,62] (62,Inf] (62,Inf] (62,Inf] (62,Inf] (0,62]
[226] (62,Inf] (62,Inf] (0,62]
Levels: (0,62] (62,Inf]
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
cut(lung$age, breaks=c(0, 62, Inf), labels=c("young", "old"))

[1] old old young young young old old old young young young old old young
[15] young old old old young young old young young young old old young old
[29] young old old old young young young young old old old old old old
[43] young young old old old old old young old old old young young young
[57] old young young old old young old old old old old old old old
[71] old young old young young old young young old old young young young young
[85] young old young young young old old old old young old old old old
[99] old old young old young old young old young old young old old young
[113] old old young old young old old old old young old old old old
[127] young old old young young young young young old old young young young young
[141] old old old old young young old young old young old young young young
[155] young old old young old young young young old old old young young young
[169] young old young young young young young young young young young old young young
[183] old old young young old young old young old young young old old old
[197] old old young young old old old young old young young young old young
[211] young old old old old old young old old young old old old old
[225] young old old young
Levels: young old
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# the base r way:
lung$agecat <- cut(lung$age, breaks=c(0, 62, Inf), labels=c("young", "old"))

head(lung)

A tibble: 6 x 11
inst time status age sex ph.ecog ph.karno pat.karno meal.cal wt.loss agecat
<dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <fct>
1 3 306 2 74 1 1 90 100 1175 NA old
2 3 455 2 68 1 0 90 90 1225 15 old
3 3 1010 1 56 1 0 90 90 NA 15 young
4 5 210 2 57 1 1 90 60 1150 11 young
5 1 883 2 60 1 0 100 90 NA 0 young
6 12 1022 1 74 1 1 50 80 513 0 old
>

用这个新的分类生成KM图时会发生什么? 看起来“老”和“年轻”患者之间的曲线存在一些差异,老年患者的生存几率稍差。但是p=0.39时,62岁以下和62岁以上者的生存率差异不显著。

1
ggsurvplot(survfit(Surv(time, status)~agecat, data=lung), pval=TRUE)

image

但是,如果我们选择一个不同的切点,例如70岁,这大概是年龄分布上限的四分之一(见“分位数”)。结果现在非常重要!

1
2
3
4
# the base r way:
lung$agecat <- cut(lung$age, breaks=c(0, 70, Inf), labels=c("young", "old"))

ggsurvplot(survfit(Surv(time, status)~agecat, data=lung), pval=TRUE)

image

Cox回归分析整个分布范围内的连续变量,其中Kaplan-Meier图上的对数秩检验可能会根据您对连续变量进行分类而发生变化。他们以一种不同的方式回答类似的问题:回归模型提出的问题是“年龄对生存有什么影响?”,而对数秩检验和KM图则问:“那些不到70岁和70岁以上的人有差异吗?”。

让我们创建另一个模型,分析数据集中的所有变量!这向我们展示了当所有变量一起考虑时,如何影响生存。一些是非常强大的预测指标(性别,ECOG评分)。有趣的是,医师对Karnofsky表现评分的评分稍高,但患者评分相同。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# predict male survival from age and medical scores 
MaleMod <- coxph(survobj~sex+age+ph.ecog+ph.karno+pat.karno+meal.cal+wt.loss,
data=lung, subset=sex==1)

MaleMod
Call:
coxph(formula = survobj ~ sex + age + ph.ecog + ph.karno + pat.karno +
meal.cal + wt.loss, data = lung, subset = sex == 1)

coef exp(coef) se(coef) z p
sex NA NA 0.000e+00 NA NA
age 2.580e-02 1.026e+00 1.508e-02 1.711 0.0871
ph.ecog 6.513e-01 1.918e+00 2.606e-01 2.499 0.0124
ph.karno 2.840e-02 1.029e+00 1.345e-02 2.112 0.0347
pat.karno -1.777e-02 9.824e-01 1.058e-02 -1.680 0.0930
meal.cal -3.551e-05 1.000e+00 2.944e-04 -0.121 0.9040
wt.loss -8.912e-03 9.911e-01 9.537e-03 -0.935 0.3500

Likelihood ratio test=15.17 on 6 df, p=0.01898
n= 104, number of events= 83
(34 observations deleted due to missingness)

代码脚本

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
# Mayo Clinic Lung Cancer Data
library(survival)

# learn about the dataset
help(lung)

# create a Surv object
survobj <- with(lung, Surv(time,status))

# Plot survival distribution of the total sample
# Kaplan-Meier estimator
fit0 <- survfit(survobj~1, data=lung)
summary(fit0)
plot(fit0, xlab="Survival Time in Days",
ylab="% Surviving", yscale=100,
main="Survival Distribution (Overall)")

# Compare the survival distributions of men and women

fit1 <- survfit(survobj~sex,data=lung)
# plot the survival distributions by sex
plot(fit1, xlab="Survival Time in Days",
ylab="% Surviving", yscale=100, col=c("red","blue"),
main="Survival Distributions by Gender")
legend("topright", title="Gender", c("Male", "Female"),
fill=c("red", "blue"))

# test for difference between male and female
# survival curves (logrank test)
survdiff(survobj~sex, data=lung)

# predict male survival from age and medical scores
MaleMod <- coxph(survobj~age+ph.ecog+ph.karno+pat.karno,
data=lung, subset=sex==1)

# display results
MaleMod

# evaluate the proportional hazards assumption
cox.zph(MaleMod)

Linear Discriminant Analysis (LDA)

线性分类判别

要在R中运行LDA,我们将使用lda()函数,该函数是名为MASS的包的一部分。 这个软件包附带了R的核心版本,因此我们只需要在使用lda()函数之前将其加载到R的工作目录中。

1
2
3
4
library(MASS)
## create an LDA model. The lda() function takes a forumla and the name of the training data set as its argument

lda_model = lda(Direction~., data = training_data)

接下来,我们将评估模型,因此我们将再次对我们的测试数据集使用predict()函数。

1
2
3
4
lda_pred = predict(lda_model,testing_data) 
names(lda_pred)
## [1] "class" "posterior" "x"
lda_pred_y = lda_pred$class

好消息是,在预测()函数中使用LDA模型时,输出是类别本身(类),与我们使用逻辑回归模型(输出是概率)时发生的情况不同。
好的,现在是评估模型的时候了。我们创建混淆矩阵并计算错误分类错误。

1
2
3
4
5
6
7
## compute the confusion matrix 
table(lda_pred_y, testing_y)

## testing_y
## lda_pred_y Down Up
## Down 2 1
## Up 109 140

计算错误分类错误率

1
2
mean(lda_pred_y != testing_y)
## [1] 0.4365

LDA模型的错误分类错误率与逻辑回归模型的错误分类错误率相同。 让我们检查一下QDA模型在错误分类错误率方面是否会做得更好。

代码脚本

1
2
3
4
5
6
7
8
9
10
11
12
--LDA
library(MASS)
## create an LDA model. The lda() function takes a forumla and the name of the training data set as its argument

lda_model = lda(Direction~., data = training_data)
lda_pred = predict(lda_model,testing_data)
names(lda_pred)
lda_pred_y = lda_pred$class
## compute the confusion matrix
table(lda_pred_y, testing_y)
## compute the misclassification error rate
mean(lda_pred_y != testing_y)

Quadratic Discriminant Analysis (QDA)

二次分类判别

培训和评估QDA模型在语法上与培训和评估LDA模型非常相似。 唯一的区别在于函数名称lda()。

1
2
3
4
5
6
7
8
9
10
library(MASS)
qda_model = qda(Direction~., data = training_data)
qda_pred = predict(qda_model,testing_data)
qda_pred_y = qda_pred$class
table(qda_pred_y, testing_y)

## testing_y
## qda_pred_y Down Up
## Down 43 51
## Up 68 90

1
2
3
mean(qda_pred_y != testing_y)

## [1] 0.4722

QDA模型的错误分配错误率为47.22%,高于我们从逻辑回归和LDA模型得到的错误率。

KNN for Classification

为了训练KNN模型进行分类,我们将使用函数knn(),它是R类包的一部分。确保安装并加载此库。

1
2
## load the class R package. 
library(class)

这里的数据分割将与我们对逻辑回归,LDA和QDA所做的不同。 这是因为与glm(),lda()和qda()相比,knn()函数被构建采用不同的参数。

对于knn(),我们必须将y变量放在训练和测试数据的单独列中。除了这个问题,我们必须缩放或标准化我们的数值变量,因为KNN方法使用距离测量对观测进行分类。 有关此问题的更多信息,请参阅之前的“KNN for Classification”文档。

为了标准化数据集,我们可以使用函数scale(),如下所示:

1
2
3
4
5
6
7
## load this package to use Smarket data library(ISLR)
## scale the Smarket data without the 8th variable (Today), and the 9th variable (Direction)
## we got rid of Today variable because is is highly correlated with Direction
## we got rid of Direction because it is our response variable. Make sure to exclude all
## categorical variables should be excluded from scaling. We can't scale categorical varia

data = scale(Smarket[,-c(8,9)])

现在让我们拆分数据。 记住它会与我们之前做的有点不同:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
## the following two steps are similar to earlier steps
train = Year < 2005
test = !train

## the following two steps looks similar to what we did earlier, but they are actually not!!
## Remember that we got rid of the response variable "Direction" when we scaled the data!
## So, our training and testing data has only the predictors! That's how KNN() function work

training_data = data[train,]
testing_data = data[test,]

## KNN take the training response variable seperately
training_y = Smarket$Direction[train]
## we also need the have the testing_y seperately for assesing the model later on
testing_y = Smarket$Direction[test]

现在,我们已准备好训练KNN模型。knn()函数使用随机数生成器来训练模型。为了在每次运行R代码时获得相同的分析输出,请确保将R中随机数生成器的种子设置为您选择的数字。但是每次运行R代码时都必须坚持使用相同的种子。

1
2
3
4
5
6
7
8
9
10
11
12
set.seed(1)
knn_pred_y = knn(training_data, testing_data, training_y, k = 1)
table(knn_pred_y, testing_y)

## testing_y
## knn_pred_y Down Up
## Down 42 53
## Up 69 88

mean(knn_pred_y != testing_y)

## [1] 0.4841

当k = 1时,错误分类错误率为48.41%。 它并不比之前的回归,lda和qda模型更好。 让我们看看k的哪个值会给我们最低的错误分类错误率。 为此我们将有一个for循环:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
knn_pred_y = NULL
error_rate = NULL
for(i in 1:300){
set.seed(1)
knn_pred_y = knn(training_data,testing_data,training_y,k=i)
error_rate[i] = mean(testing_y != knn_pred_y) }

### find the minimum error rate

min_error_rate = min(error_rate)
print(min_error_rate)

## [1] 0.4127

### get the index of that error rate, which is the k

K = which(error_rate == min_error_rate)
print(K)

## [1] 197

可视化当我们增加k时错误分类错误率如何受到影响:

1
2
library(ggplot2)
qplot(1:300, error_rate, xlab = "K",ylab = "Error Rate", geom=c("point", "line"))

image
当我们训练k = 197的KNN模型时,我们得到最低的错误分类错误率为41.27%。 这将使我们得出结论,该数据集的最佳模型将是逻辑回归或LDA模型,因为它们具有最少的错误分类错误。 但要小心,这可能是过度拟合!你怎么看?

代码脚本

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
--KNN 
library(class)
data = scale(Smarket[,-c(8,9)])

train = Year < 2005
test = !train

training_data = data[train,]
testing_data = data[test,]

training_y = Smarket$Direction[train]
testing_y = Smarket$Direction[test]

set.seed(1)
knn_pred_y = knn(training_data, testing_data, training_y, k = 1)
table(knn_pred_y, testing_y)
mean(knn_pred_y != testing_y)

knn_pred_y = NULL
error_rate = NULL
for(i in 1:300){
set.seed(1)
knn_pred_y = knn(training_data,testing_data,training_y,k=i)
### find the minimum error rate
error_rate[i] = mean(testing_y != knn_pred_y)
}
min_error_rate = min(error_rate)
print(min_error_rate)
K = which(error_rate == min_error_rate)
print(K)

library(ggplot2)
qplot(1:300, error_rate, xlab = "K",
ylab = "Error Rate", geom=c("point", "line"))
Directory
  1. 1. 广义线性模型(GLM)
    1. 1.1. Logistic回归
      1. 1.1.1. Logistic Regression(LR)
      2. 1.1.2. 代码脚本:
    2. 1.2. 泊松回归
    3. 1.3. 生存分析
      1. 1.3.1. 生存曲线
      2. 1.3.2. Cox回归模型
      3. 1.3.3. 代码脚本
  2. 2. Linear Discriminant Analysis (LDA)
    1. 2.1. 代码脚本
  3. 3. Quadratic Discriminant Analysis (QDA)
  4. 4. KNN for Classification
    1. 4.1. 代码脚本