Files
qsar/plot.py
2024-09-28 13:14:52 +08:00

62 lines
1.9 KiB
Python

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
# 数据准备:不同模型的 MSE 和 R² 值
data = {
'Model': [
'1D-Linear Regression', '1D-Stochastic Gradient Descent', '1D-K-Nearest Neighbors',
'1D-Decision Tree', '1D-Random Forest', '1D-XGBoost', '1D-Multi-layer Perceptron',
'2D-Linear Regression', '2D-Stochastic Gradient Descent', '2D-K-Nearest Neighbors',
'2D-Decision Tree', '2D-Random Forest', '2D-XGBoost', '2D-Multi-layer Perceptron',
'3D-Stochastic Gradient Descent', '3D-K-Nearest Neighbors', '3D-Decision Tree',
'3D-Random Forest', '3D-XGBoost', '3D-Multi-layer Perceptron', 'UniMol QSAR'
],
'MSE': [
32.3949, 230009980374197965960989638656.0, 30.2081,
27.7150, 26.5204, 27.7147, 143.3505,
30.1093, 33.7336, 48.8179,
30.2360, 28.7916, 30.2351, 30.1715,
64.5768, 38.6921, 30.2360,
30.8310, 30.2362, 29.9844, 59.7204
],
'R2': [
0.6525, -2467672699617844819673481216.0, 0.6759,
0.7027, 0.7155, 0.7027, -0.5379,
0.6770, 0.6381, 0.4763,
0.6756, 0.6911, 0.6756, 0.6763,
0.3072, 0.5849, 0.6756,
0.6692, 0.6756, 0.6783, 0.3593
]
}
# 转换为 DataFrame
df = pd.DataFrame(data)
# 过滤掉负值
df_filtered = df[(df['MSE'] >= 0) & (df['R2'] >= 0)]
# 设置画布
plt.figure(figsize=(14, 10))
# 绘制 MSE 比较的条形图
plt.subplot(2, 1, 1)
sns.barplot(x='MSE', y='Model', data=df_filtered, palette='viridis')
plt.title('Comparison of MSE across QSAR Models (Filtered)')
plt.xlabel('Mean Squared Error')
plt.ylabel('QSAR Model')
# 绘制 R² 比较的条形图
plt.subplot(2, 1, 2)
sns.barplot(x='R2', y='Model', data=df_filtered, palette='plasma')
plt.title('Comparison of R² across QSAR Models (Filtered)')
plt.xlabel('R-squared')
plt.ylabel('QSAR Model')
# 调整布局以避免重叠
plt.tight_layout()
# 显示图表
# plt.show()
plt.savefig('qsar_r2_scores.png', dpi=300)