Files
labweb/models/utils_decoding.py
2025-12-16 11:39:15 +08:00

63 lines
2.2 KiB
Python

import pandas as pd
import ast
from collections import Counter
################################ 计算概率,排序 ##############################
def statistical_frequency(combined_df):
percentages_list = []
if 'zhenshi' in combined_df.colums:
zhenshi = combined_df['zhenshi']
combined_df = combined_df.drop(columns=['zhenshi'])
for i in range(len(combined_df)):
row = combined_df.iloc[i]
all_beam = []
# 遍历当前行的每列,收集所有元素
for j in range(len(row)):
# 检查 row[j] 是否为字符串,必要时转换为列表
if isinstance(row[j], str):
try:
aa = ast.literal_eval(row[j]) # 转换字符串为列表
except (ValueError, SyntaxError):
aa = [] # 如果无法解析,设置为空列表
else:
aa = row[j] if isinstance(row[j], list) else [] # 如果是列表,直接使用
all_beam.extend(sorted(set(aa)))
# 统计频率并计算比例
frequencies = Counter(all_beam)
total_elements = len(combined_df.columns) # 用总元素数而非列数作为分母
if total_elements > 0:
percentage = {key: value / total_elements for key, value in frequencies.items()}
# 对比例从高到低排序
sorted_percentage = dict(sorted(percentage.items(), key=lambda item: item[1], reverse=True))
# 将排序后的结果添加到列表
percentages_list.append(sorted_percentage)
# 将计算好的 percentages_list 添加为 DataFrame 的新列
combined_df['frequency'] = percentages_list
combined_df['zhenshi'] = zhenshi
return combined_df
###################################### 根据值查找键 ##########################
def keys_valus(tgt_dict, key):
value = [k for k, v in tgt_dict.items() if v == key]
return value
# 将张量转换为营养成分
def medium(b, tgt_dict):
l = []
for i in b:
tensor = i
l.extend(keys_valus(tgt_dict, tensor))
l = [item for item in l if item not in ['</s>', '<s>', 'blank']]
return l