new

Get trending papers in your email inbox!

Subscribe

Daily Papers

byAK and the research community

Dec 24

Sharpness-Aware Training for Free

Modern deep neural networks (DNNs) have achieved state-of-the-art performances but are typically over-parameterized. The over-parameterization may result in undesirably large generalization error in the absence of other customized training strategies. Recently, a line of research under the name of Sharpness-Aware Minimization (SAM) has shown that minimizing a sharpness measure, which reflects the geometry of the loss landscape, can significantly reduce the generalization error. However, SAM-like methods incur a two-fold computational overhead of the given base optimizer (e.g. SGD) for approximating the sharpness measure. In this paper, we propose Sharpness-Aware Training for Free, or SAF, which mitigates the sharp landscape at almost zero additional computational cost over the base optimizer. Intuitively, SAF achieves this by avoiding sudden drops in the loss in the sharp local minima throughout the trajectory of the updates of the weights. Specifically, we suggest a novel trajectory loss, based on the KL-divergence between the outputs of DNNs with the current weights and past weights, as a replacement of the SAM's sharpness measure. This loss captures the rate of change of the training loss along the model's update trajectory. By minimizing it, SAF ensures the convergence to a flat minimum with improved generalization capabilities. Extensive empirical results show that SAF minimizes the sharpness in the same way that SAM does, yielding better results on the ImageNet dataset with essentially the same computational cost as the base optimizer.

  • 5 authors
·
May 27, 2022

TRAM: Bridging Trust Regions and Sharpness Aware Minimization

Sharpness-aware minimization (SAM) reports improving domain generalization by reducing the loss surface curvature in the parameter space. However, generalization during fine-tuning is often more dependent on the transferability of representations in the function space. Trust-region methods (TR) target this goal by regularizing representation curvature to reduce catastrophic forgetting of pre-trained task-agnostic information while adopting task-specific skills. We consider unifying these strategies for low curvature in both parameter space and function space to improve out-of-domain (OOD) generalization. We propose Trust Region Aware Minimization (TRAM), a SAM algorithm fine-tuning for low parameter sharpness and smooth, informative representations preserving pre-trained structure. TRAM uses a trust region bound to inform the SAM adversarial neighborhood, introducing an awareness of function curvature within optimization for flatter minima. We empirically validate TRAM in vision (cross-dataset adaptation) and text (OOD language modeling, zero-shot cross-lingual transfer) tasks where robust domain transfer and representation generality are critical. TRAM outperforms SAM- and TR-based optimization across all tasks, notably surpassing competing methods for hard transfer between anticorrelated domains. TRAM establishes a novel standard in fine-tuning for domain-generalizable models with minimal additional computation over previous sharpness-aware methods.

  • 4 authors
·
Oct 5, 2023

Towards LLM Unlearning Resilient to Relearning Attacks: A Sharpness-Aware Minimization Perspective and Beyond

The LLM unlearning technique has recently been introduced to comply with data regulations and address the safety and ethical concerns of LLMs by removing the undesired data-model influence. However, state-of-the-art unlearning methods face a critical vulnerability: they are susceptible to ``relearning'' the removed information from a small number of forget data points, known as relearning attacks. In this paper, we systematically investigate how to make unlearned models robust against such attacks. For the first time, we establish a connection between robust unlearning and sharpness-aware minimization (SAM) through a unified robust optimization framework, in an analogy to adversarial training designed to defend against adversarial attacks. Our analysis for SAM reveals that smoothness optimization plays a pivotal role in mitigating relearning attacks. Thus, we further explore diverse smoothing strategies to enhance unlearning robustness. Extensive experiments on benchmark datasets, including WMDP and MUSE, demonstrate that SAM and other smoothness optimization approaches consistently improve the resistance of LLM unlearning to relearning attacks. Notably, smoothness-enhanced unlearning also helps defend against (input-level) jailbreaking attacks, broadening our proposal's impact in robustifying LLM unlearning. Codes are available at https://github.com/OPTML-Group/Unlearn-Smooth.

  • 6 authors
