本文轉(zhuǎn)自:DeepHub IMBA
回歸任務(wù)在實(shí)際應(yīng)用中隨處可見——天氣預(yù)報(bào)、自動(dòng)駕駛、醫(yī)療診斷、經(jīng)濟(jì)預(yù)測(cè)、能耗分析,但大部分回歸模型只給出一個(gè)預(yù)測(cè)值,對(duì)這個(gè)值到底有多靠譜卻只字不提。這在某些應(yīng)用場(chǎng)景下會(huì)造成很多問題,比如用模型預(yù)測(cè)患者血壓,假設(shè)輸出是120/80這樣的正常值,表面看沒問題。但如果模型其實(shí)對(duì)這個(gè)預(yù)測(cè)很不確定呢?這時(shí)候光看數(shù)值就不夠了。
神經(jīng)網(wǎng)絡(luò)有幾種方法可以在給出預(yù)測(cè)的同時(shí)估計(jì)不確定性。
回歸中的不確定性問題
分類任務(wù)里,每個(gè)類別都有對(duì)應(yīng)的預(yù)測(cè)分?jǐn)?shù),經(jīng)過softmax之后就是概率值,可以直接當(dāng)作置信度來看——概率高說明模型比較有把握。
回歸就沒這么簡(jiǎn)單了。用MSE(均方誤差)訓(xùn)練回歸模型時(shí),模型對(duì)那些難擬合的樣本會(huì)傾向于預(yù)測(cè)平均值。比如說訓(xùn)練集里有幾個(gè)輸入特征幾乎一樣、但目標(biāo)值差異很大的樣本,模型為了降低MSE,會(huì)把預(yù)測(cè)往它們的均值靠攏。
這就帶來一個(gè)問題:當(dāng)預(yù)測(cè)值接近訓(xùn)練集的整體均值時(shí),很難判斷模型是真的有把握,還是純粹為了優(yōu)化損失函數(shù)在"混日子"。而如果預(yù)測(cè)值偏離均值較遠(yuǎn),可能說明模型比較確信——因?yàn)轭A(yù)測(cè)錯(cuò)了的話MSE懲罰會(huì)更重。
假設(shè)目標(biāo)值在-1到1之間,訓(xùn)練集均值接近0。模型預(yù)測(cè)0.1時(shí),不好說它是有信心還是在敷衍;但預(yù)測(cè)0.8時(shí),大部分情況說明它確實(shí)掌握了某些模式,當(dāng)然這個(gè)現(xiàn)象也不是在所有情況下都成立。
這里涉及兩類不確定性的概念:
任意不確定性(Aleatoric uncertainty)來自數(shù)據(jù)本身的隨機(jī)性和噪聲,比如測(cè)量誤差、自然界的隨機(jī)過程。這種不確定性沒法通過增加數(shù)據(jù)來消除,它就是客觀存在的。
認(rèn)知不確定性(Epistemic uncertainty)源于知識(shí)或數(shù)據(jù)的缺乏。弄日三級(jí)片訓(xùn)練數(shù)據(jù)太少或者測(cè)試樣本落在從沒見過的區(qū)域,模型就會(huì)產(chǎn)生這類不確定性。但是這種不確定性是可以通過改進(jìn)模型結(jié)構(gòu)、收集更多樣化的數(shù)據(jù)來降低的。
前面提到的那個(gè)例子主要說的就是任意不確定性。但即便預(yù)測(cè)0.8,如果是在數(shù)據(jù)稀疏的區(qū)域,模型也可能因?yàn)檎J(rèn)知不確定性而不靠譜。
四種不確定性估計(jì)方法
這里對(duì)比了四種在神經(jīng)網(wǎng)絡(luò)回歸中估計(jì)不確定性的方法。
1、均值 + 對(duì)數(shù)標(biāo)準(zhǔn)差(Mean + LogStd)
模型輸出兩個(gè)值:均值作為預(yù)測(cè)值,對(duì)數(shù)標(biāo)準(zhǔn)差表示不確定性——值越大說明越不確定。損失函數(shù)用的是負(fù)對(duì)數(shù)似然,假設(shè)目標(biāo)值服從正態(tài)分布,參數(shù)就是預(yù)測(cè)的均值和標(biāo)準(zhǔn)差。
x='input features'
y='targets'
mu,log_std=mean_logstd_model(x)
dist_obj=torch.distributions.Normal(loc=mu,scale=log_std.exp())
loss=-dist_obj.log_prob(y).mean()
2、均值 + 對(duì)數(shù)方差(Mean + LogVariance)
跟上一個(gè)類似,只是把標(biāo)準(zhǔn)差換成了方差。對(duì)數(shù)方差越大不確定性越高。損失計(jì)算方式如下:

x='input features'
y='targets'
mu,log_variance=mean_logvariance_model(x)
loss=(0.5*log_variance)+(((y-mu)**2)/(2*torch.exp(log_variance)))
loss=loss.mean()
3、蒙特卡洛Dropout(MC Dropout)
訓(xùn)練時(shí)用dropout,但預(yù)測(cè)階段的時(shí)候不關(guān)閉它。對(duì)同一個(gè)樣本預(yù)測(cè)多次,由于dropout的隨機(jī)性,每次結(jié)果會(huì)略有差異。把這些預(yù)測(cè)值的均值當(dāng)作最終預(yù)測(cè),標(biāo)準(zhǔn)差就是不確定性的度量。
4、簡(jiǎn)化版PPO方法
PPO本來是強(qiáng)化學(xué)習(xí)里的算法,這里做了簡(jiǎn)化改造用到監(jiān)督學(xué)習(xí)上。Actor網(wǎng)絡(luò)預(yù)測(cè)均值和標(biāo)準(zhǔn)差,獎(jiǎng)勵(lì)定義為采樣預(yù)測(cè)的負(fù)MSE。跟標(biāo)準(zhǔn)PPO的主要區(qū)別在于優(yōu)勢(shì)(advantage)的計(jì)算——監(jiān)督回歸可以看作單步環(huán)境不需要GAE,優(yōu)勢(shì)就是reward減去value。

