From 5c8c33532f14cd33a7ddd3fdeb8c1ddf353db589 Mon Sep 17 00:00:00 2001
From: Steven Murray <Steven.Murray@cern.ch>
Date: Mon, 2 Dec 2019 13:52:29 +0100
Subject: [PATCH] Added unit-test insert_with_bindString_invalid_bool_value in
 order to verify check constraints are working

---
 rdbms/StmtTest.cpp             | 30 +++++++++++++++++++++++++-----
 rdbms/wrapper/MysqlStmt.cpp    |  3 +++
 rdbms/wrapper/OcciStmt.cpp     |  7 ++++++-
 rdbms/wrapper/Postgres.hpp     | 10 +++++++++-
 rdbms/wrapper/PostgresStmt.cpp |  8 +++++---
 rdbms/wrapper/PostgresStmt.hpp |  6 +++---
 6 files changed, 51 insertions(+), 13 deletions(-)

diff --git a/rdbms/StmtTest.cpp b/rdbms/StmtTest.cpp
index c48a794dbf..8cdb54c947 100644
--- a/rdbms/StmtTest.cpp
+++ b/rdbms/StmtTest.cpp
@@ -16,6 +16,7 @@
  * along with this program.  If not, see <http://www.gnu.org/licenses/>.
  */
 
+#include "common/exception/DatabaseConstraintError.hpp"
 #include "common/exception/Exception.hpp"
 #include "common/make_unique.hpp"
 #include "common/utils/utils.hpp"
@@ -60,11 +61,12 @@ std::string cta_rdbms_StmtTest::getCreateStmtTestTableSql() {
 
   try {
     std::string sql =
-      "CREATE TABLE STMT_TEST("
-        "DOUBLE_COL FLOAT,"
-        "UINT64_COL UINT64TYPE,"
-        "STRING_COL VARCHAR(100),"
-        "BOOL_COL CHAR(1)"
+      "CREATE TABLE STMT_TEST("                                     "\n"
+      "  DOUBLE_COL FLOAT,"                                         "\n"
+      "  UINT64_COL UINT64TYPE,"                                    "\n"
+      "  STRING_COL VARCHAR(100),"                                  "\n"
+      "  BOOL_COL   CHAR(1) DEFAULT '0',"                           "\n"
+      "  CONSTRAINT BOOL_COL_BOOL_CK CHECK(BOOL_COL IN ('0', '1'))" "\n"
       ")";
 
     switch(m_login.dbType) {
@@ -383,4 +385,22 @@ TEST_P(cta_rdbms_StmtTest, insert_with_bindBool_false) {
   }
 }
 
+TEST_P(cta_rdbms_StmtTest, insert_with_bindString_invalid_bool_value) {
+  using namespace cta::rdbms;
+
+  const std::string insertValue = "2"; // null, "0" and "1" are valid values
+
+  // Insert a row into the test table
+  {
+    const char *const sql =
+      "INSERT INTO STMT_TEST("
+        "BOOL_COL) "
+      "VALUES("
+        ":BOOL_COL)";
+    auto stmt = m_conn.createStmt(sql);
+    stmt.bindString(":BOOL_COL", insertValue);
+    ASSERT_THROW(stmt.executeNonQuery(), cta::exception::DatabaseConstraintError);
+  }
+}
+
 } // namespace unitTests
diff --git a/rdbms/wrapper/MysqlStmt.cpp b/rdbms/wrapper/MysqlStmt.cpp
index 31c713947f..ab94691190 100644
--- a/rdbms/wrapper/MysqlStmt.cpp
+++ b/rdbms/wrapper/MysqlStmt.cpp
@@ -382,6 +382,9 @@ void MysqlStmt::executeNonQuery() {
       case ER_DUP_ENTRY:
         throw exception::DatabasePrimaryKeyError(std::string(__FUNCTION__) + " " +  msg);
         break;
+      case 4025: // Newer MariaDB versions have ER_CONSTRAINT_FAILED = 4025
+        throw exception::DatabaseConstraintError(std::string(__FUNCTION__) + " " +  msg);
+        break;
       }
 
       throw exception::Exception(std::string(__FUNCTION__) + " " +  msg);
diff --git a/rdbms/wrapper/OcciStmt.cpp b/rdbms/wrapper/OcciStmt.cpp
index 44451197d3..b9155cba9a 100644
--- a/rdbms/wrapper/OcciStmt.cpp
+++ b/rdbms/wrapper/OcciStmt.cpp
@@ -16,6 +16,7 @@
  * along with this program.  If not, see <http://www.gnu.org/licenses/>.
  */
 
+#include "common/exception/DatabaseConstraintError.hpp"
 #include "common/exception/Exception.hpp"
 #include "common/exception/LostDatabaseConnection.hpp"
 #include "common/make_unique.hpp"
@@ -267,7 +268,11 @@ void OcciStmt::executeNonQuery() {
       }
       throw exception::LostDatabaseConnection(msg.str());
     }
