likes
comments
collection
share

Golang后端学习笔记 — 6. Golang操作数据库事务的方法

作者站长头像
站长
· 阅读数 21

之前,学习了对数据库的每个表执行CRUD操作。真实的场景中,我们经常需要执行一个事务,它组合了多个表的相关操作。本节学习如何在Golang中实现它。

在开始之前,先聊一下事务。

什么是数据库事务?

它是一个单一的工作单元,通常由多个表操作组成。 比如:在我们的小银行项目中,我们要从张三的账户中向李四的账户中转账10元。该交易就包括5个操作,涉及到accounts表、entries表和transfers表:

  1. 创建一个金额等于10的转账记录(transfers表)
  2. 张三创建一个账目记录,金额为-10(entries表)
  3. 李四创建一个账目记录,金额为+10(entries表)
  4. 更新张三的账户余额,减10元(accounts表)
  5. 更新李四的账户余额,加10元(accounts表)

为什么需要使用数据库事务?

主要原因有2个:

  1. 我们希望这个操作单元可靠且一致,即使系统出现某些故障的情况下也如此。
  2. 在程序和访问数据库之间提供隔离。

为了达到这两个条件,数据库事务必须满足ACID特性,其中:

  • A是原子性(Atomicity),表示要么事务的所有操作成功,要么整个事务失败,一切都回滚,数据不变。
  • C是一致性(Consistency),意思是事务执行后,数据库状态应该保持有效,准确的说,所有写入数据库的数据都必须按照预定义的规则生效,包括约束、级联和触发器。
  • I是隔离(Isolation),这意味着并发运行的所有事务不应该相互影响。(隔离有几种级别,在后面的学习中再介绍)
  • D是耐久性(Durability),意思是一个成功的事务写入的所有数据都必须保存在持久存储中,并且在系统故障的情况也不会丢失,比如,系统重启了。

SQL中如何操作数据库事务?

我们用BEGIN语句开始一个事务,然后编写一系列正常的SQL语句,没有错误的情况,使用COMMIT将事务提交; 当出现错误时,使用ROLLBACK回滚操作,ROLLBACK之前的所有更改都将恢复,数据库保持与执行事务之前相同。 Golang后端学习笔记 — 6. Golang操作数据库事务的方法 对数据库事务有了一些基本的了解之后,我们现在再看在Golang中如何实现它。

在Golang中实现数据库事务

在项目的sqlc目录内,新建store.go文件,在这个文件内,定义一个新的结构体Store,它将提供对数据库的所有操作,包括操作事务。

对于单个查询,我们已经有了SQLC生成的Queries结构体,但是,每个查询仅对1个特定的表执行1次操作,所以这个Queries并不支持事务。这就是为什么需要写个Store结构体来扩展它的功能,在Go中叫做组合(compostion),它是Golang中扩展功能,而不用继承的首选办法。

所以,我们需要Store有个Queries以及sql.DB对象,用sql.DB来创建数据库事务。

type Store struct {
	*Queries
	db *sql.DB
}

继续添加一个函数NewStore来创建新的Store对象:

func NewStore(db *sql.DB) *Store {
	return &Store{
		db:      db,
		Queries: New(db),
	}
}

这里,将sql.DB作为参数,并返回一个Store,只是构建一个新的Store对象并返回它。其中,db是传入进来的参数,而Queries是通过db对象调用New函数来创建的。这个New函数是由sqlc生成的,可以点进去看一下,它创建并返回一个Queries对象。

接下来,再新建个方法来执行通用的数据库事务,思路很简单,它将一个上下文和一个回调函数作为入参,然后它将启动一个新的数据库事务。

func (store *Store) execTx(ctx context.Context, fn func(*Queries) error) error {
	tx, err := store.db.BeginTx(ctx, nil)
	if err != nil {
		return err
	}

	q := New(tx)
	err = fn(q)
	if err != nil {
		if rbErr := tx.Rollback(); rbErr != nil {
			return fmt.Errorf("tx err: %v, rb err: %v", err, rbErr)
		}
		return err
	}

	return tx.Commit()
}

