jellyfin-user-clone/main.go

418 lines
11 KiB
Go

package main
import (
"database/sql"
"fmt"
"io"
"io/ioutil"
"log"
"os"
"path"
"slices"
"strings"
//_ "github.com/glebarez/sqlite"
"github.com/google/uuid"
_ "modernc.org/sqlite"
)
const (
jellyDB = "jellyfin.db"
libDB = "library.db"
)
func main() {
data := os.Args[1]
fromU := os.Args[2]
toU := os.Args[3]
outd := os.Args[4]
workd, err := ioutil.TempDir(os.TempDir(), "jellyfin-user-clone.*")
if err != nil {
log.Fatalf("%v", err)
}
defer func() {
os.RemoveAll(workd)
}()
{
if err := cp(path.Join(data, jellyDB), path.Join(workd, jellyDB)); err != nil {
log.Fatalf("%v", err)
} else if err := cp(path.Join(data, libDB), path.Join(workd, libDB)); err != nil {
log.Fatalf("%v", err)
}
}
jellyfinDB, err := sql.Open("sqlite", path.Join(workd, jellyDB))
if err != nil {
log.Fatalf("%v", err)
}
defer jellyfinDB.Close()
libraryDB, err := sql.Open("sqlite", path.Join(workd, libDB))
if err != nil {
log.Fatalf("%v", err)
}
defer libraryDB.Close()
fromUUID, err := SelectOne[string](jellyfinDB, `SELECT Id FROM Users WHERE Username = $1`, fromU)
if err != nil {
log.Fatalf("%v", err)
}
fromID, err := SelectOne[int](jellyfinDB, `SELECT InternalId FROM Users WHERE Id = $1`, fromUUID)
if err != nil {
log.Fatalf("%v", err)
}
log.Println(fromU, fromUUID, fromID)
if n, err := SelectOne[int](jellyfinDB, `SELECT COUNT(*) FROM Users WHERE Username = $1`, toU); err != nil {
log.Fatalf("%v", err)
} else if n != 1 {
log.Println("creating user", toU, "from", fromU)
nextID, err := SelectOne[int](libraryDB, `SELECT COALESCE(MAX(userId), 0)+1 FROM UserDatas`)
if err != nil {
log.Fatalf("failed to get max userid ever: %v", err)
}
if err := CloneForColumn(jellyfinDB, `Users`, `Username`, fromU, toU, map[string]any{`InternalId`: nextID}); err != nil {
log.Fatalf("%v", err)
}
if n, err := SelectOne[int](jellyfinDB, `SELECT COUNT(*) FROM Users WHERE Username = $1 AND InternalId = $2`, toU, nextID); err != nil {
log.Fatalf("%v", err)
} else if n != 1 {
log.Fatalf("still no username=%q after insert", toU)
}
} else {
log.Println("user", toU, "already exists")
}
toUUID, err := SelectOne[string](jellyfinDB, `SELECT Id FROM Users WHERE Username = $1`, toU)
if err != nil {
log.Fatalf("%v", err)
}
toID, err := SelectOne[int](jellyfinDB, `SELECT InternalId FROM Users WHERE Id = $1`, toUUID)
if err != nil {
log.Fatalf("%v", err)
}
log.Println(toU, toUUID, toID)
for _, db := range []*sql.DB{jellyfinDB, libraryDB} {
tables, err := Tables(db)
if err != nil {
log.Fatalf("%v", err)
}
for _, table := range tables {
columns, err := Columns(db, table)
if err != nil {
log.Fatalf("%v", err)
}
userColumns := slices.DeleteFunc(slices.Clone(columns), func(s string) bool {
return !slices.Contains([]string{
"user",
"userid",
}, strings.ToLower(s))
})
if len(userColumns) == 0 {
continue
}
for _, column := range userColumns {
log.Println(table, column, "...")
if n, err := SelectOne[int](db, fmt.Sprintf(`SELECT COUNT(*) FROM %q WHERE %q = $1`, table, column), fromID); err != nil {
log.Fatalf("%v", err)
} else if n > 0 {
if err := CloneForColumn(db, table, column, fromID, toID, nil); err != nil {
log.Fatalf("%v", err)
}
} else if n, err := SelectOne[int](db, fmt.Sprintf(`SELECT COUNT(*) FROM %q WHERE %q = $1`, table, column), fromUUID); err != nil {
log.Fatalf("%v", err)
} else if n > 0 {
if err := CloneForColumn(db, table, column, fromUUID, toUUID, nil); err != nil {
log.Fatalf("%v", err)
}
}
}
}
}
entries, err := os.ReadDir(workd)
if err != nil {
log.Fatal(err)
}
for _, entry := range entries {
os.MkdirAll(path.Dir(path.Join(path.Join(outd, entry.Name()))), os.ModePerm)
if err := cp(path.Join(workd, entry.Name()), path.Join(outd, entry.Name())); err != nil {
log.Fatal(err)
} else if err := os.Remove(path.Join(workd, entry.Name())); err != nil {
log.Fatal(err)
}
}
}
func CloneForColumn[T any](db *sql.DB, table, column string, from, to T, fixed map[string]any) error {
columns, err := Columns(db, table)
if err != nil {
return err
}
uniqueIntColumns, err := UniqueTypeColumns(db, "INTEGER", table)
if err != nil {
return err
}
uniqueTextColumns, err := UniqueTypeColumns(db, "TEXT", table)
if err != nil {
return err
}
log.Printf("unique text columns %+v, unique int columns %+v", uniqueTextColumns, uniqueIntColumns)
extraUniques := []string{}
for k := range fixed {
extraUniques = append(extraUniques, k)
}
omit := append(append(extraUniques, uniqueIntColumns...), uniqueTextColumns...)
notTheseColumns := slices.DeleteFunc(slices.Clone(columns), func(s string) bool {
return s == column || slices.Contains(omit, s) || slices.Contains(extraUniques, s)
})
for i := range notTheseColumns {
notTheseColumns[i] = fmt.Sprintf("%q", notTheseColumns[i])
}
notNullTextColumns, err := NotNullTextColumns(db, table)
if err != nil {
return err
}
uuidGenColumns := slices.DeleteFunc(notNullTextColumns, func(s string) bool {
return s == column || !slices.Contains(omit, s) || slices.Contains(extraUniques, s)
})
for i := range uuidGenColumns {
uuidGenColumns[i] = fmt.Sprintf("%q", uuidGenColumns[i])
}
notNullIntColumns, err := NotNullIntColumns(db, table)
if err != nil {
return err
}
incrGenColumns := slices.DeleteFunc(notNullIntColumns, func(s string) bool {
return s == column || !slices.Contains(omit, s) || slices.Contains(extraUniques, s)
})
for i := range incrGenColumns {
incrGenColumns[i] = fmt.Sprintf("%q", incrGenColumns[i])
}
return ForEach(db, func(args []any) error {
selectMaxes := ""
for _, col := range incrGenColumns {
selectMaxes += fmt.Sprintf(`(SELECT COALESCE(MAX(%s)+1, 1) FROM %q), `, col, table)
}
values := []any{}
for _ = range uuidGenColumns {
values = append(values, fmt.Sprintf("'%s'", guid()))
}
for _, arg := range args {
values = append(values, *(arg.(*any)))
}
for _, k := range extraUniques {
values = append(values, fixed[k])
}
values = append(values, to)
q := fmt.Sprintf(
`INSERT INTO %q (%s, %q) VALUES (%s %s)`,
table, strings.Join(append(append(incrGenColumns, append(uuidGenColumns, notTheseColumns...)...), extraUniques...), ", "), column,
selectMaxes, strings.Join(slices.Repeat([]string{"?"}, len(values)), ", "),
)
log.Printf("INSERT | %s (%+v)", q, values)
return Exec(db, q, values...)
}, fmt.Sprintf(`SELECT %s FROM %q WHERE %q = $1`, strings.Join(notTheseColumns, ", "), table, column), from)
/*
q := fmt.Sprintf(
`INSERT INTO %q (%q, %s) SELECT $1, %s %s %s FROM %q WHERE %q = $2`,
table, column, strings.Join(append(incrGenColumns, append(uuidGenColumns, notTheseColumns...)...), ", "),
func() string {
incrs := make([]string, len(incrGenColumns))
for i := range incrs {
incrs[i] = fmt.Sprintf(`(SELECT COALESCE(MAX(%s), 1) FROM %q)`, incrGenColumns[i], table)
}
s := strings.Join(incrs, ", ")
if len(incrs) > 0 {
s += ", "
}
return s
}(),
func() string {
uuids := make([]string, len(uuidGenColumns))
for i := range uuids {
uuids[i] = fmt.Sprintf("'%s'", guid())
}
s := strings.Join(uuids, ", ")
if len(uuids) > 0 {
s += ", "
}
return s
}(), strings.Join(notTheseColumns, ", "), table, column,
)
log.Printf("EXEC | %s (%v, %v)", q, to, from)
return Insert(db, q, to, from)
*/
}
func ForEach(db *sql.DB, cb func([]any) error, q string, args ...any) error {
rows, err := db.Query(q, args...)
if err != nil {
return err
}
defer rows.Close()
columns, err := rows.Columns()
if err != nil {
return err
}
cols := make([]any, len(columns))
for rows.Next() {
for i := range cols {
var a any
cols[i] = &a
}
if err := rows.Scan(cols...); err != nil {
return err
}
if err := cb(cols); err != nil {
return err
}
}
return rows.Err()
}
func cp(from, to string) error {
fromF, err := os.Open(from)
if err != nil {
return err
}
defer fromF.Close()
toF, err := os.Create(to)
if err != nil {
return err
}
defer toF.Close()
_, err = io.Copy(toF, fromF)
return err
}
func Tables(db *sql.DB) ([]string, error) {
return Select[string](db, `SELECT name FROM sqlite_schema WHERE type = 'table' AND name NOT LIKE 'sqlite_%'`)
}
func Columns(db *sql.DB, table string) ([]string, error) {
return Select[string](db, `SELECT name FROM PRAGMA_TABLE_INFO($1)`, table)
}
func NotNullTextColumns(db *sql.DB, table string) ([]string, error) {
return Select[string](db, `SELECT name FROM PRAGMA_TABLE_INFO($1) WHERE "notnull" = 1 AND "type" = 'TEXT'`, table)
}
func NotNullIntColumns(db *sql.DB, table string) ([]string, error) {
return Select[string](db, `SELECT name FROM PRAGMA_TABLE_INFO($1) WHERE "notnull" = 1 AND "type" = 'INTEGER'`, table)
}
func UniqueTypeColumns(db *sql.DB, t, table string) ([]string, error) {
pks, err := Select[string](db, `SELECT name FROM PRAGMA_TABLE_INFO($1) WHERE "pk" AND "type" = $2`, table, t)
if err != nil {
return nil, err
}
idxes, err := Select[string](db, `SELECT name AS idx_name FROM PRAGMA_INDEX_LIST($1) WHERE "unique"`, table)
if err != nil {
return nil, err
}
uniqueIdxCols := []string{}
for _, idx := range idxes {
cols, err := Select[string](db, `SELECT name FROM PRAGMA_INDEX_INFO($1)`, idx)
if err != nil {
return nil, err
}
if len(cols) > 1 {
log.Printf("not impl: compound unique indexes like %s.%s's %+v", table, idx, cols)
continue
}
col := cols[0]
if n, err := SelectOne[int](db, `SELECT COUNT(*) FROM PRAGMA_TABLE_INFO($1) WHERE "name" = $2 AND "type" = $3`, table, col, t); err != nil {
return nil, err
} else if n > 0 {
uniqueIdxCols = append(uniqueIdxCols, col)
}
}
cols := append(pks, uniqueIdxCols...)
slices.Sort(cols)
cols = slices.Compact(cols)
return cols, nil
}
func Insert(db *sql.DB, q string, args ...any) error {
_, err := db.Exec(q, args...)
return err
}
func Exec(db *sql.DB, q string, args ...any) error {
_, err := db.Exec(q, args...)
return err
}
func SelectOne[T any](db *sql.DB, q string, args ...any) (T, error) {
var some T
results, err := Select[T](db, q, args...)
if err != nil {
return some, err
}
if len(results) != 1 {
return some, fmt.Errorf("expected 1 result but got %d (%+v)", len(results), results)
}
return results[0], nil
}
func Select[T any](db *sql.DB, q string, args ...any) ([]T, error) {
rows, err := db.Query(q, args...)
if err != nil {
return nil, err
}
defer rows.Close()
results := []T{}
for rows.Next() {
var some T
if err := rows.Scan(&some); err != nil {
return nil, err
}
results = append(results, some)
}
return results, rows.Err()
}
func guid() string {
s := []byte(uuid.New().String())
for i := range s {
if ('A' <= s[i] && s[i] <= 'Z') || ('a' <= s[i] && s[i] <= 'z') {
s[i] = '0' + byte(int(s[i])%10)
}
}
return strings.ToUpper(string(s))
}