likes
comments
collection
share

一个orm类库是怎么诞生的

作者站长头像
站长
· 阅读数 8

事情是这样的

下面的完整示例代码都开源到github了

因为我们的业务需要用python来实现一些定时脚本、数据爬虫之类的工具,这些工具也经常使用到SQL操作。

在python里调用数据库肯定优先会想到 cx_Oraclepymysql 这样现成的工具库啦!

(我这次也是在这两个的工具库的基础上二次封装开发的)

但是对于用习惯thinkphp、laravel这种框架的我来说,pymysql和cx_Oracle这种操作数据库的方式多少有点不太习惯

    import pymysql 
    # 打开数据库连接 
    db = pymysql.connect(host='localhost', user='testuser', password='test123', database='TESTDB') 
    # 使用cursor()方法获取操作游标 
    cursor = db.cursor() 
    # SQL 插入语句 
    sql = """INSERT INTO EMPLOYEE(FIRST_NAME, LAST_NAME, AGE, SEX, INCOME) VALUES ('Mac', 'Mohan', 20, 'M', 2000)"""
    try: 
        # 执行sql语句 
        cursor.execute(sql) 
        # 提交到数据库执行 
        db.commit() 
    except: 
        # 如果发生错误则回滚 
        db.rollback() 
        # 关闭数据库连接
    import pymysql 
    # 打开数据库连接 
    db = pymysql.connect(host='localhost', user='testuser', password='test123', database='TESTDB') 
    # 使用 cursor() 方法创建一个游标对象 
    cursor cursor = db.cursor() 
    # 使用 execute() 方法执行 SQL 查询 
    cursor.execute("SELECT VERSION()") 
    # 使用 fetchone() 方法获取单条数据. 
    data = cursor.fetchone() 
    print ("Database version : %s " % data) 
    # 关闭数据库连接 
    db.close()
    

那么,开始改造它吧

我可是用习惯了世界上最好的编程语言的开发仔😄

在php或者go的编码过程中我早已经习惯那种链式操作,所以我在想python中是不是也可以继续沿用这种操作习惯或者类似这种风格也可以,比如:

    // php
    $data = Model::where()->fields()->order()->select();
    
    // go
    data := XXXX.Where().Select().First()

于是乎,它来了

第一步,先确定一下调用方式

'''
Author: psq
Date: 2023-09-12 10:38:46
LastEditors: psq
LastEditTime: 2023-09-12 11:21:50
'''

from utils.databases.mysql    import  mysql
from utils.databases.oracle   import  oracle

if __name__ == '__main__':
    
    db = mysql.construct(env = 'mysql')
    
    '''
    description: 查询数据
    return {*}
    '''    
    select = db.select(
        table   =   'test', # 查询表名
        alias   =   'u', # 设置别名
        where   =   {'u.xxx': 'xxx', 'u.xxx': 'xxx'}, # 查询条件
        fields  =   ['u.xxx', 'a.xxx', 't.xxx'], # 限制输出字段
        orderby =   'u.xxx asc',    # 排序方式
        limit   =   [5, 1],	# 每页显示数量, 页码
        join    =   [
            {
                'connect':  'left', # 此项入不传则默认使用left连接
                'alias':    'a',
                'table':    'account',
                'on':   	'u.xxx = a.xxx'
            },
            {
                'connect':  'left', # 此项入不传则默认使用left连接
                'alias':    't',
                'table':    'account_tree',
                'on':   	'u.xxxx = t.xxx'
            }
        ],
    )
    
    print(select) # 返回list对象
    
    '''
    description: 修改数据
    return {*}
    '''    
    update = db.update(
        table   =   'test',
        values  =   {'test_name': 2},
        where   =   {'test_id': 1}
    )
    
    print(update) # 返回布尔对象
    
    '''
    description: 插入数据
    return {*}
    '''    
    insert = db.insert(
        table   =   'test',
        values  =   {'test_id': 21, 'test_name': 1, 'test_value': 1, 'test_age': 1}
    )
    
    print(insert) # 返回布尔对象
    
    '''
    description: 一次插入多条数据
    return {*}
    '''    
    insertall = db.insertall(
        table   =   'test',
        values  =   [
            {'test_id': 22, 'test_name': 1, 'test_value': 1, 'test_age': 1},
            {'test_id': 23, 'test_name': 1, 'test_value': 1, 'test_age': 1}
        ]
    )
    
    print(insertall) # 返回布尔对象
    
    '''
    description: 自定义读语句
    return {*}
    '''    
    query = db.query(
        sql = 'select test_id from test'
    )
    
    print(query) # 返回list对象
    
    '''
    description: 自定义写语句
    return {*}
    '''    
    execute = db.execute(
        sql = 'insert into test (test_id) values (23)'
    )
    
    print(execute) # 返回布尔对象
    
    '''
    description: 跨库事务
    return {*}
    '''    
    XAtransaction = db.XAtransaction(
        env     =   ['sit', 'sit2'],
        sql     =   {
            'sit':  [
                'insert into test_table_1 (tb_a) values ("1");',
                'insert into test_table_1 (tb_a) values ("2");'
            ],
            'sit2': [
                'insert into test_table_2 (tb_b) values ("3");',
                'insert into test_table_2 (tb_v) values ("4");',
                'insert into test_table_2 (tb_v) values ("5");'
            ]
        }
    )

    print(XAtransaction) # 返回布尔对象
    # 目前仅支持mysql !!!
    # 参与事务的数据表必须为innodb存储引擎 !!!
    # env内所有参与事务的示例必须有配套的写入语句,否则会导致事务整体回滚 !!!
    # 如果两个库在同一个连接实例时会导致调用报错 !!!
    

