likes
comments
collection
share

ChatGPT|用ChatGLM-6B实现图文对话

作者站长头像
站长
· 阅读数 3

ChatGLM-6B版本是文本交互模型,和GPT3.5一样,不能识别图片,但是是否有方法可以实现呢?

那先看看效果(不过如果你GPT3.5,效果应该会更好,这里为了验证功能,所以使用ChatGLM-6B版本)。

效果

第一个:

ChatGPT|用ChatGLM-6B实现图文对话

第二个:

ChatGPT|用ChatGLM-6B实现图文对话

设计架构

对于图片的元素,一般我们只需要知道场景,标签和文字,就能描述这张图片,说干就干,架构如下:

ChatGPT|用ChatGLM-6B实现图文对话

1、我们判断对话的类型,如果是文本则直接使用ChatGLM的对话功能; 2、如果是图片则执行如下步骤: 1)识别图片的场景; 2)识别图片的标签有哪些; 3)用OCR服务识别图片中的文字; 3、将如上的三个信息汇总,按照Prompt模板将信息给到ChatGLM; 4、拿到返回的结果,返回给对话发起者;

实现代码

1、使用ChatGLM

(1)私有搭建 

具体参考https://github.com/THUDM/ChatGLM-6B,按照步骤搭建即可,然后提供API;

(2)使用API 

如果自己没有GPU资源,可以去这里直接注册,使用智谱提供的API,地址:https://open.bigmodel.ai/

(3)对话代码

...

kUseChatGLM = Flase
class ZhipuAI(BaseChat):
    ability_type = "chatglm_qa_6b"  # 能力类型
    engine_type = "chatglm_6b"  # 引擎类型
    if kUseChatGLM:
        ability_type = "chatGLM"
        engine_type = "chatGLM"
    API_KEY = "xxx"  # 接口API KEY
    PUBLIC_KEY = "xxx"  # 公钥

    def __init__(self, apitype="", dict_args_input={}):
        self.dict_args = {}
        for k, v in dict_args_input:
            self.dict_args[k] = v
        self.system_mess = []
        self.user_mess = []

    @staticmethod
    def getToken():
        token_result = kTokenCache.getValue('token')
        if not token_result:
            token_result = getToken(ZhipuAI.API_KEY, ZhipuAI.PUBLIC_KEY)
            kTokenCache.setValue('token', token_result, 60)
        return token_result

    ...

    def chat(self, mess):
        isok, response = self.openaiRequest(mess)
        if isok:
            return isok, response
        else:
            return isok, "error:" + str(response)

    def openaiRequest(self, mess):
        try:
            token_result = ZhipuAI.getToken()
            uuid1 = uuid.uuid1()
            request_task_no = str(uuid1).replace("-""")
            data = {
                "requestTaskNo": request_task_no,
                "prompt": mess
            }
            if self.user_mess and len(self.user_mess) > 0:
                if kUseChatGLM:
                    historyFormat = []
                    for history in self.user_mess:
                        logging.info("history: " + str(history))
                        try:
                            if len(history["query"]) > 0 and len(history["content"]):
                                historyFormat.append(history["query"])
                                historyFormat.append(history["content"])
                        except Exception as ex:
                            logging.error("err: " + str(ex))
                    data["history"] = historyFormat
                else:
                    data["history"] = self.user_mess
            logging.info("request data: " + str(data))
            if token_result and token_result["code"] == 200:
                token = token_result["data"]
                if kUseChatGLM:
                    resp = executeEngine(ZhipuAI.ability_type,
                                         ZhipuAI.engine_type, token, data)
                else:
                    resp = executeEngineV2(ZhipuAI.ability_type,
                                           ZhipuAI.engine_type, token, data)
                    while resp["code"] == 200 and resp['data']['taskStatus'] == 'PROCESSING':
                        taskOrderNo = resp['data']['taskOrderNo']
                        time.sleep(1)
                        resp = queryTaskResult(token, taskOrderNo)
                outputText = resp["data"]["outputText"]
                if outputText:
                    # keep userid to kMaxUserMessLength
                    if len(outputText) > 0:
                        if len(self.user_mess) > kMaxUserMessLength:
                            self.user_mess = self.user_mess[-kMaxUserMessLength]
                        else:
                            self.user_mess.append(
                                {"query": mess, "content": outputText})
                return True, resp["data"]["outputText"]
        except Exception as ex:
            logging.error("err: " + str(ex))
            return Falsestr(ex)
...

以上非完整代码,大家可以参考修改,主要功能就是用智谱提供的API,进行文本对话,然后存储历史记录。

2、分析图片

(1)方案选项 

