You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

532 lines
14 KiB

package tushare
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"time"
)
const (
DefaultBaseURL = "http://api.tushare.pro"
)
// Client Tushare API客户端
type Client struct {
token string
baseURL string
client *http.Client
}
// NewClient 创建Tushare客户端
func NewClient(token string) *Client {
return &Client{
token: token,
baseURL: DefaultBaseURL,
client: &http.Client{Timeout: 30 * time.Second},
}
}
// SetBaseURL 设置基础URL用于测试
func (c *Client) SetBaseURL(baseURL string) {
c.baseURL = baseURL
}
// Request 通用请求结构
type Request struct {
APIName string `json:"api_name"`
Token string `json:"token"`
Params map[string]interface{} `json:"params"`
Fields string `json:"fields,omitempty"`
}
// Response 通用响应结构
type Response struct {
Code int `json:"code"`
Msg string `json:"msg"`
Data *Data `json:"data"`
}
// Data 响应数据
type Data struct {
Fields []string `json:"fields"`
Items [][]interface{} `json:"items"`
}
// Error 实现error接口
func (r *Response) Error() string {
return fmt.Sprintf("tushare error: code=%d, msg=%s", r.Code, r.Msg)
}
// IsSuccess 判断是否成功
func (r *Response) IsSuccess() bool {
return r.Code == 0
}
// Query 执行查询
func (c *Client) Query(apiName string, params map[string]interface{}, fields string) (*Response, error) {
reqBody := Request{
APIName: apiName,
Token: c.token,
Params: params,
Fields: fields,
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("marshal request failed: %w", err)
}
resp, err := c.client.Post(c.baseURL, "application/json", bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("http request failed: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read response body failed: %w", err)
}
var result Response
if err := json.Unmarshal(body, &result); err != nil {
return nil, fmt.Errorf("unmarshal response failed: %w", err)
}
if !result.IsSuccess() {
return &result, result
}
return &result, nil
}
// ToMapList 将Data转换为map列表
func (d *Data) ToMapList() []map[string]interface{} {
if d == nil || len(d.Fields) == 0 || len(d.Items) == 0 {
return nil
}
result := make([]map[string]interface{}, len(d.Items))
for i, item := range d.Items {
m := make(map[string]interface{})
for j, field := range d.Fields {
if j < len(item) {
m[field] = item[j]
}
}
result[i] = m
}
return result
}
// StockDaily 股票日线数据
type StockDaily struct {
TSCode string `json:"ts_code"` // 股票代码
TradeDate string `json:"trade_date"` // 交易日期
Open float64 `json:"open"` // 开盘价
High float64 `json:"high"` // 最高价
Low float64 `json:"low"` // 最低价
Close float64 `json:"close"` // 收盘价
PreClose float64 `json:"pre_close"` // 昨收价
Change float64 `json:"change"` // 涨跌额
PctChange float64 `json:"pct_chg"` // 涨跌幅
Volume float64 `json:"vol"` // 成交量(手)
Amount float64 `json:"amount"` // 成交额(千元)
}
// GetStockDaily 获取股票日线数据
func (c *Client) GetStockDaily(tsCode, startDate, endDate string) ([]StockDaily, error) {
params := map[string]interface{}{
"ts_code": tsCode,
"start_date": startDate,
"end_date": endDate,
}
resp, err := c.Query("daily", params, "")
if err != nil {
return nil, err
}
items := resp.Data.ToMapList()
result := make([]StockDaily, len(items))
for i, item := range items {
result[i] = StockDaily{
TSCode: getString(item, "ts_code"),
TradeDate: getString(item, "trade_date"),
Open: getFloat64(item, "open"),
High: getFloat64(item, "high"),
Low: getFloat64(item, "low"),
Close: getFloat64(item, "close"),
PreClose: getFloat64(item, "pre_close"),
Change: getFloat64(item, "change"),
PctChange: getFloat64(item, "pct_chg"),
Volume: getFloat64(item, "vol"),
Amount: getFloat64(item, "amount"),
}
}
return result, nil
}
// StockMinute 股票分钟线数据
type StockMinute struct {
TSCode string `json:"ts_code"`
TradeTime string `json:"trade_time"`
Open float64 `json:"open"`
High float64 `json:"high"`
Low float64 `json:"low"`
Close float64 `json:"close"`
Volume float64 `json:"vol"`
Amount float64 `json:"amount"`
}
// GetStockMinute 获取股票分钟线数据
func (c *Client) GetStockMinute(tsCode, startDate, endDate string, freq string) ([]StockMinute, error) {
apiName := "stk_mins" // 默认1分钟
switch freq {
case "5":
apiName = "stk_mins5"
case "15":
apiName = "stk_mins15"
case "30":
apiName = "stk_mins30"
case "60":
apiName = "stk_mins60"
}
params := map[string]interface{}{
"ts_code": tsCode,
"start_date": startDate,
"end_date": endDate,
}
resp, err := c.Query(apiName, params, "")
if err != nil {
return nil, err
}
items := resp.Data.ToMapList()
result := make([]StockMinute, len(items))
for i, item := range items {
result[i] = StockMinute{
TSCode: getString(item, "ts_code"),
TradeTime: getString(item, "trade_time"),
Open: getFloat64(item, "open"),
High: getFloat64(item, "high"),
Low: getFloat64(item, "low"),
Close: getFloat64(item, "close"),
Volume: getFloat64(item, "vol"),
Amount: getFloat64(item, "amount"),
}
}
return result, nil
}
// FuturesDaily 期货日线数据
type FuturesDaily struct {
TSCode string `json:"ts_code"`
TradeDate string `json:"trade_date"`
Open float64 `json:"open"`
High float64 `json:"high"`
Low float64 `json:"low"`
Close float64 `json:"close"`
PreClose float64 `json:"pre_close"`
Change float64 `json:"change"`
PctChange float64 `json:"pct_chg"`
Volume float64 `json:"vol"`
Amount float64 `json:"amount"`
OpenInterest float64 `json:"oi"`
OiChange float64 `json:"oi_chg"`
}
// GetFuturesDaily 获取期货日线数据
func (c *Client) GetFuturesDaily(tsCode, startDate, endDate string) ([]FuturesDaily, error) {
params := map[string]interface{}{
"ts_code": tsCode,
"start_date": startDate,
"end_date": endDate,
}
resp, err := c.Query("fut_daily", params, "")
if err != nil {
return nil, err
}
items := resp.Data.ToMapList()
result := make([]FuturesDaily, len(items))
for i, item := range items {
result[i] = FuturesDaily{
TSCode: getString(item, "ts_code"),
TradeDate: getString(item, "trade_date"),
Open: getFloat64(item, "open"),
High: getFloat64(item, "high"),
Low: getFloat64(item, "low"),
Close: getFloat64(item, "close"),
PreClose: getFloat64(item, "pre_close"),
Change: getFloat64(item, "change"),
PctChange: getFloat64(item, "pct_chg"),
Volume: getFloat64(item, "vol"),
Amount: getFloat64(item, "amount"),
OpenInterest: getFloat64(item, "oi"),
OiChange: getFloat64(item, "oi_chg"),
}
}
return result, nil
}
// FuturesMinute 期货分钟线数据
type FuturesMinute struct {
TSCode string `json:"ts_code"`
TradeTime string `json:"trade_time"`
Open float64 `json:"open"`
High float64 `json:"high"`
Low float64 `json:"low"`
Close float64 `json:"close"`
Volume float64 `json:"vol"`
Amount float64 `json:"amount"`
OpenInterest float64 `json:"oi"`
}
// GetFuturesMinute 获取期货分钟线数据
func (c *Client) GetFuturesMinute(tsCode, startDate, endDate string, freq string) ([]FuturesMinute, error) {
apiName := "fut_mins" // 默认1分钟
switch freq {
case "5":
apiName = "fut_mins5"
case "15":
apiName = "fut_mins15"
case "30":
apiName = "fut_mins30"
case "60":
apiName = "fut_mins60"
}
params := map[string]interface{}{
"ts_code": tsCode,
"start_date": startDate,
"end_date": endDate,
}
resp, err := c.Query(apiName, params, "")
if err != nil {
return nil, err
}
items := resp.Data.ToMapList()
result := make([]FuturesMinute, len(items))
for i, item := range items {
result[i] = FuturesMinute{
TSCode: getString(item, "ts_code"),
TradeTime: getString(item, "trade_time"),
Open: getFloat64(item, "open"),
High: getFloat64(item, "high"),
Low: getFloat64(item, "low"),
Close: getFloat64(item, "close"),
Volume: getFloat64(item, "vol"),
Amount: getFloat64(item, "amount"),
OpenInterest: getFloat64(item, "oi"),
}
}
return result, nil
}
// StockBasic 股票基础信息
type StockBasic struct {
TSCode string `json:"ts_code"`
Symbol string `json:"symbol"`
Name string `json:"name"`
Area string `json:"area"`
Industry string `json:"industry"`
FullName string `json:"fullname"`
EnName string `json:"enname"`
CNName string `json:"cnspell"`
Market string `json:"market"`
Exchange string `json:"exchange"`
CurrType string `json:"curr_type"`
ListStatus string `json:"list_status"`
ListDate string `json:"list_date"`
DelistDate string `json:"delist_date"`
IsHS string `json:"is_hs"`
}
// GetStockBasic 获取股票基础信息
func (c *Client) GetStockBasic() ([]StockBasic, error) {
resp, err := c.Query("stock_basic", map[string]interface{}{"list_status": "L"}, "")
if err != nil {
return nil, err
}
items := resp.Data.ToMapList()
result := make([]StockBasic, len(items))
for i, item := range items {
result[i] = StockBasic{
TSCode: getString(item, "ts_code"),
Symbol: getString(item, "symbol"),
Name: getString(item, "name"),
Area: getString(item, "area"),
Industry: getString(item, "industry"),
FullName: getString(item, "fullname"),
EnName: getString(item, "enname"),
CNName: getString(item, "cnspell"),
Market: getString(item, "market"),
Exchange: getString(item, "exchange"),
CurrType: getString(item, "curr_type"),
ListStatus: getString(item, "list_status"),
ListDate: getString(item, "list_date"),
DelistDate: getString(item, "delist_date"),
IsHS: getString(item, "is_hs"),
}
}
return result, nil
}
// FuturesBasic 期货合约基础信息
type FuturesBasic struct {
TSCode string `json:"ts_code"`
Symbol string `json:"symbol"`
Name string `json:"name"`
Exchange string `json:"exchange"`
FutCode string `json:"fut_code"`
Multiplier float64 `json:"multiplier"`
TradeUnit string `json:"trade_unit"`
PerUnit float64 `json:"per_unit"`
DeliveryDate string `json:"delivery_date"`
ListDate string `json:"list_date"`
DelistDate string `json:"delist_date"`
}
// GetFuturesBasic 获取期货合约基础信息
func (c *Client) GetFuturesBasic(exchange string) ([]FuturesBasic, error) {
params := map[string]interface{}{}
if exchange != "" {
params["exchange"] = exchange
}
resp, err := c.Query("fut_basic", params, "")
if err != nil {
return nil, err
}
items := resp.Data.ToMapList()
result := make([]FuturesBasic, len(items))
for i, item := range items {
result[i] = FuturesBasic{
TSCode: getString(item, "ts_code"),
Symbol: getString(item, "symbol"),
Name: getString(item, "name"),
Exchange: getString(item, "exchange"),
FutCode: getString(item, "fut_code"),
Multiplier: getFloat64(item, "multiplier"),
TradeUnit: getString(item, "trade_unit"),
PerUnit: getFloat64(item, "per_unit"),
DeliveryDate: getString(item, "delivery_date"),
ListDate: getString(item, "list_date"),
DelistDate: getString(item, "delist_date"),
}
}
return result, nil
}
// TradeCal 交易日历
type TradeCal struct {
Exchange string `json:"exchange"`
CalDate string `json:"cal_date"`
IsOpen int `json:"is_open"`
PretradeDate string `json:"pretrade_date"`
}
// GetTradeCal 获取交易日历
func (c *Client) GetTradeCal(exchange, startDate, endDate string, isOpen int) ([]TradeCal, error) {
params := map[string]interface{}{
"exchange": exchange,
"start_date": startDate,
"end_date": endDate,
}
if isOpen >= 0 {
params["is_open"] = isOpen
}
resp, err := c.Query("trade_cal", params, "")
if err != nil {
return nil, err
}
items := resp.Data.ToMapList()
result := make([]TradeCal, len(items))
for i, item := range items {
result[i] = TradeCal{
Exchange: getString(item, "exchange"),
CalDate: getString(item, "cal_date"),
IsOpen: getInt(item, "is_open"),
PretradeDate: getString(item, "pretrade_date"),
}
}
return result, nil
}
// FuturesHolding 期货持仓排名
type FuturesHolding struct {
TSCode string `json:"ts_code"`
TradeDate string `json:"trade_date"`
Symbol string `json:"symbol"`
Broker string `json:"broker"`
Vol int64 `json:"vol"`
VolChange int64 `json:"vol_chg"`
LongHld int64 `json:"long_hld"`
LongChange int64 `json:"long_chg"`
ShortHld int64 `json:"short_hld"`
ShortChange int64 `json:"short_chg"`
}
// 辅助函数
func getString(m map[string]interface{}, key string) string {
if v, ok := m[key]; ok && v != nil {
switch val := v.(type) {
case string:
return val
case []byte:
return string(val)
default:
return fmt.Sprintf("%v", v)
}
}
return ""
}
func getFloat64(m map[string]interface{}, key string) float64 {
if v, ok := m[key]; ok && v != nil {
switch val := v.(type) {
case float64:
return val
case int:
return float64(val)
case int64:
return float64(val)
case string:
var f float64
fmt.Sscanf(val, "%f", &f)
return f
}
}
return 0
}
func getInt(m map[string]interface{}, key string) int {
if v, ok := m[key]; ok && v != nil {
switch val := v.(type) {
case int:
return val
case int64:
return int(val)
case float64:
return int(val)
case string:
var i int
fmt.Sscanf(val, "%d", &i)
return i
}
}
return 0
}