·
Feb 7

UU-Mamba: Uncertainty-aware U-Mamba for Cardiovascular Segmentation

Building on the success of deep learning models in cardiovascular structure segmentation, increasing attention has been focused on improving generalization and robustness, particularly in small, annotated datasets. Despite recent advancements, current approaches often face challenges such as overfitting and accuracy limitations, largely due to their reliance on large datasets and narrow optimization techniques. This paper introduces the UU-Mamba model, an extension of the U-Mamba architecture, designed to address these challenges in both cardiac and vascular segmentation. By incorporating Sharpness-Aware Minimization (SAM), the model enhances generalization by targeting flatter minima in the loss landscape. Additionally, we propose an uncertainty-aware loss function that combines region-based, distribution-based, and pixel-based components to improve segmentation accuracy by capturing both local and global features. While the UU-Mamba model has already demonstrated great performance, further testing is required to fully assess its generalization and robustness. We expand our evaluation by conducting new trials on the ImageCAS (coronary artery) and Aorta (aortic branches and zones) datasets, which present more complex segmentation challenges than the ACDC dataset (left and right ventricles) used in our previous work, showcasing the model's adaptability and resilience. We confirm UU-Mamba's superior performance over leading models such as TransUNet, Swin-Unet, nnUNet, and nnFormer. Moreover, we provide a more comprehensive evaluation of the model's robustness and segmentation accuracy, as demonstrated by extensive experiments.

  • 8 authors
·
Sep 21, 2024

GAQAT: gradient-adaptive quantization-aware training for domain generalization

Research on loss surface geometry, such as Sharpness-Aware Minimization (SAM), shows that flatter minima improve generalization. Recent studies further reveal that flatter minima can also reduce the domain generalization (DG) gap. However, existing flatness-based DG techniques predominantly operate within a full-precision training process, which is impractical for deployment on resource-constrained edge devices that typically rely on lower bit-width representations (e.g., 4 bits, 3 bits). Consequently, low-precision quantization-aware training is critical for optimizing these techniques in real-world applications. In this paper, we observe a significant degradation in performance when applying state-of-the-art DG-SAM methods to quantized models, suggesting that current approaches fail to preserve generalizability during the low-precision training process. To address this limitation, we propose a novel Gradient-Adaptive Quantization-Aware Training (GAQAT) framework for DG. Our approach begins by identifying the scale-gradient conflict problem in low-precision quantization, where the task loss and smoothness loss induce conflicting gradients for the scaling factors of quantizers, with certain layers exhibiting opposing gradient directions. This conflict renders the optimization of quantized weights highly unstable. To mitigate this, we further introduce a mechanism to quantify gradient inconsistencies and selectively freeze the gradients of scaling factors, thereby stabilizing the training process and enhancing out-of-domain generalization. Extensive experiments validate the effectiveness of the proposed GAQAT framework. On PACS, our 3-bit and 4-bit models outperform direct DG-QAT integration by up to 4.5%. On DomainNet, the 4-bit model achieves near-lossless performance compared to full precision, with improvements of 1.39% (4-bit) and 1.06% (3-bit) over the SOTA QAT baseline.

  • 7 authors
·
Dec 7, 2024

Improving Multi-task Learning via Seeking Task-based Flat Regions

