From 7786ab3d086c51f41b2c4af1cb1a68b11190ce5b Mon Sep 17 00:00:00 2001
From: Carsten Patzke <carsten.patzke@desy.de>
Date: Mon, 6 Apr 2020 12:52:08 +0200
Subject: [PATCH] [asapo::fabric] Refactoring

---
 .../common/fabric_context_impl.cpp            | 117 +++++++++++++-----
 .../asapo_fabric/common/fabric_context_impl.h |  78 ++++--------
 2 files changed, 107 insertions(+), 88 deletions(-)

diff --git a/common/cpp/src/asapo_fabric/common/fabric_context_impl.cpp b/common/cpp/src/asapo_fabric/common/fabric_context_impl.cpp
index 51d627f11..11d0f8e5d 100644
--- a/common/cpp/src/asapo_fabric/common/fabric_context_impl.cpp
+++ b/common/cpp/src/asapo_fabric/common/fabric_context_impl.cpp
@@ -19,6 +19,8 @@ std::string __PRETTY_FUNCTION_TO_NAMESPACE__(const std::string& prettyFunction)
     return prettyFunction.substr(spaceBegin + 1, functionParamBegin - spaceBegin - 1);
 }
 
+// This marco checks if the call that is being made returns FI_SUCCESS. Should only be used with LiFabric functions
+// *error is set to the corresponding LiFabric error
 #define FI_OK(functionCall)                                     \
     do {                                                        \
         int tmp_fi_status = functionCall;                       \
@@ -27,7 +29,7 @@ std::string __PRETTY_FUNCTION_TO_NAMESPACE__(const std::string& prettyFunction)
             *error = ErrorFromFabricInternal(__PRETTY_FUNCTION_TO_NAMESPACE__(__PRETTY_FUNCTION__) + " Line " + std::to_string(__LINE__) + ": " + tmp_fi_s.substr(0, tmp_fi_s.find('(')), tmp_fi_status);\
         return;                                                 \
         }                                                       \
-    } while(0)
+    } while(0) // Enforce ';'
 
 // TODO: It is super important that version 1.10 is installed, but since its not released yet we go with 1.9
 const uint32_t FabricContextImpl::kMinExpectedLibFabricVersion = FI_VERSION(1, 9);
