Skip to content

Gorm 插件在多语言重构应用

现状:项目内有很多Movie表相关的查询,包括单条数据查询,列表查询和各种连表查询 需要支持 Movie表 的多语言

  • 直接修改原查询Movie表代码改动过大,使用gorm的插件来支持语言的切换
  • 根据前端的 header Accept-Language 判断需要返回什么语言

Gorm 插件

此处参考gorm官方文档:https://gorm.io/zh_CN/docs/write_plugins.html#Registering-a-Plugin

go
type TranslationPlugin struct{}

func (p *TranslationPlugin) Name() string {
	return "TranslationPlugin"
}

func (p *TranslationPlugin) Initialize(db *gorm.DB) error {
	db.Callback().Query().After("gorm:after_query").Register("translation:load_translations", p.loadTranslations)
	return nil
}

func (p *TranslationPlugin) loadTranslations(db *gorm.DB) {
    ... // todo 
}

实现 loadTranslations

go
func (p *TranslationPlugin) loadTranslations(db *gorm.DB) {
	if db.Statement.Schema.Table == "movie_table" {
        language, ok := db.Statement.Context.Value(LanguageContextKey).(string)
        if !ok || db.Statement.Schema == nil {
            return
        }
		var movies []models.Movie
		switch t := db.Statement.Model.(type) {
		case *[]models.Movie:
			movies = *t
		case *[]*models.Movie:
			for _, m := range *t {
				movies = append(movies, *m)
			}
		default:
			return
		}

		// 覆盖电影的 title 和 overview
		for i, _ := range movies {
			movies[i].Title = "translation.Title" + strconv.Itoa(i)
			movies[i].Overview = "translation.Overview" + strconv.Itoa(i)
		}

		// 更新 db.Statement.Model 为修改后的 movies
		db.Statement.ReflectValue.Set(reflect.ValueOf(movies))
		return
}}

上面的实现支持了Movie对象的赋值,但是查询movie表的返回用到了很多不同的对象接收

这里如果每个对象都判断再赋值会导致switch case块过长,逻辑过于冗长,故而牺牲了一些性能使用反射处理

golang 反射使用

[12.Golang 反射性能优化.md](12.Golang 反射性能优化.md)

go
func setFieldValue(structValue reflect.Value, fieldName, newValue string) {
	field := structValue.FieldByName(fieldName)
	if field.IsValid() && field.CanSet() && field.Kind() == reflect.String {
		field.SetString(newValue)
	}
}

func getFieldValue(structValue reflect.Value, fieldName string) int64 {
	field := structValue.FieldByName(fieldName)
	if field.IsValid() && field.Kind() == reflect.Int64 {
		return field.Int()
	}
	return 0
}

完整 testcase 示例

go
package movie_service

import (
	"context"
	"xxx/models"
	"xxx/models/enums"
	"xxx/models/params/res"
	"fmt"
	"gorm.io/driver/postgres"
	"gorm.io/gorm"
	"reflect"
	"strconv"
	"testing"
)

type TranslationContextKey string

const LanguageContextKey TranslationContextKey = "language"

type TranslationPlugin struct{}

func (p *TranslationPlugin) Name() string {
	return "TranslationPlugin"
}

func (p *TranslationPlugin) Initialize(db *gorm.DB) error {
	db.Callback().Query().After("gorm:after_query").Register("translation:load_translations", p.loadTranslations)
	return nil
}

func (p *TranslationPlugin) loadTranslations(db *gorm.DB) {
	language, ok := db.Statement.Context.Value(LanguageContextKey).(string)
	fmt.Println(language)
	if !ok || db.Statement.Schema == nil {
		return
	}

	if db.Statement.Schema.Table == "movie_table" {
		fmt.Println("SQL:", db.Statement.SQL.String())

		modelValue := reflect.ValueOf(db.Statement.Model)
		modelValue = dereferenceIfNeeded(modelValue)

		if modelValue.Kind() == reflect.Slice {
			// Traverse the slice and modify fields
			for i := 0; i < modelValue.Len(); i++ {
				item := modelValue.Index(i)
				item = dereferenceIfNeeded(item)
				if item.Kind() == reflect.Struct {
					modifyFields(item, language)
				}
			}
		} else if modelValue.Kind() == reflect.Pointer {
			item := dereferenceIfNeeded(modelValue)
			if item.Kind() == reflect.Struct {
				modifyFields(item, language)
			}
		}
	}
}

