Files
hailort/hailort/common/genai/session_wrapper/session_wrapper.hpp
HailoRT-Automation 0df636dcb6 v4.21.0 (#25)
2025-04-01 15:01:01 +03:00

79 lines
2.7 KiB
C++

/**
* Copyright (c) 2019-2025 Hailo Technologies Ltd. All rights reserved.
* Distributed under the MIT license (https://opensource.org/licenses/MIT)
**/
/**
* @file session_wrapper.hpp
* @brief a wrapper for session
**/
#ifndef _HAILO_COMMON_GENAI_SESSION_WRAPPER_HPP_
#define _HAILO_COMMON_GENAI_SESSION_WRAPPER_HPP_
#include "hailo/hailort.h"
#include "hailo/buffer.hpp"
#include "hailo/hailo_session.hpp"
#include "common/utils.hpp"
#include "common/genai/serializer/serializer.hpp"
namespace hailort
{
namespace genai
{
class SessionWrapper final
{
public:
SessionWrapper(std::shared_ptr<Session> session) : m_session(session) {}
~SessionWrapper() = default;
Expected<std::shared_ptr<Buffer>> read(std::chrono::milliseconds timeout = Session::DEFAULT_READ_TIMEOUT)
{
TimeoutGuard timeout_guard(timeout);
size_t size_to_read = 0;
CHECK_SUCCESS_AS_EXPECTED(m_session->read(reinterpret_cast<uint8_t*>(&size_to_read),
sizeof(size_to_read), timeout_guard.get_remaining_timeout()));
TRY(auto buffer, Buffer::create_shared(size_to_read, BufferStorageParams::create_dma()));
CHECK_SUCCESS(m_session->read(buffer->data(), size_to_read, timeout_guard.get_remaining_timeout()));
return buffer;
}
Expected<size_t> read(MemoryView buffer, std::chrono::milliseconds timeout = Session::DEFAULT_READ_TIMEOUT)
{
TimeoutGuard timeout_guard(timeout);
size_t size_to_read = 0;
CHECK_SUCCESS_AS_EXPECTED(m_session->read(reinterpret_cast<uint8_t*>(&size_to_read),
sizeof(size_to_read), timeout_guard.get_remaining_timeout()));
CHECK(size_to_read <= buffer.size(), HAILO_INVALID_OPERATION,
"Read buffer is smaller then necessary. Buffer size = {}, generation size = {}",
buffer.size(), size_to_read);
CHECK_SUCCESS(m_session->read(buffer.data(), size_to_read, timeout_guard.get_remaining_timeout()));
return size_to_read;
}
hailo_status write(MemoryView buffer, std::chrono::milliseconds timeout = Session::DEFAULT_WRITE_TIMEOUT)
{
TimeoutGuard timeout_guard(timeout);
// First we send the buffer's size. Then the buffer itself.
// TODO: Use hrpc protocol
size_t size = buffer.size();
CHECK_SUCCESS(m_session->write(reinterpret_cast<const uint8_t*>(&size), sizeof(size), timeout_guard.get_remaining_timeout()));
CHECK_SUCCESS(m_session->write(buffer.data(), size, timeout_guard.get_remaining_timeout()));
return HAILO_SUCCESS;
}
private:
std::shared_ptr<Session> m_session;
};
} /* namespace genai */
} /* namespace hailort */
#endif /* _HAILO_COMMON_GENAI_SESSION_WRAPPER_HPP_ */