解释一下上面这段代码:

  • tx, err := store.db.BeginTx(ctx, nil),启动一个新的事务,ctx参数上下文,从入参获取到,再传入进去;nil这里是可以传入一个sql.TxOptions设置数据库事务隔离级别的,如果没设置,那么会使用数据库服务器的默认隔离级别,postgres是默认读提交的。在之后的学习中,会再学习这个隔离级别,目前就用nil默认级别就行。
  • if err != nil, 这里如果BeginTx()返回错误,就直接把错误返回。
  • q := New(tx),上面没有错误的情况下,用创建的事务调用New()新建一个返回新的Queries对象,这里与NewStore()里面的New()一样,所不同的是,NewStore里面的New()参数是一个sql.DB,而这里的New()参数是sql.Tx对象,因为New()函数接受一个参数为DBTX接口,所以,可以传入sql.Txsql.DB。这样我们就有了在事务中运行的Queries
  • err = fn(q),调用传入的函数,如果出错,就返回error
  • if err != nil {,如果err不为空,那么我们需要回滚事务,通过tx.Rollback()来实现,回滚异常也会返回个错误rbErr,如果也不为nil,我们需要输出两个错误,所以在返回错误信息之前,将它们组合成1个错误信息,fmt.Errorf("tx err: %v, rb err: %v", err, rbErr);如果回滚成功,只返回原始的事务错误return err
  • 最后,如果所有事务中的操作都正常,只需要使用tx.Commit()来提交事务,就好了。

execTx是小写字母开头,并不会导出,属于私有的,因为,我们不希望外部包直接调用它,我们会为每个特定的业务( transaction)提供一个公有调用函数。

接下来,就编写个转账的函数TransferTx(),在写这个之前,先定义两个结构体,一个用于转账的入参TransferTxParams,一个用于返回的结果TransferTxResult

type TransferTxParams struct {
	FromAccountID int64 `json:"from_account_id"`
	ToAccountID   int64 `json:"to_account_id"`
	Amount        int64 `json:"amount"`
}

type TransferTxResult struct {
	Transfer    Transfer `json:"transfer"`
	FromAccount Account  `json:"from_account"`
	ToAccount   Account  `json:"to_account"`
	FromEntry   Entry    `json:"from_entry"`
	ToEntry     Entry    `json:"to_entry"`
}

之后,写转账这个方法TransferTx():

func (store *Store) TransferTx(ctx context.Context, arg TransferTxParams) (TransferTxResult, error) {
	var result TransferTxResult

	err := store.execTx(ctx, func(q *Queries) error {
		var err error

		// 1. 创建一个金额等于`10`的转账记录
		result.Transfer, err = q.CreateTransfer(ctx, CreateTransferParams{
			FromAccountID: arg.FromAccountID,
			ToAccountID:   arg.ToAccountID,
			Amount:        arg.Amount,
		})
		if err != nil {
			return err
		}

		// 2. 为`FromAccount`创建一个账目记录,金额为`-10`
		result.FromEntry, err = q.CreateEntry(ctx, CreateEntryParams{
			AccountID: arg.FromAccountID,
			Amount:    -arg.Amount,
		})
		if err != nil {
			return err
		}

		// 3. 为`ToAccount`创建一个账目记录,金额为`+10`
		result.ToEntry, err = q.CreateEntry(ctx, CreateEntryParams{
			AccountID: arg.ToAccountID,
			Amount:    arg.Amount,
		})
		if err != nil {
			return err
		}

		// TODO: 更新账户余额操作后面再做

		return err
	})

	return result, err
}

这里用到的q.CreateTransferq.CreateEntry是之前自行练习用sqlc生成的,如果,之前没有做这个练习,可到这里下载sql文件,自行用make sqlc生成出来。

上面的代码,就是按转账拆解的5个步骤,分步实现的。

单元测试

接下来,写它的单元测试,在同级目录里新建store_test.go文件,在写这个单元测试之前,先把main_test.go内容改造一下:

package db

import (
	"database/sql"
	"log"
	"os"
	"testing"

	_ "github.com/lib/pq"
)

const (
	dbDriver = "postgres"
	dbSource = "postgresql://root:123456@localhost:5432/simple_bank?sslmode=disable"
)

var testQueries *Queries
var testDB *sql.DB

func TestMain(m *testing.M) {
	var err error
	testDB, err = sql.Open(dbDriver, dbSource)
	if err != nil {
		log.Fatal("cannot connect to db:", err)
	}

	testQueries = New(testDB)

	os.Exit(m.Run())
}

声明一个全局变量,var testDB *sql.DB,因为需要在store_test.go文件中使用到,把conn, err := sql.Open(dbDriver, dbSource)替换为testDB, err = sql.Open(dbDriver, dbSource),同样需要替换testQueries = New(conn)testQueries = New(testDB)

好,接着继续store_test.go:

package db

import (
	"context"
	"testing"

	"github.com/stretchr/testify/require"
)

func TestTransferTx(t *testing.T) {
	store := NewStore(testDB)

	account1 := createRandomAccount(t)
	account2 := createRandomAccount(t)

	n := 5
	amount := int64(10)

	errs := make(chan error)
	results := make(chan TransferTxResult)

	for i := 0; i < n; i++ {
		go func() {
			result, err := store.TransferTx(context.Background(), TransferTxParams{
				FromAccountID: account1.ID,
				ToAccountID:   account2.ID,
				Amount:        amount,
			})

			errs <- err
			results <- result
		}()
	}

	// 检查结果
	for i := 0; i < n; i++ {
		err := <-errs
		require.NoError(t, err)

		result := <-results
		require.NotEmpty(t, result)

		// check transfer
		transfer := result.Transfer
		require.NotEmpty(t, transfer)
		require.Equal(t, account1.ID, transfer.FromAccountID)
		require.Equal(t, account2.ID, transfer.ToAccountID)
		require.Equal(t, amount, transfer.Amount)
		require.NotZero(t, transfer.ID)
		require.NotZero(t, transfer.CreatedAt)

		_, err = store.GetTransfer(context.Background(), transfer.ID)
		require.NoError(t, err)

		// check entries
		formEntry := result.FromEntry
		require.NotEmpty(t, formEntry)
		require.Equal(t, account1.ID, formEntry.AccountID)
		require.Equal(t, -amount, formEntry.Amount)
		require.NotZero(t, formEntry.ID)
		require.NotZero(t, formEntry.CreatedAt)

		_, err = store.GetEntry(context.Background(), formEntry.ID)
		require.NoError(t, err)

		toEntry := result.ToEntry
		require.NotEmpty(t, toEntry)
		require.Equal(t, account2.ID, toEntry.AccountID)
		require.Equal(t, amount, toEntry.Amount)
		require.NotZero(t, toEntry.ID)
		require.NotZero(t, toEntry.CreatedAt)

		_, err = store.GetEntry(context.Background(), toEntry.ID)
		require.NoError(t, err)

		// TODO: 检查更新后的账户余额
	}
}

解释一下这个单元测试: 因为涉及到数据库事务,必须要非常小心,虽然代码写起来简单,但是,如果不小心处理并发,也很容易成为一场噩梦。所以,为了确保事务正常运行,使用go的协程(Goroutine)创建多个并发来运行它。

这里定义 n := 5,运行5个并发转账,每次从account1转账10到account2上,先不考虑币种的问题。

使用for i := 0; i < n; i++ {循环5次,里面使用go关键字,开始一个新的routine来运行它。

for i := 0; i < n; i++ {
		go func() {
		}()
}

go func() {里面,调用store.TransferTx()进行转账,我们不能在这个循环里面使用testifyrequire来检查运行结果,因为这个函数运行在go routine里面,它与TestTransferTx函数运行的是不同的go routine,所以,不能保证如果条件不满足时,它会停止整个测试。

验证错误和结果的正确方法是,将它们发送回正在运行的主go routine里面进行校验。所以,可以使用管道(channel),由它连接并发的go routine,并允许它们在没有显示锁定的情况下安全的互相共享数据。因此,定义:

	errs := make(chan error)
	results := make(chan TransferTxResult)

其中,1个channel用来接收错误,1个channel用来接收执行结果,使用make来创建channel。现在,在go func() {里面,就可以用箭头运算符将err发送到errs里面,result发送到results里面:

errs <- err
results <- result

接收者在左边,要发送的数据在右边

这样,就可以从外部主go routine中检查这些错误和结果了,为了接收这些数据,同样使用箭头操作符<-

好了,运行一下run test,测试通过;运行run package test,整个包也测试通过。

至此,本节学习完成,下节继续学习`数据库的事务锁以及Golang如何处理死锁