aind_behavior_curriculum.trainer.Trainer

class aind_behavior_curriculum.trainer.Trainer(curriculum: TCurriculum)[source]

Bases: Generic[TCurriculum]

Trainer class for managing and evaluating curriculum stages and policy transitions, and updating the task parameters based on the active policies and provided metrics. The entry point is the “evaluate” method. .. attribute:: curriculum

The curriculum used by the trainer.

type:

TCurriculum

__init__(self, curriculum

TCurriculum):

curriculum(self) TCurriculum[source]
_evaluate_stage_transition(curriculum

Curriculum, current_stage: Stage, metrics: TMetrics) -> Optional[Stage]:

_evaluate_policy_transitions(cls, current_stage

Stage, active_policies: Iterable[Policy], metrics: TMetrics) -> List[Policy]: Evaluates policy transitions for the given current stage and currently active policies, based on the provided metrics.

evaluate(self, trainer_state

TrainerState, metrics: TMetrics) -> TrainerState:

get_net_parameter_update(stage_parameters

TaskParameters, stage_policies: Iterable[Policy], curr_metrics: Metrics) -> TaskParameters: Aggregates parameter updates of input stage_policies given current stage_parameters and current metrics.

_get_unique_policies(policies

List[Policy]) -> List[Policy]: Filters unique policies based on their rule functions and reassembles the Policy objects.

__init__(curriculum: TCurriculum)[source]

Initializes the Trainer with the given curriculum. :param curriculum: The curriculum to be used by the trainer. :type curriculum: TCurriculum

Methods

__init__(curriculum)

Initializes the Trainer with the given curriculum.

create_trainer_state(*, stage[, ...])

Property that returns a type-aware TrainerState class.

evaluate(trainer_state, metrics)

Evaluates the current state of the trainer and updates the stage and policies based on the provided metrics.

get_net_parameter_update(stage_parameters, ...)

Aggregates parameter update of input stage_policies given current stage_parameters and current metrics.

Attributes

curriculum

Property that returns the current curriculum.

create_trainer_state(*, stage: Stage | None, is_on_curriculum: bool = True, active_policies: Iterable[Policy] | None = None) TrainerState[source]

Property that returns a type-aware TrainerState class.

Returns:

type-aware TrainerState type.

Return type:

Type[TrainerState]

property curriculum: TCurriculum[source]

Property that returns the current curriculum.

Returns:

The current curriculum instance.

Return type:

TCurriculum

evaluate(trainer_state: TrainerState, metrics: TMetrics) TrainerState[source]

Evaluates the current state of the trainer and updates the stage and policies based on the provided metrics. :param trainer_state: The current state of the trainer, including the current stage and active policies. :type trainer_state: TrainerState :param metrics: The metrics used to evaluate the current state and determine transitions. :type metrics: TMetrics

Returns:

The updated state of the trainer, including the new stage and active policies.

Return type:

TrainerState

Raises:

ValueError – If the current stage or active policies are not set in the trainer state.

static get_net_parameter_update(stage_parameters: TaskParameters, stage_policies: Iterable[Policy], curr_metrics: Metrics) TaskParameters[source]

Aggregates parameter update of input stage_policies given current stage_parameters and current metrics.