Multi-Task Learning (MTL) is a widely-used and powerful learning paradigm for training deep neural networks that allows learning more than one objective by a single backbone. Compared to training tasks separately, MTL significantly reduces computational costs, improves data efficiency, and potentially enhances model performance by leveraging knowledge across tasks. Hence, it has been adopted in a variety of applications, ranging from computer vision to natural language processing and speech recognition. Among them, there is an emerging line of work in MTL that focuses on manipulating the task gradient to derive an ultimate gradient descent direction to benefit all tasks. Despite achieving impressive results on many benchmarks, directly applying these approaches without using appropriate regularization techniques might lead to suboptimal solutions on real-world problems. In particular, standard training that minimizes the empirical loss on the training data can easily suffer from overfitting to low-resource tasks or be spoiled by noisy-labeled ones, which can cause negative transfer between tasks and overall performance drop. To alleviate such problems, we propose to leverage a recently introduced training method, named Sharpness-aware Minimization, which can enhance model generalization ability on single-task learning. Accordingly, we present a novel MTL training methodology, encouraging the model to find task-based flat minima for coherently improving its generalization capability on all tasks. Finally, we conduct comprehensive experiments on a variety of applications to demonstrate the merit of our proposed approach to existing gradient-based MTL methods, as suggested by our developed theory.

  • 6 authors
·
Nov 24, 2022

Towards Stable Test-Time Adaptation in Dynamic Wild World

Test-time adaptation (TTA) has shown to be effective at tackling distribution shifts between training and testing data by adapting a given model on test samples. However, the online model updating of TTA may be unstable and this is often a key obstacle preventing existing TTA methods from being deployed in the real world. Specifically, TTA may fail to improve or even harm the model performance when test data have: 1) mixed distribution shifts, 2) small batch sizes, and 3) online imbalanced label distribution shifts, which are quite common in practice. In this paper, we investigate the unstable reasons and find that the batch norm layer is a crucial factor hindering TTA stability. Conversely, TTA can perform more stably with batch-agnostic norm layers, \ie, group or layer norm. However, we observe that TTA with group and layer norms does not always succeed and still suffers many failure cases. By digging into the failure cases, we find that certain noisy test samples with large gradients may disturb the model adaption and result in collapsed trivial solutions, \ie, assigning the same class label for all samples. To address the above collapse issue, we propose a sharpness-aware and reliable entropy minimization method, called SAR, for further stabilizing TTA from two aspects: 1) remove partial noisy samples with large gradients, 2) encourage model weights to go to a flat minimum so that the model is robust to the remaining noisy samples. Promising results demonstrate that SAR performs more stably over prior methods and is computationally efficient under the above wild test scenarios.

  • 7 authors
·
Feb 23, 2023

When Vision Transformers Outperform ResNets without Pre-training or Strong Data Augmentations

Vision Transformers (ViTs) and MLPs signal further efforts on replacing hand-wired features or inductive biases with general-purpose neural architectures. Existing works empower the models by massive data, such as large-scale pre-training and/or repeated strong data augmentations, and still report optimization-related problems (e.g., sensitivity to initialization and learning rates). Hence, this paper investigates ViTs and MLP-Mixers from the lens of loss geometry, intending to improve the models' data efficiency at training and generalization at inference. Visualization and Hessian reveal extremely sharp local minima of converged models. By promoting smoothness with a recently proposed sharpness-aware optimizer, we substantially improve the accuracy and robustness of ViTs and MLP-Mixers on various tasks spanning supervised, adversarial, contrastive, and transfer learning (e.g., +5.3\% and +11.0\% top-1 accuracy on ImageNet for ViT-B/16 and Mixer-B/16, respectively, with the simple Inception-style preprocessing). We show that the improved smoothness attributes to sparser active neurons in the first few layers. The resultant ViTs outperform ResNets of similar size and throughput when trained from scratch on ImageNet without large-scale pre-training or strong data augmentations. Model checkpoints are available at https://github.com/google-research/vision_transformer.

  • 3 authors
·
Jun 2, 2021 1

Generalized Incremental Learning under Concept Drift across Evolving Data Streams