看上去就舒服多了😄!

那怎么连接数据库、怎么将参数转化成SQL语句呢、怎么打日志、怎么处理异常,这就开始

#!/usr/bin/python
# -*- coding: UTF-8 -*-

import  sys
import  uuid
import  json
sys.path.append('./')

from utils.databases.mysql.connect  import pyconnection
from utils.databases.mysql.build    import queryconstruction
from utils.databases.mysql.build    import executeconstruction

class construct(object):

    def __init__(self, env:str):
        
        with open("./config/db.json", "r") as f:

            f = json.load(f)[env]
        
        self.db = pyconnection.Connect(env = f)

    '''
    description: 数据查询
    param {*} self
    param {object} parameter
    return {*}
    '''    
    def select(self, **parameter:any) -> bool:

        return self.db.querySQL(queryconstruction.Construction(parameter).sql)

    '''
    description: 数据写入
    param {*} self
    param {object} parameter
    return {*}
    '''    
    def insert(self, **parameter:any) -> bool:
    
        return self.db.executeSQL(
            sql         = executeconstruction.Construction(parameter).sql,
            autocommit  = parameter['autocommit'] if parameter.__contains__('autocommit') else True
        )

    '''
    description: 批量数据写入
    param {*} self
    param {object} parameter
    return {*}
    '''    
    def insertall(self, **parameter:any) -> bool:

        return self.db.executeSQL(
            sql         = executeconstruction.Construction(parameter).sql,
            autocommit  = parameter['autocommit'] if parameter.__contains__('autocommit') else True
        )

    '''
    description: 数据修改
    param {*} self
    param {object} parameter
    return {*}
    '''    
    def update(self, **parameter:any) -> bool:

        return self.db.executeSQL(
            sql         = executeconstruction.Construction(parameter).sql,
            autocommit  = parameter['autocommit'] if parameter.__contains__('autocommit') else True
        )

    '''
    description: 数据删除
    param {*} self
    param {object} parameter
    return {*}
    '''    
    def delete(self, **parameter:any) -> bool:

        return self.db.executeSQL(
            sql         = executeconstruction.Construction(parameter).sql,
            autocommit  = parameter['autocommit'] if parameter.__contains__('autocommit') else True
        )

    '''
    description: 自定义读语句
    param {*} self
    param {object} parameter
    return {*}
    '''    
    def query(self, **parameter:any) -> bool:

        return self.db.querySQL(parameter['sql'])

    '''
    description: 自定义写语句
    param {*} self
    param {object} parameter
    return {*}
    '''    
    def execute(self, **parameter:any) -> bool:
        
        return self.db.executeSQL(
            sql         = parameter['sql'],
            autocommit  = parameter['autocommit'] if parameter.__contains__('autocommit') else True
        )
    
    '''
    description: 分布式跨库事务
    param {*} self
    param {object} parameter
    return {*}
    '''    
    def XAtransaction(self, **parameter:any) -> bool:
        
        transactionID, sqlResult  = uuid.uuid1(), []

        for x in range(len(parameter['env'])):

            self.__init__(parameter['env'][x])
            
            self.query(sql = f"xa start '{transactionID}';")
            
            for s in range(len(parameter['sql'][parameter['env'][x]])):
            
                sqlResult.append(self.execute(
                    sql         =   parameter['sql'][parameter['env'][x]][s],
                    autocommit  =   False 
                ))
 
            self.query(sql = f"xa end '{transactionID}';")
            self.query(sql = f"xa prepare '{transactionID}';")

        for x in range(len(parameter['env'])):

            self.__init__(parameter['env'][x])
            
            if False in sqlResult:

                self.query(sql =  f"xa rollback '{transactionID}';")
            else:

                self.query(sql = f"xa commit '{transactionID}';")

        return False if False in sqlResult else True