1)图片识别标签,可以自己搭建,如果有兴趣可以参考百度的PaddlePaddle,具体搭建的方式:https://github.com/PaddlePaddle/PaddleClas; 2)为了快速验证,这里也可以使用云服务,如阿里云的https://ai.aliyun.com/image,腾讯云的https://console.cloud.tencent.com/tiia/detectlabel; 3)OCR的服务也有一些开源的,不过云上使用更方便,可以用腾讯云的https://console.cloud.tencent.com/ocr/overview

这里我就是用腾讯云的服务验证,当然不是商用,可以不需要花钱(有一定的免费额度)。

(2)获取标签

先安装SDK:

python3 -m pip install tencentcloud-sdk-python-tiia
python3 -m pip install tencentcloud-sdk-python

调用腾讯云的API,获取图片标签:

import base64
import json
import logging
from tencentcloud.common import credential
from tencentcloud.common.profile.client_profile import ClientProfile
from tencentcloud.common.profile.http_profile import HttpProfile
from tencentcloud.common.exception.tencent_cloud_sdk_exception import TencentCloudSDKException
from tencentcloud.tiia.v20190529 import tiia_client, models as tiia_models # 图片标签的库
from tencentcloud.ocr.v20181119 import ocr_client, models as ocr_models # OCR库

SecretId = 'xxx' # 腾讯云的SecretId
SecretKey = 'xxx' # 腾讯云的SecretKey
kMaxLabels = 128

def get_images_tags(base64_data):
    try:
        # 实例化一个认证对象,入参需要传入腾讯云账户 SecretId 和 SecretKey,此处还需注意密钥对的保密
        # 代码泄露可能会导致 SecretId 和 SecretKey 泄露,并威胁账号下所有资源的安全性。以下代码示例仅供参考,建议采用更安全的方式来使用密钥,请参见:https://cloud.tencent.com/document/product/1278/85305
        # 密钥可前往官网控制台 https://console.cloud.tencent.com/cam/capi 进行获取
        cred = credential.Credential(SecretId, SecretKey)
        # 实例化一个http选项,可选的,没有特殊需求可以跳过
        http_profile = HttpProfile()
        http_profile.endpoint = "tiia.tencentcloudapi.com"
        # 实例化一个client选项,可选的,没有特殊需求可以跳过
        client_profile = ClientProfile()
        client_profile.httpProfile = http_profile
        # 实例化要请求产品的client对象,client_profile是可选的
        client = tiia_client.TiiaClient(cred, "ap-guangzhou", client_profile)
        # 实例化一个请求对象,每个接口都会对应一个request对象
        req = tiia_models.DetectLabelRequest()
        params = {'ImageBase64': base64_data}
        req.from_json_string(json.dumps(params))
        # 返回的resp是一个DetectLabelResponse的实例,与请求对象对应
        resp = client.DetectLabel(req)
        logging.info("get_images_tags: " + resp.to_json_string())
        return resp  # 输出json格式的字符串回包
    except TencentCloudSDKException as err:
        logging.error("get_images_tags err: " + str(err))
        return None

获取结果样例:

{"Response":{"Labels":[{"Name":"字体","Confidence":92,"FirstCategory":"其他","SecondCategory":"其他"},{"Name":"文本","Confidence":85,"FirstCategory":"卡证文档","SecondCategory":"其他"},{"Name":"品牌","Confidence":68,"FirstCategory":"物品","SecondCategory":"标牌标识"},{"Name":"线","Confidence":50,"FirstCategory":"物品","SecondCategory":"日常用品"},{"Name":"报告","Confidence":27,"FirstCategory":"物品","SecondCategory":"其他"}],"CameraLabels":null,"AlbumLabels":null,"NewsLabels":null,"RequestId":"e710697f-3054-494a-bf33-1990dbe25bbd"}}

(3)获取文字

先安装SDK:

python3 -m pip install tencentcloud-sdk-python-tiia
python3 -m pip install tencentcloud-sdk-python

调用腾讯云的API,获取图片标签:

import base64
import json
import logging
from tencentcloud.common import credential
from tencentcloud.common.profile.client_profile import ClientProfile
from tencentcloud.common.profile.http_profile import HttpProfile
from tencentcloud.common.exception.tencent_cloud_sdk_exception import TencentCloudSDKException
from tencentcloud.tiia.v20190529 import tiia_client, models as tiia_models # 图片标签的库
from tencentcloud.ocr.v20181119 import ocr_client, models as ocr_models # OCR库

SecretId = 'xxx' # 腾讯云的SecretId
SecretKey = 'xxx' # 腾讯云的SecretKey
kMaxLabels = 128

