likes
comments
collection
share

从零开始:开发讯飞星火认知大模型Python SDK的详细指南

作者站长头像
站长
· 阅读数 4

一、前言

最近讯飞发布了星火大模型2.0的API,我为此开发了一个简易的SDK,方便大家快速接入和开发。国内有这样的大模型API对于普通开发者来说是个不错的消息,讯飞还提供了大量免费的token,正好可以用来练习和实践。希望我们国内的大型模型能不断优化,创造出更多优秀的产品,为用户提供更好的服务。

二、申请星火认知大模型应用

首先访问官网申请API xinghuo.xfyun.cn/sparkapi 个人一年内有免费试用200w 的token数

从零开始:开发讯飞星火认知大模型Python SDK的详细指南

然后到服务管理面板获取申请到的 app_id、api_key、api_secret 应用信息

从零开始:开发讯飞星火认知大模型Python SDK的详细指南

最后就是熟悉下Web对接方式的接口文档 www.xfyun.cn/doc/spark/W…

三、制作简易版SDK

确认请求方式

官方API提供的是websocket连接的方式来进行通信。

从零开始:开发讯飞星火认知大模型Python SDK的详细指南

什么是websocket?

WebSocket是一种在单个TCP连接上进行全双工通信的协议。它为客户端和服务器之间的双向通信提供了一种更简单的方法,可以使数据在一个持久连接上进行交换,而不需要客户端不断发起HTTP请求。它与传统的HTTP请求-响应模式不同,可以实现服务器向客户端推送数据的功能,而不需要客户端发出请求。

python支持websocket通信常用的库有

库名相关地址说明
websocket-clientpypi.org/project/web…websocket-client 是 Python 的 WebSocket 客户端。它提供对 WebSocket 低级 API 的访问。
websocketswebsockets.readthedocs.io/en/stable/websockets 是一个用 Python 构建 WebSocket 服务器和客户端的库,重点关注正确性、简单性、稳健性和性能。
aiohttpdocs.aiohttp.org/en/stable/w…用于 asyncio 和 Python 的异步 HTTP 客户端/服务器。也支持websocket通信。

讯飞官方也提供了基于 websocket-client 库的DEMO案例,大家感兴趣可以下载看看。

但官方的DEMO好像不支持 asyncio,因此我打算用websockets与aiohttp库简单的重新封装下。

websocket请求响应示例

这里先展示下这两个库该如何发送websocket请求与处理响应。

websockets

ws服务端demo

#!/usr/bin/python3
# -*- coding: utf-8 -*-
# @Author: Hui
# @File: ws_server_demo.py
# @Desc: { ws服务端demo }
# @Date: 2023/10/19 14:58
import asyncio
import websockets


async def hello(websocket):
    name = await websocket.recv()
    print(f"<<< {name}")

    greeting = f"Hello {name}!"

    await websocket.send(greeting)
    print(f">>> {greeting}")


async def main():
    print("ws server run on localhost:8765")
    async with websockets.serve(hello, "localhost", 8765):
        await asyncio.Future()  # run forever


if __name__ == "__main__":
    asyncio.run(main())

ws客户端demo

#!/usr/bin/python3
# -*- coding: utf-8 -*-
# @Author: Hui
# @File: ws_client_demo.py
# @Desc: { ws客户端测试 }
# @Date: 2023/10/19 14:59
import asyncio
import websockets


async def hello():
    uri = "ws://localhost:8765"
    async with websockets.connect(uri) as websocket:
        name = input("What's your name? ")

        await websocket.send(name)
        print(f">>> {name}")

        greeting = await websocket.recv()
        print(f"<<< {greeting}")


if __name__ == "__main__":
    asyncio.run(hello())

demo运行效果

从零开始:开发讯飞星火认知大模型Python SDK的详细指南

aiohttp

ws服务端

#!/usr/bin/python3
# -*- coding: utf-8 -*-
# @Author: Hui
# @File: ws_aiohttp_server.py
# @Desc: { aiohttp ws服务端demo }
# @Date: 2023/10/22 18:53
from aiohttp import web

app = web.Application()