Real-world data streams exhibit inherent non-stationarity characterized by concept drift, posing significant challenges for adaptive learning systems. While existing methods address isolated distribution shifts, they overlook the critical co-evolution of label spaces and distributions under limited supervision and persistent uncertainty. To address this, we formalize Generalized Incremental Learning under Concept Drift (GILCD), characterizing the joint evolution of distributions and label spaces in open-environment streaming contexts, and propose a novel framework called Calibrated Source-Free Adaptation (CSFA). First, CSFA introduces a training-free prototype calibration mechanism that dynamically fuses emerging prototypes with base representations, enabling stable new-class identification without optimization overhead. Second, we design a novel source-free adaptation algorithm, i.e., Reliable Surrogate Gap Sharpness-aware (RSGS) minimization. It integrates sharpness-aware perturbation loss optimization with surrogate gap minimization, while employing entropy-based uncertainty filtering to discard unreliable samples. This mechanism ensures robust distribution alignment and mitigates generalization degradation caused by uncertainties. Therefore, CSFA establishes a unified framework for stable adaptation to evolving semantics and distributions in open-world streaming scenarios. Extensive experiments validate the superior performance and effectiveness of CSFA compared to state-of-the-art approaches.

  • 3 authors
·
Jun 6

Improving the Model Consistency of Decentralized Federated Learning

To mitigate the privacy leakages and communication burdens of Federated Learning (FL), decentralized FL (DFL) discards the central server and each client only communicates with its neighbors in a decentralized communication network. However, existing DFL suffers from high inconsistency among local clients, which results in severe distribution shift and inferior performance compared with centralized FL (CFL), especially on heterogeneous data or sparse communication topology. To alleviate this issue, we propose two DFL algorithms named DFedSAM and DFedSAM-MGS to improve the performance of DFL. Specifically, DFedSAM leverages gradient perturbation to generate local flat models via Sharpness Aware Minimization (SAM), which searches for models with uniformly low loss values. DFedSAM-MGS further boosts DFedSAM by adopting Multiple Gossip Steps (MGS) for better model consistency, which accelerates the aggregation of local flat models and better balances communication complexity and generalization. Theoretically, we present improved convergence rates small Obig(1{KT}+1{T}+1{K^{1/2}T^{3/2}(1-lambda)^2}big) and small Obig(1{KT}+1{T}+lambda^Q+1{K^{1/2}T^{3/2}(1-lambda^Q)^2}big) in non-convex setting for DFedSAM and DFedSAM-MGS, respectively, where 1-lambda is the spectral gap of gossip matrix and Q is the number of MGS. Empirically, our methods can achieve competitive performance compared with CFL methods and outperform existing DFL methods.

  • 7 authors
·
Feb 8, 2023

Unknown Domain Inconsistency Minimization for Domain Generalization

The objective of domain generalization (DG) is to enhance the transferability of the model learned from a source domain to unobserved domains. To prevent overfitting to a specific domain, Sharpness-Aware Minimization (SAM) reduces source domain's loss sharpness. Although SAM variants have delivered significant improvements in DG, we highlight that there's still potential for improvement in generalizing to unknown domains through the exploration on data space. This paper introduces an objective rooted in both parameter and data perturbed regions for domain generalization, coined Unknown Domain Inconsistency Minimization (UDIM). UDIM reduces the loss landscape inconsistency between source domain and unknown domains. As unknown domains are inaccessible, these domains are empirically crafted by perturbing instances from the source domain dataset. In particular, by aligning the loss landscape acquired in the source domain to the loss landscape of perturbed domains, we expect to achieve generalization grounded on these flat minima for the unknown domains. Theoretically, we validate that merging SAM optimization with the UDIM objective establishes an upper bound for the true objective of the DG task. In an empirical aspect, UDIM consistently outperforms SAM variants across multiple DG benchmark datasets. Notably, UDIM shows statistically significant improvements in scenarios with more restrictive domain information, underscoring UDIM's generalization capability in unseen domains. Our code is available at https://github.com/SJShin-AI/UDIM.

  • 5 authors
·
Mar 12, 2024

Outliers with Opposing Signals Have an Outsized Effect on Neural Network Optimization

We identify a new phenomenon in neural network optimization which arises from the interaction of depth and a particular heavy-tailed structure in natural data. Our result offers intuitive explanations for several previously reported observations about network training dynamics. In particular, it implies a conceptually new cause for progressive sharpening and the edge of stability; we also highlight connections to other concepts in optimization and generalization including grokking, simplicity bias, and Sharpness-Aware Minimization. Experimentally, we demonstrate the significant influence of paired groups of outliers in the training data with strong opposing signals: consistent, large magnitude features which dominate the network output throughout training and provide gradients which point in opposite directions. Due to these outliers, early optimization enters a narrow valley which carefully balances the opposing groups; subsequent sharpening causes their loss to rise rapidly, oscillating between high on one group and then the other, until the overall loss spikes. We describe how to identify these groups, explore what sets them apart, and carefully study their effect on the network's optimization and behavior. We complement these experiments with a mechanistic explanation on a toy example of opposing signals and a theoretical analysis of a two-layer linear network on a simple model. Our finding enables new qualitative predictions of training behavior which we confirm experimentally. It also provides a new lens through which to study and improve modern training practices for stochastic optimization, which we highlight via a case study of Adam versus SGD.

  • 2 authors
·
Nov 7, 2023

Changing the Training Data Distribution to Reduce Simplicity Bias Improves In-distribution Generalization

Can we modify the training data distribution to encourage the underlying optimization method toward finding solutions with superior generalization performance on in-distribution data? In this work, we approach this question for the first time by comparing the inductive bias of gradient descent (GD) with that of sharpness-aware minimization (SAM). By studying a two-layer CNN, we rigorously prove that SAM learns different features more uniformly, particularly in early epochs. That is, SAM is less susceptible to simplicity bias compared to GD. We also show that examples containing features that are learned early are separable from the rest based on the model's output. Based on this observation, we propose a method that (i) clusters examples based on the network output early in training, (ii) identifies a cluster of examples with similar network output, and (iii) upsamples the rest of examples only once to alleviate the simplicity bias. We show empirically that USEFUL effectively improves the generalization performance on the original data distribution when training with various gradient methods, including (S)GD and SAM. Notably, we demonstrate that our method can be combined with SAM variants and existing data augmentation strategies to achieve, to the best of our knowledge, state-of-the-art performance for training ResNet18 on CIFAR10, STL10, CINIC10, Tiny-ImageNet; ResNet34 on CIFAR100; and VGG19 and DenseNet121 on CIFAR10.

  • 4 authors
·
Apr 26, 2024

A Three-regime Model of Network Pruning

Recent work has highlighted the complex influence training hyperparameters, e.g., the number of training epochs, can have on the prunability of machine learning models. Perhaps surprisingly, a systematic approach to predict precisely how adjusting a specific hyperparameter will affect prunability remains elusive. To address this gap, we introduce a phenomenological model grounded in the statistical mechanics of learning. Our approach uses temperature-like and load-like parameters to model the impact of neural network (NN) training hyperparameters on pruning performance. A key empirical result we identify is a sharp transition phenomenon: depending on the value of a load-like parameter in the pruned model, increasing the value of a temperature-like parameter in the pre-pruned model may either enhance or impair subsequent pruning performance. Based on this transition, we build a three-regime model by taxonomizing the global structure of the pruned NN loss landscape. Our model reveals that the dichotomous effect of high temperature is associated with transitions between distinct types of global structures in the post-pruned model. Based on our results, we present three case-studies: 1) determining whether to increase or decrease a hyperparameter for improved pruning; 2) selecting the best model to prune from a family of models; and 3) tuning the hyperparameter of the Sharpness Aware Minimization method for better pruning performance.

  • 4 authors
·
May 28, 2023

Noise-Adaptive Layerwise Learning Rates: Accelerating Geometry-Aware Optimization for Deep Neural Network Training

