Fork github.com/mattn/go-sqlite3 with adjustment for go1.16.2
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

2570 lines
62 KiB

  1. // Copyright (C) 2019 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
  2. //
  3. // Use of this source code is governed by an MIT-style
  4. // license that can be found in the LICENSE file.
  5. //go:build cgo
  6. // +build cgo
  7. package sqlite3
  8. import (
  9. "bytes"
  10. "database/sql"
  11. "database/sql/driver"
  12. "errors"
  13. "fmt"
  14. "io/ioutil"
  15. "math/rand"
  16. "net/url"
  17. "os"
  18. "reflect"
  19. "regexp"
  20. "runtime"
  21. "strconv"
  22. "strings"
  23. "sync"
  24. "testing"
  25. "time"
  26. )
  27. func TempFilename(t testing.TB) string {
  28. f, err := ioutil.TempFile("", "go-sqlite3-test-")
  29. if err != nil {
  30. t.Fatal(err)
  31. }
  32. f.Close()
  33. return f.Name()
  34. }
  35. func doTestOpen(t *testing.T, option string) (string, error) {
  36. tempFilename := TempFilename(t)
  37. url := tempFilename + option
  38. defer func() {
  39. err := os.Remove(tempFilename)
  40. if err != nil {
  41. t.Error("temp file remove error:", err)
  42. }
  43. }()
  44. db, err := sql.Open("sqlite3", url)
  45. if err != nil {
  46. return "Failed to open database:", err
  47. }
  48. defer func() {
  49. err = db.Close()
  50. if err != nil {
  51. t.Error("db close error:", err)
  52. }
  53. }()
  54. err = db.Ping()
  55. if err != nil {
  56. return "ping error:", err
  57. }
  58. _, err = db.Exec("drop table foo")
  59. _, err = db.Exec("create table foo (id integer)")
  60. if err != nil {
  61. return "Failed to create table:", err
  62. }
  63. if stat, err := os.Stat(tempFilename); err != nil || stat.IsDir() {
  64. return "Failed to create ./foo.db", nil
  65. }
  66. return "", nil
  67. }
  68. func TestOpen(t *testing.T) {
  69. cases := map[string]bool{
  70. "": true,
  71. "?_txlock=immediate": true,
  72. "?_txlock=deferred": true,
  73. "?_txlock=exclusive": true,
  74. "?_txlock=bogus": false,
  75. }
  76. for option, expectedPass := range cases {
  77. result, err := doTestOpen(t, option)
  78. if result == "" {
  79. if !expectedPass {
  80. errmsg := fmt.Sprintf("_txlock error not caught at dbOpen with option: %s", option)
  81. t.Fatal(errmsg)
  82. }
  83. } else if expectedPass {
  84. if err == nil {
  85. t.Fatal(result)
  86. } else {
  87. t.Fatal(result, err)
  88. }
  89. }
  90. }
  91. }
  92. func TestOpenWithVFS(t *testing.T) {
  93. filename := t.Name() + ".sqlite"
  94. if err := os.Remove(filename); err != nil && !os.IsNotExist(err) {
  95. t.Fatal(err)
  96. }
  97. defer os.Remove(filename)
  98. db, err := sql.Open("sqlite3", fmt.Sprintf("file:%s?vfs=hello", filename))
  99. if err != nil {
  100. t.Fatal("Failed to open", err)
  101. }
  102. err = db.Ping()
  103. if err == nil {
  104. t.Fatal("Failed to open", err)
  105. }
  106. db.Close()
  107. defer os.Remove(filename)
  108. var vfs string
  109. if runtime.GOOS == "windows" {
  110. vfs = "win32-none"
  111. } else {
  112. vfs = "unix-none"
  113. }
  114. db, err = sql.Open("sqlite3", fmt.Sprintf("file:%s?vfs=%s", filename, vfs))
  115. if err != nil {
  116. t.Fatal("Failed to open", err)
  117. }
  118. err = db.Ping()
  119. if err != nil {
  120. t.Fatal("Failed to ping", err)
  121. }
  122. db.Close()
  123. }
  124. func TestOpenNoCreate(t *testing.T) {
  125. filename := t.Name() + ".sqlite"
  126. if err := os.Remove(filename); err != nil && !os.IsNotExist(err) {
  127. t.Fatal(err)
  128. }
  129. defer os.Remove(filename)
  130. // https://golang.org/pkg/database/sql/#Open
  131. // "Open may just validate its arguments without creating a connection
  132. // to the database. To verify that the data source name is valid, call Ping."
  133. db, err := sql.Open("sqlite3", fmt.Sprintf("file:%s?mode=rw", filename))
  134. if err == nil {
  135. defer db.Close()
  136. err = db.Ping()
  137. if err == nil {
  138. t.Fatal("expected error from Open or Ping")
  139. }
  140. }
  141. sqlErr, ok := err.(Error)
  142. if !ok {
  143. t.Fatalf("expected sqlite3.Error, but got %T", err)
  144. }
  145. if sqlErr.Code != ErrCantOpen {
  146. t.Fatalf("expected SQLITE_CANTOPEN, but got %v", sqlErr)
  147. }
  148. // make sure database file truly was not created
  149. if _, err := os.Stat(filename); !os.IsNotExist(err) {
  150. if err != nil {
  151. t.Fatal(err)
  152. }
  153. t.Fatal("expected database file to not exist")
  154. }
  155. // verify that it works if the mode is "rwc" instead
  156. db, err = sql.Open("sqlite3", fmt.Sprintf("file:%s?mode=rwc", filename))
  157. if err != nil {
  158. t.Fatal(err)
  159. }
  160. defer db.Close()
  161. if err := db.Ping(); err != nil {
  162. t.Fatal(err)
  163. }
  164. // make sure database file truly was created
  165. if _, err := os.Stat(filename); err != nil {
  166. if !os.IsNotExist(err) {
  167. t.Fatal(err)
  168. }
  169. t.Fatal("expected database file to exist")
  170. }
  171. }
  172. func TestReadonly(t *testing.T) {
  173. tempFilename := TempFilename(t)
  174. defer os.Remove(tempFilename)
  175. db1, err := sql.Open("sqlite3", "file:"+tempFilename)
  176. if err != nil {
  177. t.Fatal(err)
  178. }
  179. db1.Exec("CREATE TABLE test (x int, y float)")
  180. db2, err := sql.Open("sqlite3", "file:"+tempFilename+"?mode=ro")
  181. if err != nil {
  182. t.Fatal(err)
  183. }
  184. _ = db2
  185. _, err = db2.Exec("INSERT INTO test VALUES (1, 3.14)")
  186. if err == nil {
  187. t.Fatal("didn't expect INSERT into read-only database to work")
  188. }
  189. }
  190. func TestForeignKeys(t *testing.T) {
  191. cases := map[string]bool{
  192. "?_foreign_keys=1": true,
  193. "?_foreign_keys=0": false,
  194. }
  195. for option, want := range cases {
  196. fname := TempFilename(t)
  197. uri := "file:" + fname + option
  198. db, err := sql.Open("sqlite3", uri)
  199. if err != nil {
  200. os.Remove(fname)
  201. t.Errorf("sql.Open(\"sqlite3\", %q): %v", uri, err)
  202. continue
  203. }
  204. var enabled bool
  205. err = db.QueryRow("PRAGMA foreign_keys;").Scan(&enabled)
  206. db.Close()
  207. os.Remove(fname)
  208. if err != nil {
  209. t.Errorf("query foreign_keys for %s: %v", uri, err)
  210. continue
  211. }
  212. if enabled != want {
  213. t.Errorf("\"PRAGMA foreign_keys;\" for %q = %t; want %t", uri, enabled, want)
  214. continue
  215. }
  216. }
  217. }
  218. func TestDeferredForeignKey(t *testing.T) {
  219. fname := TempFilename(t)
  220. uri := "file:" + fname + "?_foreign_keys=1"
  221. db, err := sql.Open("sqlite3", uri)
  222. if err != nil {
  223. os.Remove(fname)
  224. t.Errorf("sql.Open(\"sqlite3\", %q): %v", uri, err)
  225. }
  226. _, err = db.Exec("CREATE TABLE bar (id INTEGER PRIMARY KEY)")
  227. if err != nil {
  228. t.Errorf("failed creating tables: %v", err)
  229. }
  230. _, err = db.Exec("CREATE TABLE foo (bar_id INTEGER, FOREIGN KEY(bar_id) REFERENCES bar(id) DEFERRABLE INITIALLY DEFERRED)")
  231. if err != nil {
  232. t.Errorf("failed creating tables: %v", err)
  233. }
  234. tx, err := db.Begin()
  235. if err != nil {
  236. t.Errorf("Failed to begin transaction: %v", err)
  237. }
  238. _, err = tx.Exec("INSERT INTO foo (bar_id) VALUES (123)")
  239. if err != nil {
  240. t.Errorf("Failed to insert row: %v", err)
  241. }
  242. err = tx.Commit()
  243. if err == nil {
  244. t.Errorf("Expected an error: %v", err)
  245. }
  246. _, err = db.Begin()
  247. if err != nil {
  248. t.Errorf("Failed to begin transaction: %v", err)
  249. }
  250. db.Close()
  251. os.Remove(fname)
  252. }
  253. func TestRecursiveTriggers(t *testing.T) {
  254. cases := map[string]bool{
  255. "?_recursive_triggers=1": true,
  256. "?_recursive_triggers=0": false,
  257. }
  258. for option, want := range cases {
  259. fname := TempFilename(t)
  260. uri := "file:" + fname + option
  261. db, err := sql.Open("sqlite3", uri)
  262. if err != nil {
  263. os.Remove(fname)
  264. t.Errorf("sql.Open(\"sqlite3\", %q): %v", uri, err)
  265. continue
  266. }
  267. var enabled bool
  268. err = db.QueryRow("PRAGMA recursive_triggers;").Scan(&enabled)
  269. db.Close()
  270. os.Remove(fname)
  271. if err != nil {
  272. t.Errorf("query recursive_triggers for %s: %v", uri, err)
  273. continue
  274. }
  275. if enabled != want {
  276. t.Errorf("\"PRAGMA recursive_triggers;\" for %q = %t; want %t", uri, enabled, want)
  277. continue
  278. }
  279. }
  280. }
  281. func TestClose(t *testing.T) {
  282. tempFilename := TempFilename(t)
  283. defer os.Remove(tempFilename)
  284. db, err := sql.Open("sqlite3", tempFilename)
  285. if err != nil {
  286. t.Fatal("Failed to open database:", err)
  287. }
  288. _, err = db.Exec("drop table foo")
  289. _, err = db.Exec("create table foo (id integer)")
  290. if err != nil {
  291. t.Fatal("Failed to create table:", err)
  292. }
  293. stmt, err := db.Prepare("select id from foo where id = ?")
  294. if err != nil {
  295. t.Fatal("Failed to select records:", err)
  296. }
  297. db.Close()
  298. _, err = stmt.Exec(1)
  299. if err == nil {
  300. t.Fatal("Failed to operate closed statement")
  301. }
  302. }
  303. func TestInsert(t *testing.T) {
  304. tempFilename := TempFilename(t)
  305. defer os.Remove(tempFilename)
  306. db, err := sql.Open("sqlite3", tempFilename)
  307. if err != nil {
  308. t.Fatal("Failed to open database:", err)
  309. }
  310. defer db.Close()
  311. _, err = db.Exec("drop table foo")
  312. _, err = db.Exec("create table foo (id integer)")
  313. if err != nil {
  314. t.Fatal("Failed to create table:", err)
  315. }
  316. res, err := db.Exec("insert into foo(id) values(123)")
  317. if err != nil {
  318. t.Fatal("Failed to insert record:", err)
  319. }
  320. affected, _ := res.RowsAffected()
  321. if affected != 1 {
  322. t.Fatalf("Expected %d for affected rows, but %d:", 1, affected)
  323. }
  324. rows, err := db.Query("select id from foo")
  325. if err != nil {
  326. t.Fatal("Failed to select records:", err)
  327. }
  328. defer rows.Close()
  329. rows.Next()
  330. var result int
  331. rows.Scan(&result)
  332. if result != 123 {
  333. t.Errorf("Expected %d for fetched result, but %d:", 123, result)
  334. }
  335. }
  336. func TestUpsert(t *testing.T) {
  337. _, n, _ := Version()
  338. if n < 3024000 {
  339. t.Skip("UPSERT requires sqlite3 >= 3.24.0")
  340. }
  341. tempFilename := TempFilename(t)
  342. defer os.Remove(tempFilename)
  343. db, err := sql.Open("sqlite3", tempFilename)
  344. if err != nil {
  345. t.Fatal("Failed to open database:", err)
  346. }
  347. defer db.Close()
  348. _, err = db.Exec("drop table foo")
  349. _, err = db.Exec("create table foo (name string primary key, counter integer)")
  350. if err != nil {
  351. t.Fatal("Failed to create table:", err)
  352. }
  353. for i := 0; i < 10; i++ {
  354. res, err := db.Exec("insert into foo(name, counter) values('key', 1) on conflict (name) do update set counter=counter+1")
  355. if err != nil {
  356. t.Fatal("Failed to upsert record:", err)
  357. }
  358. affected, _ := res.RowsAffected()
  359. if affected != 1 {
  360. t.Fatalf("Expected %d for affected rows, but %d:", 1, affected)
  361. }
  362. }
  363. rows, err := db.Query("select name, counter from foo")
  364. if err != nil {
  365. t.Fatal("Failed to select records:", err)
  366. }
  367. defer rows.Close()
  368. rows.Next()
  369. var resultName string
  370. var resultCounter int
  371. rows.Scan(&resultName, &resultCounter)
  372. if resultName != "key" {
  373. t.Errorf("Expected %s for fetched result, but %s:", "key", resultName)
  374. }
  375. if resultCounter != 10 {
  376. t.Errorf("Expected %d for fetched result, but %d:", 10, resultCounter)
  377. }
  378. }
  379. func TestUpdate(t *testing.T) {
  380. tempFilename := TempFilename(t)
  381. defer os.Remove(tempFilename)
  382. db, err := sql.Open("sqlite3", tempFilename)
  383. if err != nil {
  384. t.Fatal("Failed to open database:", err)
  385. }
  386. defer db.Close()
  387. _, err = db.Exec("drop table foo")
  388. _, err = db.Exec("create table foo (id integer)")
  389. if err != nil {
  390. t.Fatal("Failed to create table:", err)
  391. }
  392. res, err := db.Exec("insert into foo(id) values(123)")
  393. if err != nil {
  394. t.Fatal("Failed to insert record:", err)
  395. }
  396. expected, err := res.LastInsertId()
  397. if err != nil {
  398. t.Fatal("Failed to get LastInsertId:", err)
  399. }
  400. affected, _ := res.RowsAffected()
  401. if err != nil {
  402. t.Fatal("Failed to get RowsAffected:", err)
  403. }
  404. if affected != 1 {
  405. t.Fatalf("Expected %d for affected rows, but %d:", 1, affected)
  406. }
  407. res, err = db.Exec("update foo set id = 234")
  408. if err != nil {
  409. t.Fatal("Failed to update record:", err)
  410. }
  411. lastID, err := res.LastInsertId()
  412. if err != nil {
  413. t.Fatal("Failed to get LastInsertId:", err)
  414. }
  415. if expected != lastID {
  416. t.Errorf("Expected %q for last Id, but %q:", expected, lastID)
  417. }
  418. affected, _ = res.RowsAffected()
  419. if err != nil {
  420. t.Fatal("Failed to get RowsAffected:", err)
  421. }
  422. if affected != 1 {
  423. t.Fatalf("Expected %d for affected rows, but %d:", 1, affected)
  424. }
  425. rows, err := db.Query("select id from foo")
  426. if err != nil {
  427. t.Fatal("Failed to select records:", err)
  428. }
  429. defer rows.Close()
  430. rows.Next()
  431. var result int
  432. rows.Scan(&result)
  433. if result != 234 {
  434. t.Errorf("Expected %d for fetched result, but %d:", 234, result)
  435. }
  436. }
  437. func TestDelete(t *testing.T) {
  438. tempFilename := TempFilename(t)
  439. defer os.Remove(tempFilename)
  440. db, err := sql.Open("sqlite3", tempFilename)
  441. if err != nil {
  442. t.Fatal("Failed to open database:", err)
  443. }
  444. defer db.Close()
  445. _, err = db.Exec("drop table foo")
  446. _, err = db.Exec("create table foo (id integer)")
  447. if err != nil {
  448. t.Fatal("Failed to create table:", err)
  449. }
  450. res, err := db.Exec("insert into foo(id) values(123)")
  451. if err != nil {
  452. t.Fatal("Failed to insert record:", err)
  453. }
  454. expected, err := res.LastInsertId()
  455. if err != nil {
  456. t.Fatal("Failed to get LastInsertId:", err)
  457. }
  458. affected, err := res.RowsAffected()
  459. if err != nil {
  460. t.Fatal("Failed to get RowsAffected:", err)
  461. }
  462. if affected != 1 {
  463. t.Errorf("Expected %d for cout of affected rows, but %q:", 1, affected)
  464. }
  465. res, err = db.Exec("delete from foo where id = 123")
  466. if err != nil {
  467. t.Fatal("Failed to delete record:", err)
  468. }
  469. lastID, err := res.LastInsertId()
  470. if err != nil {
  471. t.Fatal("Failed to get LastInsertId:", err)
  472. }
  473. if expected != lastID {
  474. t.Errorf("Expected %q for last Id, but %q:", expected, lastID)
  475. }
  476. affected, err = res.RowsAffected()
  477. if err != nil {
  478. t.Fatal("Failed to get RowsAffected:", err)
  479. }
  480. if affected != 1 {
  481. t.Errorf("Expected %d for cout of affected rows, but %q:", 1, affected)
  482. }
  483. rows, err := db.Query("select id from foo")
  484. if err != nil {
  485. t.Fatal("Failed to select records:", err)
  486. }
  487. defer rows.Close()
  488. if rows.Next() {
  489. t.Error("Fetched row but expected not rows")
  490. }
  491. }
  492. func TestBooleanRoundtrip(t *testing.T) {
  493. tempFilename := TempFilename(t)
  494. defer os.Remove(tempFilename)
  495. db, err := sql.Open("sqlite3", tempFilename)
  496. if err != nil {
  497. t.Fatal("Failed to open database:", err)
  498. }
  499. defer db.Close()
  500. _, err = db.Exec("DROP TABLE foo")
  501. _, err = db.Exec("CREATE TABLE foo(id INTEGER, value BOOL)")
  502. if err != nil {
  503. t.Fatal("Failed to create table:", err)
  504. }
  505. _, err = db.Exec("INSERT INTO foo(id, value) VALUES(1, ?)", true)
  506. if err != nil {
  507. t.Fatal("Failed to insert true value:", err)
  508. }
  509. _, err = db.Exec("INSERT INTO foo(id, value) VALUES(2, ?)", false)
  510. if err != nil {
  511. t.Fatal("Failed to insert false value:", err)
  512. }
  513. rows, err := db.Query("SELECT id, value FROM foo")
  514. if err != nil {
  515. t.Fatal("Unable to query foo table:", err)
  516. }
  517. defer rows.Close()
  518. for rows.Next() {
  519. var id int
  520. var value bool
  521. if err := rows.Scan(&id, &value); err != nil {
  522. t.Error("Unable to scan results:", err)
  523. continue
  524. }
  525. if id == 1 && !value {
  526. t.Error("Value for id 1 should be true, not false")
  527. } else if id == 2 && value {
  528. t.Error("Value for id 2 should be false, not true")
  529. }
  530. }
  531. }
  532. func timezone(t time.Time) string { return t.Format("-07:00") }
  533. func TestTimestamp(t *testing.T) {
  534. tempFilename := TempFilename(t)
  535. defer os.Remove(tempFilename)
  536. db, err := sql.Open("sqlite3", tempFilename)
  537. if err != nil {
  538. t.Fatal("Failed to open database:", err)
  539. }
  540. defer db.Close()
  541. _, err = db.Exec("DROP TABLE foo")
  542. _, err = db.Exec("CREATE TABLE foo(id INTEGER, ts timeSTAMP, dt DATETIME)")
  543. if err != nil {
  544. t.Fatal("Failed to create table:", err)
  545. }
  546. timestamp1 := time.Date(2012, time.April, 6, 22, 50, 0, 0, time.UTC)
  547. timestamp2 := time.Date(2006, time.January, 2, 15, 4, 5, 123456789, time.UTC)
  548. timestamp3 := time.Date(2012, time.November, 4, 0, 0, 0, 0, time.UTC)
  549. tzTest := time.FixedZone("TEST", -9*3600-13*60)
  550. tests := []struct {
  551. value interface{}
  552. expected time.Time
  553. }{
  554. {"nonsense", time.Time{}},
  555. {"0000-00-00 00:00:00", time.Time{}},
  556. {time.Time{}.Unix(), time.Time{}},
  557. {timestamp1, timestamp1},
  558. {timestamp2.Unix(), timestamp2.Truncate(time.Second)},
  559. {timestamp2.UnixNano() / int64(time.Millisecond), timestamp2.Truncate(time.Millisecond)},
  560. {timestamp1.In(tzTest), timestamp1.In(tzTest)},
  561. {timestamp1.Format("2006-01-02 15:04:05.000"), timestamp1},
  562. {timestamp1.Format("2006-01-02T15:04:05.000"), timestamp1},
  563. {timestamp1.Format("2006-01-02 15:04:05"), timestamp1},
  564. {timestamp1.Format("2006-01-02T15:04:05"), timestamp1},
  565. {timestamp2, timestamp2},
  566. {"2006-01-02 15:04:05.123456789", timestamp2},
  567. {"2006-01-02T15:04:05.123456789", timestamp2},
  568. {"2006-01-02T05:51:05.123456789-09:13", timestamp2.In(tzTest)},
  569. {"2012-11-04", timestamp3},
  570. {"2012-11-04 00:00", timestamp3},
  571. {"2012-11-04 00:00:00", timestamp3},
  572. {"2012-11-04 00:00:00.000", timestamp3},
  573. {"2012-11-04T00:00", timestamp3},
  574. {"2012-11-04T00:00:00", timestamp3},
  575. {"2012-11-04T00:00:00.000", timestamp3},
  576. {"2006-01-02T15:04:05.123456789Z", timestamp2},
  577. {"2012-11-04Z", timestamp3},
  578. {"2012-11-04 00:00Z", timestamp3},
  579. {"2012-11-04 00:00:00Z", timestamp3},
  580. {"2012-11-04 00:00:00.000Z", timestamp3},
  581. {"2012-11-04T00:00Z", timestamp3},
  582. {"2012-11-04T00:00:00Z", timestamp3},
  583. {"2012-11-04T00:00:00.000Z", timestamp3},
  584. }
  585. for i := range tests {
  586. _, err = db.Exec("INSERT INTO foo(id, ts, dt) VALUES(?, ?, ?)", i, tests[i].value, tests[i].value)
  587. if err != nil {
  588. t.Fatal("Failed to insert timestamp:", err)
  589. }
  590. }
  591. rows, err := db.Query("SELECT id, ts, dt FROM foo ORDER BY id ASC")
  592. if err != nil {
  593. t.Fatal("Unable to query foo table:", err)
  594. }
  595. defer rows.Close()
  596. seen := 0
  597. for rows.Next() {
  598. var id int
  599. var ts, dt time.Time
  600. if err := rows.Scan(&id, &ts, &dt); err != nil {
  601. t.Error("Unable to scan results:", err)
  602. continue
  603. }
  604. if id < 0 || id >= len(tests) {
  605. t.Error("Bad row id: ", id)
  606. continue
  607. }
  608. seen++
  609. if !tests[id].expected.Equal(ts) {
  610. t.Errorf("Timestamp value for id %v (%v) should be %v, not %v", id, tests[id].value, tests[id].expected, dt)
  611. }
  612. if !tests[id].expected.Equal(dt) {
  613. t.Errorf("Datetime value for id %v (%v) should be %v, not %v", id, tests[id].value, tests[id].expected, dt)
  614. }
  615. if timezone(tests[id].expected) != timezone(ts) {
  616. t.Errorf("Timezone for id %v (%v) should be %v, not %v", id, tests[id].value,
  617. timezone(tests[id].expected), timezone(ts))
  618. }
  619. if timezone(tests[id].expected) != timezone(dt) {
  620. t.Errorf("Timezone for id %v (%v) should be %v, not %v", id, tests[id].value,
  621. timezone(tests[id].expected), timezone(dt))
  622. }
  623. }
  624. if seen != len(tests) {
  625. t.Errorf("Expected to see %d rows", len(tests))
  626. }
  627. }
  628. func TestBoolean(t *testing.T) {
  629. tempFilename := TempFilename(t)
  630. defer os.Remove(tempFilename)
  631. db, err := sql.Open("sqlite3", tempFilename)
  632. if err != nil {
  633. t.Fatal("Failed to open database:", err)
  634. }
  635. defer db.Close()
  636. _, err = db.Exec("CREATE TABLE foo(id INTEGER, fbool BOOLEAN)")
  637. if err != nil {
  638. t.Fatal("Failed to create table:", err)
  639. }
  640. bool1 := true
  641. _, err = db.Exec("INSERT INTO foo(id, fbool) VALUES(1, ?)", bool1)
  642. if err != nil {
  643. t.Fatal("Failed to insert boolean:", err)
  644. }
  645. bool2 := false
  646. _, err = db.Exec("INSERT INTO foo(id, fbool) VALUES(2, ?)", bool2)
  647. if err != nil {
  648. t.Fatal("Failed to insert boolean:", err)
  649. }
  650. bool3 := "nonsense"
  651. _, err = db.Exec("INSERT INTO foo(id, fbool) VALUES(3, ?)", bool3)
  652. if err != nil {
  653. t.Fatal("Failed to insert nonsense:", err)
  654. }
  655. rows, err := db.Query("SELECT id, fbool FROM foo where fbool = ?", bool1)
  656. if err != nil {
  657. t.Fatal("Unable to query foo table:", err)
  658. }
  659. counter := 0
  660. var id int
  661. var fbool bool
  662. for rows.Next() {
  663. if err := rows.Scan(&id, &fbool); err != nil {
  664. t.Fatal("Unable to scan results:", err)
  665. }
  666. counter++
  667. }
  668. if counter != 1 {
  669. t.Fatalf("Expected 1 row but %v", counter)
  670. }
  671. if id != 1 && !fbool {
  672. t.Fatalf("Value for id 1 should be %v, not %v", bool1, fbool)
  673. }
  674. rows, err = db.Query("SELECT id, fbool FROM foo where fbool = ?", bool2)
  675. if err != nil {
  676. t.Fatal("Unable to query foo table:", err)
  677. }
  678. counter = 0
  679. for rows.Next() {
  680. if err := rows.Scan(&id, &fbool); err != nil {
  681. t.Fatal("Unable to scan results:", err)
  682. }
  683. counter++
  684. }
  685. if counter != 1 {
  686. t.Fatalf("Expected 1 row but %v", counter)
  687. }
  688. if id != 2 && fbool {
  689. t.Fatalf("Value for id 2 should be %v, not %v", bool2, fbool)
  690. }
  691. // make sure "nonsense" triggered an error
  692. rows, err = db.Query("SELECT id, fbool FROM foo where id=?;", 3)
  693. if err != nil {
  694. t.Fatal("Unable to query foo table:", err)
  695. }
  696. rows.Next()
  697. err = rows.Scan(&id, &fbool)
  698. if err == nil {
  699. t.Error("Expected error from \"nonsense\" bool")
  700. }
  701. }
  702. func TestFloat32(t *testing.T) {
  703. tempFilename := TempFilename(t)
  704. defer os.Remove(tempFilename)
  705. db, err := sql.Open("sqlite3", tempFilename)
  706. if err != nil {
  707. t.Fatal("Failed to open database:", err)
  708. }
  709. defer db.Close()
  710. _, err = db.Exec("CREATE TABLE foo(id INTEGER)")
  711. if err != nil {
  712. t.Fatal("Failed to create table:", err)
  713. }
  714. _, err = db.Exec("INSERT INTO foo(id) VALUES(null)")
  715. if err != nil {
  716. t.Fatal("Failed to insert null:", err)
  717. }
  718. rows, err := db.Query("SELECT id FROM foo")
  719. if err != nil {
  720. t.Fatal("Unable to query foo table:", err)
  721. }
  722. if !rows.Next() {
  723. t.Fatal("Unable to query results:", err)
  724. }
  725. var id interface{}
  726. if err := rows.Scan(&id); err != nil {
  727. t.Fatal("Unable to scan results:", err)
  728. }
  729. if id != nil {
  730. t.Error("Expected nil but not")
  731. }
  732. }
  733. func TestNull(t *testing.T) {
  734. tempFilename := TempFilename(t)
  735. defer os.Remove(tempFilename)
  736. db, err := sql.Open("sqlite3", tempFilename)
  737. if err != nil {
  738. t.Fatal("Failed to open database:", err)
  739. }
  740. defer db.Close()
  741. rows, err := db.Query("SELECT 3.141592")
  742. if err != nil {
  743. t.Fatal("Unable to query foo table:", err)
  744. }
  745. if !rows.Next() {
  746. t.Fatal("Unable to query results:", err)
  747. }
  748. var v interface{}
  749. if err := rows.Scan(&v); err != nil {
  750. t.Fatal("Unable to scan results:", err)
  751. }
  752. f, ok := v.(float64)
  753. if !ok {
  754. t.Error("Expected float but not")
  755. }
  756. if f != 3.141592 {
  757. t.Error("Expected 3.141592 but not")
  758. }
  759. }
  760. func TestTransaction(t *testing.T) {
  761. tempFilename := TempFilename(t)
  762. defer os.Remove(tempFilename)
  763. db, err := sql.Open("sqlite3", tempFilename)
  764. if err != nil {
  765. t.Fatal("Failed to open database:", err)
  766. }
  767. defer db.Close()
  768. _, err = db.Exec("CREATE TABLE foo(id INTEGER)")
  769. if err != nil {
  770. t.Fatal("Failed to create table:", err)
  771. }
  772. tx, err := db.Begin()
  773. if err != nil {
  774. t.Fatal("Failed to begin transaction:", err)
  775. }
  776. _, err = tx.Exec("INSERT INTO foo(id) VALUES(1)")
  777. if err != nil {
  778. t.Fatal("Failed to insert null:", err)
  779. }
  780. rows, err := tx.Query("SELECT id from foo")
  781. if err != nil {
  782. t.Fatal("Unable to query foo table:", err)
  783. }
  784. err = tx.Rollback()
  785. if err != nil {
  786. t.Fatal("Failed to rollback transaction:", err)
  787. }
  788. if rows.Next() {
  789. t.Fatal("Unable to query results:", err)
  790. }
  791. tx, err = db.Begin()
  792. if err != nil {
  793. t.Fatal("Failed to begin transaction:", err)
  794. }
  795. _, err = tx.Exec("INSERT INTO foo(id) VALUES(1)")
  796. if err != nil {
  797. t.Fatal("Failed to insert null:", err)
  798. }
  799. err = tx.Commit()
  800. if err != nil {
  801. t.Fatal("Failed to commit transaction:", err)
  802. }
  803. rows, err = tx.Query("SELECT id from foo")
  804. if err == nil {
  805. t.Fatal("Expected failure to query")
  806. }
  807. }
  808. func TestWAL(t *testing.T) {
  809. tempFilename := TempFilename(t)
  810. defer os.Remove(tempFilename)
  811. db, err := sql.Open("sqlite3", tempFilename)
  812. if err != nil {
  813. t.Fatal("Failed to open database:", err)
  814. }
  815. defer db.Close()
  816. if _, err = db.Exec("PRAGMA journal_mode=WAL;"); err != nil {
  817. t.Fatal("Failed to Exec PRAGMA journal_mode:", err)
  818. }
  819. if _, err = db.Exec("PRAGMA locking_mode=EXCLUSIVE;"); err != nil {
  820. t.Fatal("Failed to Exec PRAGMA locking_mode:", err)
  821. }
  822. if _, err = db.Exec("CREATE TABLE test (id SERIAL, user TEXT NOT NULL, name TEXT NOT NULL);"); err != nil {
  823. t.Fatal("Failed to Exec CREATE TABLE:", err)
  824. }
  825. if _, err = db.Exec("INSERT INTO test (user, name) VALUES ('user','name');"); err != nil {
  826. t.Fatal("Failed to Exec INSERT:", err)
  827. }
  828. trans, err := db.Begin()
  829. if err != nil {
  830. t.Fatal("Failed to Begin:", err)
  831. }
  832. s, err := trans.Prepare("INSERT INTO test (user, name) VALUES (?, ?);")
  833. if err != nil {
  834. t.Fatal("Failed to Prepare:", err)
  835. }
  836. var count int
  837. if err = trans.QueryRow("SELECT count(user) FROM test;").Scan(&count); err != nil {
  838. t.Fatal("Failed to QueryRow:", err)
  839. }
  840. if _, err = s.Exec("bbbb", "aaaa"); err != nil {
  841. t.Fatal("Failed to Exec prepared statement:", err)
  842. }
  843. if err = s.Close(); err != nil {
  844. t.Fatal("Failed to Close prepared statement:", err)
  845. }
  846. if err = trans.Commit(); err != nil {
  847. t.Fatal("Failed to Commit:", err)
  848. }
  849. }
  850. func TestTimezoneConversion(t *testing.T) {
  851. zones := []string{"UTC", "US/Central", "US/Pacific", "Local"}
  852. for _, tz := range zones {
  853. tempFilename := TempFilename(t)
  854. defer os.Remove(tempFilename)
  855. db, err := sql.Open("sqlite3", tempFilename+"?_loc="+url.QueryEscape(tz))
  856. if err != nil {
  857. t.Fatal("Failed to open database:", err)
  858. }
  859. defer db.Close()
  860. _, err = db.Exec("DROP TABLE foo")
  861. _, err = db.Exec("CREATE TABLE foo(id INTEGER, ts TIMESTAMP, dt DATETIME)")
  862. if err != nil {
  863. t.Fatal("Failed to create table:", err)
  864. }
  865. loc, err := time.LoadLocation(tz)
  866. if err != nil {
  867. t.Fatal("Failed to load location:", err)
  868. }
  869. timestamp1 := time.Date(2012, time.April, 6, 22, 50, 0, 0, time.UTC)
  870. timestamp2 := time.Date(2006, time.January, 2, 15, 4, 5, 123456789, time.UTC)
  871. timestamp3 := time.Date(2012, time.November, 4, 0, 0, 0, 0, time.UTC)
  872. tests := []struct {
  873. value interface{}
  874. expected time.Time
  875. }{
  876. {"nonsense", time.Time{}.In(loc)},
  877. {"0000-00-00 00:00:00", time.Time{}.In(loc)},
  878. {timestamp1, timestamp1.In(loc)},
  879. {timestamp1.Unix(), timestamp1.In(loc)},
  880. {timestamp1.In(time.FixedZone("TEST", -7*3600)), timestamp1.In(loc)},
  881. {timestamp1.Format("2006-01-02 15:04:05.000"), timestamp1.In(loc)},
  882. {timestamp1.Format("2006-01-02T15:04:05.000"), timestamp1.In(loc)},
  883. {timestamp1.Format("2006-01-02 15:04:05"), timestamp1.In(loc)},
  884. {timestamp1.Format("2006-01-02T15:04:05"), timestamp1.In(loc)},
  885. {timestamp2, timestamp2.In(loc)},
  886. {"2006-01-02 15:04:05.123456789", timestamp2.In(loc)},
  887. {"2006-01-02T15:04:05.123456789", timestamp2.In(loc)},
  888. {"2012-11-04", timestamp3.In(loc)},
  889. {"2012-11-04 00:00", timestamp3.In(loc)},
  890. {"2012-11-04 00:00:00", timestamp3.In(loc)},
  891. {"2012-11-04 00:00:00.000", timestamp3.In(loc)},
  892. {"2012-11-04T00:00", timestamp3.In(loc)},
  893. {"2012-11-04T00:00:00", timestamp3.In(loc)},
  894. {"2012-11-04T00:00:00.000", timestamp3.In(loc)},
  895. }
  896. for i := range tests {
  897. _, err = db.Exec("INSERT INTO foo(id, ts, dt) VALUES(?, ?, ?)", i, tests[i].value, tests[i].value)
  898. if err != nil {
  899. t.Fatal("Failed to insert timestamp:", err)
  900. }
  901. }
  902. rows, err := db.Query("SELECT id, ts, dt FROM foo ORDER BY id ASC")
  903. if err != nil {
  904. t.Fatal("Unable to query foo table:", err)
  905. }
  906. defer rows.Close()
  907. seen := 0
  908. for rows.Next() {
  909. var id int
  910. var ts, dt time.Time
  911. if err := rows.Scan(&id, &ts, &dt); err != nil {
  912. t.Error("Unable to scan results:", err)
  913. continue
  914. }
  915. if id < 0 || id >= len(tests) {
  916. t.Error("Bad row id: ", id)
  917. continue
  918. }
  919. seen++
  920. if !tests[id].expected.Equal(ts) {
  921. t.Errorf("Timestamp value for id %v (%v) should be %v, not %v", id, tests[id].value, tests[id].expected, ts)
  922. }
  923. if !tests[id].expected.Equal(dt) {
  924. t.Errorf("Datetime value for id %v (%v) should be %v, not %v", id, tests[id].value, tests[id].expected, dt)
  925. }
  926. if tests[id].expected.Location().String() != ts.Location().String() {
  927. t.Errorf("Location for id %v (%v) should be %v, not %v", id, tests[id].value, tests[id].expected.Location().String(), ts.Location().String())
  928. }
  929. if tests[id].expected.Location().String() != dt.Location().String() {
  930. t.Errorf("Location for id %v (%v) should be %v, not %v", id, tests[id].value, tests[id].expected.Location().String(), dt.Location().String())
  931. }
  932. }
  933. if seen != len(tests) {
  934. t.Errorf("Expected to see %d rows", len(tests))
  935. }
  936. }
  937. }
  938. // TODO: Execer & Queryer currently disabled
  939. // https://github.com/mattn/go-sqlite3/issues/82
  940. func TestExecer(t *testing.T) {
  941. tempFilename := TempFilename(t)
  942. defer os.Remove(tempFilename)
  943. db, err := sql.Open("sqlite3", tempFilename)
  944. if err != nil {
  945. t.Fatal("Failed to open database:", err)
  946. }
  947. defer db.Close()
  948. _, err = db.Exec(`
  949. create table foo (id integer); -- one comment
  950. insert into foo(id) values(?);
  951. insert into foo(id) values(?);
  952. insert into foo(id) values(?); -- another comment
  953. `, 1, 2, 3)
  954. if err != nil {
  955. t.Error("Failed to call db.Exec:", err)
  956. }
  957. }
  958. func TestQueryer(t *testing.T) {
  959. tempFilename := TempFilename(t)
  960. defer os.Remove(tempFilename)
  961. db, err := sql.Open("sqlite3", tempFilename)
  962. if err != nil {
  963. t.Fatal("Failed to open database:", err)
  964. }
  965. defer db.Close()
  966. _, err = db.Exec(`
  967. create table foo (id integer);
  968. `)
  969. if err != nil {
  970. t.Error("Failed to call db.Query:", err)
  971. }
  972. _, err = db.Exec(`
  973. insert into foo(id) values(?);
  974. insert into foo(id) values(?);
  975. insert into foo(id) values(?);
  976. `, 3, 2, 1)
  977. if err != nil {
  978. t.Error("Failed to call db.Exec:", err)
  979. }
  980. rows, err := db.Query(`
  981. select id from foo order by id;
  982. `)
  983. if err != nil {
  984. t.Error("Failed to call db.Query:", err)
  985. }
  986. defer rows.Close()
  987. n := 0
  988. for rows.Next() {
  989. var id int
  990. err = rows.Scan(&id)
  991. if err != nil {
  992. t.Error("Failed to db.Query:", err)
  993. }
  994. if id != n + 1 {
  995. t.Error("Failed to db.Query: not matched results")
  996. }
  997. n = n + 1
  998. }
  999. if err := rows.Err(); err != nil {
  1000. t.Errorf("Post-scan failed: %v\n", err)
  1001. }
  1002. if n != 3 {
  1003. t.Errorf("Expected 3 rows but retrieved %v", n)
  1004. }
  1005. }
  1006. func TestStress(t *testing.T) {
  1007. tempFilename := TempFilename(t)
  1008. defer os.Remove(tempFilename)
  1009. db, err := sql.Open("sqlite3", tempFilename)
  1010. if err != nil {
  1011. t.Fatal("Failed to open database:", err)
  1012. }
  1013. db.Exec("CREATE TABLE foo (id int);")
  1014. db.Exec("INSERT INTO foo VALUES(1);")
  1015. db.Exec("INSERT INTO foo VALUES(2);")
  1016. db.Close()
  1017. for i := 0; i < 10000; i++ {
  1018. db, err := sql.Open("sqlite3", tempFilename)
  1019. if err != nil {
  1020. t.Fatal("Failed to open database:", err)
  1021. }
  1022. for j := 0; j < 3; j++ {
  1023. rows, err := db.Query("select * from foo where id=1;")
  1024. if err != nil {
  1025. t.Error("Failed to call db.Query:", err)
  1026. }
  1027. for rows.Next() {
  1028. var i int
  1029. if err := rows.Scan(&i); err != nil {
  1030. t.Errorf("Scan failed: %v\n", err)
  1031. }
  1032. }
  1033. if err := rows.Err(); err != nil {
  1034. t.Errorf("Post-scan failed: %v\n", err)
  1035. }
  1036. rows.Close()
  1037. }
  1038. db.Close()
  1039. }
  1040. }
  1041. func TestDateTimeLocal(t *testing.T) {
  1042. zone := "Asia/Tokyo"
  1043. tempFilename := TempFilename(t)
  1044. defer os.Remove(tempFilename)
  1045. db, err := sql.Open("sqlite3", tempFilename+"?_loc="+zone)
  1046. if err != nil {
  1047. t.Fatal("Failed to open database:", err)
  1048. }
  1049. db.Exec("CREATE TABLE foo (dt datetime);")
  1050. db.Exec("INSERT INTO foo VALUES('2015-03-05 15:16:17');")
  1051. row := db.QueryRow("select * from foo")
  1052. var d time.Time
  1053. err = row.Scan(&d)
  1054. if err != nil {
  1055. t.Fatal("Failed to scan datetime:", err)
  1056. }
  1057. if d.Hour() == 15 || !strings.Contains(d.String(), "JST") {
  1058. t.Fatal("Result should have timezone", d)
  1059. }
  1060. db.Close()
  1061. db, err = sql.Open("sqlite3", tempFilename)
  1062. if err != nil {
  1063. t.Fatal("Failed to open database:", err)
  1064. }
  1065. row = db.QueryRow("select * from foo")
  1066. err = row.Scan(&d)
  1067. if err != nil {
  1068. t.Fatal("Failed to scan datetime:", err)
  1069. }
  1070. if d.UTC().Hour() != 15 || !strings.Contains(d.String(), "UTC") {
  1071. t.Fatalf("Result should not have timezone %v %v", zone, d.String())
  1072. }
  1073. _, err = db.Exec("DELETE FROM foo")
  1074. if err != nil {
  1075. t.Fatal("Failed to delete table:", err)
  1076. }
  1077. dt, err := time.Parse("2006/1/2 15/4/5 -0700 MST", "2015/3/5 15/16/17 +0900 JST")
  1078. if err != nil {
  1079. t.Fatal("Failed to parse datetime:", err)
  1080. }
  1081. db.Exec("INSERT INTO foo VALUES(?);", dt)
  1082. db.Close()
  1083. db, err = sql.Open("sqlite3", tempFilename+"?_loc="+zone)
  1084. if err != nil {
  1085. t.Fatal("Failed to open database:", err)
  1086. }
  1087. row = db.QueryRow("select * from foo")
  1088. err = row.Scan(&d)
  1089. if err != nil {
  1090. t.Fatal("Failed to scan datetime:", err)
  1091. }
  1092. if d.Hour() != 15 || !strings.Contains(d.String(), "JST") {
  1093. t.Fatalf("Result should have timezone %v %v", zone, d.String())
  1094. }
  1095. }
  1096. func TestVersion(t *testing.T) {
  1097. s, n, id := Version()
  1098. if s == "" || n == 0 || id == "" {
  1099. t.Errorf("Version failed %q, %d, %q\n", s, n, id)
  1100. }
  1101. }
  1102. func TestStringContainingZero(t *testing.T) {
  1103. tempFilename := TempFilename(t)
  1104. defer os.Remove(tempFilename)
  1105. db, err := sql.Open("sqlite3", tempFilename)
  1106. if err != nil {
  1107. t.Fatal("Failed to open database:", err)
  1108. }
  1109. defer db.Close()
  1110. _, err = db.Exec(`
  1111. create table foo (id integer, name, extra text);
  1112. `)
  1113. if err != nil {
  1114. t.Error("Failed to call db.Query:", err)
  1115. }
  1116. const text = "foo\x00bar"
  1117. _, err = db.Exec(`insert into foo(id, name, extra) values($1, $2, $2)`, 1, text)
  1118. if err != nil {
  1119. t.Error("Failed to call db.Exec:", err)
  1120. }
  1121. row := db.QueryRow(`select id, extra from foo where id = $1 and extra = $2`, 1, text)
  1122. if row == nil {
  1123. t.Error("Failed to call db.QueryRow")
  1124. }
  1125. var id int
  1126. var extra string
  1127. err = row.Scan(&id, &extra)
  1128. if err != nil {
  1129. t.Error("Failed to db.Scan:", err)
  1130. }
  1131. if id != 1 || extra != text {
  1132. t.Error("Failed to db.QueryRow: not matched results")
  1133. }
  1134. }
  1135. const CurrentTimeStamp = "2006-01-02 15:04:05"
  1136. type TimeStamp struct{ *time.Time }
  1137. func (t TimeStamp) Scan(value interface{}) error {
  1138. var err error
  1139. switch v := value.(type) {
  1140. case string:
  1141. *t.Time, err = time.Parse(CurrentTimeStamp, v)
  1142. case []byte:
  1143. *t.Time, err = time.Parse(CurrentTimeStamp, string(v))
  1144. default:
  1145. err = errors.New("invalid type for current_timestamp")
  1146. }
  1147. return err
  1148. }
  1149. func (t TimeStamp) Value() (driver.Value, error) {
  1150. return t.Time.Format(CurrentTimeStamp), nil
  1151. }
  1152. func TestDateTimeNow(t *testing.T) {
  1153. tempFilename := TempFilename(t)
  1154. defer os.Remove(tempFilename)
  1155. db, err := sql.Open("sqlite3", tempFilename)
  1156. if err != nil {
  1157. t.Fatal("Failed to open database:", err)
  1158. }
  1159. defer db.Close()
  1160. var d time.Time
  1161. err = db.QueryRow("SELECT datetime('now')").Scan(TimeStamp{&d})
  1162. if err != nil {
  1163. t.Fatal("Failed to scan datetime:", err)
  1164. }
  1165. }
  1166. func TestFunctionRegistration(t *testing.T) {
  1167. addi8_16_32 := func(a int8, b int16) int32 { return int32(a) + int32(b) }
  1168. addi64 := func(a, b int64) int64 { return a + b }
  1169. addu8_16_32 := func(a uint8, b uint16) uint32 { return uint32(a) + uint32(b) }
  1170. addu64 := func(a, b uint64) uint64 { return a + b }
  1171. addiu := func(a int, b uint) int64 { return int64(a) + int64(b) }
  1172. addf32_64 := func(a float32, b float64) float64 { return float64(a) + b }
  1173. not := func(a bool) bool { return !a }
  1174. regex := func(re, s string) (bool, error) {
  1175. return regexp.MatchString(re, s)
  1176. }
  1177. generic := func(a interface{}) int64 {
  1178. switch a.(type) {
  1179. case int64:
  1180. return 1
  1181. case float64:
  1182. return 2
  1183. case []byte:
  1184. return 3
  1185. case string:
  1186. return 4
  1187. default:
  1188. panic("unreachable")
  1189. }
  1190. }
  1191. variadic := func(a, b int64, c ...int64) int64 {
  1192. ret := a + b
  1193. for _, d := range c {
  1194. ret += d
  1195. }
  1196. return ret
  1197. }
  1198. variadicGeneric := func(a ...interface{}) int64 {
  1199. return int64(len(a))
  1200. }
  1201. sql.Register("sqlite3_FunctionRegistration", &SQLiteDriver{
  1202. ConnectHook: func(conn *SQLiteConn) error {
  1203. if err := conn.RegisterFunc("addi8_16_32", addi8_16_32, true); err != nil {
  1204. return err
  1205. }
  1206. if err := conn.RegisterFunc("addi64", addi64, true); err != nil {
  1207. return err
  1208. }
  1209. if err := conn.RegisterFunc("addu8_16_32", addu8_16_32, true); err != nil {
  1210. return err
  1211. }
  1212. if err := conn.RegisterFunc("addu64", addu64, true); err != nil {
  1213. return err
  1214. }
  1215. if err := conn.RegisterFunc("addiu", addiu, true); err != nil {
  1216. return err
  1217. }
  1218. if err := conn.RegisterFunc("addf32_64", addf32_64, true); err != nil {
  1219. return err
  1220. }
  1221. if err := conn.RegisterFunc("not", not, true); err != nil {
  1222. return err
  1223. }
  1224. if err := conn.RegisterFunc("regex", regex, true); err != nil {
  1225. return err
  1226. }
  1227. if err := conn.RegisterFunc("generic", generic, true); err != nil {
  1228. return err
  1229. }
  1230. if err := conn.RegisterFunc("variadic", variadic, true); err != nil {
  1231. return err
  1232. }
  1233. if err := conn.RegisterFunc("variadicGeneric", variadicGeneric, true); err != nil {
  1234. return err
  1235. }
  1236. return nil
  1237. },
  1238. })
  1239. db, err := sql.Open("sqlite3_FunctionRegistration", ":memory:")
  1240. if err != nil {
  1241. t.Fatal("Failed to open database:", err)
  1242. }
  1243. defer db.Close()
  1244. ops := []struct {
  1245. query string
  1246. expected interface{}
  1247. }{
  1248. {"SELECT addi8_16_32(1,2)", int32(3)},
  1249. {"SELECT addi64(1,2)", int64(3)},
  1250. {"SELECT addu8_16_32(1,2)", uint32(3)},
  1251. {"SELECT addu64(1,2)", uint64(3)},
  1252. {"SELECT addiu(1,2)", int64(3)},
  1253. {"SELECT addf32_64(1.5,1.5)", float64(3)},
  1254. {"SELECT not(1)", false},
  1255. {"SELECT not(0)", true},
  1256. {`SELECT regex('^foo.*', 'foobar')`, true},
  1257. {`SELECT regex('^foo.*', 'barfoobar')`, false},
  1258. {"SELECT generic(1)", int64(1)},
  1259. {"SELECT generic(1.1)", int64(2)},
  1260. {`SELECT generic(NULL)`, int64(3)},
  1261. {`SELECT generic('foo')`, int64(4)},
  1262. {"SELECT variadic(1,2)", int64(3)},
  1263. {"SELECT variadic(1,2,3,4)", int64(10)},
  1264. {"SELECT variadic(1,1,1,1,1,1,1,1,1,1)", int64(10)},
  1265. {`SELECT variadicGeneric(1,'foo',2.3, NULL)`, int64(4)},
  1266. }
  1267. for _, op := range ops {
  1268. ret := reflect.New(reflect.TypeOf(op.expected))
  1269. err = db.QueryRow(op.query).Scan(ret.Interface())
  1270. if err != nil {
  1271. t.Errorf("Query %q failed: %s", op.query, err)
  1272. } else if !reflect.DeepEqual(ret.Elem().Interface(), op.expected) {
  1273. t.Errorf("Query %q returned wrong value: got %v (%T), want %v (%T)", op.query, ret.Elem().Interface(), ret.Elem().Interface(), op.expected, op.expected)
  1274. }
  1275. }
  1276. }
  1277. type sumAggregator int64
  1278. func (s *sumAggregator) Step(x int64) {
  1279. *s += sumAggregator(x)
  1280. }
  1281. func (s *sumAggregator) Done() int64 {
  1282. return int64(*s)
  1283. }
  1284. func TestAggregatorRegistration(t *testing.T) {
  1285. customSum := func() *sumAggregator {
  1286. var ret sumAggregator
  1287. return &ret
  1288. }
  1289. sql.Register("sqlite3_AggregatorRegistration", &SQLiteDriver{
  1290. ConnectHook: func(conn *SQLiteConn) error {
  1291. return conn.RegisterAggregator("customSum", customSum, true)
  1292. },
  1293. })
  1294. db, err := sql.Open("sqlite3_AggregatorRegistration", ":memory:")
  1295. if err != nil {
  1296. t.Fatal("Failed to open database:", err)
  1297. }
  1298. defer db.Close()
  1299. _, err = db.Exec("create table foo (department integer, profits integer)")
  1300. if err != nil {
  1301. // trace feature is not implemented
  1302. t.Skip("Failed to create table:", err)
  1303. }
  1304. _, err = db.Exec("insert into foo values (1, 10), (1, 20), (2, 42)")
  1305. if err != nil {
  1306. t.Fatal("Failed to insert records:", err)
  1307. }
  1308. tests := []struct {
  1309. dept, sum int64
  1310. }{
  1311. {1, 30},
  1312. {2, 42},
  1313. }
  1314. for _, test := range tests {
  1315. var ret int64
  1316. err = db.QueryRow("select customSum(profits) from foo where department = $1 group by department", test.dept).Scan(&ret)
  1317. if err != nil {
  1318. t.Fatal("Query failed:", err)
  1319. }
  1320. if ret != test.sum {
  1321. t.Fatalf("Custom sum returned wrong value, got %d, want %d", ret, test.sum)
  1322. }
  1323. }
  1324. }
  1325. type mode struct {
  1326. counts map[interface{}]int
  1327. top interface{}
  1328. topCount int
  1329. }
  1330. func newMode() *mode {
  1331. return &mode{
  1332. counts: map[interface{}]int{},
  1333. }
  1334. }
  1335. func (m *mode) Step(x interface{}) {
  1336. m.counts[x]++
  1337. c := m.counts[x]
  1338. if c > m.topCount {
  1339. m.top = x
  1340. m.topCount = c
  1341. }
  1342. }
  1343. func (m *mode) Done() interface{} {
  1344. return m.top
  1345. }
  1346. func TestAggregatorRegistration_GenericReturn(t *testing.T) {
  1347. sql.Register("sqlite3_AggregatorRegistration_GenericReturn", &SQLiteDriver{
  1348. ConnectHook: func(conn *SQLiteConn) error {
  1349. return conn.RegisterAggregator("mode", newMode, true)
  1350. },
  1351. })
  1352. db, err := sql.Open("sqlite3_AggregatorRegistration_GenericReturn", ":memory:")
  1353. if err != nil {
  1354. t.Fatal("Failed to open database:", err)
  1355. }
  1356. defer db.Close()
  1357. _, err = db.Exec("create table foo (department integer, profits integer)")
  1358. if err != nil {
  1359. t.Fatal("Failed to create table:", err)
  1360. }
  1361. _, err = db.Exec("insert into foo values (1, 10), (1, 20), (1, 45), (2, 42), (2, 115), (2, 20)")
  1362. if err != nil {
  1363. t.Fatal("Failed to insert records:", err)
  1364. }
  1365. var mode int
  1366. err = db.QueryRow("select mode(profits) from foo").Scan(&mode)
  1367. if err != nil {
  1368. t.Fatal("MODE query error:", err)
  1369. }
  1370. if mode != 20 {
  1371. t.Fatal("Got incorrect mode. Wanted 20, got: ", mode)
  1372. }
  1373. }
  1374. func rot13(r rune) rune {
  1375. switch {
  1376. case r >= 'A' && r <= 'Z':
  1377. return 'A' + (r-'A'+13)%26
  1378. case r >= 'a' && r <= 'z':
  1379. return 'a' + (r-'a'+13)%26
  1380. }
  1381. return r
  1382. }
  1383. func TestCollationRegistration(t *testing.T) {
  1384. collateRot13 := func(a, b string) int {
  1385. ra, rb := strings.Map(rot13, a), strings.Map(rot13, b)
  1386. return strings.Compare(ra, rb)
  1387. }
  1388. collateRot13Reverse := func(a, b string) int {
  1389. return collateRot13(b, a)
  1390. }
  1391. sql.Register("sqlite3_CollationRegistration", &SQLiteDriver{
  1392. ConnectHook: func(conn *SQLiteConn) error {
  1393. if err := conn.RegisterCollation("rot13", collateRot13); err != nil {
  1394. return err
  1395. }
  1396. if err := conn.RegisterCollation("rot13reverse", collateRot13Reverse); err != nil {
  1397. return err
  1398. }
  1399. return nil
  1400. },
  1401. })
  1402. db, err := sql.Open("sqlite3_CollationRegistration", ":memory:")
  1403. if err != nil {
  1404. t.Fatal("Failed to open database:", err)
  1405. }
  1406. defer db.Close()
  1407. populate := []string{
  1408. `CREATE TABLE test (s TEXT)`,
  1409. `INSERT INTO test VALUES ('aaaa')`,
  1410. `INSERT INTO test VALUES ('ffff')`,
  1411. `INSERT INTO test VALUES ('qqqq')`,
  1412. `INSERT INTO test VALUES ('tttt')`,
  1413. `INSERT INTO test VALUES ('zzzz')`,
  1414. }
  1415. for _, stmt := range populate {
  1416. if _, err := db.Exec(stmt); err != nil {
  1417. t.Fatal("Failed to populate test DB:", err)
  1418. }
  1419. }
  1420. ops := []struct {
  1421. query string
  1422. want []string
  1423. }{
  1424. {
  1425. "SELECT * FROM test ORDER BY s COLLATE rot13 ASC",
  1426. []string{
  1427. "qqqq",
  1428. "tttt",
  1429. "zzzz",
  1430. "aaaa",
  1431. "ffff",
  1432. },
  1433. },
  1434. {
  1435. "SELECT * FROM test ORDER BY s COLLATE rot13 DESC",
  1436. []string{
  1437. "ffff",
  1438. "aaaa",
  1439. "zzzz",
  1440. "tttt",
  1441. "qqqq",
  1442. },
  1443. },
  1444. {
  1445. "SELECT * FROM test ORDER BY s COLLATE rot13reverse ASC",
  1446. []string{
  1447. "ffff",
  1448. "aaaa",
  1449. "zzzz",
  1450. "tttt",
  1451. "qqqq",
  1452. },
  1453. },
  1454. {
  1455. "SELECT * FROM test ORDER BY s COLLATE rot13reverse DESC",
  1456. []string{
  1457. "qqqq",
  1458. "tttt",
  1459. "zzzz",
  1460. "aaaa",
  1461. "ffff",
  1462. },
  1463. },
  1464. }
  1465. for _, op := range ops {
  1466. rows, err := db.Query(op.query)
  1467. if err != nil {
  1468. t.Fatalf("Query %q failed: %s", op.query, err)
  1469. }
  1470. got := []string{}
  1471. defer rows.Close()
  1472. for rows.Next() {
  1473. var s string
  1474. if err = rows.Scan(&s); err != nil {
  1475. t.Fatalf("Reading row for %q: %s", op.query, err)
  1476. }
  1477. got = append(got, s)
  1478. }
  1479. if err = rows.Err(); err != nil {
  1480. t.Fatalf("Reading rows for %q: %s", op.query, err)
  1481. }
  1482. if !reflect.DeepEqual(got, op.want) {
  1483. t.Fatalf("Unexpected output from %q\ngot:\n%s\n\nwant:\n%s", op.query, strings.Join(got, "\n"), strings.Join(op.want, "\n"))
  1484. }
  1485. }
  1486. }
  1487. func TestDeclTypes(t *testing.T) {
  1488. d := SQLiteDriver{}
  1489. conn, err := d.Open(":memory:")
  1490. if err != nil {
  1491. t.Fatal("Failed to begin transaction:", err)
  1492. }
  1493. defer conn.Close()
  1494. sqlite3conn := conn.(*SQLiteConn)
  1495. _, err = sqlite3conn.Exec("create table foo (id integer not null primary key, name text)", nil)
  1496. if err != nil {
  1497. t.Fatal("Failed to create table:", err)
  1498. }
  1499. _, err = sqlite3conn.Exec("insert into foo(name) values('bar')", nil)
  1500. if err != nil {
  1501. t.Fatal("Failed to insert:", err)
  1502. }
  1503. rs, err := sqlite3conn.Query("select * from foo", nil)
  1504. if err != nil {
  1505. t.Fatal("Failed to select:", err)
  1506. }
  1507. defer rs.Close()
  1508. declTypes := rs.(*SQLiteRows).DeclTypes()
  1509. if !reflect.DeepEqual(declTypes, []string{"integer", "text"}) {
  1510. t.Fatal("Unexpected declTypes:", declTypes)
  1511. }
  1512. }
  1513. func TestPinger(t *testing.T) {
  1514. db, err := sql.Open("sqlite3", ":memory:")
  1515. if err != nil {
  1516. t.Fatal(err)
  1517. }
  1518. err = db.Ping()
  1519. if err != nil {
  1520. t.Fatal(err)
  1521. }
  1522. db.Close()
  1523. err = db.Ping()
  1524. if err == nil {
  1525. t.Fatal("Should be closed")
  1526. }
  1527. }
  1528. func TestUpdateAndTransactionHooks(t *testing.T) {
  1529. var events []string
  1530. var commitHookReturn = 0
  1531. sql.Register("sqlite3_UpdateHook", &SQLiteDriver{
  1532. ConnectHook: func(conn *SQLiteConn) error {
  1533. conn.RegisterCommitHook(func() int {
  1534. events = append(events, "commit")
  1535. return commitHookReturn
  1536. })
  1537. conn.RegisterRollbackHook(func() {
  1538. events = append(events, "rollback")
  1539. })
  1540. conn.RegisterUpdateHook(func(op int, db string, table string, rowid int64) {
  1541. events = append(events, fmt.Sprintf("update(op=%v db=%v table=%v rowid=%v)", op, db, table, rowid))
  1542. })
  1543. return nil
  1544. },
  1545. })
  1546. db, err := sql.Open("sqlite3_UpdateHook", ":memory:")
  1547. if err != nil {
  1548. t.Fatal("Failed to open database:", err)
  1549. }
  1550. defer db.Close()
  1551. statements := []string{
  1552. "create table foo (id integer primary key)",
  1553. "insert into foo values (9)",
  1554. "update foo set id = 99 where id = 9",
  1555. "delete from foo where id = 99",
  1556. }
  1557. for _, statement := range statements {
  1558. _, err = db.Exec(statement)
  1559. if err != nil {
  1560. t.Fatalf("Unable to prepare test data [%v]: %v", statement, err)
  1561. }
  1562. }
  1563. commitHookReturn = 1
  1564. _, err = db.Exec("insert into foo values (5)")
  1565. if err == nil {
  1566. t.Error("Commit hook failed to rollback transaction")
  1567. }
  1568. var expected = []string{
  1569. "commit",
  1570. fmt.Sprintf("update(op=%v db=main table=foo rowid=9)", SQLITE_INSERT),
  1571. "commit",
  1572. fmt.Sprintf("update(op=%v db=main table=foo rowid=99)", SQLITE_UPDATE),
  1573. "commit",
  1574. fmt.Sprintf("update(op=%v db=main table=foo rowid=99)", SQLITE_DELETE),
  1575. "commit",
  1576. fmt.Sprintf("update(op=%v db=main table=foo rowid=5)", SQLITE_INSERT),
  1577. "commit",
  1578. "rollback",
  1579. }
  1580. if !reflect.DeepEqual(events, expected) {
  1581. t.Errorf("Expected notifications %v but got %v", expected, events)
  1582. }
  1583. }
  1584. func TestAuthorizer(t *testing.T) {
  1585. var authorizerReturn = 0
  1586. sql.Register("sqlite3_Authorizer", &SQLiteDriver{
  1587. ConnectHook: func(conn *SQLiteConn) error {
  1588. conn.RegisterAuthorizer(func(op int, arg1, arg2, arg3 string) int {
  1589. return authorizerReturn
  1590. })
  1591. return nil
  1592. },
  1593. })
  1594. db, err := sql.Open("sqlite3_Authorizer", ":memory:")
  1595. if err != nil {
  1596. t.Fatal("Failed to open database:", err)
  1597. }
  1598. defer db.Close()
  1599. statements := []string{
  1600. "create table foo (id integer primary key, name varchar)",
  1601. "insert into foo values (9, 'test9')",
  1602. "update foo set name = 'test99' where id = 9",
  1603. "select * from foo",
  1604. }
  1605. authorizerReturn = SQLITE_OK
  1606. for _, statement := range statements {
  1607. _, err = db.Exec(statement)
  1608. if err != nil {
  1609. t.Fatalf("No error expected [%v]: %v", statement, err)
  1610. }
  1611. }
  1612. authorizerReturn = SQLITE_DENY
  1613. for _, statement := range statements {
  1614. _, err = db.Exec(statement)
  1615. if err == nil {
  1616. t.Fatalf("Authorizer didn't worked - nil received, but error expected: [%v]", statement)
  1617. }
  1618. }
  1619. }
  1620. func TestSetFileControlInt(t *testing.T) {
  1621. t.Run("PERSIST_WAL", func(t *testing.T) {
  1622. tempFilename := TempFilename(t)
  1623. defer os.Remove(tempFilename)
  1624. sql.Register("sqlite3_FCNTL_PERSIST_WAL", &SQLiteDriver{
  1625. ConnectHook: func(conn *SQLiteConn) error {
  1626. if err := conn.SetFileControlInt("", SQLITE_FCNTL_PERSIST_WAL, 1); err != nil {
  1627. return fmt.Errorf("Unexpected error from SetFileControlInt(): %w", err)
  1628. }
  1629. return nil
  1630. },
  1631. })
  1632. db, err := sql.Open("sqlite3_FCNTL_PERSIST_WAL", tempFilename)
  1633. if err != nil {
  1634. t.Fatal("Failed to open database:", err)
  1635. }
  1636. defer db.Close()
  1637. // Set to WAL mode & write a page.
  1638. if _, err := db.Exec(`PRAGMA journal_mode = wal`); err != nil {
  1639. t.Fatal("Failed to set journal mode:", err)
  1640. } else if _, err := db.Exec(`CREATE TABLE t (x)`); err != nil {
  1641. t.Fatal("Failed to create table:", err)
  1642. }
  1643. if err := db.Close(); err != nil {
  1644. t.Fatal("Failed to close database", err)
  1645. }
  1646. // Ensure WAL file persists after close.
  1647. if _, err := os.Stat(tempFilename + "-wal"); err != nil {
  1648. t.Fatal("Expected WAL file to be persisted after close", err)
  1649. }
  1650. })
  1651. }
  1652. func TestNonColumnString(t *testing.T) {
  1653. db, err := sql.Open("sqlite3", ":memory:")
  1654. if err != nil {
  1655. t.Fatal(err)
  1656. }
  1657. defer db.Close()
  1658. var x interface{}
  1659. if err := db.QueryRow("SELECT 'hello'").Scan(&x); err != nil {
  1660. t.Fatal(err)
  1661. }
  1662. s, ok := x.(string)
  1663. if !ok {
  1664. t.Fatalf("non-column string must return string but got %T", x)
  1665. }
  1666. if s != "hello" {
  1667. t.Fatalf("non-column string must return %q but got %q", "hello", s)
  1668. }
  1669. }
  1670. func TestNilAndEmptyBytes(t *testing.T) {
  1671. db, err := sql.Open("sqlite3", ":memory:")
  1672. if err != nil {
  1673. t.Fatal(err)
  1674. }
  1675. defer db.Close()
  1676. actualNil := []byte("use this to use an actual nil not a reference to nil")
  1677. emptyBytes := []byte{}
  1678. for tsti, tst := range []struct {
  1679. name string
  1680. columnType string
  1681. insertBytes []byte
  1682. expectedBytes []byte
  1683. }{
  1684. {"actual nil blob", "blob", actualNil, nil},
  1685. {"referenced nil blob", "blob", nil, nil},
  1686. {"empty blob", "blob", emptyBytes, emptyBytes},
  1687. {"actual nil text", "text", actualNil, nil},
  1688. {"referenced nil text", "text", nil, nil},
  1689. {"empty text", "text", emptyBytes, emptyBytes},
  1690. } {
  1691. if _, err = db.Exec(fmt.Sprintf("create table tbl%d (txt %s)", tsti, tst.columnType)); err != nil {
  1692. t.Fatal(tst.name, err)
  1693. }
  1694. if bytes.Equal(tst.insertBytes, actualNil) {
  1695. if _, err = db.Exec(fmt.Sprintf("insert into tbl%d (txt) values (?)", tsti), nil); err != nil {
  1696. t.Fatal(tst.name, err)
  1697. }
  1698. } else {
  1699. if _, err = db.Exec(fmt.Sprintf("insert into tbl%d (txt) values (?)", tsti), &tst.insertBytes); err != nil {
  1700. t.Fatal(tst.name, err)
  1701. }
  1702. }
  1703. rows, err := db.Query(fmt.Sprintf("select txt from tbl%d", tsti))
  1704. if err != nil {
  1705. t.Fatal(tst.name, err)
  1706. }
  1707. if !rows.Next() {
  1708. t.Fatal(tst.name, "no rows")
  1709. }
  1710. var scanBytes []byte
  1711. if err = rows.Scan(&scanBytes); err != nil {
  1712. t.Fatal(tst.name, err)
  1713. }
  1714. if err = rows.Err(); err != nil {
  1715. t.Fatal(tst.name, err)
  1716. }
  1717. if tst.expectedBytes == nil && scanBytes != nil {
  1718. t.Errorf("%s: %#v != %#v", tst.name, scanBytes, tst.expectedBytes)
  1719. } else if !bytes.Equal(scanBytes, tst.expectedBytes) {
  1720. t.Errorf("%s: %#v != %#v", tst.name, scanBytes, tst.expectedBytes)
  1721. }
  1722. }
  1723. }
  1724. func TestInsertNilByteSlice(t *testing.T) {
  1725. db, err := sql.Open("sqlite3", ":memory:")
  1726. if err != nil {
  1727. t.Fatal(err)
  1728. }
  1729. defer db.Close()
  1730. if _, err := db.Exec("create table blob_not_null (b blob not null)"); err != nil {
  1731. t.Fatal(err)
  1732. }
  1733. var nilSlice []byte
  1734. if _, err := db.Exec("insert into blob_not_null (b) values (?)", nilSlice); err == nil {
  1735. t.Fatal("didn't expect INSERT to 'not null' column with a nil []byte slice to work")
  1736. }
  1737. zeroLenSlice := []byte{}
  1738. if _, err := db.Exec("insert into blob_not_null (b) values (?)", zeroLenSlice); err != nil {
  1739. t.Fatal("failed to insert zero-length slice")
  1740. }
  1741. }
  1742. func TestNamedParam(t *testing.T) {
  1743. tempFilename := TempFilename(t)
  1744. defer os.Remove(tempFilename)
  1745. db, err := sql.Open("sqlite3", tempFilename)
  1746. if err != nil {
  1747. t.Fatal("Failed to open database:", err)
  1748. }
  1749. defer db.Close()
  1750. _, err = db.Exec("drop table foo")
  1751. _, err = db.Exec("create table foo (id integer, name text, amount integer)")
  1752. if err != nil {
  1753. t.Fatal("Failed to create table:", err)
  1754. }
  1755. _, err = db.Exec("insert into foo(id, name, amount) values(:id, @name, $amount)",
  1756. sql.Named("bar", 42), sql.Named("baz", "quux"),
  1757. sql.Named("amount", 123), sql.Named("corge", "waldo"),
  1758. sql.Named("id", 2), sql.Named("name", "grault"))
  1759. if err != nil {
  1760. t.Fatal("Failed to insert record with named parameters:", err)
  1761. }
  1762. rows, err := db.Query("select id, name, amount from foo")
  1763. if err != nil {
  1764. t.Fatal("Failed to select records:", err)
  1765. }
  1766. defer rows.Close()
  1767. rows.Next()
  1768. var id, amount int
  1769. var name string
  1770. rows.Scan(&id, &name, &amount)
  1771. if id != 2 || name != "grault" || amount != 123 {
  1772. t.Errorf("Expected %d, %q, %d for fetched result, but got %d, %q, %d:", 2, "grault", 123, id, name, amount)
  1773. }
  1774. }
  1775. var customFunctionOnce sync.Once
  1776. func BenchmarkCustomFunctions(b *testing.B) {
  1777. customFunctionOnce.Do(func() {
  1778. customAdd := func(a, b int64) int64 {
  1779. return a + b
  1780. }
  1781. sql.Register("sqlite3_BenchmarkCustomFunctions", &SQLiteDriver{
  1782. ConnectHook: func(conn *SQLiteConn) error {
  1783. // Impure function to force sqlite to reexecute it each time.
  1784. return conn.RegisterFunc("custom_add", customAdd, false)
  1785. },
  1786. })
  1787. })
  1788. db, err := sql.Open("sqlite3_BenchmarkCustomFunctions", ":memory:")
  1789. if err != nil {
  1790. b.Fatal("Failed to open database:", err)
  1791. }
  1792. defer db.Close()
  1793. b.ResetTimer()
  1794. for i := 0; i < b.N; i++ {
  1795. var i int64
  1796. err = db.QueryRow("SELECT custom_add(1,2)").Scan(&i)
  1797. if err != nil {
  1798. b.Fatal("Failed to run custom add:", err)
  1799. }
  1800. }
  1801. }
  1802. func TestSuite(t *testing.T) {
  1803. initializeTestDB(t)
  1804. defer freeTestDB()
  1805. for _, test := range tests {
  1806. t.Run(test.Name, test.F)
  1807. }
  1808. }
  1809. func BenchmarkSuite(b *testing.B) {
  1810. initializeTestDB(b)
  1811. defer freeTestDB()
  1812. for _, benchmark := range benchmarks {
  1813. b.Run(benchmark.Name, benchmark.F)
  1814. }
  1815. }
  1816. // Dialect is a type of dialect of databases.
  1817. type Dialect int
  1818. // Dialects for databases.
  1819. const (
  1820. SQLITE Dialect = iota // SQLITE mean SQLite3 dialect
  1821. POSTGRESQL // POSTGRESQL mean PostgreSQL dialect
  1822. MYSQL // MYSQL mean MySQL dialect
  1823. )
  1824. // DB provide context for the tests
  1825. type TestDB struct {
  1826. testing.TB
  1827. *sql.DB
  1828. dialect Dialect
  1829. once sync.Once
  1830. tempFilename string
  1831. }
  1832. var db *TestDB
  1833. func initializeTestDB(t testing.TB) {
  1834. tempFilename := TempFilename(t)
  1835. d, err := sql.Open("sqlite3", tempFilename+"?_busy_timeout=99999")
  1836. if err != nil {
  1837. os.Remove(tempFilename)
  1838. t.Fatal(err)
  1839. }
  1840. db = &TestDB{t, d, SQLITE, sync.Once{}, tempFilename}
  1841. }
  1842. func freeTestDB() {
  1843. err := db.DB.Close()
  1844. if err != nil {
  1845. panic(err)
  1846. }
  1847. err = os.Remove(db.tempFilename)
  1848. if err != nil {
  1849. panic(err)
  1850. }
  1851. }
  1852. // the following tables will be created and dropped during the test
  1853. var testTables = []string{"foo", "bar", "t", "bench"}
  1854. var tests = []testing.InternalTest{
  1855. {Name: "TestResult", F: testResult},
  1856. {Name: "TestBlobs", F: testBlobs},
  1857. {Name: "TestMultiBlobs", F: testMultiBlobs},
  1858. {Name: "TestNullZeroLengthBlobs", F: testNullZeroLengthBlobs},
  1859. {Name: "TestManyQueryRow", F: testManyQueryRow},
  1860. {Name: "TestTxQuery", F: testTxQuery},
  1861. {Name: "TestPreparedStmt", F: testPreparedStmt},
  1862. {Name: "TestExecEmptyQuery", F: testExecEmptyQuery},
  1863. }
  1864. var benchmarks = []testing.InternalBenchmark{
  1865. {Name: "BenchmarkExec", F: benchmarkExec},
  1866. {Name: "BenchmarkQuery", F: benchmarkQuery},
  1867. {Name: "BenchmarkParams", F: benchmarkParams},
  1868. {Name: "BenchmarkStmt", F: benchmarkStmt},
  1869. {Name: "BenchmarkRows", F: benchmarkRows},
  1870. {Name: "BenchmarkStmtRows", F: benchmarkStmtRows},
  1871. }
  1872. func (db *TestDB) mustExec(sql string, args ...interface{}) sql.Result {
  1873. res, err := db.Exec(sql, args...)
  1874. if err != nil {
  1875. db.Fatalf("Error running %q: %v", sql, err)
  1876. }
  1877. return res
  1878. }
  1879. func (db *TestDB) tearDown() {
  1880. for _, tbl := range testTables {
  1881. switch db.dialect {
  1882. case SQLITE:
  1883. db.mustExec("drop table if exists " + tbl)
  1884. case MYSQL, POSTGRESQL:
  1885. db.mustExec("drop table if exists " + tbl)
  1886. default:
  1887. db.Fatal("unknown dialect")
  1888. }
  1889. }
  1890. }
  1891. // q replaces ? parameters if needed
  1892. func (db *TestDB) q(sql string) string {
  1893. switch db.dialect {
  1894. case POSTGRESQL: // replace with $1, $2, ..
  1895. qrx := regexp.MustCompile(`\?`)
  1896. n := 0
  1897. return qrx.ReplaceAllStringFunc(sql, func(string) string {
  1898. n++
  1899. return "$" + strconv.Itoa(n)
  1900. })
  1901. }
  1902. return sql
  1903. }
  1904. func (db *TestDB) blobType(size int) string {
  1905. switch db.dialect {
  1906. case SQLITE:
  1907. return fmt.Sprintf("blob[%d]", size)
  1908. case POSTGRESQL:
  1909. return "bytea"
  1910. case MYSQL:
  1911. return fmt.Sprintf("VARBINARY(%d)", size)
  1912. }
  1913. panic("unknown dialect")
  1914. }
  1915. func (db *TestDB) serialPK() string {
  1916. switch db.dialect {
  1917. case SQLITE:
  1918. return "integer primary key autoincrement"
  1919. case POSTGRESQL:
  1920. return "serial primary key"
  1921. case MYSQL:
  1922. return "integer primary key auto_increment"
  1923. }
  1924. panic("unknown dialect")
  1925. }
  1926. func (db *TestDB) now() string {
  1927. switch db.dialect {
  1928. case SQLITE:
  1929. return "datetime('now')"
  1930. case POSTGRESQL:
  1931. return "now()"
  1932. case MYSQL:
  1933. return "now()"
  1934. }
  1935. panic("unknown dialect")
  1936. }
  1937. func makeBench() {
  1938. if _, err := db.Exec("create table bench (n varchar(32), i integer, d double, s varchar(32), t datetime)"); err != nil {
  1939. panic(err)
  1940. }
  1941. st, err := db.Prepare("insert into bench values (?, ?, ?, ?, ?)")
  1942. if err != nil {
  1943. panic(err)
  1944. }
  1945. defer st.Close()
  1946. for i := 0; i < 100; i++ {
  1947. if _, err = st.Exec(nil, i, float64(i), fmt.Sprintf("%d", i), time.Now()); err != nil {
  1948. panic(err)
  1949. }
  1950. }
  1951. }
  1952. // testResult is test for result
  1953. func testResult(t *testing.T) {
  1954. db.tearDown()
  1955. db.mustExec("create temporary table test (id " + db.serialPK() + ", name varchar(10))")
  1956. for i := 1; i < 3; i++ {
  1957. r := db.mustExec(db.q("insert into test (name) values (?)"), fmt.Sprintf("row %d", i))
  1958. n, err := r.RowsAffected()
  1959. if err != nil {
  1960. t.Fatal(err)
  1961. }
  1962. if n != 1 {
  1963. t.Errorf("got %v, want %v", n, 1)
  1964. }
  1965. n, err = r.LastInsertId()
  1966. if err != nil {
  1967. t.Fatal(err)
  1968. }
  1969. if n != int64(i) {
  1970. t.Errorf("got %v, want %v", n, i)
  1971. }
  1972. }
  1973. if _, err := db.Exec("error!"); err == nil {
  1974. t.Fatalf("expected error")
  1975. }
  1976. }
  1977. // testBlobs is test for blobs
  1978. func testBlobs(t *testing.T) {
  1979. db.tearDown()
  1980. var blob = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}
  1981. db.mustExec("create table foo (id integer primary key, bar " + db.blobType(16) + ")")
  1982. db.mustExec(db.q("insert into foo (id, bar) values(?,?)"), 0, blob)
  1983. want := fmt.Sprintf("%x", blob)
  1984. b := make([]byte, 16)
  1985. err := db.QueryRow(db.q("select bar from foo where id = ?"), 0).Scan(&b)
  1986. got := fmt.Sprintf("%x", b)
  1987. if err != nil {
  1988. t.Errorf("[]byte scan: %v", err)
  1989. } else if got != want {
  1990. t.Errorf("for []byte, got %q; want %q", got, want)
  1991. }
  1992. err = db.QueryRow(db.q("select bar from foo where id = ?"), 0).Scan(&got)
  1993. want = string(blob)
  1994. if err != nil {
  1995. t.Errorf("string scan: %v", err)
  1996. } else if got != want {
  1997. t.Errorf("for string, got %q; want %q", got, want)
  1998. }
  1999. }
  2000. func testMultiBlobs(t *testing.T) {
  2001. db.tearDown()
  2002. db.mustExec("create table foo (id integer primary key, bar " + db.blobType(16) + ")")
  2003. var blob0 = []byte{0, 1, 2, 3, 4, 5, 6, 7}
  2004. db.mustExec(db.q("insert into foo (id, bar) values(?,?)"), 0, blob0)
  2005. var blob1 = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}
  2006. db.mustExec(db.q("insert into foo (id, bar) values(?,?)"), 1, blob1)
  2007. r, err := db.Query(db.q("select bar from foo order by id"))
  2008. if err != nil {
  2009. t.Fatal(err)
  2010. }
  2011. defer r.Close()
  2012. if !r.Next() {
  2013. if r.Err() != nil {
  2014. t.Fatal(err)
  2015. }
  2016. t.Fatal("expected one rows")
  2017. }
  2018. want0 := fmt.Sprintf("%x", blob0)
  2019. b0 := make([]byte, 8)
  2020. err = r.Scan(&b0)
  2021. if err != nil {
  2022. t.Fatal(err)
  2023. }
  2024. got0 := fmt.Sprintf("%x", b0)
  2025. if !r.Next() {
  2026. if r.Err() != nil {
  2027. t.Fatal(err)
  2028. }
  2029. t.Fatal("expected one rows")
  2030. }
  2031. want1 := fmt.Sprintf("%x", blob1)
  2032. b1 := make([]byte, 16)
  2033. err = r.Scan(&b1)
  2034. if err != nil {
  2035. t.Fatal(err)
  2036. }
  2037. got1 := fmt.Sprintf("%x", b1)
  2038. if got0 != want0 {
  2039. t.Errorf("for []byte, got %q; want %q", got0, want0)
  2040. }
  2041. if got1 != want1 {
  2042. t.Errorf("for []byte, got %q; want %q", got1, want1)
  2043. }
  2044. }
  2045. // testBlobs tests that we distinguish between null and zero-length blobs
  2046. func testNullZeroLengthBlobs(t *testing.T) {
  2047. db.tearDown()
  2048. db.mustExec("create table foo (id integer primary key, bar " + db.blobType(16) + ")")
  2049. db.mustExec(db.q("insert into foo (id, bar) values(?,?)"), 0, nil)
  2050. db.mustExec(db.q("insert into foo (id, bar) values(?,?)"), 1, []byte{})
  2051. r0 := db.QueryRow(db.q("select bar from foo where id=0"))
  2052. var b0 []byte
  2053. err := r0.Scan(&b0)
  2054. if err != nil {
  2055. t.Fatal(err)
  2056. }
  2057. if b0 != nil {
  2058. t.Errorf("for id=0, got %x; want nil", b0)
  2059. }
  2060. r1 := db.QueryRow(db.q("select bar from foo where id=1"))
  2061. var b1 []byte
  2062. err = r1.Scan(&b1)
  2063. if err != nil {
  2064. t.Fatal(err)
  2065. }
  2066. if b1 == nil {
  2067. t.Error("for id=1, got nil; want zero-length slice")
  2068. } else if len(b1) > 0 {
  2069. t.Errorf("for id=1, got %x; want zero-length slice", b1)
  2070. }
  2071. }
  2072. // testManyQueryRow is test for many query row
  2073. func testManyQueryRow(t *testing.T) {
  2074. if testing.Short() {
  2075. t.Log("skipping in short mode")
  2076. return
  2077. }
  2078. db.tearDown()
  2079. db.mustExec("create table foo (id integer primary key, name varchar(50))")
  2080. db.mustExec(db.q("insert into foo (id, name) values(?,?)"), 1, "bob")
  2081. var name string
  2082. for i := 0; i < 10000; i++ {
  2083. err := db.QueryRow(db.q("select name from foo where id = ?"), 1).Scan(&name)
  2084. if err != nil || name != "bob" {
  2085. t.Fatalf("on query %d: err=%v, name=%q", i, err, name)
  2086. }
  2087. }
  2088. }
  2089. // testTxQuery is test for transactional query
  2090. func testTxQuery(t *testing.T) {
  2091. db.tearDown()
  2092. tx, err := db.Begin()
  2093. if err != nil {
  2094. t.Fatal(err)
  2095. }
  2096. defer tx.Rollback()
  2097. _, err = tx.Exec("create table foo (id integer primary key, name varchar(50))")
  2098. if err != nil {
  2099. t.Fatal(err)
  2100. }
  2101. _, err = tx.Exec(db.q("insert into foo (id, name) values(?,?)"), 1, "bob")
  2102. if err != nil {
  2103. t.Fatal(err)
  2104. }
  2105. r, err := tx.Query(db.q("select name from foo where id = ?"), 1)
  2106. if err != nil {
  2107. t.Fatal(err)
  2108. }
  2109. defer r.Close()
  2110. if !r.Next() {
  2111. if r.Err() != nil {
  2112. t.Fatal(err)
  2113. }
  2114. t.Fatal("expected one rows")
  2115. }
  2116. var name string
  2117. err = r.Scan(&name)
  2118. if err != nil {
  2119. t.Fatal(err)
  2120. }
  2121. }
  2122. // testPreparedStmt is test for prepared statement
  2123. func testPreparedStmt(t *testing.T) {
  2124. db.tearDown()
  2125. db.mustExec("CREATE TABLE t (count INT)")
  2126. sel, err := db.Prepare("SELECT count FROM t ORDER BY count DESC")
  2127. if err != nil {
  2128. t.Fatalf("prepare 1: %v", err)
  2129. }
  2130. ins, err := db.Prepare(db.q("INSERT INTO t (count) VALUES (?)"))
  2131. if err != nil {
  2132. t.Fatalf("prepare 2: %v", err)
  2133. }
  2134. for n := 1; n <= 3; n++ {
  2135. if _, err := ins.Exec(n); err != nil {
  2136. t.Fatalf("insert(%d) = %v", n, err)
  2137. }
  2138. }
  2139. const nRuns = 10
  2140. var wg sync.WaitGroup
  2141. for i := 0; i < nRuns; i++ {
  2142. wg.Add(1)
  2143. go func() {
  2144. defer wg.Done()
  2145. for j := 0; j < 10; j++ {
  2146. count := 0
  2147. if err := sel.QueryRow().Scan(&count); err != nil && err != sql.ErrNoRows {
  2148. t.Errorf("Query: %v", err)
  2149. return
  2150. }
  2151. if _, err := ins.Exec(rand.Intn(100)); err != nil {
  2152. t.Errorf("Insert: %v", err)
  2153. return
  2154. }
  2155. }
  2156. }()
  2157. }
  2158. wg.Wait()
  2159. }
  2160. // testEmptyQuery is test for validating the API in case of empty query
  2161. func testExecEmptyQuery(t *testing.T) {
  2162. db.tearDown()
  2163. res, err := db.Exec(" -- this is just a comment ")
  2164. if err != nil {
  2165. t.Fatalf("empty query err: %v", err)
  2166. }
  2167. _, err = res.LastInsertId()
  2168. if err != nil {
  2169. t.Fatalf("LastInsertId returned an error: %v", err)
  2170. }
  2171. _, err = res.RowsAffected()
  2172. if err != nil {
  2173. t.Fatalf("RowsAffected returned an error: %v", err)
  2174. }
  2175. }
  2176. // Benchmarks need to use panic() since b.Error errors are lost when
  2177. // running via testing.Benchmark() I would like to run these via go
  2178. // test -bench but calling Benchmark() from a benchmark test
  2179. // currently hangs go.
  2180. // benchmarkExec is benchmark for exec
  2181. func benchmarkExec(b *testing.B) {
  2182. for i := 0; i < b.N; i++ {
  2183. if _, err := db.Exec("select 1"); err != nil {
  2184. panic(err)
  2185. }
  2186. }
  2187. }
  2188. // benchmarkQuery is benchmark for query
  2189. func benchmarkQuery(b *testing.B) {
  2190. for i := 0; i < b.N; i++ {
  2191. var n sql.NullString
  2192. var i int
  2193. var f float64
  2194. var s string
  2195. // var t time.Time
  2196. if err := db.QueryRow("select null, 1, 1.1, 'foo'").Scan(&n, &i, &f, &s); err != nil {
  2197. panic(err)
  2198. }
  2199. }
  2200. }
  2201. // benchmarkParams is benchmark for params
  2202. func benchmarkParams(b *testing.B) {
  2203. for i := 0; i < b.N; i++ {
  2204. var n sql.NullString
  2205. var i int
  2206. var f float64
  2207. var s string
  2208. // var t time.Time
  2209. if err := db.QueryRow("select ?, ?, ?, ?", nil, 1, 1.1, "foo").Scan(&n, &i, &f, &s); err != nil {
  2210. panic(err)
  2211. }
  2212. }
  2213. }
  2214. // benchmarkStmt is benchmark for statement
  2215. func benchmarkStmt(b *testing.B) {
  2216. st, err := db.Prepare("select ?, ?, ?, ?")
  2217. if err != nil {
  2218. panic(err)
  2219. }
  2220. defer st.Close()
  2221. for n := 0; n < b.N; n++ {
  2222. var n sql.NullString
  2223. var i int
  2224. var f float64
  2225. var s string
  2226. // var t time.Time
  2227. if err := st.QueryRow(nil, 1, 1.1, "foo").Scan(&n, &i, &f, &s); err != nil {
  2228. panic(err)
  2229. }
  2230. }
  2231. }
  2232. // benchmarkRows is benchmark for rows
  2233. func benchmarkRows(b *testing.B) {
  2234. db.once.Do(makeBench)
  2235. for n := 0; n < b.N; n++ {
  2236. var n sql.NullString
  2237. var i int
  2238. var f float64
  2239. var s string
  2240. var t time.Time
  2241. r, err := db.Query("select * from bench")
  2242. if err != nil {
  2243. panic(err)
  2244. }
  2245. for r.Next() {
  2246. if err = r.Scan(&n, &i, &f, &s, &t); err != nil {
  2247. panic(err)
  2248. }
  2249. }
  2250. if err = r.Err(); err != nil {
  2251. panic(err)
  2252. }
  2253. }
  2254. }
  2255. // benchmarkStmtRows is benchmark for statement rows
  2256. func benchmarkStmtRows(b *testing.B) {
  2257. db.once.Do(makeBench)
  2258. st, err := db.Prepare("select * from bench")
  2259. if err != nil {
  2260. panic(err)
  2261. }
  2262. defer st.Close()
  2263. for n := 0; n < b.N; n++ {
  2264. var n sql.NullString
  2265. var i int
  2266. var f float64
  2267. var s string
  2268. var t time.Time
  2269. r, err := st.Query()
  2270. if err != nil {
  2271. panic(err)
  2272. }
  2273. for r.Next() {
  2274. if err = r.Scan(&n, &i, &f, &s, &t); err != nil {
  2275. panic(err)
  2276. }
  2277. }
  2278. if err = r.Err(); err != nil {
  2279. panic(err)
  2280. }
  2281. }
  2282. }