async def websocket_handler(request):
    ws = web.WebSocketResponse()
    await ws.prepare(request)

    async for msg in ws:
        if msg.type == web.WSMsgType.text:
            if msg.data == 'close':
                print('websocket connection closed')
                await ws.close()
            else:
                print(f"recv data >>> {msg.data}")
                await ws.send_str(f"Echo: {msg.data}")
        elif msg.type == web.WSMsgType.error:
            print(f'ws connection closed with exception {ws.exception()}')

    return ws


app.router.add_get('/ws_demo', websocket_handler)

if __name__ == '__main__':
    web.run_app(app, host="localhost", port=8080)

ws客户端

#!/usr/bin/python3
# -*- coding: utf-8 -*-
# @Author: Hui
# @File: ws_aiohttp_client.py
# @Desc: { aiohttp ws客户端demo }
# @Date: 2023/10/22 18:53
import aiohttp
import asyncio


async def ws_demo(session):
    async with session.ws_connect('ws://localhost:8080/ws_demo') as ws:
        test_data_list = ["hello ws", "close"]
        for test_data in test_data_list:
            print("send", test_data)
            await ws.send_str(test_data)

            msg = await ws.receive()
            print("recv", msg)


async def main():
    async with aiohttp.ClientSession() as session:
        await ws_demo(session)


if __name__ == '__main__':
    asyncio.run(main())

可以通过 aiohttp.ClientSession().ws_connect() 进行ws连接。

Demo运行效果

从零开始:开发讯飞星火认知大模型Python SDK的详细指南

封装简易SDK

交互流程

  1. 接受用户问题

  2. 组织api请求参数,

  3. api鉴权,获取鉴权后的ws url

  4. 建立ws连接,发送ws请求

  5. 处理响应

星火客户端初步封装

#!/usr/bin/python3
# -*- coding: utf-8 -*-
# @Author: Hui
# @Desc: { 讯飞星火大模型客户端 }
# @Date: 2023/10/19 14:56
import base64
import hashlib
import hmac
import uuid
import json
from datetime import datetime
from time import mktime
from urllib.parse import urlparse, urlencode
from wsgiref.handlers import format_date_time

import aiohttp
import websockets


class SparkChatConfig(BaseModel):
    """星火聊天配置"""
    domain: str = Field(default="generalv2", description="api版本")
    temperature: float = Field(
        default=0.5,
        ge=0, le=1,
        description="取值为[0,1],默认为0.5, 核采样阈值。用于决定结果随机性,取值越高随机性越强即相同的问题得到的不同答案的可能性越高"
    )
    max_tokens: int = Field(default=2048, le=8192, ge=1, description="模型回答的tokens的最大长度")
    top_k: int = Field(default=4, le=6, ge=1, description="从k个候选中随机选择⼀个(⾮等概率)")


class SparkClient:
    SERVER_URI_MAPPING = {
        "general": "ws://spark-api.xf-yun.com/v1.1/chat",
        "generalv2": "ws://spark-api.xf-yun.com/v2.1/chat",
    }

    def __init__(
            self,
            app_id: str,
            api_secret: str,
            api_key: str,
            chat_conf: SparkChatConfig = None
    ):
        self.app_id = app_id
        self.api_secret = api_secret
        self.api_key = api_key
        self.chat_conf = chat_conf or SparkChatConfig()
        self.server_uri = self.SERVER_URI_MAPPING[self.chat_conf.domain]
        self.answer_full_content = ""

    def build_chat_params(self, msg_context_list=None, uid: str = None):
        """构造请求参数"""
        pass

    def _parse_chat_response(self, chat_resp: str) -> SparkMsgInfo:
        """解析chat响应"""
        pass

    def get_sign_url(self, host=None, path=None):
        """获取鉴权后url"""
        pass

    async def achat(self, msg_context_list: list, uid: str = None):
        chat_params = self.build_chat_params(msg_context_list, uid)
        sign_url = self.get_sign_url()

        async with websockets.connect(sign_url) as ws:
            await ws.send(chat_params)
            async for chat_resp in ws:
                spark_msg_info = self._parse_chat_response(chat_resp)
                yield spark_msg_info

