ONNX Runtime
Loading...
Searching...
No Matches
onnxruntime_training_cxx_api.h
1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4#pragma once
5#include "onnxruntime_training_c_api.h"
6#include <optional>
7#include <variant>
8
9namespace Ort::detail {
10
11#define ORT_DECLARE_TRAINING_RELEASE(NAME) \
12 void OrtRelease(Ort##NAME* ptr);
13
14// These release methods must be forward declared before including onnxruntime_cxx_api.h
15// otherwise class Base won't be aware of them
16ORT_DECLARE_TRAINING_RELEASE(CheckpointState);
17ORT_DECLARE_TRAINING_RELEASE(TrainingSession);
18
19} // namespace Ort::detail
20
21#include "onnxruntime_cxx_api.h"
22
23namespace Ort {
24
31
32namespace detail {
33
34#define ORT_DEFINE_TRAINING_RELEASE(NAME) \
35 inline void OrtRelease(Ort##NAME* ptr) { GetTrainingApi().Release##NAME(ptr); }
36
37ORT_DEFINE_TRAINING_RELEASE(CheckpointState);
38ORT_DEFINE_TRAINING_RELEASE(TrainingSession);
39
40#undef ORT_DECLARE_TRAINING_RELEASE
41#undef ORT_DEFINE_TRAINING_RELEASE
42
43} // namespace detail
44
45using Property = std::variant<int64_t, float, std::string>;
46
63class CheckpointState : public detail::Base<OrtCheckpointState> {
64 private:
65 CheckpointState(OrtCheckpointState* checkpoint_state) { p_ = checkpoint_state; }
66
67 public:
68 // Construct the checkpoint state by loading the checkpoint by calling LoadCheckpoint
69 CheckpointState() = delete;
70
73
85 static CheckpointState LoadCheckpoint(const std::basic_string<ORTCHAR_T>& path_to_checkpoint);
86
98 static CheckpointState LoadCheckpointFromBuffer(const std::vector<uint8_t>& buffer);
99
111 static void SaveCheckpoint(const CheckpointState& checkpoint_state,
112 const std::basic_string<ORTCHAR_T>& path_to_checkpoint,
113 const bool include_optimizer_state = false);
114
125 void AddProperty(const std::string& property_name, const Property& property_value);
126
136 Property GetProperty(const std::string& property_name);
137
139};
140
152class TrainingSession : public detail::Base<OrtTrainingSession> {
153 private:
154 size_t training_model_output_count_, eval_model_output_count_;
155
156 public:
159
174 TrainingSession(const Env& env, const SessionOptions& session_options, CheckpointState& checkpoint_state,
175 const std::basic_string<ORTCHAR_T>& train_model_path,
176 const std::optional<std::basic_string<ORTCHAR_T>>& eval_model_path = std::nullopt,
177 const std::optional<std::basic_string<ORTCHAR_T>>& optimizer_model_path = std::nullopt);
178
180
183
199 std::vector<Value> TrainStep(const std::vector<Value>& input_values);
200
209
219 std::vector<Value> EvalStep(const std::vector<Value>& input_values);
220
236 void SetLearningRate(float learning_rate);
237
247 float GetLearningRate() const;
248
261 void RegisterLinearLRScheduler(int64_t warmup_step_count, int64_t total_step_count,
262 float initial_lr);
263
274
285
287
290
304 void ExportModelForInferencing(const std::basic_string<ORTCHAR_T>& inference_model_path,
305 const std::vector<std::string>& graph_output_names);
306
308
311
321 std::vector<std::string> InputNames(const bool training);
322
333 std::vector<std::string> OutputNames(const bool training);
334
336
339
346 Value ToBuffer(const bool only_trainable);
347
352 void FromBuffer(Value& buffer);
353
355};
356
359
366void SetSeed(const int64_t seed);
368
370
371} // namespace Ort
372
373#include "onnxruntime_training_cxx_inline.h"
Holds the state of the training session.
Definition: onnxruntime_training_cxx_api.h:63
static CheckpointState LoadCheckpointFromBuffer(const std::vector< uint8_t > &buffer)
Load a checkpoint state from a buffer.
void AddProperty(const std::string &property_name, const Property &property_value)
Adds the given property to the checkpoint state.
static CheckpointState LoadCheckpoint(const std::basic_string< char > &path_to_checkpoint)
Load a checkpoint state from a file on disk into checkpoint_state.
static void SaveCheckpoint(const CheckpointState &checkpoint_state, const std::basic_string< char > &path_to_checkpoint, const bool include_optimizer_state=false)
Save the given state to a checkpoint file on disk.
Property GetProperty(const std::string &property_name)
Gets the property value associated with the given name from the checkpoint state.
Trainer class that provides training, evaluation and optimizer methods for training an ONNX models.
Definition: onnxruntime_training_cxx_api.h:152
void OptimizerStep()
Performs the weight updates for the trainable parameters using the optimizer model.
void RegisterLinearLRScheduler(int64_t warmup_step_count, int64_t total_step_count, float initial_lr)
Registers a linear learning rate scheduler for the training session.
std::vector< Value > EvalStep(const std::vector< Value > &input_values)
Computes the outputs for the eval model for the given inputs.
std::vector< std::string > InputNames(const bool training)
Retrieves the names of the user inputs for the training and eval models.
float GetLearningRate() const
Gets the current learning rate for this training session.
void ExportModelForInferencing(const std::basic_string< char > &inference_model_path, const std::vector< std::string > &graph_output_names)
Export a model that can be used for inferencing.
void LazyResetGrad()
Reset the gradients of all trainable parameters to zero lazily.
Value ToBuffer(const bool only_trainable)
Returns a contiguous buffer that holds a copy of all training state parameters.
std::vector< Value > TrainStep(const std::vector< Value > &input_values)
Computes the outputs of the training model and the gradients of the trainable parameters for the give...
TrainingSession(const Env &env, const SessionOptions &session_options, CheckpointState &checkpoint_state, const std::basic_string< char > &train_model_path, const std::optional< std::basic_string< char > > &eval_model_path=std::nullopt, const std::optional< std::basic_string< char > > &optimizer_model_path=std::nullopt)
Create a training session that can be used to begin or resume training.
void SchedulerStep()
Update the learning rate based on the registered learing rate scheduler.
std::vector< std::string > OutputNames(const bool training)
Retrieves the names of the user outputs for the training and eval models.
void FromBuffer(Value &buffer)
Loads the training session model parameters from a contiguous buffer.
void SetLearningRate(float learning_rate)
Sets the learning rate for this training session.
#define ORT_API_VERSION
The API version defined in this header.
Definition: onnxruntime_c_api.h:40
struct OrtCheckpointState OrtCheckpointState
Definition: onnxruntime_training_c_api.h:105
void SetSeed(const int64_t seed)
This function sets the seed for generating random numbers.
Definition: onnxruntime_cxx_api.h:281
All C++ Onnxruntime APIs are defined inside this namespace.
Definition: onnxruntime_cxx_api.h:45
const OrtApi & GetApi() noexcept
This returns a reference to the OrtApi interface in use.
Definition: onnxruntime_cxx_api.h:122
std::variant< int64_t, float, std::string > Property
Definition: onnxruntime_training_cxx_api.h:45
const OrtTrainingApi & GetTrainingApi()
This function returns the C training api struct with the pointers to the ort training C functions....
Definition: onnxruntime_training_cxx_api.h:30
The Env (Environment)
Definition: onnxruntime_cxx_api.h:479
Wrapper around OrtSessionOptions.
Definition: onnxruntime_cxx_api.h:692
Wrapper around OrtValue.
Definition: onnxruntime_cxx_api.h:1365
Used internally by the C++ API. C++ wrapper types inherit from this. This is a zero cost abstraction ...
Definition: onnxruntime_cxx_api.h:338
contained_type * p_
Definition: onnxruntime_cxx_api.h:366
const OrtTrainingApi *(* GetTrainingApi)(uint32_t version)
Gets the Training C Api struct.
Definition: onnxruntime_c_api.h:3649
The Training C API that holds onnxruntime training function pointers.
Definition: onnxruntime_training_c_api.h:122