func modifyFields(structValue reflect.Value, language string) {
	structValue = dereferenceIfNeeded(structValue)
	id := getFieldValue(structValue, "ID")
	setFieldValue(structValue, "Title", getTitleByID(id, language))
	setFieldValue(structValue, "Overview", getOverviewByID(id, language))
}

func getTitleByID(id int64, language string) string {
	return "func => query from db: translation.Title_" + language + strconv.Itoa(int(id))
}

func getOverviewByID(id int64, language string) string {
	return "func => query from db: translation.Overview_" + language + strconv.Itoa(int(id))
}

func setFieldValue(structValue reflect.Value, fieldName, newValue string) {
	field := structValue.FieldByName(fieldName)
	if field.IsValid() && field.CanSet() && field.Kind() == reflect.String {
		field.SetString(newValue)
	}
}

func getFieldValue(structValue reflect.Value, fieldName string) int64 {
	field := structValue.FieldByName(fieldName)
	if field.IsValid() && field.Kind() == reflect.Int64 {
		return field.Int()
	}
	return 0
}

func dereferenceIfNeeded(value reflect.Value) reflect.Value {
	if value.Kind() == reflect.Ptr {
		if value.IsNil() {
			return reflect.Zero(value.Type().Elem())
		}
		return value.Elem()
	}
	return value
}

func TestPlugin(t *testing.T) {
	dsn := fmt.Sprintf("host=%s user=%s password=%s dbname=%s port=%s sslmode=disable TimeZone=UTC",
		"xxx",
		"xxx",
		"xxx",
		"xxx",
		"5432")
	db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{})
	if err != nil {
		panic("failed to connect database")
	}

	if err := db.Use(&TranslationPlugin{}); err != nil {
		panic("failed use TranslationPlugin")
	}

	ctx := context.WithValue(context.Background(), LanguageContextKey, "en")
	var movie *models.Movie
	tx := db.Debug().WithContext(ctx).Limit(1).First(&movie)
	if tx.Error != nil {
		fmt.Println(tx.Error)
	} else {
		fmt.Println(movie)
	}

	var movies []*res.MovieInfoForVodRes
	title := "A"

	err = db.WithContext(ctx).Table("movie_table m").
		Select("DISTINCT m.*, cover.url as cover, poster.url as poster, p.is_offline").
		Joins("INNER JOIN license_table l on m.id=l.movie_id").
		Joins("LEFT JOIN product_table p ON p.source_id = l.id and p.source_type=1").
		Joins("LEFT JOIN movie_alternative_title_table malt on m.id = malt.movie_id").
		Joins("LEFT JOIN movie_images_table cover on cover.movie_id=m.id AND cover.is_movie_cover=true").
		Joins("LEFT JOIN movie_images_table poster on poster.movie_id=m.id AND poster.is_movie_poster=true").
		Where("m.type =? AND (malt.title ILIKE ? OR m.title ILIKE ?)", enums.TVShow, "%"+title+"%", "%"+title+"%").
		Limit(10).Offset(0).Find(&movies).Error

	if err != nil {
		fmt.Println(err)
	} else {
		for _, movie := range movies {
			fmt.Printf("%+v\n", movie)
		}
	}

	var movies2 []*models.Movie
	err = db.Debug().WithContext(ctx).Limit(5).Find(&movies2).Error

	if err != nil {
		fmt.Println(err)
	} else {
		for _, movie := range movies2 {
			fmt.Printf("%+v\n", movie)
		}
	}
}

文章来源于自己总结和网络转载,内容如有任何问题,请大佬斧正!联系我