實(shí)驗(yàn)設(shè)置
數(shù)據(jù)集
混凝土抗壓強(qiáng)度數(shù)據(jù)集包含1030個(gè)樣本,8個(gè)特征,1個(gè)目標(biāo)值。按7:3隨機(jī)劃分訓(xùn)練集和測(cè)試集。輸入特征標(biāo)準(zhǔn)化到均值0、標(biāo)準(zhǔn)差1,目標(biāo)值除以100歸一化到(0,1)區(qū)間。
模型結(jié)構(gòu)
基礎(chǔ)架構(gòu)都是全連接網(wǎng)絡(luò),4個(gè)隱藏層,每層64個(gè)神經(jīng)元。輸出層根據(jù)不同方法有各自的設(shè)計(jì)。
訓(xùn)練參數(shù)
統(tǒng)一訓(xùn)練2000個(gè)epoch,batch size 256,學(xué)習(xí)率0.0001。
實(shí)驗(yàn)結(jié)果分析

圖1展示了各方法在訓(xùn)練集和測(cè)試集上的MSE?;€方法(不估計(jì)不確定性的普通回歸)測(cè)試集MSE最低,均值對(duì)數(shù)方差最高。均值對(duì)數(shù)標(biāo)準(zhǔn)差和MC Dropout在測(cè)試集上表現(xiàn)相當(dāng),排第二,PPO排第三。

圖2畫出了真實(shí)值和預(yù)測(cè)值的散點(diǎn)圖,x軸是真值,y軸是預(yù)測(cè)。點(diǎn)越靠近對(duì)角線,預(yù)測(cè)越準(zhǔn)。

圖3展示了不確定性估計(jì)的實(shí)際效果。x軸是確定性閾值,比如0.3表示過濾掉30%最不確定的預(yù)測(cè),只保留70%最有把握的,y軸是這些篩選后樣本的MSE。
基線方法沒有不確定性估計(jì),所以MSE是條平線。PPO的表現(xiàn)比較奇怪——按理說高標(biāo)準(zhǔn)差應(yīng)該對(duì)應(yīng)高不確定性,但這里似乎反過來了,篩掉低標(biāo)準(zhǔn)差的預(yù)測(cè)后MSE反而上升。這可能跟PPO用標(biāo)準(zhǔn)差控制探索有關(guān),那些被分配低標(biāo)準(zhǔn)差的樣本探索不夠充分反而預(yù)測(cè)不準(zhǔn)。
均值對(duì)數(shù)標(biāo)準(zhǔn)差、均值方差和MC Dropout三者表現(xiàn)接近。均值對(duì)數(shù)標(biāo)準(zhǔn)差稍微好一點(diǎn),在0.55閾值之后還能繼續(xù)降低MSE,而均值方差已經(jīng)平了。

圖4把PPO的不確定性順序反過來試了試——最確定的當(dāng)最不確定,最不確定的當(dāng)最確定。雖然在0到0.2的區(qū)間MSE有所下降,但0.2之后又回升了。說明即便反著用,PPO的不確定性估計(jì)也不太對(duì)。
總結(jié)
在混凝土強(qiáng)度這個(gè)數(shù)據(jù)集上,均值對(duì)數(shù)標(biāo)準(zhǔn)差和均值對(duì)數(shù)方差兩種方法在估計(jì)預(yù)測(cè)不確定性方面效果最好。MC Dropout也不錯(cuò),但PPO的簡(jiǎn)化版本表現(xiàn)不佳,即使反轉(zhuǎn)其不確定性指標(biāo)也無法獲得可靠的估計(jì)。
代碼倉(cāng)庫:https://github.com/navid-bamdad-roshan/regression-with-uncertainty-methods-comparison
作者:Navid Bamdad Roshan
-
神經(jīng)網(wǎng)絡(luò)
+關(guān)注
關(guān)注
42文章
4840瀏覽量
108091 -
模型
+關(guān)注
關(guān)注
1文章
3808瀏覽量
52239 -
MSE
+關(guān)注
關(guān)注
0文章
7瀏覽量
6701
發(fā)布評(píng)論請(qǐng)先 登錄
去嵌入和不確定性是否使用了正確的設(shè)置
E8364C PNA的不確定性和跟蹤是什么?
是否可以使用全雙端口校準(zhǔn)中的S11不確定性來覆蓋單端口校準(zhǔn)的不確定性?
N5531S TRFL不確定性
容差模擬電路軟故障診斷的小波與量子神經(jīng)網(wǎng)絡(luò)方法設(shè)計(jì)
5G網(wǎng)絡(luò)架構(gòu)的不確定性及其對(duì)承載網(wǎng)的影響
一種求解動(dòng)態(tài)及不確定性優(yōu)化問題的新方法
LUCT工具主要特性及不確定性時(shí)鐘樹設(shè)計(jì)方法和算法的介紹
如何用不確定性解決模型問題
針對(duì)自閉癥輔助的不確定性聯(lián)合組稀疏建模方法
4種神經(jīng)網(wǎng)絡(luò)不確定性估計(jì)方法對(duì)比與代碼實(shí)現(xiàn)
評(píng)論