@@ -64,7 +66,6 @@ void FabricContextImpl::InitCommon(const std::string& networkIpHint, uint16_t se
     uint64_t additionalFlags = isServer ? FI_SOURCE : 0;
 
     fi_info* hints = fi_allocinfo();
-    // We somehow have to know if we should allocate a dummy sockets domain or a real verbs domain
     if (networkIpHint == "127.0.0.1") {
         // sockets mode
         hints->fabric_attr->prov_name = strdup("sockets");
@@ -208,13 +209,13 @@ void FabricContextImpl::StartBackgroundThreads() {
 }
 
 void FabricContextImpl::StopBackgroundThreads() {
+    alive_check_response_task_.Stop(); // This has to be done before we kill the completion thread
+
     background_threads_running_ = false;
     if (completion_thread_) {
         completion_thread_->join();
         completion_thread_ = nullptr;
     }
-
-    alive_check_response_task_.Stop();
 }
 
 void FabricContextImpl::CompletionThread() {
@@ -224,48 +225,49 @@ void FabricContextImpl::CompletionThread() {
     while(background_threads_running_ && !error) {
         ssize_t ret;
         ret = fi_cq_sreadfrom(completion_queue_, &entry, 1, &tmpAddress, nullptr, 10 /*ms*/);
-        if (ret == -FI_EAGAIN) {
-            std::this_thread::yield();
-            continue; // No data
-        }
 
-        // TODO Refactor, maybe put it in other functions and/or switch case of ret
-
-        if (ret == -FI_EAVAIL) { // An error is available
-            fi_cq_err_entry errEntry{};
-            ret = fi_cq_readerr(completion_queue_, &errEntry, 0);
-            if (ret != 1) {
-                error = ErrorFromFabricInternal("Unknown error while fi_cq_readerr", ret);
-            } else {
-                auto task = (FabricWaitableTask*)(errEntry.op_context);
+        switch (ret) {
+            case -FI_EAGAIN: // No data
+                std::this_thread::yield();
+                break;
+            case -FI_EAVAIL: // An error is in the queue
+                CompletionThreadHandleErrorAvailable(&error);
+                break;
+            case 1: { // We got 1 data entry back
+                auto task = (FabricWaitableTask*)(entry.op_context);
                 if (task) {
-                    task->HandleErrorCompletion(&errEntry);
-                } else if (hotfix_using_sockets_) {
-                    printf("[Known Sockets bug libfabric/#5795] Ignoring nullptr task!\n");
+                    task->HandleCompletion(&entry, tmpAddress);
                 } else {
-                    error = FabricErrorTemplates::kInternalError.Generate("nullptr context from fi_cq_readerr");
+                    error = FabricErrorTemplates::kInternalError.Generate("nullptr context from fi_cq_sreadfrom");
                 }
+                break;
             }
-
-            continue;
+            default:
+                error = ErrorFromFabricInternal("Unknown error while fi_cq_readfrom", ret);
+                break;
         }
+    }
 
-        if (ret != 1) { // We expect to receive 1 event
-            error = ErrorFromFabricInternal("Unknown error while fi_cq_readfrom", ret);
-            break;
-        }
+    if (error) {
+        throw std::runtime_error("ASAPO Fabric CompletionThread exited with error: " + error->Explain());
+    }
+}
 
-        auto task = (FabricWaitableTask*)(entry.op_context);
+void FabricContextImpl::CompletionThreadHandleErrorAvailable(Error* error) {
+    fi_cq_err_entry errEntry{};
+    ssize_t ret = fi_cq_readerr(completion_queue_, &errEntry, 0);
+    if (ret != 1) {
+        *error = ErrorFromFabricInternal("Unknown error while fi_cq_readerr", ret);
+    } else {
+        auto task = (FabricWaitableTask*)(errEntry.op_context);
         if (task) {
-            task->HandleCompletion(&entry, tmpAddress);
+            task->HandleErrorCompletion(&errEntry);
+        } else if (hotfix_using_sockets_) {
+            printf("[Known Sockets bug libfabric/#5795] Ignoring nullptr task!\n");
         } else {
-            error = FabricErrorTemplates::kInternalError.Generate("nullptr context from fi_cq_sreadfrom");
+            *error = FabricErrorTemplates::kInternalError.Generate("nullptr context from fi_cq_readerr");
         }
     }
-
-    if (error) {
-        throw std::runtime_error("ASAPO Fabric CompletionThread exited with error: " + error->Explain());
-    }
 }
 
 bool FabricContextImpl::TargetIsAliveCheck(FabricAddress address) {
@@ -276,3 +278,50 @@ bool FabricContextImpl::TargetIsAliveCheck(FabricAddress address) {
     // If the send was successful, then we are still able to communicate with the peer
     return !(error != nullptr);
 }
+
+void FabricContextImpl::InternalWait(FabricAddress targetAddress, FabricWaitableTask* task, Error* error) {
+
+    // Check if we simply can wait for our task
+    task->Wait(requestTimeoutMs_, error);
+
+    if (*error == FabricErrorTemplates::kTimeout) {
+        if (targetAddress == FI_ASAPO_ADDR_NO_ALIVE_CHECK) {
+            CancelTask(task, error);
+            // We expect the task to fail with 'Operation canceled'
+            if (*error == FabricErrorTemplates::kInternalOperationCanceledError) {
+                // Switch it to a timeout so its more clearly what happened
+                *error = FabricErrorTemplates::kTimeout.Generate();
+            }
+        } else {
+            InternalWaitWithAliveCheck(targetAddress, task, error);
+        }
+    }
+}
+
+void FabricContextImpl::InternalWaitWithAliveCheck(FabricAddress targetAddress, FabricWaitableTask* task,
+                                                   Error* error) {// Handle advanced alive check
+    bool aliveCheckFailed = false;
+    for (uint32_t i = 0; i < maxTimeoutRetires_ && *error == FabricErrorTemplates::kTimeout; i++) {
+        *error = nullptr;
+        printf("HandleFiCommandAndWait - Tries: %d\n", i);
+        if (!TargetIsAliveCheck(targetAddress)) {
+            aliveCheckFailed = true;
+            break;
+        }
+        task->Wait(requestTimeoutMs_, error);
+    }
+
+    CancelTask(task, error);
+
+    if (aliveCheckFailed) {
+        *error = FabricErrorTemplates::kInternalConnectionError.Generate();
+    } else if(*error == FabricErrorTemplates::kInternalOperationCanceledError) {
+        *error = FabricErrorTemplates::kTimeout.Generate();
+    }
+}
+
+void FabricContextImpl::CancelTask(FabricWaitableTask* task, Error* error) {
+    *error = nullptr;
+    fi_cancel(&endpoint_->fid, task);
+    task->Wait(0, error); // You can probably expect a kInternalOperationCanceledError
+}
diff --git a/common/cpp/src/asapo_fabric/common/fabric_context_impl.h b/common/cpp/src/asapo_fabric/common/fabric_context_impl.h
index 0925be332..8d51c4cb1 100644
--- a/common/cpp/src/asapo_fabric/common/fabric_context_impl.h
+++ b/common/cpp/src/asapo_fabric/common/fabric_context_impl.h
@@ -102,7 +102,7 @@ class FabricContextImpl : public FabricContext {
     void StartBackgroundThreads();
     void StopBackgroundThreads();
 
-    // If the targetAddress is FI_ASAPO_ADDR_NO_ALIVE_CHECK and a timeout occurs, no further ping is being done
+    // If the targetAddress is FI_ASAPO_ADDR_NO_ALIVE_CHECK and a timeout occurs, no further ping is being done.
     // Alive check is generally only necessary if you are trying to receive data or RDMA send.
     template<class FuncType, class... ArgTypes>
     inline void HandleFiCommandWithBasicTaskAndWait(FabricAddress targetAddress, Error* error,
@@ -111,51 +111,12 @@ class FabricContextImpl : public FabricContext {
         HandleFiCommandAndWait(targetAddress, &task, error, func, args...);
     }
 
-    // If the targetAddress is FI_ASAPO_ADDR_NO_ALIVE_CHECK and a timeout occurs, no further ping is being done.
-    // Alive check is generally only necessary if you are trying to receive data or RDMA send.
     template<class FuncType, class... ArgTypes>
     inline void HandleFiCommandAndWait(FabricAddress targetAddress, FabricWaitableTask* task, Error* error,
                                        FuncType func, ArgTypes... args) {
         HandleRawFiCommand(task, error, func, args...);
-
-        if (!(*error)) {
-            task->Wait(requestTimeoutMs_, error);
-            if (*error == FabricErrorTemplates::kTimeout) {
-                if (targetAddress != FI_ASAPO_ADDR_NO_ALIVE_CHECK) {
-                    // Handle advanced alive check
-                    bool aliveCheckFailed = false;
-                    for (uint32_t i = 0; i < maxTimeoutRetires_ && *error == FabricErrorTemplates::kTimeout; i++) {
-                        *error = nullptr;
-                        printf("HandleFiCommandAndWait - Tries: %d\n", i);
-                        if (!TargetIsAliveCheck(targetAddress)) {
-                            aliveCheckFailed = true;
-                            break;
-                        }
-                        task->Wait(requestTimeoutMs_, error);
-                    }
-
-                    // TODO refactor this if/else mess
-                    if (aliveCheckFailed) {
-                        fi_cancel(&endpoint_->fid, task);
-                        task->Wait(0, error);
-                        *error = FabricErrorTemplates::kInternalConnectionError.Generate();
-                    } else if(*error == FabricErrorTemplates::kTimeout) {
-                        fi_cancel(&endpoint_->fid, task);
-                        task->Wait(0, error);
-                        *error = FabricErrorTemplates::kTimeout.Generate();
-                    }
-                } else {
-                    // If a timeout occurs we want to cancel the action,
-                    // which invokes an 'Operation canceled' error in the completion queue.
-                    fi_cancel(&endpoint_->fid, task);
-                    task->Wait(0, error);
-                    // We expect the task to fail with 'Operation canceled'
-                    if (*error == FabricErrorTemplates::kInternalOperationCanceledError) {
-                        // Switch it to a timeout so its more clearly what happened
-                        *error = FabricErrorTemplates::kTimeout.Generate();
-                    }
-                }
-            }
+        if (!(*error)) { // We successfully queued our request
+            InternalWait(targetAddress, task, error);
         }
     }
 
@@ -175,24 +136,33 @@ class FabricContextImpl : public FabricContext {
             } while (ret == -FI_EAGAIN && maxTime >= clock::now());
         }
 
-        if (ret != 0) {
-            switch (-ret) {
-            case FI_EAGAIN:
-                *error = FabricErrorTemplates::kTimeout.Generate();
-                break;
-            case FI_ENOENT:
-                *error = FabricErrorTemplates::kConnectionRefusedError.Generate();
-                break;
-            default:
-                *error = ErrorFromFabricInternal("HandleFiCommandAndWait", ret);
-            }
-            return;
+        switch (-ret) {
+        case FI_SUCCESS:
+            // Success
+            break;
+        case FI_EAGAIN: // We felt trough our own timeout loop
+            *error = FabricErrorTemplates::kTimeout.Generate();
+            break;
+        case FI_ENOENT:
+            *error = FabricErrorTemplates::kConnectionRefusedError.Generate();
+            break;
+        default:
+            *error = ErrorFromFabricInternal("HandleRawFiCommand", ret);
+            break;
         }
     }
 
   private:
     bool TargetIsAliveCheck(FabricAddress address);
     void CompletionThread();
+
+    void InternalWait(FabricAddress targetAddress, FabricWaitableTask* task, Error* error);
+
+    void InternalWaitWithAliveCheck(FabricAddress targetAddress, FabricWaitableTask* task, Error* error);
+
+    void CompletionThreadHandleErrorAvailable(Error* error);
+
+    void CancelTask(FabricWaitableTask* task, Error* error);
 };
 
 }
-- 
GitLab