这里第一步当然是对语句进行分流啦


'''
Author: psq
Date: 2021-04-28 16:21:01
LastEditors: psq
LastEditTime: 2023-09-12 10:52:59
'''
#!/usr/bin/python
# -*- coding: UTF-8 -*-
#
from ast import Try
import  sys
import  time
import  pymysql

sys.path.append('./')


class Connect():

    def __init__(self, env:any):

        self.starttime = int(round(time.time() * 1000))

        try:
            
            self.db  = pymysql.connect(
                port     = int(env['port']),
                host     = str(env['host']),
                user     = str(env['user']),
                charset  = str(env['charset']),
                password = str(env['password']),
                db       = str(env['db']),
            )

        except Exception as e:

            raise RuntimeError('connectErr:', e)

        self.cursor = self.db.cursor(cursor = pymysql.cursors.DictCursor)

    def querySQL(self, sql:str) -> tuple:

        try:
            
            self.cursor.execute(self.__sql(sql))

            self.result = True

        except Exception as e:

            self.result = False
            
            raise RuntimeError('executeErr:', e)
        
        # 操作日志,可以根据自己的方式保存下来
        log = {
            "dbType":       "mysql",
            "startTime":    time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(self.starttime / 1000)),
            "endTime":      time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(time.time()))),
            "runResult":    self.result,
            "runTime":      (int(round(time.time() * 1000)) - self.starttime) / 1000,
            "sql":          self.sql
        }

        return self.cursor.fetchall()

    def executeSQL(self, sql:str, autocommit:bool) -> bool:

        try:

            self.cursor.execute(self.__sql(sql))

            if autocommit: self.db.commit()

            self.result = True

        except Exception as e:
            
            if autocommit: self.db.rollback()
  
            self.result = False
        
        # 操作日志,可以根据自己的方式保存下来
        log = {
            "logType":      "DBLogs",
            "dbType":       "mysql",
            "startTime":    time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(self.starttime / 1000)),
            "endTime":      time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(time.time()))),
            "runResult":    self.result,
            "runTime":      (int(round(time.time() * 1000)) - self.starttime) / 1000,
            "sql":          self.sql
        }

        return self.result
    
    def __sql(self, sql):
        
        self.sql = sql

        return self.sql

    def __del__(self):

        try:
            
            self.db.close
        
        except Exception as e:
            ...
  

然后就是要构造出执行语句

读语句

'''
Author: psq
Date: 2023-06-29 18:09:52
LastEditors: psq
LastEditTime: 2023-09-12 11:02:09
'''
#!/usr/bin/python
# -*- coding: UTF-8 -*-
import traceback

