Commit 3c822842 authored by wu's avatar wu
Browse files

Update models.py

parent f19165b7
Loading
Loading
Loading
Loading
+17 −1
Original line number Diff line number Diff line
@@ -201,13 +201,29 @@ class SummarisationModelWithCrossEntropyLoss(SummarisationModel):

class ActorCriticSummarisationModel(SummarisationModel):

    def __init__(self, gpu):
    def __init__(self, actor_wts, critic_wts, gpu):

        super().__init__(gpu)

        self.optimizer = torch.optim.Adam(self.parameters(), lr=0.001)

        # actor 
        self.load_state_dict(actor_wts)

        # critic
        critic = Critic(self) # was übergeben wird? weights? => critic
        critic.load_state_dict(critic_wts)
        self.loss_fn = critic

    def forward()

    def train()

    def test()



# gpu, set_parameters
class Critic(nn.Module):

    def __init__(self, model, steepness=8, denoise=100):