SparkClient 初始化的基本属性是API服务的申请应用信息与api密钥

  • app_id
  • api_secret
  • api_key

然后也可以初始化聊天对话的配置,默认None使用默认的对话配置

class SparkChatConfig(BaseModel):
    """星火聊天配置"""
    domain: str = Field(default="generalv2", description="api版本")
    temperature: float = Field(
        default=0.5,
        ge=0, le=1,
        description="取值为[0,1],默认为0.5, 核采样阈值。用于决定结果随机性,取值越高随机性越强即相同的问题得到的不同答案的可能性越高"
    )
    max_tokens: int = Field(default=2048, le=8192, ge=1, description="模型回答的tokens的最大长度")
    top_k: int = Field(default=4, le=6, ge=1, description="从k个候选中随机选择⼀个(⾮等概率)")

然后根据 domain 参数获取对应api版本的ws的url。

对外的功能就是 achat 对话

async def achat(self, msg_context_list: list, uid: str = None):
    chat_params = self.build_chat_params(msg_context_list, uid)
    sign_url = self.get_sign_url()

    async with websockets.connect(sign_url) as ws:
        await ws.send(chat_params)
        async for chat_resp in ws:
            spark_msg_info = self._parse_chat_response(chat_resp)
            yield spark_msg_info
  • msg_context_list 用户提问的上下文信息列表

    • 
      msg_context_list = [
          {"role": 'user', "content": content},  # 用户的历史问题
          # {"role": 'assistant', "content": "....."},  # AI的历史回答结果
          # ....... 省略的历史对话
          # {"role": "user", "content": "你会做什么"}  # 最新的一条问题,如无需上下文,可只传最新一条问题
      ]
      
  • uid 用户唯一标识id,默认None使用uuid

这里使用的是websockets来进行ws连接处理,由于星火返回的数据是一段一段,因此这里使用 async for 来处接受返回的数据,然后调用 _parse_chat_response 方法处理聊天的数据,最后使用yield返回(异步生成器)。OK,到这里初步结构已经好了,接下来就是具体实现了。

构造对话聊天参数


def build_chat_params(self, msg_context_list=None, uid: str = None):
    """构造请求参数"""
    return json.dumps({
        "header": self._build_header(uid=uid),
        "parameter": self._build_parameter(),
        "payload": self._build_payload(msg_context_list)
    })

这里分别通过三个方法一起构造请求参数信息,分别是

  • _build_header 请求头部信息 (应用、用户信息)

  • _build_parameter 请求参数信息(对话的配置)

  • _build_payload 请求载体信息 (问题内容)

def _build_header(self, uid=None):
    return {
        "app_id": self.app_id,
        "uid": uid or uuid.uuid4().hex
    }

def _build_parameter(self):
    return {
        "chat": {
            "domain": self.chat_conf.domain,
            "temperature": self.chat_conf.temperature,
            "max_tokens": self.chat_conf.max_tokens,
            "top_k": self.chat_conf.top_k
        }
    }

def _build_payload(self, msg_context_list: list):
    return {
        "message": {
            # 如果想获取结合上下文的回答,需要开发者每次将历史问答信息一起传给服务端,如下示例
            # 注意:text里面的所有content内容加一起的tokens需要控制在8192以内,开发者如有较长对话需求,需要适当裁剪历史信息
            "text": msg_context_list
        }
    }

具体组织的信息就是用应用服务配置、聊天配置、以及用户传的问题信息。这样封装看起来就非常的清晰,也好在不同方法中扩展信息,不然一个大字典的组织不美观。

获取鉴权后的ws地址

应该是先获取鉴权后的url再构造请求参数,其实都可以,不影响。

