first add
This commit is contained in:
24
BLOSUM62
Normal file
24
BLOSUM62
Normal file
@@ -0,0 +1,24 @@
|
||||
A 4 -1 -2 -2 0 -1 -1 0 -2 -1 -1 -1 -1 -2 -1 1 0 -3 -2 0 -2 -1 0 -4
|
||||
R -1 5 0 -2 -3 1 0 -2 0 -3 -2 2 -1 -3 -2 -1 -1 -3 -2 -3 -1 0 -1 -4
|
||||
N -2 0 6 1 -3 0 0 0 1 -3 -3 0 -2 -3 -2 1 0 -4 -2 -3 3 0 -1 -4
|
||||
D -2 -2 1 6 -3 0 2 -1 -1 -3 -4 -1 -3 -3 -1 0 -1 -4 -3 -3 4 1 -1 -4
|
||||
C 0 -3 -3 -3 9 -3 -4 -3 -3 -1 -1 -3 -1 -2 -3 -1 -1 -2 -2 -1 -3 -3 -2 -4
|
||||
Q -1 1 0 0 -3 5 2 -2 0 -3 -2 1 0 -3 -1 0 -1 -2 -1 -2 0 3 -1 -4
|
||||
E -1 0 0 2 -4 2 5 -2 0 -3 -3 1 -2 -3 -1 0 -1 -3 -2 -2 1 4 -1 -4
|
||||
G 0 -2 0 -1 -3 -2 -2 6 -2 -4 -4 -2 -3 -3 -2 0 -2 -2 -3 -3 -1 -2 -1 -4
|
||||
H -2 0 1 -1 -3 0 0 -2 8 -3 -3 -1 -2 -1 -2 -1 -2 -2 2 -3 0 0 -1 -4
|
||||
I -1 -3 -3 -3 -1 -3 -3 -4 -3 4 2 -3 1 0 -3 -2 -1 -3 -1 3 -3 -3 -1 -4
|
||||
L -1 -2 -3 -4 -1 -2 -3 -4 -3 2 4 -2 2 0 -3 -2 -1 -2 -1 1 -4 -3 -1 -4
|
||||
K -1 2 0 -1 -3 1 1 -2 -1 -3 -2 5 -1 -3 -1 0 -1 -3 -2 -2 0 1 -1 -4
|
||||
M -1 -1 -2 -3 -1 0 -2 -3 -2 1 2 -1 5 0 -2 -1 -1 -1 -1 1 -3 -1 -1 -4
|
||||
F -2 -3 -3 -3 -2 -3 -3 -3 -1 0 0 -3 0 6 -4 -2 -2 1 3 -1 -3 -3 -1 -4
|
||||
P -1 -2 -2 -1 -3 -1 -1 -2 -2 -3 -3 -1 -2 -4 7 -1 -1 -4 -3 -2 -2 -1 -2 -4
|
||||
S 1 -1 1 0 -1 0 0 0 -1 -2 -2 0 -1 -2 -1 4 1 -3 -2 -2 0 0 0 -4
|
||||
T 0 -1 0 -1 -1 -1 -1 -2 -2 -1 -1 -1 -1 -2 -1 1 5 -2 -2 0 -1 -1 0 -4
|
||||
W -3 -3 -4 -4 -2 -2 -3 -2 -2 -3 -2 -3 -1 1 -4 -3 -2 11 2 -3 -4 -3 -2 -4
|
||||
Y -2 -2 -2 -3 -2 -1 -2 -3 2 -1 -1 -2 -1 3 -3 -2 -2 2 7 -1 -3 -2 -1 -4
|
||||
V 0 -3 -3 -3 -1 -2 -2 -3 -3 3 1 -2 1 -1 -2 -2 0 -3 -1 4 -3 -2 -1 -4
|
||||
B -2 -1 3 4 -3 0 1 -1 0 -3 -4 0 -3 -3 -2 0 -1 -4 -3 -3 4 1 -1 -4
|
||||
Z -1 0 0 1 -3 3 4 -2 0 -3 -3 1 -1 -3 -1 0 -1 -3 -2 -2 1 4 -1 -4
|
||||
X 0 -1 -1 -1 -2 -1 -1 -1 -1 -1 -1 -1 -1 -1 -2 0 0 -2 -1 -1 -1 -1 -1 -4
|
||||
* -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 1
|
||||
94
DNN_final.py
Normal file
94
DNN_final.py
Normal file
@@ -0,0 +1,94 @@
|
||||
from Tools import getMMScoreType,WeightAndMatrix
|
||||
from sklearn.metrics import roc_auc_score
|
||||
from sklearn.model_selection import KFold,StratifiedKFold
|
||||
from keras import metrics
|
||||
from keras.models import Model
|
||||
import numpy as np
|
||||
from keras import backend as K
|
||||
from keras import optimizers
|
||||
from keras.models import Sequential,load_model
|
||||
from keras.layers import Dense, Dropout
|
||||
from keras.callbacks import EarlyStopping,ModelCheckpoint
|
||||
def final_dnn():
|
||||
length = 30
|
||||
nfold = 10
|
||||
f1 = open(r"POS.txt", "r")
|
||||
f2 = open(r"NEG.txt", "r")
|
||||
pos = set()
|
||||
neg = set()
|
||||
for line in f1.readlines():
|
||||
sp = line.strip().split("\t")
|
||||
pep = sp[0]
|
||||
site = sp[1] + "\t" + sp[2]
|
||||
pos.add(pep)
|
||||
for line in f2.readlines():
|
||||
sp = line.strip().split("\t")
|
||||
pep = sp[0]
|
||||
if pep not in pos:
|
||||
neg.add(pep)
|
||||
|
||||
fw = open(r"AUCs.txt", "a")
|
||||
AAscores, l_aas, weight_coef, AAs = \
|
||||
WeightAndMatrix(r"traningout_best.txt")
|
||||
l_scores, l_type, peps = getMMScoreType(pos, neg, AAscores, weight_coef, l_aas, AAs, length)
|
||||
raw_scores = []
|
||||
for i in range(len(l_scores)):
|
||||
total = 0.0
|
||||
for j in range(len(l_scores[i])):
|
||||
total += l_scores[i][j]
|
||||
raw_scores.append(total)
|
||||
X = np.array(l_scores)
|
||||
Y = np.array(l_type)
|
||||
PEP = np.array(peps)
|
||||
parameter = [512, 0.2, 2, X.shape[1]]
|
||||
auc_all,best_model = dnn(X,Y,nfold,parameter,PEP)
|
||||
fw.write("Best AUC:" + "\t" + str(auc_all) + "\t" + str(best_model) + "\n")
|
||||
fw.flush()
|
||||
fw.close()
|
||||
|
||||
|
||||
|
||||
def dnn(X,Y,nfold,parameter,PEP):
|
||||
skf = StratifiedKFold(n_splits=nfold)
|
||||
num = 0
|
||||
best_auc = 0.0
|
||||
best_model = 0
|
||||
Y_last = []
|
||||
Score_last = []
|
||||
for train_index, test_index in skf.split(X, Y):
|
||||
num += 1
|
||||
print("dnn_" + str(num))
|
||||
X_train, X_test = X[train_index], X[test_index]
|
||||
Y_train, Y_test = Y[train_index], Y[test_index]
|
||||
model = create_model(parameter)
|
||||
model.fit(X_train, Y_train, epochs=300, batch_size=100,validation_data=(X_test,Y_test),verbose=1,
|
||||
callbacks=[EarlyStopping(monitor="val_auc", mode="max", min_delta=0, patience=30),
|
||||
ModelCheckpoint(str(num) +'.model', monitor="val_auc", mode="max", save_best_only=True)])
|
||||
model = load_model(str(num) +".model")
|
||||
predict_x = model.predict(X_test)[:, 0]
|
||||
auc = roc_auc_score(Y_test, predict_x)
|
||||
if auc > best_auc:
|
||||
best_auc = auc
|
||||
best_model = num
|
||||
Y_last.extend(Y_test)
|
||||
Score_last.extend(predict_x)
|
||||
K.clear_session()
|
||||
auc_all = roc_auc_score(np.array(Y_last), np.array(Score_last))
|
||||
return auc_all,best_model
|
||||
|
||||
|
||||
def create_model(parameter):
|
||||
model = Sequential()
|
||||
model.add(Dense(parameter[0], activation='linear', input_dim=parameter[3]))
|
||||
model.add(Dropout(parameter[1]))
|
||||
for i in range(parameter[2]):
|
||||
fold = 2 ** (i+1)
|
||||
model.add(Dense(parameter[0]/ fold,activation='linear'))
|
||||
model.add(Dropout(parameter[1]))
|
||||
model.add(Dense(1, activation='sigmoid'))
|
||||
model.compile(optimizer=optimizers.Adam(lr=1e-3,decay=3e-5), loss='binary_crossentropy', metrics=[metrics.AUC(name="auc")])
|
||||
model.summary()
|
||||
|
||||
return model
|
||||
|
||||
final_dnn()
|
||||
33
Dockerfile
Normal file
33
Dockerfile
Normal file
@@ -0,0 +1,33 @@
|
||||
# 使用Python 3.8作为基础镜像
|
||||
FROM python:3.8-slim
|
||||
|
||||
# 设置工作目录
|
||||
WORKDIR /app
|
||||
|
||||
# 安装系统依赖
|
||||
RUN apt-get update && apt-get install -y \
|
||||
build-essential \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# 复制requirements.txt
|
||||
COPY requirements.txt .
|
||||
|
||||
# 安装Python依赖
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# 安装项目所需的Python包
|
||||
RUN pip install --no-cache-dir \
|
||||
tensorflow \
|
||||
keras \
|
||||
scikit-learn \
|
||||
numpy \
|
||||
pandas
|
||||
|
||||
# 复制项目文件
|
||||
COPY . .
|
||||
|
||||
# 设置环境变量
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
|
||||
# 设置容器启动命令
|
||||
CMD ["python", "DNN_final.py"]
|
||||
33
Dockerfile.pytorch
Normal file
33
Dockerfile.pytorch
Normal file
@@ -0,0 +1,33 @@
|
||||
# 使用支持CUDA的PyTorch基础镜像
|
||||
FROM pytorch/pytorch:2.0.1-cuda11.7-cudnn8-runtime
|
||||
|
||||
# 设置工作目录
|
||||
WORKDIR /app
|
||||
|
||||
# 安装系统依赖
|
||||
RUN apt-get update && apt-get install -y \
|
||||
build-essential \
|
||||
git \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# 复制requirements.pytorch.txt
|
||||
COPY requirements.pytorch.txt .
|
||||
|
||||
# 安装Python依赖
|
||||
RUN pip install --no-cache-dir -r requirements.pytorch.txt
|
||||
|
||||
# 安装额外的机器学习包
|
||||
RUN pip install --no-cache-dir \
|
||||
scikit-learn \
|
||||
pandas \
|
||||
matplotlib \
|
||||
tensorboard
|
||||
|
||||
# 复制项目文件
|
||||
COPY . .
|
||||
|
||||
# 设置环境变量
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
|
||||
# 设置容器启动命令
|
||||
CMD ["python", "train.py"]
|
||||
400
GPS 5.0M.py
Normal file
400
GPS 5.0M.py
Normal file
@@ -0,0 +1,400 @@
|
||||
from Tools import blosum62
|
||||
import numpy as np
|
||||
from sklearn.metrics import roc_auc_score
|
||||
from sklearn.linear_model import LogisticRegressionCV
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
from pathlib import Path
|
||||
|
||||
def trainning():
|
||||
length = 30
|
||||
fold = 10
|
||||
if not Path('POS.txt').exists():
|
||||
raise FileNotFoundError("POS.txt or NEG.txt not found.")
|
||||
f1 = open(r"POS.txt", "r")
|
||||
f2 = open(r"NEG.txt", "r")
|
||||
pos = set()
|
||||
neg = set()
|
||||
for line in f1.readlines():
|
||||
sp = line.strip().split("\t")
|
||||
pep = sp[0]
|
||||
pos.add(pep)
|
||||
for line in f2.readlines():
|
||||
sp = line.strip().split("\t")
|
||||
pep = sp[0]
|
||||
if pep not in pos:
|
||||
neg.add(pep)
|
||||
|
||||
print("Frist round。。。。。。。。。。。。。")
|
||||
AAscores, l_aas, AAs = blosum62()
|
||||
l_scores, l_type,l_peps = getWeightScoreType(pos, neg, AAscores, AAs,length) ## 获取位置权重确定 PWD 矩阵
|
||||
raw_scores = []
|
||||
for i in range(len(l_scores)): # 所有正负样本的矩阵数目 长度是 21919 个 丙酰化的位点
|
||||
total = 0.0
|
||||
for j in range(len(l_scores[i])):
|
||||
total += l_scores[i][j]
|
||||
raw_scores.append(total) # 展平所有打分
|
||||
X = np.array(l_scores) # 所有评分
|
||||
Y = np.array(l_type) # 所有正负列表的标签
|
||||
PEP = np.array(l_peps) # 原始丙酰化肽段的原始字符串
|
||||
weight_coef, weight_auc = logistic_GPS(X, Y,PEP,"WW",0)
|
||||
print("First weight AUC:" + str(weight_auc))
|
||||
#MM training
|
||||
l_scores, l_type,peps = getMMScoreType(pos, neg, AAscores, weight_coef, l_aas, AAs,length)
|
||||
raw_scores = []
|
||||
for i in range(len(l_scores)):
|
||||
total = 0.0
|
||||
for j in range(len(l_scores[i])):
|
||||
total += l_scores[i][j]
|
||||
raw_scores.append(total)
|
||||
X = np.array(l_scores)
|
||||
Y = np.array(l_type)
|
||||
PEP = np.array(l_peps)
|
||||
MM_coef,MM_auc= logistic_GPS(X, Y,PEP,"MM",0)
|
||||
print("First matix AUC:" + str(MM_auc))
|
||||
best_weight_auc = weight_auc
|
||||
best_MM_auc = MM_auc
|
||||
|
||||
file = "traningout_first_" + str(fold) + ".txt"
|
||||
writeParameter_MM(file, 10, 10, AAscores, l_aas, weight_coef, MM_coef, MM_auc)
|
||||
|
||||
AAscores = newAAScore(AAscores, l_aas, AAs, MM_coef)
|
||||
for i in range(100)[1:]:
|
||||
print("The " + str(i + 1) + "th trainning")
|
||||
l_scores, l_type,l_peps = getWeightScoreType(pos, neg, AAscores, AAs, length)
|
||||
X = np.array(l_scores)
|
||||
Y = np.array(l_type)
|
||||
PEP = np.array(l_peps)
|
||||
weight_coef, weight_auc = logistic_GPS(X, Y,PEP,"WW",i)
|
||||
if weight_auc > best_weight_auc:
|
||||
best_weight_auc = weight_auc
|
||||
file2 = "traningout_weight_best_" + str(fold) + "_" + str(i+1) + ".txt"
|
||||
writeParameter_WW(file2, 30, 30, AAscores, l_aas, weight_coef, weight_auc)
|
||||
# MM training
|
||||
l_scores, l_type, peps = getMMScoreType(pos, neg, AAscores, weight_coef, l_aas, AAs, length)
|
||||
X = np.array(l_scores)
|
||||
Y = np.array(l_type)
|
||||
PEP = np.array(l_peps)
|
||||
MM_coef, MM_auc = logistic_GPS(X, Y,PEP,"MM",i)
|
||||
print("The " + str(i + 1) + "round AUC:" + str(MM_auc))
|
||||
if MM_auc > best_MM_auc:
|
||||
best_MM_auc = MM_auc
|
||||
file2 = "traningout_MM_best_" + str(fold) + "_" + str(i+1) + ".txt"
|
||||
writeParameter_MM(file2, 30, 30, AAscores, l_aas, weight_coef, MM_coef, MM_auc)
|
||||
else:
|
||||
break
|
||||
AAscores = newAAScore(AAscores, l_aas, AAs, MM_coef)
|
||||
|
||||
def newAAScore(AAscores, l_aas, AAs, MM_coef):
|
||||
dict_weight = {}
|
||||
for i in range(len(l_aas)):
|
||||
aas = l_aas[i]
|
||||
score = AAscores[aas]
|
||||
mweight = MM_coef[i]
|
||||
newscore = score * mweight
|
||||
dict_weight[aas] = newscore
|
||||
return dict_weight
|
||||
|
||||
|
||||
def getWeightScoreType(pos, neg, matrix, AAs,length):
|
||||
scores = [] # scores 是一个二维列表,用于存储每个位置上的氨基酸得分。
|
||||
for i in range(length*2+1): # 61 = 30+1+30
|
||||
pos_score = []
|
||||
for j in range(len(AAs)):
|
||||
aa1 = AAs[j]
|
||||
score = 0.0
|
||||
for oth in pos:
|
||||
aa2 = oth[i:i + 1]
|
||||
aas = aa1 + "_" + aa2
|
||||
aas2 = aa2 + "_" + aa1
|
||||
if aas in matrix:
|
||||
score += matrix[aas]
|
||||
else:
|
||||
score += matrix[aas2]
|
||||
pos_score.append(score)
|
||||
scores.append(pos_score)
|
||||
|
||||
l_scores = [] # l_scores 是一个二维列表,用于存储每个肽段的得分向量。
|
||||
l_type = []
|
||||
l_peps = []
|
||||
|
||||
for pep in pos:
|
||||
score = []
|
||||
for i in range(len(pep)): # range(0,61)
|
||||
aa = pep[i:i + 1]
|
||||
index = AAs.index(aa)
|
||||
aascore = (scores[i][index] - matrix[aa + "_" + aa]) / (len(pos) - 1) # 减去自身样本的影响,使用样本均值
|
||||
score.append(aascore)
|
||||
l_scores.append(score)
|
||||
l_type.append(1)
|
||||
l_peps.append(pep)
|
||||
|
||||
# num = 0
|
||||
for pep in neg:
|
||||
score = []
|
||||
for i in range(len(pep)):
|
||||
aa = pep[i:i + 1]
|
||||
index = AAs.index(aa)
|
||||
aascore = scores[i][index] / len(pos) # 负样本本身就是噪音并不需要调整自身影响。
|
||||
score.append(aascore)
|
||||
l_scores.append(score)
|
||||
l_type.append(0)
|
||||
l_peps.append(pep)
|
||||
|
||||
return l_scores, l_type,l_peps
|
||||
|
||||
def getMMScoreType(pos, neg, matrix, weights, l_aas, AAs,length):
|
||||
"""
|
||||
参数解释:
|
||||
|
||||
1. `pos` (set):
|
||||
- 正样本集合,包含肽段序列的字符串。
|
||||
- 每个肽段表示一个多肽序列,例如 "PEPTIDE"。
|
||||
- 这些肽段对应已知的正类标签(功能重要或感兴趣的多肽)。
|
||||
|
||||
2. `neg` (set):
|
||||
- 负样本集合,包含肽段序列的字符串。
|
||||
- 每个肽段表示一个多肽序列,例如 "NEGTIDE"。
|
||||
- 这些肽段对应已知的负类标签(非功能重要或不感兴趣的多肽)。
|
||||
|
||||
3. `matrix` (dict):
|
||||
- 氨基酸对的评分矩阵,键为氨基酸对(如 "A_R" 或 "G_G"),值为相应的得分。
|
||||
- 用于衡量不同氨基酸对之间的相似性或重要性。
|
||||
- 例如:`matrix["A_R"] = 1.2` 表示氨基酸对 A 和 R 的评分为 1.2。
|
||||
|
||||
4. `weights` (list):
|
||||
- 权重列表,存储逻辑回归模型计算出的权重,用于加权计算每个位置的影响。
|
||||
- 长度为 `length * 2 + 1`,每个权重对应肽段中某个相对位置的影响。
|
||||
- 例如:`weights[0]` 可能表示窗口中心位置的权重,`weights[length]` 表示窗口前 `length` 个位置的权重。
|
||||
|
||||
5. `l_aas` (list):
|
||||
- 氨基酸对的完整列表,包含所有可能的氨基酸对组合。
|
||||
- 用于定位得分矩阵中某个氨基酸对的位置。
|
||||
- 例如:`l_aas = ["A_A", "A_R", "R_A", "G_G", ...]`。
|
||||
|
||||
6. `AAs` (list):
|
||||
- 氨基酸的完整列表,包含所有单个氨基酸字符。
|
||||
- 用于定位肽段中某个氨基酸的位置。
|
||||
- 例如:`AAs = ["A", "R", "N", "D", "C", "Q", "E", "G", "H", "I", "L", "K", "M", "F", "P", "S", "T", "W", "Y", "V"]`。
|
||||
|
||||
7. `length` (int):
|
||||
- 窗口大小,用于定义肽段中心位置的上下游范围。
|
||||
- 计算时会覆盖从 `-length` 到 `+length` 的范围,总共 `2 * length + 1` 个位置。
|
||||
- 例如:`length = 30` 时,窗口范围为中心位置前后 30 个氨基酸,总共 61 个位置。
|
||||
"""
|
||||
scorespos = []
|
||||
scoresneg = []
|
||||
for i in range(length*2+1):
|
||||
score_pos = []
|
||||
score_neg = []
|
||||
for j in range(len(AAs)):
|
||||
aa1 = AAs[j]
|
||||
score = []
|
||||
for z in range(len(l_aas)):
|
||||
score.append(0.0)
|
||||
for oth in pos:
|
||||
aa2 = oth[i:i + 1]
|
||||
aas1 = aa1 + "_" + aa2
|
||||
aas2 = aa2 + "_" + aa1
|
||||
if aas1 in l_aas:
|
||||
index = l_aas.index(aas1)
|
||||
score[index] += matrix[aas1] * weights[i]
|
||||
elif aas2 in l_aas:
|
||||
index = l_aas.index(aas2)
|
||||
score[index] += matrix[aas2] * weights[i]
|
||||
scoreneg = np.array(score)
|
||||
index2 = l_aas.index(aa1 + "_" + aa1)
|
||||
score[index2] -= matrix[aa1 +"_" + aa1] * weights[i]
|
||||
scorepos = np.array(score)
|
||||
|
||||
score_pos.append(scorepos)
|
||||
score_neg.append(scoreneg)
|
||||
scorespos.append(score_pos)
|
||||
scoresneg.append(score_neg)
|
||||
|
||||
l_scores = []
|
||||
l_type = []
|
||||
l_peps = []
|
||||
|
||||
for pep in pos:
|
||||
score = getArray(l_aas)
|
||||
for i in range(len(pep)):
|
||||
aa = pep[i:i + 1]
|
||||
index = AAs.index(aa)
|
||||
scoreary = scorespos[i][index]
|
||||
score += scoreary
|
||||
|
||||
score = (score / (len(pos) -1 )).tolist()
|
||||
|
||||
l_scores.append(score)
|
||||
l_type.append(1)
|
||||
l_peps.append(pep)
|
||||
|
||||
# num = 0
|
||||
for pep in neg:
|
||||
score = getArray(l_aas)
|
||||
for i in range(len(pep)):
|
||||
aa = pep[i:i + 1]
|
||||
index = AAs.index(aa)
|
||||
scoreary = scoresneg[i][index]
|
||||
score += scoreary
|
||||
|
||||
score = (score / len(pos)).tolist()
|
||||
|
||||
l_scores.append(score)
|
||||
l_type.append(0)
|
||||
l_peps.append(pep)
|
||||
return l_scores, l_type, l_peps
|
||||
|
||||
|
||||
def getArray(l_aas):
|
||||
score = []
|
||||
for i in range(len(l_aas)):
|
||||
score.append(0.0)
|
||||
scoreary = np.array(score)
|
||||
|
||||
return scoreary
|
||||
|
||||
def writeParameter_WW(file,left,right,AAscores,l_aas,weight_coef,weight_auc):
|
||||
fw = open(file, "w")
|
||||
list_aa = ["A", "R", "N", "D", "C", "Q", "E", "G", "H", "I", "L", "K", "M", "F", "P", "S", "T", "W", "Y", "V",
|
||||
"B", "Z", "X", "*"]
|
||||
dict_weight = {}
|
||||
for i in range(len(l_aas)):
|
||||
aas = l_aas[i]
|
||||
score = AAscores[aas]
|
||||
dict_weight[aas] = score
|
||||
fw.write("#KprFunc 1.0 Parameters\n")
|
||||
fw.write("#Version: 1.0\n")
|
||||
fw.write("#By Chenwei Wang @HUST\n")
|
||||
fw.write("@param\tCode=K\tUp=" + str(left) + "\tDown=" + str(right) + "\n")
|
||||
fw.write("@AUC=" + str(weight_auc) + "\n")
|
||||
fw.write("@weight")
|
||||
for i in range(len(weight_coef)):
|
||||
fw.write("\t" + str(weight_coef[i]))
|
||||
fw.write("\n")
|
||||
|
||||
for i in range(len(list_aa)):
|
||||
a = list_aa[i]
|
||||
fw.write(" " + a)
|
||||
fw.write("\n")
|
||||
|
||||
for i in range(len(list_aa)):
|
||||
a1 = list_aa[i]
|
||||
fw.write(a1)
|
||||
for j in range(len(list_aa)):
|
||||
a2 = list_aa[j]
|
||||
aas1 = a1 + "_" + a2
|
||||
aas2 = a2 + "_" + a1
|
||||
score = 0.0
|
||||
if aas1 in dict_weight:
|
||||
score = dict_weight[aas1]
|
||||
elif aas2 in dict_weight:
|
||||
score = dict_weight[aas2]
|
||||
else:
|
||||
print(aas1 + "wrong!")
|
||||
fw.write(" " + str(score))
|
||||
fw.write("\n")
|
||||
|
||||
fw.flush()
|
||||
fw.close()
|
||||
|
||||
def writeParameter_MM(file,left,right,AAscores,l_aas,weight_coef,MM_coef,MM_auc):
|
||||
fw = open(file, "w")
|
||||
list_aa = ["A", "R", "N", "D", "C", "Q", "E", "G", "H", "I", "L", "K", "M", "F", "P", "S", "T", "W", "Y", "V",
|
||||
"B", "Z", "X", "*"]
|
||||
dict_weight = {}
|
||||
for i in range(len(l_aas)):
|
||||
aas = l_aas[i]
|
||||
score = AAscores[aas]
|
||||
mweight = MM_coef[i]
|
||||
newscore = score * mweight
|
||||
dict_weight[aas] = newscore
|
||||
fw.write("#KprFunc 1.0 Parameters\n")
|
||||
fw.write("#Version: 1.0\n")
|
||||
fw.write("#By Chenwei Wang @HUST\n")
|
||||
fw.write("@param\tCode=K\tUp=" + str(left) + "\tDown=" + str(right) + "\n")
|
||||
fw.write("@AUC=" + str(MM_auc) + "\n")
|
||||
fw.write("@weight")
|
||||
for i in range(len(weight_coef)):
|
||||
fw.write("\t" + str(weight_coef[i]))
|
||||
fw.write("\n")
|
||||
|
||||
for i in range(len(list_aa)):
|
||||
a = list_aa[i]
|
||||
fw.write(" " + a)
|
||||
fw.write("\n")
|
||||
|
||||
for i in range(len(list_aa)):
|
||||
a1 = list_aa[i]
|
||||
fw.write(a1)
|
||||
for j in range(len(list_aa)):
|
||||
a2 = list_aa[j]
|
||||
aas1 = a1 + "_" + a2
|
||||
aas2 = a2 + "_" + a1
|
||||
score = 0.0
|
||||
if aas1 in dict_weight:
|
||||
score = dict_weight[aas1]
|
||||
elif aas2 in dict_weight:
|
||||
score = dict_weight[aas2]
|
||||
else:
|
||||
print(aas1 + "wrong!")
|
||||
fw.write(" " + str(score))
|
||||
fw.write("\n")
|
||||
|
||||
fw.flush()
|
||||
fw.close()
|
||||
|
||||
def logistic_GPS(X: bytearray, Y: bytearray,PEP,type,turn):
|
||||
'''
|
||||
@description:
|
||||
在代码中,logistic_GPS 函数是实现 PWD 和 SMO 的核心,完成了逻辑回归的训练与权重优化。
|
||||
@param:
|
||||
X:
|
||||
|
||||
输入特征矩阵,形状为 (n_samples, n_features),其中 n_samples 是样本数量,n_features 是特征数量。肽段得分矩阵
|
||||
Y:
|
||||
|
||||
目标变量(标签),形状为 (n_samples,),表示每个样本的类别(0 或 1)。正负样本,1 或 0
|
||||
PEP:
|
||||
|
||||
肽段数据,形状为 (n_samples,),表示每个样本的肽段序列。虽然在 logistic_GPS 函数中没有直接使用 PEP,但它可能在后续的处理中被用到,例如在 writeParameter_WW 或 writeParameter_MM 函数中。
|
||||
主要功能是执行逻辑回归训练,特别是交叉验证选择最佳正则化参数 C,并根据此参数训练逻辑回归模型,返回权重系数和 AUC 分数。
|
||||
|
||||
return
|
||||
list_coef: list of float
|
||||
权重系数,形状为 (n_features,),表示每个特征的权重。用于加权评分。
|
||||
auc: 逻辑回归模型的性能指标,用于评估模型的分类能力。
|
||||
|
||||
1. type = "WW"
|
||||
含义:表示“Weighted Window”(加权窗口)训练类型。
|
||||
特点:
|
||||
在这种训练类型中,每个位置的权重是通过逻辑回归模型计算出来的。
|
||||
重点在于计算每个位置的权重,以便更好地反映每个位置对肽段分类的重要性。
|
||||
通常用于初始训练阶段,确定每个位置的权重。
|
||||
2. type = "MM"
|
||||
含义:表示“Matrix Multiplication”(矩阵乘法)训练类型。
|
||||
特点:
|
||||
在这种训练类型中,除了计算每个位置的权重外,还会考虑氨基酸对的评分矩阵。
|
||||
重点在于结合氨基酸对的评分矩阵和位置权重,进一步优化模型的性能。
|
||||
通常用于后续的训练阶段,通过矩阵乘法进一步调整权重,提高模型的准确性。
|
||||
'''
|
||||
solverchose = 'sag'
|
||||
clscv = LogisticRegressionCV(max_iter=10000, cv=10, solver=solverchose,scoring='roc_auc') # LogisticRegressionCV 的目标是通过最大化AUC交叉验证找到最佳的 C 值
|
||||
clscv.fit(X, Y)
|
||||
#regularization = clscv.C_[0]
|
||||
regularization = clscv.C_[0] * (10**(-turn)) # 动态调整C的值,通过10^turn 来调整正则化的强度。
|
||||
print("C=" + str(regularization))
|
||||
cls = LogisticRegression(max_iter=10000,solver=solverchose,C=regularization)
|
||||
cls.fit(X, Y)
|
||||
list_coef = cls.coef_[0]
|
||||
predict_prob_x = cls.predict_proba(X)
|
||||
predict_x = predict_prob_x[:, 1]
|
||||
auc = roc_auc_score(Y,np.array(predict_x))
|
||||
print("AUC:" + str(auc))
|
||||
return list_coef,auc
|
||||
|
||||
trainning()
|
||||
|
||||
|
||||
|
||||
|
||||
201
LICENSE
Normal file
201
LICENSE
Normal file
@@ -0,0 +1,201 @@
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
119
MAML.py
Normal file
119
MAML.py
Normal file
@@ -0,0 +1,119 @@
|
||||
from sklearn.metrics import roc_auc_score
|
||||
from sklearn.model_selection import KFold,StratifiedKFold
|
||||
from keras import metrics
|
||||
from keras.models import Model
|
||||
import numpy as np
|
||||
import copy
|
||||
from keras import backend as K
|
||||
from keras import optimizers
|
||||
from keras.models import Sequential,load_model
|
||||
from keras.layers import Dense, Dropout
|
||||
from keras.callbacks import EarlyStopping,ModelCheckpoint
|
||||
import random
|
||||
from Tools import getMMScoreType,WeightAndMatrix
|
||||
def fetshot():
|
||||
|
||||
length = 30
|
||||
nfold = 5
|
||||
f1 = open(r"functionsite.txt", "r")
|
||||
f2 = open(r"POS.txt", "r")
|
||||
funcinf = set()
|
||||
for line in f1.readlines():
|
||||
site = line.strip()
|
||||
funcinf.add(site)
|
||||
pos = []
|
||||
neg = []
|
||||
for line in f2.readlines():
|
||||
sp = line.strip().split("\t")
|
||||
pep = sp[0]
|
||||
site = sp[1] + "\t" + sp[2]
|
||||
if site in funcinf:
|
||||
pos.append(pep)
|
||||
else:
|
||||
neg.append(pep)
|
||||
print(len(pos))
|
||||
fw = open("MAML_AUCs.txt", "a")
|
||||
pos_size = len(pos)
|
||||
for a in range(10000):
|
||||
if len(neg) > pos_size*6:
|
||||
new_neg = random.sample(neg,pos_size*5)
|
||||
tem_neg = copy.deepcopy(neg)
|
||||
for j in range(len(tem_neg)):
|
||||
negpep = tem_neg[j]
|
||||
if negpep in new_neg:
|
||||
neg.remove(negpep)
|
||||
print(len(neg))
|
||||
|
||||
AAscores, l_aas, weight_coef, AAs = \
|
||||
WeightAndMatrix("traningout_best.txt")
|
||||
l_scores, l_type, peps = getMMScoreType(pos, new_neg, AAscores, weight_coef, l_aas, AAs, length)
|
||||
raw_scores = []
|
||||
for i in range(len(l_scores)):
|
||||
total = 0.0
|
||||
for j in range(len(l_scores[i])):
|
||||
total += l_scores[i][j]
|
||||
raw_scores.append(total)
|
||||
X = np.array(l_scores)
|
||||
Y = np.array(l_type)
|
||||
PEP = np.array(peps)
|
||||
parameter = [512, 0.2, 2, X.shape[1]]
|
||||
auc_all,best_model = dnn(X,Y,nfold,parameter,PEP,a)
|
||||
fw.write(str(a+1) + "\tBest:" + "\t" + str(auc_all) + "\t" + str(best_model) + "\n")
|
||||
fw.flush()
|
||||
else:
|
||||
AAscores, l_aas, weight_coef, AAs = \
|
||||
WeightAndMatrix("traningout_best.txt")
|
||||
l_scores, l_type, peps = getMMScoreType(pos, neg, AAscores, weight_coef, l_aas, AAs, length)
|
||||
raw_scores = []
|
||||
for i in range(len(l_scores)):
|
||||
total = 0.0
|
||||
for j in range(len(l_scores[i])):
|
||||
total += l_scores[i][j]
|
||||
raw_scores.append(total)
|
||||
X = np.array(l_scores)
|
||||
Y = np.array(l_type)
|
||||
PEP = np.array(peps)
|
||||
parameter = [512, 0.2, 2, X.shape[1]]
|
||||
auc_all, best_model = dnn(X, Y, nfold, parameter, PEP, a)
|
||||
fw.write(str(a + 1) + "\tBest:" + "\t" + str(auc_all) + "\t" + str(best_model) + "\n")
|
||||
fw.flush()
|
||||
break
|
||||
fw.flush()
|
||||
fw.close()
|
||||
|
||||
|
||||
def dnn(X,Y,nfold,parameter,PEP,a):
|
||||
skf = StratifiedKFold(n_splits=nfold)
|
||||
num = 0
|
||||
best_auc = 0.0
|
||||
best_model = 0
|
||||
Y_last = []
|
||||
Score_last = []
|
||||
for train_index, test_index in skf.split(X, Y):
|
||||
num += 1
|
||||
print("dnn_" + str(num))
|
||||
X_train, X_test = X[train_index], X[test_index]
|
||||
Y_train, Y_test = Y[train_index], Y[test_index]
|
||||
|
||||
model = load_model("original.model")
|
||||
for i in range(6):
|
||||
model.layers[i].trainable = False
|
||||
model.compile(optimizer=optimizers.Adam(lr=1e-3, decay=3e-5), loss='binary_crossentropy',
|
||||
metrics=[metrics.AUC(name="auc")])
|
||||
model.fit(X_train, Y_train, epochs=300, batch_size=8,validation_data=(X_test,Y_test),verbose=1,
|
||||
callbacks=[EarlyStopping(monitor="val_auc", mode="max", min_delta=0, patience=10),
|
||||
ModelCheckpoint(str(a+1) + "_" + str(num) +'.model', monitor="val_auc", mode="max", save_best_only=True)])
|
||||
model = load_model(str(a+1) + "_" + str(num) +".model")
|
||||
predict_x = model.predict(X_test)[:, 0]
|
||||
auc = roc_auc_score(Y_test, predict_x)
|
||||
if auc > best_auc:
|
||||
best_auc = auc
|
||||
best_model = num
|
||||
Y_last.extend(Y_test)
|
||||
Score_last.extend(predict_x)
|
||||
K.clear_session()
|
||||
auc_all = roc_auc_score(np.array(Y_last), np.array(Score_last))
|
||||
return auc_all,best_model
|
||||
|
||||
|
||||
fetshot()
|
||||
39
README.md
Normal file
39
README.md
Normal file
@@ -0,0 +1,39 @@
|
||||
# KprFunc
|
||||
A hybrid-learning AI framework for the prediction of functional propionylation site
|
||||
|
||||
## The description of each source code
|
||||
### GPS 5.0M.py
|
||||
The position weight determination (PWD) and scoring matrix optimization (SMO) methods were adopted iteratively to generate the optimal postion weights and similarity matrix
|
||||
### DNN_final.py
|
||||
A 4-layer DNN framework was implemented in Keras 2.4.3 (http://github.com/fchollet/keras) to general the final model for the prediciton of propionylation sites based on the parameters determined by GPS 5.0M.py
|
||||
### MAML.py
|
||||
A 4-layer DNN framework implemented by a MAML strategy to general the model for the prediciton of functional propionylation sites
|
||||
### Tools.py
|
||||
Supported methods for GPS 5.0M.py, DNN_final.py and MAML.py
|
||||
### demo
|
||||
A small dataset to demo above codes, including the postive & negative dataset, the BLOSUM62 matrix, the typical weights and models generated by GPS 5.0M.py.
|
||||
|
||||
## Software Requirements
|
||||
### OS Requirements
|
||||
Above codes have been tested on the following systems:
|
||||
Windows: Windows7, Windos10
|
||||
Linux: CentOS linux 7.8.2003
|
||||
### Hardware Requirements
|
||||
All codes and softwares could run on a "normal" desktop computer, no non-standard hardware is needed
|
||||
|
||||
## Installation guide
|
||||
All codes can run directly on a "normal" computer with Python 3.7.9 installed, no extra installation is required
|
||||
|
||||
## Instruction
|
||||
For users who want to run KprFunc in own computer, you should first get the optimal postion weights and similarity matrix usding GPS 5.0M.py with the positive dataset and negative dataset in /demo, then the best output of GPS 5.0M.py will be adopted for DNN.final to generate the models for the prediction of propionlytion site. Finally, the known functional propionylation sites contained in "functionsite" would be taken as secondary positive data while other propionylation sites as negative data to generate the models for the prediction of functional propionylation sites with MAML.py
|
||||
|
||||
## Additional information
|
||||
Expected run time is depended on the hardwares of your computer. In general, it will take about 1 hour to get the final models.
|
||||
## Contact
|
||||
Dr. Yu Xue: xueyu@hust.edu.cn
|
||||
Dr. Luoying Zhang: zhangluoying@hust.edu.cn
|
||||
Chenwei Wang: wangchenwei@hust.edu.cn
|
||||
Ke Shui: shuike@hust.edu.cn
|
||||
|
||||
|
||||
|
||||
180
Tools.py
Normal file
180
Tools.py
Normal file
@@ -0,0 +1,180 @@
|
||||
import numpy as np
|
||||
def blosum62():
|
||||
f1 = open("BLOSUM62","r")
|
||||
l_AAS = []
|
||||
AAs = []
|
||||
scores = {}
|
||||
for line in f1.readlines():
|
||||
sp = line.split()
|
||||
aa = sp[0]
|
||||
AAs.append(aa)
|
||||
num = 0
|
||||
f1 = open("BLOSUM62","r")
|
||||
for line in f1.readlines():
|
||||
sp = line.split()
|
||||
for i in range(len(sp)):
|
||||
if i == 0:
|
||||
continue
|
||||
else:
|
||||
score = float(sp[i])
|
||||
aas = AAs[num] + "_" + AAs[i-1]
|
||||
aas2 = AAs[i-1] + "_" + AAs[num]
|
||||
if aas not in l_AAS and aas2 not in l_AAS:
|
||||
l_AAS.append(aas)
|
||||
scores[aas] = score
|
||||
num += 1
|
||||
return scores,l_AAS,AAs
|
||||
|
||||
|
||||
def getWeightScoreType(pos, neg, matrix, AAs,length):
|
||||
scores = []
|
||||
for i in range(length*2+1):
|
||||
pos_score = []
|
||||
for j in range(len(AAs)):
|
||||
aa1 = AAs[j]
|
||||
score = 0.0
|
||||
for oth in pos:
|
||||
aa2 = oth[i:i + 1]
|
||||
aas = aa1 + "_" + aa2
|
||||
aas2 = aa2 + "_" + aa1
|
||||
if aas in matrix:
|
||||
score += matrix[aas]
|
||||
else:
|
||||
score += matrix[aas2]
|
||||
pos_score.append(score)
|
||||
scores.append(pos_score)
|
||||
|
||||
l_scores = []
|
||||
l_type = []
|
||||
l_peps = []
|
||||
|
||||
for pep in pos:
|
||||
score = []
|
||||
for i in range(len(pep)):
|
||||
aa = pep[i:i + 1]
|
||||
index = AAs.index(aa)
|
||||
aascore = (scores[i][index] - matrix[aa + "_" + aa]) / (len(pos) - 1)
|
||||
score.append(aascore)
|
||||
l_scores.append(score)
|
||||
l_type.append(1)
|
||||
l_peps.append(pep)
|
||||
|
||||
# num = 0
|
||||
for pep in neg:
|
||||
score = []
|
||||
for i in range(len(pep)):
|
||||
aa = pep[i:i + 1]
|
||||
index = AAs.index(aa)
|
||||
aascore = scores[i][index] / len(pos)
|
||||
score.append(aascore)
|
||||
l_scores.append(score)
|
||||
l_type.append(0)
|
||||
l_peps.append(pep)
|
||||
return l_scores, l_type,l_peps
|
||||
|
||||
def getMMScoreType(pos, neg, matrix, weights, l_aas, AAs, length):
|
||||
scorespos = []
|
||||
scoresneg = []
|
||||
for i in range(length * 2 + 1):
|
||||
score_pos = []
|
||||
score_neg = []
|
||||
for j in range(len(AAs)):
|
||||
aa1 = AAs[j]
|
||||
score = []
|
||||
for z in range(len(l_aas)):
|
||||
score.append(0.0)
|
||||
for oth in pos:
|
||||
aa2 = oth[i:i + 1]
|
||||
aas1 = aa1 + "_" + aa2
|
||||
aas2 = aa2 + "_" + aa1
|
||||
if aas1 in l_aas:
|
||||
index = l_aas.index(aas1)
|
||||
score[index] += matrix[aas1] * weights[i]
|
||||
elif aas2 in l_aas:
|
||||
index = l_aas.index(aas2)
|
||||
score[index] += matrix[aas2] * weights[i]
|
||||
scoreneg = np.array(score)
|
||||
index2 = l_aas.index(aa1 + "_" + aa1)
|
||||
score[index2] -= matrix[aa1 + "_" + aa1] * weights[i]
|
||||
scorepos = np.array(score)
|
||||
|
||||
score_pos.append(scorepos)
|
||||
score_neg.append(scoreneg)
|
||||
scorespos.append(score_pos)
|
||||
scoresneg.append(score_neg)
|
||||
|
||||
l_scores = []
|
||||
l_type = []
|
||||
l_peps = []
|
||||
|
||||
for pep in pos:
|
||||
score = getArray(l_aas)
|
||||
for i in range(len(pep)):
|
||||
aa = pep[i:i + 1]
|
||||
index = AAs.index(aa)
|
||||
scoreary = scorespos[i][index]
|
||||
score += scoreary
|
||||
score = (score / (len(pos) - 1)).tolist()
|
||||
l_scores.append(score)
|
||||
l_type.append(1)
|
||||
l_peps.append(pep)
|
||||
|
||||
for pep in neg:
|
||||
score = getArray(l_aas)
|
||||
for i in range(len(pep)):
|
||||
aa = pep[i:i + 1]
|
||||
index = AAs.index(aa)
|
||||
scoreary = scoresneg[i][index]
|
||||
score += scoreary
|
||||
score = (score / len(pos)).tolist()
|
||||
l_scores.append(score)
|
||||
l_type.append(0)
|
||||
l_peps.append(pep)
|
||||
return l_scores, l_type, l_peps
|
||||
|
||||
def getArray(l_aas):
|
||||
score = []
|
||||
for i in range(len(l_aas)):
|
||||
score.append(0.0)
|
||||
scoreary = np.array(score)
|
||||
|
||||
return scoreary
|
||||
|
||||
def WeightAndMatrix(path):
|
||||
f1 = open(path, "r")
|
||||
weights = []
|
||||
l_AAS = []
|
||||
AAs = []
|
||||
scores = {}
|
||||
for line in f1.readlines():
|
||||
if line.startswith(" A "):
|
||||
sp = line.strip().split()
|
||||
for i in range(len(sp)):
|
||||
aa = sp[i]
|
||||
AAs.append(aa)
|
||||
if line.startswith("@weight"):
|
||||
sp = line.strip().split("\t")
|
||||
for i in range(len(sp))[1:]:
|
||||
w = float(sp[i])
|
||||
weights.append(w)
|
||||
|
||||
num = 0
|
||||
f1 = open(path, "r")
|
||||
t = False
|
||||
for line in f1.readlines():
|
||||
if t:
|
||||
sp = line.strip().split()
|
||||
for i in range(len(sp)):
|
||||
if i == 0:
|
||||
continue
|
||||
else:
|
||||
score = float(sp[i])
|
||||
aas = AAs[num] + "_" + AAs[i - 1]
|
||||
aas2 = AAs[i - 1] + "_" + AAs[num]
|
||||
if aas not in l_AAS and aas2 not in l_AAS:
|
||||
l_AAS.append(aas)
|
||||
scores[aas] = score
|
||||
num += 1
|
||||
if line.startswith(" A "):
|
||||
t = True
|
||||
return scores, l_AAS, weights, AAs
|
||||
24
demo/BLOSUM62
Normal file
24
demo/BLOSUM62
Normal file
@@ -0,0 +1,24 @@
|
||||
A 4 -1 -2 -2 0 -1 -1 0 -2 -1 -1 -1 -1 -2 -1 1 0 -3 -2 0 -2 -1 0 -4
|
||||
R -1 5 0 -2 -3 1 0 -2 0 -3 -2 2 -1 -3 -2 -1 -1 -3 -2 -3 -1 0 -1 -4
|
||||
N -2 0 6 1 -3 0 0 0 1 -3 -3 0 -2 -3 -2 1 0 -4 -2 -3 3 0 -1 -4
|
||||
D -2 -2 1 6 -3 0 2 -1 -1 -3 -4 -1 -3 -3 -1 0 -1 -4 -3 -3 4 1 -1 -4
|
||||
C 0 -3 -3 -3 9 -3 -4 -3 -3 -1 -1 -3 -1 -2 -3 -1 -1 -2 -2 -1 -3 -3 -2 -4
|
||||
Q -1 1 0 0 -3 5 2 -2 0 -3 -2 1 0 -3 -1 0 -1 -2 -1 -2 0 3 -1 -4
|
||||
E -1 0 0 2 -4 2 5 -2 0 -3 -3 1 -2 -3 -1 0 -1 -3 -2 -2 1 4 -1 -4
|
||||
G 0 -2 0 -1 -3 -2 -2 6 -2 -4 -4 -2 -3 -3 -2 0 -2 -2 -3 -3 -1 -2 -1 -4
|
||||
H -2 0 1 -1 -3 0 0 -2 8 -3 -3 -1 -2 -1 -2 -1 -2 -2 2 -3 0 0 -1 -4
|
||||
I -1 -3 -3 -3 -1 -3 -3 -4 -3 4 2 -3 1 0 -3 -2 -1 -3 -1 3 -3 -3 -1 -4
|
||||
L -1 -2 -3 -4 -1 -2 -3 -4 -3 2 4 -2 2 0 -3 -2 -1 -2 -1 1 -4 -3 -1 -4
|
||||
K -1 2 0 -1 -3 1 1 -2 -1 -3 -2 5 -1 -3 -1 0 -1 -3 -2 -2 0 1 -1 -4
|
||||
M -1 -1 -2 -3 -1 0 -2 -3 -2 1 2 -1 5 0 -2 -1 -1 -1 -1 1 -3 -1 -1 -4
|
||||
F -2 -3 -3 -3 -2 -3 -3 -3 -1 0 0 -3 0 6 -4 -2 -2 1 3 -1 -3 -3 -1 -4
|
||||
P -1 -2 -2 -1 -3 -1 -1 -2 -2 -3 -3 -1 -2 -4 7 -1 -1 -4 -3 -2 -2 -1 -2 -4
|
||||
S 1 -1 1 0 -1 0 0 0 -1 -2 -2 0 -1 -2 -1 4 1 -3 -2 -2 0 0 0 -4
|
||||
T 0 -1 0 -1 -1 -1 -1 -2 -2 -1 -1 -1 -1 -2 -1 1 5 -2 -2 0 -1 -1 0 -4
|
||||
W -3 -3 -4 -4 -2 -2 -3 -2 -2 -3 -2 -3 -1 1 -4 -3 -2 11 2 -3 -4 -3 -2 -4
|
||||
Y -2 -2 -2 -3 -2 -1 -2 -3 2 -1 -1 -2 -1 3 -3 -2 -2 2 7 -1 -3 -2 -1 -4
|
||||
V 0 -3 -3 -3 -1 -2 -2 -3 -3 3 1 -2 1 -1 -2 -2 0 -3 -1 4 -3 -2 -1 -4
|
||||
B -2 -1 3 4 -3 0 1 -1 0 -3 -4 0 -3 -3 -2 0 -1 -4 -3 -3 4 1 -1 -4
|
||||
Z -1 0 0 1 -3 3 4 -2 0 -3 -3 1 -1 -3 -1 0 -1 -3 -2 -2 1 4 -1 -4
|
||||
X 0 -1 -1 -1 -2 -1 -1 -1 -1 -1 -1 -1 -1 -1 -2 0 0 -2 -1 -1 -1 -1 -1 -4
|
||||
* -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 1
|
||||
20213
demo/NEG.txt
Normal file
20213
demo/NEG.txt
Normal file
File diff suppressed because it is too large
Load Diff
1707
demo/POS.txt
Normal file
1707
demo/POS.txt
Normal file
File diff suppressed because it is too large
Load Diff
13
demo/functionsite
Normal file
13
demo/functionsite
Normal file
@@ -0,0 +1,13 @@
|
||||
Q09472 1554
|
||||
Q09472 1555
|
||||
Q09472 1558
|
||||
Q09472 1560
|
||||
P62806 17
|
||||
P08515 11
|
||||
P08515 180
|
||||
P08515 181
|
||||
P68431 24
|
||||
P61830 24
|
||||
P61830 57
|
||||
P55912 592
|
||||
Q6P980 132
|
||||
31
demo/traningout_best.txt
Normal file
31
demo/traningout_best.txt
Normal file
@@ -0,0 +1,31 @@
|
||||
#KprFunc 1.0 Parameters
|
||||
#Version: 1.0
|
||||
#By Chenwei Wang @HUST
|
||||
@param Code=K Up=10 Down=10
|
||||
@AUC=0.6965301379789716
|
||||
@weight -0.018689462273396077 -0.018744888854332252 0.00294702479025041 -0.018721678351478477 -0.01800558569261358 -0.008358276798366189 -0.00907753373048285 0.005118357005704048 0.004476784921562541 0.001917935779225377 -0.012826394014935651 -0.010954111727165613 0.011940789952102696 0.012207912112939114 -0.002808208285854016 -0.0009658925061748041 0.0011404113910593882 0.012443692092668724 0.013060442658975291 -0.019556766645198398 0.004787683132730011 0.0009189775531212159 0.019910902066826296 0.025785778782636607 0.02280656388019064 0.03548725141382741 0.020884032586196186 0.03378028343513734 -0.006675521197197722 0.01805101776679797 -2.6856017654315826e-05 0.09983429259831926 -0.005350043836287985 0.035294469712788816 0.04530583290961643 0.030707674319242206 0.03538658589192208 0.04989596526751631 0.03753221601130363 0.004211112881363887 0.009338818135657265 0.009921848303412887 0.004476961340685209 -0.00398459142499561 0.008164531072787896 0.007821462401426152 0.010531762609291787 -0.006370575721155825 0.015161718270273355 0.006265233004570455 -0.007496366116987143 0.011051331297383008 -0.0026721724621365022 0.007655540505946334 -0.028687938189757952 -0.0241422469916008 -0.006096023243209063 -0.025903648378292152 -0.016341332872108923 -0.008695073272698103 -0.028922372950143112
|
||||
A R N D C Q E G H I L K M F P S T W Y V B Z X *
|
||||
A -79.21686178125424 -65.92132479850945 -79.98837212206928 92.62139439026826 0.0 -21.673970536667984 -38.81557522754138 0.0 74.70048538908974 22.26925623514351 32.55176650198224 16.260618898029154 0.844571769278751 229.2800946549054 -120.58516522445346 9.831165105980736 0.0 -30.11430460434257 157.99685552930555 0.0 -0.0 -0.0 0.0 -72.28615205702248
|
||||
R -65.92132479850945 30.09743981271335 0.0 -127.02603043821476 -47.51118830327828 -8.61542160592377 0.0 32.14691537634928 0.0 99.03478796565255 -15.398450570193724 184.41772332322412 -19.513532542042633 59.38657875242247 -97.91990790887277 7.909847690097471 -120.73934856938855 259.1479731984058 -57.80617886251458 3.15784189370199 -0.0 0.0 -2.2430271649812012 -83.17589601530933
|
||||
N -79.98837212206928 0.0 103.35925354177455 20.41222620657458 113.62257738361774 0.0 0.0 0.0 -22.913232019703905 39.41527892555478 -16.34179104717671 0.0 66.55298009842075 -115.03738702753992 62.38574308704257 -59.70824767428982 0.0 -126.64758255182352 -29.341962910726068 -27.96035539921929 0.0 0.0 -0.23799812081034613 59.40017704226467
|
||||
D 92.62139439026826 -127.02603043821476 20.41222620657458 -166.9984768142573 -298.0180353277666 0.0 112.39046530449396 141.25195148672574 78.7055610292056 60.9359525727567 -107.26096533357821 7.14594395804312 125.22062150229573 -101.84047524195205 33.01074173760774 0.0 -50.70922225927597 -272.499852043676 -378.0551730872354 0.28823662800545113 0.0 0.0 -2.357496950620939 13.483413642880151
|
||||
C 0.0 -47.51118830327828 113.62257738361774 -298.0180353277666 492.59096915638327 7.414999786573917 349.82196783846547 -83.29268620399931 39.09460958221237 -7.026576196714204 -23.517740862810317 -108.49426532779171 -6.140783256356095 58.485369954483964 -153.2385067054147 -22.131647853054037 12.81319193625693 -62.90926556444103 -21.46768802400408 -95.7555256521471 -0.0 -0.0 1.0433529656873173 15.170410091971686
|
||||
Q -21.673970536667984 -8.61542160592377 0.0 0.0 7.414999786573917 -128.29372461437015 72.73437533733868 36.57712103211332 0.0 94.4442736338702 35.11394163994572 -79.45814968467835 0.0 -312.90370011939405 65.67191337898352 0.0 53.113460301023835 64.28326791769194 -42.34239872981483 -91.41211386862538 0.0 0.0 -0.7884247508056094 50.904456241584874
|
||||
E -38.81557522754138 0.0 0.0 112.39046530449396 349.82196783846547 72.73437533733868 -56.71503509361014 30.319768311411483 0.0 -202.11092180585268 -41.28978478267595 -54.40717794468668 -83.33700073981204 125.85623503422804 151.79293168921117 0.0 -44.4578980238522 -45.84358702785538 126.28513741886557 -0.9960624102945992 0.0 0.0 -3.8888696091556834 -49.22440112687386
|
||||
G 0.0 32.14691537634928 0.0 141.25195148672574 -83.29268620399931 36.57712103211332 30.319768311411483 -403.3729867646581 -349.72183059949555 120.69236333382085 19.05212621188991 -26.126972634359213 6.469825848552872 110.3068590225507 -34.51877747119401 0.0 307.56378189169965 -204.99307699784714 -124.48622953813785 23.5754978183186 -0.0 -0.0 -0.5676051387596681 -37.5380905290468
|
||||
H 74.70048538908974 0.0 -22.913232019703905 78.7055610292056 39.09460958221237 0.0 0.0 -349.72183059949555 -199.5675794762037 64.11038662625575 -2.1286239816855694 55.507429663136115 122.77702460075692 1.3510533585612652 103.6913785579249 27.327718354460142 -12.014242019512894 92.24661803309353 -16.577928267940926 12.72213814812037 0.0 0.0 -0.4261959470553026 -30.946696207802955
|
||||
I 22.26925623514351 99.03478796565255 39.41527892555478 60.9359525727567 -7.026576196714204 94.4442736338702 -202.11092180585268 120.69236333382085 64.11038662625575 -337.9806468653485 -119.12646972648812 -81.45732610916914 -64.17912608236941 0.0 90.00817181318124 81.13242575369226 -22.442321947035097 -96.40634303220645 -19.604999555912414 141.40317709321485 -0.0 -0.0 -0.3129748998700986 -0.7337125394934492
|
||||
L 32.55176650198224 -15.398450570193724 -16.34179104717671 -107.26096533357821 -23.517740862810317 35.11394163994572 -41.28978478267595 19.05212621188991 -2.1286239816855694 -119.12646972648812 56.051759657080964 55.116747195509035 -35.62343511963436 0.0 23.335349459352063 -76.08144081690865 -46.27414973333793 -214.5806513794929 -111.90357389854505 32.99635456881741 -0.0 -0.0 0.2916698996042472 56.557383414628916
|
||||
K 16.260618898029154 184.41772332322412 0.0 7.14594395804312 -108.49426532779171 -79.45814968467835 -54.40717794468668 -26.126972634359213 55.507429663136115 -81.45732610916914 55.116747195509035 -78.75268382796665 40.13095072867636 -157.52814111846226 36.24737383146892 0.0 -17.474921927682338 102.32952605615694 114.22615700194471 -51.71780708741426 0.0 0.0 -2.7164115917493046 -9.690341603582475
|
||||
M 0.844571769278751 -19.513532542042633 66.55298009842075 125.22062150229573 -6.140783256356095 0.0 -83.33700073981204 6.469825848552872 122.77702460075692 -64.17912608236941 -35.62343511963436 40.13095072867636 -41.23204113299315 0.0 -86.65262964865319 -19.216205065919947 -43.64623499562159 5.846479010920261 -24.87394221794222 -18.693380838323208 -0.0 -0.0 -0.05397067529379592 82.24793446282418
|
||||
F 229.2800946549054 59.38657875242247 -115.03738702753992 -101.84047524195205 58.485369954483964 -312.90370011939405 125.85623503422804 110.3068590225507 1.3510533585612652 0.0 0.0 -157.52814111846226 0.0 -612.6517054598662 -118.16563738729351 187.24293780426152 -126.71638003680711 21.984863606707254 74.84992027204291 -77.34998690903478 -0.0 -0.0 -0.25165023465121844 -8.227278957779806
|
||||
P -120.58516522445346 -97.91990790887277 62.38574308704257 33.01074173760774 -153.2385067054147 65.67191337898352 151.79293168921117 -34.51877747119401 103.6913785579249 90.00817181318124 23.335349459352063 36.24737383146892 -86.65262964865319 -118.16563738729351 -220.77117992791 31.203741262404847 -23.21406824589217 180.17267210895102 143.6005722168128 -70.56658554386473 -0.0 -0.0 -3.4159732225792228 12.238138416341974
|
||||
S 9.831165105980736 7.909847690097471 -59.70824767428982 0.0 -22.131647853054037 0.0 0.0 0.0 27.327718354460142 81.13242575369226 -76.08144081690865 0.0 -19.216205065919947 187.24293780426152 31.203741262404847 -248.04389065721458 50.00369037933264 20.65980432304682 -145.01776399359676 -62.43589859119674 0.0 0.0 0.0 102.60166838915332
|
||||
T 0.0 -120.73934856938855 0.0 -50.70922225927597 12.81319193625693 53.113460301023835 -44.4578980238522 307.56378189169965 -12.014242019512894 -22.442321947035097 -46.27414973333793 -17.474921927682338 -43.64623499562159 -126.71638003680711 -23.21406824589217 50.00369037933264 -121.7444087876097 -6.535735822302903 119.90609700147013 0.0 -0.0 -0.0 0.0 -41.644490043119156
|
||||
W -30.11430460434257 259.1479731984058 -126.64758255182352 -272.499852043676 -62.90926556444103 64.28326791769194 -45.84358702785538 -204.99307699784714 92.24661803309353 -96.40634303220645 -214.5806513794929 102.32952605615694 5.846479010920261 21.984863606707254 180.17267210895102 20.65980432304682 -6.535735822302903 -1021.4479806205242 57.811761195943006 198.61423085950054 -0.0 -0.0 0.47845888878625414 102.66355520252834
|
||||
Y 157.99685552930555 -57.80617886251458 -29.341962910726068 -378.0551730872354 -21.46768802400408 -42.34239872981483 126.28513741886557 -124.48622953813785 -16.577928267940926 -19.604999555912414 -111.90357389854505 114.22615700194471 -24.87394221794222 74.84992027204291 143.6005722168128 -145.01776399359676 119.90609700147013 57.811761195943006 -447.9360470667096 39.1163467109986 -0.0 -0.0 -0.15500573992694627 89.48920801856288
|
||||
V 0.0 3.15784189370199 -27.96035539921929 0.28823662800545113 -95.7555256521471 -91.41211386862538 -0.9960624102945992 23.5754978183186 12.72213814812037 141.40317709321485 32.99635456881741 -51.71780708741426 -18.693380838323208 -77.34998690903478 -70.56658554386473 -62.43589859119674 0.0 198.61423085950054 39.1163467109986 -36.00803636365247 -0.0 -0.0 -0.4101578746084907 -55.158673993070664
|
||||
B -0.0 -0.0 0.0 0.0 -0.0 0.0 0.0 -0.0 0.0 -0.0 -0.0 0.0 -0.0 -0.0 -0.0 0.0 -0.0 -0.0 -0.0 -0.0 0.0 0.0 -0.0 -0.0
|
||||
Z -0.0 0.0 0.0 0.0 -0.0 0.0 0.0 -0.0 0.0 -0.0 -0.0 0.0 -0.0 -0.0 -0.0 0.0 -0.0 -0.0 -0.0 -0.0 0.0 0.0 -0.0 -0.0
|
||||
X 0.0 -2.2430271649812012 -0.23799812081034613 -2.357496950620939 1.0433529656873173 -0.7884247508056094 -3.8888696091556834 -0.5676051387596681 -0.4261959470553026 -0.3129748998700986 0.2916698996042472 -2.7164115917493046 -0.05397067529379592 -0.25165023465121844 -3.4159732225792228 0.0 0.0 0.47845888878625414 -0.15500573992694627 -0.4101578746084907 -0.0 -0.0 -0.0 79.82173475395903
|
||||
* -72.28615205702248 -83.17589601530933 59.40017704226467 13.483413642880151 15.170410091971686 50.904456241584874 -49.22440112687386 -37.5380905290468 -30.946696207802955 -0.7337125394934492 56.557383414628916 -9.690341603582475 82.24793446282418 -8.227278957779806 12.238138416341974 102.60166838915332 -41.644490043119156 102.66355520252834 89.48920801856288 -55.158673993070664 -0.0 -0.0 79.82173475395903 -29.027882579458264
|
||||
22
docker/docker-compose.pytorch.yml
Normal file
22
docker/docker-compose.pytorch.yml
Normal file
@@ -0,0 +1,22 @@
|
||||
version: '3.8'
|
||||
|
||||
services:
|
||||
pytorch-training:
|
||||
build:
|
||||
context: ..
|
||||
dockerfile: docker/Dockerfile.pytorch
|
||||
volumes:
|
||||
- ../:/app
|
||||
- ../data:/app/data
|
||||
- ../models:/app/models
|
||||
environment:
|
||||
- PYTHONUNBUFFERED=1
|
||||
- CUDA_VISIBLE_DEVICES=0
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: 1
|
||||
capabilities: [gpu]
|
||||
shm_size: '8gb'
|
||||
16
docker/docker-compose.yml
Normal file
16
docker/docker-compose.yml
Normal file
@@ -0,0 +1,16 @@
|
||||
version: '3.8'
|
||||
|
||||
services:
|
||||
ml-training:
|
||||
build:
|
||||
context: ..
|
||||
dockerfile: docker/Dockerfile
|
||||
volumes:
|
||||
- ../:/app
|
||||
environment:
|
||||
- PYTHONUNBUFFERED=1
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- capabilities: [gpu]
|
||||
158
rebuildinpython/README.md
Normal file
158
rebuildinpython/README.md
Normal file
@@ -0,0 +1,158 @@
|
||||
## Andromeda得分计算
|
||||
|
||||
`Spec/Pscore.cs` 中实现了Andromeda算法的得分计算,包括匹配峰数、质量匹配、和得分的归一化。Score 函数是Andromeda算法的核心。
|
||||
|
||||
Andromeda算法的流程包括以下步骤:
|
||||
|
||||
1. 输入准备:输入理论肽段质量数组和实验质谱数据,包括实验峰的质量和强度。
|
||||
|
||||
2. 修饰组合生成:根据修饰位置和修饰类型生成所有可能的组合,例如可能的修饰位置上分别有磷酸化、乙酰化等修饰。
|
||||
|
||||
3. 匹配计算:对于每个修饰组合的理论质量,与实验质谱中的峰质量进行匹配。根据指定的容差(Da或ppm),计算匹配的峰数
|
||||
|
||||
4. 得分计算:基于匹配的峰数 $k$ 和总峰数 $n$,计算每个组合的得分。得分公式如下:
|
||||
$$
|
||||
s=-k\cdot\ln(p)-(n-k)\cdot\ln(1-p)-\mathrm{Factln}(n)+\mathrm{Factln}(k)+\mathrm{Factln}(n-k)
|
||||
$$
|
||||
缩放:$$\mathrm{Score}=\frac{10\cdot s}{\log_{10}(e)}$$
|
||||
|
||||
|
||||
第一项: $-k\cdot\ln(p)$
|
||||
|
||||
含义:该项计算了匹配的理论峰所贡献的得分
|
||||
|
||||
正负号原因:因为我们希望匹配峰增加得分,而匹配概率 $p$ 通常是小值(如 $p=0.06$),因此 $ln(p)$ 是负值,加上负号后整体是正值,表示匹配的理论峰对总得分的正向贡献。
|
||||
|
||||
原因:我们希望与实验数据匹配的理论峰越多,总得分越高。这里的 $p$ 是匹配概率的自然对数(一般设定为小值,如 $p=0.06$ ),用来降低匹配的理论峰数对得分的影响。
|
||||
|
||||
逻辑:每个匹配的理论峰都会使得得分增加一个 $-\ln(p)$ 的权重,因此随着匹配数 $k$ 的增加,该项使得得分增加。
|
||||
|
||||
第二项: $-(n-k)\cdot\ln(1-p)$
|
||||
|
||||
含义:该项计算了不匹配的理论峰对总得分的贡献
|
||||
|
||||
正负号原因:因为不匹配峰会减少得分,$ln(1−p)$ 是负值($1−p$ 接近 1),乘以负号后为正值,表示不匹配的理论峰对得分的负向影响。
|
||||
|
||||
原因:我们需要在得分计算中考虑到不匹配的峰数量,因为在实验数据中,非匹配峰的存在会降低总得分。这里的 $1-p$ 表示不匹配的概率。
|
||||
|
||||
逻辑:每个未匹配的峰会导致得分降低一个 $-\ln(1-p)$ 的权重。因此,随着未匹配数 $n-k$ 的增加,该项会进一步降低得分。
|
||||
|
||||
第三项:$-Factln(n)$
|
||||
|
||||
含义:该项是对 $n$ 的阶乘的对数(或称为组合项),用于计算总的排列数
|
||||
|
||||
正负号原因:这项表示理论峰的总排列数。由于 $Factln(n)$ 是常数,它是所有组合的基数,减去这一项是为了防止总得分因总排列数的增加而不合理地被放大。
|
||||
|
||||
原因:该项是组合学的部分,表示从 $n$ 个理论峰中选择 $k$ 个匹配峰的总排列数的一部分。
|
||||
|
||||
逻辑:该项是固定的,因此对不同组合之间的相对得分影响不大,但保证了计算的完整性,并在数值上稳定得分计算。
|
||||
|
||||
### 为什么要引入组合学中的对数项?
|
||||
|
||||
> 在Andromeda算法中,得分不仅仅取决于匹配峰的数量,还取决于匹配峰的组合方式。这些对数阶乘项 $(\mathrm{Factln}(n)\mathrm{、Factln}(k)\text{ 和 Factln}(n-k))$ 的引入是为了计算所有可能的组合的对数得分,这在统计学和概率论中被广泛用于处理多项式分布和二项式分布的问题。
|
||||
>
|
||||
> $\mathrm{Factln}(x)=\ln(x!)=\ln(1)+\ln(2)+\ldots+\ln(x)$
|
||||
|
||||
### 为什么要这样计算?
|
||||
|
||||
在质谱匹配中,我们不仅关心匹配数,还关心匹配的组合方式。引入组合项是为了确保在计算匹配和不匹配对得分的贡献时,能够体现出组合的影响。简单来说,它在计算时考虑了所有可能的匹配和不匹配组合,使得得分更加符合实际的概率分布情况。
|
||||
|
||||
第四项:Factln $(k)$
|
||||
|
||||
含义:该项是 $k$ 的阶乘的对数
|
||||
|
||||
正负号原因:匹配峰对得分是正向贡献,因此这一项为正值,表示匹配组合的数目会提升得分。
|
||||
|
||||
原因:这是组合学中的一部分,表示从 $k$ 个匹配峰中选择匹配组合数
|
||||
|
||||
逻辑:该项配合其余项用于组合匹配数的计算,保证得分计算的平衡性。
|
||||
|
||||
第五项:Factln $(n-k)$
|
||||
|
||||
含义:这一项是未匹配峰数 $n-k$ 的阶乘的对数
|
||||
|
||||
原因:这一项的引入同样是为了计入组合方式的影响。对于未匹配峰来说, $n-k$ 个非匹配峰的排列方式也会影响整体得分的分布。这项确保了非匹配峰的组合方式也对得分有影响,使得得分公式能够更加全面地反映匹配和非匹配峰的概率分布。
|
||||
|
||||
逻辑:同样地,该项也是用于组合计算中,保证了不匹配部分在得分计算中的平衡性。
|
||||
|
||||
缩放: $\frac{10\cdot s}{\log_{10}(e)}$
|
||||
|
||||
含义:最终得分 $s$ 除以 $\log_{10}(e)$ 并乘以10,使得得分在数值上更易读且更符 合人类认知。
|
||||
|
||||
原因:缩放将自然对数的结果转换为常见的对数刻度(以10为底),便于与其他得分系统对比。
|
||||
|
||||
逻辑:这种缩放通常会将得分缩放到0-100范围内,便于质谱研究人员解读和分析。
|
||||
|
||||
### 为什么第三项要减,第四和第五项要加?
|
||||
|
||||
1.组合学中的归一化:
|
||||
|
||||
Factln(n)是所有理论峰的排列总数。减去这一项是为了归一化得分,确保匹配峰的贡献是相对于总峰数的,而不是绝对的。如果没有减去 $\mathrm{Factlin}(n),$ 当理论峰总数 $_n$ 较大时,得分会被夸大。
|
||||
|
||||
2.匹配和未匹配的相对影响
|
||||
|
||||
匹配峰的贡献是正向的,因此 $\mathrm{Factln}(k)$ 为正不匹配峰的贡献也是组合学的一部分,但对得分的贡献是中性的(因为其本质是计算所有可能的排列数),所以它的符号也为正。
|
||||
|
||||
|
||||
$k$:匹配的峰数。
|
||||
|
||||
$n$:理论质量数和匹配数的最大值。
|
||||
|
||||
$\mathrm{Factln}\left(x\right)$:是阶乘的对数,避免大数的直接计算。
|
||||
|
||||
$p$:理论峰匹配到实验峰的概率。
|
||||
|
||||
|
||||
|
||||
5. 归一化:对于所有修饰组合的得分,进行归一化计算,即将所有组合得分的和作为分母,得到每个组合的局部化概率(LP分数)。
|
||||
|
||||
6. 输出:最终输出每个修饰组合的得分以及其对应的局部化概率。
|
||||
|
||||
## MsmsHit.cs 和 Pscore.cs 文件中实现的主要算法
|
||||
|
||||
MsmsHit.cs
|
||||
|
||||
修饰组合的生成:在质谱分析中,肽段上可能存在多个修饰(如丙酰化或磷酸化),该文件中包含生成修饰组合的算法。这一功能通过修饰状态对象 PeptideModificationState 和组合生成方法实现。
|
||||
|
||||
Silac指数计算:涉及稳定同位素标记氨基酸(SILAC)数据处理的算法,用于计算SILAC标签。
|
||||
|
||||
峰强度、质量差和注释的提取:包含计算不同修饰状态的质量与实验质谱数据匹配的注释和强度的处理。
|
||||
|
||||
序列质量计算:使用修饰后的肽段序列计算单一的“理论质量”值,用于与实验质量匹配。
|
||||
|
||||
计算得分的核心部分:该文件包含对PScore算法的调用,用于计算质谱峰的匹配得分和修饰位置的局部化概率。
|
||||
|
||||
Pscore.cs
|
||||
|
||||
Andromeda得分计算:Pscore.cs 中实现了Andromeda算法的得分计算,包括匹配峰数、质量匹配、和得分的归一化。Score 函数是Andromeda算法的核心。
|
||||
|
||||
局部化概率(LP)分数计算:基于Andromeda得分,计算特定修饰组合在多个修饰位点的概率分布,即局部化概率。
|
||||
|
||||
匹配数计算:通过 CountMatches 函数,计算理论质量与实验峰的匹配数。
|
||||
|
||||
修饰组合生成:通过 ApplyVariableModificationsFixedNumbers 等函数生成所有可能的修饰组合,并计算这些组合的Andromeda得分。
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
## LP 分数介绍
|
||||
|
||||
代码逻辑位于:Pscore.cs
|
||||
|
||||
局部化概率 (LP) 是用来衡量翻译后修饰 (PTM),例如磷酸化位点,在肽段序列中定位可靠性的指标。
|
||||
|
||||
严格复刻 MsmsHit.cs 和 Pscore.cs 的关键是细节的还原,包括修饰生成逻辑、匹配容差、得分公式中的缩放方式等。
|
||||
|
||||
## 单位介绍
|
||||
|
||||
### Da 或 ppm 单位介绍
|
||||
|
||||
Da(道尔顿):是一种质量单位,主要用于分子和原子的质量。1 Da 等于1/12的碳-12原子质量(约为 1.66053906660 × 10⁻²⁷ 千克)。在质谱中,Da 表示质量差,例如某个肽段或离子的理论质量和实验质量之间的差异,以绝对数值表示。
|
||||
|
||||
ppm(百万分率):是一种相对单位,用于表示误差或偏差的相对大小。
|
||||
|
||||
在质谱中,ppm表示理论质量和实验质量之间的差异相对于理论质量的比例。例如,如果理论质量为 1000 Da,容差为 10 ppm,则允许的误差范围是 1000 × 10⁻⁶ = 0.01 Da。
|
||||
35
rebuildinpython/andromeda.py
Normal file
35
rebuildinpython/andromeda.py
Normal file
@@ -0,0 +1,35 @@
|
||||
import numpy as np
|
||||
from scipy.special import gammaln # gammaln(x) 计算 ln(x!) 更为高效
|
||||
|
||||
def andromeda_score(k, n, p=0.06):
|
||||
"""
|
||||
计算Andromeda算法的得分。
|
||||
|
||||
参数:
|
||||
- k: 匹配的理论峰数
|
||||
- n: 理论峰的总数
|
||||
- p: 匹配概率,默认0.06
|
||||
|
||||
返回:
|
||||
- Andromeda得分
|
||||
"""
|
||||
# 计算匹配项和不匹配项
|
||||
match_term = -k * np.log(p)
|
||||
non_match_term = -(n - k) * np.log(1 - p)
|
||||
|
||||
# 计算组合项
|
||||
factln_n = gammaln(n + 1) # 对应 Factln(n)
|
||||
factln_k = gammaln(k + 1) # 对应 Factln(k)
|
||||
factln_n_k = gammaln(n - k + 1) # 对应 Factln(n - k)
|
||||
|
||||
# Andromeda得分
|
||||
score = match_term + non_match_term - factln_n + factln_k + factln_n_k
|
||||
|
||||
# 缩放得分
|
||||
score = 10.0 * score / np.log(10)
|
||||
return score
|
||||
|
||||
# 示例计算
|
||||
k = 2
|
||||
n = 5
|
||||
andromeda_score(k, n)
|
||||
40
rebuildinpython/lp_score.py
Normal file
40
rebuildinpython/lp_score.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import numpy as np
|
||||
from andromeda import andromeda_score
|
||||
|
||||
def lp_score(masses, spec_masses, spec_intensities, tol=0.1, p=0.06):
|
||||
"""
|
||||
计算局部化概率(LP)分数。
|
||||
|
||||
参数:
|
||||
- masses: 理论质谱峰列表
|
||||
- spec_masses: 实验质谱峰列表
|
||||
- spec_intensities: 实验峰对应的强度
|
||||
- tol: 质量容差
|
||||
- p: 匹配概率
|
||||
|
||||
返回:
|
||||
- LP分数
|
||||
"""
|
||||
# 初始化匹配计数
|
||||
k = 0
|
||||
for mass in masses:
|
||||
# 找到最近的实验峰并检查匹配
|
||||
closest_idx = np.abs(np.array(spec_masses) - mass).argmin()
|
||||
if np.abs(spec_masses[closest_idx] - mass) <= tol:
|
||||
k += 1
|
||||
|
||||
# 计算Andromeda得分
|
||||
n = len(masses) # 理论峰总数
|
||||
base_score = andromeda_score(k, n, p)
|
||||
|
||||
# 计算LP归一化
|
||||
lp = np.exp((base_score - 100) * np.log(10) / 10)
|
||||
lp /= np.sum(lp) # 归一化操作
|
||||
|
||||
return lp
|
||||
|
||||
# 示例计算
|
||||
masses = [100, 105, 110, 115, 120] # 示例理论峰
|
||||
spec_masses = [101, 106, 111, 116, 121] # 示例实验峰
|
||||
spec_intensities = [10, 20, 30, 40, 50] # 实验峰强度
|
||||
lp_score(masses, spec_masses, spec_intensities)
|
||||
30
rebuildinpython/test_andromeda.py
Normal file
30
rebuildinpython/test_andromeda.py
Normal file
@@ -0,0 +1,30 @@
|
||||
import unittest
|
||||
from andromeda import Peptide, Modification, Pscore
|
||||
|
||||
class TestAndromeda(unittest.TestCase):
|
||||
def test_calc_ptm_score(self):
|
||||
sequence = "PEPTIDE"
|
||||
fixed_modifications = []
|
||||
variable_modifications = [
|
||||
Modification(1, 80, ['E', 'D'], 'anywhere'),
|
||||
Modification(2, 42, ['P'], 'anywhere')
|
||||
]
|
||||
mod_count = [1, 1]
|
||||
spec_masses = [100, 105, 110, 115, 120]
|
||||
spec_intensities = [10, 20, 30, 40, 50]
|
||||
ms2_tol = 0.1
|
||||
ms2_tol_unit = "Da"
|
||||
topx = 1
|
||||
mz = 500
|
||||
charge = 2
|
||||
|
||||
score, best_pep, counts, delta, description, intensities, mass_diffs, mod_prob, mod_score_diffs = Pscore.calc_ptm_score(
|
||||
ms2_tol, ms2_tol_unit, topx, sequence, fixed_modifications, variable_modifications, mod_count, spec_masses, spec_intensities, mz, charge
|
||||
)
|
||||
|
||||
self.assertIsNotNone(best_pep)
|
||||
self.assertGreater(score, 0)
|
||||
self.assertEqual(counts, 24) # Number of possible modifications
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
19793
s41467-023-38414-8.pdf
Normal file
19793
s41467-023-38414-8.pdf
Normal file
File diff suppressed because one or more lines are too long
53
src/README.md
Normal file
53
src/README.md
Normal file
@@ -0,0 +1,53 @@
|
||||
## env
|
||||
|
||||
```shell
|
||||
conda create -n l2l python=3.10 pytorch torchvision torchaudio pytorch-cuda=11.8 ipython ipykernel jupyter -c pytorch -c nvidia --yes
|
||||
conda activate l2l
|
||||
pip install learn2learn
|
||||
```
|
||||
|
||||
## [Emu3](https://github.com/baaivision/Emu3) env
|
||||
|
||||
```shell
|
||||
conda create -n emu3 pytorch=2.2.1 transformers=4.44.0 tiktoken=0.6.0 flash-attn=2.5.8 pillow gradio=4.44.0 ipython ipykernel jupyter ninja packaging -c pytorch -c nvidia --yes
|
||||
|
||||
```
|
||||
|
||||
## pytorch 检测
|
||||
|
||||
```python
|
||||
import torch
|
||||
|
||||
# 检查是否检测到 CUDA
|
||||
cuda_available = torch.cuda.is_available()
|
||||
print(f"CUDA 是否可用: {cuda_available}")
|
||||
|
||||
if cuda_available:
|
||||
# 显示当前 CUDA 设备数量
|
||||
num_devices = torch.cuda.device_count()
|
||||
print(f"CUDA 设备数量: {num_devices}")
|
||||
|
||||
# 显示当前设备名称
|
||||
for i in range(num_devices):
|
||||
print(f"设备 {i}: {torch.cuda.get_device_name(i)}")
|
||||
|
||||
# 测试设备分配
|
||||
device = torch.device("cuda:0") # 选择第一个 CUDA 设备
|
||||
print(f"当前设备: {torch.cuda.current_device()}")
|
||||
|
||||
# 测试 Tensor 在 CUDA 上创建和计算
|
||||
tensor = torch.tensor([1.0, 2.0, 3.0], device=device)
|
||||
print(f"Tensor 在 {device} 上创建: {tensor}")
|
||||
|
||||
result = tensor * 2
|
||||
print(f"计算结果: {result}")
|
||||
|
||||
else:
|
||||
print("CUDA 不可用,PyTorch 将使用 CPU。")
|
||||
```
|
||||
|
||||
## 元数据训练
|
||||
|
||||
训练是基于元数据集、元任务。
|
||||
|
||||
问题在于:例子中的数据集都是MNIST\CIFAR等标准化数据集,而我自己想要实现时,需要使用自己的数据,于是要了解这两个数据集如何构建。
|
||||
7
src/learn2learn_test.py
Normal file
7
src/learn2learn_test.py
Normal file
@@ -0,0 +1,7 @@
|
||||
import learn2learn as l2l
|
||||
import torch
|
||||
|
||||
# 检查 learn2learn 和 torch 是否可用
|
||||
print(f"Learn2Learn version: {l2l.__version__}")
|
||||
print(f"Torch version: {torch.__version__}")
|
||||
print(f"CUDA Available: {torch.cuda.is_available()}")
|
||||
45
src/learn2learn_test2.py
Normal file
45
src/learn2learn_test2.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import learn2learn as l2l
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.optim import SGD
|
||||
|
||||
# 定义设备:优先选择 GPU
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print("Device:", device)
|
||||
|
||||
# 定义一个简单的线性模型
|
||||
class LinearModel(nn.Module):
|
||||
def __init__(self, input_dim, output_dim):
|
||||
super(LinearModel, self).__init__()
|
||||
self.layer = nn.Sequential(
|
||||
nn.Linear(input_dim, 10),
|
||||
nn.ReLU(),
|
||||
nn.Linear(10, output_dim)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.layer(x)
|
||||
|
||||
# 初始化数据
|
||||
train_data = torch.tensor([[0.0], [1.0], [2.0], [3.0], [4.0], [5.0]], dtype=torch.float32).to(device)
|
||||
train_labels = torch.tensor([[0.0], [2.0], [4.0], [6.0], [8.0], [10.0]], dtype=torch.float32).to(device)
|
||||
|
||||
# 初始化模型、优化器、损失函数
|
||||
model = LinearModel(1, 1).to(device)
|
||||
maml = l2l.algorithms.MAML(model, lr=0.1) # MAML 包装器
|
||||
optimizer = SGD(maml.parameters(), lr=0.01)
|
||||
loss_fn = nn.MSELoss()
|
||||
|
||||
# 测试快速适应 (fast adaptation)
|
||||
for step in range(5): # 5 个任务
|
||||
learner = maml.clone() # 克隆模型
|
||||
print(f"Device of model: {next(learner.parameters()).device}") # 检查设备
|
||||
for _ in range(3): # 内部更新步骤
|
||||
predictions = learner(train_data)
|
||||
loss = loss_fn(predictions, train_labels)
|
||||
learner.adapt(loss)
|
||||
|
||||
# 测试损失
|
||||
test_predictions = learner(train_data)
|
||||
test_loss = loss_fn(test_predictions, train_labels)
|
||||
print(f"Step {step + 1}, Test Loss: {test_loss.item():.4f}")
|
||||
77
src/maml_demo.py
Normal file
77
src/maml_demo.py
Normal file
@@ -0,0 +1,77 @@
|
||||
import learn2learn as l2l
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
|
||||
# 定义模型
|
||||
class Net(nn.Module):
|
||||
def __init__(self, input_size, hidden_size, output_size):
|
||||
super(Net, self).__init__()
|
||||
self.fc1 = nn.Linear(input_size, hidden_size)
|
||||
self.fc2 = nn.Linear(hidden_size, hidden_size)
|
||||
self.fc3 = nn.Linear(hidden_size, output_size)
|
||||
|
||||
def forward(self, x):
|
||||
x = F.relu(self.fc1(x))
|
||||
x = F.relu(self.fc2(x))
|
||||
x = torch.sigmoid(self.fc3(x))
|
||||
return x
|
||||
|
||||
def main():
|
||||
# 超参数
|
||||
input_size = 30 # 与原项目保持一致
|
||||
hidden_size = 512
|
||||
output_size = 1
|
||||
meta_lr = 0.01
|
||||
adapt_lr = 0.1
|
||||
num_iterations = 1000
|
||||
|
||||
# 创建模型
|
||||
model = Net(input_size, hidden_size, output_size)
|
||||
maml = l2l.algorithms.MAML(model, lr=adapt_lr, first_order=False)
|
||||
opt = optim.Adam(maml.parameters(), meta_lr)
|
||||
|
||||
# 训练循环
|
||||
for iteration in range(num_iterations):
|
||||
opt.zero_grad()
|
||||
meta_train_error = 0.0
|
||||
meta_valid_error = 0.0
|
||||
|
||||
# 采样任务批次
|
||||
for task in range(5): # 5个任务
|
||||
# 克隆MAML模型
|
||||
learner = maml.clone()
|
||||
|
||||
# 获取支持集和查询集(这里需要根据实际数据修改)
|
||||
support_x = torch.randn(10, input_size) # 示例数据
|
||||
support_y = torch.randint(0, 2, (10, 1)).float()
|
||||
query_x = torch.randn(10, input_size)
|
||||
query_y = torch.randint(0, 2, (10, 1)).float()
|
||||
|
||||
# 内循环适应
|
||||
for _ in range(5): # 5步适应
|
||||
support_pred = learner(support_x)
|
||||
support_loss = F.binary_cross_entropy(support_pred, support_y)
|
||||
learner.adapt(support_loss)
|
||||
|
||||
# 计算查询集损失
|
||||
query_pred = learner(query_x)
|
||||
query_loss = F.binary_cross_entropy(query_pred, query_y)
|
||||
|
||||
# 累积元训练误差
|
||||
meta_train_error += query_loss
|
||||
|
||||
# 平均元训练误差
|
||||
meta_train_error = meta_train_error / 5
|
||||
|
||||
# 元优化
|
||||
meta_train_error.backward()
|
||||
opt.step()
|
||||
|
||||
# 打印训练进度
|
||||
if iteration % 100 == 0:
|
||||
print(f'Iteration {iteration}: Meta Train Error {meta_train_error.item():.4f}')
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
Reference in New Issue
Block a user