You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
84 lines
2.7 KiB
84 lines
2.7 KiB
# 认证模块
|
|
|
|
from flask import Flask, request, redirect, url_for, flash
|
|
from flask_login import LoginManager, UserMixin, login_user, login_required, logout_user, current_user
|
|
from werkzeug.security import generate_password_hash, check_password_hash
|
|
import sqlite3
|
|
import os
|
|
|
|
# 初始化Flask-Login
|
|
login_manager = LoginManager()
|
|
|
|
class User(UserMixin):
|
|
def __init__(self, id, username, password_hash):
|
|
self.id = id
|
|
self.username = username
|
|
self.password_hash = password_hash
|
|
|
|
def check_password(self, password):
|
|
return check_password_hash(self.password_hash, password)
|
|
|
|
@login_manager.user_loader
|
|
def load_user(user_id):
|
|
"""根据用户ID加载用户对象"""
|
|
conn = get_db_connection()
|
|
user = conn.execute('SELECT id, username, password_hash FROM users WHERE id = ?', (user_id,)).fetchone()
|
|
conn.close()
|
|
if user:
|
|
return User(user['id'], user['username'], user['password_hash'])
|
|
return None
|
|
|
|
def get_db_connection():
|
|
"""获取数据库连接"""
|
|
db_path = os.path.join(os.path.dirname(__file__), 'data', 'futures_analysis.db')
|
|
conn = sqlite3.connect(db_path)
|
|
conn.row_factory = sqlite3.Row
|
|
return conn
|
|
|
|
def init_db():
|
|
"""初始化数据库,创建用户表"""
|
|
conn = get_db_connection()
|
|
conn.execute('''
|
|
CREATE TABLE IF NOT EXISTS users (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
username TEXT UNIQUE NOT NULL,
|
|
password_hash TEXT NOT NULL
|
|
)
|
|
''')
|
|
conn.commit()
|
|
conn.close()
|
|
|
|
def register_user(username, password):
|
|
"""注册新用户"""
|
|
conn = get_db_connection()
|
|
try:
|
|
# 检查用户名是否已存在
|
|
existing_user = conn.execute('SELECT id FROM users WHERE username = ?', (username,)).fetchone()
|
|
if existing_user:
|
|
return False, "用户名已存在"
|
|
|
|
# 生成密码哈希
|
|
password_hash = generate_password_hash(password)
|
|
|
|
# 插入新用户
|
|
conn.execute('INSERT INTO users (username, password_hash) VALUES (?, ?)', (username, password_hash))
|
|
conn.commit()
|
|
return True, "注册成功"
|
|
except Exception as e:
|
|
conn.rollback()
|
|
return False, str(e)
|
|
finally:
|
|
conn.close()
|
|
|
|
def login_user_by_credentials(username, password):
|
|
"""通过用户名和密码登录用户"""
|
|
conn = get_db_connection()
|
|
user = conn.execute('SELECT id, username, password_hash FROM users WHERE username = ?', (username,)).fetchone()
|
|
conn.close()
|
|
|
|
if user and check_password_hash(user['password_hash'], password):
|
|
user_obj = User(user['id'], user['username'], user['password_hash'])
|
|
login_user(user_obj)
|
|
return True, "登录成功"
|
|
return False, "用户名或密码错误"
|