class Construction(object):
    
    def __init__(self, param):

        self.sql = traceback.extract_stack()[-2][2] + ' '

        self.__fields(param['fields'] if param.__contains__('fields') else '*')

        self.__table(param['table'])

        self.__alias(param['alias'] if param.__contains__('alias') else None)

        self.__join(param['join'] if param.__contains__('join') else [])
        
        self.__where(param['where'] if param.__contains__('where') else None)

        self.__orderby(param['orderby'] if param.__contains__('orderby') else None)

        self.__limit(param['limit'] if param.__contains__('limit') else None)

    def __alias(self, alias):
        
        self.sql += f" as {alias}" if alias != None else ''

    def __join(self, join):

        for x in range(len(join)):
            
            j = join[x]['connect'].lower() if join[x].__contains__('connect') else 'left';

            try:
                
                self.sql += f" {j} join {join[x]['table']} as {join[x]['alias']} on {join[x]['on']}"

            except Exception as e:
                
                RuntimeError('executeErr:', e)

    def __where(self, where):

        if where != None:

            if isinstance(where, str) :
                
                self.sql += f" where {where}"

            else:

                self.sql += " where ("

                for x in range(len(where)):

                    self.sql += f"{str(list(where.keys())[x])} = '{str(list(where.values())[x])}'"

                    if x == int(len(where)) - 1:

                        self.sql += ") " 
                        
                    else:
                        
                        self.sql += " and "

    def __limit(self, limit):

        if limit != None:
            
            if len(limit) == 1:
            
                self.sql += f"limit {str(limit[0])}"
            
            else:
                
                self.sql += f"limit {str(int(limit[0]) * int(limit[1])) + ',' + str(limit[0])}"

    def __table(self, table):

        self.sql += f"from {table}"

    def __fields(self, fields):
        
        if type(fields) != str:

            self.sql += ','.join(fields) + ' ' 
        
        else:
            
            self.sql += fields + ' '

    def __orderby(self, orderby):
        
        if orderby != None:
            
            self.sql += f"order by  {str(orderby)}"
            
        else:
            
            self.sql += ' '

    def __group(self, group):
        pass

    def __having(self, having):
        pass

写语句


#!/usr/bin/python
# -*- coding: UTF-8 -*-
import traceback

class Construction(object):
    
    def __init__(self, parameter:any):

        sqllist = {
            'insert'    :   'insert into',
            'insertall' :   'insert into',
            'update'    :   'update',
            'delete'    :   'delete from'
        }

        self.parameter = parameter

        self.f = traceback.extract_stack()[-2][2]

        self.sql = sqllist[self.f]

        self.__table(self.parameter['table'])

        self.__values(self.parameter['values'] if self.parameter.__contains__('values') else None)
        
        self.__where(self.parameter['where'] if self.parameter.__contains__('where') else None)

    def __table(self, table):
    
        self.sql += f" {table}"

    def __values(self, values):
    
        if self.f == "insert":

            self.sql += f"({str(', '.join(list(values.keys())))}) values ({self.__setVal(values)})"

        if self.f == "insertall":
            
            self.sql += f"({str(', '.join(list(values[0].keys())))}) values "
      
            for x in range(len(values)):
                
                self.sql += f"({self.__setVal(values[x])})"

                if len(values) > int(x + 1): 
                    
                    self.sql += ', '

        if self.f == 'update':

            self.sql += " set "

            if len(values) == 1:
                
                key = list(values.keys())[0]
                
                if isinstance(str(values[key]), str):
                    
                    self.sql += f"{str(key)} = '{str(values[key])}'" 
                    
                else:
                    
                    self.sql += str(values[key])

            else:

                for x in range(len(values)):
                    
                    key = list(values.keys())[x]

                    self.sql += f"{key} = "
                    
                    if isinstance(values[key], str):
                        
                        self.sql +=  f"'{values[key]}' " 
                    
                    else:
                    
                        self.sql += str(values[key])

                    if len(values) > int(x + 1): 
                        
                        self.sql += ', '

    def __where(self, where):

        if where == None: return

        if isinstance(where, str) :
            
            self.sql += f" where {where}"

        else:

            self.sql += " where ("

            for x in range(len(where)):

                self.sql += f"{str(list(where.keys())[x])} = '{str(list(where.values())[x])}'"

                if x == int(len(where)) - 1:
                    
                    self.sql += ") "
                    
                else: 
                    
                    self.sql += " and "


    def __setVal(self, values):
        
        v = list(values.values())

        result = ""

        for x in range(len(v)):

            if isinstance(v[x], str) and v[x][:7] != 'to_date':
            
                result += f"'{v[x]}'"
            
            else:
                
                result += str(v[x])

            if len(v) > int(x + 1):
                
                result += ','

        return result

OK,至此一个简单的orm就完工了。

转载自:https://juejin.cn/post/7278299358910300201
评论
请登录