完成股吧数据分析项目:

1. 修复词云断句问题 - 添加英文单词过滤
2. 创建 Word2Vec + CNN 情绪感知模型
3. 创建情绪时间序列分析脚本(基于大连理工大学情感词典)
4. 添加停用词文件(1427个中英文停用词)
5. 更新 analyze.py 保存时间字段 post_publish_time
6. 更新 requirements.txt 添加必要依赖
This commit is contained in:
2026-05-28 15:30:16 +08:00
parent 5231e995dd
commit 0098977172
7 changed files with 2165 additions and 19 deletions
+163
View File
@@ -0,0 +1,163 @@
# 股吧数据爬取与情感分析系统
基于东方财富网股吧数据的爬虫系统,支持数据爬取、情感分析和关键词挖掘。
## 功能特性
- 🕷️ **股吧数据爬取** - 自动爬取指定股票的股吧帖子
- 😊 **情感分析** - 基于大连理工大学情感词汇本体进行情绪计算
- 🔍 **关键词挖掘** - 使用TF-IDF算法提取热门话题
- 📊 **可视化输出** - 生成词云、情绪分布图等可视化图表
## 项目结构
```
guba2vec/
├── spider.py # 股吧数据爬虫
├── sentiment_analysis.py # 情感分析模块
├── analyze.py # TF-IDF关键词分析
├── requirements.txt # 依赖列表
├── 大连理工大学中文情感词汇本体.xlsx # 情感词典
└── data/ # 爬取数据存储目录
└── guba_*.json/xlsx
```
## 安装依赖
```bash
pip install -r requirements.txt
```
## 使用方法
### 1. 爬取股吧数据
```bash
python spider.py
```
默认爬取以下游戏行业股票:
- 完美世界 (002624)
- 三七互娱 (002555)
- 巨人网络 (002558)
- 世纪华通 (002602)
- 昆仑万维 (300418)
- 游族网络 (002174)
- 掌趣科技 (300315)
- 吉比特 (603444)
爬取结果保存在 `data/` 目录下,包含 JSON 和 Excel 两种格式。
### 2. 情感分析
```bash
python sentiment_analysis.py
```
基于大连理工大学中文情感词汇本体进行情绪分析,支持7种情绪类型:
- 正面情绪:快乐、好评、惊讶
- 负面情绪:愤怒、悲伤、恐惧、厌恶
分析结果保存在 `sentiment_output/` 目录,包含:
- 各股票详细情感数据(CSV
- 情绪统计汇总(CSV
- 可视化图表(PNG
### 3. 关键词分析
```bash
python analyze.py
```
使用TF-IDF算法提取关键词并生成词云,结果保存在 `output/` 目录。
## 核心模块说明
### spider.py
主要函数:
- `fetch_guba_data(code, page, page_size, sort_type)` - 爬取单页数据
- `fetch_stock_posts(code, name, pages, page_size)` - 爬取多页数据
- `save_to_json(data, name, filename)` - 保存为JSON格式
- `save_to_excel(data, name, filename)` - 保存为Excel格式
### sentiment_analysis.py
主要函数:
- `build_sentiment_dictionary()` - 构建情感词典
- `emotion_caculate(text, sentiment_dict)` - 计算文本情绪
- `load_and_analyze_data(data_dir, output_dir)` - 批量分析数据
- `generate_visualizations()` - 生成可视化图表
### analyze.py
主要函数:
- `clean_text(text)` - 文本清洗
- `tokenize(text)` - 中文分词
- `calculate_tfidf(texts)` - 计算TF-IDF
- `get_top_keywords()` - 获取Top关键词
- `generate_wordcloud()` - 生成词云
## 情感词典
使用 **大连理工大学中文情感词汇本体**(需自行准备),包含:
- 27469个情感词汇
- 7种情感分类
- 3种强度等级
- 2种极性(正面/负面)
备用方案:内置简化版情感词典,包含约200个常用情感词。
## 数据格式
### 爬取数据 (JSON)
```json
{
"stock_code": "002624",
"stock_name": "完美世界",
"total_pages": 10,
"total_posts": 200,
"crawl_time": "2024-01-01T12:00:00",
"posts": [
{
"post_id": "123456",
"post_title": "标题",
"post_content": "内容",
"post_user": {"user_nickname": "用户名"},
"post_publish_time": "2024-01-01 10:00",
"post_click_count": 100,
"post_comment_count": 10,
"post_like_count": 5
}
]
}
```
### 情感分析结果 (CSV)
| 帖子ID | 标题 | 内容 | positive | negative | sentiment_score |
|--------|------|------|----------|----------|-----------------|
| 123456 | ... | ... | 5 | 2 | 3 |
## 注意事项
1. 爬虫使用模拟移动端请求,请合理控制爬取频率
2. 情感词典文件需放置在项目根目录
3. 首次运行可能需要下载jieba分词字典
4. 生成词云需要系统安装中文字体(默认使用SimHei)
## 依赖列表
| 库 | 版本 | 用途 |
|----|------|------|
| requests | >=2.28.0 | HTTP请求 |
| pandas | >=2.0.0 | 数据处理 |
| openpyxl | >=3.1.0 | Excel读写 |
| jieba | >=0.42.1 | 中文分词 |
| scikit-learn | >=1.3.0 | TF-IDF计算 |
| numpy | >=1.24.0 | 数值计算 |
| matplotlib | >=3.7.0 | 可视化 |
| wordcloud | >=1.9.0 | 词云生成 |
## License
MIT License
+41 -13
View File
@@ -10,17 +10,30 @@ import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('Agg') # 使用非交互式后端
# 中文停用词表
STOPWORDS = {
def load_stopwords(filepath='stopwords.txt'):
"""从文件加载停用词"""
stopwords = set()
if os.path.exists(filepath):
with open(filepath, 'r', encoding='utf-8') as f:
for line in f:
word = line.strip()
if word:
stopwords.add(word)
print(f"已加载 {len(stopwords)} 个停用词")
else:
print(f"警告:停用词文件 {filepath} 不存在,使用默认停用词")
stopwords = {
'', '', '', '', '', '', '', '', '', '', '', '', '一个', '', '', '', '', '', '',
'', '', '', '', '没有', '', '', '自己', '', '', '', '', '', '', '', '', '什么', '怎么',
'为什么', '哪里', '', '多少', '', '', '', '', '', '', '', '', '', '今天', '明天', '昨天', '',
'', '', '已经', '还是', '还是', '但是', '可是', '不过', '只是', '只有', '就是', '还是', '或者', '还是', '还是',
'这个', '那个', '这些', '那些', '那么', '这么', '怎么', '如何', '因为', '所以', '虽然', '但是', '如果', '', '那么',
'', '', '', '', '', '还是', '还是', '还是', '还是', '还是', '还是', '还是', '还是', '还是', '还是',
'', '', '已经', '还是', '但是', '可是', '不过', '只是', '只有', '就是', '或者', '', '', '', '', '',
'股吧', '东方财富', '帖子', '发表', '回复', '点击', '查看', '更多', '原文', '转发', '分享', '收藏', '评论', '点赞',
'http', 'https', 'com', 'cn', 'www', 'net', 'org'
}
}
return stopwords
# 加载停用词
STOPWORDS = load_stopwords()
def clean_text(text):
"""清洗文本"""
@@ -32,10 +45,11 @@ def clean_text(text):
text = re.sub(r'<.*?>', '', text)
# 移除表情符号
text = re.sub(r'\[.*?\]', '', text)
# 移除特殊字符
text = re.sub(r'[^\w\s]', '', text)
# 移除数字
text = re.sub(r'\d+', '', text)
# 移除纯英文和数字混合的无效标记(如 sh123、abc456等)
text = re.sub(r'\b[a-zA-Z]+\d+\b', '', text)
text = re.sub(r'\b\d+[a-zA-Z]+\b', '', text)
# 移除特殊字符(保留中文、英文、数字)
text = re.sub(r'[^\w\s]', ' ', text)
# 移除多余空格
text = re.sub(r'\s+', ' ', text).strip()
return text
@@ -43,9 +57,21 @@ def clean_text(text):
def tokenize(text):
"""中文分词"""
words = jieba.lcut(text)
# 过滤停用词短词
words = [w for w in words if w not in STOPWORDS and len(w) > 1]
return words
# 过滤停用词短词、纯英文单词和无意义字符
filtered_words = []
for w in words:
# 跳过停用词和短词
if w in STOPWORDS or len(w) <= 1:
continue
# 检查是否是纯英文单词
if re.match(r'^[a-zA-Z]+$', w):
# 过滤掉纯英文单词(通常是论坛标记、无意义的缩写等)
continue
# 检查是否包含无意义的英文字符组合
if re.match(r'^[a-zA-Z\s]+$', w):
continue
filtered_words.append(w)
return filtered_words
def load_data(data_dir='data'):
"""加载所有股票数据"""
@@ -74,6 +100,7 @@ def load_data(data_dir='data'):
for post in posts:
content = post.get('post_content', '')
title = post.get('post_title', '')
publish_time = post.get('post_publish_time', '')
full_text = f"{title} {content}".strip()
if full_text:
@@ -81,6 +108,7 @@ def load_data(data_dir='data'):
'stock_code': stock_code,
'stock_name': stock_name,
'post_id': post.get('post_id'),
'post_publish_time': publish_time,
'text': full_text,
'clean_text': clean_text(full_text)
})
+3
View File
@@ -7,3 +7,6 @@ numpy>=1.24.0
matplotlib>=3.7.0
seaborn>=0.12.0
wordcloud>=1.9.0
gensim>=4.3.0
tensorflow>=2.10.0
keras>=2.10.0
+297
View File
@@ -0,0 +1,297 @@
import os
import json
import re
import numpy as np
import pandas as pd
from datetime import datetime
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('Agg')
import jieba
# 设置中文字体
plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'SimSun', 'Arial Unicode MS']
plt.rcParams['axes.unicode_minus'] = False
# 加载停用词
def load_stopwords(filepath='stopwords.txt'):
stopwords = set()
if os.path.exists(filepath):
with open(filepath, 'r', encoding='utf-8') as f:
for line in f:
word = line.strip()
if word:
stopwords.add(word)
return stopwords
STOPWORDS = load_stopwords()
# ============================================================
# 构建情感词典(参照 sentiment_analysis.py
# ============================================================
def build_sentiment_dictionary():
"""使用大连理工大学中文情感词汇本体构建情感词典"""
dict_path = '大连理工大学中文情感词汇本体.xlsx'
try:
df = pd.read_excel(dict_path)
df = df[['词语', '词性种类', '词义数', '词义序号', '情感分类', '强度', '极性']]
Happy = []
Good = []
Surprise = []
Anger = []
Sad = []
Fear = []
Disgust = []
for idx, row in df.iterrows():
if row['情感分类'] in ['PA', 'PE']:
Happy.append(row['词语'])
if row['情感分类'] in ['PD', 'PH', 'PG', 'PB', 'PK']:
Good.append(row['词语'])
if row['情感分类'] in ['PC']:
Surprise.append(row['词语'])
if row['情感分类'] in ['NA']:
Anger.append(row['词语'])
if row['情感分类'] in ['NB', 'NJ', 'NH', 'PF']:
Sad.append(row['词语'])
if row['情感分类'] in ['NI', 'NC', 'NG']:
Fear.append(row['词语'])
if row['情感分类'] in ['NE', 'ND', 'NN', 'NK', 'NL']:
Disgust.append(row['词语'])
# 添加股票相关词汇
stock_positive = ['', '上涨', '暴涨', '拉升', '涨停', '盈利', '收益', '赚钱', '',
'利好', '增长', '上升', '增加', '发展', '进步', '提升', '改善', '突破',
'创新', '优势', '超预期', '亮眼', '惊艳', '奇迹']
stock_negative = ['', '下跌', '暴跌', '跳水', '跌停', '亏损', '亏钱', '', '损失',
'套牢', '垃圾', '恶心', '坑爹', '骗局', '', '爆雷', '崩盘', '退市']
Good.extend(stock_positive)
Disgust.extend(stock_negative)
Positive = Happy + Good + Surprise
Negative = Anger + Sad + Fear + Disgust
print(f'大连理工大学情感词典加载完成')
print(f' 正面情感词: {len(Positive)}')
print(f' 负面情感词: {len(Negative)}')
return {
'Happy': Happy,
'Good': Good,
'Surprise': Surprise,
'Anger': Anger,
'Sad': Sad,
'Fear': Fear,
'Disgust': Disgust,
'Positive': Positive,
'Negative': Negative
}
except Exception as e:
print(f'加载大连理工大学情感词典失败: {e}')
print('使用简化版情感词典')
return build_simplified_dictionary()
def build_simplified_dictionary():
"""构建简化的中文情感词典(备用方案)"""
Happy = ['开心', '快乐', '高兴', '喜悦', '愉快', '欣喜', '欢乐', '欢喜', '幸福',
'满意', '满足', '欣慰', '愉悦', '畅快', '乐观', '积极', '美好', '成功']
Good = ['', '优秀', '出色', '精彩', '卓越', '杰出', '优良', '良好', '完美', '不错',
'', '上涨', '暴涨', '拉升', '涨停', '盈利', '收益', '赚钱', '', '利好',
'增长', '上升', '增加', '发展', '进步', '提升', '改善', '突破', '创新', '优势']
Surprise = ['惊喜', '意外', '震惊', '惊讶', '震撼', '神奇', '奇迹', '惊艳', '亮眼', '超预期']
Anger = ['愤怒', '生气', '恼火', '气愤', '暴怒', '愤慨', '愤恨', '震怒', '发怒',
'', '垃圾', '恶心', '坑爹', '骗局', '欺骗', '欺诈', '造假', '腐败', '黑暗']
Sad = ['伤心', '难过', '悲伤', '痛苦', '悲哀', '沮丧', '失望', '绝望', '低落', '悲观',
'', '下跌', '暴跌', '跳水', '跌停', '亏损', '亏钱', '', '损失', '套牢']
Fear = ['害怕', '恐惧', '担心', '担忧', '恐慌', '不安', '焦虑', '忧虑', '紧张', '恐怖',
'风险', '危机', '危险', '下跌', '暴跌', '崩盘', '退市', '爆雷', '', '']
Disgust = ['厌恶', '恶心', '反感', '讨厌', '鄙视', '唾弃', '不屑', '蔑视', '嫌弃',
'垃圾', '废物', '不行', '差劲', '', '', '', '骗局']
Positive = Happy + Good + Surprise
Negative = Anger + Sad + Fear + Disgust
print(f'简化版情感词典构建完成')
print(f' 正面情感词: {len(Positive)}')
print(f' 负面情感词: {len(Negative)}')
return {
'Happy': Happy,
'Good': Good,
'Surprise': Surprise,
'Anger': Anger,
'Sad': Sad,
'Fear': Fear,
'Disgust': Disgust,
'Positive': Positive,
'Negative': Negative
}
# ============================================================
# 情绪计算函数(参照 sentiment_analysis.py
# ============================================================
def emotion_caculate(text, sentiment_dict):
"""计算单条文本的情绪"""
if not text or pd.isna(text):
return 0
positive = 0
negative = 0
wordlist = jieba.lcut(text)
for word in wordlist:
# 跳过停用词和短词
if word in STOPWORDS or len(word) <= 1:
continue
freq = wordlist.count(word)
if word in sentiment_dict['Positive']:
positive += freq
if word in sentiment_dict['Negative']:
negative += freq
sentiment_score = positive - negative
return sentiment_score
# ============================================================
# 时间序列分析
# ============================================================
def analyze_sentiment_trend():
"""分析情绪时间序列趋势(使用情感词典)"""
print("="*60)
print("情绪时间序列分析(基于情感词典)")
print("="*60)
# 构建情感词典
print("\n[1/5] 构建情感词典...")
sentiment_dict = build_sentiment_dictionary()
# 加载数据
print("\n[2/5] 加载数据...")
df = pd.read_csv('output/all_posts.csv', encoding='utf-8-sig')
# 检查是否有 post_publish_time 字段
if 'post_publish_time' not in df.columns:
print("警告:数据中没有 post_publish_time 字段,请先运行 analyze.py")
return
# 转换时间戳
print("\n[3/5] 转换时间戳...")
df['timestamp'] = pd.to_datetime(df['post_publish_time'], errors='coerce')
df = df.dropna(subset=['timestamp'])
df['date'] = df['timestamp'].dt.date
# 计算情绪得分
print("\n[4/5] 计算情绪得分...")
df['sentiment_score'] = df['clean_text'].apply(
lambda x: emotion_caculate(x, sentiment_dict)
)
# 保存结果
df.to_csv('output/sentiment_analysis_result.csv', index=False, encoding='utf-8-sig')
print(" 情绪分析结果已保存到: output/sentiment_analysis_result.csv")
# 按股票分组分析
stock_groups = df.groupby('stock_code')
os.makedirs('output/plots', exist_ok=True)
print("\n[5/5] 生成时间序列图表...")
for stock_code, group in stock_groups:
stock_name = group['stock_name'].iloc[0]
print(f"\n 分析 {stock_name} ({stock_code})...")
# 按日期分组计算平均情绪
daily_sentiment = group.groupby('date')['sentiment_score'].agg(['mean', 'count']).reset_index()
daily_sentiment.columns = ['date', 'avg_sentiment', 'post_count']
if len(daily_sentiment) < 2:
print(f" 数据不足,跳过")
continue
# 绘制时间序列图
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 8), sharex=True)
# 情绪趋势
ax1.plot(daily_sentiment['date'], daily_sentiment['avg_sentiment'],
marker='o', linestyle='-', color='b', label='日均情绪')
# 添加移动平均线
daily_sentiment['MA3'] = daily_sentiment['avg_sentiment'].rolling(window=3).mean()
ax1.plot(daily_sentiment['date'], daily_sentiment['MA3'],
marker='', linestyle='--', color='r', label='3日移动平均')
ax1.set_title(f'{stock_name} ({stock_code}) 情绪时间序列趋势', fontsize=14)
ax1.set_ylabel('情绪分数', fontsize=12)
ax1.axhline(y=0, color='gray', linestyle='-', linewidth=0.5)
ax1.grid(True)
ax1.legend()
# 发帖量
ax2.bar(daily_sentiment['date'], daily_sentiment['post_count'], color='g', alpha=0.7)
ax2.set_xlabel('日期', fontsize=12)
ax2.set_ylabel('发帖数量', fontsize=12)
ax2.grid(True)
plt.xticks(rotation=45)
plt.tight_layout()
# 保存图表
plot_path = f'output/plots/sentiment_trend_{stock_name}.png'
plt.savefig(plot_path, dpi=100)
plt.close()
print(f" 图表已保存到: {plot_path}")
# 输出统计信息
avg_sentiment = group['sentiment_score'].mean()
pos_count = (group['sentiment_score'] > 0).sum()
neg_count = (group['sentiment_score'] < 0).sum()
neu_count = (group['sentiment_score'] == 0).sum()
print(f" 平均情绪: {avg_sentiment:.4f}")
print(f" 正面帖子: {pos_count}, 负面帖子: {neg_count}, 中性帖子: {neu_count}")
# 生成汇总报告
print("\n生成汇总报告...")
summary_data = []
for stock_code, group in stock_groups:
stock_name = group['stock_name'].iloc[0]
avg_sentiment = group['sentiment_score'].mean()
post_count = len(group)
pos_count = (group['sentiment_score'] > 0).sum()
neg_count = (group['sentiment_score'] < 0).sum()
neu_count = (group['sentiment_score'] == 0).sum()
summary_data.append({
'股票代码': stock_code,
'股票名称': stock_name,
'帖子数量': post_count,
'平均情绪': round(avg_sentiment, 4),
'正面帖子': pos_count,
'负面帖子': neg_count,
'中性帖子': neu_count
})
summary_df = pd.DataFrame(summary_data)
summary_df.to_csv('output/sentiment_summary.csv', index=False, encoding='utf-8-sig')
print("汇总报告已保存到: output/sentiment_summary.csv")
print("\n" + "="*60)
print("情绪时间序列分析完成!")
print("="*60)
if __name__ == '__main__':
analyze_sentiment_trend()
+1 -1
View File
@@ -168,7 +168,7 @@ if __name__ == '__main__':
print(f'{"="*50}')
# 爬取10页数据
data = fetch_stock_posts(code, name, pages=10)
data = fetch_stock_posts(code, name, pages=30)
if data and data['total_posts'] > 0:
print(f'\n共获取 {data["total_posts"]} 条帖子')
+1426
View File
File diff suppressed because it is too large Load Diff
+229
View File
@@ -0,0 +1,229 @@
import os
import json
import re
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, accuracy_score
from gensim.models import Word2Vec
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Embedding, Conv1D, GlobalMaxPooling1D, Dense, Dropout
from tensorflow.keras.utils import to_categorical
import jieba
def load_stopwords(filepath='stopwords.txt'):
"""从文件加载停用词"""
stopwords = set()
if os.path.exists(filepath):
with open(filepath, 'r', encoding='utf-8') as f:
for line in f:
word = line.strip()
if word:
stopwords.add(word)
print(f"已加载 {len(stopwords)} 个停用词")
else:
print(f"警告:停用词文件 {filepath} 不存在,使用默认停用词")
stopwords = {
'', '', '', '', '', '', '', '', '', '', '', '', '一个', '', '', '', '', '', '',
'', '', '', '', '没有', '', '', '自己', '', '', '', '', '', '', '', '', '什么', '怎么',
'为什么', '哪里', '', '多少', '', '', '', '', '', '', '', '', '', '今天', '明天', '昨天', '',
'', '', '已经', '还是', '但是', '可是', '不过', '只是', '只有', '就是', '或者', '', '', '', '', '',
'股吧', '东方财富', '帖子', '发表', '回复', '点击', '查看', '更多', '原文', '转发', '分享', '收藏', '评论', '点赞',
'http', 'https', 'com', 'cn', 'www', 'net', 'org'
}
return stopwords
# 加载停用词
STOPWORDS = load_stopwords()
def clean_text(text):
"""清洗文本"""
if not text or pd.isna(text):
return ""
text = str(text)
text = re.sub(r'https?://\S+|www\.\S+', '', text)
text = re.sub(r'<.*?>', '', text)
text = re.sub(r'\[.*?\]', '', text)
text = re.sub(r'\b[a-zA-Z]+\d+\b', '', text)
text = re.sub(r'\b\d+[a-zA-Z]+\b', '', text)
text = re.sub(r'[^\w\s]', ' ', text)
text = re.sub(r'\s+', ' ', text).strip()
return text
def tokenize(text):
"""中文分词"""
words = jieba.lcut(text)
filtered_words = []
for w in words:
if w in STOPWORDS or len(w) <= 1:
continue
if re.match(r'^[a-zA-Z]+$', w):
continue
if re.match(r'^[a-zA-Z\s]+$', w):
continue
filtered_words.append(w)
return filtered_words
def load_and_preprocess_data(filepath='output/all_posts.csv'):
"""加载并预处理数据"""
df = pd.read_csv(filepath, encoding='utf-8-sig')
print(f"原始数据: {len(df)}")
df = df.dropna(subset=['clean_text', 'label'])
df = df[df['clean_text'].str.strip() != '']
print(f"有效数据: {len(df)}")
print(f"标签分布:")
print(df['label'].value_counts())
df['tokens'] = df['clean_text'].apply(tokenize)
df = df[df['tokens'].apply(len) > 0]
print(f"分词后有效数据: {len(df)}")
return df
def train_word2vec_model(sentences, vector_size=100, window=5, min_count=5):
"""训练 Word2Vec 模型"""
print(f"\n训练 Word2Vec 模型...")
model = Word2Vec(
sentences=sentences,
vector_size=vector_size,
window=window,
min_count=min_count,
workers=4,
epochs=10
)
print(f"Word2Vec 词汇表大小: {len(model.wv)}")
return model
def build_cnn_model(vocab_size, embedding_dim, max_seq_len, embedding_matrix, num_classes=3):
"""构建 CNN 模型"""
model = Sequential()
model.add(Embedding(
input_dim=vocab_size,
output_dim=embedding_dim,
input_length=max_seq_len,
weights=[embedding_matrix],
trainable=False
))
model.add(Conv1D(128, 5, activation='relu'))
model.add(GlobalMaxPooling1D())
model.add(Dense(64, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(num_classes, activation='softmax'))
model.compile(
optimizer='adam',
loss='categorical_crossentropy',
metrics=['accuracy']
)
return model
def main():
print("="*60)
print("Word2Vec + CNN 情绪感知模型训练")
print("="*60)
# 加载数据
print("\n[1/5] 加载数据...")
df = load_and_preprocess_data()
if len(df) < 10:
print("数据不足,无法训练")
return
# 准备 Word2Vec 训练数据
sentences = df['tokens'].tolist()
# 训练 Word2Vec
print("\n[2/5] 训练 Word2Vec 词向量...")
w2v_model = train_word2vec_model(sentences)
# 构建词汇表
print("\n[3/5] 构建词汇表...")
tokenizer = Tokenizer()
tokenizer.fit_on_texts(sentences)
vocab_size = len(tokenizer.word_index) + 1
print(f"词汇表大小: {vocab_size}")
# 转换文本为序列
max_seq_len = max(len(s) for s in sentences)
print(f"最大序列长度: {max_seq_len}")
sequences = tokenizer.texts_to_sequences(sentences)
X = pad_sequences(sequences, maxlen=max_seq_len)
# 准备标签
label_mapping = {-1: 0, 0: 1, 1: 2}
y = df['label'].map(label_mapping).values
y = to_categorical(y, num_classes=3)
# 创建嵌入矩阵
print("\n[4/5] 创建嵌入矩阵...")
embedding_dim = w2v_model.vector_size
embedding_matrix = np.zeros((vocab_size, embedding_dim))
for word, i in tokenizer.word_index.items():
if word in w2v_model.wv:
embedding_matrix[i] = w2v_model.wv[word]
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42, stratify=y
)
print(f"训练集: {len(X_train)}")
print(f"测试集: {len(X_test)}")
# 构建并训练 CNN 模型
print("\n[5/5] 训练 CNN 模型...")
model = build_cnn_model(vocab_size, embedding_dim, max_seq_len, embedding_matrix)
print(model.summary())
history = model.fit(
X_train, y_train,
batch_size=32,
epochs=10,
validation_split=0.1,
verbose=1
)
# 评估模型
print("\n[6/6] 评估模型...")
y_pred = model.predict(X_test)
y_pred_classes = np.argmax(y_pred, axis=1)
y_true_classes = np.argmax(y_test, axis=1)
print("\n分类报告:")
print(classification_report(y_true_classes, y_pred_classes, target_names=['负面', '中性', '正面']))
print(f"准确率: {accuracy_score(y_true_classes, y_pred_classes):.4f}")
# 保存模型
print("\n保存模型...")
os.makedirs('models', exist_ok=True)
# 保存 Word2Vec 模型
w2v_model.save('models/word2vec.model')
print("Word2Vec 模型已保存到: models/word2vec.model")
# 保存 CNN 模型
model.save('models/cnn_sentiment.h5')
print("CNN 模型已保存到: models/cnn_sentiment.h5")
# 保存 tokenizer
with open('models/tokenizer.json', 'w', encoding='utf-8') as f:
f.write(tokenizer.to_json())
print("Tokenizer 已保存到: models/tokenizer.json")
print("\n" + "="*60)
print("训练完成!")
print("="*60)
if __name__ == '__main__':
main()