Skip to content
Snippets Groups Projects
Commit 85ec8450 authored by Sergey Yakubov's avatar Sergey Yakubov
Browse files

Merge pull request #17 in HIDRA2/hidra2 from feature/HIDRA2-41-start-to-end-data-flow to develop

* commit 'fe43ea53':
  correct exit from getnext_broker
  added manual script for fullchain, grafana dashboard file
  refactoring, add test for windows
  full chain, test, checkl max_id
  Create records in mongodb when receiving file
parents 8ba1e5c7 fe43ea53
No related branches found
No related tags found
No related merge requests found
Showing
with 747 additions and 97 deletions
......@@ -7,6 +7,7 @@ import (
"gopkg.in/mgo.v2"
"gopkg.in/mgo.v2/bson"
"hidra2_broker/utils"
"sync"
"time"
)
......@@ -19,28 +20,37 @@ const data_collection_name = "data"
const pointer_collection_name = "current_location"
const pointer_field_name = "current_pointer"
const no_session_msg = "database session not created"
const wrong_id_type = "wrong id type"
const already_connected_msg = "already connected"
var dbListLock sync.RWMutex
var dbPointersLock sync.RWMutex
type Mongodb struct {
main_session *mgo.Session
timeout time.Duration
databases []string
session *mgo.Session
timeout time.Duration
databases []string
parent_db *Mongodb
db_pointers_created map[string]bool
}
func (db *Mongodb) Copy() Agent {
new_db:= new(Mongodb)
new_db.main_session = db.main_session.Copy()
new_db.databases = make([]string,len(db.databases))
copy(new_db.databases,db.databases)
new_db := new(Mongodb)
new_db.session = db.session.Copy()
new_db.parent_db = db
return new_db
}
func (db *Mongodb) databaseInList(dbname string) bool {
dbListLock.RLock()
defer dbListLock.RUnlock()
return utils.StringInSlice(dbname, db.databases)
}
func (db *Mongodb) updateDatabaseList() (err error) {
db.databases, err = db.main_session.DatabaseNames()
dbListLock.Lock()
db.databases, err = db.session.DatabaseNames()
dbListLock.Unlock()
return err
}
......@@ -61,11 +71,11 @@ func (db *Mongodb) dataBaseExist(dbname string) (err error) {
}
func (db *Mongodb) Connect(address string) (err error) {
if db.main_session != nil {
if db.session != nil {
return errors.New(already_connected_msg)
}
db.main_session, err = mgo.DialWithTimeout(address, time.Second)
db.session, err = mgo.DialWithTimeout(address, time.Second)
if err != nil {
return err
}
......@@ -78,47 +88,72 @@ func (db *Mongodb) Connect(address string) (err error) {
}
func (db *Mongodb) Close() {
if db.main_session != nil {
db.main_session.Close()
db.main_session = nil
if db.session != nil {
db.session.Close()
db.session = nil
}
}
func (db *Mongodb) DeleteAllRecords(dbname string) (err error) {
if db.main_session == nil {
if db.session == nil {
return errors.New(no_session_msg)
}
return db.main_session.DB(dbname).DropDatabase()
return db.session.DB(dbname).DropDatabase()
}
func (db *Mongodb) InsertRecord(dbname string, s interface{}) error {
if db.main_session == nil {
if db.session == nil {
return errors.New(no_session_msg)
}
c := db.main_session.DB(dbname).C(data_collection_name)
c := db.session.DB(dbname).C(data_collection_name)
return c.Insert(s)
}
func (db *Mongodb) incrementField(dbname string, res interface{}) (err error) {
func (db *Mongodb) getMaxIndex(dbname string) (max_id int, err error) {
c := db.session.DB(dbname).C(data_collection_name)
var id Pointer
err = c.Find(nil).Sort("-_id").Select(bson.M{"_id": 1}).One(&id)
if err != nil {
return 0, nil
}
return id.ID, nil
}
func (db *Mongodb) createLocationPointers(dbname string) (err error) {
change := mgo.Change{
Update: bson.M{"$inc": bson.M{pointer_field_name: 1}},
Upsert: true,
ReturnNew: true,
Update: bson.M{"$inc": bson.M{pointer_field_name: 0}},
Upsert: true,
}
q := bson.M{"_id": 0}
c := db.main_session.DB(dbname).C(pointer_collection_name)
_, err = c.Find(q).Apply(change, res)
c := db.session.DB(dbname).C(pointer_collection_name)
var res map[string]interface{}
_, err = c.Find(q).Apply(change, &res)
return err
}
func (db *Mongodb) incrementField(dbname string, max_ind int, res interface{}) (err error) {
update := bson.M{"$inc": bson.M{pointer_field_name: 1}}
change := mgo.Change{
Update: update,
Upsert: false,
ReturnNew: true,
}
q := bson.M{"_id": 0, pointer_field_name: bson.M{"$lt": max_ind}}
c := db.session.DB(dbname).C(pointer_collection_name)
_, err = c.Find(q).Apply(change, res)
if err == mgo.ErrNotFound {
return &DBError{utils.StatusNoData, err.Error()}
}
return err
}
func (db *Mongodb) getRecordByID(dbname string, id int) (interface{}, error) {
var res map[string]interface{}
q := bson.M{"_id": id}
c := db.main_session.DB(dbname).C(data_collection_name)
c := db.session.DB(dbname).C(data_collection_name)
err := c.Find(q).One(&res)
if err == mgo.ErrNotFound {
return nil, &DBError{utils.StatusNoData, err.Error()}
......@@ -126,20 +161,58 @@ func (db *Mongodb) getRecordByID(dbname string, id int) (interface{}, error) {
return &res, err
}
func (db *Mongodb) needCreateLocationPointersInDb(db_name string) bool {
dbPointersLock.RLock()
needCreate := !db.db_pointers_created[db_name]
dbPointersLock.RUnlock()
return needCreate
}
func (db *Mongodb) SetLocationPointersCreateFlag(db_name string) {
dbPointersLock.Lock()
if db.db_pointers_created == nil {
db.db_pointers_created = make(map[string]bool)
}
db.db_pointers_created[db_name] = true
dbPointersLock.Unlock()
}
func (db *Mongodb) generateLocationPointersInDbIfNeeded(db_name string) {
if db.needCreateLocationPointersInDb(db_name) {
db.createLocationPointers(db_name)
db.SetLocationPointersCreateFlag(db_name)
}
}
func (db *Mongodb) getParentDB() *Mongodb {
if db.parent_db == nil {
return db
} else {
return db.parent_db
}
}
func (db *Mongodb) checkDatabaseOperationPrerequisites(db_name string) error {
if db.main_session == nil {
if db.session == nil {
return &DBError{utils.StatusError, no_session_msg}
}
if err := db.dataBaseExist(db_name); err != nil {
if err := db.getParentDB().dataBaseExist(db_name); err != nil {
return &DBError{utils.StatusWrongInput, err.Error()}
}
db.getParentDB().generateLocationPointersInDbIfNeeded(db_name)
return nil
}
func (db *Mongodb) getCurrentPointer(db_name string) (Pointer, error) {
max_ind, err := db.getMaxIndex(db_name)
if err != nil {
return Pointer{}, err
}
var curPointer Pointer
err := db.incrementField(db_name, &curPointer)
err = db.incrementField(db_name, max_ind, &curPointer)
if err != nil {
return Pointer{}, err
}
......@@ -148,6 +221,7 @@ func (db *Mongodb) getCurrentPointer(db_name string) (Pointer, error) {
}
func (db *Mongodb) GetNextRecord(db_name string) ([]byte, error) {
if err := db.checkDatabaseOperationPrerequisites(db_name); err != nil {
return nil, err
}
......
......@@ -27,6 +27,7 @@ var rec2_expect, _ = json.Marshal(rec2)
func cleanup() {
db.DeleteAllRecords(dbname)
db.db_pointers_created = nil
db.Close()
}
......@@ -59,10 +60,8 @@ func TestMongoDBGetNextErrorWhenWrongDatabasename(t *testing.T) {
func TestMongoDBGetNextErrorWhenEmptyCollection(t *testing.T) {
db.Connect(dbaddress)
db.databases = append(db.databases, dbname)
defer cleanup()
var curPointer Pointer
db.incrementField(dbname, &curPointer)
_, err := db.GetNextRecord(dbname)
assert.Equal(t, utils.StatusNoData, err.(*DBError).Code)
}
......
......@@ -28,19 +28,22 @@ func routeGetNext(w http.ResponseWriter, r *http.Request) {
w.Write(answer)
}
func returnError(err error) (answer []byte, code int) {
err_db, ok := err.(*database.DBError)
code = utils.StatusError
if ok {
code = err_db.Code
}
return []byte(err.Error()), code
}
func getNextRecord(db_name string) (answer []byte, code int) {
db_new := db.Copy()
defer db_new.Close()
statistics.IncreaseCounter()
answer, err := db_new.GetNextRecord(db_name)
if err != nil {
err_db, ok := err.(*database.DBError)
code = utils.StatusError
if ok {
code = err_db.Code
}
return []byte(err.Error()), code
return returnError(err)
}
return answer, utils.StatusOK
}
......@@ -7,10 +7,10 @@ import (
var db database.Agent
type serverSettings struct {
BrokerDbAddress string
BrokerDbAddress string
MonitorDbAddress string
MonitorDbName string
Port int
MonitorDbName string
Port int
}
var settings serverSettings
......@@ -18,7 +18,8 @@ var statistics serverStatistics
func InitDB(dbAgent database.Agent) error {
db = dbAgent
return db.Connect(settings.BrokerDbAddress)
err := db.Connect(settings.BrokerDbAddress)
return err
}
func CleanupDB() {
......
#ifndef HIDRA2_MOCKDATABASE_H
#define HIDRA2_MOCKDATABASE_H
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "database/database.h"
#include "common/error.h"
namespace hidra2 {
class MockDatabase : public Database {
public:
Error Connect(const std::string& address, const std::string& database,
const std::string& collection ) override {
return Error{Connect_t(address, database, collection)};
}
Error Insert(const FileInfo& file, bool ignore_duplicates) const override {
return Error{Insert_t(file, ignore_duplicates)};
}
MOCK_METHOD3(Connect_t, SimpleError * (const std::string&, const std::string&, const std::string&));
MOCK_CONST_METHOD2(Insert_t, SimpleError * (const FileInfo&, bool));
// stuff to test db destructor is called and avoid "uninteresting call" messages
MOCK_METHOD0(Die, void());
virtual ~MockDatabase() override {
if (check_destructor)
Die();
}
bool check_destructor{false};
};
}
#endif //HIDRA2_MOCKDATABASE_H
set(TARGET_NAME curl_http_client)
set(SOURCE_FILES
curl_http_client.cpp
)
../../include/unittests/MockDatabase.h)
################################
......
{
"__inputs": [
{
"name": "DS_TEST",
"label": "test",
"description": "",
"type": "datasource",
"pluginId": "influxdb",
"pluginName": "InfluxDB"
}
],
"__requires": [
{
"type": "grafana",
"id": "grafana",
"name": "Grafana",
"version": "5.0.0-beta5"
},
{
"type": "panel",
"id": "graph",
"name": "Graph",
"version": "5.0.0"
},
{
"type": "datasource",
"id": "influxdb",
"name": "InfluxDB",
"version": "5.0.0"
}
],
"annotations": {
"list": [
{
"builtIn": 1,
"datasource": "-- Grafana --",
"enable": true,
"hide": true,
"iconColor": "rgba(0, 211, 255, 1)",
"name": "Annotations & Alerts",
"type": "dashboard"
}
]
},
"editable": true,
"gnetId": null,
"graphTooltip": 0,
"id": null,
"links": [],
"panels": [
{
"aliasColors": {},
"bars": false,
"dashLength": 10,
"dashes": false,
"datasource": "${DS_TEST}",
"fill": 0,
"gridPos": {
"h": 9,
"w": 12,
"x": 0,
"y": 0
},
"id": 6,
"legend": {
"avg": false,
"current": false,
"max": false,
"min": false,
"show": true,
"total": false,
"values": false
},
"lines": true,
"linewidth": 1,
"links": [],
"nullPointMode": "null",
"percentage": false,
"pointradius": 5,
"points": false,
"renderer": "flot",
"seriesOverrides": [],
"spaceLength": 10,
"stack": false,
"steppedLine": false,
"targets": [
{
"alias": "Database",
"groupBy": [],
"measurement": "statistics",
"orderByTime": "ASC",
"policy": "default",
"query": "SELECT \"db_share\" FROM \"statistics\" WHERE $timeFilter",
"rawQuery": false,
"refId": "A",
"resultFormat": "time_series",
"select": [
[
{
"params": [
"db_share"
],
"type": "field"
}
]
],
"tags": []
},
{
"alias": "Disk",
"groupBy": [],
"measurement": "statistics",
"orderByTime": "ASC",
"policy": "default",
"refId": "B",
"resultFormat": "time_series",
"select": [
[
{
"params": [
"disk_share"
],
"type": "field"
}
]
],
"tags": []
},
{
"alias": "Network",
"groupBy": [],
"measurement": "statistics",
"orderByTime": "ASC",
"policy": "default",
"refId": "C",
"resultFormat": "time_series",
"select": [
[
{
"params": [
"network_share"
],
"type": "field"
}
]
],
"tags": []
}
],
"thresholds": [],
"timeFrom": null,
"timeShift": null,
"title": "Shares",
"tooltip": {
"shared": true,
"sort": 0,
"value_type": "individual"
},
"type": "graph",
"xaxis": {
"buckets": null,
"mode": "time",
"name": null,
"show": true,
"values": []
},
"yaxes": [
{
"format": "short",
"label": null,
"logBase": 1,
"max": null,
"min": null,
"show": true
},
{
"format": "short",
"label": null,
"logBase": 1,
"max": null,
"min": null,
"show": true
}
]
},
{
"aliasColors": {},
"bars": false,
"dashLength": 10,
"dashes": false,
"datasource": "${DS_TEST}",
"fill": 0,
"gridPos": {
"h": 8,
"w": 11,
"x": 12,
"y": 0
},
"id": 2,
"legend": {
"avg": false,
"current": false,
"max": false,
"min": false,
"show": true,
"total": false,
"values": false
},
"lines": true,
"linewidth": 1,
"links": [],
"nullPointMode": "null",
"percentage": true,
"pointradius": 5,
"points": false,
"renderer": "flot",
"seriesOverrides": [],
"spaceLength": 10,
"stack": false,
"steppedLine": false,
"targets": [
{
"alias": "Total",
"groupBy": [],
"measurement": "statistics",
"orderByTime": "ASC",
"policy": "default",
"refId": "A",
"resultFormat": "time_series",
"select": [
[
{
"params": [
"data_volume"
],
"type": "field"
},
{
"params": [
" / elapsed_ms/1024/1024/1024*1000*8"
],
"type": "math"
}
]
],
"tags": []
}
],
"thresholds": [],
"timeFrom": null,
"timeShift": null,
"title": "Bandwidth",
"tooltip": {
"shared": true,
"sort": 0,
"value_type": "individual"
},
"type": "graph",
"xaxis": {
"buckets": null,
"mode": "time",
"name": null,
"show": true,
"values": []
},
"yaxes": [
{
"format": "short",
"label": null,
"logBase": 1,
"max": null,
"min": null,
"show": true
},
{
"format": "short",
"label": null,
"logBase": 1,
"max": null,
"min": null,
"show": true
}
]
},
{
"aliasColors": {},
"bars": false,
"dashLength": 10,
"dashes": false,
"datasource": "${DS_TEST}",
"fill": 0,
"gridPos": {
"h": 8,
"w": 11,
"x": 12,
"y": 8
},
"id": 4,
"legend": {
"avg": false,
"current": false,
"max": false,
"min": false,
"show": true,
"total": false,
"values": false
},
"lines": true,
"linewidth": 1,
"links": [],
"nullPointMode": "null",
"percentage": false,
"pointradius": 5,
"points": false,
"renderer": "flot",
"seriesOverrides": [],
"spaceLength": 10,
"stack": false,
"steppedLine": false,
"targets": [
{
"alias": "Receiver",
"groupBy": [
{
"params": [
"$__interval"
],
"type": "time"
},
{
"params": [
"null"
],
"type": "fill"
}
],
"measurement": "RequestsRate",
"orderByTime": "ASC",
"policy": "default",
"query": "SELECT \"n_requests\" / elapsed_ms*1000 FROM \"statistics\" WHERE $timeFilter",
"rawQuery": true,
"refId": "A",
"resultFormat": "time_series",
"select": [
[
{
"params": [
"n_requests"
],
"type": "field"
},
{
"params": [],
"type": "mean"
}
]
],
"tags": []
},
{
"alias": "Broker",
"groupBy": [],
"measurement": "RequestsRate",
"orderByTime": "ASC",
"policy": "default",
"refId": "B",
"resultFormat": "time_series",
"select": [
[
{
"params": [
"rate"
],
"type": "field"
}
]
],
"tags": []
}
],
"thresholds": [],
"timeFrom": null,
"timeShift": null,
"title": "Number of Requests",
"tooltip": {
"shared": true,
"sort": 0,
"value_type": "individual"
},
"transparent": false,
"type": "graph",
"xaxis": {
"buckets": null,
"mode": "time",
"name": null,
"show": true,
"values": []
},
"yaxes": [
{
"format": "short",
"label": null,
"logBase": 1,
"max": null,
"min": null,
"show": true
},
{
"format": "short",
"label": null,
"logBase": 1,
"max": null,
"min": null,
"show": true
}
]
}
],
"refresh": false,
"schemaVersion": 16,
"style": "dark",
"tags": [],
"templating": {
"list": []
},
"time": {
"from": "now/d",
"to": "now/d"
},
"timepicker": {
"refresh_intervals": [
"5s",
"10s",
"30s",
"1m",
"5m",
"15m",
"30m",
"1h",
"2h",
"1d"
],
"time_options": [
"5m",
"15m",
"1h",
"6h",
"12h",
"24h",
"2d",
"7d",
"30d"
]
},
"timezone": "",
"title": "ASAP::O",
"uid": "3JvTwliiz",
"version": 4
}
\ No newline at end of file
......@@ -30,7 +30,7 @@ bool SendDummyData(hidra2::Producer* producer, size_t number_of_byte, uint64_t i
for(uint64_t i = 0; i < iterations; i++) {
// std::cerr << "Send file " << i + 1 << "/" << iterations << std::endl;
auto err = producer->Send(i, buffer.get(), number_of_byte);
auto err = producer->Send(i + 1, buffer.get(), number_of_byte);
if (err) {
std::cerr << "File was not successfully send: " << err << std::endl;
......
......@@ -3,6 +3,7 @@ set(SOURCE_FILES getnext_broker.cpp)
add_executable(${TARGET_NAME} ${SOURCE_FILES})
target_link_libraries(${TARGET_NAME} hidra2-worker)
#use expression generator to get rid of VS adding Debug/Release folders
set_target_properties(${TARGET_NAME} PROPERTIES RUNTIME_OUTPUT_DIRECTORY
${CMAKE_CURRENT_BINARY_DIR}$<$<CONFIG:Debug>:>
......
......@@ -18,24 +18,26 @@ void WaitThreads(std::vector<std::thread>* threads) {
}
}
void ProcessError(const Error& err) {
if (err == nullptr) return;
int ProcessError(const Error& err) {
if (err == nullptr) return 0;
if (err->GetErrorType() != hidra2::ErrorType::kEndOfFile) {
std::cout << err->Explain() << std::endl;
exit(EXIT_FAILURE);
return 1;
}
return 0;
}
std::vector<std::thread> StartThreads(const std::string& server, const std::string& run_name, int nthreads,
std::vector<int>* nfiles) {
auto exec_next = [server, run_name, nfiles](int i) {
std::vector<int>* nfiles, std::vector<int>* errors) {
auto exec_next = [server, run_name, nfiles, errors](int i) {
hidra2::FileInfo fi;
Error err;
auto broker = hidra2::DataBrokerFactory::CreateServerBroker(server, run_name, &err);
broker->SetTimeout(1000);
while ((err = broker->GetNext(&fi, nullptr)) == nullptr) {
(*nfiles)[i] ++;
}
ProcessError(err);
(*errors)[i] = ProcessError(err);
};
std::vector<std::thread> threads;
......@@ -50,11 +52,16 @@ int ReadAllData(const std::string& server, const std::string& run_name, int nthr
high_resolution_clock::time_point t1 = high_resolution_clock::now();
std::vector<int>nfiles(nthreads, 0);
std::vector<int>errors(nthreads, 0);
auto threads = StartThreads(server, run_name, nthreads, &nfiles);
auto threads = StartThreads(server, run_name, nthreads, &nfiles, &errors);
WaitThreads(&threads);
int n_total = std::accumulate(nfiles.begin(), nfiles.end(), 0);
int errors_total = std::accumulate(errors.begin(), errors.end(), 0);
if (errors_total) {
exit(EXIT_FAILURE);
}
high_resolution_clock::time_point t2 = high_resolution_clock::now();
auto duration_read = std::chrono::duration_cast<std::chrono::milliseconds>( t2 - t1 );
......
......@@ -7,7 +7,8 @@ set(SOURCE_FILES
src/request_handler_file_write.cpp
src/statistics.cpp
src/statistics_sender_influx_db.cpp
src/receiver_config.cpp src/receiver_config.h)
src/receiver_config.cpp
src/request_handler_db_write.cpp)
################################
......@@ -19,7 +20,7 @@ add_library(${TARGET_NAME} STATIC ${SOURCE_FILES} $<TARGET_OBJECTS:system_io> $<
$<TARGET_OBJECTS:json_parser>)
set_target_properties(${TARGET_NAME} PROPERTIES LINKER_LANGUAGE CXX)
target_include_directories(${TARGET_NAME} PUBLIC ${HIDRA2_CXX_COMMON_INCLUDE_DIR} ${CURL_INCLUDE_DIRS})
target_link_libraries(${TARGET_NAME} ${CURL_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
target_link_libraries(${TARGET_NAME} ${CURL_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT} database)
add_executable(${TARGET_NAME}-bin src/main.cpp)
......@@ -38,7 +39,9 @@ set(TEST_SOURCE_FILES
unittests/test_statistics.cpp
unittests/test_config.cpp
unittests/test_request.cpp
unittests/test_request_factory.cpp
unittests/test_request_handler_file_write.cpp
unittests/test_request_handler_db_writer.cpp
unittests/test_statistics_sender_influx_db.cpp
unittests/mock_receiver_config.cpp
)
......
......@@ -17,6 +17,10 @@ Error ReceiverConfigFactory::SetConfigFromFile(std::string file_name) {
(err = parser.GetString("MonitorDbAddress", &config.monitor_db_uri)) ||
(err = parser.GetUInt64("ListenPort", &config.listen_port)) ||
(err = parser.GetBool("WriteToDisk", &config.write_to_disk)) ||
(err = parser.GetBool("WriteToDb", &config.write_to_db)) ||
(err = parser.GetString("BrokerDbAddress", &config.broker_db_uri)) ||
(err = parser.GetString("BrokerDbName", &config.broker_db_name)) ||
(err = parser.GetString("MonitorDbName", &config.monitor_db_name));
return err;
}
......
......@@ -9,8 +9,12 @@ namespace hidra2 {
struct ReceiverConfig {
std::string monitor_db_uri;
std::string monitor_db_name;
std::string broker_db_uri;
std::string broker_db_name;
uint64_t listen_port = 0;
bool write_to_disk = false;
bool write_to_db = false;
};
const ReceiverConfig* GetReceiverConfig();
......
......@@ -60,6 +60,11 @@ void Request::AddHandler(const RequestHandler* handler) {
}
uint64_t Request::GetDataID() const {
return request_header_.data_id;
}
uint64_t Request::GetDataSize() const {
return request_header_.data_size;
}
......@@ -72,7 +77,6 @@ std::string Request::GetFileName() const {
return std::to_string(request_header_.data_id) + ".bin";
}
std::unique_ptr<Request> RequestFactory::GenerateRequest(const GenericNetworkRequestHeader&
request_header, SocketDescriptor socket_fd,
Error* err) const noexcept {
......@@ -80,16 +84,22 @@ std::unique_ptr<Request> RequestFactory::GenerateRequest(const GenericNetworkReq
switch (request_header.op_code) {
case Opcode::kNetOpcodeSendData: {
auto request = std::unique_ptr<Request> {new Request{request_header, socket_fd}};
if (GetReceiverConfig()->write_to_disk) {
request->AddHandler(&request_handler_filewrite_);
}
if (GetReceiverConfig()->write_to_db) {
request->AddHandler(&request_handler_dbwrite_);
}
return request;
}
default:
*err = ReceiverErrorTemplates::kInvalidOpCode.Generate();
return nullptr;
}
}
}
\ No newline at end of file
......@@ -6,6 +6,7 @@
#include "io/io.h"
#include "request_handler.h"
#include "request_handler_file_write.h"
#include "request_handler_db_write.h"
#include "statistics.h"
namespace hidra2 {
......@@ -20,6 +21,7 @@ class Request {
void AddHandler(const RequestHandler*);
const RequestHandlerList& GetListHandlers() const;
virtual uint64_t GetDataSize() const;
virtual uint64_t GetDataID() const;
virtual std::string GetFileName() const;
virtual const FileData& GetData() const;
......@@ -39,8 +41,10 @@ class RequestFactory {
SocketDescriptor socket_fd, Error* err) const noexcept;
private:
RequestHandlerFileWrite request_handler_filewrite_;
RequestHandlerDbWrite request_handler_dbwrite_;
};
}
#endif //HIDRA2_REQUEST_H
#include "request_handler_db_write.h"
#include "request.h"
#include "receiver_config.h"
namespace hidra2 {
Error RequestHandlerDbWrite::ProcessRequest(const Request& request) const {
if (Error err = ConnectToDbIfNeeded() ) {
return err;
}
FileInfo file_info;
file_info.name = request.GetFileName();
file_info.size = request.GetDataSize();
file_info.id = request.GetDataID();
return db_client__->Insert(file_info, false);
}
RequestHandlerDbWrite::RequestHandlerDbWrite() {
DatabaseFactory factory;
Error err;
db_client__ = factory.Create(&err);
}
StatisticEntity RequestHandlerDbWrite::GetStatisticEntity() const {
return StatisticEntity::kDatabase;
}
Error RequestHandlerDbWrite::ConnectToDbIfNeeded() const {
if (!connected_to_db) {
Error err = db_client__->Connect(GetReceiverConfig()->broker_db_uri, GetReceiverConfig()->broker_db_name,
kDBCollectionName);
if (err) {
return err;
}
connected_to_db = true;
}
return nullptr;
}
}
\ No newline at end of file
#ifndef HIDRA2_REQUEST_HANDLER_DB_WRITE_H
#define HIDRA2_REQUEST_HANDLER_DB_WRITE_H
#include "request_handler.h"
#include "database/database.h"
#include "io/io.h"
namespace hidra2 {
class RequestHandlerDbWrite final: public RequestHandler {
public:
RequestHandlerDbWrite();
StatisticEntity GetStatisticEntity() const override;
Error ProcessRequest(const Request& request) const override;
std::unique_ptr<Database> db_client__;
private:
Error ConnectToDbIfNeeded() const;
mutable bool connected_to_db = false;
};
}
#endif //HIDRA2_REQUEST_HANDLER_DB_WRITE_H
......@@ -17,8 +17,12 @@ Error SetReceiverConfig (const ReceiverConfig& config) {
auto config_string = std::string("{\"MonitorDbAddress\":") + "\"" + config.monitor_db_uri + "\"";
config_string += "," + std::string("\"MonitorDbName\":") + "\"" + config.monitor_db_name + "\"";
config_string += "," + std::string("\"BrokerDbName\":") + "\"" + config.broker_db_name + "\"";
config_string += "," + std::string("\"BrokerDbAddress\":") + "\"" + config.broker_db_uri + "\"";
config_string += "," + std::string("\"ListenPort\":") + std::to_string(config.listen_port);
config_string += "," + std::string("\"WriteToDisk\":") + (config.write_to_disk ? "true" : "false");
config_string += "," + std::string("\"WriteToDb\":") + (config.write_to_db ? "true" : "false");
config_string += "}";
EXPECT_CALL(mock_io, ReadFileToString_t("fname", _)).WillOnce(
......
......@@ -53,7 +53,9 @@ TEST_F(ConfigTests, ReadSettings) {
test_config.monitor_db_name = "db_test";
test_config.monitor_db_uri = "localhost:8086";
test_config.write_to_disk = true;
test_config.write_to_db = true;
test_config.broker_db_uri = "localhost:27017";
test_config.broker_db_name = "test";
auto err = hidra2::SetReceiverConfig(test_config);
......@@ -62,8 +64,11 @@ TEST_F(ConfigTests, ReadSettings) {
ASSERT_THAT(err, Eq(nullptr));
ASSERT_THAT(config->monitor_db_uri, Eq("localhost:8086"));
ASSERT_THAT(config->monitor_db_name, Eq("db_test"));
ASSERT_THAT(config->broker_db_uri, Eq("localhost:27017"));
ASSERT_THAT(config->broker_db_name, Eq("test"));
ASSERT_THAT(config->listen_port, Eq(4200));
ASSERT_THAT(config->write_to_disk, true);
ASSERT_THAT(config->write_to_db, true);
}
......
......@@ -6,6 +6,8 @@
#include "../src/request.h"
#include "../src/request_handler.h"
#include "../src/request_handler_file_write.h"
#include "../src/request_handler_db_write.h"
#include "database/database.h"
#include "mock_statistics.h"
#include "mock_receiver_config.h"
......@@ -40,7 +42,7 @@ using hidra2::StatisticEntity;
using hidra2::ReceiverConfig;
using hidra2::SetReceiverConfig;
using hidra2::RequestFactory;
namespace {
......@@ -58,47 +60,6 @@ class MockReqestHandler : public hidra2::RequestHandler {
};
class FactoryTests : public Test {
public:
hidra2::RequestFactory factory;
Error err{nullptr};
GenericNetworkRequestHeader generic_request_header;
ReceiverConfig config;
void SetUp() override {
config.write_to_disk = true;
SetReceiverConfig(config);
}
void TearDown() override {
}
};
TEST_F(FactoryTests, ErrorOnWrongCode) {
generic_request_header.op_code = hidra2::Opcode::kNetOpcodeUnknownOp;
auto request = factory.GenerateRequest(generic_request_header, 1, &err);
ASSERT_THAT(err, Ne(nullptr));
}
TEST_F(FactoryTests, ReturnsDataRequestOnkNetOpcodeSendDataCode) {
generic_request_header.op_code = hidra2::Opcode::kNetOpcodeSendData;
auto request = factory.GenerateRequest(generic_request_header, 1, &err);
ASSERT_THAT(err, Eq(nullptr));
ASSERT_THAT(dynamic_cast<hidra2::Request*>(request.get()), Ne(nullptr));
ASSERT_THAT(dynamic_cast<const hidra2::RequestHandlerFileWrite*>(request->GetListHandlers()[0]), Ne(nullptr));
}
TEST_F(FactoryTests, DoNotAddWriterIfNotWanted) {
generic_request_header.op_code = hidra2::Opcode::kNetOpcodeSendData;
config.write_to_disk = false;
SetReceiverConfig(config);
auto request = factory.GenerateRequest(generic_request_header, 1, &err);
ASSERT_THAT(err, Eq(nullptr));
ASSERT_THAT(request->GetListHandlers().size(), Eq(0));
}
class RequestTests : public Test {
public:
......@@ -212,6 +173,13 @@ TEST_F(RequestTests, GetDataIsNotNullptr) {
}
TEST_F(RequestTests, GetDataID) {
auto id = request->GetDataID();
ASSERT_THAT(id, Eq(data_id_));
}
TEST_F(RequestTests, GetDataSize) {
auto size = request->GetDataSize();
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment