function [U, out] = ftvcs_al_TVL2p(A,b,p,q,opts)
%
% Goal: solve    min sum ||D_i u|| + mu/2||Au-b||_2^2    (with or without 
%       the constraint u>=0) to recover image/signal u from encoded b,
%
%       which is equivalent to solve     min sum ||w_i|| + mu/2||Au-b||_2^2
%                                          s.t. D_i u = w_i
% ftvcs_al solves Augmented Lagrangian function:
% 
% min_{u,w} sum ||w_i|| - sigma'(Du-w) 
%                   + beta/2 ||Du-w||_2^2 + mu/2||Au-b||_2^2 ,
%
% by an alternating algorithm:
% i)  while norm(up-u)/norm(up) > tol_inn
%     1) Fix w^k, do Gradient Descent to 
%            - sigma'(Du-w^k) + beta/2||Du-w^k||^2 + mu/2||Au-f||^2;
%            u^k+1 is determined in the following way:
%         a) compute step length tau > 0 by BB formula
%         b) determine u^k+1 by
%                  u^k+1 = u^k - alpha*g^k,
%            where g^k = -D'sigma + beta D'(Du^k - w^k) + mu A'(Au^k-f), 
%            and alpha is determined by Amijo-like nonmonotone line search;
%     2) Given u^k+1, compute w^k+1 by shrinkage
%                 w^k+1 = shrink(Du^k+1-sigma/beta, 1/beta);
%     end
% ii) update the Lagrangian multiplier by
%             sigma^k+1 = sigma^k - beta(Du^k+1 - w^k+1).
% iii)accept current u as the initial guess to run the loop again
%
% Inputs:
%       A        : either an matrix representing the measurement or a struct 
%                  with 2 function handles:
%                           A(x,1) defines @(x) A*x;
%                           A(x,2) defines @(x) A'*x;
%       b        :  either real or complex input vector representing the noisy observation of a
%                   grayscale image
%       p, q     :  size of original image
%       opts     :  structure to restore parameters
%
%
% variables in this code:
%
% lam1 = sum ||wi||
% lam2 = ||Du-w||^2 (at current w).
% lam3 = ||Au-f||^2
% lam4 = sigma'(Du-w)
%
%   f  = lam1 + beta/2 lam2 + mu/2 lam3 - lam4
%
%   g  = A'(Au-f)
%   g2 = D'(Du-w) (coefficients beta and mu are not included)
%
% 
% Numerical tests illustrate that this solver requirs large beta.
%
%
% Written by: Chengbo Li
% Advisor: Prof. Yin Zhang and Wotao Yin
% Computational and Applied Mathematics department, Rice University
% May. 12, 2009


global D Dt
[D,Dt] = defDDt;

% unify implementation of A
if ~isa(A,'function_handle')
    A = @(u,mode) f_handleA(A,u,mode); 
end

% get or check opts
opts = ftvcs_al_opts(opts); 

% problem dimension
n = p*q;

% mark important constants
mu = opts.mu;
beta = opts.beta;
tol_inn = opts.tol_inn;
tol_out = opts.tol;
gam = opts.gam;

% check if A*A'=I
tmp = rand(length(b),1);
if norm(A(A(tmp,2),1)-tmp,1)/norm(tmp,1) < 1e-3
    opts.scale_A = false;
end
clear tmp;

% check scaling A
if opts.scale_A
    [mu,A,b] = ScaleA(n,mu,A,b,opts.consist_mu); 
end 

% check scaling b
if opts.scale_b
    [mu,b,scl] = Scaleb(mu,b,opts.consist_mu);
end

% calculate A'*b
Atb = A(b,2);

% initialize U, beta
muf = mu;
betaf = beta;     % final beta
[U,mu,beta] = ftvcs_al_init(p,q,Atb,scl,opts);    % U: p*q
if mu > muf; mu = muf; end
if beta > betaf; beta = betaf; end
 muDbeta = mu/beta;                      % muDbeta: constant
rcdU = U;

% initialize multiplers
sigmax = zeros(p,q);                       % sigmax, sigmay: p*q 
sigmay = zeros(p,q);

% initialize D^T sigma
DtsAtd = zeros(p*q,1); 

% initialize out.n2re
if isfield(opts,'Ut')
    Ut = opts.Ut*scl;        %true U, just for computing the error
    nrmUt = norm(Ut,'fro');
else
    Ut = []; 
end
if ~isempty(Ut)
    out.n2re = norm(U - Ut,'fro')/nrmUt; 
end

% prepare for iterations
out.mus = mu; out.betas = beta;
out.res = []; out.itrs = []; out.f = []; out.obj = []; out.reer = [];
out.lam1 = []; out.lam2 = []; out.lam3 = []; out.lam4 = [];
out.itr = Inf;
out.tau = []; out.alpha = []; out.C = []; gp = [];
out.cnt = [];

[Ux,Uy] = D(U);                   % Ux, Uy: p*q
if opts.TVnorm == 1
    Wx = max(abs(Ux) - 1/beta, 0).*sign(Ux);
    Wy = max(abs(Uy) - 1/beta, 0).*sign(Uy);
    lam1 = sum(sum(abs(Wx) + abs(Wy)));
else
    V = sqrt(Ux.*conj(Ux) + Uy.*conj(Uy));        % V: p*q
    V(V==0) = 1;
    S = max(V - 1/beta, 0)./V;        % S: p*q
    Wx = S.*Ux;                       % Wx, Wy: p*q
    Wy = S.*Uy;
    lam1 = sum(sum(sqrt(Wx.*conj(Wx) + Wy.*conj(Wy))));
end

[lam2,lam3,lam4,f,g2,Au,g] = get_g(U,Ux,Uy,Wx,Wy,lam1,beta,mu,A,b,...
    Atb,sigmax,sigmay);
%lam, f: constant      g2: pq        Au: m         g: pq

% compute gradient
d = g2 + muDbeta*g - DtsAtd;

count = 1;
Q = 1; C = f;                     % Q, C: costant
out.f = [out.f; f]; out.C = [out.C; C];
out.lam1 = [out.lam1; lam1]; out.lam2 = [out.lam2; lam2]; out.lam3 = [out.lam3; lam3];
out.lam4 = [out.lam4; lam4];

for ii = 1:opts.maxit
    if opts.disp
        fprintf('outer iter = %d, total iter = %d, normU = %4.2e; \n',count,ii,norm(U,'fro'));
    end

    % compute tau first
    if ~isempty(gp)
        dg = g - gp;                        % dg: pq
        dg2 = g2 - g2p;                     % dg2: pq
        ss = uup'*uup;                      % ss: constant
        sy = uup'*(dg2 + muDbeta*dg);       % sy: constant
        % sy = uup'*((dg2 + g2) + muDbeta*(dg + g));
        % compute BB step length
        tau = abs(ss/max(sy,eps));               % tau: constant
        
        fst_itr = false;
    else
        % do Steepest Descent at the 1st ieration
        %d = g2 + muDbeta*g - DtsAtd;         % d: pq
        [dx,dy] = D(reshape(d,p,q));                    %dx, dy: p*q
        dDd = norm(dx,'fro')^2 + norm(dy,'fro')^2;      % dDd: cosntant
        Ad = A(d,1);                        %Ad: m
        % compute Steepest Descent step length
        tau = abs((d'*d)/(dDd + muDbeta*Ad'*Ad));
        
        % mark the first iteration 
        fst_itr = true;
    end    
    
    % keep the previous values
    Up = U; gp = g; g2p = g2; Aup = Au; 
    Uxp = Ux; Uyp = Uy;

    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    % ONE-STEP GRADIENT DESCENT %
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    taud = tau*d;
    U = U(:) - taud;
    % projected gradient method for nonnegtivity
    if opts.nonneg
        U = max(real(U),0);
    end
    U = reshape(U,p,q);                    % U: p*q (still)
    [Ux,Uy] = D(U);                        % Ux, Uy: p*q
    
    [lam2,lam3,lam4,f,g2,Au,g] = get_g(U,Ux,Uy,Wx,Wy,lam1,...
        beta,mu,A,b,Atb,sigmax,sigmay);
   
    % Nonmonotone Line Search
    alpha = 1;
    du = U - Up;                          % du: p*q
    const = opts.c*beta*(d'*taud);

    % Unew = Up + alpha*(U - Up)
    cnt = 0; flag = true;
    
    while f > C - alpha*const
        %----------------------
        %2015.06.01 NP: Avoid hardcoding and use maxcnt option
        %if cnt == 5
        if cnt == opts.maxcnt
        %---------------------
            % shrink gam
            gam = opts.rate_gam*gam;

            % give up and take Steepest Descent step
            if opts.disp
                %-----------------------
                %2015.06.01 NP: Change output message according to change above
                %disp('    count of back tracking attains 5 ');
                disp(['    count of back tracking attains ', int2str(cnt)]);
                %-----------------------
            end

            %d = g2p + muDbeta*gp - DtsAtd;
            [dx,dy] = D(reshape(d,p,q));
            dDd = norm(dx,'fro')^2 + norm(dy,'fro')^2;
            Ad = A(d,1);
            tau = abs((d'*d)/(dDd + muDbeta*Ad'*Ad));
            U = Up(:) - tau*d;
            % projected gradient method for nonnegtivity
            if opts.nonneg
                U = max(real(U),0);
            end
            U = reshape(U,p,q);
            [Ux Uy] = D(U);
            Uxbar = Ux - sigmax/beta;
            Uybar = Uy - sigmay/beta;
            if opts.TVnorm == 1
                % ONE-DIMENSIONAL SHRINKAGE STEP
                Wx = max(abs(Uxbar) - 1/beta, 0).*sign(Uxbar);
                Wy = max(abs(Uybar) - 1/beta, 0).*sign(Uybar);
                lam1 = sum(sum(abs(Wx) + abs(Wy)));
            else
                % TWO-DIMENSIONAL SHRINKAGE STEP
                V = sqrt(Uxbar.*conj(Uxbar) + Uybar.*conj(Uybar)); % V: p*q
                V(V==0) = 1;
                S = max(V - 1/beta, 0)./V;                         % S: p*q
                Wx = S.*Uxbar;
                Wy = S.*Uybar;
                lam1 = sum(sum(sqrt(Wx.*conj(Wx) + Wy.*conj(Wy))));
            end
            [lam2,lam3,lam4,f,g2,Au,g] = get_g(U,Ux,Uy,Wx,Wy,lam1,...
                beta,mu,A,b,Atb,sigmax,sigmay);
            alpha = 0; % remark the failure of back tracking
            break;
        end
        if flag
            dg = g - gp;
            dg2 = g2 - g2p;
            dAu = Au - Aup;                 % dAu: m
            dUx = Ux - Uxp;
            dUy = Uy - Uyp;
            flag = false;
        end
        alpha = alpha*opts.gamma;
        [U,lam2,lam3,lam4,f,Ux,Uy,Au,g,g2] = update_g(p,q,lam1,...
            alpha,beta,mu,Up,du,gp,dg,g2p,dg2,Aup,dAu,Wx,Wy,Uxp,dUx,...
            Uyp,dUy,b,sigmax,sigmay);
        cnt = cnt + 1;
    end
    
    % if back tracking is succeceful, then recompute
    if alpha ~= 0
        Uxbar = Ux - sigmax/beta;
        Uybar = Uy - sigmay/beta;
        if opts.TVnorm == 1
            % ONE-DIMENSIONAL SHRINKAGE STEP
            Wx = max(abs(Uxbar) - 1/beta, 0).*sign(Uxbar);
            Wy = max(abs(Uybar) - 1/beta, 0).*sign(Uybar);
        else
            % TWO-DIMENSIONAL SHRINKAGE STEP
            V = sqrt(Uxbar.*conj(Uxbar) + Uybar.*conj(Uybar));
            V(V==0) = 1;
            S = max(V - 1/beta, 0)./V;
            Wx = S.*Uxbar;
            Wy = S.*Uybar;
        end
        
        % update parameters related to Wx, Wy
        [lam1,lam2,lam4,f,g2] = update_W(beta,Wx,Wy,Ux,Uy,sigmax,...
            sigmay,lam1,lam2,lam4,f,opts.TVnorm);
    end
    
    % update reference value
    Qp = Q; Q = gam*Qp + 1; C = (gam*Qp*C + f)/Q;
    uup = U - Up; uup = uup(:);           % uup: pq
    nrmuup = norm(uup,'fro');                   % nrmuup: constant
    
    out.res = [out.res; nrmuup];
    out.f = [out.f; f]; out.C = [out.C; C]; out.cnt = [out.cnt;cnt];
    out.lam1 = [out.lam1; lam1]; out.lam2 = [out.lam2; lam2]; 
    out.lam3 = [out.lam3; lam3];out.lam4 = [out.lam4; lam4];
    out.tau = [out.tau; tau]; out.alpha = [out.alpha; alpha];

    if ~isempty(Ut), out.n2re = [out.n2re; norm(U - Ut,'fro')/norm(Ut,'fro')]; end

    nrmup = norm(Up,'fro');
    RelChg = nrmuup/nrmup;

    % recompute gradient
    d = g2 + muDbeta*g - DtsAtd;
    
    if RelChg < tol_inn  && ~fst_itr
        count = count + 1;
        RelChgOut = norm(U-rcdU,'fro')/nrmup;
        out.reer = [out.reer; RelChgOut];
        rcdU = U;
        out.obj = [out.obj; f + lam4];
        if isempty(out.itrs)
            out.itrs = ii;
        else
            out.itrs = [out.itrs; ii - sum(out.itrs)];
        end

        % stop if already reached final multipliers
        if RelChgOut < tol_out || count > opts.maxcnt
            if opts.isreal
                U = real(U);
            end
            if exist('scl','var')
                U = U/scl;
            end
            out.itr = ii;
            fprintf('Number of total iterations is %d. \n',out.itr);
            return
        end
        
        % update multipliers
        [sigmax,sigmay,lam4,f] = update_mlp(beta,Wx,Wy,Ux,Uy,...
            sigmax,sigmay,lam4,f);
        
        % update penality parameters for continuation scheme
        beta0 = beta;
        beta = beta*opts.rate_ctn;
        mu = mu*opts.rate_ctn;
        if beta > betaf; beta = betaf; end
        if mu > muf; mu = muf; end
        muDbeta = mu/beta;
        out.mus = [out.mus; mu]; out.betas = [out.betas; beta];


        % update function value, gradient, and relavent constant
        f = lam1 + beta/2*lam2 + mu/2*lam3 - lam4;
        DtsAtd = -(beta0/beta)*d;     % DtsAtd should be divded by new beta instead of the old one for consistency!  
        d = g2 + muDbeta*g - DtsAtd;

        %initialize the constants
        gp = [];
        gam = opts.gam; Q = 1; C = f;
    end

end

if opts.isreal
    U = real(U);
end
if exist('scl','var')
    fprintf('Attain the maximum of iterations %d. \n',opts.maxit);
    U = U/scl;
end




function [lam2,lam3,lam4,f,g2,Au,g] = get_g(U,Ux,Uy,Wx,Wy,lam1,...
    beta,mu,A,b,Atb,sigmax,sigmay)
global Dt

% A*u 
Au = A(U(:),1);

% g
g = A(Au,2) - Atb;



% lam2
Vx = Ux - Wx;
Vy = Uy - Wy;
lam2 = sum(sum(Vx.*conj(Vx) + Vy.*conj(Vy)));


% g2 = D'(Du-w)
g2 = Dt(Vx,Vy);

% lam3
Aub = Au-b;
lam3 = norm(Aub,'fro')^2;

%lam4
lam4 = sum(sum(conj(sigmax).*Vx + conj(sigmay).*Vy));


% f
f = lam1 + beta/2*lam2 + mu/2*lam3 - lam4;



function [U,lam2,lam3,lam4,f,Ux,Uy,Au,g,g2] = update_g(p,q,lam1,...
    alpha,beta,mu,Up,du,gp,dg,g2p,dg2,Aup,dAu,Wx,Wy,Uxp,dUx,Uyp,dUy,b,...
    sigmax,sigmay)

g = gp + alpha*dg;
g2 = g2p + alpha*dg2;
U = Up + alpha*reshape(du,p,q);
Au = Aup + alpha*dAu;
Ux = Uxp + alpha*dUx;
Uy = Uyp + alpha*dUy;

Vx = Ux - Wx;
Vy = Uy - Wy;
lam2 = sum(sum(Vx.*conj(Vx) + Vy.*conj(Vy)));
Aub = Au-b;
lam3 = norm(Aub,'fro')^2;
lam4 = sum(sum(conj(sigmax).*Vx + conj(sigmay).*Vy));
f = lam1 + beta/2*lam2 + mu/2*lam3 - lam4;



function [lam1,lam2,lam4,f,g2] = update_W(beta,...
    Wx,Wy,Ux,Uy,sigmax,sigmay,lam1,lam2,lam4,f,option)
global Dt

% update parameters because Wx, Wy were updated
tmpf = f -lam1 - beta/2*lam2 + lam4;
if option == 1
    lam1 = sum(sum(abs(Wx) + abs(Wy)));
else
    lam1 = sum(sum(sqrt(Wx.^2 + Wy.^2)));
end
Vx = Ux - Wx;
Vy = Uy - Wy;
g2 = Dt(Vx,Vy);
lam2 = sum(sum(Vx.*conj(Vx) + Vy.*conj(Vy)));
lam4 = sum(sum(conj(sigmax).*Vx + conj(sigmay).*Vy));
f = tmpf +lam1 + beta/2*lam2 - lam4;



function [sigmax,sigmay,lam4,f] = update_mlp(beta, ...
    Wx,Wy,Ux,Uy,sigmax,sigmay,lam4,f)

Vx = Ux - Wx;
Vy = Uy - Wy;
sigmax = sigmax - beta*Vx;
sigmay = sigmay - beta*Vy;

tmpf = f + lam4;
lam4 = sum(sum(conj(sigmax).*Vx + conj(sigmay).*Vy));
f = tmpf - lam4;




function [U,mu,beta] = ftvcs_al_init(p,q,Atb,scl,opts)

% initialize mu beta
if isfield(opts,'mu0')
    mu = opts.mu0;
else
    error('Initial mu is not provided.');
end
if isfield(opts,'beta0')
    beta = opts.beta0;
else
    error('Initial beta is not provided.');
end

% initialize U
[mm,nn] = size(opts.init);
if max(mm,nn) == 1
    switch opts.init
        case 0, U = zeros(p,q);
        case 1, U = reshape(Atb,p,q);
    end
else
    U = opts.init*scl;  
    if mm ~= p || nn ~= q
        fprintf('Input initial guess has incompatible size! Switch to the default initial guess. \n');
        U = reshape(Atb,p,q);
    end
end