-    throw exception::Exception(msg.str());
+    if(2290 == ex.getErrorCode()) {
+      throw exception::DatabaseConstraintError(msg.str());
+    } else {
+      throw exception::Exception(msg.str());
+    }
   }
 }
 
diff --git a/rdbms/wrapper/Postgres.hpp b/rdbms/wrapper/Postgres.hpp
index b630cbfe8c..52c2698be6 100644
--- a/rdbms/wrapper/Postgres.hpp
+++ b/rdbms/wrapper/Postgres.hpp
@@ -18,11 +18,14 @@
 
 #pragma once
 
+#include "common/exception/DatabaseConstraintError.hpp"
 #include "common/exception/Exception.hpp"
 #include "common/exception/LostDatabaseConnection.hpp"
+
+#include <algorithm>
+#include <cstring>
 #include <libpq-fe.h>
 #include <string>
-#include <algorithm>
 
 namespace cta {
 namespace rdbms {
@@ -47,10 +50,12 @@ public:
       pgstr.erase(std::remove(pgstr.begin(), pgstr.end(), '\n'), pgstr.end());
     }
     std::string resstr;
+    bool checkViolation = false;
     if (nullptr != res) {
       resstr = "DB Result Status:" + std::to_string(PQresultStatus(res));
       const char *const e = PQresultErrorField(res, PG_DIAG_SQLSTATE);
       if (nullptr != e && '\0' != *e) {
+        checkViolation = 0 == std::strcmp("23514", e);
         resstr += " SQLState:" + std::string(e);
       }
     }
@@ -85,6 +90,9 @@ public:
     if (badconn) {
       throw exception::LostDatabaseConnection(dbmsg);
     }
+    if (checkViolation) {
+      throw exception::DatabaseConstraintError(dbmsg);
+    }
     throw exception::Exception(dbmsg);
   }
 
diff --git a/rdbms/wrapper/PostgresStmt.cpp b/rdbms/wrapper/PostgresStmt.cpp
index a39c012bfa..0cb1a8f838 100644
--- a/rdbms/wrapper/PostgresStmt.cpp
+++ b/rdbms/wrapper/PostgresStmt.cpp
@@ -28,10 +28,11 @@
 #include "rdbms/wrapper/PostgresRset.hpp"
 #include "rdbms/wrapper/PostgresStmt.hpp"
 
+#include <algorithm>
 #include <exception>
+#include <pgsql/server/utils/errcodes.h>
 #include <sstream>
 #include <utility>
-#include <algorithm>
 
 namespace cta {
 namespace rdbms {
@@ -355,8 +356,9 @@ void PostgresStmt::executeNonQuery() {
     throw exception::LostDatabaseConnection(std::string(__FUNCTION__) +
       " detected lost connection for SQL statement " + getSqlForException() + ": " + ex.getMessage().str());
   } catch(exception::Exception &ex) {
-    throw exception::Exception(std::string(__FUNCTION__) + " failed for SQL statement " +
-      getSqlForException() + ": " + ex.getMessage().str());
+    ex.getMessage().str(std::string(__FUNCTION__) + " failed for SQL statement " + getSqlForException() + ": " +
+      ex.getMessage().str());
+    throw;
   }
 }
 
diff --git a/rdbms/wrapper/PostgresStmt.hpp b/rdbms/wrapper/PostgresStmt.hpp
index c1a6bc47f2..8b018d0b9a 100644
--- a/rdbms/wrapper/PostgresStmt.hpp
+++ b/rdbms/wrapper/PostgresStmt.hpp
@@ -25,10 +25,10 @@
 #include "rdbms/wrapper/PostgresConn.hpp"
 #include "rdbms/wrapper/PostgresColumn.hpp"
 
-#include <vector>
-#include <string>
-#include <memory>
 #include <cstdint>
+#include <memory>
+#include <string>
+#include <vector>
 
 namespace cta {
 namespace rdbms {
-- 
GitLab