Loading models.py +17 −1 Original line number Diff line number Diff line Loading @@ -201,13 +201,29 @@ class SummarisationModelWithCrossEntropyLoss(SummarisationModel): class ActorCriticSummarisationModel(SummarisationModel): def __init__(self, gpu): def __init__(self, actor_wts, critic_wts, gpu): super().__init__(gpu) self.optimizer = torch.optim.Adam(self.parameters(), lr=0.001) # actor self.load_state_dict(actor_wts) # critic critic = Critic(self) # was übergeben wird? weights? => critic critic.load_state_dict(critic_wts) self.loss_fn = critic def forward() def train() def test() # gpu, set_parameters class Critic(nn.Module): def __init__(self, model, steepness=8, denoise=100): Loading Loading
models.py +17 −1 Original line number Diff line number Diff line Loading @@ -201,13 +201,29 @@ class SummarisationModelWithCrossEntropyLoss(SummarisationModel): class ActorCriticSummarisationModel(SummarisationModel): def __init__(self, gpu): def __init__(self, actor_wts, critic_wts, gpu): super().__init__(gpu) self.optimizer = torch.optim.Adam(self.parameters(), lr=0.001) # actor self.load_state_dict(actor_wts) # critic critic = Critic(self) # was übergeben wird? weights? => critic critic.load_state_dict(critic_wts) self.loss_fn = critic def forward() def train() def test() # gpu, set_parameters class Critic(nn.Module): def __init__(self, model, steepness=8, denoise=100): Loading