ONNX Runtime
Loading...
Searching...
No Matches
onnxruntime_training_c_api.h
1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4// This file contains the training c apis.
5
6#pragma once
7#include <stdbool.h>
8#include "onnxruntime_c_api.h"
9
104ORT_RUNTIME_CLASS(TrainingSession); // Type that enables performing training for the given user models.
105ORT_RUNTIME_CLASS(CheckpointState); // Type that holds the training states for the training session.
106
109typedef enum OrtPropertyType {
114
125
142 ORT_API2_STATUS(LoadCheckpoint, _In_ const ORTCHAR_T* checkpoint_path,
143 _Outptr_ OrtCheckpointState** checkpoint_state);
144
158 ORT_API2_STATUS(SaveCheckpoint, _In_ OrtCheckpointState* checkpoint_state, _In_ const ORTCHAR_T* checkpoint_path,
159 const bool include_optimizer_state);
160
162
165
190 ORT_API2_STATUS(CreateTrainingSession, _In_ const OrtEnv* env, _In_ const OrtSessionOptions* options,
191 _Inout_ OrtCheckpointState* checkpoint_state, _In_ const ORTCHAR_T* train_model_path,
192 _In_ const ORTCHAR_T* eval_model_path, _In_ const ORTCHAR_T* optimizer_model_path,
193 _Outptr_ OrtTrainingSession** out);
194
196
199
211 ORT_API2_STATUS(TrainingSessionGetTrainingModelOutputCount, _In_ const OrtTrainingSession* sess, _Out_ size_t* out);
212
224 ORT_API2_STATUS(TrainingSessionGetEvalModelOutputCount, _In_ const OrtTrainingSession* sess, _Out_ size_t* out);
225
239 ORT_API2_STATUS(TrainingSessionGetTrainingModelOutputName, _In_ const OrtTrainingSession* sess, size_t index, _Inout_ OrtAllocator* allocator, _Outptr_ char** output);
240
254 ORT_API2_STATUS(TrainingSessionGetEvalModelOutputName, _In_ const OrtTrainingSession* sess, size_t index, _Inout_ OrtAllocator* allocator, _Outptr_ char** output);
255
257
260
272 ORT_API2_STATUS(LazyResetGrad, _Inout_ OrtTrainingSession* session);
273
295 ORT_API2_STATUS(TrainStep, _Inout_ OrtTrainingSession* sess, _In_opt_ const OrtRunOptions* run_options,
296 _In_ size_t inputs_len, _In_reads_(inputs_len) const OrtValue* const* inputs,
297 _In_ size_t outputs_len, _Inout_updates_all_(outputs_len) OrtValue** outputs);
298
314 ORT_API2_STATUS(EvalStep, _In_ const OrtTrainingSession* sess, _In_opt_ const OrtRunOptions* run_options,
315 _In_ size_t inputs_len, _In_reads_(inputs_len) const OrtValue* const* inputs,
316 _In_ size_t outputs_len, _Inout_updates_all_(outputs_len) OrtValue** outputs);
317
336 ORT_API2_STATUS(SetLearningRate, _Inout_ OrtTrainingSession* sess, _In_ float learning_rate);
337
350 ORT_API2_STATUS(GetLearningRate, _Inout_ OrtTrainingSession* sess, _Out_ float* learning_rate);
351
366 ORT_API2_STATUS(OptimizerStep, _Inout_ OrtTrainingSession* sess,
367 _In_opt_ const OrtRunOptions* run_options);
368
384 ORT_API2_STATUS(RegisterLinearLRScheduler, _Inout_ OrtTrainingSession* sess, _In_ const int64_t warmup_step_count,
385 _In_ const int64_t total_step_count, _In_ const float initial_lr);
386
400 ORT_API2_STATUS(SchedulerStep, _Inout_ OrtTrainingSession* sess);
401
403
406
419 ORT_API2_STATUS(GetParametersSize, _Inout_ OrtTrainingSession* sess, _Out_ size_t* out, bool trainable_only);
420
437 ORT_API2_STATUS(CopyParametersToBuffer, _Inout_ OrtTrainingSession* sess,
438 _Inout_ OrtValue* parameters_buffer, bool trainable_only);
439
456 ORT_API2_STATUS(CopyBufferToParameters, _Inout_ OrtTrainingSession* sess,
457 _Inout_ OrtValue* parameters_buffer, bool trainable_only);
458
460
463
470 ORT_CLASS_RELEASE(TrainingSession);
471
479 ORT_CLASS_RELEASE(CheckpointState);
480
482
485
502 ORT_API2_STATUS(ExportModelForInferencing, _Inout_ OrtTrainingSession* sess,
503 _In_ const ORTCHAR_T* inference_model_path, size_t graph_outputs_len,
504 _In_reads_(graph_outputs_len) const char* const* graph_output_names);
505
507
510
520 ORT_API2_STATUS(SetSeed, _In_ const int64_t seed);
521
523
526
537 ORT_API2_STATUS(TrainingSessionGetTrainingModelInputCount, _In_ const OrtTrainingSession* sess, _Out_ size_t* out);
538
550 ORT_API2_STATUS(TrainingSessionGetEvalModelInputCount, _In_ const OrtTrainingSession* sess, _Out_ size_t* out);
551
565 ORT_API2_STATUS(TrainingSessionGetTrainingModelInputName, _In_ const OrtTrainingSession* sess, size_t index,
566 _In_ OrtAllocator* allocator, _Outptr_ char** output);
567
581 ORT_API2_STATUS(TrainingSessionGetEvalModelInputName, _In_ const OrtTrainingSession* sess, size_t index,
582 _In_ OrtAllocator* allocator, _Outptr_ char** output);
583
585
588
603 ORT_API2_STATUS(AddProperty, _Inout_ OrtCheckpointState* checkpoint_state,
604 _In_ const char* property_name, _In_ enum OrtPropertyType property_type,
605 _In_ void* property_value);
606
621 ORT_API2_STATUS(GetProperty, _In_ const OrtCheckpointState* checkpoint_state,
622 _In_ const char* property_name, _Inout_ OrtAllocator* allocator,
623 _Out_ enum OrtPropertyType* property_type, _Outptr_ void** property_value);
624
626
629
647 ORT_API2_STATUS(LoadCheckpointFromBuffer, _In_ const void* checkpoint_buffer,
648 _In_ const size_t num_bytes, _Outptr_ OrtCheckpointState** checkpoint_state);
649
651};
652
653typedef struct OrtTrainingApi OrtTrainingApi;
654
struct OrtRunOptions OrtRunOptions
Definition: onnxruntime_c_api.h:282
struct OrtSessionOptions OrtSessionOptions
Definition: onnxruntime_c_api.h:288
struct OrtValue OrtValue
Definition: onnxruntime_c_api.h:281
struct OrtEnv OrtEnv
Definition: onnxruntime_c_api.h:276
struct OrtTrainingSession OrtTrainingSession
Definition: onnxruntime_training_c_api.h:104
struct OrtCheckpointState OrtCheckpointState
Definition: onnxruntime_training_c_api.h:105
OrtPropertyType
Type of property to be added to or returned from the OrtCheckpointState.
Definition: onnxruntime_training_c_api.h:109
@ OrtIntProperty
Definition: onnxruntime_training_c_api.h:110
@ OrtStringProperty
Definition: onnxruntime_training_c_api.h:112
@ OrtFloatProperty
Definition: onnxruntime_training_c_api.h:111
Memory allocation interface.
Definition: onnxruntime_c_api.h:315
The Training C API that holds onnxruntime training function pointers.
Definition: onnxruntime_training_c_api.h:122
OrtStatus * CopyBufferToParameters(OrtTrainingSession *sess, OrtValue *parameters_buffer, bool trainable_only)
Copy parameter values from the given contiguous buffer held by parameters_buffer to the training stat...
OrtStatus * EvalStep(const OrtTrainingSession *sess, const OrtRunOptions *run_options, size_t inputs_len, const OrtValue *const *inputs, size_t outputs_len, OrtValue **outputs)
Computes the outputs for the eval model for the given inputs.
OrtStatus * LazyResetGrad(OrtTrainingSession *session)
Reset the gradients of all trainable parameters to zero lazily.
OrtStatus * TrainingSessionGetEvalModelInputName(const OrtTrainingSession *sess, size_t index, OrtAllocator *allocator, char **output)
Retrieves the name of the user input at given index in the eval model.
OrtStatus * CreateTrainingSession(const OrtEnv *env, const OrtSessionOptions *options, OrtCheckpointState *checkpoint_state, const char *train_model_path, const char *eval_model_path, const char *optimizer_model_path, OrtTrainingSession **out)
Create a training session that can be used to begin or resume training.
OrtStatus * TrainingSessionGetTrainingModelInputCount(const OrtTrainingSession *sess, size_t *out)
Retrieves the number of user inputs in the training model.
OrtStatus * LoadCheckpoint(const char *checkpoint_path, OrtCheckpointState **checkpoint_state)
Load a checkpoint state from a file on disk into checkpoint_state.
OrtStatus * TrainStep(OrtTrainingSession *sess, const OrtRunOptions *run_options, size_t inputs_len, const OrtValue *const *inputs, size_t outputs_len, OrtValue **outputs)
Computes the outputs of the training model and the gradients of the trainable parameters for the give...
OrtStatus * ExportModelForInferencing(OrtTrainingSession *sess, const char *inference_model_path, size_t graph_outputs_len, const char *const *graph_output_names)
Export a model that can be used for inferencing.
OrtStatus * GetLearningRate(OrtTrainingSession *sess, float *learning_rate)
Gets the current learning rate for this training session.
OrtStatus * TrainingSessionGetEvalModelInputCount(const OrtTrainingSession *sess, size_t *out)
Retrieves the number of user inputs in the eval model.
OrtStatus * TrainingSessionGetEvalModelOutputCount(const OrtTrainingSession *sess, size_t *out)
Retrieves the number of user outputs in the eval model.
OrtStatus * RegisterLinearLRScheduler(OrtTrainingSession *sess, const int64_t warmup_step_count, const int64_t total_step_count, const float initial_lr)
Registers a linear learning rate scheduler for the training session.
OrtStatus * CopyParametersToBuffer(OrtTrainingSession *sess, OrtValue *parameters_buffer, bool trainable_only)
Copy all parameters to a contiguous buffer held by the argument parameters_buffer.
OrtStatus * SetLearningRate(OrtTrainingSession *sess, float learning_rate)
Sets the learning rate for this training session.
OrtStatus * TrainingSessionGetTrainingModelOutputCount(const OrtTrainingSession *sess, size_t *out)
Retrieves the number of user outputs in the training model.
OrtStatus * TrainingSessionGetTrainingModelOutputName(const OrtTrainingSession *sess, size_t index, OrtAllocator *allocator, char **output)
Retrieves the names of user outputs in the training model.
OrtStatus * TrainingSessionGetTrainingModelInputName(const OrtTrainingSession *sess, size_t index, OrtAllocator *allocator, char **output)
Retrieves the name of the user input at given index in the training model.
OrtStatus * TrainingSessionGetEvalModelOutputName(const OrtTrainingSession *sess, size_t index, OrtAllocator *allocator, char **output)
Retrieves the names of user outputs in the eval model.
OrtStatus * AddProperty(OrtCheckpointState *checkpoint_state, const char *property_name, enum OrtPropertyType property_type, void *property_value)
Adds the given property to the checkpoint state.
OrtStatus * SchedulerStep(OrtTrainingSession *sess)
Update the learning rate based on the registered learing rate scheduler.
OrtStatus * GetParametersSize(OrtTrainingSession *sess, size_t *out, bool trainable_only)
Retrieves the size of all the parameters.
OrtStatus * SetSeed(const int64_t seed)
Sets the seed used for random number generation in Onnxruntime.
OrtStatus * LoadCheckpointFromBuffer(const void *checkpoint_buffer, const size_t num_bytes, OrtCheckpointState **checkpoint_state)
Load a checkpoint state from a buffer into checkpoint_state.
OrtStatus * GetProperty(const OrtCheckpointState *checkpoint_state, const char *property_name, OrtAllocator *allocator, enum OrtPropertyType *property_type, void **property_value)
Gets the property value associated with the given name from the checkpoint state.
OrtStatus * SaveCheckpoint(OrtCheckpointState *checkpoint_state, const char *checkpoint_path, const bool include_optimizer_state)
Save the given state to a checkpoint file on disk.
OrtStatus * OptimizerStep(OrtTrainingSession *sess, const OrtRunOptions *run_options)
Performs the weight updates for the trainable parameters using the optimizer model.