def get_sign_url(self, host=None, path=None):
    """获取鉴权后url"""
    host = host or urlparse(self.server_uri).hostname
    path = path or urlparse(self.server_uri).path

    # 生成RFC1123格式的时间戳
    now = datetime.now()
    date = format_date_time(mktime(now.timetuple()))

    # 拼接字符串
    signature_origin = "host: " + host + "\n"
    signature_origin += "date: " + date + "\n"
    signature_origin += "GET " + path + " HTTP/1.1"

    # 进行hmac-sha256进行加密
    signature_sha = hmac.new(self.api_secret.encode('utf-8'), signature_origin.encode('utf-8'),
                             digestmod=hashlib.sha256).digest()

    signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8')

    authorization_origin = f'api_key="{self.api_key}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'

    authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')

    # 将请求的鉴权参数组合为字典
    v = {
        "authorization": authorization,
        "date": date,
        "host": host
    }
    # 拼接鉴权参数,生成url
    sign_url = self.server_uri + '?' + urlencode(v)
    return sign_url

这里的host、path其实是你服务器获取到请求的host与path,如果不给的话,默认使用的是讯飞api的host、path。这里鉴权流程如下

  1. 通过 host date request-line 信息进行hmac-sha256进行加密然后进行base64编码

  2. 然后把加密后得到的信息 与 api_key 再次进行base64编码

  3. 最后把鉴权的参数拼接到ws的uri上生成新的sign_url

发送请求处理响应

async def aiohttp_chat(self, msg_context_list: list, uid: str = None):
    chat_params = self.build_chat_params(msg_context_list, uid)
    sign_url = self.get_sign_url()

    async with aiohttp.ClientSession() as session:
        async with session.ws_connect(sign_url) as ws:
            await ws.send_str(chat_params)
            async for chat_resp in ws:
                spark_msg_info = self._parse_chat_response(chat_resp.data)
                yield spark_msg_info
                
async def achat(self, msg_context_list: list, uid: str = None):
    chat_params = self.build_chat_params(msg_context_list, uid)
    sign_url = self.get_sign_url()

    async with websockets.connect(sign_url) as ws:
        await ws.send(chat_params)
        async for chat_resp in ws:
            spark_msg_info = self._parse_chat_response(chat_resp)
            yield spark_msg_info
            
def _parse_chat_response(self, chat_resp: str) -> SparkMsgInfo:
    """解析chat响应"""
    chat_resp = json.loads(chat_resp)
    code = chat_resp["header"]["code"]
    if code != 0:
        raise ValueError(f"对话错误,{chat_resp}")

    text_list = chat_resp["payload"]["choices"]["text"]
    answer_content = text_list[0]["content"]
    self.answer_full_content += answer_content
    spark_msg_info = SparkMsgInfo()

    status = chat_resp["header"]["status"]
    sid = chat_resp["header"]["sid"]
    spark_msg_info.msg_sid = sid
    spark_msg_info.msg_status = status
    spark_msg_info.msg_content = answer_content

    if status == SparkMessageStatus.END_RET.value:
        usage_info = chat_resp["payload"]["usage"]["text"]
        spark_msg_info.usage_info = usage_info
        spark_msg_info.msg_content = self.answer_full_content
        self.answer_full_content = ""

    return spark_msg_info

解析响应其实就是获取星火回答的内容并组装成我们自己定义的格式 SparkMsgInfo

星火返回的格式内容如下

# 接口为流式返回,此示例为最后一次返回结果,开发者需要将接口多次返回的结果进行拼接展示
{
    "header":{
        "code":0,
        "message":"Success",
        "sid":"cht000cb087@dx18793cd421fb894542",
        "status":2
    },
    "payload":{
        "choices":{
            "status":2,
            "seq":0,
            "text":[
                {
                    "content":"我可以帮助你的吗?",
                    "role":"assistant",
                    "index":0
                }
            ]
        },
        "usage":{
            "text":{
                "question_tokens":4,
                "prompt_tokens":5,
                "completion_tokens":9,
                "total_tokens":14
            }
        }
    }
}

封装的内容如下


class SparkMessageStatus(Enum):
    """
    星火消息响应状态
    0-代表首个文本结果;1-代表中间文本结果;2-代表最后一个文本结果
    """

    FIRST_RET = 0
    MID_RET = 1
    END_RET = 2
    
    
