update
This commit is contained in:
61
plot.py
Normal file
61
plot.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
import pandas as pd
|
||||
|
||||
# 数据准备:不同模型的 MSE 和 R² 值
|
||||
data = {
|
||||
'Model': [
|
||||
'1D-Linear Regression', '1D-Stochastic Gradient Descent', '1D-K-Nearest Neighbors',
|
||||
'1D-Decision Tree', '1D-Random Forest', '1D-XGBoost', '1D-Multi-layer Perceptron',
|
||||
'2D-Linear Regression', '2D-Stochastic Gradient Descent', '2D-K-Nearest Neighbors',
|
||||
'2D-Decision Tree', '2D-Random Forest', '2D-XGBoost', '2D-Multi-layer Perceptron',
|
||||
'3D-Stochastic Gradient Descent', '3D-K-Nearest Neighbors', '3D-Decision Tree',
|
||||
'3D-Random Forest', '3D-XGBoost', '3D-Multi-layer Perceptron', 'UniMol QSAR'
|
||||
],
|
||||
'MSE': [
|
||||
32.3949, 230009980374197965960989638656.0, 30.2081,
|
||||
27.7150, 26.5204, 27.7147, 143.3505,
|
||||
30.1093, 33.7336, 48.8179,
|
||||
30.2360, 28.7916, 30.2351, 30.1715,
|
||||
64.5768, 38.6921, 30.2360,
|
||||
30.8310, 30.2362, 29.9844, 59.7204
|
||||
],
|
||||
'R2': [
|
||||
0.6525, -2467672699617844819673481216.0, 0.6759,
|
||||
0.7027, 0.7155, 0.7027, -0.5379,
|
||||
0.6770, 0.6381, 0.4763,
|
||||
0.6756, 0.6911, 0.6756, 0.6763,
|
||||
0.3072, 0.5849, 0.6756,
|
||||
0.6692, 0.6756, 0.6783, 0.3593
|
||||
]
|
||||
}
|
||||
|
||||
# 转换为 DataFrame
|
||||
df = pd.DataFrame(data)
|
||||
|
||||
# 过滤掉负值
|
||||
df_filtered = df[(df['MSE'] >= 0) & (df['R2'] >= 0)]
|
||||
|
||||
# 设置画布
|
||||
plt.figure(figsize=(14, 10))
|
||||
|
||||
# 绘制 MSE 比较的条形图
|
||||
plt.subplot(2, 1, 1)
|
||||
sns.barplot(x='MSE', y='Model', data=df_filtered, palette='viridis')
|
||||
plt.title('Comparison of MSE across QSAR Models (Filtered)')
|
||||
plt.xlabel('Mean Squared Error')
|
||||
plt.ylabel('QSAR Model')
|
||||
|
||||
# 绘制 R² 比较的条形图
|
||||
plt.subplot(2, 1, 2)
|
||||
sns.barplot(x='R2', y='Model', data=df_filtered, palette='plasma')
|
||||
plt.title('Comparison of R² across QSAR Models (Filtered)')
|
||||
plt.xlabel('R-squared')
|
||||
plt.ylabel('QSAR Model')
|
||||
|
||||
# 调整布局以避免重叠
|
||||
plt.tight_layout()
|
||||
|
||||
# 显示图表
|
||||
# plt.show()
|
||||
plt.savefig('qsar_r2_scores.png', dpi=300)
|
||||
Reference in New Issue
Block a user