def get_images_ocr(base64_data):
    try:
        cred = credential.Credential(SecretId, SecretKey)
        # 实例化一个http选项,可选的,没有特殊需求可以跳过
        http_profile = HttpProfile()
        http_profile.endpoint = "ocr.tencentcloudapi.com"
        # 实例化一个client选项,可选的,没有特殊需求可以跳过
        client_profile = ClientProfile()
        client_profile.httpProfile = http_profile
        # 实例化要请求产品的client对象,client_profile是可选的
        client = ocr_client.OcrClient(cred, "ap-guangzhou", client_profile)
        # 实例化一个请求对象,每个接口都会对应一个request对象
        req = ocr_models.RecognizeTableAccurateOCRRequest()
        params = {'ImageBase64': base64_data}
        req.from_json_string(json.dumps(params))
        # 返回的resp是一个DetectLabelResponse的实例,与请求对象对应
        resp = client.RecognizeTableAccurateOCR(req)
        # 输出json格式的字符串回包
        logging.info("get_images_ocr: " + resp.to_json_string())
        return resp
    except TencentCloudSDKException as err:
        logging.error("get_images_ocr err: " + str(err))
        return None

获取结果样例:

{"TableDetections": [{"Cells": [{"ColTl": 0, "RowTl": 0, "ColBr": 1, "RowBr": 1, "Text": "周末程序猿", "Type": "header", "Confidence": 100, "Polygon": [{"X": 1407, "Y": 588}, {"X": 2176, "Y": 584}, {"X": 2177, "Y": 748}, {"X": 1408, "Y": 752}]}], "Type": 0, "TableCoordPoint": [{"X": 1407, "Y": 584}, {"X": 2177, "Y": 584}, {"X": 2177, "Y": 752}, {"X": 1407, "Y": 752}]}] ...

3、组装Prompt

可以按照Prompt模板组装数据,如:

#####
图片有以下标签:{拿到的图片标签列表}
#####
图片中有文字,请你理解以下文字:[{拿到的文字标签列表}]...
现在请综合以上信息(标签、文字描述等),自然并详细地描述这副图片。请你不要在回答中暴露上述信息来源是图片分析服务。

注意:如果文字过长,可以用省略号...,具体长度可以设置256-1024之间。

具体代码:

def get_images_prompt(base64_data):
    try:
        # "#####\n经过某个图片分析服务,得出以下关于这幅图片的信息:\n"
        prompt = ""
        resp = get_images_tags(base64_data)
        labels_str = []
        if resp:
            for labels in resp.Labels:
                labels_str.append(labels.Name)
        logging.info(labels_str)
        if len(labels_str) > 0:
            prompt += f"#####\n图片有以下标签:{','.join(labels_str)}\n"
        logging.info("labels_str prompt: " + str(prompt))
        resp = get_images_ocr(base64_data)
        labels_ocr_str = []
        if resp:
            for cells in resp.TableDetections:
                for text in cells.Cells:
                    if len(text.Text) > 0:
                        labels_ocr_str.append(text.Text)
        logging.info(str(labels_ocr_str))
        if len(labels_ocr_str) > 0:
            labels = ','.join(labels_ocr_str).replace("\n""")
            if len(labels) > kMaxLabels:
                labels = labels[:kMaxLabels] + "..."
            prompt += f"#####\n图片中有文字,请你理解以下文字:["+labels+"]\n"
        else:
            prompt += f"#####\n图片中没有文字\n"
        prompt += "现在请综合以上信息(标签、文字描述等),自然并详细地描述这副图片。请你不要在回答中暴露上述信息来源是图片分析服务。"
        logging.info("labels_ocr_str prompt: " + str(prompt))
        return prompt
    except Exception as err:
        logging.err("get_images_ocr err: " + str(err))
        return None

3、将组装的Prompt发给ChatGLM

省略了一些工程代码(如果需要整体代码可以留言给我),测试代码就是这样:

@staticmethod
def test(content, userid=""):
    """
    测试函数
    :return:
    """
    logging.info("content->"+str(content))
    chat = SimpleChatFactory.getInstance(kChatModel, userid)
    image_pos = content.find("data:image/")
    if image_pos >= 0:
        images_base64 = re.sub(
            '<img src="data:image/(.*);base64,''', content)
        images_base64 = re.sub('">''', images_base64)
        logging.info("images_base64->"+str(images_base64))
        content = get_images_prompt(images_base64)
        chat.withOption(user_mess=False)
    # 如果是图像需要预处理
    if len(content) > kMaxContentLength:
        return kERRTimeout
    isok, replaycontent = chat.chat(content)
    if isok:
        return replaycontent
    else:
        return kERRTimeout

以上就是基于云服务+ChatGLM实现图片理解,如果你想自己搭建一个不依赖云服务的,其实也可以,上文给出了一些开源的方案,按照同样的方式替换服务,将生成的文本描述给到ChatGLM或者其他的GPT文本对话模型,实现扩展为多模态。

demo页面:

service-mpjvpuxa-1251014631.gz.apigw.tencentcs.com/static/chat…