class SparkMsgInfo(BaseModel):
    """星火消息信息"""

    msg_sid: str = Field(default=uuid.uuid4().hex, description="消息id,用于唯一标识⼀条消息")
    msg_type: str = Field(default="text", description="消息类型,目前仅支持text")
    msg_content: str = Field(default="", description="消息内容")
    msg_status: SparkMessageStatus = Field(default=SparkMessageStatus.FIRST_RET, description="消息状态")

    usage_info: Optional[SparkChatUsageInfo] = Field(default=None, description="消息使用信息")

最后有一个判断就是对话消息状态为 2 代表最后一个文本结果的时候,我把之前的回复的内容拼接到了 answer_full_content 中去,然后就是获取消息token的使用信息后再返回。

这里顺便把aiohttp请求的方式也写了下,主要是当练习用的,SDK连接的方式最好是确认一种方式好,不要两种混用,我一开始不知道aiohttp也可以websocket通信,所以封装的时候使用websockets库,后面才发现aiohttp也支持异步的websocket通信,要说功能性的话感觉还是要使用aiohttp,因为后面可能还要封装http请求的api。

四、使用体验

#!/usr/bin/python3
# -*- coding: utf-8 -*-
# @Author: Hui
# @Desc: { 主入口测试模块 }
# @Date: 2023/10/20 19:39
import asyncio
import random

from spark_ai_sdk.client import SparkClient
from spark_ai_sdk.config import SparkMsgRole, SparkChatConfig, SparkMsgInfo


def build_user_msg_context_list(content):
    msg_context_list = [
        {"role": SparkMsgRole.USER.value, "content": content},  # 用户的历史问题
        # {"role": SparkMsgRole.ASSISTANT.value, "content": "....."},  # AI的历史回答结果
        # ....... 省略的历史对话
        # {"role": "user", "content": "你会做什么"}  # 最新的一条问题,如无需上下文,可只传最新一条问题
    ]
    return msg_context_list


async def main():
    chat_conf = SparkChatConfig(domain="generalv2", temperature=0.5, max_tokens=2048, top_k=3)
    spark_client = SparkClient(
        app_id="",
        api_secret="",
        api_key="",
        chat_conf=chat_conf
    )

    questions = ["程序员如何技术提升?", "如何提升系统并发", "如何找女朋友"]
    ques = random.choice(questions)
    msg_context_list = build_user_msg_context_list(content=ques)
    answer_full_content = ""

    async for chat_resp in spark_client.achat(msg_context_list):
        chat_resp: SparkMsgInfo = chat_resp
        answer_full_content += chat_resp.msg_content
        print(chat_resp)
    print(answer_full_content)


if __name__ == '__main__':
    asyncio.run(main())

从零开始:开发讯飞星火认知大模型Python SDK的详细指南

五、封装总结

做事情不要太着急,写代码也是, 古话说的好,磨刀不误砍柴工 。

  • 首先熟悉API文档,确认请求方式与鉴权、下载示例Demo观摩学习体验下。

  • 调研你不熟悉的领域,例如python 如何进行 websocket 通信,利用Google搜索查询资料,学习一些Demo,获取关键信息,然后逐渐扩展知识面,了解相关的技术,再度扩张,例如 webscokets、aiohttp库具体使用,还是要看官方文档才是最新、最权威的,这时就可以去pypi、github去查找这些开源库学习官方文档和教程。

  • 学习下别人写的一些开源库,可以获得一些灵感,最后就是让代码组织自己的想法去实现。

我一开始的初始想法,就是消息的上下文让调用方自行组织比较好,然后就是数据格式使用pydantic进行封装组织,这样比字典更好维护。后续可能会继续扩展,存储对话的上下文,简化组织消息格式。例如支持本地内存的形式或者Redis的形式进行存储对话上下文,由于token的限制,还可以指定一些存储对话上下文的策略,例如

  • 一次会话只保留最近30对话

  • 总结压缩会话等。

大家也可以去学习下其他优秀开源项目的实践

六、源代码

欢迎大家一起贡献学习。

Github:github.com/HuiDBK/Spar…

转载自:https://juejin.cn/post/7292781589477916726
评论
请登录