Geometry-aware optimization algorithms, such as Muon, have achieved remarkable success in training deep neural networks (DNNs). These methods leverage the underlying geometry of DNNs by selecting appropriate norms for different layers and updating parameters via norm-constrained linear minimization oracles (LMOs). However, even within a group of layers associated with the same norm, the local curvature can be heterogeneous across layers and vary dynamically over the course of training. For example, recent work shows that sharpness varies substantially across transformer layers and throughout training, yet standard geometry-aware optimizers impose fixed learning rates to layers within the same group, which may be inefficient for DNN training. In this paper, we introduce a noise-adaptive layerwise learning rate scheme on top of geometry-aware optimization algorithms and substantially accelerate DNN training compared to methods that use fixed learning rates within each group. Our method estimates gradient variance in the dual norm induced by the chosen LMO on the fly, and uses it to assign time-varying noise-adaptive layerwise learning rates within each group. We provide a theoretical analysis showing that our algorithm achieves a sharp convergence rate. Empirical results on transformer architectures such as LLaMA and GPT demonstrate that our approach achieves faster convergence than state-of-the-art optimizers.

  • 5 authors
·
Oct 15

VITON-HD: High-Resolution Virtual Try-On via Misalignment-Aware Normalization

The task of image-based virtual try-on aims to transfer a target clothing item onto the corresponding region of a person, which is commonly tackled by fitting the item to the desired body part and fusing the warped item with the person. While an increasing number of studies have been conducted, the resolution of synthesized images is still limited to low (e.g., 256x192), which acts as the critical limitation against satisfying online consumers. We argue that the limitation stems from several challenges: as the resolution increases, the artifacts in the misaligned areas between the warped clothes and the desired clothing regions become noticeable in the final results; the architectures used in existing methods have low performance in generating high-quality body parts and maintaining the texture sharpness of the clothes. To address the challenges, we propose a novel virtual try-on method called VITON-HD that successfully synthesizes 1024x768 virtual try-on images. Specifically, we first prepare the segmentation map to guide our virtual try-on synthesis, and then roughly fit the target clothing item to a given person's body. Next, we propose ALIgnment-Aware Segment (ALIAS) normalization and ALIAS generator to handle the misaligned areas and preserve the details of 1024x768 inputs. Through rigorous comparison with existing methods, we demonstrate that VITON-HD highly surpasses the baselines in terms of synthesized image quality both qualitatively and quantitatively. Code is available at https://github.com/shadow2496/VITON-HD.

  • 4 authors
·
Mar 31, 2021

Frequency-aware Feature Fusion for Dense Image Prediction

Dense image prediction tasks demand features with strong category information and precise spatial boundary details at high resolution. To achieve this, modern hierarchical models often utilize feature fusion, directly adding upsampled coarse features from deep layers and high-resolution features from lower levels. In this paper, we observe rapid variations in fused feature values within objects, resulting in intra-category inconsistency due to disturbed high-frequency features. Additionally, blurred boundaries in fused features lack accurate high frequency, leading to boundary displacement. Building upon these observations, we propose Frequency-Aware Feature Fusion (FreqFusion), integrating an Adaptive Low-Pass Filter (ALPF) generator, an offset generator, and an Adaptive High-Pass Filter (AHPF) generator. The ALPF generator predicts spatially-variant low-pass filters to attenuate high-frequency components within objects, reducing intra-class inconsistency during upsampling. The offset generator refines large inconsistent features and thin boundaries by replacing inconsistent features with more consistent ones through resampling, while the AHPF generator enhances high-frequency detailed boundary information lost during downsampling. Comprehensive visualization and quantitative analysis demonstrate that FreqFusion effectively improves feature consistency and sharpens object boundaries. Extensive experiments across various dense prediction tasks confirm its effectiveness. The code is made publicly available at https://github.com/Linwei-Chen/FreqFusion.

  • 6 authors
·
Aug 23, 2024